Commit 2e4a3f37 authored by Sasank Chilamkurthy's avatar Sasank Chilamkurthy Committed by Alykhan Tejani
Browse files

TenCrop and FiveCrop refactored (#273)

TenCrop and FiveCrop refactored
parent 8cd15cbe
......@@ -102,10 +102,12 @@ class Tester(unittest.TestCase):
if single_dim:
crop_h = min(crop_h, crop_w)
crop_w = crop_h
transform = transforms.TenCrop(crop_h, vflip=should_vflip)
transform = transforms.TenCrop(crop_h,
vertical_flip=should_vflip)
five_crop = transforms.FiveCrop(crop_h)
else:
transform = transforms.TenCrop((crop_h, crop_w), vflip=should_vflip)
transform = transforms.TenCrop((crop_h, crop_w),
vertical_flip=should_vflip)
five_crop = transforms.FiveCrop((crop_h, crop_w))
img = to_pil_image(torch.FloatTensor(3, h, w).uniform_())
......
......@@ -281,6 +281,73 @@ def vflip(img):
return img.transpose(Image.FLIP_TOP_BOTTOM)
def five_crop(img, size):
"""Crop the given PIL.Image into four corners and the central crop.
Note: this transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your `Dataset` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
Returns:
tuple: tuple (tl, tr, bl, br, center) corresponding top left,
top right, bottom left, bottom right and center crop.
"""
if isinstance(size, numbers.Number):
size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
w, h = img.size
crop_h, crop_w = size
if crop_w > w or crop_h > h:
raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
(h, w)))
tl = img.crop((0, 0, crop_w, crop_h))
tr = img.crop((w - crop_w, 0, w, crop_h))
bl = img.crop((0, h - crop_h, crop_w, h))
br = img.crop((w - crop_w, h - crop_h, w, h))
center = CenterCrop((crop_h, crop_w))(img)
return (tl, tr, bl, br, center)
def ten_crop(img, size, vertical_flip=False):
"""Crop the given PIL.Image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
Note: this transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your `Dataset` returns.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
vertical_flip (bool): Use vertical flipping instead of horizontal
Returns:
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
br_flip, center_flip) corresponding top left, top right,
bottom left, bottom right and center crop and same for the
flipped image.
"""
if isinstance(size, numbers.Number):
size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
first_five = five_crop(img, size)
if vertical_flip:
img = vflip(img)
else:
img = hflip(img)
second_five = five_crop(img, size)
return first_five + second_five
class Compose(object):
"""Composes several transforms together.
......@@ -661,17 +728,7 @@ class FiveCrop(object):
self.size = size
def __call__(self, img):
w, h = img.size
crop_h, crop_w = self.size
if crop_w > w or crop_h > h:
raise ValueError("Requested crop size {} is bigger than input size {}".format(self.size,
(h, w)))
tl = img.crop((0, 0, crop_w, crop_h))
tr = img.crop((w - crop_w, 0, w, crop_h))
bl = img.crop((0, h - crop_h, crop_w, h))
br = img.crop((w - crop_w, h - crop_h, w, h))
center = CenterCrop((crop_h, crop_w))(img)
return (tl, tr, bl, br, center)
return five_crop(img, self.size)
class TenCrop(object):
......@@ -685,25 +742,17 @@ class TenCrop(object):
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
vflip bool: Use vertical flipping instead of horizontal
vertical_flip(bool): Use vertical flipping instead of horizontal
"""
def __init__(self, size, vflip=False):
def __init__(self, size, vertical_flip=False):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
self.vflip = vflip
self.vertical_flip = vertical_flip
def __call__(self, img):
five_crop = FiveCrop(self.size)
first_five = five_crop(img)
if self.vflip:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
else:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
second_five = five_crop(img)
return first_five + second_five
return ten_crop(img, self.size, self.vertical_flip)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment