Unverified Commit a99b6bd7 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Unified Tensor/PIL crop (#2342)

* [WIP] Unified Tensor/PIL crop

* Fixed misplaced type annotation

* Fixed tests
- crop with padding
- other tests using mising private functions: _is_pil_image, _get_image_size

* Unified CenterCrop and F.center_crop
- sorted includes in transforms.py
- used py3 annotations

* Unified FiveCrop and F.five_crop

* Improved tests and docs

* Unified TenCrop and F.ten_crop

* Removed useless typing in functional_pil

* Updated code according to the review
- removed useless torch.jit.export
- added missing typing return type
- fixed F.F_pil._is_pil_image -> F._is_pil_image

* Removed useless torch.jit.export

* Improved code according to the review
parent 446eac61
...@@ -99,6 +99,120 @@ class Tester(unittest.TestCase): ...@@ -99,6 +99,120 @@ class Tester(unittest.TestCase):
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
def test_crop(self):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
self._test_geom_op(
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
# Test torchscript of transforms.RandomCrop with size as int
f = T.RandomCrop(size=5)
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
# Test torchscript of transforms.RandomCrop with size as [int, ]
f = T.RandomCrop(size=[5, ], padding=[2, ])
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
# Test torchscript of transforms.RandomCrop with size as list
f = T.RandomCrop(size=[6, 6])
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), }
self._test_geom_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5, )}
self._test_geom_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
# Test torchscript of transforms.CenterCrop with size as int
f = T.CenterCrop(size=5)
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
# Test torchscript of transforms.CenterCrop with size as [int, ]
f = T.CenterCrop(size=[5, ])
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
# Test torchscript of transforms.CenterCrop with size as tuple
f = T.CenterCrop(size=(6, 6))
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None:
fn_kwargs = {}
if meth_kwargs is None:
meth_kwargs = {}
tensor, pil_img = self._create_data(height=20, width=20)
transformed_t_list = getattr(F, func)(tensor, **fn_kwargs)
transformed_p_list = getattr(F, func)(pil_img, **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_p_list))
self.assertEqual(len(transformed_t_list), out_length)
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
scripted_fn = torch.jit.script(getattr(F, func))
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length)
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))
# test for class interface
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
output = scripted_fn(tensor)
self.assertEqual(len(output), len(transformed_t_list_script))
def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
def test_ten_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -2,6 +2,7 @@ import math ...@@ -2,6 +2,7 @@ import math
import numbers import numbers
import warnings import warnings
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any
import numpy as np import numpy as np
from numpy import sin, cos, tan from numpy import sin, cos, tan
...@@ -9,7 +10,7 @@ from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION ...@@ -9,7 +10,7 @@ from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
import torch import torch
from torch import Tensor from torch import Tensor
from torch.jit.annotations import List from torch.jit.annotations import List, Tuple
try: try:
import accimage import accimage
...@@ -20,18 +21,25 @@ from . import functional_pil as F_pil ...@@ -20,18 +21,25 @@ from . import functional_pil as F_pil
from . import functional_tensor as F_t from . import functional_tensor as F_t
def _is_pil_image(img): _is_pil_image = F_pil._is_pil_image
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else: def _get_image_size(img: Tensor) -> List[int]:
return isinstance(img, Image.Image) """Returns image sizea as (w, h)
"""
if isinstance(img, torch.Tensor):
return F_t._get_image_size(img)
return F_pil._get_image_size(img)
def _is_numpy(img):
@torch.jit.unused
def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray) return isinstance(img, np.ndarray)
def _is_numpy_image(img): @torch.jit.unused
def _is_numpy_image(img: Any) -> bool:
return img.ndim in {2, 3} return img.ndim in {2, 3}
...@@ -46,7 +54,7 @@ def to_tensor(pic): ...@@ -46,7 +54,7 @@ def to_tensor(pic):
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
""" """
if not(_is_pil_image(pic) or _is_numpy(pic)): if not(F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
if _is_numpy(pic) and not _is_numpy_image(pic): if _is_numpy(pic) and not _is_numpy_image(pic):
...@@ -101,7 +109,7 @@ def pil_to_tensor(pic): ...@@ -101,7 +109,7 @@ def pil_to_tensor(pic):
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
""" """
if not(_is_pil_image(pic)): if not(F_pil._is_pil_image(pic)):
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic))) raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
if accimage is not None and isinstance(pic, accimage.Image): if accimage is not None and isinstance(pic, accimage.Image):
...@@ -319,7 +327,7 @@ def resize(img, size, interpolation=Image.BILINEAR): ...@@ -319,7 +327,7 @@ def resize(img, size, interpolation=Image.BILINEAR):
Returns: Returns:
PIL Image: Resized image. PIL Image: Resized image.
""" """
if not _is_pil_image(img): if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
raise TypeError('Got inappropriate size arg: {}'.format(size)) raise TypeError('Got inappropriate size arg: {}'.format(size))
...@@ -388,41 +396,58 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con ...@@ -388,41 +396,58 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
def crop(img, top, left, height, width): def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
"""Crop the given PIL Image. """Crop the given image at specified location and output size.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args: Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box. top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box. height (int): Height of the crop box.
width (int): Width of the crop box. width (int): Width of the crop box.
Returns: Returns:
PIL Image: Cropped image. PIL Image or Tensor: Cropped image.
""" """
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.crop((left, top, left + width, top + height)) if not isinstance(img, torch.Tensor):
return F_pil.crop(img, top, left, height, width)
return F_t.crop(img, top, left, height, width)
def center_crop(img, output_size):
"""Crop the given PIL Image and resize it to desired size. def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
"""Crops the given image at the center.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Args: Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int, output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int
it is used for both directions it is used for both directions.
Returns: Returns:
PIL Image: Cropped image. PIL Image or Tensor: Cropped image.
""" """
if isinstance(output_size, numbers.Number): if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size)) output_size = (int(output_size), int(output_size))
image_width, image_height = img.size elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
image_width, image_height = _get_image_size(img)
crop_height, crop_width = output_size crop_height, crop_width = output_size
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.)) # crop_top = int(round((image_height - crop_height) / 2.))
# Result can be different between python func and scripted func
# Temporary workaround:
crop_top = int((image_height - crop_height + 1) * 0.5)
# crop_left = int(round((image_width - crop_width) / 2.))
# Result can be different between python func and scripted func
# Temporary workaround:
crop_left = int((image_width - crop_width + 1) * 0.5)
return crop(img, crop_top, crop_left, crop_height, crop_width) return crop(img, crop_top, crop_left, crop_height, crop_width)
...@@ -443,23 +468,23 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE ...@@ -443,23 +468,23 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE
Returns: Returns:
PIL Image: Cropped image. PIL Image: Cropped image.
""" """
assert _is_pil_image(img), 'img should be PIL Image' assert F_pil._is_pil_image(img), 'img should be PIL Image'
img = crop(img, top, left, height, width) img = crop(img, top, left, height, width)
img = resize(img, size, interpolation) img = resize(img, size, interpolation)
return img return img
def hflip(img: Tensor) -> Tensor: def hflip(img: Tensor) -> Tensor:
"""Horizontally flip the given PIL Image or torch Tensor. """Horizontally flip the given PIL Image or Tensor.
Args: Args:
img (PIL Image or Torch Tensor): Image to be flipped. If img img (PIL Image or Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format, is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing where ... means it can have an arbitrary number of trailing
dimensions. dimensions.
Returns: Returns:
PIL Image: Horizontally flipped image. PIL Image or Tensor: Horizontally flipped image.
""" """
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.hflip(img) return F_pil.hflip(img)
...@@ -512,8 +537,7 @@ def _get_perspective_coeffs(startpoints, endpoints): ...@@ -512,8 +537,7 @@ def _get_perspective_coeffs(startpoints, endpoints):
Args: Args:
List containing [top-left, top-right, bottom-right, bottom-left] of the original image, List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image
image
Returns: Returns:
octuple (a, b, c, d, e, f, g, h) for transforming each pixel. octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
""" """
...@@ -545,7 +569,7 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N ...@@ -545,7 +569,7 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N
PIL Image: Perspectively transformed Image. PIL Image: Perspectively transformed Image.
""" """
if not _is_pil_image(img): if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
opts = _parse_fill(fill, img, '5.0.0') opts = _parse_fill(fill, img, '5.0.0')
...@@ -558,7 +582,7 @@ def vflip(img: Tensor) -> Tensor: ...@@ -558,7 +582,7 @@ def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given PIL Image or torch Tensor. """Vertically flip the given PIL Image or torch Tensor.
Args: Args:
img (PIL Image or Torch Tensor): Image to be flipped. If img img (PIL Image or Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format, is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing where ... means it can have an arbitrary number of trailing
dimensions. dimensions.
...@@ -572,17 +596,20 @@ def vflip(img: Tensor) -> Tensor: ...@@ -572,17 +596,20 @@ def vflip(img: Tensor) -> Tensor:
return F_t.vflip(img) return F_t.vflip(img)
def five_crop(img, size): def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Crop the given PIL Image into four corners and the central crop. """Crop the given image into four corners and the central crop.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. Note:: .. Note::
This transform returns a tuple of images and there may be a This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns. mismatch in the number of inputs and targets your ``Dataset`` returns.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an img (PIL Image or Tensor): Image to be cropped.
int instead of sequence like (h, w), a square crop (size, size) is size (sequence or int): Desired output size of the crop. If size is an
made. int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
Returns: Returns:
tuple: tuple (tl, tr, bl, br, center) tuple: tuple (tl, tr, bl, br, center)
...@@ -590,37 +617,44 @@ def five_crop(img, size): ...@@ -590,37 +617,44 @@ def five_crop(img, size):
""" """
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
size = (int(size), int(size)) size = (int(size), int(size))
else: elif isinstance(size, (tuple, list)) and len(size) == 1:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size." size = (size[0], size[0])
image_width, image_height = img.size if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
image_width, image_height = _get_image_size(img)
crop_height, crop_width = size crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}" msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width))) raise ValueError(msg.format(size, (image_height, image_width)))
tl = img.crop((0, 0, crop_width, crop_height)) tl = crop(img, 0, 0, crop_height, crop_width)
tr = img.crop((image_width - crop_width, 0, image_width, crop_height)) tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
bl = img.crop((0, image_height - crop_height, crop_width, image_height)) bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
br = img.crop((image_width - crop_width, image_height - crop_height, br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
image_width, image_height))
center = center_crop(img, (crop_height, crop_width)) center = center_crop(img, [crop_height, crop_width])
return (tl, tr, bl, br, center)
return tl, tr, bl, br, center
def ten_crop(img, size, vertical_flip=False): def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]:
"""Generate ten cropped images from the given PIL Image. """Generate ten cropped images from the given image.
Crop the given PIL Image into four corners and the central crop plus the Crop the given image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default). flipped version of these (horizontal flipping is used by default).
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
.. Note:: .. Note::
This transform returns a tuple of images and there may be a This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns. mismatch in the number of inputs and targets your ``Dataset`` returns.
Args: Args:
img (PIL Image or Tensor): Image to be cropped.
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is int instead of sequence like (h, w), a square crop (size, size) is
made. made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
vertical_flip (bool): Use vertical flipping instead of horizontal vertical_flip (bool): Use vertical flipping instead of horizontal
Returns: Returns:
...@@ -630,8 +664,11 @@ def ten_crop(img, size, vertical_flip=False): ...@@ -630,8 +664,11 @@ def ten_crop(img, size, vertical_flip=False):
""" """
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
size = (int(size), int(size)) size = (int(size), int(size))
else: elif isinstance(size, (tuple, list)) and len(size) == 1:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size." size = (size[0], size[0])
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
first_five = five_crop(img, size) first_five = five_crop(img, size)
...@@ -648,13 +685,13 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: ...@@ -648,13 +685,13 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
"""Adjust brightness of an Image. """Adjust brightness of an Image.
Args: Args:
img (PIL Image or Torch Tensor): Image to be adjusted. img (PIL Image or Tensor): Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2. original image while 2 increases the brightness by a factor of 2.
Returns: Returns:
PIL Image or Torch Tensor: Brightness adjusted image. PIL Image or Tensor: Brightness adjusted image.
""" """
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.adjust_brightness(img, brightness_factor) return F_pil.adjust_brightness(img, brightness_factor)
...@@ -666,13 +703,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: ...@@ -666,13 +703,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
"""Adjust contrast of an Image. """Adjust contrast of an Image.
Args: Args:
img (PIL Image or Torch Tensor): Image to be adjusted. img (PIL Image or Tensor): Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2. original image while 2 increases the contrast by a factor of 2.
Returns: Returns:
PIL Image or Torch Tensor: Contrast adjusted image. PIL Image or Tensor: Contrast adjusted image.
""" """
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.adjust_contrast(img, contrast_factor) return F_pil.adjust_contrast(img, contrast_factor)
...@@ -684,13 +721,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: ...@@ -684,13 +721,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
"""Adjust color saturation of an image. """Adjust color saturation of an image.
Args: Args:
img (PIL Image or Torch Tensor): Image to be adjusted. img (PIL Image or Tensor): Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2. 2 will enhance the saturation by a factor of 2.
Returns: Returns:
PIL Image or Torch Tensor: Saturation adjusted image. PIL Image or Tensor: Saturation adjusted image.
""" """
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
return F_pil.adjust_saturation(img, saturation_factor) return F_pil.adjust_saturation(img, saturation_factor)
...@@ -749,7 +786,7 @@ def adjust_gamma(img, gamma, gain=1): ...@@ -749,7 +786,7 @@ def adjust_gamma(img, gamma, gain=1):
while gamma smaller than 1 make dark regions lighter. while gamma smaller than 1 make dark regions lighter.
gain (float): The constant multiplier. gain (float): The constant multiplier.
""" """
if not _is_pil_image(img): if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if gamma < 0: if gamma < 0:
...@@ -789,7 +826,7 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None): ...@@ -789,7 +826,7 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
""" """
if not _is_pil_image(img): if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
opts = _parse_fill(fill, img, '5.2.0') opts = _parse_fill(fill, img, '5.2.0')
...@@ -870,7 +907,7 @@ def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): ...@@ -870,7 +907,7 @@ def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
""" """
if not _is_pil_image(img): if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
...@@ -897,7 +934,7 @@ def to_grayscale(img, num_output_channels=1): ...@@ -897,7 +934,7 @@ def to_grayscale(img, num_output_channels=1):
if num_output_channels = 3 : returned image is 3 channel with r = g = b if num_output_channels = 3 : returned image is 3 channel with r = g = b
""" """
if not _is_pil_image(img): if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if num_output_channels == 1: if num_output_channels == 1:
......
import numbers import numbers
from typing import Any, List
import torch import torch
try: try:
...@@ -10,13 +11,20 @@ import numpy as np ...@@ -10,13 +11,20 @@ import numpy as np
@torch.jit.unused @torch.jit.unused
def _is_pil_image(img): def _is_pil_image(img: Any) -> bool:
if accimage is not None: if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image)) return isinstance(img, (Image.Image, accimage.Image))
else: else:
return isinstance(img, Image.Image) return isinstance(img, Image.Image)
@torch.jit.unused
def _get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return img.size
raise TypeError("Unexpected type {}".format(type(img)))
@torch.jit.unused @torch.jit.unused
def hflip(img): def hflip(img):
"""Horizontally flip the given PIL Image. """Horizontally flip the given PIL Image.
...@@ -258,3 +266,23 @@ def pad(img, padding, fill=0, padding_mode="constant"): ...@@ -258,3 +266,23 @@ def pad(img, padding, fill=0, padding_mode="constant"):
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
return Image.fromarray(img) return Image.fromarray(img)
@torch.jit.unused
def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image:
"""Crop the given PIL Image.
Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
Returns:
PIL Image: Cropped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.crop((left, top, left + width, top + height))
...@@ -3,12 +3,17 @@ from torch import Tensor ...@@ -3,12 +3,17 @@ from torch import Tensor
from torch.jit.annotations import List, BroadcastingList2 from torch.jit.annotations import List, BroadcastingList2
def _is_tensor_a_torch_image(input): def _is_tensor_a_torch_image(x: Tensor) -> bool:
return input.ndim >= 2 return x.ndim >= 2
def vflip(img): def _get_image_size(img: Tensor) -> List[int]:
# type: (Tensor) -> Tensor if _is_tensor_a_torch_image(img):
return [img.shape[-1], img.shape[-2]]
raise TypeError("Unexpected type {}".format(type(img)))
def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given the Image Tensor. """Vertically flip the given the Image Tensor.
Args: Args:
...@@ -23,8 +28,7 @@ def vflip(img): ...@@ -23,8 +28,7 @@ def vflip(img):
return img.flip(-2) return img.flip(-2)
def hflip(img): def hflip(img: Tensor) -> Tensor:
# type: (Tensor) -> Tensor
"""Horizontally flip the given the Image Tensor. """Horizontally flip the given the Image Tensor.
Args: Args:
...@@ -39,12 +43,11 @@ def hflip(img): ...@@ -39,12 +43,11 @@ def hflip(img):
return img.flip(-1) return img.flip(-1)
def crop(img, top, left, height, width): def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
# type: (Tensor, int, int, int, int) -> Tensor
"""Crop the given Image Tensor. """Crop the given Image Tensor.
Args: Args:
img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image. img (Tensor): Image to be cropped in the form [..., H, W]. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box. top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box. height (int): Height of the crop box.
...@@ -54,13 +57,12 @@ def crop(img, top, left, height, width): ...@@ -54,13 +57,12 @@ def crop(img, top, left, height, width):
Tensor: Cropped image. Tensor: Cropped image.
""" """
if not _is_tensor_a_torch_image(img): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError("tensor is not a torch image.")
return img[..., top:top + height, left:left + width] return img[..., top:top + height, left:left + width]
def rgb_to_grayscale(img): def rgb_to_grayscale(img: Tensor) -> Tensor:
# type: (Tensor) -> Tensor
"""Convert the given RGB Image Tensor to Grayscale. """Convert the given RGB Image Tensor to Grayscale.
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
is L = R * 0.2989 + G * 0.5870 + B * 0.1140 is L = R * 0.2989 + G * 0.5870 + B * 0.1140
...@@ -78,8 +80,7 @@ def rgb_to_grayscale(img): ...@@ -78,8 +80,7 @@ def rgb_to_grayscale(img):
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype) return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)
def adjust_brightness(img, brightness_factor): def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
# type: (Tensor, float) -> Tensor
"""Adjust brightness of an RGB image. """Adjust brightness of an RGB image.
Args: Args:
...@@ -97,8 +98,7 @@ def adjust_brightness(img, brightness_factor): ...@@ -97,8 +98,7 @@ def adjust_brightness(img, brightness_factor):
return _blend(img, torch.zeros_like(img), brightness_factor) return _blend(img, torch.zeros_like(img), brightness_factor)
def adjust_contrast(img, contrast_factor): def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
# type: (Tensor, float) -> Tensor
"""Adjust contrast of an RGB image. """Adjust contrast of an RGB image.
Args: Args:
...@@ -166,8 +166,7 @@ def adjust_hue(img, hue_factor): ...@@ -166,8 +166,7 @@ def adjust_hue(img, hue_factor):
return img_hue_adj return img_hue_adj
def adjust_saturation(img, saturation_factor): def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
# type: (Tensor, float) -> Tensor
"""Adjust color saturation of an RGB image. """Adjust color saturation of an RGB image.
Args: Args:
...@@ -185,12 +184,11 @@ def adjust_saturation(img, saturation_factor): ...@@ -185,12 +184,11 @@ def adjust_saturation(img, saturation_factor):
return _blend(img, rgb_to_grayscale(img), saturation_factor) return _blend(img, rgb_to_grayscale(img), saturation_factor)
def center_crop(img, output_size): def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
# type: (Tensor, BroadcastingList2[int]) -> Tensor
"""Crop the Image Tensor and resize it to desired size. """Crop the Image Tensor and resize it to desired size.
Args: Args:
img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. img (Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int, output_size (sequence or int): (height, width) of the crop box. If int,
it is used for both directions it is used for both directions
...@@ -202,23 +200,29 @@ def center_crop(img, output_size): ...@@ -202,23 +200,29 @@ def center_crop(img, output_size):
_, image_width, image_height = img.size() _, image_width, image_height = img.size()
crop_height, crop_width = output_size crop_height, crop_width = output_size
crop_top = int(round((image_height - crop_height) / 2.)) # crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.)) # Result can be different between python func and scripted func
# Temporary workaround:
crop_top = int((image_height - crop_height + 1) * 0.5)
# crop_left = int(round((image_width - crop_width) / 2.))
# Result can be different between python func and scripted func
# Temporary workaround:
crop_left = int((image_width - crop_width + 1) * 0.5)
return crop(img, crop_top, crop_left, crop_height, crop_width) return crop(img, crop_top, crop_left, crop_height, crop_width)
def five_crop(img, size): def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
# type: (Tensor, BroadcastingList2[int]) -> List[Tensor]
"""Crop the given Image Tensor into four corners and the central crop. """Crop the given Image Tensor into four corners and the central crop.
.. Note:: .. Note::
This transform returns a List of Tensors and there may be a This transform returns a List of Tensors and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns. mismatch in the number of inputs and targets your ``Dataset`` returns.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an img (Tensor): Image to be cropped.
int instead of sequence like (h, w), a square crop (size, size) is size (sequence or int): Desired output size of the crop. If size is an
made. int instead of sequence like (h, w), a square crop (size, size) is
made.
Returns: Returns:
List: List (tl, tr, bl, br, center) List: List (tl, tr, bl, br, center)
...@@ -244,19 +248,20 @@ def five_crop(img, size): ...@@ -244,19 +248,20 @@ def five_crop(img, size):
return [tl, tr, bl, br, center] return [tl, tr, bl, br, center]
def ten_crop(img, size, vertical_flip=False): def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]:
# type: (Tensor, BroadcastingList2[int], bool) -> List[Tensor]
"""Crop the given Image Tensor into four corners and the central crop plus the """Crop the given Image Tensor into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default). flipped version of these (horizontal flipping is used by default).
.. Note:: .. Note::
This transform returns a List of images and there may be a This transform returns a List of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns. mismatch in the number of inputs and targets your ``Dataset`` returns.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an img (Tensor): Image to be cropped.
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is int instead of sequence like (h, w), a square crop (size, size) is
made. made.
vertical_flip (bool): Use vertical flipping instead of horizontal vertical_flip (bool): Use vertical flipping instead of horizontal
Returns: Returns:
List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
...@@ -279,8 +284,7 @@ def ten_crop(img, size, vertical_flip=False): ...@@ -279,8 +284,7 @@ def ten_crop(img, size, vertical_flip=False):
return first_five + second_five return first_five + second_five
def _blend(img1, img2, ratio): def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
# type: (Tensor, Tensor, float) -> Tensor
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255 bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
......
import torch
import math import math
import numbers
import random import random
import warnings
from collections.abc import Sequence, Iterable
from typing import Tuple
import numpy as np
import torch
from PIL import Image from PIL import Image
from torch import Tensor
try: try:
import accimage import accimage
except ImportError: except ImportError:
accimage = None accimage = None
import numpy as np
import numbers
import types
from collections.abc import Sequence, Iterable
import warnings
from . import functional as F from . import functional as F
...@@ -31,15 +34,6 @@ _pil_interpolation_to_str = { ...@@ -31,15 +34,6 @@ _pil_interpolation_to_str = {
} }
def _get_image_size(img):
if F._is_pil_image(img):
return img.size
elif isinstance(img, torch.Tensor) and img.dim() > 2:
return img.shape[-2:][::-1]
else:
raise TypeError("Unexpected type {}".format(type(img)))
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -98,7 +92,7 @@ class ToTensor(object): ...@@ -98,7 +92,7 @@ class ToTensor(object):
class PILToTensor(object): class PILToTensor(object):
"""Convert a ``PIL Image`` to a tensor of the same type. """Convert a ``PIL Image`` to a tensor of the same type.
Converts a PIL Image (H x W x C) to a torch.Tensor of shape (C x H x W). Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
""" """
def __call__(self, pic): def __call__(self, pic):
...@@ -258,28 +252,36 @@ class Scale(Resize): ...@@ -258,28 +252,36 @@ class Scale(Resize):
super(Scale, self).__init__(*args, **kwargs) super(Scale, self).__init__(*args, **kwargs)
class CenterCrop(object): class CenterCrop(torch.nn.Module):
"""Crops the given PIL Image at the center. """Crops the given image at the center.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is int instead of sequence like (h, w), a square crop (size, size) is
made. made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
""" """
def __init__(self, size): def __init__(self, size):
super().__init__()
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
self.size = (int(size), int(size)) self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else: else:
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
self.size = size self.size = size
def __call__(self, img): def forward(self, img):
""" """
Args: Args:
img (PIL Image): Image to be cropped. img (PIL Image or Tensor): Image to be cropped.
Returns: Returns:
PIL Image: Cropped image. PIL Image or Tensor: Cropped image.
""" """
return F.center_crop(img, self.size) return F.center_crop(img, self.size)
...@@ -443,25 +445,30 @@ class RandomChoice(RandomTransforms): ...@@ -443,25 +445,30 @@ class RandomChoice(RandomTransforms):
return t(img) return t(img)
class RandomCrop(object): class RandomCrop(torch.nn.Module):
"""Crop the given PIL Image at a random location. """Crop the given image at a random location.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is int instead of sequence like (h, w), a square crop (size, size) is
made. made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
padding (int or sequence, optional): Optional padding on each border padding (int or sequence, optional): Optional padding on each border
of the image. Default is None, i.e no padding. If a sequence of length of the image. Default is None. If a single int is provided this
4 is provided, it is used to pad left, top, right, bottom borders is used to pad all borders. If tuple of length 2 is provided this is the padding
respectively. If a sequence of length 2 is provided, it is used to on left/right and top/bottom respectively. If a tuple of length 4 is provided
pad left/right, top/bottom borders, respectively. this is the padding for the left, top, right and bottom borders respectively.
In torchscript mode padding as single int is not supported, use a tuple or
list of length 1: ``[padding, ]``.
pad_if_needed (boolean): It will pad the image if smaller than the pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset. after padding, the padding seems to be done at a random offset.
fill: Pixel fill value for constant fill. Default is 0. If a tuple of fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively. length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant This value is only used when the padding_mode is constant
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
- constant: pads with a constant value, this value is specified with fill - constant: pads with a constant value, this value is specified with fill
...@@ -479,60 +486,70 @@ class RandomCrop(object): ...@@ -479,60 +486,70 @@ class RandomCrop(object):
""" """
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode
@staticmethod @staticmethod
def get_params(img, output_size): def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random crop. """Get parameters for ``crop`` for a random crop.
Args: Args:
img (PIL Image): Image to be cropped. img (PIL Image or Tensor): Image to be cropped.
output_size (tuple): Expected output size of the crop. output_size (tuple): Expected output size of the crop.
Returns: Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
""" """
w, h = _get_image_size(img) w, h = F._get_image_size(img)
th, tw = output_size th, tw = output_size
if w == tw and h == th: if w == tw and h == th:
return 0, 0, h, w return 0, 0, h, w
i = random.randint(0, h - th) i = torch.randint(0, h - th, size=(1, )).item()
j = random.randint(0, w - tw) j = torch.randint(0, w - tw, size=(1, )).item()
return i, j, th, tw return i, j, th, tw
def __call__(self, img): def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
super().__init__()
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else:
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
# cast to tuple for torchscript
self.size = tuple(size)
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode
def forward(self, img):
""" """
Args: Args:
img (PIL Image): Image to be cropped. img (PIL Image or Tensor): Image to be cropped.
Returns: Returns:
PIL Image: Cropped image. PIL Image or Tensor: Cropped image.
""" """
if self.padding is not None: if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode) img = F.pad(img, self.padding, self.fill, self.padding_mode)
width, height = F._get_image_size(img)
# pad the width if needed # pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]: if self.pad_if_needed and width < self.size[1]:
img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) padding = [self.size[1] - width, 0]
img = F.pad(img, padding, self.fill, self.padding_mode)
# pad the height if needed # pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]: if self.pad_if_needed and height < self.size[0]:
img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) padding = [0, self.size[0] - height]
img = F.pad(img, padding, self.fill, self.padding_mode)
i, j, h, w = self.get_params(img, self.size) i, j, h, w = self.get_params(img, self.size)
return F.crop(img, i, j, h, w) return F.crop(img, i, j, h, w)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding)
class RandomHorizontalFlip(torch.nn.Module): class RandomHorizontalFlip(torch.nn.Module):
...@@ -566,7 +583,7 @@ class RandomHorizontalFlip(torch.nn.Module): ...@@ -566,7 +583,7 @@ class RandomHorizontalFlip(torch.nn.Module):
class RandomVerticalFlip(torch.nn.Module): class RandomVerticalFlip(torch.nn.Module):
"""Vertically flip the given PIL Image randomly with a given probability. """Vertically flip the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions dimensions
...@@ -702,7 +719,7 @@ class RandomResizedCrop(object): ...@@ -702,7 +719,7 @@ class RandomResizedCrop(object):
tuple: params (i, j, h, w) to be passed to ``crop`` for a random tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop. sized crop.
""" """
width, height = _get_image_size(img) width, height = F._get_image_size(img)
area = height * width area = height * width
for _ in range(10): for _ in range(10):
...@@ -763,8 +780,11 @@ class RandomSizedCrop(RandomResizedCrop): ...@@ -763,8 +780,11 @@ class RandomSizedCrop(RandomResizedCrop):
super(RandomSizedCrop, self).__init__(*args, **kwargs) super(RandomSizedCrop, self).__init__(*args, **kwargs)
class FiveCrop(object): class FiveCrop(torch.nn.Module):
"""Crop the given PIL Image into four corners and the central crop """Crop the given image into four corners and the central crop.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
.. Note:: .. Note::
This transform returns a tuple of images and there may be a mismatch in the number of This transform returns a tuple of images and there may be a mismatch in the number of
...@@ -774,6 +794,7 @@ class FiveCrop(object): ...@@ -774,6 +794,7 @@ class FiveCrop(object):
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an ``int`` size (sequence or int): Desired output size of the crop. If size is an ``int``
instead of sequence like (h, w), a square crop of size (size, size) is made. instead of sequence like (h, w), a square crop of size (size, size) is made.
If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
Example: Example:
>>> transform = Compose([ >>> transform = Compose([
...@@ -788,23 +809,37 @@ class FiveCrop(object): ...@@ -788,23 +809,37 @@ class FiveCrop(object):
""" """
def __init__(self, size): def __init__(self, size):
self.size = size super().__init__()
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
self.size = (int(size), int(size)) self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else: else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size." if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
self.size = size self.size = size
def __call__(self, img): def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
tuple of 5 images. Image can be PIL Image or Tensor
"""
return F.five_crop(img, self.size) return F.five_crop(img, self.size)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size) return self.__class__.__name__ + '(size={0})'.format(self.size)
class TenCrop(object): class TenCrop(torch.nn.Module):
"""Crop the given PIL Image into four corners and the central crop plus the flipped version of """Crop the given image into four corners and the central crop plus the flipped version of
these (horizontal flipping is used by default) these (horizontal flipping is used by default).
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
.. Note:: .. Note::
This transform returns a tuple of images and there may be a mismatch in the number of This transform returns a tuple of images and there may be a mismatch in the number of
...@@ -814,7 +849,7 @@ class TenCrop(object): ...@@ -814,7 +849,7 @@ class TenCrop(object):
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is int instead of sequence like (h, w), a square crop (size, size) is
made. made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
vertical_flip (bool): Use vertical flipping instead of horizontal vertical_flip (bool): Use vertical flipping instead of horizontal
Example: Example:
...@@ -830,15 +865,26 @@ class TenCrop(object): ...@@ -830,15 +865,26 @@ class TenCrop(object):
""" """
def __init__(self, size, vertical_flip=False): def __init__(self, size, vertical_flip=False):
self.size = size super().__init__()
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
self.size = (int(size), int(size)) self.size = (int(size), int(size))
elif isinstance(size, Sequence) and len(size) == 1:
self.size = (size[0], size[0])
else: else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size." if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
self.size = size self.size = size
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def __call__(self, img): def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
tuple of 10 images. Image can be PIL Image or Tensor
"""
return F.ten_crop(img, self.size, self.vertical_flip) return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self): def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment