Commit 2e4a3f37 authored by Sasank Chilamkurthy's avatar Sasank Chilamkurthy Committed by Alykhan Tejani
Browse files

TenCrop and FiveCrop refactored (#273)

TenCrop and FiveCrop refactored
parent 8cd15cbe
...@@ -102,10 +102,12 @@ class Tester(unittest.TestCase): ...@@ -102,10 +102,12 @@ class Tester(unittest.TestCase):
if single_dim: if single_dim:
crop_h = min(crop_h, crop_w) crop_h = min(crop_h, crop_w)
crop_w = crop_h crop_w = crop_h
transform = transforms.TenCrop(crop_h, vflip=should_vflip) transform = transforms.TenCrop(crop_h,
vertical_flip=should_vflip)
five_crop = transforms.FiveCrop(crop_h) five_crop = transforms.FiveCrop(crop_h)
else: else:
transform = transforms.TenCrop((crop_h, crop_w), vflip=should_vflip) transform = transforms.TenCrop((crop_h, crop_w),
vertical_flip=should_vflip)
five_crop = transforms.FiveCrop((crop_h, crop_w)) five_crop = transforms.FiveCrop((crop_h, crop_w))
img = to_pil_image(torch.FloatTensor(3, h, w).uniform_()) img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
......
...@@ -281,6 +281,73 @@ def vflip(img): ...@@ -281,6 +281,73 @@ def vflip(img):
return img.transpose(Image.FLIP_TOP_BOTTOM) return img.transpose(Image.FLIP_TOP_BOTTOM)
def five_crop(img, size):
"""Crop the given PIL.Image into four corners and the central crop.
Note: this transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your `Dataset` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
Returns:
tuple: tuple (tl, tr, bl, br, center) corresponding top left,
top right, bottom left, bottom right and center crop.
"""
if isinstance(size, numbers.Number):
size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
w, h = img.size
crop_h, crop_w = size
if crop_w > w or crop_h > h:
raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
(h, w)))
tl = img.crop((0, 0, crop_w, crop_h))
tr = img.crop((w - crop_w, 0, w, crop_h))
bl = img.crop((0, h - crop_h, crop_w, h))
br = img.crop((w - crop_w, h - crop_h, w, h))
center = CenterCrop((crop_h, crop_w))(img)
return (tl, tr, bl, br, center)
def ten_crop(img, size, vertical_flip=False):
"""Crop the given PIL.Image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
Note: this transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your `Dataset` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
vertical_flip (bool): Use vertical flipping instead of horizontal
Returns:
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
br_flip, center_flip) corresponding top left, top right,
bottom left, bottom right and center crop and same for the
flipped image.
"""
if isinstance(size, numbers.Number):
size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
first_five = five_crop(img, size)
if vertical_flip:
img = vflip(img)
else:
img = hflip(img)
second_five = five_crop(img, size)
return first_five + second_five
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -661,17 +728,7 @@ class FiveCrop(object): ...@@ -661,17 +728,7 @@ class FiveCrop(object):
self.size = size self.size = size
def __call__(self, img): def __call__(self, img):
w, h = img.size return five_crop(img, self.size)
crop_h, crop_w = self.size
if crop_w > w or crop_h > h:
raise ValueError("Requested crop size {} is bigger than input size {}".format(self.size,
(h, w)))
tl = img.crop((0, 0, crop_w, crop_h))
tr = img.crop((w - crop_w, 0, w, crop_h))
bl = img.crop((0, h - crop_h, crop_w, h))
br = img.crop((w - crop_w, h - crop_h, w, h))
center = CenterCrop((crop_h, crop_w))(img)
return (tl, tr, bl, br, center)
class TenCrop(object): class TenCrop(object):
...@@ -685,25 +742,17 @@ class TenCrop(object): ...@@ -685,25 +742,17 @@ class TenCrop(object):
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is int instead of sequence like (h, w), a square crop (size, size) is
made. made.
vflip bool: Use vertical flipping instead of horizontal vertical_flip(bool): Use vertical flipping instead of horizontal
""" """
def __init__(self, size, vflip=False): def __init__(self, size, vertical_flip=False):
self.size = size self.size = size
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
self.size = (int(size), int(size)) self.size = (int(size), int(size))
else: else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size." assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size self.size = size
self.vflip = vflip self.vertical_flip = vertical_flip
def __call__(self, img): def __call__(self, img):
five_crop = FiveCrop(self.size) return ten_crop(img, self.size, self.vertical_flip)
first_five = five_crop(img)
if self.vflip:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
else:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
second_five = five_crop(img)
return first_five + second_five
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment