Unverified Commit 05cda5df authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

🚨🚨🚨 Fix rescale ViVit Efficientnet (#25174)

* Fix rescaling bug

* Add tests

* Update integration tests

* Fix up

* Update src/transformers/image_transforms.py

* Update test - new possible order in list
parent 03f98f96
...@@ -110,10 +110,11 @@ def rescale( ...@@ -110,10 +110,11 @@ def rescale(
if not isinstance(image, np.ndarray): if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
image = image.astype(dtype)
rescaled_image = image * scale rescaled_image = image * scale
if data_format is not None: if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format) rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image return rescaled_image
......
...@@ -153,7 +153,13 @@ class EfficientNetImageProcessor(BaseImageProcessor): ...@@ -153,7 +153,13 @@ class EfficientNetImageProcessor(BaseImageProcessor):
**kwargs, **kwargs,
): ):
""" """
Rescale an image by a scale factor. image = image * scale. Rescale an image by a scale factor.
If offset is True, the image is rescaled between [-1, 1].
image = image * scale * 2 - 1
If offset is False, the image is rescaled between [0, 1].
image = image * scale
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
...@@ -165,13 +171,12 @@ class EfficientNetImageProcessor(BaseImageProcessor): ...@@ -165,13 +171,12 @@ class EfficientNetImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
scale = scale * 2 if offset else scale
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
if offset: if offset:
rescaled_image = (image - 127.5) * scale rescaled_image = rescaled_image - 1
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = rescaled_image.astype(np.float32)
else:
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
return rescaled_image return rescaled_image
def preprocess( def preprocess(
......
...@@ -167,6 +167,7 @@ class VivitImageProcessor(BaseImageProcessor): ...@@ -167,6 +167,7 @@ class VivitImageProcessor(BaseImageProcessor):
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}") raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
# Copied from transformers.models.efficientnet.image_processing_efficientnet.EfficientNetImageProcessor.rescale
def rescale( def rescale(
self, self,
image: np.ndarray, image: np.ndarray,
...@@ -178,23 +179,29 @@ class VivitImageProcessor(BaseImageProcessor): ...@@ -178,23 +179,29 @@ class VivitImageProcessor(BaseImageProcessor):
""" """
Rescale an image by a scale factor. Rescale an image by a scale factor.
If offset is `True`, image scaled between [-1, 1]: image = (image - 127.5) * scale. If offset is `False`, image If offset is True, the image is rescaled between [-1, 1].
scaled between [0, 1]: image = image * scale image = image * scale * 2 - 1
If offset is False, the image is rescaled between [0, 1].
image = image * scale
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
Image to rescale. Image to rescale.
scale (`int` or `float`): scale (`int` or `float`):
Scale to apply to the image. Scale to apply to the image.
offset (`bool`, *optional*): offset (`bool`, *optional*):
Whether to scale the image in both negative and positive directions. Whether to scale the image in both negative and positive directions.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
image = image.astype(np.float32) scale = scale * 2 if offset else scale
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
if offset: if offset:
image = image - (scale / 2) rescaled_image = rescaled_image - 1
return rescale(image, scale=scale, data_format=data_format, **kwargs)
return rescaled_image
def _preprocess_image( def _preprocess_image(
self, self,
......
...@@ -193,3 +193,17 @@ class EfficientNetImageProcessorTest(ImageProcessingSavingTestMixin, unittest.Te ...@@ -193,3 +193,17 @@ class EfficientNetImageProcessorTest(ImageProcessingSavingTestMixin, unittest.Te
self.image_processor_tester.size["width"], self.image_processor_tester.size["width"],
), ),
) )
def test_rescale(self):
# EfficientNet optionally rescales between -1 and 1 instead of the usual 0 and 1
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)
image_processor = self.image_processing_class(**self.image_processor_dict)
rescaled_image = image_processor.rescale(image, scale=1 / 255)
expected_image = image.astype(np.float32) * (2 / 255.0) - 1
self.assertTrue(np.allclose(rescaled_image, expected_image))
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
expected_image = image.astype(np.float32) / 255.0
self.assertTrue(np.allclose(rescaled_image, expected_image))
...@@ -212,3 +212,17 @@ class VivitImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase ...@@ -212,3 +212,17 @@ class VivitImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase
self.image_processor_tester.crop_size["width"], self.image_processor_tester.crop_size["width"],
), ),
) )
def test_rescale(self):
# ViVit optionally rescales between -1 and 1 instead of the usual 0 and 1
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)
image_processor = self.image_processing_class(**self.image_processor_dict)
rescaled_image = image_processor.rescale(image, scale=1 / 255)
expected_image = image.astype(np.float32) * (2 / 255.0) - 1
self.assertTrue(np.allclose(rescaled_image, expected_image))
rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
expected_image = image.astype(np.float32) / 255.0
self.assertTrue(np.allclose(rescaled_image, expected_image))
...@@ -345,6 +345,6 @@ class VivitModelIntegrationTest(unittest.TestCase): ...@@ -345,6 +345,6 @@ class VivitModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
# taken from original model # taken from original model
expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658]).to(torch_device) expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4))
...@@ -85,6 +85,7 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase): ...@@ -85,6 +85,7 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase):
[ [
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}], [{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}],
[{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}, {"score": 0.333, "label": "b"}], [{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}, {"score": 0.333, "label": "b"}],
[{"score": 0.333, "label": "b"}, {"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}],
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment