Commit 4e9c5c96 authored by Sasank Chilamkurthy's avatar Sasank Chilamkurthy
Browse files

Fix transforms documentation.

PIL.Image is renamed to 'PIL Image' because PIL.Image is the module and PIL.Image.Image is the class we are referring but is too long to put in docstrings
parent cd20d5f6
...@@ -7,9 +7,11 @@ Transforms are common image transforms. They can be chained together using :clas ...@@ -7,9 +7,11 @@ Transforms are common image transforms. They can be chained together using :clas
.. autoclass:: Compose .. autoclass:: Compose
Transforms on PIL.Image Transforms on PIL Image
----------------------- -----------------------
.. autoclass:: Resize
.. autoclass:: Scale .. autoclass:: Scale
.. autoclass:: CenterCrop .. autoclass:: CenterCrop
...@@ -18,10 +20,20 @@ Transforms on PIL.Image ...@@ -18,10 +20,20 @@ Transforms on PIL.Image
.. autoclass:: RandomHorizontalFlip .. autoclass:: RandomHorizontalFlip
.. autoclass:: RandomVerticalFlip
.. autoclass:: RandomResizedCrop
.. autoclass:: RandomSizedCrop .. autoclass:: RandomSizedCrop
.. autoclass:: FiveCrop
.. autoclass:: TenCrop
.. autoclass:: Pad .. autoclass:: Pad
.. autoclass:: ColorJitter
Transforms on torch.\*Tensor Transforms on torch.\*Tensor
---------------------------- ----------------------------
......
...@@ -30,12 +30,12 @@ def _is_numpy_image(img): ...@@ -30,12 +30,12 @@ def _is_numpy_image(img):
def to_tensor(pic): def to_tensor(pic):
"""Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
See ``ToTensor`` for more details. See ``ToTensor`` for more details.
Args: Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
...@@ -84,10 +84,10 @@ def to_pil_image(pic): ...@@ -84,10 +84,10 @@ def to_pil_image(pic):
See ``ToPIlImage`` for more details. See ``ToPIlImage`` for more details.
Args: Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image. pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns: Returns:
PIL.Image: Image converted to PIL.Image. PIL Image: Image converted to PIL Image.
""" """
if not(_is_numpy_image(pic) or _is_tensor_image(pic)): if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
...@@ -143,10 +143,10 @@ def normalize(tensor, mean, std): ...@@ -143,10 +143,10 @@ def normalize(tensor, mean, std):
def resize(img, size, interpolation=Image.BILINEAR): def resize(img, size, interpolation=Image.BILINEAR):
"""Resize the input PIL.Image to the given size. """Resize the input PIL Image to the given size.
Args: Args:
img (PIL.Image): Image to be resized. img (PIL Image): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int, (h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaing the smaller edge of the image will be matched to this number maintaing
...@@ -156,7 +156,7 @@ def resize(img, size, interpolation=Image.BILINEAR): ...@@ -156,7 +156,7 @@ def resize(img, size, interpolation=Image.BILINEAR):
``PIL.Image.BILINEAR`` ``PIL.Image.BILINEAR``
Returns: Returns:
PIL.Image: Resized image. PIL Image: Resized image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -186,10 +186,10 @@ def scale(*args, **kwargs): ...@@ -186,10 +186,10 @@ def scale(*args, **kwargs):
def pad(img, padding, fill=0): def pad(img, padding, fill=0):
"""Pad the given PIL.Image on all sides with the given "pad" value. """Pad the given PIL Image on all sides with the given "pad" value.
Args: Args:
img (PIL.Image): Image to be padded. img (PIL Image): Image to be padded.
padding (int or tuple): Padding on each border. If a single int is provided this padding (int or tuple): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided on left/right and top/bottom respectively. If a tuple of length 4 is provided
...@@ -199,7 +199,7 @@ def pad(img, padding, fill=0): ...@@ -199,7 +199,7 @@ def pad(img, padding, fill=0):
length 3, it is used to fill R, G, B channels respectively. length 3, it is used to fill R, G, B channels respectively.
Returns: Returns:
PIL.Image: Padded image. PIL Image: Padded image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -217,17 +217,17 @@ def pad(img, padding, fill=0): ...@@ -217,17 +217,17 @@ def pad(img, padding, fill=0):
def crop(img, i, j, h, w): def crop(img, i, j, h, w):
"""Crop the given PIL.Image. """Crop the given PIL Image.
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
i: Upper pixel coordinate. i: Upper pixel coordinate.
j: Left pixel coordinate. j: Left pixel coordinate.
h: Height of the cropped image. h: Height of the cropped image.
w: Width of the cropped image. w: Width of the cropped image.
Returns: Returns:
PIL.Image: Cropped image. PIL Image: Cropped image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -236,12 +236,12 @@ def crop(img, i, j, h, w): ...@@ -236,12 +236,12 @@ def crop(img, i, j, h, w):
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
"""Crop the given PIL.Image and resize it to desired size. """Crop the given PIL Image and resize it to desired size.
Notably used in RandomResizedCrop. Notably used in RandomResizedCrop.
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
i: Upper pixel coordinate. i: Upper pixel coordinate.
j: Left pixel coordinate. j: Left pixel coordinate.
h: Height of the cropped image. h: Height of the cropped image.
...@@ -250,7 +250,7 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): ...@@ -250,7 +250,7 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
interpolation (int, optional): Desired interpolation. Default is interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``. ``PIL.Image.BILINEAR``.
Returns: Returns:
PIL.Image: Cropped image. PIL Image: Cropped image.
""" """
assert _is_pil_image(img), 'img should be PIL Image' assert _is_pil_image(img), 'img should be PIL Image'
img = crop(img, i, j, h, w) img = crop(img, i, j, h, w)
...@@ -259,13 +259,13 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): ...@@ -259,13 +259,13 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
def hflip(img): def hflip(img):
"""Horizontally flip the given PIL.Image. """Horizontally flip the given PIL Image.
Args: Args:
img (PIL.Image): Image to be flipped. img (PIL Image): Image to be flipped.
Returns: Returns:
PIL.Image: Horizontall flipped image. PIL Image: Horizontall flipped image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -274,13 +274,13 @@ def hflip(img): ...@@ -274,13 +274,13 @@ def hflip(img):
def vflip(img): def vflip(img):
"""Vertically flip the given PIL.Image. """Vertically flip the given PIL Image.
Args: Args:
img (PIL.Image): Image to be flipped. img (PIL Image): Image to be flipped.
Returns: Returns:
PIL.Image: Vertically flipped image. PIL Image: Vertically flipped image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -289,10 +289,11 @@ def vflip(img): ...@@ -289,10 +289,11 @@ def vflip(img):
def five_crop(img, size): def five_crop(img, size):
"""Crop the given PIL.Image into four corners and the central crop. """Crop the given PIL Image into four corners and the central crop.
Note: this transform returns a tuple of images and there may be a mismatch in the number of .. Note::
inputs and targets your `Dataset` returns. This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
...@@ -321,11 +322,12 @@ def five_crop(img, size): ...@@ -321,11 +322,12 @@ def five_crop(img, size):
def ten_crop(img, size, vertical_flip=False): def ten_crop(img, size, vertical_flip=False):
"""Crop the given PIL.Image into four corners and the central crop plus the """Crop the given PIL Image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default). flipped version of these (horizontal flipping is used by default).
Note: this transform returns a tuple of images and there may be a mismatch in the number of .. Note::
inputs and targets your `Dataset` returns. This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
...@@ -359,13 +361,13 @@ def adjust_brightness(img, brightness_factor): ...@@ -359,13 +361,13 @@ def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an Image. """Adjust brightness of an Image.
Args: Args:
img (PIL.Image): PIL Image to be adjusted. img (PIL Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2. original image while 2 increases the brightness by a factor of 2.
Returns: Returns:
PIL.Image: Brightness adjusted image. PIL Image: Brightness adjusted image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -379,13 +381,13 @@ def adjust_contrast(img, contrast_factor): ...@@ -379,13 +381,13 @@ def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image. """Adjust contrast of an Image.
Args: Args:
img (PIL.Image): PIL Image to be adjusted. img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2. original image while 2 increases the contrast by a factor of 2.
Returns: Returns:
PIL.Image: Contrast adjusted image. PIL Image: Contrast adjusted image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -399,13 +401,13 @@ def adjust_saturation(img, saturation_factor): ...@@ -399,13 +401,13 @@ def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image. """Adjust color saturation of an image.
Args: Args:
img (PIL.Image): PIL Image to be adjusted. img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2. 2 will enhance the saturation by a factor of 2.
Returns: Returns:
PIL.Image: Saturation adjusted image. PIL Image: Saturation adjusted image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -428,7 +430,7 @@ def adjust_hue(img, hue_factor): ...@@ -428,7 +430,7 @@ def adjust_hue(img, hue_factor):
See https://en.wikipedia.org/wiki/Hue for more details on Hue. See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args: Args:
img (PIL.Image): PIL Image to be adjusted. img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively. HSV space in positive and negative direction respectively.
...@@ -436,7 +438,7 @@ def adjust_hue(img, hue_factor): ...@@ -436,7 +438,7 @@ def adjust_hue(img, hue_factor):
with complementary colors while 0 gives the original image. with complementary colors while 0 gives the original image.
Returns: Returns:
PIL.Image: Hue adjusted image. PIL Image: Hue adjusted image.
""" """
if not(-0.5 <= hue_factor <= 0.5): if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
...@@ -471,7 +473,7 @@ def adjust_gamma(img, gamma, gain=1): ...@@ -471,7 +473,7 @@ def adjust_gamma(img, gamma, gain=1):
See https://en.wikipedia.org/wiki/Gamma_correction for more details. See https://en.wikipedia.org/wiki/Gamma_correction for more details.
Args: Args:
img (PIL.Image): PIL Image to be adjusted. img (PIL Image): PIL Image to be adjusted.
gamma (float): Non negative real number. gamma larger than 1 make the gamma (float): Non negative real number. gamma larger than 1 make the
shadows darker, while gamma smaller than 1 make dark regions shadows darker, while gamma smaller than 1 make dark regions
lighter. lighter.
...@@ -517,16 +519,16 @@ class Compose(object): ...@@ -517,16 +519,16 @@ class Compose(object):
class ToTensor(object): class ToTensor(object):
"""Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL.Image or numpy.ndarray (H x W x C) in the range Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
""" """
def __call__(self, pic): def __call__(self, pic):
""" """
Args: Args:
pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns: Returns:
Tensor: Converted image. Tensor: Converted image.
...@@ -538,16 +540,16 @@ class ToPILImage(object): ...@@ -538,16 +540,16 @@ class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image. """Convert a tensor or an ndarray to PIL Image.
Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
H x W x C to a PIL.Image while preserving the value range. H x W x C to a PIL Image while preserving the value range.
""" """
def __call__(self, pic): def __call__(self, pic):
""" """
Args: Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image. pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Returns: Returns:
PIL.Image: Image converted to PIL.Image. PIL Image: Image converted to PIL Image.
""" """
return to_pil_image(pic) return to_pil_image(pic)
...@@ -582,7 +584,7 @@ class Normalize(object): ...@@ -582,7 +584,7 @@ class Normalize(object):
class Resize(object): class Resize(object):
"""Resize the input PIL.Image to the given size. """Resize the input PIL Image to the given size.
Args: Args:
size (sequence or int): Desired output size. If size is a sequence like size (sequence or int): Desired output size. If size is a sequence like
...@@ -602,15 +604,18 @@ class Resize(object): ...@@ -602,15 +604,18 @@ class Resize(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be scaled. img (PIL Image): Image to be scaled.
Returns: Returns:
PIL.Image: Rescaled image. PIL Image: Rescaled image.
""" """
return resize(img, self.size, self.interpolation) return resize(img, self.size, self.interpolation)
class Scale(Resize): class Scale(Resize):
"""
Note: This transform is deprecated in favor of Resize.
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, " + warnings.warn("The use of the transforms.Scale transform is deprecated, " +
"please use transforms.Resize instead.") "please use transforms.Resize instead.")
...@@ -618,7 +623,7 @@ class Scale(Resize): ...@@ -618,7 +623,7 @@ class Scale(Resize):
class CenterCrop(object): class CenterCrop(object):
"""Crops the given PIL.Image at the center. """Crops the given PIL Image at the center.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
...@@ -637,7 +642,7 @@ class CenterCrop(object): ...@@ -637,7 +642,7 @@ class CenterCrop(object):
"""Get parameters for ``crop`` for center crop. """Get parameters for ``crop`` for center crop.
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop. output_size (tuple): Expected output size of the crop.
Returns: Returns:
...@@ -652,17 +657,17 @@ class CenterCrop(object): ...@@ -652,17 +657,17 @@ class CenterCrop(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
Returns: Returns:
PIL.Image: Cropped image. PIL Image: Cropped image.
""" """
i, j, h, w = self.get_params(img, self.size) i, j, h, w = self.get_params(img, self.size)
return crop(img, i, j, h, w) return crop(img, i, j, h, w)
class Pad(object): class Pad(object):
"""Pad the given PIL.Image on all sides with the given "pad" value. """Pad the given PIL Image on all sides with the given "pad" value.
Args: Args:
padding (int or tuple): Padding on each border. If a single int is provided this padding (int or tuple): Padding on each border. If a single int is provided this
...@@ -687,10 +692,10 @@ class Pad(object): ...@@ -687,10 +692,10 @@ class Pad(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be padded. img (PIL Image): Image to be padded.
Returns: Returns:
PIL.Image: Padded image. PIL Image: Padded image.
""" """
return pad(img, self.padding, self.fill) return pad(img, self.padding, self.fill)
...@@ -711,7 +716,7 @@ class Lambda(object): ...@@ -711,7 +716,7 @@ class Lambda(object):
class RandomCrop(object): class RandomCrop(object):
"""Crop the given PIL.Image at a random location. """Crop the given PIL Image at a random location.
Args: Args:
size (sequence or int): Desired output size of the crop. If size is an size (sequence or int): Desired output size of the crop. If size is an
...@@ -735,7 +740,7 @@ class RandomCrop(object): ...@@ -735,7 +740,7 @@ class RandomCrop(object):
"""Get parameters for ``crop`` for a random crop. """Get parameters for ``crop`` for a random crop.
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop. output_size (tuple): Expected output size of the crop.
Returns: Returns:
...@@ -753,10 +758,10 @@ class RandomCrop(object): ...@@ -753,10 +758,10 @@ class RandomCrop(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
Returns: Returns:
PIL.Image: Cropped image. PIL Image: Cropped image.
""" """
if self.padding > 0: if self.padding > 0:
img = pad(img, self.padding) img = pad(img, self.padding)
...@@ -767,15 +772,15 @@ class RandomCrop(object): ...@@ -767,15 +772,15 @@ class RandomCrop(object):
class RandomHorizontalFlip(object): class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be flipped. img (PIL Image): Image to be flipped.
Returns: Returns:
PIL.Image: Randomly flipped image. PIL Image: Randomly flipped image.
""" """
if random.random() < 0.5: if random.random() < 0.5:
return hflip(img) return hflip(img)
...@@ -783,15 +788,15 @@ class RandomHorizontalFlip(object): ...@@ -783,15 +788,15 @@ class RandomHorizontalFlip(object):
class RandomVerticalFlip(object): class RandomVerticalFlip(object):
"""Vertically flip the given PIL.Image randomly with a probability of 0.5.""" """Vertically flip the given PIL Image randomly with a probability of 0.5."""
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be flipped. img (PIL Image): Image to be flipped.
Returns: Returns:
PIL.Image: Randomly flipped image. PIL Image: Randomly flipped image.
""" """
if random.random() < 0.5: if random.random() < 0.5:
return vflip(img) return vflip(img)
...@@ -799,7 +804,7 @@ class RandomVerticalFlip(object): ...@@ -799,7 +804,7 @@ class RandomVerticalFlip(object):
class RandomResizedCrop(object): class RandomResizedCrop(object):
"""Crop the given PIL.Image to random size and aspect ratio. """Crop the given PIL Image to random size and aspect ratio.
A crop of random size of (0.08 to 1.0) of the original size and a random A crop of random size of (0.08 to 1.0) of the original size and a random
aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
...@@ -820,7 +825,7 @@ class RandomResizedCrop(object): ...@@ -820,7 +825,7 @@ class RandomResizedCrop(object):
"""Get parameters for ``crop`` for a random sized crop. """Get parameters for ``crop`` for a random sized crop.
Args: Args:
img (PIL.Image): Image to be cropped. img (PIL Image): Image to be cropped.
Returns: Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random tuple: params (i, j, h, w) to be passed to ``crop`` for a random
...@@ -851,16 +856,19 @@ class RandomResizedCrop(object): ...@@ -851,16 +856,19 @@ class RandomResizedCrop(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Image to be flipped. img (PIL Image): Image to be flipped.
Returns: Returns:
PIL.Image: Randomly cropped and resize image. PIL Image: Randomly cropped and resize image.
""" """
i, j, h, w = self.get_params(img) i, j, h, w = self.get_params(img)
return resized_crop(img, i, j, h, w, self.size, self.interpolation) return resized_crop(img, i, j, h, w, self.size, self.interpolation)
class RandomSizedCrop(RandomResizedCrop): class RandomSizedCrop(RandomResizedCrop):
"""
Note: This transform is deprecated in favor of RandomResizedCrop.
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
"please use transforms.RandomResizedCrop instead.") "please use transforms.RandomResizedCrop instead.")
...@@ -868,7 +876,7 @@ class RandomSizedCrop(RandomResizedCrop): ...@@ -868,7 +876,7 @@ class RandomSizedCrop(RandomResizedCrop):
class FiveCrop(object): class FiveCrop(object):
"""Crop the given PIL.Image into four corners and the central crop.abs """Crop the given PIL Image into four corners and the central crop.abs
Note: this transform returns a tuple of images and there may be a mismatch in the number of Note: this transform returns a tuple of images and there may be a mismatch in the number of
inputs and targets your `Dataset` returns. inputs and targets your `Dataset` returns.
...@@ -892,7 +900,7 @@ class FiveCrop(object): ...@@ -892,7 +900,7 @@ class FiveCrop(object):
class TenCrop(object): class TenCrop(object):
"""Crop the given PIL.Image into four corners and the central crop plus the """Crop the given PIL Image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default) flipped version of these (horizontal flipping is used by default)
Note: this transform returns a tuple of images and there may be a mismatch in the number of Note: this transform returns a tuple of images and there may be a mismatch in the number of
...@@ -972,10 +980,10 @@ class ColorJitter(object): ...@@ -972,10 +980,10 @@ class ColorJitter(object):
def __call__(self, img): def __call__(self, img):
""" """
Args: Args:
img (PIL.Image): Input image. img (PIL Image): Input image.
Returns: Returns:
PIL.Image: Color jittered image. PIL Image: Color jittered image.
""" """
transform = self.get_params(self.brightness, self.contrast, transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue) self.saturation, self.hue)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment