"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "cb45f71c4dfb28613f8348716de821df7db68799"
Unverified Commit 11757e2b authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add input_data_format argument, image transforms (#25462)

* Enable specifying input data format - overriding inferring

* Add tests
parent 0acf5622
...@@ -63,6 +63,8 @@ def to_channel_dimension_format( ...@@ -63,6 +63,8 @@ def to_channel_dimension_format(
The image to have its channel dimension set. The image to have its channel dimension set.
channel_dim (`ChannelDimension`): channel_dim (`ChannelDimension`):
The channel dimension format to use. The channel dimension format to use.
input_channel_dim (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
Returns: Returns:
`np.ndarray`: The image with the channel dimension set to `channel_dim`. `np.ndarray`: The image with the channel dimension set to `channel_dim`.
...@@ -88,7 +90,11 @@ def to_channel_dimension_format( ...@@ -88,7 +90,11 @@ def to_channel_dimension_format(
def rescale( def rescale(
image: np.ndarray, scale: float, data_format: Optional[ChannelDimension] = None, dtype=np.float32 image: np.ndarray,
scale: float,
data_format: Optional[ChannelDimension] = None,
dtype: np.dtype = np.float32,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Rescales `image` by `scale`. Rescales `image` by `scale`.
...@@ -103,6 +109,8 @@ def rescale( ...@@ -103,6 +109,8 @@ def rescale(
dtype (`np.dtype`, *optional*, defaults to `np.float32`): dtype (`np.dtype`, *optional*, defaults to `np.float32`):
The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
extractors. extractors.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
Returns: Returns:
`np.ndarray`: The rescaled image. `np.ndarray`: The rescaled image.
...@@ -112,7 +120,7 @@ def rescale( ...@@ -112,7 +120,7 @@ def rescale(
rescaled_image = image * scale rescaled_image = image * scale
if data_format is not None: if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format) rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
rescaled_image = rescaled_image.astype(dtype) rescaled_image = rescaled_image.astype(dtype)
...@@ -149,6 +157,7 @@ def _rescale_for_pil_conversion(image): ...@@ -149,6 +157,7 @@ def _rescale_for_pil_conversion(image):
def to_pil_image( def to_pil_image(
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"], image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
do_rescale: Optional[bool] = None, do_rescale: Optional[bool] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> "PIL.Image.Image": ) -> "PIL.Image.Image":
""" """
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
...@@ -161,6 +170,8 @@ def to_pil_image( ...@@ -161,6 +170,8 @@ def to_pil_image(
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision, to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
and `False` otherwise. and `False` otherwise.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Returns: Returns:
`PIL.Image.Image`: The converted image. `PIL.Image.Image`: The converted image.
...@@ -179,7 +190,7 @@ def to_pil_image( ...@@ -179,7 +190,7 @@ def to_pil_image(
raise ValueError("Input image type not supported: {}".format(type(image))) raise ValueError("Input image type not supported: {}".format(type(image)))
# If the channel as been moved to first dim, we put it back at the end. # If the channel as been moved to first dim, we put it back at the end.
image = to_channel_dimension_format(image, ChannelDimension.LAST) image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it. # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
...@@ -200,6 +211,7 @@ def get_resize_output_image_size( ...@@ -200,6 +211,7 @@ def get_resize_output_image_size(
size: Union[int, Tuple[int, int], List[int], Tuple[int]], size: Union[int, Tuple[int, int], List[int], Tuple[int]],
default_to_square: bool = True, default_to_square: bool = True,
max_size: Optional[int] = None, max_size: Optional[int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> tuple: ) -> tuple:
""" """
Find the target (height, width) dimension of the output image after resizing given the input image and the desired Find the target (height, width) dimension of the output image after resizing given the input image and the desired
...@@ -225,6 +237,8 @@ def get_resize_output_image_size( ...@@ -225,6 +237,8 @@ def get_resize_output_image_size(
than `max_size` after being resized according to `size`, then the image is resized again so that the longer than `max_size` after being resized according to `size`, then the image is resized again so that the longer
edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
than `size`. Only used if `default_to_square` is `False`. than `size`. Only used if `default_to_square` is `False`.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Returns: Returns:
`tuple`: The target (height, width) dimension of the output image after resizing. `tuple`: The target (height, width) dimension of the output image after resizing.
...@@ -241,7 +255,7 @@ def get_resize_output_image_size( ...@@ -241,7 +255,7 @@ def get_resize_output_image_size(
if default_to_square: if default_to_square:
return (size, size) return (size, size)
height, width = get_image_size(input_image) height, width = get_image_size(input_image, input_data_format)
short, long = (width, height) if width <= height else (height, width) short, long = (width, height) if width <= height else (height, width)
requested_new_short = size requested_new_short = size
...@@ -266,6 +280,7 @@ def resize( ...@@ -266,6 +280,7 @@ def resize(
reducing_gap: Optional[int] = None, reducing_gap: Optional[int] = None,
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
return_numpy: bool = True, return_numpy: bool = True,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Resizes `image` to `(height, width)` specified by `size` using the PIL library. Resizes `image` to `(height, width)` specified by `size` using the PIL library.
...@@ -285,6 +300,8 @@ def resize( ...@@ -285,6 +300,8 @@ def resize(
return_numpy (`bool`, *optional*, defaults to `True`): return_numpy (`bool`, *optional*, defaults to `True`):
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
returned. returned.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Returns: Returns:
`np.ndarray`: The resized image. `np.ndarray`: The resized image.
...@@ -298,14 +315,16 @@ def resize( ...@@ -298,14 +315,16 @@ def resize(
# For all transformations, we want to keep the same data format as the input image unless otherwise specified. # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
# The resized image from PIL will always have channels last, so find the input format first. # The resized image from PIL will always have channels last, so find the input format first.
data_format = infer_channel_dimension_format(image) if data_format is None else data_format if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
data_format = input_data_format if data_format is None else data_format
# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
# the pillow library to resize the image and then convert back to numpy # the pillow library to resize the image and then convert back to numpy
do_rescale = False do_rescale = False
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
do_rescale = _rescale_for_pil_conversion(image) do_rescale = _rescale_for_pil_conversion(image)
image = to_pil_image(image, do_rescale=do_rescale) image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
height, width = size height, width = size
# PIL images are in the format (width, height) # PIL images are in the format (width, height)
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap) resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
...@@ -330,6 +349,7 @@ def normalize( ...@@ -330,6 +349,7 @@ def normalize(
mean: Union[float, Iterable[float]], mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]], std: Union[float, Iterable[float]],
data_format: Optional[ChannelDimension] = None, data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Normalizes `image` using the mean and standard deviation specified by `mean` and `std`. Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
...@@ -345,12 +365,15 @@ def normalize( ...@@ -345,12 +365,15 @@ def normalize(
The standard deviation to use for normalization. The standard deviation to use for normalization.
data_format (`ChannelDimension`, *optional*): data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If unset, will use the inferred format from the input. The channel dimension format of the output image. If unset, will use the inferred format from the input.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
""" """
if not isinstance(image, np.ndarray): if not isinstance(image, np.ndarray):
raise ValueError("image must be a numpy array") raise ValueError("image must be a numpy array")
input_data_format = infer_channel_dimension_format(image) if input_data_format is None:
channel_axis = get_channel_dimension_axis(image) input_data_format = infer_channel_dimension_format(image)
channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
num_channels = image.shape[channel_axis] num_channels = image.shape[channel_axis]
if isinstance(mean, Iterable): if isinstance(mean, Iterable):
...@@ -372,7 +395,7 @@ def normalize( ...@@ -372,7 +395,7 @@ def normalize(
else: else:
image = ((image.T - mean) / std).T image = ((image.T - mean) / std).T
image = to_channel_dimension_format(image, data_format) if data_format is not None else image image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
return image return image
...@@ -380,6 +403,7 @@ def center_crop( ...@@ -380,6 +403,7 @@ def center_crop(
image: np.ndarray, image: np.ndarray,
size: Tuple[int, int], size: Tuple[int, int],
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
return_numpy: Optional[bool] = None, return_numpy: Optional[bool] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
...@@ -396,6 +420,11 @@ def center_crop( ...@@ -396,6 +420,11 @@ def center_crop(
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image. If unset, will use the inferred format of the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
return_numpy (`bool`, *optional*): return_numpy (`bool`, *optional*):
Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
previous ImageFeatureExtractionMixin method. previous ImageFeatureExtractionMixin method.
...@@ -418,13 +447,14 @@ def center_crop( ...@@ -418,13 +447,14 @@ def center_crop(
if not isinstance(size, Iterable) or len(size) != 2: if not isinstance(size, Iterable) or len(size) != 2:
raise ValueError("size must have 2 elements representing the height and width of the output image") raise ValueError("size must have 2 elements representing the height and width of the output image")
input_data_format = infer_channel_dimension_format(image) if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
output_data_format = data_format if data_format is not None else input_data_format output_data_format = data_format if data_format is not None else input_data_format
# We perform the crop in (C, H, W) format and then convert to the output format # We perform the crop in (C, H, W) format and then convert to the output format
image = to_channel_dimension_format(image, ChannelDimension.FIRST) image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
orig_height, orig_width = get_image_size(image) orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
crop_height, crop_width = size crop_height, crop_width = size
crop_height, crop_width = int(crop_height), int(crop_width) crop_height, crop_width = int(crop_height), int(crop_width)
...@@ -438,7 +468,7 @@ def center_crop( ...@@ -438,7 +468,7 @@ def center_crop(
# Check if cropped area is within image boundaries # Check if cropped area is within image boundaries
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width: if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
image = image[..., top:bottom, left:right] image = image[..., top:bottom, left:right]
image = to_channel_dimension_format(image, output_data_format) image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
return image return image
# Otherwise, we may need to pad if the image is too small. Oh joy... # Otherwise, we may need to pad if the image is too small. Oh joy...
...@@ -460,7 +490,7 @@ def center_crop( ...@@ -460,7 +490,7 @@ def center_crop(
right += left_pad right += left_pad
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)] new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
new_image = to_channel_dimension_format(new_image, output_data_format) new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
if not return_numpy: if not return_numpy:
new_image = to_pil_image(new_image) new_image = to_pil_image(new_image)
...@@ -705,7 +735,7 @@ def pad( ...@@ -705,7 +735,7 @@ def pad(
else: else:
raise ValueError(f"Invalid padding mode: {mode}") raise ValueError(f"Invalid padding mode: {mode}")
image = to_channel_dimension_format(image, data_format) if data_format is not None else image image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
return image return image
...@@ -728,7 +758,11 @@ def convert_to_rgb(image: ImageInput) -> ImageInput: ...@@ -728,7 +758,11 @@ def convert_to_rgb(image: ImageInput) -> ImageInput:
return image return image
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray: def flip_channel_order(
image: np.ndarray,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
""" """
Flips the channel order of the image. Flips the channel order of the image.
...@@ -742,9 +776,14 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension ...@@ -742,9 +776,14 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image. If unset, will use same as the input image.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
""" """
input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST: if input_data_format == ChannelDimension.LAST:
image = image[..., ::-1] image = image[..., ::-1]
elif input_data_format == ChannelDimension.FIRST: elif input_data_format == ChannelDimension.FIRST:
...@@ -753,5 +792,5 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension ...@@ -753,5 +792,5 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension
raise ValueError(f"Unsupported channel dimension: {input_data_format}") raise ValueError(f"Unsupported channel dimension: {input_data_format}")
if data_format is not None: if data_format is not None:
image = to_channel_dimension_format(image, data_format) image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image return image
...@@ -176,23 +176,28 @@ def infer_channel_dimension_format( ...@@ -176,23 +176,28 @@ def infer_channel_dimension_format(
raise ValueError("Unable to infer channel dimension format") raise ValueError("Unable to infer channel dimension format")
def get_channel_dimension_axis(image: np.ndarray) -> int: def get_channel_dimension_axis(
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
) -> int:
""" """
Returns the channel dimension axis of the image. Returns the channel dimension axis of the image.
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
The image to get the channel dimension axis of. The image to get the channel dimension axis of.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
Returns: Returns:
The channel dimension axis of the image. The channel dimension axis of the image.
""" """
channel_dim = infer_channel_dimension_format(image) if input_data_format is None:
if channel_dim == ChannelDimension.FIRST: input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.FIRST:
return image.ndim - 3 return image.ndim - 3
elif channel_dim == ChannelDimension.LAST: elif input_data_format == ChannelDimension.LAST:
return image.ndim - 1 return image.ndim - 1
raise ValueError(f"Unsupported data format: {channel_dim}") raise ValueError(f"Unsupported data format: {input_data_format}")
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]: def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
......
...@@ -185,6 +185,11 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -185,6 +185,11 @@ class ImageTransformsTester(unittest.TestCase):
image = to_channel_dimension_format(image, "channels_first") image = to_channel_dimension_format(image, "channels_first")
self.assertEqual(image.shape, (3, 4, 5)) self.assertEqual(image.shape, (3, 4, 5))
# Can pass in input_data_format and works if data format is ambiguous or unknown.
image = np.random.rand(4, 5, 6)
image = to_channel_dimension_format(image, "channels_first", input_channel_dim="channels_last")
self.assertEqual(image.shape, (6, 4, 5))
def test_get_resize_output_image_size(self): def test_get_resize_output_image_size(self):
image = np.random.randint(0, 256, (3, 224, 224)) image = np.random.randint(0, 256, (3, 224, 224))
...@@ -212,6 +217,14 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -212,6 +217,14 @@ class ImageTransformsTester(unittest.TestCase):
image = np.random.randint(0, 256, (3, 50, 40)) image = np.random.randint(0, 256, (3, 50, 40))
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17)) self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
# Test output size = (int(size * height / width), size) if size is an int and height > width and
# input has 4 channels
image = np.random.randint(0, 256, (4, 50, 40))
self.assertEqual(
get_resize_output_image_size(image, 20, default_to_square=False, input_data_format="channels_first"),
(25, 20),
)
# Test correct channel dimension is returned if output size if height == 3 # Test correct channel dimension is returned if output size if height == 3
# Defaults to input format - channels first # Defaults to input format - channels first
image = np.random.randint(0, 256, (3, 18, 97)) image = np.random.randint(0, 256, (3, 18, 97))
...@@ -258,6 +271,12 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -258,6 +271,12 @@ class ImageTransformsTester(unittest.TestCase):
self.assertTrue(np.all(resized_image >= 0)) self.assertTrue(np.all(resized_image >= 0))
self.assertTrue(np.all(resized_image <= 1)) self.assertTrue(np.all(resized_image <= 1))
# Check that an image with 4 channels is resized correctly
image = np.random.randint(0, 256, (4, 224, 224))
resized_image = resize(image, (30, 40), input_data_format="channels_first")
self.assertIsInstance(resized_image, np.ndarray)
self.assertEqual(resized_image.shape, (4, 30, 40))
def test_normalize(self): def test_normalize(self):
image = np.random.randint(0, 256, (224, 224, 3)) / 255 image = np.random.randint(0, 256, (224, 224, 3)) / 255
...@@ -285,6 +304,15 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -285,6 +304,15 @@ class ImageTransformsTester(unittest.TestCase):
self.assertEqual(normalized_image.shape, (3, 224, 224)) self.assertEqual(normalized_image.shape, (3, 224, 224))
self.assertTrue(np.allclose(normalized_image, expected_image)) self.assertTrue(np.allclose(normalized_image, expected_image))
# Test image with 4 channels is normalized correctly
image = np.random.randint(0, 256, (224, 224, 4)) / 255
mean = (0.5, 0.6, 0.7, 0.8)
std = (0.1, 0.2, 0.3, 0.4)
expected_image = (image - mean) / std
self.assertTrue(
np.allclose(normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image)
)
def test_center_crop(self): def test_center_crop(self):
image = np.random.randint(0, 256, (3, 224, 224)) image = np.random.randint(0, 256, (3, 224, 224))
...@@ -308,6 +336,11 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -308,6 +336,11 @@ class ImageTransformsTester(unittest.TestCase):
self.assertEqual(cropped_image.shape, (300, 260, 3)) self.assertEqual(cropped_image.shape, (300, 260, 3))
self.assertTrue(np.allclose(cropped_image, expected_image)) self.assertTrue(np.allclose(cropped_image, expected_image))
# Test image with 4 channels is cropped correctly
image = np.random.randint(0, 256, (224, 224, 4))
expected_image = image[52:172, 82:142, :]
self.assertTrue(np.allclose(center_crop(image, (120, 60), input_data_format="channels_last"), expected_image))
def test_center_to_corners_format(self): def test_center_to_corners_format(self):
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]]) bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]]) expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
...@@ -493,6 +526,22 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -493,6 +526,22 @@ class ImageTransformsTester(unittest.TestCase):
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last")) np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
) )
# Test we can pad on an image with 2 channels
# fmt: off
image = np.array([
[[0, 1], [2, 3]],
])
expected_image = np.array([
[[0, 0], [0, 1], [2, 3]],
[[0, 0], [0, 0], [0, 0]],
])
# fmt: on
self.assertTrue(
np.allclose(
expected_image, pad(image, ((0, 1), (1, 0)), mode="constant", input_data_format="channels_last")
)
)
@require_vision @require_vision
def test_convert_to_rgb(self): def test_convert_to_rgb(self):
# Test that an RGBA image is converted to RGB # Test that an RGBA image is converted to RGB
...@@ -559,3 +608,20 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -559,3 +608,20 @@ class ImageTransformsTester(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first) np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first)
) )
# Can flip when the image has 2 channels
# fmt: off
img_channels_first = np.array([
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
])
# fmt: on
flipped_img_channels_first = img_channels_first[::-1, :, :]
self.assertTrue(
np.allclose(
flip_channel_order(img_channels_first, input_data_format="channels_first"), flipped_img_channels_first
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment