transforms.py 75.4 KB
Newer Older
1
import math
vfdev's avatar
vfdev committed
2
import numbers
3
import random
vfdev's avatar
vfdev committed
4
import warnings
vfdev's avatar
vfdev committed
5
from collections.abc import Sequence
6
from typing import Tuple, List, Optional
vfdev's avatar
vfdev committed
7
8
9
10

import torch
from torch import Tensor

11
12
13
14
15
16
try:
    import accimage
except ImportError:
    accimage = None

from . import functional as F
17
from .functional import InterpolationMode, _interpolation_modes_from_int
18

19

20
21
22
23
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
           "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
           "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
           "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
24
25
           "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
           "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]
26

27

28
class Compose:
29
30
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.
31
32
33
34
35
36
37
38
39

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
40
41
42
43
44
45
46
47
48
49
50
51
52

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

53
54
55
56
57
58
59
60
61
62
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

63
64
65
66
67
68
69
70
    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

71

72
class ToTensor:
73
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
74
75

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
surgan12's avatar
surgan12 committed
76
77
78
79
80
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.
81
82
83
84
85
86

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

    .. _references: https://github.com/pytorch/vision/tree/master/references/segmentation
87
88
89
90
91
92
93
94
95
96
97
98
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

99
100
101
    def __repr__(self):
        return self.__class__.__name__ + '()'

102

103
class PILToTensor:
104
    """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript.
105

vfdev's avatar
vfdev committed
106
    Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.pil_to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'


123
class ConvertImageDtype(torch.nn.Module):
124
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly
125
    This function does not support PIL Image.
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

    Args:
        dtype (torch.dtype): Desired data type of the output

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """

    def __init__(self, dtype: torch.dtype) -> None:
143
        super().__init__()
144
145
        self.dtype = dtype

vfdev's avatar
vfdev committed
146
    def forward(self, image):
147
148
149
        return F.convert_image_dtype(image, self.dtype)


150
class ToPILImage:
151
    """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.
152
153
154
155
156
157
158

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
vfdev's avatar
vfdev committed
159
160
161
162
163
            - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
            - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
            - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
            - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
            ``short``).
164

csukuangfj's avatar
csukuangfj committed
165
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    """
    def __init__(self, mode=None):
        self.mode = mode

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        return F.to_pil_image(pic, self.mode)

181
    def __repr__(self):
182
183
184
185
186
        format_string = self.__class__.__name__ + '('
        if self.mode is not None:
            format_string += 'mode={0}'.format(self.mode)
        format_string += ')'
        return format_string
187

188

189
class Normalize(torch.nn.Module):
Fang Gao's avatar
Fang Gao committed
190
    """Normalize a tensor image with mean and standard deviation.
191
    This transform does not support PIL Image.
192
193
194
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
abdjava's avatar
abdjava committed
195
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
196

197
    .. note::
198
        This transform acts out of place, i.e., it does not mutate the input tensor.
199

200
201
202
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
203
204
        inplace(bool,optional): Bool to make this operation in-place.

205
206
    """

surgan12's avatar
surgan12 committed
207
    def __init__(self, mean, std, inplace=False):
208
        super().__init__()
209
210
        self.mean = mean
        self.std = std
surgan12's avatar
surgan12 committed
211
        self.inplace = inplace
212

213
    def forward(self, tensor: Tensor) -> Tensor:
214
215
        """
        Args:
vfdev's avatar
vfdev committed
216
            tensor (Tensor): Tensor image to be normalized.
217
218
219
220

        Returns:
            Tensor: Normalized Tensor image.
        """
surgan12's avatar
surgan12 committed
221
        return F.normalize(tensor, self.mean, self.std, self.inplace)
222

223
224
225
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

226

vfdev's avatar
vfdev committed
227
228
class Resize(torch.nn.Module):
    """Resize the input image to the given size.
229
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
230
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
231
232
233
234
235
236

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
237
            (size * height / width, size).
238
            In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
239
240
241
242
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
            ``InterpolationMode.BICUBIC`` are supported.
243
244
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.

245
246
    """

247
    def __init__(self, size, interpolation=InterpolationMode.BILINEAR):
vfdev's avatar
vfdev committed
248
        super().__init__()
249
250
251
252
253
        if not isinstance(size, (int, Sequence)):
            raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
        if isinstance(size, Sequence) and len(size) not in (1, 2):
            raise ValueError("If size is a sequence, it should have 1 or 2 values")
        self.size = size
254
255
256
257

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
258
259
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
260
261
262
            )
            interpolation = _interpolation_modes_from_int(interpolation)

263
264
        self.interpolation = interpolation

vfdev's avatar
vfdev committed
265
    def forward(self, img):
266
267
        """
        Args:
vfdev's avatar
vfdev committed
268
            img (PIL Image or Tensor): Image to be scaled.
269
270

        Returns:
vfdev's avatar
vfdev committed
271
            PIL Image or Tensor: Rescaled image.
272
273
274
        """
        return F.resize(img, self.size, self.interpolation)

275
    def __repr__(self):
276
        interpolate_str = self.interpolation.value
277
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
278

279
280
281
282
283
284
285
286
287
288
289

class Scale(Resize):
    """
    Note: This transform is deprecated in favor of Resize.
    """
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                      "please use transforms.Resize instead.")
        super(Scale, self).__init__(*args, **kwargs)


vfdev's avatar
vfdev committed
290
291
class CenterCrop(torch.nn.Module):
    """Crops the given image at the center.
292
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
293
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
294
295
296
297

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
298
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
299
300
301
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
302
        super().__init__()
303
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
304

vfdev's avatar
vfdev committed
305
    def forward(self, img):
306
307
        """
        Args:
vfdev's avatar
vfdev committed
308
            img (PIL Image or Tensor): Image to be cropped.
309
310

        Returns:
vfdev's avatar
vfdev committed
311
            PIL Image or Tensor: Cropped image.
312
313
314
        """
        return F.center_crop(img, self.size)

315
316
317
    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)

318

319
320
class Pad(torch.nn.Module):
    """Pad the given image on all sides with the given "pad" value.
321
    If the image is torch Tensor, it is expected
322
323
324
    to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
    at most 3 leading dimensions for mode edge,
    and an arbitrary number of leading dimensions for mode constant
325
326

    Args:
327
328
329
        padding (int or sequence): Padding on each border. If a single int is provided this
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
330
            this is the padding for the left, top, right and bottom borders respectively.
331
332
            In torchscript mode padding as single int is not supported, use a sequence of length 1: ``[padding, ]``.
        fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
333
            length 3, it is used to fill R, G, B channels respectively.
334
335
336
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
            Only int or str or tuple value is supported for PIL Image.
337
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
338
            Default is constant.
339
340
341

            - constant: pads with a constant value, this value is specified with fill

342
343
            - edge: pads with the last value at the edge of the image,
                    if input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
344
345
346
347

            - reflect: pads with reflection of image without repeating the last value on the edge

                For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
348
                will result in [3, 2, 1, 2, 3, 4, 3, 2]
349
350
351
352

            - symmetric: pads with reflection of image repeating the last value on the edge

                For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
353
                will result in [2, 1, 1, 2, 3, 4, 4, 3]
354
355
    """

356
357
358
359
360
361
362
363
364
365
366
367
368
    def __init__(self, padding, fill=0, padding_mode="constant"):
        super().__init__()
        if not isinstance(padding, (numbers.Number, tuple, list)):
            raise TypeError("Got inappropriate padding arg")

        if not isinstance(fill, (numbers.Number, str, tuple)):
            raise TypeError("Got inappropriate fill arg")

        if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
            raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

        if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
            raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
369
370
371
372
                             "{} element tuple".format(len(padding)))

        self.padding = padding
        self.fill = fill
373
        self.padding_mode = padding_mode
374

375
    def forward(self, img):
376
377
        """
        Args:
378
            img (PIL Image or Tensor): Image to be padded.
379
380

        Returns:
381
            PIL Image or Tensor: Padded image.
382
        """
383
        return F.pad(img, self.padding, self.fill, self.padding_mode)
384

385
    def __repr__(self):
386
387
        return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
            format(self.padding, self.fill, self.padding_mode)
388

389

390
class Lambda:
391
    """Apply a user-defined lambda as a transform. This transform does not support torchscript.
392
393
394
395
396
397

    Args:
        lambd (function): Lambda/function to be used for transform.
    """

    def __init__(self, lambd):
398
399
        if not callable(lambd):
            raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__)))
400
401
402
403
404
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

405
406
407
    def __repr__(self):
        return self.__class__.__name__ + '()'

408

409
class RandomTransforms:
410
411
412
    """Base class for a list of transformations with randomness

    Args:
413
        transforms (sequence): list of transformations
414
415
416
    """

    def __init__(self, transforms):
417
418
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence")
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        self.transforms = transforms

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


433
class RandomApply(torch.nn.Module):
434
    """Apply randomly a list of transformations with a given probability.
435
436
437
438
439
440
441
442
443
444
445
446

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.
447
448

    Args:
449
        transforms (sequence or torch.nn.Module): list of transformations
450
451
452
453
        p (float): probability
    """

    def __init__(self, transforms, p=0.5):
454
455
        super().__init__()
        self.transforms = transforms
456
457
        self.p = p

458
459
    def forward(self, img):
        if self.p < torch.rand(1):
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
            return img
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += '\n    p={}'.format(self.p)
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class RandomOrder(RandomTransforms):
476
    """Apply a list of transformations in a random order. This transform does not support torchscript.
477
478
479
480
481
482
483
484
485
486
    """
    def __call__(self, img):
        order = list(range(len(self.transforms)))
        random.shuffle(order)
        for i in order:
            img = self.transforms[i](img)
        return img


class RandomChoice(RandomTransforms):
487
    """Apply single transformation randomly picked from a list. This transform does not support torchscript.
488
489
490
491
492
493
    """
    def __call__(self, img):
        t = random.choice(self.transforms)
        return t(img)


vfdev's avatar
vfdev committed
494
495
class RandomCrop(torch.nn.Module):
    """Crop the given image at a random location.
496
    If the image is torch Tensor, it is expected
497
498
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
    but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
499
500
501
502

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
503
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
504
        padding (int or sequence, optional): Optional padding on each border
vfdev's avatar
vfdev committed
505
            of the image. Default is None. If a single int is provided this
506
507
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
vfdev's avatar
vfdev committed
508
            this is the padding for the left, top, right and bottom borders respectively.
509
            In torchscript mode padding as single int is not supported, use a sequence of length 1: ``[padding, ]``.
510
        pad_if_needed (boolean): It will pad the image if smaller than the
ekka's avatar
ekka committed
511
            desired size to avoid raising an exception. Since cropping is done
512
            after padding, the padding seems to be done at a random offset.
513
        fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
514
            length 3, it is used to fill R, G, B channels respectively.
515
516
517
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
            Only int or str or tuple value is supported for PIL Image.
vfdev's avatar
vfdev committed
518
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533

             - constant: pads with a constant value, this value is specified with fill

             - edge: pads with the last value on the edge of the image

             - reflect: pads with reflection of image (without repeating the last value on the edge)

                padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
                will result in [3, 2, 1, 2, 3, 4, 3, 2]

             - symmetric: pads with reflection of image (repeating the last value on the edge)

                padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
                will result in [2, 1, 1, 2, 3, 4, 4, 3]

534
535
536
    """

    @staticmethod
vfdev's avatar
vfdev committed
537
    def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
538
539
540
        """Get parameters for ``crop`` for a random crop.

        Args:
vfdev's avatar
vfdev committed
541
            img (PIL Image or Tensor): Image to be cropped.
542
543
544
545
546
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
vfdev's avatar
vfdev committed
547
        w, h = F._get_image_size(img)
548
        th, tw = output_size
vfdev's avatar
vfdev committed
549
550
551
552
553
554

        if h + 1 < th or w + 1 < tw:
            raise ValueError(
                "Required crop size {} is larger then input image size {}".format((th, tw), (h, w))
            )

555
556
557
        if w == tw and h == th:
            return 0, 0, h, w

558
559
        i = torch.randint(0, h - th + 1, size=(1, )).item()
        j = torch.randint(0, w - tw + 1, size=(1, )).item()
560
561
        return i, j, th, tw

vfdev's avatar
vfdev committed
562
563
564
    def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
        super().__init__()

565
566
567
568
        self.size = tuple(_setup_size(
            size, error_msg="Please provide only two dimensions (h, w) for size."
        ))

vfdev's avatar
vfdev committed
569
570
571
572
573
574
        self.padding = padding
        self.pad_if_needed = pad_if_needed
        self.fill = fill
        self.padding_mode = padding_mode

    def forward(self, img):
575
576
        """
        Args:
vfdev's avatar
vfdev committed
577
            img (PIL Image or Tensor): Image to be cropped.
578
579

        Returns:
vfdev's avatar
vfdev committed
580
            PIL Image or Tensor: Cropped image.
581
        """
582
583
        if self.padding is not None:
            img = F.pad(img, self.padding, self.fill, self.padding_mode)
584

vfdev's avatar
vfdev committed
585
        width, height = F._get_image_size(img)
586
        # pad the width if needed
vfdev's avatar
vfdev committed
587
588
589
        if self.pad_if_needed and width < self.size[1]:
            padding = [self.size[1] - width, 0]
            img = F.pad(img, padding, self.fill, self.padding_mode)
590
        # pad the height if needed
vfdev's avatar
vfdev committed
591
592
593
        if self.pad_if_needed and height < self.size[0]:
            padding = [0, self.size[0] - height]
            img = F.pad(img, padding, self.fill, self.padding_mode)
594

595
596
597
598
        i, j, h, w = self.get_params(img, self.size)

        return F.crop(img, i, j, h, w)

599
    def __repr__(self):
vfdev's avatar
vfdev committed
600
        return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding)
601

602

603
604
class RandomHorizontalFlip(torch.nn.Module):
    """Horizontally flip the given image randomly with a given probability.
605
    If the image is torch Tensor, it is expected
606
607
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
608
609
610
611
612
613

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
614
        super().__init__()
615
        self.p = p
616

617
    def forward(self, img):
618
619
        """
        Args:
620
            img (PIL Image or Tensor): Image to be flipped.
621
622

        Returns:
623
            PIL Image or Tensor: Randomly flipped image.
624
        """
625
        if torch.rand(1) < self.p:
626
627
628
            return F.hflip(img)
        return img

629
    def __repr__(self):
630
        return self.__class__.__name__ + '(p={})'.format(self.p)
631

632

633
class RandomVerticalFlip(torch.nn.Module):
vfdev's avatar
vfdev committed
634
    """Vertically flip the given image randomly with a given probability.
635
    If the image is torch Tensor, it is expected
636
637
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
638
639
640
641
642
643

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
644
        super().__init__()
645
        self.p = p
646

647
    def forward(self, img):
648
649
        """
        Args:
650
            img (PIL Image or Tensor): Image to be flipped.
651
652

        Returns:
653
            PIL Image or Tensor: Randomly flipped image.
654
        """
655
        if torch.rand(1) < self.p:
656
657
658
            return F.vflip(img)
        return img

659
    def __repr__(self):
660
        return self.__class__.__name__ + '(p={})'.format(self.p)
661

662

663
664
class RandomPerspective(torch.nn.Module):
    """Performs a random perspective transformation of the given image with a given probability.
665
    If the image is torch Tensor, it is expected
666
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
667
668

    Args:
669
670
671
        distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
            Default is 0.5.
        p (float): probability of the image being transformed. Default is 0.5.
672
673
674
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
675
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
676
677
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
678
            If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
679
680
    """

681
    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
682
        super().__init__()
683
        self.p = p
684
685
686
687

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
688
689
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
690
691
692
            )
            interpolation = _interpolation_modes_from_int(interpolation)

693
694
        self.interpolation = interpolation
        self.distortion_scale = distortion_scale
695
696
697
698
699
700

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

701
        self.fill = fill
702

703
    def forward(self, img):
704
705
        """
        Args:
706
            img (PIL Image or Tensor): Image to be Perspectively transformed.
707
708

        Returns:
709
            PIL Image or Tensor: Randomly transformed image.
710
        """
711
712
713
714
715
716
717
718

        fill = self.fill
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * F._get_image_num_channels(img)
            else:
                fill = [float(f) for f in fill]

719
720
        if torch.rand(1) < self.p:
            width, height = F._get_image_size(img)
721
            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
722
            return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
723
724
725
        return img

    @staticmethod
726
    def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
727
728
729
        """Get parameters for ``perspective`` for a random perspective transform.

        Args:
730
731
732
            width (int): width of the image.
            height (int): height of the image.
            distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
733
734

        Returns:
735
            List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
736
737
            List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
        """
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
        half_height = height // 2
        half_width = width // 2
        topleft = [
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
        ]
        topright = [
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
        ]
        botright = [
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
        ]
        botleft = [
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
        ]
        startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
757
758
759
760
761
762
763
        endpoints = [topleft, topright, botright, botleft]
        return startpoints, endpoints

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


764
765
class RandomResizedCrop(torch.nn.Module):
    """Crop the given image to random size and aspect ratio.
766
    If the image is torch Tensor, it is expected
767
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
768

769
770
    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
771
772
773
774
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
775
776
        size (int or sequence): expected output size of each edge. If size is an
            int instead of sequence like (h, w), a square output size ``(size, size)`` is
777
778
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
            In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
779
780
        scale (tuple of float): scale range of the cropped image before resizing, relatively to the origin image.
        ratio (tuple of float): aspect ratio range of the cropped image before resizing.
781
782
783
784
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
            ``InterpolationMode.BICUBIC`` are supported.
785
786
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.

787
788
    """

789
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR):
790
        super().__init__()
791
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
792

793
        if not isinstance(scale, Sequence):
794
            raise TypeError("Scale should be a sequence")
795
        if not isinstance(ratio, Sequence):
796
            raise TypeError("Ratio should be a sequence")
797
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
798
            warnings.warn("Scale and ratio should be of kind (min, max)")
799

800
801
802
        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
803
804
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
805
806
807
            )
            interpolation = _interpolation_modes_from_int(interpolation)

808
        self.interpolation = interpolation
809
810
        self.scale = scale
        self.ratio = ratio
811
812

    @staticmethod
813
    def get_params(
814
            img: Tensor, scale: List[float], ratio: List[float]
815
    ) -> Tuple[int, int, int, int]:
816
817
818
        """Get parameters for ``crop`` for a random sized crop.

        Args:
819
            img (PIL Image or Tensor): Input image.
820
821
            scale (list): range of scale of the origin size cropped
            ratio (list): range of aspect ratio of the origin aspect ratio cropped
822
823
824
825
826

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
                sized crop.
        """
vfdev's avatar
vfdev committed
827
        width, height = F._get_image_size(img)
Zhicheng Yan's avatar
Zhicheng Yan committed
828
        area = height * width
829

830
        for _ in range(10):
831
            target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
832
833
834
835
            log_ratio = torch.log(torch.tensor(ratio))
            aspect_ratio = torch.exp(
                torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
            ).item()
836
837
838
839

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

Zhicheng Yan's avatar
Zhicheng Yan committed
840
            if 0 < w <= width and 0 < h <= height:
841
842
                i = torch.randint(0, height - h + 1, size=(1,)).item()
                j = torch.randint(0, width - w + 1, size=(1,)).item()
843
844
                return i, j, h, w

845
        # Fallback to central crop
Zhicheng Yan's avatar
Zhicheng Yan committed
846
        in_ratio = float(width) / float(height)
847
        if in_ratio < min(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
848
            w = width
849
            h = int(round(w / min(ratio)))
850
        elif in_ratio > max(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
851
            h = height
852
            w = int(round(h * max(ratio)))
853
        else:  # whole image
Zhicheng Yan's avatar
Zhicheng Yan committed
854
855
856
857
            w = width
            h = height
        i = (height - h) // 2
        j = (width - w) // 2
858
        return i, j, h, w
859

860
    def forward(self, img):
861
862
        """
        Args:
863
            img (PIL Image or Tensor): Image to be cropped and resized.
864
865

        Returns:
866
            PIL Image or Tensor: Randomly cropped and resized image.
867
        """
868
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
869
870
        return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)

871
    def __repr__(self):
872
        interpolate_str = self.interpolation.value
873
        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
874
875
        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
876
877
        format_string += ', interpolation={0})'.format(interpolate_str)
        return format_string
878

879
880
881
882
883
884
885
886
887
888
889

class RandomSizedCrop(RandomResizedCrop):
    """
    Note: This transform is deprecated in favor of RandomResizedCrop.
    """
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
                      "please use transforms.RandomResizedCrop instead.")
        super(RandomSizedCrop, self).__init__(*args, **kwargs)


vfdev's avatar
vfdev committed
890
891
class FiveCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop.
892
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
893
894
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
895
896
897
898
899
900
901
902
903

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.
904
            If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
905
906
907
908
909
910
911
912
913
914
915
916
917
918

    Example:
         >>> transform = Compose([
         >>>    FiveCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
919
        super().__init__()
920
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
921

vfdev's avatar
vfdev committed
922
923
924
925
926
927
928
929
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 5 images. Image can be PIL Image or Tensor
        """
930
931
        return F.five_crop(img, self.size)

932
933
934
    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)

935

vfdev's avatar
vfdev committed
936
937
938
class TenCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default).
939
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
940
941
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
942
943
944
945
946
947
948
949
950

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
951
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
952
        vertical_flip (bool): Use vertical flipping instead of horizontal
953
954
955
956
957
958
959
960
961
962
963
964
965
966

    Example:
         >>> transform = Compose([
         >>>    TenCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size, vertical_flip=False):
vfdev's avatar
vfdev committed
967
        super().__init__()
968
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
969
970
        self.vertical_flip = vertical_flip

vfdev's avatar
vfdev committed
971
972
973
974
975
976
977
978
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 10 images. Image can be PIL Image or Tensor
        """
979
980
        return F.ten_crop(img, self.size, self.vertical_flip)

981
    def __repr__(self):
982
        return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
983

984

985
class LinearTransformation(torch.nn.Module):
ekka's avatar
ekka committed
986
    """Transform a tensor image with a square transformation matrix and a mean_vector computed
987
    offline.
988
    This transform does not support PIL Image.
ekka's avatar
ekka committed
989
990
991
    Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
    subtract mean_vector from it which is then followed by computing the dot
    product with the transformation matrix and then reshaping the tensor to its
992
    original shape.
993

994
    Applications:
995
        whitening transformation: Suppose X is a column vector zero-centered data.
996
997
998
        Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
        perform SVD on this matrix and pass it as transformation_matrix.

999
1000
    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
ekka's avatar
ekka committed
1001
        mean_vector (Tensor): tensor [D], D = C x H x W
1002
1003
    """

ekka's avatar
ekka committed
1004
    def __init__(self, transformation_matrix, mean_vector):
1005
        super().__init__()
1006
1007
1008
        if transformation_matrix.size(0) != transformation_matrix.size(1):
            raise ValueError("transformation_matrix should be square. Got " +
                             "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
ekka's avatar
ekka committed
1009
1010
1011

        if mean_vector.size(0) != transformation_matrix.size(0):
            raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) +
Francisco Massa's avatar
Francisco Massa committed
1012
1013
                             " as any one of the dimensions of the transformation_matrix [{}]"
                             .format(tuple(transformation_matrix.size())))
ekka's avatar
ekka committed
1014

1015
1016
1017
1018
        if transformation_matrix.device != mean_vector.device:
            raise ValueError("Input tensors should be on the same device. Got {} and {}"
                             .format(transformation_matrix.device, mean_vector.device))

1019
        self.transformation_matrix = transformation_matrix
ekka's avatar
ekka committed
1020
        self.mean_vector = mean_vector
1021

1022
    def forward(self, tensor: Tensor) -> Tensor:
1023
1024
        """
        Args:
vfdev's avatar
vfdev committed
1025
            tensor (Tensor): Tensor image to be whitened.
1026
1027
1028
1029

        Returns:
            Tensor: Transformed image.
        """
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
        shape = tensor.shape
        n = shape[-3] * shape[-2] * shape[-1]
        if n != self.transformation_matrix.shape[0]:
            raise ValueError("Input tensor and transformation matrix have incompatible shape." +
                             "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) +
                             "{}".format(self.transformation_matrix.shape[0]))

        if tensor.device.type != self.mean_vector.device.type:
            raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. "
                             "Got {} vs {}".format(tensor.device, self.mean_vector.device))

        flat_tensor = tensor.view(-1, n) - self.mean_vector
1042
        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
1043
        tensor = transformed_tensor.view(shape)
1044
1045
        return tensor

1046
    def __repr__(self):
ekka's avatar
ekka committed
1047
1048
1049
        format_string = self.__class__.__name__ + '(transformation_matrix='
        format_string += (str(self.transformation_matrix.tolist()) + ')')
        format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')')
1050
1051
        return format_string

1052

1053
class ColorJitter(torch.nn.Module):
1054
    """Randomly change the brightness, contrast, saturation and hue of an image.
1055
    If the image is torch Tensor, it is expected
1056
1057
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.
1058
1059

    Args:
yaox12's avatar
yaox12 committed
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
        brightness (float or tuple of float (min, max)): How much to jitter brightness.
            brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
            or the given [min, max]. Should be non negative numbers.
        contrast (float or tuple of float (min, max)): How much to jitter contrast.
            contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
            or the given [min, max]. Should be non negative numbers.
        saturation (float or tuple of float (min, max)): How much to jitter saturation.
            saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
            or the given [min, max]. Should be non negative numbers.
        hue (float or tuple of float (min, max)): How much to jitter hue.
            hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
            Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1072
    """
1073

1074
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
1075
        super().__init__()
yaox12's avatar
yaox12 committed
1076
1077
1078
1079
1080
1081
        self.brightness = self._check_input(brightness, 'brightness')
        self.contrast = self._check_input(contrast, 'contrast')
        self.saturation = self._check_input(saturation, 'saturation')
        self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
                                     clip_first_on_zero=False)

1082
    @torch.jit.unused
yaox12's avatar
yaox12 committed
1083
1084
1085
1086
    def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError("If {} is a single number, it must be non negative.".format(name))
1087
            value = [center - float(value), center + float(value)]
yaox12's avatar
yaox12 committed
1088
            if clip_first_on_zero:
1089
                value[0] = max(value[0], 0.0)
yaox12's avatar
yaox12 committed
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError("{} values should be between {}".format(name, bound))
        else:
            raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))

        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
            value = None
        return value
1101
1102

    @staticmethod
1103
1104
1105
1106
1107
1108
    def get_params(brightness: Optional[List[float]],
                   contrast: Optional[List[float]],
                   saturation: Optional[List[float]],
                   hue: Optional[List[float]]
                   ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
        """Get the parameters for the randomized transform to be applied on image.
1109

1110
1111
1112
1113
1114
1115
1116
1117
1118
        Args:
            brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
                uniformly. Pass None to turn off the transformation.
            contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
                uniformly. Pass None to turn off the transformation.
            saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
                uniformly. Pass None to turn off the transformation.
            hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
                Pass None to turn off the transformation.
1119
1120

        Returns:
1121
1122
            tuple: The parameters used to apply the randomized transform
            along with their random order.
1123
        """
1124
        fn_idx = torch.randperm(4)
1125

1126
1127
1128
1129
        b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
        c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
        s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
        h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
1130

1131
        return fn_idx, b, c, s, h
1132

1133
    def forward(self, img):
1134
1135
        """
        Args:
1136
            img (PIL Image or Tensor): Input image.
1137
1138

        Returns:
1139
1140
            PIL Image or Tensor: Color jittered image.
        """
1141
1142
1143
        fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
            self.get_params(self.brightness, self.contrast, self.saturation, self.hue)

1144
        for fn_id in fn_idx:
1145
            if fn_id == 0 and brightness_factor is not None:
1146
                img = F.adjust_brightness(img, brightness_factor)
1147
            elif fn_id == 1 and contrast_factor is not None:
1148
                img = F.adjust_contrast(img, contrast_factor)
1149
            elif fn_id == 2 and saturation_factor is not None:
1150
                img = F.adjust_saturation(img, saturation_factor)
1151
            elif fn_id == 3 and hue_factor is not None:
1152
1153
1154
                img = F.adjust_hue(img, hue_factor)

        return img
1155

1156
    def __repr__(self):
1157
1158
1159
1160
1161
1162
        format_string = self.__class__.__name__ + '('
        format_string += 'brightness={0}'.format(self.brightness)
        format_string += ', contrast={0}'.format(self.contrast)
        format_string += ', saturation={0}'.format(self.saturation)
        format_string += ', hue={0})'.format(self.hue)
        return format_string
1163

1164

1165
class RandomRotation(torch.nn.Module):
1166
    """Rotate the image by angle.
1167
    If the image is torch Tensor, it is expected
1168
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1169
1170

    Args:
1171
        degrees (sequence or number): Range of degrees to select from.
1172
1173
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees).
1174
1175
1176
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1177
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1178
1179
1180
1181
        expand (bool, optional): Optional expansion flag.
            If true, expands the output to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1182
        center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1183
            Default is the center of the image.
1184
1185
        fill (sequence or number): Pixel fill value for the area outside the rotated
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1186
            If input is PIL Image, the options is only available for ``Pillow>=5.2.0``.
1187
        resample (int, optional): deprecated argument and will be removed since v0.10.0.
1188
            Please use the ``interpolation`` parameter instead.
1189
1190
1191

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1192
1193
    """

1194
    def __init__(
1195
        self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None
1196
    ):
1197
        super().__init__()
1198
1199
1200
1201
1202
1203
1204
1205
1206
        if resample is not None:
            warnings.warn(
                "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
            )
            interpolation = _interpolation_modes_from_int(resample)

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
1207
1208
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
1209
1210
1211
            )
            interpolation = _interpolation_modes_from_int(interpolation)

1212
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
1213
1214

        if center is not None:
1215
            _check_sequence_input(center, "center", req_sizes=(2, ))
1216
1217

        self.center = center
1218

1219
        self.resample = self.interpolation = interpolation
1220
        self.expand = expand
1221
1222
1223
1224
1225
1226

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1227
        self.fill = fill
1228
1229

    @staticmethod
1230
    def get_params(degrees: List[float]) -> float:
1231
1232
1233
        """Get parameters for ``rotate`` for a random rotation.

        Returns:
1234
            float: angle parameter to be passed to ``rotate`` for random rotation.
1235
        """
1236
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1237
1238
        return angle

1239
    def forward(self, img):
1240
        """
1241
        Args:
1242
            img (PIL Image or Tensor): Image to be rotated.
1243
1244

        Returns:
1245
            PIL Image or Tensor: Rotated image.
1246
        """
1247
1248
1249
1250
1251
1252
        fill = self.fill
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * F._get_image_num_channels(img)
            else:
                fill = [float(f) for f in fill]
1253
        angle = self.get_params(self.degrees)
1254
1255

        return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
1256

1257
    def __repr__(self):
1258
        interpolate_str = self.interpolation.value
1259
        format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
1260
        format_string += ', interpolation={0}'.format(interpolate_str)
1261
1262
1263
        format_string += ', expand={0}'.format(self.expand)
        if self.center is not None:
            format_string += ', center={0}'.format(self.center)
1264
1265
        if self.fill is not None:
            format_string += ', fill={0}'.format(self.fill)
1266
1267
        format_string += ')'
        return format_string
1268

1269

1270
1271
class RandomAffine(torch.nn.Module):
    """Random affine transformation of the image keeping center invariant.
1272
    If the image is torch Tensor, it is expected
1273
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1274
1275

    Args:
1276
        degrees (sequence or number): Range of degrees to select from.
1277
            If degrees is a number instead of sequence like (min, max), the range of degrees
1278
            will be (-degrees, +degrees). Set to 0 to deactivate rotations.
1279
1280
1281
1282
1283
1284
        translate (tuple, optional): tuple of maximum absolute fraction for horizontal
            and vertical translations. For example translate=(a, b), then horizontal shift
            is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
            randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
        scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
            randomly sampled from the range a <= scale <= b. Will keep original scale by default.
1285
        shear (sequence or number, optional): Range of degrees to select from.
ptrblck's avatar
ptrblck committed
1286
            If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
1287
1288
            will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the
            range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
ptrblck's avatar
ptrblck committed
1289
            a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1290
            Will not apply shear by default.
1291
1292
1293
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1294
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1295
1296
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1297
            If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
1298
        fillcolor (sequence or number, optional): deprecated argument and will be removed since v0.10.0.
1299
            Please use the ``fill`` parameter instead.
1300
        resample (int, optional): deprecated argument and will be removed since v0.10.0.
1301
            Please use the ``interpolation`` parameter instead.
1302
1303
1304

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1305
1306
    """

1307
    def __init__(
1308
        self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0,
1309
1310
        fillcolor=None, resample=None
    ):
1311
        super().__init__()
1312
1313
1314
1315
1316
1317
1318
1319
1320
        if resample is not None:
            warnings.warn(
                "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
            )
            interpolation = _interpolation_modes_from_int(resample)

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
1321
1322
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
1323
1324
1325
1326
1327
1328
1329
1330
1331
            )
            interpolation = _interpolation_modes_from_int(interpolation)

        if fillcolor is not None:
            warnings.warn(
                "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead"
            )
            fill = fillcolor

1332
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, ))
1333
1334

        if translate is not None:
1335
            _check_sequence_input(translate, "translate", req_sizes=(2, ))
1336
1337
1338
1339
1340
1341
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
1342
            _check_sequence_input(scale, "scale", req_sizes=(2, ))
1343
1344
1345
1346
1347
1348
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
1349
            self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
1350
1351
1352
        else:
            self.shear = shear

1353
        self.resample = self.interpolation = interpolation
1354
1355
1356
1357
1358
1359

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1360
        self.fillcolor = self.fill = fill
1361
1362

    @staticmethod
1363
1364
1365
1366
1367
1368
1369
    def get_params(
            degrees: List[float],
            translate: Optional[List[float]],
            scale_ranges: Optional[List[float]],
            shears: Optional[List[float]],
            img_size: List[int]
    ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
1370
1371
1372
        """Get parameters for affine transformation

        Returns:
1373
            params to be passed to the affine transformation
1374
        """
1375
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1376
        if translate is not None:
1377
1378
1379
1380
1381
            max_dx = float(translate[0] * img_size[0])
            max_dy = float(translate[1] * img_size[1])
            tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
            ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
            translations = (tx, ty)
1382
1383
1384
1385
        else:
            translations = (0, 0)

        if scale_ranges is not None:
1386
            scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
1387
1388
1389
        else:
            scale = 1.0

1390
        shear_x = shear_y = 0.0
1391
        if shears is not None:
1392
1393
1394
1395
1396
            shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
            if len(shears) == 4:
                shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())

        shear = (shear_x, shear_y)
1397
1398
1399

        return angle, translations, scale, shear

1400
    def forward(self, img):
1401
        """
1402
            img (PIL Image or Tensor): Image to be transformed.
1403
1404

        Returns:
1405
            PIL Image or Tensor: Affine transformed image.
1406
        """
1407
1408
1409
1410
1411
1412
        fill = self.fill
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * F._get_image_num_channels(img)
            else:
                fill = [float(f) for f in fill]
1413
1414
1415
1416

        img_size = F._get_image_size(img)

        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
1417
1418

        return F.affine(img, *ret, interpolation=self.interpolation, fill=fill)
1419
1420
1421
1422
1423
1424
1425
1426
1427

    def __repr__(self):
        s = '{name}(degrees={degrees}'
        if self.translate is not None:
            s += ', translate={translate}'
        if self.scale is not None:
            s += ', scale={scale}'
        if self.shear is not None:
            s += ', shear={shear}'
1428
        if self.interpolation != InterpolationMode.NEAREST:
1429
1430
1431
            s += ', interpolation={interpolation}'
        if self.fill != 0:
            s += ', fill={fill}'
1432
1433
        s += ')'
        d = dict(self.__dict__)
1434
        d['interpolation'] = self.interpolation.value
1435
1436
1437
        return s.format(name=self.__class__.__name__, **d)


1438
class Grayscale(torch.nn.Module):
1439
    """Convert image to grayscale.
1440
1441
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1442

1443
1444
1445
1446
    Args:
        num_output_channels (int): (1 or 3) number of channels desired for output image

    Returns:
1447
        PIL Image: Grayscale version of the input.
1448
1449
         - If ``num_output_channels == 1`` : returned image is single channel
         - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
1450
1451
1452
1453

    """

    def __init__(self, num_output_channels=1):
1454
        super().__init__()
1455
1456
        self.num_output_channels = num_output_channels

vfdev's avatar
vfdev committed
1457
    def forward(self, img):
1458
1459
        """
        Args:
1460
            img (PIL Image or Tensor): Image to be converted to grayscale.
1461
1462

        Returns:
1463
            PIL Image or Tensor: Grayscaled image.
1464
        """
1465
        return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
1466

1467
    def __repr__(self):
1468
        return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
1469

1470

1471
class RandomGrayscale(torch.nn.Module):
1472
    """Randomly convert image to grayscale with a probability of p (default 0.1).
1473
1474
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1475

1476
1477
1478
1479
    Args:
        p (float): probability that image should be converted to grayscale.

    Returns:
1480
        PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
1481
1482
1483
        with probability (1-p).
        - If input image is 1 channel: grayscale version is 1 channel
        - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1484
1485
1486
1487

    """

    def __init__(self, p=0.1):
1488
        super().__init__()
1489
1490
        self.p = p

vfdev's avatar
vfdev committed
1491
    def forward(self, img):
1492
1493
        """
        Args:
1494
            img (PIL Image or Tensor): Image to be converted to grayscale.
1495
1496

        Returns:
1497
            PIL Image or Tensor: Randomly grayscaled image.
1498
        """
1499
1500
1501
        num_output_channels = F._get_image_num_channels(img)
        if torch.rand(1) < self.p:
            return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
1502
        return img
1503
1504

    def __repr__(self):
1505
        return self.__class__.__name__ + '(p={0})'.format(self.p)
1506
1507


1508
class RandomErasing(torch.nn.Module):
1509
1510
    """ Randomly selects a rectangle region in an torch Tensor image and erases its pixels.
    This transform does not support PIL Image.
vfdev's avatar
vfdev committed
1511
    'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
1512

1513
1514
1515
1516
1517
1518
1519
1520
    Args:
         p: probability that the random erasing operation will be performed.
         scale: range of proportion of erased area against input image.
         ratio: range of aspect ratio of erased area.
         value: erasing value. Default is 0. If a single int, it is used to
            erase all pixels. If a tuple of length 3, it is used to erase
            R, G, B channels respectively.
            If a str of 'random', erasing each pixel with random values.
Zhun Zhong's avatar
Zhun Zhong committed
1521
         inplace: boolean to make this transform inplace. Default set to False.
1522

1523
1524
    Returns:
        Erased Image.
1525

vfdev's avatar
vfdev committed
1526
    Example:
1527
        >>> transform = transforms.Compose([
1528
1529
1530
1531
        >>>   transforms.RandomHorizontalFlip(),
        >>>   transforms.ToTensor(),
        >>>   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>>   transforms.RandomErasing(),
1532
1533
1534
        >>> ])
    """

Zhun Zhong's avatar
Zhun Zhong committed
1535
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
1536
1537
1538
1539
1540
1541
1542
1543
1544
        super().__init__()
        if not isinstance(value, (numbers.Number, str, tuple, list)):
            raise TypeError("Argument value should be either a number or str or a sequence")
        if isinstance(value, str) and value != "random":
            raise ValueError("If value is str, it should be 'random'")
        if not isinstance(scale, (tuple, list)):
            raise TypeError("Scale should be a sequence")
        if not isinstance(ratio, (tuple, list)):
            raise TypeError("Ratio should be a sequence")
1545
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
1546
            warnings.warn("Scale and ratio should be of kind (min, max)")
1547
        if scale[0] < 0 or scale[1] > 1:
1548
            raise ValueError("Scale should be between 0 and 1")
1549
        if p < 0 or p > 1:
1550
            raise ValueError("Random erasing probability should be between 0 and 1")
1551
1552
1553
1554
1555

        self.p = p
        self.scale = scale
        self.ratio = ratio
        self.value = value
1556
        self.inplace = inplace
1557
1558

    @staticmethod
1559
1560
1561
    def get_params(
            img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
    ) -> Tuple[int, int, int, int, Tensor]:
1562
1563
1564
        """Get parameters for ``erase`` for a random erasing.

        Args:
vfdev's avatar
vfdev committed
1565
            img (Tensor): Tensor image to be erased.
1566
1567
            scale (sequence): range of proportion of erased area against input image.
            ratio (sequence): range of aspect ratio of erased area.
1568
1569
1570
            value (list, optional): erasing value. If None, it is interpreted as "random"
                (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
                i.e. ``value[0]``.
1571
1572
1573
1574

        Returns:
            tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
        """
vfdev's avatar
vfdev committed
1575
        img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
1576
        area = img_h * img_w
1577

1578
        for _ in range(10):
1579
1580
            erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
            aspect_ratio = torch.empty(1).uniform_(ratio[0], ratio[1]).item()
1581
1582
1583

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
1584
1585
1586
1587
1588
1589
1590
            if not (h < img_h and w < img_w):
                continue

            if value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(value)[:, None, None]
1591

1592
1593
            i = torch.randint(0, img_h - h + 1, size=(1, )).item()
            j = torch.randint(0, img_w - w + 1, size=(1, )).item()
1594
            return i, j, h, w, v
1595

Zhun Zhong's avatar
Zhun Zhong committed
1596
1597
1598
        # Return original image
        return 0, 0, img_h, img_w, img

1599
    def forward(self, img):
1600
1601
        """
        Args:
vfdev's avatar
vfdev committed
1602
            img (Tensor): Tensor image to be erased.
1603
1604
1605
1606

        Returns:
            img (Tensor): Erased Tensor image.
        """
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
        if torch.rand(1) < self.p:

            # cast self.value to script acceptable type
            if isinstance(self.value, (int, float)):
                value = [self.value, ]
            elif isinstance(self.value, str):
                value = None
            elif isinstance(self.value, tuple):
                value = list(self.value)
            else:
                value = self.value

            if value is not None and not (len(value) in (1, img.shape[-3])):
                raise ValueError(
                    "If value is a sequence, it should have either a single value or "
                    "{} (number of input channels)".format(img.shape[-3])
                )

            x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
1626
            return F.erase(img, x, y, h, w, v, self.inplace)
1627
        return img
1628
1629


1630
1631
class GaussianBlur(torch.nn.Module):
    """Blurs image with randomly chosen Gaussian blur.
1632
1633
    If the image is torch Tensor, it is expected
    to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667

    Args:
        kernel_size (int or sequence): Size of the Gaussian kernel.
        sigma (float or tuple of float (min, max)): Standard deviation to be used for
            creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
            of float (min, max), sigma is chosen uniformly at random to lie in the
            given range.

    Returns:
        PIL Image or Tensor: Gaussian blurred version of the input image.

    """

    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        super().__init__()
        self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
        for ks in self.kernel_size:
            if ks <= 0 or ks % 2 == 0:
                raise ValueError("Kernel size value should be an odd and positive number.")

        if isinstance(sigma, numbers.Number):
            if sigma <= 0:
                raise ValueError("If sigma is a single number, it must be positive.")
            sigma = (sigma, sigma)
        elif isinstance(sigma, Sequence) and len(sigma) == 2:
            if not 0. < sigma[0] <= sigma[1]:
                raise ValueError("sigma values should be positive and of the form (min, max).")
        else:
            raise ValueError("sigma should be a single number or a list/tuple with length 2.")

        self.sigma = sigma

    @staticmethod
    def get_params(sigma_min: float, sigma_max: float) -> float:
vfdev's avatar
vfdev committed
1668
        """Choose sigma for random gaussian blurring.
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681

        Args:
            sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
            sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.

        Returns:
            float: Standard deviation to be passed to calculate kernel for gaussian blurring.
        """
        return torch.empty(1).uniform_(sigma_min, sigma_max).item()

    def forward(self, img: Tensor) -> Tensor:
        """
        Args:
vfdev's avatar
vfdev committed
1682
            img (PIL Image or Tensor): image to be blurred.
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695

        Returns:
            PIL Image or Tensor: Gaussian blurred image
        """
        sigma = self.get_params(self.sigma[0], self.sigma[1])
        return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])

    def __repr__(self):
        s = '(kernel_size={}, '.format(self.kernel_size)
        s += 'sigma={})'.format(self.sigma)
        return self.__class__.__name__ + s


1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
def _setup_size(size, error_msg):
    if isinstance(size, numbers.Number):
        return int(size), int(size)

    if isinstance(size, Sequence) and len(size) == 1:
        return size[0], size[0]

    if len(size) != 2:
        raise ValueError(error_msg)

    return size


def _check_sequence_input(x, name, req_sizes):
    msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
    if not isinstance(x, Sequence):
        raise TypeError("{} should be a sequence of length {}.".format(name, msg))
    if len(x) not in req_sizes:
        raise ValueError("{} should be sequence of length {}.".format(name, msg))


def _setup_angle(x, name, req_sizes=(2, )):
    if isinstance(x, numbers.Number):
        if x < 0:
            raise ValueError("If {} is a single number, it must be positive.".format(name))
        x = [-x, x]
    else:
        _check_sequence_input(x, name, req_sizes)

    return [float(d) for d in x]
1726
1727
1728
1729


class RandomInvert(torch.nn.Module):
    """Inverts the colors of the given image randomly with a given probability.
1730
1731
1732
    If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759

    Args:
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be inverted.

        Returns:
            PIL Image or Tensor: Randomly color inverted image.
        """
        if torch.rand(1).item() < self.p:
            return F.invert(img)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomPosterize(torch.nn.Module):
    """Posterize the image randomly with a given probability by reducing the
1760
1761
1762
    number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
    and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791

    Args:
        bits (int): number of bits to keep for each channel (0-8)
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, bits, p=0.5):
        super().__init__()
        self.bits = bits
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be posterized.

        Returns:
            PIL Image or Tensor: Randomly posterized image.
        """
        if torch.rand(1).item() < self.p:
            return F.posterize(img, self.bits)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p)


class RandomSolarize(torch.nn.Module):
    """Solarize the image randomly with a given probability by inverting all pixel
1792
1793
1794
    values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822

    Args:
        threshold (float): all pixels equal or above this value are inverted.
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, threshold, p=0.5):
        super().__init__()
        self.threshold = threshold
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be solarized.

        Returns:
            PIL Image or Tensor: Randomly solarized image.
        """
        if torch.rand(1).item() < self.p:
            return F.solarize(img, self.threshold)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p)


class RandomAdjustSharpness(torch.nn.Module):
1823
1824
    """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
    it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855

    Args:
        sharpness_factor (float):  How much to adjust the sharpness. Can be
            any non negative number. 0 gives a blurred image, 1 gives the
            original image while 2 increases the sharpness by a factor of 2.
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, sharpness_factor, p=0.5):
        super().__init__()
        self.sharpness_factor = sharpness_factor
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be sharpened.

        Returns:
            PIL Image or Tensor: Randomly sharpened image.
        """
        if torch.rand(1).item() < self.p:
            return F.adjust_sharpness(img, self.sharpness_factor)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p)


class RandomAutocontrast(torch.nn.Module):
    """Autocontrast the pixels of the given image randomly with a given probability.
1856
1857
1858
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885

    Args:
        p (float): probability of the image being autocontrasted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be autocontrasted.

        Returns:
            PIL Image or Tensor: Randomly autocontrasted image.
        """
        if torch.rand(1).item() < self.p:
            return F.autocontrast(img)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomEqualize(torch.nn.Module):
    """Equalize the histogram of the given image randomly with a given probability.
1886
1887
1888
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911

    Args:
        p (float): probability of the image being equalized. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be equalized.

        Returns:
            PIL Image or Tensor: Randomly equalized image.
        """
        if torch.rand(1).item() < self.p:
            return F.equalize(img)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)