transforms.py 84.2 KB
Newer Older
1
import math
vfdev's avatar
vfdev committed
2
import numbers
3
import random
vfdev's avatar
vfdev committed
4
import warnings
vfdev's avatar
vfdev committed
5
from collections.abc import Sequence
6
from typing import List, Optional, Tuple, Union
vfdev's avatar
vfdev committed
7
8
9
10

import torch
from torch import Tensor

11
12
13
14
15
try:
    import accimage
except ImportError:
    accimage = None

16
from ..utils import _log_api_usage_once
17
from . import functional as F
18
from .functional import _interpolation_modes_from_int, InterpolationMode
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
__all__ = [
    "Compose",
    "ToTensor",
    "PILToTensor",
    "ConvertImageDtype",
    "ToPILImage",
    "Normalize",
    "Resize",
    "CenterCrop",
    "Pad",
    "Lambda",
    "RandomApply",
    "RandomChoice",
    "RandomOrder",
    "RandomCrop",
    "RandomHorizontalFlip",
    "RandomVerticalFlip",
    "RandomResizedCrop",
    "FiveCrop",
    "TenCrop",
    "LinearTransformation",
    "ColorJitter",
    "RandomRotation",
    "RandomAffine",
    "Grayscale",
    "RandomGrayscale",
    "RandomPerspective",
    "RandomErasing",
    "GaussianBlur",
    "InterpolationMode",
    "RandomInvert",
    "RandomPosterize",
    "RandomSolarize",
    "RandomAdjustSharpness",
    "RandomAutocontrast",
    "RandomEqualize",
56
    "ElasticTransform",
57
]
58

59

60
class Compose:
61
62
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.
63
64
65
66
67
68
69

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
70
71
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
72
        >>> ])
73
74
75
76
77
78
79
80
81
82
83
84
85

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

86
87
88
    """

    def __init__(self, transforms):
89
90
        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            _log_api_usage_once(self)
91
92
93
94
95
96
97
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
98
    def __repr__(self) -> str:
99
        format_string = self.__class__.__name__ + "("
100
        for t in self.transforms:
101
            format_string += "\n"
102
            format_string += f"    {t}"
103
        format_string += "\n)"
104
105
        return format_string

106

107
class ToTensor:
108
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
109
110

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
surgan12's avatar
surgan12 committed
111
112
113
114
115
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.
116
117
118
119
120

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

121
    .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
122
123
    """

124
125
126
    def __init__(self) -> None:
        _log_api_usage_once(self)

127
128
129
130
131
132
133
134
135
136
    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

Joao Gomes's avatar
Joao Gomes committed
137
138
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
139

140

141
class PILToTensor:
142
    """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript.
143

vfdev's avatar
vfdev committed
144
    Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
145
146
    """

147
148
149
    def __init__(self) -> None:
        _log_api_usage_once(self)

150
151
    def __call__(self, pic):
        """
152
153
154
155
        .. note::

            A deep copy of the underlying array is performed.

156
157
158
159
160
161
162
163
        Args:
            pic (PIL Image): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.pil_to_tensor(pic)

Joao Gomes's avatar
Joao Gomes committed
164
165
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
166
167


168
class ConvertImageDtype(torch.nn.Module):
169
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly
170
    This function does not support PIL Image.
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

    Args:
        dtype (torch.dtype): Desired data type of the output

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """

    def __init__(self, dtype: torch.dtype) -> None:
188
        super().__init__()
189
        _log_api_usage_once(self)
190
191
        self.dtype = dtype

vfdev's avatar
vfdev committed
192
    def forward(self, image):
193
194
195
        return F.convert_image_dtype(image, self.dtype)


196
class ToPILImage:
197
    """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.
198
199
200
201
202
203
204

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
vfdev's avatar
vfdev committed
205
206
207
208
209
            - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
            - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
            - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
            - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
            ``short``).
210

csukuangfj's avatar
csukuangfj committed
211
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
212
    """
213

214
    def __init__(self, mode=None):
215
        _log_api_usage_once(self)
216
217
218
219
220
221
222
223
224
225
226
227
228
        self.mode = mode

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        return F.to_pil_image(pic, self.mode)

Joao Gomes's avatar
Joao Gomes committed
229
    def __repr__(self) -> str:
230
        format_string = self.__class__.__name__ + "("
231
        if self.mode is not None:
232
            format_string += f"mode={self.mode}"
233
        format_string += ")"
234
        return format_string
235

236

237
class Normalize(torch.nn.Module):
Fang Gao's avatar
Fang Gao committed
238
    """Normalize a tensor image with mean and standard deviation.
239
    This transform does not support PIL Image.
240
241
242
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
abdjava's avatar
abdjava committed
243
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
244

245
    .. note::
246
        This transform acts out of place, i.e., it does not mutate the input tensor.
247

248
249
250
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
251
252
        inplace(bool,optional): Bool to make this operation in-place.

253
254
    """

surgan12's avatar
surgan12 committed
255
    def __init__(self, mean, std, inplace=False):
256
        super().__init__()
257
        _log_api_usage_once(self)
258
259
        self.mean = mean
        self.std = std
surgan12's avatar
surgan12 committed
260
        self.inplace = inplace
261

262
    def forward(self, tensor: Tensor) -> Tensor:
263
264
        """
        Args:
vfdev's avatar
vfdev committed
265
            tensor (Tensor): Tensor image to be normalized.
266
267
268
269

        Returns:
            Tensor: Normalized Tensor image.
        """
surgan12's avatar
surgan12 committed
270
        return F.normalize(tensor, self.mean, self.std, self.inplace)
271

Joao Gomes's avatar
Joao Gomes committed
272
273
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
274

275

vfdev's avatar
vfdev committed
276
277
class Resize(torch.nn.Module):
    """Resize the input image to the given size.
278
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
279
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
280

281
282
283
284
    .. warning::
        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
285
286
        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
        closer.
287

288
289
290
291
292
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
293
            (size * height / width, size).
294
295
296

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
297
298
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
299
300
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
301
302
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
303
304
305
306
        max_size (int, optional): The maximum allowed for the longer edge of
            the resized image: if the longer edge of the image is greater
            than ``max_size`` after being resized according to ``size``, then
            the image is resized again so that the longer edge is equal to
307
            ``max_size``. As a result, ``size`` might be overruled, i.e. the
308
309
310
            smaller edge may be shorter than ``size``. This is only supported
            if ``size`` is an int (or a sequence of length 1 in torchscript
            mode).
311
        antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
312
            is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
313
314
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes.
            This can help making the output for PIL images and tensors closer.
315
316
    """

317
    def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None):
vfdev's avatar
vfdev committed
318
        super().__init__()
319
        _log_api_usage_once(self)
320
        if not isinstance(size, (int, Sequence)):
321
            raise TypeError(f"Size should be int or sequence. Got {type(size)}")
322
323
324
        if isinstance(size, Sequence) and len(size) not in (1, 2):
            raise ValueError("If size is a sequence, it should have 1 or 2 values")
        self.size = size
325
        self.max_size = max_size
326
327
328
329

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
330
331
                "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
                "Please use InterpolationMode enum."
332
333
334
            )
            interpolation = _interpolation_modes_from_int(interpolation)

335
        self.interpolation = interpolation
336
        self.antialias = antialias
337

vfdev's avatar
vfdev committed
338
    def forward(self, img):
339
340
        """
        Args:
vfdev's avatar
vfdev committed
341
            img (PIL Image or Tensor): Image to be scaled.
342
343

        Returns:
vfdev's avatar
vfdev committed
344
            PIL Image or Tensor: Rescaled image.
345
        """
346
        return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
347

Joao Gomes's avatar
Joao Gomes committed
348
    def __repr__(self) -> str:
349
        detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
Joao Gomes's avatar
Joao Gomes committed
350
        return f"{self.__class__.__name__}{detail}"
351

352

vfdev's avatar
vfdev committed
353
354
class CenterCrop(torch.nn.Module):
    """Crops the given image at the center.
355
    If the image is torch Tensor, it is expected
356
357
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
358
359
360
361

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
362
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
363
364
365
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
366
        super().__init__()
367
        _log_api_usage_once(self)
368
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
369

vfdev's avatar
vfdev committed
370
    def forward(self, img):
371
372
        """
        Args:
vfdev's avatar
vfdev committed
373
            img (PIL Image or Tensor): Image to be cropped.
374
375

        Returns:
vfdev's avatar
vfdev committed
376
            PIL Image or Tensor: Cropped image.
377
378
379
        """
        return F.center_crop(img, self.size)

Joao Gomes's avatar
Joao Gomes committed
380
381
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size})"
382

383

384
385
class Pad(torch.nn.Module):
    """Pad the given image on all sides with the given "pad" value.
386
    If the image is torch Tensor, it is expected
387
388
389
    to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
    at most 3 leading dimensions for mode edge,
    and an arbitrary number of leading dimensions for mode constant
390
391

    Args:
392
393
394
        padding (int or sequence): Padding on each border. If a single int is provided this
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
395
            this is the padding for the left, top, right and bottom borders respectively.
396
397
398
399

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
400
        fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
401
            length 3, it is used to fill R, G, B channels respectively.
402
403
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
404
            Only int or tuple value is supported for PIL Image.
405
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
406
            Default is constant.
407
408
409

            - constant: pads with a constant value, this value is specified with fill

410
411
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
412

413
414
415
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
416

417
418
419
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
420
421
    """

422
423
    def __init__(self, padding, fill=0, padding_mode="constant"):
        super().__init__()
424
        _log_api_usage_once(self)
425
426
427
        if not isinstance(padding, (numbers.Number, tuple, list)):
            raise TypeError("Got inappropriate padding arg")

428
        if not isinstance(fill, (numbers.Number, tuple, list)):
429
430
431
432
433
434
            raise TypeError("Got inappropriate fill arg")

        if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
            raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

        if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
435
            raise ValueError(
436
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
437
            )
438
439
440

        self.padding = padding
        self.fill = fill
441
        self.padding_mode = padding_mode
442

443
    def forward(self, img):
444
445
        """
        Args:
446
            img (PIL Image or Tensor): Image to be padded.
447
448

        Returns:
449
            PIL Image or Tensor: Padded image.
450
        """
451
        return F.pad(img, self.padding, self.fill, self.padding_mode)
452

Joao Gomes's avatar
Joao Gomes committed
453
454
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
455

456

457
class Lambda:
458
    """Apply a user-defined lambda as a transform. This transform does not support torchscript.
459
460
461
462
463
464

    Args:
        lambd (function): Lambda/function to be used for transform.
    """

    def __init__(self, lambd):
465
        _log_api_usage_once(self)
466
        if not callable(lambd):
467
            raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
468
469
470
471
472
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

Joao Gomes's avatar
Joao Gomes committed
473
474
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
475

476

477
class RandomTransforms:
478
479
480
    """Base class for a list of transformations with randomness

    Args:
481
        transforms (sequence): list of transformations
482
483
484
    """

    def __init__(self, transforms):
485
        _log_api_usage_once(self)
486
487
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence")
488
489
490
491
492
        self.transforms = transforms

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()

Joao Gomes's avatar
Joao Gomes committed
493
    def __repr__(self) -> str:
494
        format_string = self.__class__.__name__ + "("
495
        for t in self.transforms:
496
            format_string += "\n"
497
            format_string += f"    {t}"
498
        format_string += "\n)"
499
500
501
        return format_string


502
class RandomApply(torch.nn.Module):
503
    """Apply randomly a list of transformations with a given probability.
504
505
506
507
508
509
510
511
512
513
514
515

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.
516
517

    Args:
518
        transforms (sequence or torch.nn.Module): list of transformations
519
520
521
522
        p (float): probability
    """

    def __init__(self, transforms, p=0.5):
523
        super().__init__()
524
        _log_api_usage_once(self)
525
        self.transforms = transforms
526
527
        self.p = p

528
529
    def forward(self, img):
        if self.p < torch.rand(1):
530
531
532
533
534
            return img
        for t in self.transforms:
            img = t(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
535
    def __repr__(self) -> str:
536
        format_string = self.__class__.__name__ + "("
537
        format_string += f"\n    p={self.p}"
538
        for t in self.transforms:
539
            format_string += "\n"
540
            format_string += f"    {t}"
541
        format_string += "\n)"
542
543
544
545
        return format_string


class RandomOrder(RandomTransforms):
546
547
    """Apply a list of transformations in a random order. This transform does not support torchscript."""

548
549
550
551
552
553
554
555
556
    def __call__(self, img):
        order = list(range(len(self.transforms)))
        random.shuffle(order)
        for i in order:
            img = self.transforms[i](img)
        return img


class RandomChoice(RandomTransforms):
557
558
    """Apply single transformation randomly picked from a list. This transform does not support torchscript."""

559
560
561
    def __init__(self, transforms, p=None):
        super().__init__(transforms)
        if p is not None and not isinstance(p, Sequence):
562
            raise TypeError("Argument p should be a sequence")
563
564
565
566
567
568
        self.p = p

    def __call__(self, *args):
        t = random.choices(self.transforms, weights=self.p)[0]
        return t(*args)

Joao Gomes's avatar
Joao Gomes committed
569
570
    def __repr__(self) -> str:
        return f"{super().__repr__()}(p={self.p})"
571
572


vfdev's avatar
vfdev committed
573
574
class RandomCrop(torch.nn.Module):
    """Crop the given image at a random location.
575
    If the image is torch Tensor, it is expected
576
577
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
    but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
578
579
580
581

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
582
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
583
        padding (int or sequence, optional): Optional padding on each border
vfdev's avatar
vfdev committed
584
            of the image. Default is None. If a single int is provided this
585
586
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
vfdev's avatar
vfdev committed
587
            this is the padding for the left, top, right and bottom borders respectively.
588
589
590
591

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
592
        pad_if_needed (boolean): It will pad the image if smaller than the
ekka's avatar
ekka committed
593
            desired size to avoid raising an exception. Since cropping is done
594
            after padding, the padding seems to be done at a random offset.
595
        fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
596
            length 3, it is used to fill R, G, B channels respectively.
597
598
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
599
            Only int or tuple value is supported for PIL Image.
600
601
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is constant.
602

603
            - constant: pads with a constant value, this value is specified with fill
604

605
606
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
607

608
609
610
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
611

612
613
614
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
615
616
617
    """

    @staticmethod
vfdev's avatar
vfdev committed
618
    def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
619
620
621
        """Get parameters for ``crop`` for a random crop.

        Args:
vfdev's avatar
vfdev committed
622
            img (PIL Image or Tensor): Image to be cropped.
623
624
625
626
627
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
628
        _, h, w = F.get_dimensions(img)
629
        th, tw = output_size
vfdev's avatar
vfdev committed
630

631
        if h < th or w < tw:
632
            raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
vfdev's avatar
vfdev committed
633

634
635
636
        if w == tw and h == th:
            return 0, 0, h, w

637
638
        i = torch.randint(0, h - th + 1, size=(1,)).item()
        j = torch.randint(0, w - tw + 1, size=(1,)).item()
639
640
        return i, j, th, tw

vfdev's avatar
vfdev committed
641
642
    def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
        super().__init__()
643
        _log_api_usage_once(self)
vfdev's avatar
vfdev committed
644

645
        self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
646

vfdev's avatar
vfdev committed
647
648
649
650
651
652
        self.padding = padding
        self.pad_if_needed = pad_if_needed
        self.fill = fill
        self.padding_mode = padding_mode

    def forward(self, img):
653
654
        """
        Args:
vfdev's avatar
vfdev committed
655
            img (PIL Image or Tensor): Image to be cropped.
656
657

        Returns:
vfdev's avatar
vfdev committed
658
            PIL Image or Tensor: Cropped image.
659
        """
660
661
        if self.padding is not None:
            img = F.pad(img, self.padding, self.fill, self.padding_mode)
662

663
        _, height, width = F.get_dimensions(img)
664
        # pad the width if needed
vfdev's avatar
vfdev committed
665
666
667
        if self.pad_if_needed and width < self.size[1]:
            padding = [self.size[1] - width, 0]
            img = F.pad(img, padding, self.fill, self.padding_mode)
668
        # pad the height if needed
vfdev's avatar
vfdev committed
669
670
671
        if self.pad_if_needed and height < self.size[0]:
            padding = [0, self.size[0] - height]
            img = F.pad(img, padding, self.fill, self.padding_mode)
672

673
674
675
676
        i, j, h, w = self.get_params(img, self.size)

        return F.crop(img, i, j, h, w)

Joao Gomes's avatar
Joao Gomes committed
677
678
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})"
679

680

681
682
class RandomHorizontalFlip(torch.nn.Module):
    """Horizontally flip the given image randomly with a given probability.
683
    If the image is torch Tensor, it is expected
684
685
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
686
687
688
689
690
691

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
692
        super().__init__()
693
        _log_api_usage_once(self)
694
        self.p = p
695

696
    def forward(self, img):
697
698
        """
        Args:
699
            img (PIL Image or Tensor): Image to be flipped.
700
701

        Returns:
702
            PIL Image or Tensor: Randomly flipped image.
703
        """
704
        if torch.rand(1) < self.p:
705
706
707
            return F.hflip(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
708
709
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
710

711

712
class RandomVerticalFlip(torch.nn.Module):
vfdev's avatar
vfdev committed
713
    """Vertically flip the given image randomly with a given probability.
714
    If the image is torch Tensor, it is expected
715
716
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
717
718
719
720
721
722

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
723
        super().__init__()
724
        _log_api_usage_once(self)
725
        self.p = p
726

727
    def forward(self, img):
728
729
        """
        Args:
730
            img (PIL Image or Tensor): Image to be flipped.
731
732

        Returns:
733
            PIL Image or Tensor: Randomly flipped image.
734
        """
735
        if torch.rand(1) < self.p:
736
737
738
            return F.vflip(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
739
740
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
741

742

743
744
class RandomPerspective(torch.nn.Module):
    """Performs a random perspective transformation of the given image with a given probability.
745
    If the image is torch Tensor, it is expected
746
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
747
748

    Args:
749
750
751
        distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
            Default is 0.5.
        p (float): probability of the image being transformed. Default is 0.5.
752
753
754
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
755
756
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
757
758
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
759
760
    """

761
    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
762
        super().__init__()
763
        _log_api_usage_once(self)
764
        self.p = p
765
766
767
768

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
769
770
                "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
                "Please use InterpolationMode enum."
771
772
773
            )
            interpolation = _interpolation_modes_from_int(interpolation)

774
775
        self.interpolation = interpolation
        self.distortion_scale = distortion_scale
776
777
778
779
780
781

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

782
        self.fill = fill
783

784
    def forward(self, img):
785
786
        """
        Args:
787
            img (PIL Image or Tensor): Image to be Perspectively transformed.
788
789

        Returns:
790
            PIL Image or Tensor: Randomly transformed image.
791
        """
792
793

        fill = self.fill
794
        channels, height, width = F.get_dimensions(img)
795
796
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
797
                fill = [float(fill)] * channels
798
799
800
            else:
                fill = [float(f) for f in fill]

801
        if torch.rand(1) < self.p:
802
            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
803
            return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
804
805
806
        return img

    @staticmethod
807
    def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
808
809
810
        """Get parameters for ``perspective`` for a random perspective transform.

        Args:
811
812
813
            width (int): width of the image.
            height (int): height of the image.
            distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
814
815

        Returns:
816
            List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
817
818
            List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
        """
819
820
821
        half_height = height // 2
        half_width = width // 2
        topleft = [
822
823
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
824
825
        ]
        topright = [
826
827
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
828
829
        ]
        botright = [
830
831
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
832
833
        ]
        botleft = [
834
835
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
836
837
        ]
        startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
838
839
840
        endpoints = [topleft, topright, botright, botleft]
        return startpoints, endpoints

Joao Gomes's avatar
Joao Gomes committed
841
842
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
843
844


845
class RandomResizedCrop(torch.nn.Module):
846
847
    """Crop a random portion of image and resize it to a given size.

848
    If the image is torch Tensor, it is expected
849
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
850

851
852
853
    A crop of the original image is made: the crop has a random area (H * W)
    and a random aspect ratio. This crop is finally resized to the given
    size. This is popularly used to train the Inception networks.
854
855

    Args:
856
        size (int or sequence): expected output size of the crop, for each edge. If size is an
857
            int instead of sequence like (h, w), a square output size ``(size, size)`` is
858
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
859
860
861

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
Nicolas Hug's avatar
Nicolas Hug committed
862
863
        scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
            before resizing. The scale is defined with respect to the area of the original image.
864
865
        ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
            resizing.
866
867
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
868
869
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
870
871
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
872
873
874
875
        antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
            is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes.
            This can help making the output for PIL images and tensors closer.
876
877
    """

878
879
880
881
882
883
884
885
    def __init__(
        self,
        size,
        scale=(0.08, 1.0),
        ratio=(3.0 / 4.0, 4.0 / 3.0),
        interpolation=InterpolationMode.BILINEAR,
        antialias: Optional[bool] = None,
    ):
886
        super().__init__()
887
        _log_api_usage_once(self)
888
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
889

890
        if not isinstance(scale, Sequence):
891
            raise TypeError("Scale should be a sequence")
892
        if not isinstance(ratio, Sequence):
893
            raise TypeError("Ratio should be a sequence")
894
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
895
            warnings.warn("Scale and ratio should be of kind (min, max)")
896

897
898
899
        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
900
901
                "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
                "Please use InterpolationMode enum."
902
903
904
            )
            interpolation = _interpolation_modes_from_int(interpolation)

905
        self.interpolation = interpolation
906
        self.antialias = antialias
907
908
        self.scale = scale
        self.ratio = ratio
909
910

    @staticmethod
911
    def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
912
913
914
        """Get parameters for ``crop`` for a random sized crop.

        Args:
915
            img (PIL Image or Tensor): Input image.
916
917
            scale (list): range of scale of the origin size cropped
            ratio (list): range of aspect ratio of the origin aspect ratio cropped
918
919
920

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
921
            sized crop.
922
        """
923
        _, height, width = F.get_dimensions(img)
Zhicheng Yan's avatar
Zhicheng Yan committed
924
        area = height * width
925

926
        log_ratio = torch.log(torch.tensor(ratio))
927
        for _ in range(10):
928
            target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
929
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
930
931
932
933

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

Zhicheng Yan's avatar
Zhicheng Yan committed
934
            if 0 < w <= width and 0 < h <= height:
935
936
                i = torch.randint(0, height - h + 1, size=(1,)).item()
                j = torch.randint(0, width - w + 1, size=(1,)).item()
937
938
                return i, j, h, w

939
        # Fallback to central crop
Zhicheng Yan's avatar
Zhicheng Yan committed
940
        in_ratio = float(width) / float(height)
941
        if in_ratio < min(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
942
            w = width
943
            h = int(round(w / min(ratio)))
944
        elif in_ratio > max(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
945
            h = height
946
            w = int(round(h * max(ratio)))
947
        else:  # whole image
Zhicheng Yan's avatar
Zhicheng Yan committed
948
949
950
951
            w = width
            h = height
        i = (height - h) // 2
        j = (width - w) // 2
952
        return i, j, h, w
953

954
    def forward(self, img):
955
956
        """
        Args:
957
            img (PIL Image or Tensor): Image to be cropped and resized.
958
959

        Returns:
960
            PIL Image or Tensor: Randomly cropped and resized image.
961
        """
962
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
963
        return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
964

Joao Gomes's avatar
Joao Gomes committed
965
    def __repr__(self) -> str:
966
        interpolate_str = self.interpolation.value
967
968
969
        format_string = self.__class__.__name__ + f"(size={self.size}"
        format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
        format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
970
        format_string += f", interpolation={interpolate_str}"
971
        format_string += f", antialias={self.antialias})"
972
        return format_string
973

974

vfdev's avatar
vfdev committed
975
976
class FiveCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop.
977
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
978
979
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
980
981
982
983
984
985
986
987
988

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.
989
            If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
990
991
992
993

    Example:
         >>> transform = Compose([
         >>>    FiveCrop(size), # this is a list of PIL Images
994
         >>>    Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
995
996
997
998
999
1000
1001
1002
1003
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
1004
        super().__init__()
1005
        _log_api_usage_once(self)
1006
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1007

vfdev's avatar
vfdev committed
1008
1009
1010
1011
1012
1013
1014
1015
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 5 images. Image can be PIL Image or Tensor
        """
1016
1017
        return F.five_crop(img, self.size)

Joao Gomes's avatar
Joao Gomes committed
1018
1019
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size})"
1020

1021

vfdev's avatar
vfdev committed
1022
1023
1024
class TenCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default).
1025
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1026
1027
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
1028
1029
1030
1031
1032
1033
1034
1035
1036

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
1037
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
1038
        vertical_flip (bool): Use vertical flipping instead of horizontal
1039
1040
1041
1042

    Example:
         >>> transform = Compose([
         >>>    TenCrop(size), # this is a list of PIL Images
1043
         >>>    Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
1044
1045
1046
1047
1048
1049
1050
1051
1052
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size, vertical_flip=False):
vfdev's avatar
vfdev committed
1053
        super().__init__()
1054
        _log_api_usage_once(self)
1055
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1056
1057
        self.vertical_flip = vertical_flip

vfdev's avatar
vfdev committed
1058
1059
1060
1061
1062
1063
1064
1065
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 10 images. Image can be PIL Image or Tensor
        """
1066
1067
        return F.ten_crop(img, self.size, self.vertical_flip)

Joao Gomes's avatar
Joao Gomes committed
1068
1069
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
1070

1071

1072
class LinearTransformation(torch.nn.Module):
ekka's avatar
ekka committed
1073
    """Transform a tensor image with a square transformation matrix and a mean_vector computed
1074
    offline.
1075
    This transform does not support PIL Image.
ekka's avatar
ekka committed
1076
1077
1078
    Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
    subtract mean_vector from it which is then followed by computing the dot
    product with the transformation matrix and then reshaping the tensor to its
1079
    original shape.
1080

1081
    Applications:
1082
        whitening transformation: Suppose X is a column vector zero-centered data.
1083
1084
1085
        Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
        perform SVD on this matrix and pass it as transformation_matrix.

1086
1087
    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
ekka's avatar
ekka committed
1088
        mean_vector (Tensor): tensor [D], D = C x H x W
1089
1090
    """

ekka's avatar
ekka committed
1091
    def __init__(self, transformation_matrix, mean_vector):
1092
        super().__init__()
1093
        _log_api_usage_once(self)
1094
        if transformation_matrix.size(0) != transformation_matrix.size(1):
1095
1096
            raise ValueError(
                "transformation_matrix should be square. Got "
1097
                f"{tuple(transformation_matrix.size())} rectangular matrix."
1098
            )
ekka's avatar
ekka committed
1099
1100

        if mean_vector.size(0) != transformation_matrix.size(0):
1101
            raise ValueError(
1102
1103
                f"mean_vector should have the same length {mean_vector.size(0)}"
                f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
1104
            )
ekka's avatar
ekka committed
1105

1106
        if transformation_matrix.device != mean_vector.device:
1107
            raise ValueError(
1108
                f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
1109
            )
1110

1111
        self.transformation_matrix = transformation_matrix
ekka's avatar
ekka committed
1112
        self.mean_vector = mean_vector
1113

1114
    def forward(self, tensor: Tensor) -> Tensor:
1115
1116
        """
        Args:
vfdev's avatar
vfdev committed
1117
            tensor (Tensor): Tensor image to be whitened.
1118
1119
1120
1121

        Returns:
            Tensor: Transformed image.
        """
1122
1123
1124
        shape = tensor.shape
        n = shape[-3] * shape[-2] * shape[-1]
        if n != self.transformation_matrix.shape[0]:
1125
1126
            raise ValueError(
                "Input tensor and transformation matrix have incompatible shape."
1127
1128
                + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
                + f"{self.transformation_matrix.shape[0]}"
1129
            )
1130
1131

        if tensor.device.type != self.mean_vector.device.type:
1132
1133
            raise ValueError(
                "Input tensor should be on the same device as transformation matrix and mean vector. "
1134
                f"Got {tensor.device} vs {self.mean_vector.device}"
1135
            )
1136
1137

        flat_tensor = tensor.view(-1, n) - self.mean_vector
1138
        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
1139
        tensor = transformed_tensor.view(shape)
1140
1141
        return tensor

Joao Gomes's avatar
Joao Gomes committed
1142
1143
1144
1145
1146
1147
1148
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}(transformation_matrix="
            f"{self.transformation_matrix.tolist()}"
            f", mean_vector={self.mean_vector.tolist()})"
        )
        return s
1149

1150

1151
class ColorJitter(torch.nn.Module):
1152
    """Randomly change the brightness, contrast, saturation and hue of an image.
1153
    If the image is torch Tensor, it is expected
1154
1155
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
1156
1157

    Args:
yaox12's avatar
yaox12 committed
1158
1159
1160
1161
1162
        brightness (float or tuple of float (min, max)): How much to jitter brightness.
            brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
            or the given [min, max]. Should be non negative numbers.
        contrast (float or tuple of float (min, max)): How much to jitter contrast.
            contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
1163
            or the given [min, max]. Should be non-negative numbers.
yaox12's avatar
yaox12 committed
1164
1165
1166
1167
1168
1169
        saturation (float or tuple of float (min, max)): How much to jitter saturation.
            saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
            or the given [min, max]. Should be non negative numbers.
        hue (float or tuple of float (min, max)): How much to jitter hue.
            hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
            Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1170
1171
1172
            To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
            thus it does not work if you normalize your image to an interval with negative values,
            or use an interpolation that generates negative values before using this function.
1173
    """
1174

1175
1176
1177
1178
1179
1180
1181
    def __init__(
        self,
        brightness: Union[float, Tuple[float, float]] = 0,
        contrast: Union[float, Tuple[float, float]] = 0,
        saturation: Union[float, Tuple[float, float]] = 0,
        hue: Union[float, Tuple[float, float]] = 0,
    ) -> None:
1182
        super().__init__()
1183
        _log_api_usage_once(self)
1184
1185
1186
1187
        self.brightness = self._check_input(brightness, "brightness")
        self.contrast = self._check_input(contrast, "contrast")
        self.saturation = self._check_input(saturation, "saturation")
        self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
yaox12's avatar
yaox12 committed
1188

1189
    @torch.jit.unused
1190
    def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
yaox12's avatar
yaox12 committed
1191
1192
        if isinstance(value, numbers.Number):
            if value < 0:
1193
                raise ValueError(f"If {name} is a single number, it must be non negative.")
1194
            value = [center - float(value), center + float(value)]
yaox12's avatar
yaox12 committed
1195
            if clip_first_on_zero:
1196
                value[0] = max(value[0], 0.0)
yaox12's avatar
yaox12 committed
1197
        elif isinstance(value, (tuple, list)) and len(value) == 2:
1198
            value = [float(value[0]), float(value[1])]
yaox12's avatar
yaox12 committed
1199
        else:
1200
            raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
yaox12's avatar
yaox12 committed
1201

1202
1203
1204
        if not bound[0] <= value[0] <= value[1] <= bound[1]:
            raise ValueError(f"{name} values should be between {bound}, but got {value}.")

yaox12's avatar
yaox12 committed
1205
1206
1207
        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
1208
1209
1210
            return None
        else:
            return tuple(value)
1211
1212

    @staticmethod
1213
1214
1215
1216
1217
1218
    def get_params(
        brightness: Optional[List[float]],
        contrast: Optional[List[float]],
        saturation: Optional[List[float]],
        hue: Optional[List[float]],
    ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
1219
        """Get the parameters for the randomized transform to be applied on image.
1220

1221
1222
1223
1224
1225
1226
1227
1228
1229
        Args:
            brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
                uniformly. Pass None to turn off the transformation.
            contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
                uniformly. Pass None to turn off the transformation.
            saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
                uniformly. Pass None to turn off the transformation.
            hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
                Pass None to turn off the transformation.
1230
1231

        Returns:
1232
1233
            tuple: The parameters used to apply the randomized transform
            along with their random order.
1234
        """
1235
        fn_idx = torch.randperm(4)
1236

1237
1238
1239
1240
        b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
        c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
        s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
        h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
1241

1242
        return fn_idx, b, c, s, h
1243

1244
    def forward(self, img):
1245
1246
        """
        Args:
1247
            img (PIL Image or Tensor): Input image.
1248
1249

        Returns:
1250
1251
            PIL Image or Tensor: Color jittered image.
        """
1252
1253
1254
        fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
            self.brightness, self.contrast, self.saturation, self.hue
        )
1255

1256
        for fn_id in fn_idx:
1257
            if fn_id == 0 and brightness_factor is not None:
1258
                img = F.adjust_brightness(img, brightness_factor)
1259
            elif fn_id == 1 and contrast_factor is not None:
1260
                img = F.adjust_contrast(img, contrast_factor)
1261
            elif fn_id == 2 and saturation_factor is not None:
1262
                img = F.adjust_saturation(img, saturation_factor)
1263
            elif fn_id == 3 and hue_factor is not None:
1264
1265
1266
                img = F.adjust_hue(img, hue_factor)

        return img
1267

Joao Gomes's avatar
Joao Gomes committed
1268
1269
1270
1271
1272
1273
1274
1275
1276
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}("
            f"brightness={self.brightness}"
            f", contrast={self.contrast}"
            f", saturation={self.saturation}"
            f", hue={self.hue})"
        )
        return s
1277

1278

1279
class RandomRotation(torch.nn.Module):
1280
    """Rotate the image by angle.
1281
    If the image is torch Tensor, it is expected
1282
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1283
1284

    Args:
1285
        degrees (sequence or number): Range of degrees to select from.
1286
1287
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees).
1288
1289
1290
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1291
1292
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
1293
1294
1295
1296
        expand (bool, optional): Optional expansion flag.
            If true, expands the output to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1297
        center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1298
            Default is the center of the image.
1299
1300
        fill (sequence or number): Pixel fill value for the area outside the rotated
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1301
1302
1303

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1304
1305
    """

1306
    def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0):
1307
        super().__init__()
1308
        _log_api_usage_once(self)
1309
1310
1311
1312

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
1313
1314
                "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
                "Please use InterpolationMode enum."
1315
1316
1317
            )
            interpolation = _interpolation_modes_from_int(interpolation)

1318
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1319
1320

        if center is not None:
1321
            _check_sequence_input(center, "center", req_sizes=(2,))
1322
1323

        self.center = center
1324

1325
        self.interpolation = interpolation
1326
        self.expand = expand
1327
1328
1329
1330
1331
1332

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1333
        self.fill = fill
1334
1335

    @staticmethod
1336
    def get_params(degrees: List[float]) -> float:
1337
1338
1339
        """Get parameters for ``rotate`` for a random rotation.

        Returns:
1340
            float: angle parameter to be passed to ``rotate`` for random rotation.
1341
        """
1342
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1343
1344
        return angle

1345
    def forward(self, img):
1346
        """
1347
        Args:
1348
            img (PIL Image or Tensor): Image to be rotated.
1349
1350

        Returns:
1351
            PIL Image or Tensor: Rotated image.
1352
        """
1353
        fill = self.fill
1354
        channels, _, _ = F.get_dimensions(img)
1355
1356
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
1357
                fill = [float(fill)] * channels
1358
1359
            else:
                fill = [float(f) for f in fill]
1360
        angle = self.get_params(self.degrees)
1361

1362
        return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)
1363

Joao Gomes's avatar
Joao Gomes committed
1364
    def __repr__(self) -> str:
1365
        interpolate_str = self.interpolation.value
1366
1367
1368
        format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
        format_string += f", interpolation={interpolate_str}"
        format_string += f", expand={self.expand}"
1369
        if self.center is not None:
1370
            format_string += f", center={self.center}"
1371
        if self.fill is not None:
1372
            format_string += f", fill={self.fill}"
1373
        format_string += ")"
1374
        return format_string
1375

1376

1377
1378
class RandomAffine(torch.nn.Module):
    """Random affine transformation of the image keeping center invariant.
1379
    If the image is torch Tensor, it is expected
1380
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1381
1382

    Args:
1383
        degrees (sequence or number): Range of degrees to select from.
1384
            If degrees is a number instead of sequence like (min, max), the range of degrees
1385
            will be (-degrees, +degrees). Set to 0 to deactivate rotations.
1386
1387
1388
1389
1390
1391
        translate (tuple, optional): tuple of maximum absolute fraction for horizontal
            and vertical translations. For example translate=(a, b), then horizontal shift
            is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
            randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
        scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
            randomly sampled from the range a <= scale <= b. Will keep original scale by default.
1392
        shear (sequence or number, optional): Range of degrees to select from.
1393
1394
            If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
            will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
1395
            range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
1396
            an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1397
            Will not apply shear by default.
1398
1399
1400
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1401
1402
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
1403
1404
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1405
1406
        center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
            Default is the center of the image.
1407
1408
1409

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1410
1411
    """

1412
    def __init__(
1413
1414
1415
1416
1417
1418
1419
        self,
        degrees,
        translate=None,
        scale=None,
        shear=None,
        interpolation=InterpolationMode.NEAREST,
        fill=0,
1420
        center=None,
1421
    ):
1422
        super().__init__()
1423
        _log_api_usage_once(self)
1424
1425
1426
1427

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
1428
1429
                "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
                "Please use InterpolationMode enum."
1430
1431
1432
            )
            interpolation = _interpolation_modes_from_int(interpolation)

1433
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1434
1435

        if translate is not None:
1436
            _check_sequence_input(translate, "translate", req_sizes=(2,))
1437
1438
1439
1440
1441
1442
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
1443
            _check_sequence_input(scale, "scale", req_sizes=(2,))
1444
1445
1446
1447
1448
1449
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
1450
            self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
1451
1452
1453
        else:
            self.shear = shear

1454
        self.interpolation = interpolation
1455
1456
1457
1458
1459
1460

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1461
        self.fill = fill
1462

1463
1464
1465
1466
1467
        if center is not None:
            _check_sequence_input(center, "center", req_sizes=(2,))

        self.center = center

1468
    @staticmethod
1469
    def get_params(
1470
1471
1472
1473
1474
        degrees: List[float],
        translate: Optional[List[float]],
        scale_ranges: Optional[List[float]],
        shears: Optional[List[float]],
        img_size: List[int],
1475
    ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
1476
1477
1478
        """Get parameters for affine transformation

        Returns:
1479
            params to be passed to the affine transformation
1480
        """
1481
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1482
        if translate is not None:
1483
1484
1485
1486
1487
            max_dx = float(translate[0] * img_size[0])
            max_dy = float(translate[1] * img_size[1])
            tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
            ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
            translations = (tx, ty)
1488
1489
1490
1491
        else:
            translations = (0, 0)

        if scale_ranges is not None:
1492
            scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
1493
1494
1495
        else:
            scale = 1.0

1496
        shear_x = shear_y = 0.0
1497
        if shears is not None:
1498
1499
1500
1501
1502
            shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
            if len(shears) == 4:
                shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())

        shear = (shear_x, shear_y)
1503
1504
1505

        return angle, translations, scale, shear

1506
    def forward(self, img):
1507
        """
1508
            img (PIL Image or Tensor): Image to be transformed.
1509
1510

        Returns:
1511
            PIL Image or Tensor: Affine transformed image.
1512
        """
1513
        fill = self.fill
1514
        channels, height, width = F.get_dimensions(img)
1515
1516
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
1517
                fill = [float(fill)] * channels
1518
1519
            else:
                fill = [float(f) for f in fill]
1520

1521
        img_size = [width, height]  # flip for keeping BC on get_params call
1522
1523

        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
1524

1525
        return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
1526

Joao Gomes's avatar
Joao Gomes committed
1527
1528
1529
1530
1531
1532
1533
1534
    def __repr__(self) -> str:
        s = f"{self.__class__.__name__}(degrees={self.degrees}"
        s += f", translate={self.translate}" if self.translate is not None else ""
        s += f", scale={self.scale}" if self.scale is not None else ""
        s += f", shear={self.shear}" if self.shear is not None else ""
        s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
        s += f", fill={self.fill}" if self.fill != 0 else ""
        s += f", center={self.center}" if self.center is not None else ""
1535
        s += ")"
Joao Gomes's avatar
Joao Gomes committed
1536
1537

        return s
1538
1539


1540
class Grayscale(torch.nn.Module):
1541
    """Convert image to grayscale.
1542
1543
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1544

1545
1546
1547
1548
    Args:
        num_output_channels (int): (1 or 3) number of channels desired for output image

    Returns:
1549
        PIL Image: Grayscale version of the input.
1550
1551
1552

        - If ``num_output_channels == 1`` : returned image is single channel
        - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
1553
1554
1555
1556

    """

    def __init__(self, num_output_channels=1):
1557
        super().__init__()
1558
        _log_api_usage_once(self)
1559
1560
        self.num_output_channels = num_output_channels

vfdev's avatar
vfdev committed
1561
    def forward(self, img):
1562
1563
        """
        Args:
1564
            img (PIL Image or Tensor): Image to be converted to grayscale.
1565
1566

        Returns:
1567
            PIL Image or Tensor: Grayscaled image.
1568
        """
1569
        return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
1570

Joao Gomes's avatar
Joao Gomes committed
1571
1572
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})"
1573

1574

1575
class RandomGrayscale(torch.nn.Module):
1576
    """Randomly convert image to grayscale with a probability of p (default 0.1).
1577
1578
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1579

1580
1581
1582
1583
    Args:
        p (float): probability that image should be converted to grayscale.

    Returns:
1584
        PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
1585
1586
1587
        with probability (1-p).
        - If input image is 1 channel: grayscale version is 1 channel
        - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1588
1589
1590
1591

    """

    def __init__(self, p=0.1):
1592
        super().__init__()
1593
        _log_api_usage_once(self)
1594
1595
        self.p = p

vfdev's avatar
vfdev committed
1596
    def forward(self, img):
1597
1598
        """
        Args:
1599
            img (PIL Image or Tensor): Image to be converted to grayscale.
1600
1601

        Returns:
1602
            PIL Image or Tensor: Randomly grayscaled image.
1603
        """
1604
        num_output_channels, _, _ = F.get_dimensions(img)
1605
1606
        if torch.rand(1) < self.p:
            return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
1607
        return img
1608

Joao Gomes's avatar
Joao Gomes committed
1609
1610
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
1611
1612


1613
class RandomErasing(torch.nn.Module):
1614
    """Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
1615
    This transform does not support PIL Image.
vfdev's avatar
vfdev committed
1616
    'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
1617

1618
1619
1620
1621
1622
1623
1624
1625
    Args:
         p: probability that the random erasing operation will be performed.
         scale: range of proportion of erased area against input image.
         ratio: range of aspect ratio of erased area.
         value: erasing value. Default is 0. If a single int, it is used to
            erase all pixels. If a tuple of length 3, it is used to erase
            R, G, B channels respectively.
            If a str of 'random', erasing each pixel with random values.
Zhun Zhong's avatar
Zhun Zhong committed
1626
         inplace: boolean to make this transform inplace. Default set to False.
1627

1628
1629
    Returns:
        Erased Image.
1630

vfdev's avatar
vfdev committed
1631
    Example:
1632
        >>> transform = transforms.Compose([
1633
        >>>   transforms.RandomHorizontalFlip(),
1634
1635
        >>>   transforms.PILToTensor(),
        >>>   transforms.ConvertImageDtype(torch.float),
1636
1637
        >>>   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>>   transforms.RandomErasing(),
1638
1639
1640
        >>> ])
    """

Zhun Zhong's avatar
Zhun Zhong committed
1641
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
1642
        super().__init__()
1643
        _log_api_usage_once(self)
1644
1645
1646
1647
1648
1649
1650
1651
        if not isinstance(value, (numbers.Number, str, tuple, list)):
            raise TypeError("Argument value should be either a number or str or a sequence")
        if isinstance(value, str) and value != "random":
            raise ValueError("If value is str, it should be 'random'")
        if not isinstance(scale, (tuple, list)):
            raise TypeError("Scale should be a sequence")
        if not isinstance(ratio, (tuple, list)):
            raise TypeError("Ratio should be a sequence")
1652
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
1653
            warnings.warn("Scale and ratio should be of kind (min, max)")
1654
        if scale[0] < 0 or scale[1] > 1:
1655
            raise ValueError("Scale should be between 0 and 1")
1656
        if p < 0 or p > 1:
1657
            raise ValueError("Random erasing probability should be between 0 and 1")
1658
1659
1660
1661
1662

        self.p = p
        self.scale = scale
        self.ratio = ratio
        self.value = value
1663
        self.inplace = inplace
1664
1665

    @staticmethod
1666
    def get_params(
1667
        img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
1668
    ) -> Tuple[int, int, int, int, Tensor]:
1669
1670
1671
        """Get parameters for ``erase`` for a random erasing.

        Args:
vfdev's avatar
vfdev committed
1672
            img (Tensor): Tensor image to be erased.
1673
1674
            scale (sequence): range of proportion of erased area against input image.
            ratio (sequence): range of aspect ratio of erased area.
1675
1676
1677
            value (list, optional): erasing value. If None, it is interpreted as "random"
                (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
                i.e. ``value[0]``.
1678
1679
1680
1681

        Returns:
            tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
        """
vfdev's avatar
vfdev committed
1682
        img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
1683
        area = img_h * img_w
1684

1685
        log_ratio = torch.log(torch.tensor(ratio))
1686
        for _ in range(10):
1687
            erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
1688
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
1689
1690
1691

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
1692
1693
1694
1695
1696
1697
1698
            if not (h < img_h and w < img_w):
                continue

            if value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(value)[:, None, None]
1699

1700
1701
            i = torch.randint(0, img_h - h + 1, size=(1,)).item()
            j = torch.randint(0, img_w - w + 1, size=(1,)).item()
1702
            return i, j, h, w, v
1703

Zhun Zhong's avatar
Zhun Zhong committed
1704
1705
1706
        # Return original image
        return 0, 0, img_h, img_w, img

1707
    def forward(self, img):
1708
1709
        """
        Args:
vfdev's avatar
vfdev committed
1710
            img (Tensor): Tensor image to be erased.
1711
1712
1713
1714

        Returns:
            img (Tensor): Erased Tensor image.
        """
1715
1716
1717
1718
        if torch.rand(1) < self.p:

            # cast self.value to script acceptable type
            if isinstance(self.value, (int, float)):
1719
                value = [float(self.value)]
1720
1721
            elif isinstance(self.value, str):
                value = None
1722
1723
            elif isinstance(self.value, (list, tuple)):
                value = [float(v) for v in self.value]
1724
1725
1726
1727
1728
1729
            else:
                value = self.value

            if value is not None and not (len(value) in (1, img.shape[-3])):
                raise ValueError(
                    "If value is a sequence, it should have either a single value or "
1730
                    f"{img.shape[-3]} (number of input channels)"
1731
1732
1733
                )

            x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
1734
            return F.erase(img, x, y, h, w, v, self.inplace)
1735
        return img
1736

Joao Gomes's avatar
Joao Gomes committed
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}"
            f"(p={self.p}, "
            f"scale={self.scale}, "
            f"ratio={self.ratio}, "
            f"value={self.value}, "
            f"inplace={self.inplace})"
        )
        return s
1747

1748

1749
1750
class GaussianBlur(torch.nn.Module):
    """Blurs image with randomly chosen Gaussian blur.
1751
1752
    If the image is torch Tensor, it is expected
    to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767

    Args:
        kernel_size (int or sequence): Size of the Gaussian kernel.
        sigma (float or tuple of float (min, max)): Standard deviation to be used for
            creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
            of float (min, max), sigma is chosen uniformly at random to lie in the
            given range.

    Returns:
        PIL Image or Tensor: Gaussian blurred version of the input image.

    """

    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        super().__init__()
1768
        _log_api_usage_once(self)
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
        self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
        for ks in self.kernel_size:
            if ks <= 0 or ks % 2 == 0:
                raise ValueError("Kernel size value should be an odd and positive number.")

        if isinstance(sigma, numbers.Number):
            if sigma <= 0:
                raise ValueError("If sigma is a single number, it must be positive.")
            sigma = (sigma, sigma)
        elif isinstance(sigma, Sequence) and len(sigma) == 2:
1779
            if not 0.0 < sigma[0] <= sigma[1]:
1780
1781
1782
1783
1784
1785
1786
1787
                raise ValueError("sigma values should be positive and of the form (min, max).")
        else:
            raise ValueError("sigma should be a single number or a list/tuple with length 2.")

        self.sigma = sigma

    @staticmethod
    def get_params(sigma_min: float, sigma_max: float) -> float:
vfdev's avatar
vfdev committed
1788
        """Choose sigma for random gaussian blurring.
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801

        Args:
            sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
            sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.

        Returns:
            float: Standard deviation to be passed to calculate kernel for gaussian blurring.
        """
        return torch.empty(1).uniform_(sigma_min, sigma_max).item()

    def forward(self, img: Tensor) -> Tensor:
        """
        Args:
vfdev's avatar
vfdev committed
1802
            img (PIL Image or Tensor): image to be blurred.
1803
1804
1805
1806
1807
1808
1809

        Returns:
            PIL Image or Tensor: Gaussian blurred image
        """
        sigma = self.get_params(self.sigma[0], self.sigma[1])
        return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])

Joao Gomes's avatar
Joao Gomes committed
1810
1811
1812
    def __repr__(self) -> str:
        s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
        return s
1813
1814


1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
def _setup_size(size, error_msg):
    if isinstance(size, numbers.Number):
        return int(size), int(size)

    if isinstance(size, Sequence) and len(size) == 1:
        return size[0], size[0]

    if len(size) != 2:
        raise ValueError(error_msg)

    return size


def _check_sequence_input(x, name, req_sizes):
    msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
    if not isinstance(x, Sequence):
1831
        raise TypeError(f"{name} should be a sequence of length {msg}.")
1832
    if len(x) not in req_sizes:
1833
        raise ValueError(f"{name} should be a sequence of length {msg}.")
1834
1835


1836
def _setup_angle(x, name, req_sizes=(2,)):
1837
1838
    if isinstance(x, numbers.Number):
        if x < 0:
1839
            raise ValueError(f"If {name} is a single number, it must be positive.")
1840
1841
1842
1843
1844
        x = [-x, x]
    else:
        _check_sequence_input(x, name, req_sizes)

    return [float(d) for d in x]
1845
1846
1847
1848


class RandomInvert(torch.nn.Module):
    """Inverts the colors of the given image randomly with a given probability.
1849
1850
1851
    If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1852
1853
1854
1855
1856
1857
1858

    Args:
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
1859
        _log_api_usage_once(self)
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be inverted.

        Returns:
            PIL Image or Tensor: Randomly color inverted image.
        """
        if torch.rand(1).item() < self.p:
            return F.invert(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
1874
1875
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
1876
1877
1878
1879


class RandomPosterize(torch.nn.Module):
    """Posterize the image randomly with a given probability by reducing the
1880
1881
1882
    number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
    and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1883
1884
1885

    Args:
        bits (int): number of bits to keep for each channel (0-8)
1886
        p (float): probability of the image being posterized. Default value is 0.5
1887
1888
1889
1890
    """

    def __init__(self, bits, p=0.5):
        super().__init__()
1891
        _log_api_usage_once(self)
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
        self.bits = bits
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be posterized.

        Returns:
            PIL Image or Tensor: Randomly posterized image.
        """
        if torch.rand(1).item() < self.p:
            return F.posterize(img, self.bits)
        return img

Joao Gomes's avatar
Joao Gomes committed
1907
1908
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
1909
1910
1911
1912


class RandomSolarize(torch.nn.Module):
    """Solarize the image randomly with a given probability by inverting all pixel
1913
1914
1915
    values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1916
1917
1918

    Args:
        threshold (float): all pixels equal or above this value are inverted.
1919
        p (float): probability of the image being solarized. Default value is 0.5
1920
1921
1922
1923
    """

    def __init__(self, threshold, p=0.5):
        super().__init__()
1924
        _log_api_usage_once(self)
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
        self.threshold = threshold
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be solarized.

        Returns:
            PIL Image or Tensor: Randomly solarized image.
        """
        if torch.rand(1).item() < self.p:
            return F.solarize(img, self.threshold)
        return img

Joao Gomes's avatar
Joao Gomes committed
1940
1941
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
1942
1943
1944


class RandomAdjustSharpness(torch.nn.Module):
1945
1946
    """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
    it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1947
1948
1949

    Args:
        sharpness_factor (float):  How much to adjust the sharpness. Can be
1950
            any non-negative number. 0 gives a blurred image, 1 gives the
1951
            original image while 2 increases the sharpness by a factor of 2.
1952
        p (float): probability of the image being sharpened. Default value is 0.5
1953
1954
1955
1956
    """

    def __init__(self, sharpness_factor, p=0.5):
        super().__init__()
1957
        _log_api_usage_once(self)
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
        self.sharpness_factor = sharpness_factor
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be sharpened.

        Returns:
            PIL Image or Tensor: Randomly sharpened image.
        """
        if torch.rand(1).item() < self.p:
            return F.adjust_sharpness(img, self.sharpness_factor)
        return img

Joao Gomes's avatar
Joao Gomes committed
1973
1974
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
1975
1976
1977
1978


class RandomAutocontrast(torch.nn.Module):
    """Autocontrast the pixels of the given image randomly with a given probability.
1979
1980
1981
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1982
1983
1984
1985
1986
1987
1988

    Args:
        p (float): probability of the image being autocontrasted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
1989
        _log_api_usage_once(self)
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be autocontrasted.

        Returns:
            PIL Image or Tensor: Randomly autocontrasted image.
        """
        if torch.rand(1).item() < self.p:
            return F.autocontrast(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
2004
2005
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
2006
2007
2008
2009


class RandomEqualize(torch.nn.Module):
    """Equalize the histogram of the given image randomly with a given probability.
2010
2011
2012
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
2013
2014
2015
2016
2017
2018
2019

    Args:
        p (float): probability of the image being equalized. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
2020
        _log_api_usage_once(self)
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be equalized.

        Returns:
            PIL Image or Tensor: Randomly equalized image.
        """
        if torch.rand(1).item() < self.p:
            return F.equalize(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
2035
2036
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106


class ElasticTransform(torch.nn.Module):
    """Transform a tensor image with elastic transformations.
    Given alpha and sigma, it will generate displacement
    vectors for all pixels based on random offsets. Alpha controls the strength
    and sigma controls the smoothness of the displacements.
    The displacements are added to an identity grid and the resulting grid is
    used to grid_sample from the image.

    Applications:
        Randomly transforms the morphology of objects in images and produces a
        see-through-water-like effect.

    Args:
        alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
        sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.

    """

    def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
        super().__init__()
        _log_api_usage_once(self)
        if not isinstance(alpha, (float, Sequence)):
            raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
        if isinstance(alpha, Sequence) and len(alpha) != 2:
            raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
        if isinstance(alpha, Sequence):
            for element in alpha:
                if not isinstance(element, float):
                    raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")

        if isinstance(alpha, float):
            alpha = [float(alpha), float(alpha)]
        if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
            alpha = [alpha[0], alpha[0]]

        self.alpha = alpha

        if not isinstance(sigma, (float, Sequence)):
            raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
        if isinstance(sigma, Sequence) and len(sigma) != 2:
            raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
        if isinstance(sigma, Sequence):
            for element in sigma:
                if not isinstance(element, float):
                    raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")

        if isinstance(sigma, float):
            sigma = [float(sigma), float(sigma)]
        if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
            sigma = [sigma[0], sigma[0]]

        self.sigma = sigma

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
            )
            interpolation = _interpolation_modes_from_int(interpolation)
        self.interpolation = interpolation

2107
2108
2109
2110
2111
2112
        if isinstance(fill, (int, float)):
            fill = [float(fill)]
        elif isinstance(fill, (list, tuple)):
            fill = [float(f) for f in fill]
        else:
            raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
        self.fill = fill

    @staticmethod
    def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
        dx = torch.rand([1, 1] + size) * 2 - 1
        if sigma[0] > 0.0:
            kx = int(8 * sigma[0] + 1)
            # if kernel size is even we have to make it odd
            if kx % 2 == 0:
                kx += 1
            dx = F.gaussian_blur(dx, [kx, kx], sigma)
        dx = dx * alpha[0] / size[0]

        dy = torch.rand([1, 1] + size) * 2 - 1
        if sigma[1] > 0.0:
            ky = int(8 * sigma[1] + 1)
            # if kernel size is even we have to make it odd
            if ky % 2 == 0:
                ky += 1
            dy = F.gaussian_blur(dy, [ky, ky], sigma)
        dy = dy * alpha[1] / size[1]
        return torch.concat([dx, dy], 1).permute([0, 2, 3, 1])  # 1 x H x W x 2

    def forward(self, tensor: Tensor) -> Tensor:
        """
        Args:
2139
            tensor (PIL Image or Tensor): Image to be transformed.
2140
2141
2142
2143

        Returns:
            PIL Image or Tensor: Transformed image.
        """
2144
2145
        _, height, width = F.get_dimensions(tensor)
        displacement = self.get_params(self.alpha, self.sigma, [height, width])
2146
2147
2148
        return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)

    def __repr__(self):
2149
2150
2151
2152
2153
        format_string = self.__class__.__name__
        format_string += f"(alpha={self.alpha}"
        format_string += f", sigma={self.sigma}"
        format_string += f", interpolation={self.interpolation}"
        format_string += f", fill={self.fill})"
2154
        return format_string