Commit cec7ea72 authored by Ankit Jha's avatar Ankit Jha Committed by Francisco Massa
Browse files

Add scriptable transform: center_crop, five crop and ten_crop (#1615)

* add scriptable transform: center_crop

* add test: center_crop

* add scriptable transform: five_crop

* add scriptable transform: five_crop

* add scriptable transform: fix minor issues
parent e3a13055
......@@ -76,6 +76,53 @@ class Tester(unittest.TestCase):
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001)
def test_center_crop(self):
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
def test_five_crop(self):
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10])
self.assertTrue(torch.equal(cropped_tensor[0],
(transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[1],
(transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[2],
(transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[3],
(transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[4],
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
def test_ten_crop(self):
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10])
self.assertTrue(torch.equal(cropped_tensor[0],
(transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[1],
(transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[2],
(transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[3],
(transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[4],
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[5],
(transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[6],
(transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[7],
(transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[8],
(transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[9],
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
if __name__ == '__main__':
unittest.main()
......@@ -125,6 +125,97 @@ def adjust_saturation(img, saturation_factor):
return _blend(img, rgb_to_grayscale(img), saturation_factor)
def center_crop(img, output_size):
"""Crop the Image Tensor and resize it to desired size.
Args:
img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
output_size (sequence or int): (height, width) of the crop box. If int,
it is used for both directions
Returns:
Tensor: Cropped image.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')
_, image_width, image_height = img.size()
crop_height, crop_width = output_size
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
return crop(img, crop_top, crop_left, crop_height, crop_width)
def five_crop(img, size):
"""Crop the given Image Tensor into four corners and the central crop.
.. Note::
This transform returns a tuple of Tensors and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
Returns:
tuple: tuple (tl, tr, bl, br, center)
Corresponding top left, top right, bottom left, bottom right and center crop.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
_, image_width, image_height = img.size()
crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop(img, 0, 0, crop_width, crop_height)
tr = crop(img, image_width - crop_width, 0, image_width, crop_height)
bl = crop(img, 0, image_height - crop_height, crop_width, image_height)
br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height)
center = center_crop(img, (crop_height, crop_width))
return (tl, tr, bl, br, center)
def ten_crop(img, size, vertical_flip=False):
"""Crop the given Image Tensor into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
vertical_flip (bool): Use vertical flipping instead of horizontal
Returns:
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
Corresponding top left, top right, bottom left, bottom right and center crop
and same for the flipped image's tensor.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
first_five = five_crop(img, size)
if vertical_flip:
img = vflip(img)
else:
img = hflip(img)
second_five = five_crop(img, size)
return first_five + second_five
def _blend(img1, img2, ratio):
bound = 1 if img1.dtype.is_floating_point else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment