Unverified Commit 11757e2b authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add input_data_format argument, image transforms (#25462)

* Enable specifying input data format - overriding inferring

* Add tests
parent 0acf5622
......@@ -63,6 +63,8 @@ def to_channel_dimension_format(
The image to have its channel dimension set.
channel_dim (`ChannelDimension`):
The channel dimension format to use.
input_channel_dim (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
Returns:
`np.ndarray`: The image with the channel dimension set to `channel_dim`.
......@@ -88,7 +90,11 @@ def to_channel_dimension_format(
def rescale(
image: np.ndarray, scale: float, data_format: Optional[ChannelDimension] = None, dtype=np.float32
image: np.ndarray,
scale: float,
data_format: Optional[ChannelDimension] = None,
dtype: np.dtype = np.float32,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Rescales `image` by `scale`.
......@@ -103,6 +109,8 @@ def rescale(
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
extractors.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred from the input image.
Returns:
`np.ndarray`: The rescaled image.
......@@ -112,7 +120,7 @@ def rescale(
rescaled_image = image * scale
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
rescaled_image = rescaled_image.astype(dtype)
......@@ -149,6 +157,7 @@ def _rescale_for_pil_conversion(image):
def to_pil_image(
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
do_rescale: Optional[bool] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> "PIL.Image.Image":
"""
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
......@@ -161,6 +170,8 @@ def to_pil_image(
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
and `False` otherwise.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Returns:
`PIL.Image.Image`: The converted image.
......@@ -179,7 +190,7 @@ def to_pil_image(
raise ValueError("Input image type not supported: {}".format(type(image)))
# If the channel as been moved to first dim, we put it back at the end.
image = to_channel_dimension_format(image, ChannelDimension.LAST)
image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
......@@ -200,6 +211,7 @@ def get_resize_output_image_size(
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
default_to_square: bool = True,
max_size: Optional[int] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> tuple:
"""
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
......@@ -225,6 +237,8 @@ def get_resize_output_image_size(
than `max_size` after being resized according to `size`, then the image is resized again so that the longer
edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
than `size`. Only used if `default_to_square` is `False`.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Returns:
`tuple`: The target (height, width) dimension of the output image after resizing.
......@@ -241,7 +255,7 @@ def get_resize_output_image_size(
if default_to_square:
return (size, size)
height, width = get_image_size(input_image)
height, width = get_image_size(input_image, input_data_format)
short, long = (width, height) if width <= height else (height, width)
requested_new_short = size
......@@ -266,6 +280,7 @@ def resize(
reducing_gap: Optional[int] = None,
data_format: Optional[ChannelDimension] = None,
return_numpy: bool = True,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Resizes `image` to `(height, width)` specified by `size` using the PIL library.
......@@ -285,6 +300,8 @@ def resize(
return_numpy (`bool`, *optional*, defaults to `True`):
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
returned.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Returns:
`np.ndarray`: The resized image.
......@@ -298,14 +315,16 @@ def resize(
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
# The resized image from PIL will always have channels last, so find the input format first.
data_format = infer_channel_dimension_format(image) if data_format is None else data_format
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
data_format = input_data_format if data_format is None else data_format
# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
# the pillow library to resize the image and then convert back to numpy
do_rescale = False
if not isinstance(image, PIL.Image.Image):
do_rescale = _rescale_for_pil_conversion(image)
image = to_pil_image(image, do_rescale=do_rescale)
image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
height, width = size
# PIL images are in the format (width, height)
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
......@@ -330,6 +349,7 @@ def normalize(
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
......@@ -345,12 +365,15 @@ def normalize(
The standard deviation to use for normalization.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If unset, will use the inferred format from the input.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
"""
if not isinstance(image, np.ndarray):
raise ValueError("image must be a numpy array")
input_data_format = infer_channel_dimension_format(image)
channel_axis = get_channel_dimension_axis(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
num_channels = image.shape[channel_axis]
if isinstance(mean, Iterable):
......@@ -372,7 +395,7 @@ def normalize(
else:
image = ((image.T - mean) / std).T
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
return image
......@@ -380,6 +403,7 @@ def center_crop(
image: np.ndarray,
size: Tuple[int, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
return_numpy: Optional[bool] = None,
) -> np.ndarray:
"""
......@@ -396,6 +420,11 @@ def center_crop(
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
return_numpy (`bool`, *optional*):
Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
previous ImageFeatureExtractionMixin method.
......@@ -418,13 +447,14 @@ def center_crop(
if not isinstance(size, Iterable) or len(size) != 2:
raise ValueError("size must have 2 elements representing the height and width of the output image")
input_data_format = infer_channel_dimension_format(image)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
output_data_format = data_format if data_format is not None else input_data_format
# We perform the crop in (C, H, W) format and then convert to the output format
image = to_channel_dimension_format(image, ChannelDimension.FIRST)
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
orig_height, orig_width = get_image_size(image)
orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
crop_height, crop_width = size
crop_height, crop_width = int(crop_height), int(crop_width)
......@@ -438,7 +468,7 @@ def center_crop(
# Check if cropped area is within image boundaries
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
image = image[..., top:bottom, left:right]
image = to_channel_dimension_format(image, output_data_format)
image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
return image
# Otherwise, we may need to pad if the image is too small. Oh joy...
......@@ -460,7 +490,7 @@ def center_crop(
right += left_pad
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
new_image = to_channel_dimension_format(new_image, output_data_format)
new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
if not return_numpy:
new_image = to_pil_image(new_image)
......@@ -705,7 +735,7 @@ def pad(
else:
raise ValueError(f"Invalid padding mode: {mode}")
image = to_channel_dimension_format(image, data_format) if data_format is not None else image
image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
return image
......@@ -728,7 +758,11 @@ def convert_to_rgb(image: ImageInput) -> ImageInput:
return image
def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension] = None) -> np.ndarray:
def flip_channel_order(
image: np.ndarray,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Flips the channel order of the image.
......@@ -742,9 +776,14 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
"""
input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.LAST:
image = image[..., ::-1]
elif input_data_format == ChannelDimension.FIRST:
......@@ -753,5 +792,5 @@ def flip_channel_order(image: np.ndarray, data_format: Optional[ChannelDimension
raise ValueError(f"Unsupported channel dimension: {input_data_format}")
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
......@@ -176,23 +176,28 @@ def infer_channel_dimension_format(
raise ValueError("Unable to infer channel dimension format")
def get_channel_dimension_axis(image: np.ndarray) -> int:
def get_channel_dimension_axis(
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
) -> int:
"""
Returns the channel dimension axis of the image.
Args:
image (`np.ndarray`):
The image to get the channel dimension axis of.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
Returns:
The channel dimension axis of the image.
"""
channel_dim = infer_channel_dimension_format(image)
if channel_dim == ChannelDimension.FIRST:
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
if input_data_format == ChannelDimension.FIRST:
return image.ndim - 3
elif channel_dim == ChannelDimension.LAST:
elif input_data_format == ChannelDimension.LAST:
return image.ndim - 1
raise ValueError(f"Unsupported data format: {channel_dim}")
raise ValueError(f"Unsupported data format: {input_data_format}")
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
......
......@@ -185,6 +185,11 @@ class ImageTransformsTester(unittest.TestCase):
image = to_channel_dimension_format(image, "channels_first")
self.assertEqual(image.shape, (3, 4, 5))
# Can pass in input_data_format and works if data format is ambiguous or unknown.
image = np.random.rand(4, 5, 6)
image = to_channel_dimension_format(image, "channels_first", input_channel_dim="channels_last")
self.assertEqual(image.shape, (6, 4, 5))
def test_get_resize_output_image_size(self):
image = np.random.randint(0, 256, (3, 224, 224))
......@@ -212,6 +217,14 @@ class ImageTransformsTester(unittest.TestCase):
image = np.random.randint(0, 256, (3, 50, 40))
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
# Test output size = (int(size * height / width), size) if size is an int and height > width and
# input has 4 channels
image = np.random.randint(0, 256, (4, 50, 40))
self.assertEqual(
get_resize_output_image_size(image, 20, default_to_square=False, input_data_format="channels_first"),
(25, 20),
)
# Test correct channel dimension is returned if output size if height == 3
# Defaults to input format - channels first
image = np.random.randint(0, 256, (3, 18, 97))
......@@ -258,6 +271,12 @@ class ImageTransformsTester(unittest.TestCase):
self.assertTrue(np.all(resized_image >= 0))
self.assertTrue(np.all(resized_image <= 1))
# Check that an image with 4 channels is resized correctly
image = np.random.randint(0, 256, (4, 224, 224))
resized_image = resize(image, (30, 40), input_data_format="channels_first")
self.assertIsInstance(resized_image, np.ndarray)
self.assertEqual(resized_image.shape, (4, 30, 40))
def test_normalize(self):
image = np.random.randint(0, 256, (224, 224, 3)) / 255
......@@ -285,6 +304,15 @@ class ImageTransformsTester(unittest.TestCase):
self.assertEqual(normalized_image.shape, (3, 224, 224))
self.assertTrue(np.allclose(normalized_image, expected_image))
# Test image with 4 channels is normalized correctly
image = np.random.randint(0, 256, (224, 224, 4)) / 255
mean = (0.5, 0.6, 0.7, 0.8)
std = (0.1, 0.2, 0.3, 0.4)
expected_image = (image - mean) / std
self.assertTrue(
np.allclose(normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image)
)
def test_center_crop(self):
image = np.random.randint(0, 256, (3, 224, 224))
......@@ -308,6 +336,11 @@ class ImageTransformsTester(unittest.TestCase):
self.assertEqual(cropped_image.shape, (300, 260, 3))
self.assertTrue(np.allclose(cropped_image, expected_image))
# Test image with 4 channels is cropped correctly
image = np.random.randint(0, 256, (224, 224, 4))
expected_image = image[52:172, 82:142, :]
self.assertTrue(np.allclose(center_crop(image, (120, 60), input_data_format="channels_last"), expected_image))
def test_center_to_corners_format(self):
bbox_center = np.array([[10, 20, 4, 8], [15, 16, 3, 4]])
expected = np.array([[8, 16, 12, 24], [13.5, 14, 16.5, 18]])
......@@ -493,6 +526,22 @@ class ImageTransformsTester(unittest.TestCase):
np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last"))
)
# Test we can pad on an image with 2 channels
# fmt: off
image = np.array([
[[0, 1], [2, 3]],
])
expected_image = np.array([
[[0, 0], [0, 1], [2, 3]],
[[0, 0], [0, 0], [0, 0]],
])
# fmt: on
self.assertTrue(
np.allclose(
expected_image, pad(image, ((0, 1), (1, 0)), mode="constant", input_data_format="channels_last")
)
)
@require_vision
def test_convert_to_rgb(self):
# Test that an RGBA image is converted to RGB
......@@ -559,3 +608,20 @@ class ImageTransformsTester(unittest.TestCase):
self.assertTrue(
np.allclose(flip_channel_order(img_channels_last, "channels_first"), flipped_img_channels_first)
)
# Can flip when the image has 2 channels
# fmt: off
img_channels_first = np.array([
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
])
# fmt: on
flipped_img_channels_first = img_channels_first[::-1, :, :]
self.assertTrue(
np.allclose(
flip_channel_order(img_channels_first, input_data_format="channels_first"), flipped_img_channels_first
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment