Unverified Commit 491e9518 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Move convert_to_rgb to image_transforms module (#20784)

* Move convert_to_rgb to image_transforms module

* Fix tests
parent 4bc723f8
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from transformers.image_utils import ( from transformers.image_utils import (
ChannelDimension, ChannelDimension,
ImageInput,
get_channel_dimension_axis, get_channel_dimension_axis,
get_image_size, get_image_size,
infer_channel_dimension_format, infer_channel_dimension_format,
...@@ -687,3 +688,22 @@ def pad( ...@@ -687,3 +688,22 @@ def pad(
image = to_channel_dimension_format(image, data_format) if data_format is not None else image image = to_channel_dimension_format(image, data_format) if data_format is not None else image
return image return image
# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
def convert_to_rgb(image: ImageInput) -> ImageInput:
"""
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
as is.
Args:
image (Image):
The image to convert.
"""
requires_backends(convert_to_rgb, ["vision"])
if not isinstance(image, PIL.Image.Image):
return image
image = image.convert("RGB")
return image
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Image processor class for BiT.""" """Image processor class for BiT."""
from typing import Any, Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
...@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType ...@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
center_crop, center_crop,
convert_to_rgb,
get_resize_output_image_size, get_resize_output_image_size,
normalize, normalize,
rescale, rescale,
...@@ -41,20 +42,6 @@ if is_vision_available(): ...@@ -41,20 +42,6 @@ if is_vision_available():
import PIL import PIL
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
class BitImageProcessor(BaseImageProcessor): class BitImageProcessor(BaseImageProcessor):
r""" r"""
Constructs a BiT image processor. Constructs a BiT image processor.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Image processor class for Chinese-CLIP.""" """Image processor class for Chinese-CLIP."""
from typing import Any, Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
...@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType ...@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
center_crop, center_crop,
convert_to_rgb,
get_resize_output_image_size, get_resize_output_image_size,
normalize, normalize,
rescale, rescale,
...@@ -41,20 +42,6 @@ if is_vision_available(): ...@@ -41,20 +42,6 @@ if is_vision_available():
import PIL import PIL
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
class ChineseCLIPImageProcessor(BaseImageProcessor): class ChineseCLIPImageProcessor(BaseImageProcessor):
r""" r"""
Constructs a Chinese-CLIP image processor. Constructs a Chinese-CLIP image processor.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Image processor class for CLIP.""" """Image processor class for CLIP."""
from typing import Any, Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
...@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType ...@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
center_crop, center_crop,
convert_to_rgb,
get_resize_output_image_size, get_resize_output_image_size,
normalize, normalize,
rescale, rescale,
...@@ -41,20 +42,6 @@ if is_vision_available(): ...@@ -41,20 +42,6 @@ if is_vision_available():
import PIL import PIL
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
class CLIPImageProcessor(BaseImageProcessor): class CLIPImageProcessor(BaseImageProcessor):
r""" r"""
Constructs a CLIP image processor. Constructs a CLIP image processor.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
"""Image processor class for ViT hybrid.""" """Image processor class for ViT hybrid."""
from typing import Any, Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
...@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType ...@@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
center_crop, center_crop,
convert_to_rgb,
get_resize_output_image_size, get_resize_output_image_size,
normalize, normalize,
rescale, rescale,
...@@ -41,21 +42,6 @@ if is_vision_available(): ...@@ -41,21 +42,6 @@ if is_vision_available():
import PIL import PIL
# Copied from transformers.models.bit.image_processing_bit.convert_to_rgb
def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]:
"""
Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is.
Args:
image (`PIL.Image.Image`):
The image to convert.
"""
if not isinstance(image, PIL.Image.Image):
return image
return image.convert("RGB")
class ViTHybridImageProcessor(BaseImageProcessor): class ViTHybridImageProcessor(BaseImageProcessor):
r""" r"""
Constructs a ViT Hybrid image processor. Constructs a ViT Hybrid image processor.
......
...@@ -37,6 +37,7 @@ if is_vision_available(): ...@@ -37,6 +37,7 @@ if is_vision_available():
from transformers.image_transforms import ( from transformers.image_transforms import (
center_crop, center_crop,
center_to_corners_format, center_to_corners_format,
convert_to_rgb,
corners_to_center_format, corners_to_center_format,
get_resize_output_image_size, get_resize_output_image_size,
id_to_rgb, id_to_rgb,
...@@ -456,3 +457,32 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -456,3 +457,32 @@ class ImageTransformsTester(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last")) np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
) )
@require_vision
def test_convert_to_rgb(self):
# Test that an RGBA image is converted to RGB
image = np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.uint8)
pil_image = PIL.Image.fromarray(image)
self.assertEqual(pil_image.mode, "RGBA")
self.assertEqual(pil_image.size, (2, 1))
# For the moment, numpy images are returned as is
rgb_image = convert_to_rgb(image)
self.assertEqual(rgb_image.shape, (1, 2, 4))
self.assertTrue(np.allclose(rgb_image, image))
# And PIL images are converted
rgb_image = convert_to_rgb(pil_image)
self.assertEqual(rgb_image.mode, "RGB")
self.assertEqual(rgb_image.size, (2, 1))
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[1, 2, 3], [5, 6, 7]]], dtype=np.uint8)))
# Test that a grayscale image is converted to RGB
image = np.array([[0, 255]], dtype=np.uint8)
pil_image = PIL.Image.fromarray(image)
self.assertEqual(pil_image.mode, "L")
self.assertEqual(pil_image.size, (2, 1))
rgb_image = convert_to_rgb(pil_image)
self.assertEqual(rgb_image.mode, "RGB")
self.assertEqual(rgb_image.size, (2, 1))
self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment