Unverified Commit cf1b8c34 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Fix donut image processor (#20625)

* fix donut image processor

* Update test values

* Apply lower bound on resizing size

* Add in missing size param

* Resolve resize channel_dimension bug

* Update src/transformers/image_transforms.py
parent e3cc4487
...@@ -48,7 +48,11 @@ if is_flax_available(): ...@@ -48,7 +48,11 @@ if is_flax_available():
import jax.numpy as jnp import jax.numpy as jnp
def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray: def to_channel_dimension_format(
image: np.ndarray,
channel_dim: Union[ChannelDimension, str],
input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
) -> np.ndarray:
""" """
Converts `image` to the channel dimension format specified by `channel_dim`. Converts `image` to the channel dimension format specified by `channel_dim`.
...@@ -64,9 +68,11 @@ def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDim ...@@ -64,9 +68,11 @@ def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDim
if not isinstance(image, np.ndarray): if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
current_channel_dim = infer_channel_dimension_format(image) if input_channel_dim is None:
input_channel_dim = infer_channel_dimension_format(image)
target_channel_dim = ChannelDimension(channel_dim) target_channel_dim = ChannelDimension(channel_dim)
if current_channel_dim == target_channel_dim: if input_channel_dim == target_channel_dim:
return image return image
if target_channel_dim == ChannelDimension.FIRST: if target_channel_dim == ChannelDimension.FIRST:
...@@ -152,6 +158,7 @@ def to_pil_image( ...@@ -152,6 +158,7 @@ def to_pil_image(
return PIL.Image.fromarray(image) return PIL.Image.fromarray(image)
# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
def get_resize_output_image_size( def get_resize_output_image_size(
input_image: np.ndarray, input_image: np.ndarray,
size: Union[int, Tuple[int, int], List[int], Tuple[int]], size: Union[int, Tuple[int, int], List[int], Tuple[int]],
...@@ -202,9 +209,6 @@ def get_resize_output_image_size( ...@@ -202,9 +209,6 @@ def get_resize_output_image_size(
short, long = (width, height) if width <= height else (height, width) short, long = (width, height) if width <= height else (height, width)
requested_new_short = size requested_new_short = size
if short == requested_new_short:
return (height, width)
new_short, new_long = requested_new_short, int(requested_new_short * long / short) new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None: if max_size is not None:
...@@ -271,7 +275,10 @@ def resize( ...@@ -271,7 +275,10 @@ def resize(
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
# so we need to add it back if necessary. # so we need to add it back if necessary.
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
resized_image = to_channel_dimension_format(resized_image, data_format) # The image is always in channels last format after converting from a PIL image
resized_image = to_channel_dimension_format(
resized_image, data_format, input_channel_dim=ChannelDimension.LAST
)
return resized_image return resized_image
......
...@@ -210,7 +210,8 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -210,7 +210,8 @@ class DonutImageProcessor(BaseImageProcessor):
**kwargs **kwargs
) -> np.ndarray: ) -> np.ndarray:
""" """
Resize the image to the specified size using thumbnail method. Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
corresponding dimension of the specified size.
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
...@@ -222,8 +223,24 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -222,8 +223,24 @@ class DonutImageProcessor(BaseImageProcessor):
data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
The data format of the output image. If unset, the same format as the input image is used. The data format of the output image. If unset, the same format as the input image is used.
""" """
output_size = (size["height"], size["width"]) input_height, input_width = get_image_size(image)
return resize(image, size=output_size, resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs) output_height, output_width = size["height"], size["width"]
# We always resize to the smallest of either the input or output size.
height = min(input_height, output_height)
width = min(input_width, output_width)
if height == input_height and width == input_width:
return image
if input_height > input_width:
width = int(input_width * height / input_height)
elif input_width > input_height:
height = int(input_height * width / input_width)
return resize(
image, size=(height, width), resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs
)
def resize( def resize(
self, self,
...@@ -250,7 +267,8 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -250,7 +267,8 @@ class DonutImageProcessor(BaseImageProcessor):
size = get_size_dict(size) size = get_size_dict(size)
shortest_edge = min(size["height"], size["width"]) shortest_edge = min(size["height"], size["width"])
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False) output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) resized_image = resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
return resized_image
def rescale( def rescale(
self, self,
...@@ -403,7 +421,7 @@ class DonutImageProcessor(BaseImageProcessor): ...@@ -403,7 +421,7 @@ class DonutImageProcessor(BaseImageProcessor):
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if do_align_long_axis: if do_align_long_axis:
images = [self.align_long_axis(image) for image in images] images = [self.align_long_axis(image, size=size) for image in images]
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [self.resize(image=image, size=size, resample=resample) for image in images]
......
...@@ -836,7 +836,7 @@ class DonutModelIntegrationTest(unittest.TestCase): ...@@ -836,7 +836,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size([1, 1, 57532]) expected_shape = torch.Size([1, 1, 57532])
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([24.2731, -6.4522, 32.4130]).to(torch_device) expected_slice = torch.tensor([24.3873, -6.4491, 32.5394]).to(torch_device)
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
# step 2: generation # step 2: generation
...@@ -872,7 +872,7 @@ class DonutModelIntegrationTest(unittest.TestCase): ...@@ -872,7 +872,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
self.assertEqual(len(outputs.scores), 11) self.assertEqual(len(outputs.scores), 11)
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
outputs.scores[0][0, :3], torch.tensor([5.3153, -3.5276, 13.4781], device=torch_device), atol=1e-4 outputs.scores[0][0, :3], torch.tensor([5.6019, -3.5070, 13.7123], device=torch_device), atol=1e-4
) )
) )
......
...@@ -184,6 +184,25 @@ class ImageTransformsTester(unittest.TestCase): ...@@ -184,6 +184,25 @@ class ImageTransformsTester(unittest.TestCase):
image = np.random.randint(0, 256, (3, 50, 40)) image = np.random.randint(0, 256, (3, 50, 40))
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17)) self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
# Test correct channel dimension is returned if output size if height == 3
# Defaults to input format - channels first
image = np.random.randint(0, 256, (3, 18, 97))
resized_image = resize(image, (3, 20))
self.assertEqual(resized_image.shape, (3, 3, 20))
# Defaults to input format - channels last
image = np.random.randint(0, 256, (18, 97, 3))
resized_image = resize(image, (3, 20))
self.assertEqual(resized_image.shape, (3, 20, 3))
image = np.random.randint(0, 256, (3, 18, 97))
resized_image = resize(image, (3, 20), data_format="channels_last")
self.assertEqual(resized_image.shape, (3, 20, 3))
image = np.random.randint(0, 256, (18, 97, 3))
resized_image = resize(image, (3, 20), data_format="channels_first")
self.assertEqual(resized_image.shape, (3, 3, 20))
def test_resize(self): def test_resize(self):
image = np.random.randint(0, 256, (3, 224, 224)) image = np.random.randint(0, 256, (3, 224, 224))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment