Unverified Commit 9d2cee8b authored by Tobias Norlund's avatar Tobias Norlund Committed by GitHub
Browse files

CLIPFeatureExtractor should resize images with kept aspect ratio (#11994)



* Resize with kept aspect ratio

* Fixed failed test

* Overload center_crop and resize methods instead

* resize should handle non-PIL images

* update slow test

* Tensor => tensor
Co-authored-by: default avatarpatil-suraj <surajp815@gmail.com>
parent 472a8676
...@@ -154,3 +154,56 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): ...@@ -154,3 +154,56 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs return encoded_inputs
def center_crop(self, image, size):
"""
Crops :obj:`image` to the given size using a center crop. Note that if the image is too small to be cropped to
the size is given, it will be padded (so the returned result has the size asked).
Args:
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The image to resize.
size (:obj:`int` or :obj:`Tuple[int, int]`):
The size to which crop the image.
"""
self._ensure_format_supported(image)
if not isinstance(size, tuple):
size = (size, size)
if not isinstance(image, Image.Image):
image = self.to_pil_image(image)
image_width, image_height = image.size
crop_height, crop_width = size
crop_top = int((image_height - crop_height + 1) * 0.5)
crop_left = int((image_width - crop_width + 1) * 0.5)
return image.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
def resize(self, image, size, resample=Image.BICUBIC):
"""
Resizes :obj:`image`. Note that this will trigger a conversion of :obj:`image` to a PIL Image.
Args:
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The image to resize.
size (:obj:`int` or :obj:`Tuple[int, int]`):
The size to use for resizing the image. If :obj:`int` it will be resized to match the shorter side
resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BILINEAR`):
The filter to user for resampling.
"""
self._ensure_format_supported(image)
if not isinstance(image, Image.Image):
image = self.to_pil_image(image)
if isinstance(size, tuple):
new_w, new_h = size
else:
width, height = image.size
short, long = (width, height) if width <= height else (height, width)
if short == size:
return image
new_short, new_long = size, int(size * long / short)
new_w, new_h = (new_short, new_long) if width <= height else (new_long, new_short)
return image.resize((new_w, new_h), resample)
...@@ -544,7 +544,8 @@ class CLIPModelIntegrationTest(unittest.TestCase): ...@@ -544,7 +544,8 @@ class CLIPModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
# forward pass # forward pass
outputs = model(**inputs) with torch.no_grad():
outputs = model(**inputs)
# verify the logits # verify the logits
self.assertEqual( self.assertEqual(
...@@ -556,6 +557,6 @@ class CLIPModelIntegrationTest(unittest.TestCase): ...@@ -556,6 +557,6 @@ class CLIPModelIntegrationTest(unittest.TestCase):
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
) )
expected_logits = torch.tensor([[24.5056, 18.8076]], device=torch_device) expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment