Unverified Commit 9b804659 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Unified input for resized crop op (#2396)

* [WIP] Unify random resized crop

* Unify input for RandomResizedCrop

* Fixed bugs and updated test

* Added resized crop functional test
- fixed bug with size convention

* Fixed incoherent sampling

* Fixed torch randint review remark
parent b572d5e6
...@@ -331,6 +331,23 @@ class Tester(unittest.TestCase): ...@@ -331,6 +331,23 @@ class Tester(unittest.TestCase):
pad_tensor_script = script_fn(tensor, size=script_size, interpolation=interpolation) pad_tensor_script = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(resized_tensor.equal(pad_tensor_script), msg="{}, {}".format(size, interpolation)) self.assertTrue(resized_tensor.equal(pad_tensor_script), msg="{}, {}".format(size, interpolation))
def test_resized_crop(self):
# test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity
tensor, _ = self._create_data(26, 36)
for i in [0, 2, 3]:
out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i)
self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))
# 2) resize by half and crop a TL corner
tensor, _ = self._create_data(26, 36)
out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0)
expected_out_tensor = tensor[:, :20:2, :30:2]
self.assertTrue(
expected_out_tensor.equal(out_tensor),
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -245,6 +245,25 @@ class Tester(unittest.TestCase): ...@@ -245,6 +245,25 @@ class Tester(unittest.TestCase):
s_resized_tensor = script_transform(tensor) s_resized_tensor = script_transform(tensor)
self.assertTrue(s_resized_tensor.equal(resized_tensor)) self.assertTrue(s_resized_tensor.equal(resized_tensor))
def test_resized_crop(self):
tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8)
scale = (0.7, 1.2)
ratio = (0.75, 1.333)
for size in [(32, ), [32, ], [32, 32], (32, 32)]:
for interpolation in [NEAREST, BILINEAR, BICUBIC]:
transform = T.RandomResizedCrop(
size=size, scale=scale, ratio=ratio, interpolation=interpolation
)
s_transform = torch.jit.script(transform)
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -439,24 +439,26 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: ...@@ -439,24 +439,26 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
return crop(img, crop_top, crop_left, crop_height, crop_width) return crop(img, crop_top, crop_left, crop_height, crop_width)
def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR): def resized_crop(
"""Crop the given PIL Image and resize it to desired size. img: Tensor, top: int, left: int, height: int, width: int, size: List[int], interpolation: int = Image.BILINEAR
) -> Tensor:
"""Crop the given image and resize it to desired size.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
Args: Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box. top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box. height (int): Height of the crop box.
width (int): Width of the crop box. width (int): Width of the crop box.
size (sequence or int): Desired output size. Same semantics as ``resize``. size (sequence or int): Desired output size. Same semantics as ``resize``.
interpolation (int, optional): Desired interpolation. Default is interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``.
``PIL.Image.BILINEAR``.
Returns: Returns:
PIL Image: Cropped image. PIL Image or Tensor: Cropped image.
""" """
assert F_pil._is_pil_image(img), 'img should be PIL Image'
img = crop(img, top, left, height, width) img = crop(img, top, left, height, width)
img = resize(img, size, interpolation) img = resize(img, size, interpolation)
return img return img
......
...@@ -532,7 +532,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: ...@@ -532,7 +532,7 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
elif len(size) < 2: elif len(size) < 2:
size_w, size_h = size[0], size[0] size_w, size_h = size[0], size[0]
else: else:
size_w, size_h = size[0], size[1] size_w, size_h = size[1], size[0] # Convention (h, w)
if isinstance(size, int) or len(size) < 2: if isinstance(size, int) or len(size) < 2:
if w < h: if w < h:
......
...@@ -687,8 +687,10 @@ class RandomPerspective(object): ...@@ -687,8 +687,10 @@ class RandomPerspective(object):
return self.__class__.__name__ + '(p={})'.format(self.p) return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomResizedCrop(object): class RandomResizedCrop(torch.nn.Module):
"""Crop the given PIL Image to random size and aspect ratio. """Crop the given image to random size and aspect ratio.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
A crop of random size (default: of 0.08 to 1.0) of the original size and a random A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
...@@ -696,31 +698,45 @@ class RandomResizedCrop(object): ...@@ -696,31 +698,45 @@ class RandomResizedCrop(object):
This is popularly used to train the Inception networks. This is popularly used to train the Inception networks.
Args: Args:
size: expected output size of each edge size (int or sequence): expected output size of each edge. If size is an
scale: range of size of the origin size cropped int instead of sequence like (h, w), a square output size ``(size, size)`` is
ratio: range of aspect ratio of the origin aspect ratio cropped made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
interpolation: Default: PIL.Image.BILINEAR scale (tuple of float): range of size of the origin size cropped
ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
interpolation (int): Desired interpolation. Default: ``PIL.Image.BILINEAR``
""" """
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
if isinstance(size, (tuple, list)): super().__init__()
self.size = size if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else: else:
self.size = (size, size) if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
self.size = size
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)") warnings.warn("Scale and ratio should be of kind (min, max)")
self.interpolation = interpolation self.interpolation = interpolation
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
@staticmethod @staticmethod
def get_params(img, scale, ratio): def get_params(
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float]
) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop. """Get parameters for ``crop`` for a random sized crop.
Args: Args:
img (PIL Image): Image to be cropped. img (PIL Image or Tensor): Input image.
scale (tuple): range of size of the origin size cropped scale (tuple): range of scale of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns: Returns:
...@@ -731,24 +747,26 @@ class RandomResizedCrop(object): ...@@ -731,24 +747,26 @@ class RandomResizedCrop(object):
area = height * width area = height * width
for _ in range(10): for _ in range(10):
target_area = random.uniform(*scale) * area target_area = area * torch.empty(1).uniform_(*scale).item()
log_ratio = (math.log(ratio[0]), math.log(ratio[1])) log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = math.exp(random.uniform(*log_ratio)) aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
w = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height: if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h) i = torch.randint(0, height - h + 1, size=(1,)).item()
j = random.randint(0, width - w) j = torch.randint(0, width - w + 1, size=(1,)).item()
return i, j, h, w return i, j, h, w
# Fallback to central crop # Fallback to central crop
in_ratio = float(width) / float(height) in_ratio = float(width) / float(height)
if (in_ratio < min(ratio)): if in_ratio < min(ratio):
w = width w = width
h = int(round(w / min(ratio))) h = int(round(w / min(ratio)))
elif (in_ratio > max(ratio)): elif in_ratio > max(ratio):
h = height h = height
w = int(round(h * max(ratio))) w = int(round(h * max(ratio)))
else: # whole image else: # whole image
...@@ -758,13 +776,13 @@ class RandomResizedCrop(object): ...@@ -758,13 +776,13 @@ class RandomResizedCrop(object):
j = (width - w) // 2 j = (width - w) // 2
return i, j, h, w return i, j, h, w
def __call__(self, img): def forward(self, img):
""" """
Args: Args:
img (PIL Image): Image to be cropped and resized. img (PIL Image or Tensor): Image to be cropped and resized.
Returns: Returns:
PIL Image: Randomly cropped and resized image. PIL Image or Tensor: Randomly cropped and resized image.
""" """
i, j, h, w = self.get_params(img, self.scale, self.ratio) i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment