Commit 8cd15cbe authored by Alykhan Tejani's avatar Alykhan Tejani Committed by GitHub
Browse files

Add FiveCrop and TenCrop transforms (#261)

add FiveCrop and TenCrop
parent a5b75c8a
...@@ -61,6 +61,66 @@ class Tester(unittest.TestCase): ...@@ -61,6 +61,66 @@ class Tester(unittest.TestCase):
assert sum2 > sum1, "height: " + str(height) + " width: " \ assert sum2 > sum1, "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
def test_five_crop(self):
to_pil_image = transforms.ToPILImage()
h = random.randint(5, 25)
w = random.randint(5, 25)
for single_dim in [True, False]:
crop_h = random.randint(1, h)
crop_w = random.randint(1, w)
if single_dim:
crop_h = min(crop_h, crop_w)
crop_w = crop_h
transform = transforms.FiveCrop(crop_h)
else:
transform = transforms.FiveCrop((crop_h, crop_w))
img = torch.FloatTensor(3, h, w).uniform_()
results = transform(to_pil_image(img))
assert len(results) == 5
for crop in results:
assert crop.size == (crop_w, crop_h)
to_pil_image = transforms.ToPILImage()
tl = to_pil_image(img[:, 0:crop_h, 0:crop_w])
tr = to_pil_image(img[:, 0:crop_h, w - crop_w:])
bl = to_pil_image(img[:, h - crop_h:, 0:crop_w])
br = to_pil_image(img[:, h - crop_h:, w - crop_w:])
center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img))
expected_output = (tl, tr, bl, br, center)
assert results == expected_output
def test_ten_crop(self):
to_pil_image = transforms.ToPILImage()
h = random.randint(5, 25)
w = random.randint(5, 25)
for should_vflip in [True, False]:
for single_dim in [True, False]:
crop_h = random.randint(1, h)
crop_w = random.randint(1, w)
if single_dim:
crop_h = min(crop_h, crop_w)
crop_w = crop_h
transform = transforms.TenCrop(crop_h, vflip=should_vflip)
five_crop = transforms.FiveCrop(crop_h)
else:
transform = transforms.TenCrop((crop_h, crop_w), vflip=should_vflip)
five_crop = transforms.FiveCrop((crop_h, crop_w))
img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
results = transform(img)
expected_output = five_crop(img)
if should_vflip:
vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
expected_output += five_crop(vflipped_img)
else:
hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
expected_output += five_crop(hflipped_img)
assert len(results) == 10
assert expected_output == results
def test_scale(self): def test_scale(self):
height = random.randint(24, 32) * 2 height = random.randint(24, 32) * 2
width = random.randint(24, 32) * 2 width = random.randint(24, 32) * 2
......
...@@ -638,3 +638,72 @@ class RandomSizedCrop(object): ...@@ -638,3 +638,72 @@ class RandomSizedCrop(object):
""" """
i, j, h, w = self.get_params(img) i, j, h, w = self.get_params(img)
return scaled_crop(img, i, j, h, w, self.size, self.interpolation) return scaled_crop(img, i, j, h, w, self.size, self.interpolation)
class FiveCrop(object):
"""Crop the given PIL.Image into four corners and the central crop.abs
Note: this transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your `Dataset` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
"""
def __init__(self, size):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
def __call__(self, img):
w, h = img.size
crop_h, crop_w = self.size
if crop_w > w or crop_h > h:
raise ValueError("Requested crop size {} is bigger than input size {}".format(self.size,
(h, w)))
tl = img.crop((0, 0, crop_w, crop_h))
tr = img.crop((w - crop_w, 0, w, crop_h))
bl = img.crop((0, h - crop_h, crop_w, h))
br = img.crop((w - crop_w, h - crop_h, w, h))
center = CenterCrop((crop_h, crop_w))(img)
return (tl, tr, bl, br, center)
class TenCrop(object):
"""Crop the given PIL.Image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default)
Note: this transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your `Dataset` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
vflip bool: Use vertical flipping instead of horizontal
"""
def __init__(self, size, vflip=False):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
self.vflip = vflip
def __call__(self, img):
five_crop = FiveCrop(self.size)
first_five = five_crop(img)
if self.vflip:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
else:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
second_five = five_crop(img)
return first_five + second_five
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment