Unverified Commit d5096a7f authored by Saurabh Khanduja's avatar Saurabh Khanduja Committed by GitHub
Browse files

Bugfix - same output for PIL and tensor when centercrop size is greater than imgsize (#3333)



* Renamed original method to test center crop

* Added test method, docs and added padding when imgsize < cropsize.

* BugFix - keep odd_crop_size odd

* Do not crop when image size after padding matches crop size; updated test.
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 859a535f
...@@ -91,7 +91,7 @@ If you have modified the code by adding a new feature or a bug-fix, please add u ...@@ -91,7 +91,7 @@ If you have modified the code by adding a new feature or a bug-fix, please add u
test: test:
```bash ```bash
pytest test/<test-module.py> -vvv -k <test_myfunc> pytest test/<test-module.py> -vvv -k <test_myfunc>
# e.g. pytest test/test_transforms.py -vvv -k test_crop # e.g. pytest test/test_transforms.py -vvv -k test_center_crop
``` ```
If you would like to run all tests: If you would like to run all tests:
......
import itertools
import os import os
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
...@@ -29,7 +30,7 @@ GRACE_HOPPER = get_file_path_2( ...@@ -29,7 +30,7 @@ GRACE_HOPPER = get_file_path_2(
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def test_crop(self): def test_center_crop(self):
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2 oheight = random.randint(5, (height - 2) / 2) * 2
...@@ -70,6 +71,64 @@ class Tester(unittest.TestCase): ...@@ -70,6 +71,64 @@ class Tester(unittest.TestCase):
self.assertGreater(sum2, sum1, self.assertGreater(sum2, sum1,
"height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth)) "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
def test_center_crop_2(self):
""" Tests when center crop size is larger than image size, along any dimension"""
even_image_size = (random.randint(10, 32) * 2, random.randint(10, 32) * 2)
odd_image_size = (even_image_size[0] + 1, even_image_size[1] + 1)
# Since height is independent of width, we can ignore images with odd height and even width and vice-versa.
input_image_sizes = [even_image_size, odd_image_size]
# Get different crop sizes
delta = random.choice((1, 3, 5))
crop_size_delta = [-2 * delta, -delta, 0, delta, 2 * delta]
crop_size_params = itertools.product(input_image_sizes, crop_size_delta, crop_size_delta)
for (input_image_size, delta_height, delta_width) in crop_size_params:
img = torch.ones(3, *input_image_size)
crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width)
# Test both transforms, one with PIL input and one with tensor
output_pil = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop(crop_size),
transforms.ToTensor()],
)(img)
self.assertEqual(output_pil.size()[1:3], crop_size,
"image_size: {} crop_size: {}".format(input_image_size, crop_size))
output_tensor = transforms.CenterCrop(crop_size)(img)
self.assertEqual(output_tensor.size()[1:3], crop_size,
"image_size: {} crop_size: {}".format(input_image_size, crop_size))
# Ensure output for PIL and Tensor are equal
self.assertEqual((output_tensor - output_pil).sum(), 0,
"image_size: {} crop_size: {}".format(input_image_size, crop_size))
# Check if content in center of both image and cropped output is same.
center_size = (min(crop_size[0], input_image_size[0]), min(crop_size[1], input_image_size[1]))
crop_center_tl, input_center_tl = [0, 0], [0, 0]
for index in range(2):
if crop_size[index] > input_image_size[index]:
crop_center_tl[index] = (crop_size[index] - input_image_size[index]) // 2
else:
input_center_tl[index] = (input_image_size[index] - crop_size[index]) // 2
output_center = output_pil[
:,
crop_center_tl[0]:crop_center_tl[0] + center_size[0],
crop_center_tl[1]:crop_center_tl[1] + center_size[1]
]
img_center = img[
:,
input_center_tl[0]:input_center_tl[0] + center_size[0],
input_center_tl[1]:input_center_tl[1] + center_size[1]
]
self.assertEqual((output_center - img_center).sum(), 0,
"image_size: {} crop_size: {}".format(input_image_size, crop_size))
def test_five_crop(self): def test_five_crop(self):
to_pil_image = transforms.ToPILImage() to_pil_image = transforms.ToPILImage()
h = random.randint(5, 25) h = random.randint(5, 25)
......
...@@ -451,7 +451,8 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: ...@@ -451,7 +451,8 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
def center_crop(img: Tensor, output_size: List[int]) -> Tensor: def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
"""Crops the given image at the center. """Crops the given image at the center.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args: Args:
img (PIL Image or Tensor): Image to be cropped. img (PIL Image or Tensor): Image to be cropped.
...@@ -469,6 +470,18 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: ...@@ -469,6 +470,18 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
image_width, image_height = _get_image_size(img) image_width, image_height = _get_image_size(img)
crop_height, crop_width = output_size crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
image_width, image_height = _get_image_size(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.)) crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.)) crop_left = int(round((image_width - crop_width) / 2.))
return crop(img, crop_top, crop_left, crop_height, crop_width) return crop(img, crop_top, crop_left, crop_height, crop_width)
......
...@@ -290,7 +290,8 @@ class Scale(Resize): ...@@ -290,7 +290,8 @@ class Scale(Resize):
class CenterCrop(torch.nn.Module): class CenterCrop(torch.nn.Module):
"""Crops the given image at the center. """Crops the given image at the center.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment