transforms.py 84.1 KB
Newer Older
1
import math
vfdev's avatar
vfdev committed
2
import numbers
3
import random
vfdev's avatar
vfdev committed
4
import warnings
vfdev's avatar
vfdev committed
5
from collections.abc import Sequence
6
from typing import List, Optional, Tuple, Union
vfdev's avatar
vfdev committed
7
8
9
10

import torch
from torch import Tensor

11
12
13
14
15
try:
    import accimage
except ImportError:
    accimage = None

16
from ..utils import _log_api_usage_once
17
from . import functional as F
18
from .functional import _interpolation_modes_from_int, InterpolationMode
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
__all__ = [
    "Compose",
    "ToTensor",
    "PILToTensor",
    "ConvertImageDtype",
    "ToPILImage",
    "Normalize",
    "Resize",
    "CenterCrop",
    "Pad",
    "Lambda",
    "RandomApply",
    "RandomChoice",
    "RandomOrder",
    "RandomCrop",
    "RandomHorizontalFlip",
    "RandomVerticalFlip",
    "RandomResizedCrop",
    "FiveCrop",
    "TenCrop",
    "LinearTransformation",
    "ColorJitter",
    "RandomRotation",
    "RandomAffine",
    "Grayscale",
    "RandomGrayscale",
    "RandomPerspective",
    "RandomErasing",
    "GaussianBlur",
    "InterpolationMode",
    "RandomInvert",
    "RandomPosterize",
    "RandomSolarize",
    "RandomAdjustSharpness",
    "RandomAutocontrast",
    "RandomEqualize",
56
    "ElasticTransform",
57
]
58

59

60
class Compose:
61
62
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.
63
64
65
66
67
68
69

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
70
71
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
72
        >>> ])
73
74
75
76
77
78
79
80
81
82
83
84
85

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

86
87
88
    """

    def __init__(self, transforms):
89
90
        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            _log_api_usage_once(self)
91
92
93
94
95
96
97
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
98
    def __repr__(self) -> str:
99
        format_string = self.__class__.__name__ + "("
100
        for t in self.transforms:
101
            format_string += "\n"
102
            format_string += f"    {t}"
103
        format_string += "\n)"
104
105
        return format_string

106

107
class ToTensor:
108
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
109
110

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
surgan12's avatar
surgan12 committed
111
112
113
114
115
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.
116
117
118
119
120

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

121
    .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
122
123
    """

124
125
126
    def __init__(self) -> None:
        _log_api_usage_once(self)

127
128
129
130
131
132
133
134
135
136
    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

Joao Gomes's avatar
Joao Gomes committed
137
138
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
139

140

141
class PILToTensor:
142
    """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript.
143

vfdev's avatar
vfdev committed
144
    Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
145
146
    """

147
148
149
    def __init__(self) -> None:
        _log_api_usage_once(self)

150
151
    def __call__(self, pic):
        """
152
153
154
155
        .. note::

            A deep copy of the underlying array is performed.

156
157
158
159
160
161
162
163
        Args:
            pic (PIL Image): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.pil_to_tensor(pic)

Joao Gomes's avatar
Joao Gomes committed
164
165
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
166
167


168
class ConvertImageDtype(torch.nn.Module):
169
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly
170
    This function does not support PIL Image.
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

    Args:
        dtype (torch.dtype): Desired data type of the output

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """

    def __init__(self, dtype: torch.dtype) -> None:
188
        super().__init__()
189
        _log_api_usage_once(self)
190
191
        self.dtype = dtype

vfdev's avatar
vfdev committed
192
    def forward(self, image):
193
194
195
        return F.convert_image_dtype(image, self.dtype)


196
class ToPILImage:
197
    """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.
198
199
200
201
202
203
204

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
vfdev's avatar
vfdev committed
205
206
207
208
209
            - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
            - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
            - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
            - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
            ``short``).
210

csukuangfj's avatar
csukuangfj committed
211
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
212
    """
213

214
    def __init__(self, mode=None):
215
        _log_api_usage_once(self)
216
217
218
219
220
221
222
223
224
225
226
227
228
        self.mode = mode

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        return F.to_pil_image(pic, self.mode)

Joao Gomes's avatar
Joao Gomes committed
229
    def __repr__(self) -> str:
230
        format_string = self.__class__.__name__ + "("
231
        if self.mode is not None:
232
            format_string += f"mode={self.mode}"
233
        format_string += ")"
234
        return format_string
235

236

237
class Normalize(torch.nn.Module):
Fang Gao's avatar
Fang Gao committed
238
    """Normalize a tensor image with mean and standard deviation.
239
    This transform does not support PIL Image.
240
241
242
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
abdjava's avatar
abdjava committed
243
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
244

245
    .. note::
246
        This transform acts out of place, i.e., it does not mutate the input tensor.
247

248
249
250
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
251
252
        inplace(bool,optional): Bool to make this operation in-place.

253
254
    """

surgan12's avatar
surgan12 committed
255
    def __init__(self, mean, std, inplace=False):
256
        super().__init__()
257
        _log_api_usage_once(self)
258
259
        self.mean = mean
        self.std = std
surgan12's avatar
surgan12 committed
260
        self.inplace = inplace
261

262
    def forward(self, tensor: Tensor) -> Tensor:
263
264
        """
        Args:
vfdev's avatar
vfdev committed
265
            tensor (Tensor): Tensor image to be normalized.
266
267
268
269

        Returns:
            Tensor: Normalized Tensor image.
        """
surgan12's avatar
surgan12 committed
270
        return F.normalize(tensor, self.mean, self.std, self.inplace)
271

Joao Gomes's avatar
Joao Gomes committed
272
273
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
274

275

vfdev's avatar
vfdev committed
276
277
class Resize(torch.nn.Module):
    """Resize the input image to the given size.
278
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
279
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
280

281
282
283
284
    .. warning::
        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
285
286
        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
        closer.
287

288
289
290
291
292
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
293
            (size * height / width, size).
294
295
296

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
297
298
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
299
300
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
301
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
302
303
304
305
        max_size (int, optional): The maximum allowed for the longer edge of
            the resized image: if the longer edge of the image is greater
            than ``max_size`` after being resized according to ``size``, then
            the image is resized again so that the longer edge is equal to
306
            ``max_size``. As a result, ``size`` might be overruled, i.e. the
307
308
309
            smaller edge may be shorter than ``size``. This is only supported
            if ``size`` is an int (or a sequence of length 1 in torchscript
            mode).
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        antialias (bool, optional): Whether to apply antialiasing.
            It only affects **tensors** with bilinear or bicubic modes and it is
            ignored otherwise: on PIL images, antialiasing is always applied on
            bilinear or bicubic modes; on other modes (for PIL images and
            tensors), antialiasing makes no sense and this parameter is ignored.
            Possible values are:

            - ``True``: will apply antialiasing for bilinear or bicubic modes.
              Other mode aren't affected. This is probably what you want to use.
            - ``False``: will not apply antialiasing for tensors on any mode. PIL
              images are still antialiased on bilinear or bicubic modes, because
              PIL doesn't support no antialias.
            - ``None``: equivalent to ``False`` for tensors and ``True`` for
              PIL images. This value exists for legacy reasons and you probably
              don't want to use it unless you really know what you are doing.

            The current default is ``None`` **but will change to** ``True`` **in
            v0.17** for the PIL and Tensor backends to be consistent.
328
329
    """

330
    def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias="warn"):
vfdev's avatar
vfdev committed
331
        super().__init__()
332
        _log_api_usage_once(self)
333
        if not isinstance(size, (int, Sequence)):
334
            raise TypeError(f"Size should be int or sequence. Got {type(size)}")
335
336
337
        if isinstance(size, Sequence) and len(size) not in (1, 2):
            raise ValueError("If size is a sequence, it should have 1 or 2 values")
        self.size = size
338
        self.max_size = max_size
339

340
341
342
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

343
        self.interpolation = interpolation
344
        self.antialias = antialias
345

vfdev's avatar
vfdev committed
346
    def forward(self, img):
347
348
        """
        Args:
vfdev's avatar
vfdev committed
349
            img (PIL Image or Tensor): Image to be scaled.
350
351

        Returns:
vfdev's avatar
vfdev committed
352
            PIL Image or Tensor: Rescaled image.
353
        """
354
        return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
355

Joao Gomes's avatar
Joao Gomes committed
356
    def __repr__(self) -> str:
357
        detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
Joao Gomes's avatar
Joao Gomes committed
358
        return f"{self.__class__.__name__}{detail}"
359

360

vfdev's avatar
vfdev committed
361
362
class CenterCrop(torch.nn.Module):
    """Crops the given image at the center.
363
    If the image is torch Tensor, it is expected
364
365
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
366
367
368
369

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
370
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
371
372
373
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
374
        super().__init__()
375
        _log_api_usage_once(self)
376
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
377

vfdev's avatar
vfdev committed
378
    def forward(self, img):
379
380
        """
        Args:
vfdev's avatar
vfdev committed
381
            img (PIL Image or Tensor): Image to be cropped.
382
383

        Returns:
vfdev's avatar
vfdev committed
384
            PIL Image or Tensor: Cropped image.
385
386
387
        """
        return F.center_crop(img, self.size)

Joao Gomes's avatar
Joao Gomes committed
388
389
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size})"
390

391

392
393
class Pad(torch.nn.Module):
    """Pad the given image on all sides with the given "pad" value.
394
    If the image is torch Tensor, it is expected
395
396
397
    to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
    at most 3 leading dimensions for mode edge,
    and an arbitrary number of leading dimensions for mode constant
398
399

    Args:
400
401
402
        padding (int or sequence): Padding on each border. If a single int is provided this
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
403
            this is the padding for the left, top, right and bottom borders respectively.
404
405
406
407

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
408
        fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
409
            length 3, it is used to fill R, G, B channels respectively.
410
411
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
412
            Only int or tuple value is supported for PIL Image.
413
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
414
            Default is constant.
415
416
417

            - constant: pads with a constant value, this value is specified with fill

418
419
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
420

421
422
423
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
424

425
426
427
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
428
429
    """

430
431
    def __init__(self, padding, fill=0, padding_mode="constant"):
        super().__init__()
432
        _log_api_usage_once(self)
433
434
435
        if not isinstance(padding, (numbers.Number, tuple, list)):
            raise TypeError("Got inappropriate padding arg")

436
        if not isinstance(fill, (numbers.Number, tuple, list)):
437
438
439
440
441
442
            raise TypeError("Got inappropriate fill arg")

        if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
            raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

        if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
443
            raise ValueError(
444
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
445
            )
446
447
448

        self.padding = padding
        self.fill = fill
449
        self.padding_mode = padding_mode
450

451
    def forward(self, img):
452
453
        """
        Args:
454
            img (PIL Image or Tensor): Image to be padded.
455
456

        Returns:
457
            PIL Image or Tensor: Padded image.
458
        """
459
        return F.pad(img, self.padding, self.fill, self.padding_mode)
460

Joao Gomes's avatar
Joao Gomes committed
461
462
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
463

464

465
class Lambda:
466
    """Apply a user-defined lambda as a transform. This transform does not support torchscript.
467
468
469
470
471
472

    Args:
        lambd (function): Lambda/function to be used for transform.
    """

    def __init__(self, lambd):
473
        _log_api_usage_once(self)
474
        if not callable(lambd):
475
            raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
476
477
478
479
480
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

Joao Gomes's avatar
Joao Gomes committed
481
482
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
483

484

485
class RandomTransforms:
486
487
488
    """Base class for a list of transformations with randomness

    Args:
489
        transforms (sequence): list of transformations
490
491
492
    """

    def __init__(self, transforms):
493
        _log_api_usage_once(self)
494
495
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence")
496
497
498
499
500
        self.transforms = transforms

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()

Joao Gomes's avatar
Joao Gomes committed
501
    def __repr__(self) -> str:
502
        format_string = self.__class__.__name__ + "("
503
        for t in self.transforms:
504
            format_string += "\n"
505
            format_string += f"    {t}"
506
        format_string += "\n)"
507
508
509
        return format_string


510
class RandomApply(torch.nn.Module):
511
    """Apply randomly a list of transformations with a given probability.
512
513
514
515
516
517
518
519
520
521
522
523

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.
524
525

    Args:
526
        transforms (sequence or torch.nn.Module): list of transformations
527
528
529
530
        p (float): probability
    """

    def __init__(self, transforms, p=0.5):
531
        super().__init__()
532
        _log_api_usage_once(self)
533
        self.transforms = transforms
534
535
        self.p = p

536
537
    def forward(self, img):
        if self.p < torch.rand(1):
538
539
540
541
542
            return img
        for t in self.transforms:
            img = t(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
543
    def __repr__(self) -> str:
544
        format_string = self.__class__.__name__ + "("
545
        format_string += f"\n    p={self.p}"
546
        for t in self.transforms:
547
            format_string += "\n"
548
            format_string += f"    {t}"
549
        format_string += "\n)"
550
551
552
553
        return format_string


class RandomOrder(RandomTransforms):
554
555
    """Apply a list of transformations in a random order. This transform does not support torchscript."""

556
557
558
559
560
561
562
563
564
    def __call__(self, img):
        order = list(range(len(self.transforms)))
        random.shuffle(order)
        for i in order:
            img = self.transforms[i](img)
        return img


class RandomChoice(RandomTransforms):
565
566
    """Apply single transformation randomly picked from a list. This transform does not support torchscript."""

567
568
569
    def __init__(self, transforms, p=None):
        super().__init__(transforms)
        if p is not None and not isinstance(p, Sequence):
570
            raise TypeError("Argument p should be a sequence")
571
572
573
574
575
576
        self.p = p

    def __call__(self, *args):
        t = random.choices(self.transforms, weights=self.p)[0]
        return t(*args)

Joao Gomes's avatar
Joao Gomes committed
577
578
    def __repr__(self) -> str:
        return f"{super().__repr__()}(p={self.p})"
579
580


vfdev's avatar
vfdev committed
581
582
class RandomCrop(torch.nn.Module):
    """Crop the given image at a random location.
583
    If the image is torch Tensor, it is expected
584
585
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
    but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
586
587
588
589

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
590
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
591
        padding (int or sequence, optional): Optional padding on each border
vfdev's avatar
vfdev committed
592
            of the image. Default is None. If a single int is provided this
593
594
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
vfdev's avatar
vfdev committed
595
            this is the padding for the left, top, right and bottom borders respectively.
596
597
598
599

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
600
        pad_if_needed (boolean): It will pad the image if smaller than the
ekka's avatar
ekka committed
601
            desired size to avoid raising an exception. Since cropping is done
602
            after padding, the padding seems to be done at a random offset.
603
        fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
604
            length 3, it is used to fill R, G, B channels respectively.
605
606
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
607
            Only int or tuple value is supported for PIL Image.
608
609
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is constant.
610

611
            - constant: pads with a constant value, this value is specified with fill
612

613
614
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
615

616
617
618
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
619

620
621
622
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
623
624
625
    """

    @staticmethod
vfdev's avatar
vfdev committed
626
    def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
627
628
629
        """Get parameters for ``crop`` for a random crop.

        Args:
vfdev's avatar
vfdev committed
630
            img (PIL Image or Tensor): Image to be cropped.
631
632
633
634
635
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
636
        _, h, w = F.get_dimensions(img)
637
        th, tw = output_size
vfdev's avatar
vfdev committed
638

639
        if h < th or w < tw:
640
            raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
vfdev's avatar
vfdev committed
641

642
643
644
        if w == tw and h == th:
            return 0, 0, h, w

645
646
        i = torch.randint(0, h - th + 1, size=(1,)).item()
        j = torch.randint(0, w - tw + 1, size=(1,)).item()
647
648
        return i, j, th, tw

vfdev's avatar
vfdev committed
649
650
    def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
        super().__init__()
651
        _log_api_usage_once(self)
vfdev's avatar
vfdev committed
652

653
        self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
654

vfdev's avatar
vfdev committed
655
656
657
658
659
660
        self.padding = padding
        self.pad_if_needed = pad_if_needed
        self.fill = fill
        self.padding_mode = padding_mode

    def forward(self, img):
661
662
        """
        Args:
vfdev's avatar
vfdev committed
663
            img (PIL Image or Tensor): Image to be cropped.
664
665

        Returns:
vfdev's avatar
vfdev committed
666
            PIL Image or Tensor: Cropped image.
667
        """
668
669
        if self.padding is not None:
            img = F.pad(img, self.padding, self.fill, self.padding_mode)
670

671
        _, height, width = F.get_dimensions(img)
672
        # pad the width if needed
vfdev's avatar
vfdev committed
673
674
675
        if self.pad_if_needed and width < self.size[1]:
            padding = [self.size[1] - width, 0]
            img = F.pad(img, padding, self.fill, self.padding_mode)
676
        # pad the height if needed
vfdev's avatar
vfdev committed
677
678
679
        if self.pad_if_needed and height < self.size[0]:
            padding = [0, self.size[0] - height]
            img = F.pad(img, padding, self.fill, self.padding_mode)
680

681
682
683
684
        i, j, h, w = self.get_params(img, self.size)

        return F.crop(img, i, j, h, w)

Joao Gomes's avatar
Joao Gomes committed
685
686
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})"
687

688

689
690
class RandomHorizontalFlip(torch.nn.Module):
    """Horizontally flip the given image randomly with a given probability.
691
    If the image is torch Tensor, it is expected
692
693
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
694
695
696
697
698
699

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
700
        super().__init__()
701
        _log_api_usage_once(self)
702
        self.p = p
703

704
    def forward(self, img):
705
706
        """
        Args:
707
            img (PIL Image or Tensor): Image to be flipped.
708
709

        Returns:
710
            PIL Image or Tensor: Randomly flipped image.
711
        """
712
        if torch.rand(1) < self.p:
713
714
715
            return F.hflip(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
716
717
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
718

719

720
class RandomVerticalFlip(torch.nn.Module):
vfdev's avatar
vfdev committed
721
    """Vertically flip the given image randomly with a given probability.
722
    If the image is torch Tensor, it is expected
723
724
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
725
726
727
728
729
730

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
731
        super().__init__()
732
        _log_api_usage_once(self)
733
        self.p = p
734

735
    def forward(self, img):
736
737
        """
        Args:
738
            img (PIL Image or Tensor): Image to be flipped.
739
740

        Returns:
741
            PIL Image or Tensor: Randomly flipped image.
742
        """
743
        if torch.rand(1) < self.p:
744
745
746
            return F.vflip(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
747
748
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
749

750

751
752
class RandomPerspective(torch.nn.Module):
    """Performs a random perspective transformation of the given image with a given probability.
753
    If the image is torch Tensor, it is expected
754
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
755
756

    Args:
757
758
759
        distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
            Default is 0.5.
        p (float): probability of the image being transformed. Default is 0.5.
760
761
762
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
763
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
764
765
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
766
767
    """

768
    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
769
        super().__init__()
770
        _log_api_usage_once(self)
771
        self.p = p
772

773
774
775
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

776
777
        self.interpolation = interpolation
        self.distortion_scale = distortion_scale
778
779
780
781
782
783

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

784
        self.fill = fill
785

786
    def forward(self, img):
787
788
        """
        Args:
789
            img (PIL Image or Tensor): Image to be Perspectively transformed.
790
791

        Returns:
792
            PIL Image or Tensor: Randomly transformed image.
793
        """
794
795

        fill = self.fill
796
        channels, height, width = F.get_dimensions(img)
797
798
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
799
                fill = [float(fill)] * channels
800
801
802
            else:
                fill = [float(f) for f in fill]

803
        if torch.rand(1) < self.p:
804
            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
805
            return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
806
807
808
        return img

    @staticmethod
809
    def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
810
811
812
        """Get parameters for ``perspective`` for a random perspective transform.

        Args:
813
814
815
            width (int): width of the image.
            height (int): height of the image.
            distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
816
817

        Returns:
818
            List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
819
820
            List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
        """
821
822
823
        half_height = height // 2
        half_width = width // 2
        topleft = [
824
825
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
826
827
        ]
        topright = [
828
829
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
830
831
        ]
        botright = [
832
833
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
834
835
        ]
        botleft = [
836
837
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
838
839
        ]
        startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
840
841
842
        endpoints = [topleft, topright, botright, botleft]
        return startpoints, endpoints

Joao Gomes's avatar
Joao Gomes committed
843
844
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
845
846


847
class RandomResizedCrop(torch.nn.Module):
848
849
    """Crop a random portion of image and resize it to a given size.

850
    If the image is torch Tensor, it is expected
851
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
852

853
854
855
    A crop of the original image is made: the crop has a random area (H * W)
    and a random aspect ratio. This crop is finally resized to the given
    size. This is popularly used to train the Inception networks.
856
857

    Args:
858
        size (int or sequence): expected output size of the crop, for each edge. If size is an
859
            int instead of sequence like (h, w), a square output size ``(size, size)`` is
860
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
861
862
863

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
Nicolas Hug's avatar
Nicolas Hug committed
864
865
        scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
            before resizing. The scale is defined with respect to the area of the original image.
866
867
        ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
            resizing.
868
869
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
870
871
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
872
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
        antialias (bool, optional): Whether to apply antialiasing.
            It only affects **tensors** with bilinear or bicubic modes and it is
            ignored otherwise: on PIL images, antialiasing is always applied on
            bilinear or bicubic modes; on other modes (for PIL images and
            tensors), antialiasing makes no sense and this parameter is ignored.
            Possible values are:

            - ``True``: will apply antialiasing for bilinear or bicubic modes.
              Other mode aren't affected. This is probably what you want to use.
            - ``False``: will not apply antialiasing for tensors on any mode. PIL
              images are still antialiased on bilinear or bicubic modes, because
              PIL doesn't support no antialias.
            - ``None``: equivalent to ``False`` for tensors and ``True`` for
              PIL images. This value exists for legacy reasons and you probably
              don't want to use it unless you really know what you are doing.

            The current default is ``None`` **but will change to** ``True`` **in
            v0.17** for the PIL and Tensor backends to be consistent.
891
892
    """

893
894
895
896
897
898
    def __init__(
        self,
        size,
        scale=(0.08, 1.0),
        ratio=(3.0 / 4.0, 4.0 / 3.0),
        interpolation=InterpolationMode.BILINEAR,
899
        antialias: Optional[Union[str, bool]] = "warn",
900
    ):
901
        super().__init__()
902
        _log_api_usage_once(self)
903
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
904

905
        if not isinstance(scale, Sequence):
906
            raise TypeError("Scale should be a sequence")
907
        if not isinstance(ratio, Sequence):
908
            raise TypeError("Ratio should be a sequence")
909
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
910
            warnings.warn("Scale and ratio should be of kind (min, max)")
911

912
913
914
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

915
        self.interpolation = interpolation
916
        self.antialias = antialias
917
918
        self.scale = scale
        self.ratio = ratio
919
920

    @staticmethod
921
    def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
922
923
924
        """Get parameters for ``crop`` for a random sized crop.

        Args:
925
            img (PIL Image or Tensor): Input image.
926
927
            scale (list): range of scale of the origin size cropped
            ratio (list): range of aspect ratio of the origin aspect ratio cropped
928
929
930

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
931
            sized crop.
932
        """
933
        _, height, width = F.get_dimensions(img)
Zhicheng Yan's avatar
Zhicheng Yan committed
934
        area = height * width
935

936
        log_ratio = torch.log(torch.tensor(ratio))
937
        for _ in range(10):
938
            target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
939
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
940
941
942
943

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

Zhicheng Yan's avatar
Zhicheng Yan committed
944
            if 0 < w <= width and 0 < h <= height:
945
946
                i = torch.randint(0, height - h + 1, size=(1,)).item()
                j = torch.randint(0, width - w + 1, size=(1,)).item()
947
948
                return i, j, h, w

949
        # Fallback to central crop
Zhicheng Yan's avatar
Zhicheng Yan committed
950
        in_ratio = float(width) / float(height)
951
        if in_ratio < min(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
952
            w = width
953
            h = int(round(w / min(ratio)))
954
        elif in_ratio > max(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
955
            h = height
956
            w = int(round(h * max(ratio)))
957
        else:  # whole image
Zhicheng Yan's avatar
Zhicheng Yan committed
958
959
960
961
            w = width
            h = height
        i = (height - h) // 2
        j = (width - w) // 2
962
        return i, j, h, w
963

964
    def forward(self, img):
965
966
        """
        Args:
967
            img (PIL Image or Tensor): Image to be cropped and resized.
968
969

        Returns:
970
            PIL Image or Tensor: Randomly cropped and resized image.
971
        """
972
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
973
        return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
974

Joao Gomes's avatar
Joao Gomes committed
975
    def __repr__(self) -> str:
976
        interpolate_str = self.interpolation.value
977
978
979
        format_string = self.__class__.__name__ + f"(size={self.size}"
        format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
        format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
980
        format_string += f", interpolation={interpolate_str}"
981
        format_string += f", antialias={self.antialias})"
982
        return format_string
983

984

vfdev's avatar
vfdev committed
985
986
class FiveCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop.
987
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
988
989
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
990
991
992
993
994
995
996
997
998

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.
999
            If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
1000
1001
1002
1003

    Example:
         >>> transform = Compose([
         >>>    FiveCrop(size), # this is a list of PIL Images
1004
         >>>    Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
1005
1006
1007
1008
1009
1010
1011
1012
1013
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
1014
        super().__init__()
1015
        _log_api_usage_once(self)
1016
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1017

vfdev's avatar
vfdev committed
1018
1019
1020
1021
1022
1023
1024
1025
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 5 images. Image can be PIL Image or Tensor
        """
1026
1027
        return F.five_crop(img, self.size)

Joao Gomes's avatar
Joao Gomes committed
1028
1029
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size})"
1030

1031

vfdev's avatar
vfdev committed
1032
1033
1034
class TenCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default).
1035
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1036
1037
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
1038
1039
1040
1041
1042
1043
1044
1045
1046

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
1047
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
1048
        vertical_flip (bool): Use vertical flipping instead of horizontal
1049
1050
1051

    Example:
         >>> transform = Compose([
Philip Meier's avatar
Philip Meier committed
1052
         >>>    TenCrop(size), # this is a tuple of PIL Images
1053
         >>>    Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
1054
1055
1056
1057
1058
1059
1060
1061
1062
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size, vertical_flip=False):
vfdev's avatar
vfdev committed
1063
        super().__init__()
1064
        _log_api_usage_once(self)
1065
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1066
1067
        self.vertical_flip = vertical_flip

vfdev's avatar
vfdev committed
1068
1069
1070
1071
1072
1073
1074
1075
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 10 images. Image can be PIL Image or Tensor
        """
1076
1077
        return F.ten_crop(img, self.size, self.vertical_flip)

Joao Gomes's avatar
Joao Gomes committed
1078
1079
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
1080

1081

1082
class LinearTransformation(torch.nn.Module):
ekka's avatar
ekka committed
1083
    """Transform a tensor image with a square transformation matrix and a mean_vector computed
1084
    offline.
1085
    This transform does not support PIL Image.
ekka's avatar
ekka committed
1086
1087
1088
    Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
    subtract mean_vector from it which is then followed by computing the dot
    product with the transformation matrix and then reshaping the tensor to its
1089
    original shape.
1090

1091
    Applications:
1092
        whitening transformation: Suppose X is a column vector zero-centered data.
1093
1094
1095
        Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
        perform SVD on this matrix and pass it as transformation_matrix.

1096
1097
    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
ekka's avatar
ekka committed
1098
        mean_vector (Tensor): tensor [D], D = C x H x W
1099
1100
    """

ekka's avatar
ekka committed
1101
    def __init__(self, transformation_matrix, mean_vector):
1102
        super().__init__()
1103
        _log_api_usage_once(self)
1104
        if transformation_matrix.size(0) != transformation_matrix.size(1):
1105
1106
            raise ValueError(
                "transformation_matrix should be square. Got "
1107
                f"{tuple(transformation_matrix.size())} rectangular matrix."
1108
            )
ekka's avatar
ekka committed
1109
1110

        if mean_vector.size(0) != transformation_matrix.size(0):
1111
            raise ValueError(
1112
1113
                f"mean_vector should have the same length {mean_vector.size(0)}"
                f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
1114
            )
ekka's avatar
ekka committed
1115

1116
        if transformation_matrix.device != mean_vector.device:
1117
            raise ValueError(
1118
                f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
1119
            )
1120

1121
1122
1123
1124
1125
        if transformation_matrix.dtype != mean_vector.dtype:
            raise ValueError(
                f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
            )

1126
        self.transformation_matrix = transformation_matrix
ekka's avatar
ekka committed
1127
        self.mean_vector = mean_vector
1128

1129
    def forward(self, tensor: Tensor) -> Tensor:
1130
1131
        """
        Args:
vfdev's avatar
vfdev committed
1132
            tensor (Tensor): Tensor image to be whitened.
1133
1134
1135
1136

        Returns:
            Tensor: Transformed image.
        """
1137
1138
1139
        shape = tensor.shape
        n = shape[-3] * shape[-2] * shape[-1]
        if n != self.transformation_matrix.shape[0]:
1140
1141
            raise ValueError(
                "Input tensor and transformation matrix have incompatible shape."
1142
1143
                + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
                + f"{self.transformation_matrix.shape[0]}"
1144
            )
1145
1146

        if tensor.device.type != self.mean_vector.device.type:
1147
1148
            raise ValueError(
                "Input tensor should be on the same device as transformation matrix and mean vector. "
1149
                f"Got {tensor.device} vs {self.mean_vector.device}"
1150
            )
1151
1152

        flat_tensor = tensor.view(-1, n) - self.mean_vector
1153
1154
        transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
        transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
1155
1156
        tensor = transformed_tensor.view(shape)
        return tensor
1157

Joao Gomes's avatar
Joao Gomes committed
1158
1159
1160
1161
1162
1163
1164
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}(transformation_matrix="
            f"{self.transformation_matrix.tolist()}"
            f", mean_vector={self.mean_vector.tolist()})"
        )
        return s
1165

1166

1167
class ColorJitter(torch.nn.Module):
1168
    """Randomly change the brightness, contrast, saturation and hue of an image.
1169
    If the image is torch Tensor, it is expected
1170
1171
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
1172
1173

    Args:
yaox12's avatar
yaox12 committed
1174
1175
1176
1177
1178
        brightness (float or tuple of float (min, max)): How much to jitter brightness.
            brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
            or the given [min, max]. Should be non negative numbers.
        contrast (float or tuple of float (min, max)): How much to jitter contrast.
            contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
1179
            or the given [min, max]. Should be non-negative numbers.
yaox12's avatar
yaox12 committed
1180
1181
1182
1183
1184
1185
        saturation (float or tuple of float (min, max)): How much to jitter saturation.
            saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
            or the given [min, max]. Should be non negative numbers.
        hue (float or tuple of float (min, max)): How much to jitter hue.
            hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
            Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1186
1187
1188
            To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
            thus it does not work if you normalize your image to an interval with negative values,
            or use an interpolation that generates negative values before using this function.
1189
    """
1190

1191
1192
1193
1194
1195
1196
1197
    def __init__(
        self,
        brightness: Union[float, Tuple[float, float]] = 0,
        contrast: Union[float, Tuple[float, float]] = 0,
        saturation: Union[float, Tuple[float, float]] = 0,
        hue: Union[float, Tuple[float, float]] = 0,
    ) -> None:
1198
        super().__init__()
1199
        _log_api_usage_once(self)
1200
1201
1202
1203
        self.brightness = self._check_input(brightness, "brightness")
        self.contrast = self._check_input(contrast, "contrast")
        self.saturation = self._check_input(saturation, "saturation")
        self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
yaox12's avatar
yaox12 committed
1204

1205
    @torch.jit.unused
1206
    def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
yaox12's avatar
yaox12 committed
1207
1208
        if isinstance(value, numbers.Number):
            if value < 0:
1209
                raise ValueError(f"If {name} is a single number, it must be non negative.")
1210
            value = [center - float(value), center + float(value)]
yaox12's avatar
yaox12 committed
1211
            if clip_first_on_zero:
1212
                value[0] = max(value[0], 0.0)
yaox12's avatar
yaox12 committed
1213
        elif isinstance(value, (tuple, list)) and len(value) == 2:
1214
            value = [float(value[0]), float(value[1])]
yaox12's avatar
yaox12 committed
1215
        else:
1216
            raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
yaox12's avatar
yaox12 committed
1217

1218
1219
1220
        if not bound[0] <= value[0] <= value[1] <= bound[1]:
            raise ValueError(f"{name} values should be between {bound}, but got {value}.")

yaox12's avatar
yaox12 committed
1221
1222
1223
        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
1224
1225
1226
            return None
        else:
            return tuple(value)
1227
1228

    @staticmethod
1229
1230
1231
1232
1233
1234
    def get_params(
        brightness: Optional[List[float]],
        contrast: Optional[List[float]],
        saturation: Optional[List[float]],
        hue: Optional[List[float]],
    ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
1235
        """Get the parameters for the randomized transform to be applied on image.
1236

1237
1238
1239
1240
1241
1242
1243
1244
1245
        Args:
            brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
                uniformly. Pass None to turn off the transformation.
            contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
                uniformly. Pass None to turn off the transformation.
            saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
                uniformly. Pass None to turn off the transformation.
            hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
                Pass None to turn off the transformation.
1246
1247

        Returns:
1248
1249
            tuple: The parameters used to apply the randomized transform
            along with their random order.
1250
        """
1251
        fn_idx = torch.randperm(4)
1252

1253
1254
1255
1256
        b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
        c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
        s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
        h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
1257

1258
        return fn_idx, b, c, s, h
1259

1260
    def forward(self, img):
1261
1262
        """
        Args:
1263
            img (PIL Image or Tensor): Input image.
1264
1265

        Returns:
1266
1267
            PIL Image or Tensor: Color jittered image.
        """
1268
1269
1270
        fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
            self.brightness, self.contrast, self.saturation, self.hue
        )
1271

1272
        for fn_id in fn_idx:
1273
            if fn_id == 0 and brightness_factor is not None:
1274
                img = F.adjust_brightness(img, brightness_factor)
1275
            elif fn_id == 1 and contrast_factor is not None:
1276
                img = F.adjust_contrast(img, contrast_factor)
1277
            elif fn_id == 2 and saturation_factor is not None:
1278
                img = F.adjust_saturation(img, saturation_factor)
1279
            elif fn_id == 3 and hue_factor is not None:
1280
1281
1282
                img = F.adjust_hue(img, hue_factor)

        return img
1283

Joao Gomes's avatar
Joao Gomes committed
1284
1285
1286
1287
1288
1289
1290
1291
1292
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}("
            f"brightness={self.brightness}"
            f", contrast={self.contrast}"
            f", saturation={self.saturation}"
            f", hue={self.hue})"
        )
        return s
1293

1294

1295
class RandomRotation(torch.nn.Module):
1296
    """Rotate the image by angle.
1297
    If the image is torch Tensor, it is expected
1298
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1299
1300

    Args:
1301
        degrees (sequence or number): Range of degrees to select from.
1302
1303
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees).
1304
1305
1306
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1307
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1308
1309
1310
1311
        expand (bool, optional): Optional expansion flag.
            If true, expands the output to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1312
        center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1313
            Default is the center of the image.
1314
1315
        fill (sequence or number): Pixel fill value for the area outside the rotated
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1316
1317
1318

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1319
1320
    """

1321
    def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0):
1322
        super().__init__()
1323
        _log_api_usage_once(self)
1324

1325
1326
1327
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

1328
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1329
1330

        if center is not None:
1331
            _check_sequence_input(center, "center", req_sizes=(2,))
1332
1333

        self.center = center
1334

1335
        self.interpolation = interpolation
1336
        self.expand = expand
1337
1338
1339
1340
1341
1342

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1343
        self.fill = fill
1344
1345

    @staticmethod
1346
    def get_params(degrees: List[float]) -> float:
1347
1348
1349
        """Get parameters for ``rotate`` for a random rotation.

        Returns:
1350
            float: angle parameter to be passed to ``rotate`` for random rotation.
1351
        """
1352
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1353
1354
        return angle

1355
    def forward(self, img):
1356
        """
1357
        Args:
1358
            img (PIL Image or Tensor): Image to be rotated.
1359
1360

        Returns:
1361
            PIL Image or Tensor: Rotated image.
1362
        """
1363
        fill = self.fill
1364
        channels, _, _ = F.get_dimensions(img)
1365
1366
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
1367
                fill = [float(fill)] * channels
1368
1369
            else:
                fill = [float(f) for f in fill]
1370
        angle = self.get_params(self.degrees)
1371

1372
        return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)
1373

Joao Gomes's avatar
Joao Gomes committed
1374
    def __repr__(self) -> str:
1375
        interpolate_str = self.interpolation.value
1376
1377
1378
        format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
        format_string += f", interpolation={interpolate_str}"
        format_string += f", expand={self.expand}"
1379
        if self.center is not None:
1380
            format_string += f", center={self.center}"
1381
        if self.fill is not None:
1382
            format_string += f", fill={self.fill}"
1383
        format_string += ")"
1384
        return format_string
1385

1386

1387
1388
class RandomAffine(torch.nn.Module):
    """Random affine transformation of the image keeping center invariant.
1389
    If the image is torch Tensor, it is expected
1390
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1391
1392

    Args:
1393
        degrees (sequence or number): Range of degrees to select from.
1394
            If degrees is a number instead of sequence like (min, max), the range of degrees
1395
            will be (-degrees, +degrees). Set to 0 to deactivate rotations.
1396
1397
1398
1399
1400
1401
        translate (tuple, optional): tuple of maximum absolute fraction for horizontal
            and vertical translations. For example translate=(a, b), then horizontal shift
            is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
            randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
        scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
            randomly sampled from the range a <= scale <= b. Will keep original scale by default.
1402
        shear (sequence or number, optional): Range of degrees to select from.
1403
1404
            If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
            will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
1405
            range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
1406
            an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1407
            Will not apply shear by default.
1408
1409
1410
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1411
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1412
1413
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1414
1415
        center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
            Default is the center of the image.
1416
1417
1418

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1419
1420
    """

1421
    def __init__(
1422
1423
1424
1425
1426
1427
1428
        self,
        degrees,
        translate=None,
        scale=None,
        shear=None,
        interpolation=InterpolationMode.NEAREST,
        fill=0,
1429
        center=None,
1430
    ):
1431
        super().__init__()
1432
        _log_api_usage_once(self)
1433

1434
1435
1436
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

1437
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1438
1439

        if translate is not None:
1440
            _check_sequence_input(translate, "translate", req_sizes=(2,))
1441
1442
1443
1444
1445
1446
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
1447
            _check_sequence_input(scale, "scale", req_sizes=(2,))
1448
1449
1450
1451
1452
1453
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
1454
            self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
1455
1456
1457
        else:
            self.shear = shear

1458
        self.interpolation = interpolation
1459
1460
1461
1462
1463
1464

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1465
        self.fill = fill
1466

1467
1468
1469
1470
1471
        if center is not None:
            _check_sequence_input(center, "center", req_sizes=(2,))

        self.center = center

1472
    @staticmethod
1473
    def get_params(
1474
1475
1476
1477
1478
        degrees: List[float],
        translate: Optional[List[float]],
        scale_ranges: Optional[List[float]],
        shears: Optional[List[float]],
        img_size: List[int],
1479
    ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
1480
1481
1482
        """Get parameters for affine transformation

        Returns:
1483
            params to be passed to the affine transformation
1484
        """
1485
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1486
        if translate is not None:
1487
1488
1489
1490
1491
            max_dx = float(translate[0] * img_size[0])
            max_dy = float(translate[1] * img_size[1])
            tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
            ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
            translations = (tx, ty)
1492
1493
1494
1495
        else:
            translations = (0, 0)

        if scale_ranges is not None:
1496
            scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
1497
1498
1499
        else:
            scale = 1.0

1500
        shear_x = shear_y = 0.0
1501
        if shears is not None:
1502
1503
1504
1505
1506
            shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
            if len(shears) == 4:
                shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())

        shear = (shear_x, shear_y)
1507
1508
1509

        return angle, translations, scale, shear

1510
    def forward(self, img):
1511
        """
1512
            img (PIL Image or Tensor): Image to be transformed.
1513
1514

        Returns:
1515
            PIL Image or Tensor: Affine transformed image.
1516
        """
1517
        fill = self.fill
1518
        channels, height, width = F.get_dimensions(img)
1519
1520
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
1521
                fill = [float(fill)] * channels
1522
1523
            else:
                fill = [float(f) for f in fill]
1524

1525
        img_size = [width, height]  # flip for keeping BC on get_params call
1526
1527

        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
1528

1529
        return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
1530

Joao Gomes's avatar
Joao Gomes committed
1531
1532
1533
1534
1535
1536
1537
1538
    def __repr__(self) -> str:
        s = f"{self.__class__.__name__}(degrees={self.degrees}"
        s += f", translate={self.translate}" if self.translate is not None else ""
        s += f", scale={self.scale}" if self.scale is not None else ""
        s += f", shear={self.shear}" if self.shear is not None else ""
        s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
        s += f", fill={self.fill}" if self.fill != 0 else ""
        s += f", center={self.center}" if self.center is not None else ""
1539
        s += ")"
Joao Gomes's avatar
Joao Gomes committed
1540
1541

        return s
1542
1543


1544
class Grayscale(torch.nn.Module):
1545
    """Convert image to grayscale.
1546
1547
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1548

1549
1550
1551
1552
    Args:
        num_output_channels (int): (1 or 3) number of channels desired for output image

    Returns:
1553
        PIL Image: Grayscale version of the input.
1554
1555
1556

        - If ``num_output_channels == 1`` : returned image is single channel
        - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
1557
1558
1559
1560

    """

    def __init__(self, num_output_channels=1):
1561
        super().__init__()
1562
        _log_api_usage_once(self)
1563
1564
        self.num_output_channels = num_output_channels

vfdev's avatar
vfdev committed
1565
    def forward(self, img):
1566
1567
        """
        Args:
1568
            img (PIL Image or Tensor): Image to be converted to grayscale.
1569
1570

        Returns:
1571
            PIL Image or Tensor: Grayscaled image.
1572
        """
1573
        return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
1574

Joao Gomes's avatar
Joao Gomes committed
1575
1576
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})"
1577

1578

1579
class RandomGrayscale(torch.nn.Module):
1580
    """Randomly convert image to grayscale with a probability of p (default 0.1).
1581
1582
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1583

1584
1585
1586
1587
    Args:
        p (float): probability that image should be converted to grayscale.

    Returns:
1588
        PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
1589
1590
1591
        with probability (1-p).
        - If input image is 1 channel: grayscale version is 1 channel
        - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1592
1593
1594
1595

    """

    def __init__(self, p=0.1):
1596
        super().__init__()
1597
        _log_api_usage_once(self)
1598
1599
        self.p = p

vfdev's avatar
vfdev committed
1600
    def forward(self, img):
1601
1602
        """
        Args:
1603
            img (PIL Image or Tensor): Image to be converted to grayscale.
1604
1605

        Returns:
1606
            PIL Image or Tensor: Randomly grayscaled image.
1607
        """
1608
        num_output_channels, _, _ = F.get_dimensions(img)
1609
1610
        if torch.rand(1) < self.p:
            return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
1611
        return img
1612

Joao Gomes's avatar
Joao Gomes committed
1613
1614
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
1615
1616


1617
class RandomErasing(torch.nn.Module):
1618
    """Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
1619
    This transform does not support PIL Image.
vfdev's avatar
vfdev committed
1620
    'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
1621

1622
1623
1624
1625
1626
1627
1628
1629
    Args:
         p: probability that the random erasing operation will be performed.
         scale: range of proportion of erased area against input image.
         ratio: range of aspect ratio of erased area.
         value: erasing value. Default is 0. If a single int, it is used to
            erase all pixels. If a tuple of length 3, it is used to erase
            R, G, B channels respectively.
            If a str of 'random', erasing each pixel with random values.
Zhun Zhong's avatar
Zhun Zhong committed
1630
         inplace: boolean to make this transform inplace. Default set to False.
1631

1632
1633
    Returns:
        Erased Image.
1634

vfdev's avatar
vfdev committed
1635
    Example:
1636
        >>> transform = transforms.Compose([
1637
        >>>   transforms.RandomHorizontalFlip(),
1638
1639
        >>>   transforms.PILToTensor(),
        >>>   transforms.ConvertImageDtype(torch.float),
1640
1641
        >>>   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>>   transforms.RandomErasing(),
1642
1643
1644
        >>> ])
    """

Zhun Zhong's avatar
Zhun Zhong committed
1645
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
1646
        super().__init__()
1647
        _log_api_usage_once(self)
1648
1649
1650
1651
1652
1653
1654
1655
        if not isinstance(value, (numbers.Number, str, tuple, list)):
            raise TypeError("Argument value should be either a number or str or a sequence")
        if isinstance(value, str) and value != "random":
            raise ValueError("If value is str, it should be 'random'")
        if not isinstance(scale, (tuple, list)):
            raise TypeError("Scale should be a sequence")
        if not isinstance(ratio, (tuple, list)):
            raise TypeError("Ratio should be a sequence")
1656
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
1657
            warnings.warn("Scale and ratio should be of kind (min, max)")
1658
        if scale[0] < 0 or scale[1] > 1:
1659
            raise ValueError("Scale should be between 0 and 1")
1660
        if p < 0 or p > 1:
1661
            raise ValueError("Random erasing probability should be between 0 and 1")
1662
1663
1664
1665
1666

        self.p = p
        self.scale = scale
        self.ratio = ratio
        self.value = value
1667
        self.inplace = inplace
1668
1669

    @staticmethod
1670
    def get_params(
1671
        img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
1672
    ) -> Tuple[int, int, int, int, Tensor]:
1673
1674
1675
        """Get parameters for ``erase`` for a random erasing.

        Args:
vfdev's avatar
vfdev committed
1676
            img (Tensor): Tensor image to be erased.
1677
1678
            scale (sequence): range of proportion of erased area against input image.
            ratio (sequence): range of aspect ratio of erased area.
1679
1680
1681
            value (list, optional): erasing value. If None, it is interpreted as "random"
                (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
                i.e. ``value[0]``.
1682
1683
1684
1685

        Returns:
            tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
        """
vfdev's avatar
vfdev committed
1686
        img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
1687
        area = img_h * img_w
1688

1689
        log_ratio = torch.log(torch.tensor(ratio))
1690
        for _ in range(10):
1691
            erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
1692
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
1693
1694
1695

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
1696
1697
1698
1699
1700
1701
1702
            if not (h < img_h and w < img_w):
                continue

            if value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(value)[:, None, None]
1703

1704
1705
            i = torch.randint(0, img_h - h + 1, size=(1,)).item()
            j = torch.randint(0, img_w - w + 1, size=(1,)).item()
1706
            return i, j, h, w, v
1707

Zhun Zhong's avatar
Zhun Zhong committed
1708
1709
1710
        # Return original image
        return 0, 0, img_h, img_w, img

1711
    def forward(self, img):
1712
1713
        """
        Args:
vfdev's avatar
vfdev committed
1714
            img (Tensor): Tensor image to be erased.
1715
1716
1717
1718

        Returns:
            img (Tensor): Erased Tensor image.
        """
1719
1720
1721
1722
        if torch.rand(1) < self.p:

            # cast self.value to script acceptable type
            if isinstance(self.value, (int, float)):
1723
                value = [float(self.value)]
1724
1725
            elif isinstance(self.value, str):
                value = None
1726
1727
            elif isinstance(self.value, (list, tuple)):
                value = [float(v) for v in self.value]
1728
1729
1730
1731
1732
1733
            else:
                value = self.value

            if value is not None and not (len(value) in (1, img.shape[-3])):
                raise ValueError(
                    "If value is a sequence, it should have either a single value or "
1734
                    f"{img.shape[-3]} (number of input channels)"
1735
1736
1737
                )

            x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
1738
            return F.erase(img, x, y, h, w, v, self.inplace)
1739
        return img
1740

Joao Gomes's avatar
Joao Gomes committed
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}"
            f"(p={self.p}, "
            f"scale={self.scale}, "
            f"ratio={self.ratio}, "
            f"value={self.value}, "
            f"inplace={self.inplace})"
        )
        return s
1751

1752

1753
1754
class GaussianBlur(torch.nn.Module):
    """Blurs image with randomly chosen Gaussian blur.
1755
1756
    If the image is torch Tensor, it is expected
    to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771

    Args:
        kernel_size (int or sequence): Size of the Gaussian kernel.
        sigma (float or tuple of float (min, max)): Standard deviation to be used for
            creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
            of float (min, max), sigma is chosen uniformly at random to lie in the
            given range.

    Returns:
        PIL Image or Tensor: Gaussian blurred version of the input image.

    """

    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        super().__init__()
1772
        _log_api_usage_once(self)
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
        self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
        for ks in self.kernel_size:
            if ks <= 0 or ks % 2 == 0:
                raise ValueError("Kernel size value should be an odd and positive number.")

        if isinstance(sigma, numbers.Number):
            if sigma <= 0:
                raise ValueError("If sigma is a single number, it must be positive.")
            sigma = (sigma, sigma)
        elif isinstance(sigma, Sequence) and len(sigma) == 2:
1783
            if not 0.0 < sigma[0] <= sigma[1]:
1784
1785
1786
1787
1788
1789
1790
1791
                raise ValueError("sigma values should be positive and of the form (min, max).")
        else:
            raise ValueError("sigma should be a single number or a list/tuple with length 2.")

        self.sigma = sigma

    @staticmethod
    def get_params(sigma_min: float, sigma_max: float) -> float:
vfdev's avatar
vfdev committed
1792
        """Choose sigma for random gaussian blurring.
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805

        Args:
            sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
            sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.

        Returns:
            float: Standard deviation to be passed to calculate kernel for gaussian blurring.
        """
        return torch.empty(1).uniform_(sigma_min, sigma_max).item()

    def forward(self, img: Tensor) -> Tensor:
        """
        Args:
vfdev's avatar
vfdev committed
1806
            img (PIL Image or Tensor): image to be blurred.
1807
1808
1809
1810
1811
1812
1813

        Returns:
            PIL Image or Tensor: Gaussian blurred image
        """
        sigma = self.get_params(self.sigma[0], self.sigma[1])
        return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])

Joao Gomes's avatar
Joao Gomes committed
1814
1815
1816
    def __repr__(self) -> str:
        s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
        return s
1817
1818


1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
def _setup_size(size, error_msg):
    if isinstance(size, numbers.Number):
        return int(size), int(size)

    if isinstance(size, Sequence) and len(size) == 1:
        return size[0], size[0]

    if len(size) != 2:
        raise ValueError(error_msg)

    return size


def _check_sequence_input(x, name, req_sizes):
    msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
    if not isinstance(x, Sequence):
1835
        raise TypeError(f"{name} should be a sequence of length {msg}.")
1836
    if len(x) not in req_sizes:
1837
        raise ValueError(f"{name} should be a sequence of length {msg}.")
1838
1839


1840
def _setup_angle(x, name, req_sizes=(2,)):
1841
1842
    if isinstance(x, numbers.Number):
        if x < 0:
1843
            raise ValueError(f"If {name} is a single number, it must be positive.")
1844
1845
1846
1847
1848
        x = [-x, x]
    else:
        _check_sequence_input(x, name, req_sizes)

    return [float(d) for d in x]
1849
1850
1851
1852


class RandomInvert(torch.nn.Module):
    """Inverts the colors of the given image randomly with a given probability.
1853
1854
1855
    If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1856
1857
1858
1859
1860
1861
1862

    Args:
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
1863
        _log_api_usage_once(self)
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be inverted.

        Returns:
            PIL Image or Tensor: Randomly color inverted image.
        """
        if torch.rand(1).item() < self.p:
            return F.invert(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
1878
1879
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
1880
1881
1882
1883


class RandomPosterize(torch.nn.Module):
    """Posterize the image randomly with a given probability by reducing the
1884
1885
1886
    number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
    and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1887
1888
1889

    Args:
        bits (int): number of bits to keep for each channel (0-8)
1890
        p (float): probability of the image being posterized. Default value is 0.5
1891
1892
1893
1894
    """

    def __init__(self, bits, p=0.5):
        super().__init__()
1895
        _log_api_usage_once(self)
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
        self.bits = bits
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be posterized.

        Returns:
            PIL Image or Tensor: Randomly posterized image.
        """
        if torch.rand(1).item() < self.p:
            return F.posterize(img, self.bits)
        return img

Joao Gomes's avatar
Joao Gomes committed
1911
1912
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
1913
1914
1915
1916


class RandomSolarize(torch.nn.Module):
    """Solarize the image randomly with a given probability by inverting all pixel
1917
1918
1919
    values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1920
1921
1922

    Args:
        threshold (float): all pixels equal or above this value are inverted.
1923
        p (float): probability of the image being solarized. Default value is 0.5
1924
1925
1926
1927
    """

    def __init__(self, threshold, p=0.5):
        super().__init__()
1928
        _log_api_usage_once(self)
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
        self.threshold = threshold
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be solarized.

        Returns:
            PIL Image or Tensor: Randomly solarized image.
        """
        if torch.rand(1).item() < self.p:
            return F.solarize(img, self.threshold)
        return img

Joao Gomes's avatar
Joao Gomes committed
1944
1945
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
1946
1947
1948


class RandomAdjustSharpness(torch.nn.Module):
1949
1950
    """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
    it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1951
1952
1953

    Args:
        sharpness_factor (float):  How much to adjust the sharpness. Can be
1954
            any non-negative number. 0 gives a blurred image, 1 gives the
1955
            original image while 2 increases the sharpness by a factor of 2.
1956
        p (float): probability of the image being sharpened. Default value is 0.5
1957
1958
1959
1960
    """

    def __init__(self, sharpness_factor, p=0.5):
        super().__init__()
1961
        _log_api_usage_once(self)
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
        self.sharpness_factor = sharpness_factor
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be sharpened.

        Returns:
            PIL Image or Tensor: Randomly sharpened image.
        """
        if torch.rand(1).item() < self.p:
            return F.adjust_sharpness(img, self.sharpness_factor)
        return img

Joao Gomes's avatar
Joao Gomes committed
1977
1978
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
1979
1980
1981
1982


class RandomAutocontrast(torch.nn.Module):
    """Autocontrast the pixels of the given image randomly with a given probability.
1983
1984
1985
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1986
1987
1988
1989
1990
1991
1992

    Args:
        p (float): probability of the image being autocontrasted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
1993
        _log_api_usage_once(self)
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be autocontrasted.

        Returns:
            PIL Image or Tensor: Randomly autocontrasted image.
        """
        if torch.rand(1).item() < self.p:
            return F.autocontrast(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
2008
2009
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
2010
2011
2012
2013


class RandomEqualize(torch.nn.Module):
    """Equalize the histogram of the given image randomly with a given probability.
2014
2015
2016
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
2017
2018
2019
2020
2021
2022
2023

    Args:
        p (float): probability of the image being equalized. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
2024
        _log_api_usage_once(self)
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be equalized.

        Returns:
            PIL Image or Tensor: Randomly equalized image.
        """
        if torch.rand(1).item() < self.p:
            return F.equalize(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
2039
2040
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060


class ElasticTransform(torch.nn.Module):
    """Transform a tensor image with elastic transformations.
    Given alpha and sigma, it will generate displacement
    vectors for all pixels based on random offsets. Alpha controls the strength
    and sigma controls the smoothness of the displacements.
    The displacements are added to an identity grid and the resulting grid is
    used to grid_sample from the image.

    Applications:
        Randomly transforms the morphology of objects in images and produces a
        see-through-water-like effect.

    Args:
        alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
        sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
2061
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.

    """

    def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
        super().__init__()
        _log_api_usage_once(self)
        if not isinstance(alpha, (float, Sequence)):
            raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
        if isinstance(alpha, Sequence) and len(alpha) != 2:
            raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
        if isinstance(alpha, Sequence):
            for element in alpha:
                if not isinstance(element, float):
                    raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")

        if isinstance(alpha, float):
            alpha = [float(alpha), float(alpha)]
        if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
            alpha = [alpha[0], alpha[0]]

        self.alpha = alpha

        if not isinstance(sigma, (float, Sequence)):
            raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
        if isinstance(sigma, Sequence) and len(sigma) != 2:
            raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
        if isinstance(sigma, Sequence):
            for element in sigma:
                if not isinstance(element, float):
                    raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")

        if isinstance(sigma, float):
            sigma = [float(sigma), float(sigma)]
        if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
            sigma = [sigma[0], sigma[0]]

        self.sigma = sigma

        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)
        self.interpolation = interpolation

2106
2107
2108
2109
2110
2111
        if isinstance(fill, (int, float)):
            fill = [float(fill)]
        elif isinstance(fill, (list, tuple)):
            fill = [float(f) for f in fill]
        else:
            raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
        self.fill = fill

    @staticmethod
    def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
        dx = torch.rand([1, 1] + size) * 2 - 1
        if sigma[0] > 0.0:
            kx = int(8 * sigma[0] + 1)
            # if kernel size is even we have to make it odd
            if kx % 2 == 0:
                kx += 1
            dx = F.gaussian_blur(dx, [kx, kx], sigma)
        dx = dx * alpha[0] / size[0]

        dy = torch.rand([1, 1] + size) * 2 - 1
        if sigma[1] > 0.0:
            ky = int(8 * sigma[1] + 1)
            # if kernel size is even we have to make it odd
            if ky % 2 == 0:
                ky += 1
            dy = F.gaussian_blur(dy, [ky, ky], sigma)
        dy = dy * alpha[1] / size[1]
        return torch.concat([dx, dy], 1).permute([0, 2, 3, 1])  # 1 x H x W x 2

    def forward(self, tensor: Tensor) -> Tensor:
        """
        Args:
2138
            tensor (PIL Image or Tensor): Image to be transformed.
2139
2140
2141
2142

        Returns:
            PIL Image or Tensor: Transformed image.
        """
2143
2144
        _, height, width = F.get_dimensions(tensor)
        displacement = self.get_params(self.alpha, self.sigma, [height, width])
2145
2146
2147
        return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)

    def __repr__(self):
2148
2149
2150
2151
2152
        format_string = self.__class__.__name__
        format_string += f"(alpha={self.alpha}"
        format_string += f", sigma={self.sigma}"
        format_string += f", interpolation={self.interpolation}"
        format_string += f", fill={self.fill})"
2153
        return format_string