"vscode:/vscode.git/clone" did not exist on "82eccae44e8603d6faee57d22b64c030f3490f0c"
transforms.py 84.1 KB
Newer Older
1
import math
vfdev's avatar
vfdev committed
2
import numbers
3
import random
vfdev's avatar
vfdev committed
4
import warnings
vfdev's avatar
vfdev committed
5
from collections.abc import Sequence
6
from typing import List, Optional, Tuple, Union
vfdev's avatar
vfdev committed
7
8
9
10

import torch
from torch import Tensor

11
12
13
14
15
try:
    import accimage
except ImportError:
    accimage = None

16
from ..utils import _log_api_usage_once
17
from . import functional as F
18
from .functional import _interpolation_modes_from_int, InterpolationMode
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
__all__ = [
    "Compose",
    "ToTensor",
    "PILToTensor",
    "ConvertImageDtype",
    "ToPILImage",
    "Normalize",
    "Resize",
    "CenterCrop",
    "Pad",
    "Lambda",
    "RandomApply",
    "RandomChoice",
    "RandomOrder",
    "RandomCrop",
    "RandomHorizontalFlip",
    "RandomVerticalFlip",
    "RandomResizedCrop",
    "FiveCrop",
    "TenCrop",
    "LinearTransformation",
    "ColorJitter",
    "RandomRotation",
    "RandomAffine",
    "Grayscale",
    "RandomGrayscale",
    "RandomPerspective",
    "RandomErasing",
    "GaussianBlur",
    "InterpolationMode",
    "RandomInvert",
    "RandomPosterize",
    "RandomSolarize",
    "RandomAdjustSharpness",
    "RandomAutocontrast",
    "RandomEqualize",
56
    "ElasticTransform",
57
]
58

59

60
class Compose:
61
62
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.
63
64
65
66
67
68
69

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
70
71
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
72
        >>> ])
73
74
75
76
77
78
79
80
81
82
83
84
85

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

86
87
88
    """

    def __init__(self, transforms):
89
90
        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            _log_api_usage_once(self)
91
92
93
94
95
96
97
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
98
    def __repr__(self) -> str:
99
        format_string = self.__class__.__name__ + "("
100
        for t in self.transforms:
101
            format_string += "\n"
102
            format_string += f"    {t}"
103
        format_string += "\n)"
104
105
        return format_string

106

107
class ToTensor:
Nicolas Hug's avatar
Nicolas Hug committed
108
109
110
    """Convert a PIL Image or ndarray to tensor and scale the values accordingly.

    This transform does not support torchscript.
111
112

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
surgan12's avatar
surgan12 committed
113
114
115
116
117
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.
118
119
120
121
122

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

123
    .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
124
125
    """

126
127
128
    def __init__(self) -> None:
        _log_api_usage_once(self)

129
130
131
132
133
134
135
136
137
138
    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

Joao Gomes's avatar
Joao Gomes committed
139
140
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
141

142

143
class PILToTensor:
Nicolas Hug's avatar
Nicolas Hug committed
144
145
146
    """Convert a PIL Image to a tensor of the same type - this does not scale values.

    This transform does not support torchscript.
147

vfdev's avatar
vfdev committed
148
    Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
149
150
    """

151
152
153
    def __init__(self) -> None:
        _log_api_usage_once(self)

154
155
    def __call__(self, pic):
        """
156
157
158
159
        .. note::

            A deep copy of the underlying array is performed.

160
161
162
163
164
165
166
167
        Args:
            pic (PIL Image): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.pil_to_tensor(pic)

Joao Gomes's avatar
Joao Gomes committed
168
169
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
170
171


172
class ConvertImageDtype(torch.nn.Module):
Nicolas Hug's avatar
Nicolas Hug committed
173
174
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly.

175
    This function does not support PIL Image.
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

    Args:
        dtype (torch.dtype): Desired data type of the output

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """

    def __init__(self, dtype: torch.dtype) -> None:
193
        super().__init__()
194
        _log_api_usage_once(self)
195
196
        self.dtype = dtype

vfdev's avatar
vfdev committed
197
    def forward(self, image):
198
199
200
        return F.convert_image_dtype(image, self.dtype)


201
class ToPILImage:
Nicolas Hug's avatar
Nicolas Hug committed
202
203
204
    """Convert a tensor or an ndarray to PIL Image - this does not scale values.

    This transform does not support torchscript.
205
206
207
208
209
210
211

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
vfdev's avatar
vfdev committed
212
213
214
215
216
            - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
            - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
            - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
            - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
            ``short``).
217

csukuangfj's avatar
csukuangfj committed
218
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
219
    """
220

221
    def __init__(self, mode=None):
222
        _log_api_usage_once(self)
223
224
225
226
227
228
229
230
231
232
233
234
235
        self.mode = mode

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        return F.to_pil_image(pic, self.mode)

Joao Gomes's avatar
Joao Gomes committed
236
    def __repr__(self) -> str:
237
        format_string = self.__class__.__name__ + "("
238
        if self.mode is not None:
239
            format_string += f"mode={self.mode}"
240
        format_string += ")"
241
        return format_string
242

243

244
class Normalize(torch.nn.Module):
Fang Gao's avatar
Fang Gao committed
245
    """Normalize a tensor image with mean and standard deviation.
246
    This transform does not support PIL Image.
247
248
249
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
abdjava's avatar
abdjava committed
250
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
251

252
    .. note::
253
        This transform acts out of place, i.e., it does not mutate the input tensor.
254

255
256
257
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
258
259
        inplace(bool,optional): Bool to make this operation in-place.

260
261
    """

surgan12's avatar
surgan12 committed
262
    def __init__(self, mean, std, inplace=False):
263
        super().__init__()
264
        _log_api_usage_once(self)
265
266
        self.mean = mean
        self.std = std
surgan12's avatar
surgan12 committed
267
        self.inplace = inplace
268

269
    def forward(self, tensor: Tensor) -> Tensor:
270
271
        """
        Args:
vfdev's avatar
vfdev committed
272
            tensor (Tensor): Tensor image to be normalized.
273
274
275
276

        Returns:
            Tensor: Normalized Tensor image.
        """
surgan12's avatar
surgan12 committed
277
        return F.normalize(tensor, self.mean, self.std, self.inplace)
278

Joao Gomes's avatar
Joao Gomes committed
279
280
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
281

282

vfdev's avatar
vfdev committed
283
284
class Resize(torch.nn.Module):
    """Resize the input image to the given size.
285
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
286
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
287

288
289
290
291
    .. warning::
        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
292
293
        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
        closer.
294

295
296
297
298
299
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
300
            (size * height / width, size).
301
302
303

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
304
305
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
306
307
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
308
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
309
310
311
312
        max_size (int, optional): The maximum allowed for the longer edge of
            the resized image: if the longer edge of the image is greater
            than ``max_size`` after being resized according to ``size``, then
            the image is resized again so that the longer edge is equal to
313
            ``max_size``. As a result, ``size`` might be overruled, i.e. the
314
315
316
            smaller edge may be shorter than ``size``. This is only supported
            if ``size`` is an int (or a sequence of length 1 in torchscript
            mode).
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        antialias (bool, optional): Whether to apply antialiasing.
            It only affects **tensors** with bilinear or bicubic modes and it is
            ignored otherwise: on PIL images, antialiasing is always applied on
            bilinear or bicubic modes; on other modes (for PIL images and
            tensors), antialiasing makes no sense and this parameter is ignored.
            Possible values are:

            - ``True``: will apply antialiasing for bilinear or bicubic modes.
              Other mode aren't affected. This is probably what you want to use.
            - ``False``: will not apply antialiasing for tensors on any mode. PIL
              images are still antialiased on bilinear or bicubic modes, because
              PIL doesn't support no antialias.
            - ``None``: equivalent to ``False`` for tensors and ``True`` for
              PIL images. This value exists for legacy reasons and you probably
              don't want to use it unless you really know what you are doing.

            The current default is ``None`` **but will change to** ``True`` **in
            v0.17** for the PIL and Tensor backends to be consistent.
335
336
    """

337
    def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias="warn"):
vfdev's avatar
vfdev committed
338
        super().__init__()
339
        _log_api_usage_once(self)
340
        if not isinstance(size, (int, Sequence)):
341
            raise TypeError(f"Size should be int or sequence. Got {type(size)}")
342
343
344
        if isinstance(size, Sequence) and len(size) not in (1, 2):
            raise ValueError("If size is a sequence, it should have 1 or 2 values")
        self.size = size
345
        self.max_size = max_size
346

347
348
349
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

350
        self.interpolation = interpolation
351
        self.antialias = antialias
352

vfdev's avatar
vfdev committed
353
    def forward(self, img):
354
355
        """
        Args:
vfdev's avatar
vfdev committed
356
            img (PIL Image or Tensor): Image to be scaled.
357
358

        Returns:
vfdev's avatar
vfdev committed
359
            PIL Image or Tensor: Rescaled image.
360
        """
361
        return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
362

Joao Gomes's avatar
Joao Gomes committed
363
    def __repr__(self) -> str:
364
        detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
Joao Gomes's avatar
Joao Gomes committed
365
        return f"{self.__class__.__name__}{detail}"
366

367

vfdev's avatar
vfdev committed
368
369
class CenterCrop(torch.nn.Module):
    """Crops the given image at the center.
370
    If the image is torch Tensor, it is expected
371
372
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
373
374
375
376

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
377
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
378
379
380
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
381
        super().__init__()
382
        _log_api_usage_once(self)
383
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
384

vfdev's avatar
vfdev committed
385
    def forward(self, img):
386
387
        """
        Args:
vfdev's avatar
vfdev committed
388
            img (PIL Image or Tensor): Image to be cropped.
389
390

        Returns:
vfdev's avatar
vfdev committed
391
            PIL Image or Tensor: Cropped image.
392
393
394
        """
        return F.center_crop(img, self.size)

Joao Gomes's avatar
Joao Gomes committed
395
396
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size})"
397

398

399
400
class Pad(torch.nn.Module):
    """Pad the given image on all sides with the given "pad" value.
401
    If the image is torch Tensor, it is expected
402
403
404
    to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
    at most 3 leading dimensions for mode edge,
    and an arbitrary number of leading dimensions for mode constant
405
406

    Args:
407
408
409
        padding (int or sequence): Padding on each border. If a single int is provided this
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
410
            this is the padding for the left, top, right and bottom borders respectively.
411
412
413
414

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
415
        fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
416
            length 3, it is used to fill R, G, B channels respectively.
417
418
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
419
            Only int or tuple value is supported for PIL Image.
420
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
421
            Default is constant.
422
423
424

            - constant: pads with a constant value, this value is specified with fill

425
426
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
427

428
429
430
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
431

432
433
434
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
435
436
    """

437
438
    def __init__(self, padding, fill=0, padding_mode="constant"):
        super().__init__()
439
        _log_api_usage_once(self)
440
441
442
        if not isinstance(padding, (numbers.Number, tuple, list)):
            raise TypeError("Got inappropriate padding arg")

443
        if not isinstance(fill, (numbers.Number, tuple, list)):
444
445
446
447
448
449
            raise TypeError("Got inappropriate fill arg")

        if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
            raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

        if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
450
            raise ValueError(
451
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
452
            )
453
454
455

        self.padding = padding
        self.fill = fill
456
        self.padding_mode = padding_mode
457

458
    def forward(self, img):
459
460
        """
        Args:
461
            img (PIL Image or Tensor): Image to be padded.
462
463

        Returns:
464
            PIL Image or Tensor: Padded image.
465
        """
466
        return F.pad(img, self.padding, self.fill, self.padding_mode)
467

Joao Gomes's avatar
Joao Gomes committed
468
469
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
470

471

472
class Lambda:
473
    """Apply a user-defined lambda as a transform. This transform does not support torchscript.
474
475
476
477
478
479

    Args:
        lambd (function): Lambda/function to be used for transform.
    """

    def __init__(self, lambd):
480
        _log_api_usage_once(self)
481
        if not callable(lambd):
482
            raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
483
484
485
486
487
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

Joao Gomes's avatar
Joao Gomes committed
488
489
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
490

491

492
class RandomTransforms:
493
494
495
    """Base class for a list of transformations with randomness

    Args:
496
        transforms (sequence): list of transformations
497
498
499
    """

    def __init__(self, transforms):
500
        _log_api_usage_once(self)
501
502
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence")
503
504
505
506
507
        self.transforms = transforms

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()

Joao Gomes's avatar
Joao Gomes committed
508
    def __repr__(self) -> str:
509
        format_string = self.__class__.__name__ + "("
510
        for t in self.transforms:
511
            format_string += "\n"
512
            format_string += f"    {t}"
513
        format_string += "\n)"
514
515
516
        return format_string


517
class RandomApply(torch.nn.Module):
518
    """Apply randomly a list of transformations with a given probability.
519
520
521
522
523
524
525
526
527
528
529
530

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.
531
532

    Args:
533
        transforms (sequence or torch.nn.Module): list of transformations
534
535
536
537
        p (float): probability
    """

    def __init__(self, transforms, p=0.5):
538
        super().__init__()
539
        _log_api_usage_once(self)
540
        self.transforms = transforms
541
542
        self.p = p

543
544
    def forward(self, img):
        if self.p < torch.rand(1):
545
546
547
548
549
            return img
        for t in self.transforms:
            img = t(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
550
    def __repr__(self) -> str:
551
        format_string = self.__class__.__name__ + "("
552
        format_string += f"\n    p={self.p}"
553
        for t in self.transforms:
554
            format_string += "\n"
555
            format_string += f"    {t}"
556
        format_string += "\n)"
557
558
559
560
        return format_string


class RandomOrder(RandomTransforms):
561
562
    """Apply a list of transformations in a random order. This transform does not support torchscript."""

563
564
565
566
567
568
569
570
571
    def __call__(self, img):
        order = list(range(len(self.transforms)))
        random.shuffle(order)
        for i in order:
            img = self.transforms[i](img)
        return img


class RandomChoice(RandomTransforms):
572
573
    """Apply single transformation randomly picked from a list. This transform does not support torchscript."""

574
575
576
    def __init__(self, transforms, p=None):
        super().__init__(transforms)
        if p is not None and not isinstance(p, Sequence):
577
            raise TypeError("Argument p should be a sequence")
578
579
580
581
582
583
        self.p = p

    def __call__(self, *args):
        t = random.choices(self.transforms, weights=self.p)[0]
        return t(*args)

Joao Gomes's avatar
Joao Gomes committed
584
585
    def __repr__(self) -> str:
        return f"{super().__repr__()}(p={self.p})"
586
587


vfdev's avatar
vfdev committed
588
589
class RandomCrop(torch.nn.Module):
    """Crop the given image at a random location.
590
    If the image is torch Tensor, it is expected
591
592
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
    but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
593
594
595
596

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
597
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
598
        padding (int or sequence, optional): Optional padding on each border
vfdev's avatar
vfdev committed
599
            of the image. Default is None. If a single int is provided this
600
601
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
vfdev's avatar
vfdev committed
602
            this is the padding for the left, top, right and bottom borders respectively.
603
604
605
606

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
607
        pad_if_needed (boolean): It will pad the image if smaller than the
ekka's avatar
ekka committed
608
            desired size to avoid raising an exception. Since cropping is done
609
            after padding, the padding seems to be done at a random offset.
610
        fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
611
            length 3, it is used to fill R, G, B channels respectively.
612
613
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
614
            Only int or tuple value is supported for PIL Image.
615
616
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is constant.
617

618
            - constant: pads with a constant value, this value is specified with fill
619

620
621
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
622

623
624
625
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
626

627
628
629
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
630
631
632
    """

    @staticmethod
vfdev's avatar
vfdev committed
633
    def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
634
635
636
        """Get parameters for ``crop`` for a random crop.

        Args:
vfdev's avatar
vfdev committed
637
            img (PIL Image or Tensor): Image to be cropped.
638
639
640
641
642
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
643
        _, h, w = F.get_dimensions(img)
644
        th, tw = output_size
vfdev's avatar
vfdev committed
645

646
        if h < th or w < tw:
647
            raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
vfdev's avatar
vfdev committed
648

649
650
651
        if w == tw and h == th:
            return 0, 0, h, w

652
653
        i = torch.randint(0, h - th + 1, size=(1,)).item()
        j = torch.randint(0, w - tw + 1, size=(1,)).item()
654
655
        return i, j, th, tw

vfdev's avatar
vfdev committed
656
657
    def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
        super().__init__()
658
        _log_api_usage_once(self)
vfdev's avatar
vfdev committed
659

660
        self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
661

vfdev's avatar
vfdev committed
662
663
664
665
666
667
        self.padding = padding
        self.pad_if_needed = pad_if_needed
        self.fill = fill
        self.padding_mode = padding_mode

    def forward(self, img):
668
669
        """
        Args:
vfdev's avatar
vfdev committed
670
            img (PIL Image or Tensor): Image to be cropped.
671
672

        Returns:
vfdev's avatar
vfdev committed
673
            PIL Image or Tensor: Cropped image.
674
        """
675
676
        if self.padding is not None:
            img = F.pad(img, self.padding, self.fill, self.padding_mode)
677

678
        _, height, width = F.get_dimensions(img)
679
        # pad the width if needed
vfdev's avatar
vfdev committed
680
681
682
        if self.pad_if_needed and width < self.size[1]:
            padding = [self.size[1] - width, 0]
            img = F.pad(img, padding, self.fill, self.padding_mode)
683
        # pad the height if needed
vfdev's avatar
vfdev committed
684
685
686
        if self.pad_if_needed and height < self.size[0]:
            padding = [0, self.size[0] - height]
            img = F.pad(img, padding, self.fill, self.padding_mode)
687

688
689
690
691
        i, j, h, w = self.get_params(img, self.size)

        return F.crop(img, i, j, h, w)

Joao Gomes's avatar
Joao Gomes committed
692
693
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})"
694

695

696
697
class RandomHorizontalFlip(torch.nn.Module):
    """Horizontally flip the given image randomly with a given probability.
698
    If the image is torch Tensor, it is expected
699
700
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
701
702
703
704
705
706

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
707
        super().__init__()
708
        _log_api_usage_once(self)
709
        self.p = p
710

711
    def forward(self, img):
712
713
        """
        Args:
714
            img (PIL Image or Tensor): Image to be flipped.
715
716

        Returns:
717
            PIL Image or Tensor: Randomly flipped image.
718
        """
719
        if torch.rand(1) < self.p:
720
721
722
            return F.hflip(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
723
724
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
725

726

727
class RandomVerticalFlip(torch.nn.Module):
vfdev's avatar
vfdev committed
728
    """Vertically flip the given image randomly with a given probability.
729
    If the image is torch Tensor, it is expected
730
731
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
732
733
734
735
736
737

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
738
        super().__init__()
739
        _log_api_usage_once(self)
740
        self.p = p
741

742
    def forward(self, img):
743
744
        """
        Args:
745
            img (PIL Image or Tensor): Image to be flipped.
746
747

        Returns:
748
            PIL Image or Tensor: Randomly flipped image.
749
        """
750
        if torch.rand(1) < self.p:
751
752
753
            return F.vflip(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
754
755
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
756

757

758
759
class RandomPerspective(torch.nn.Module):
    """Performs a random perspective transformation of the given image with a given probability.
760
    If the image is torch Tensor, it is expected
761
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
762
763

    Args:
764
765
766
        distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
            Default is 0.5.
        p (float): probability of the image being transformed. Default is 0.5.
767
768
769
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
770
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
771
772
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
773
774
    """

775
    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
776
        super().__init__()
777
        _log_api_usage_once(self)
778
        self.p = p
779

780
781
782
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

783
784
        self.interpolation = interpolation
        self.distortion_scale = distortion_scale
785
786
787
788
789
790

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

791
        self.fill = fill
792

793
    def forward(self, img):
794
795
        """
        Args:
796
            img (PIL Image or Tensor): Image to be Perspectively transformed.
797
798

        Returns:
799
            PIL Image or Tensor: Randomly transformed image.
800
        """
801
802

        fill = self.fill
803
        channels, height, width = F.get_dimensions(img)
804
805
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
806
                fill = [float(fill)] * channels
807
808
809
            else:
                fill = [float(f) for f in fill]

810
        if torch.rand(1) < self.p:
811
            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
812
            return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
813
814
815
        return img

    @staticmethod
816
    def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
817
818
819
        """Get parameters for ``perspective`` for a random perspective transform.

        Args:
820
821
822
            width (int): width of the image.
            height (int): height of the image.
            distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
823
824

        Returns:
825
            List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
826
827
            List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
        """
828
829
830
        half_height = height // 2
        half_width = width // 2
        topleft = [
831
832
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
833
834
        ]
        topright = [
835
836
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
837
838
        ]
        botright = [
839
840
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
841
842
        ]
        botleft = [
843
844
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
845
846
        ]
        startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
847
848
849
        endpoints = [topleft, topright, botright, botleft]
        return startpoints, endpoints

Joao Gomes's avatar
Joao Gomes committed
850
851
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
852
853


854
class RandomResizedCrop(torch.nn.Module):
855
856
    """Crop a random portion of image and resize it to a given size.

857
    If the image is torch Tensor, it is expected
858
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
859

860
861
862
    A crop of the original image is made: the crop has a random area (H * W)
    and a random aspect ratio. This crop is finally resized to the given
    size. This is popularly used to train the Inception networks.
863
864

    Args:
865
        size (int or sequence): expected output size of the crop, for each edge. If size is an
866
            int instead of sequence like (h, w), a square output size ``(size, size)`` is
867
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
868
869
870

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
Nicolas Hug's avatar
Nicolas Hug committed
871
872
        scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
            before resizing. The scale is defined with respect to the area of the original image.
873
874
        ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
            resizing.
875
876
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
877
878
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
879
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
        antialias (bool, optional): Whether to apply antialiasing.
            It only affects **tensors** with bilinear or bicubic modes and it is
            ignored otherwise: on PIL images, antialiasing is always applied on
            bilinear or bicubic modes; on other modes (for PIL images and
            tensors), antialiasing makes no sense and this parameter is ignored.
            Possible values are:

            - ``True``: will apply antialiasing for bilinear or bicubic modes.
              Other mode aren't affected. This is probably what you want to use.
            - ``False``: will not apply antialiasing for tensors on any mode. PIL
              images are still antialiased on bilinear or bicubic modes, because
              PIL doesn't support no antialias.
            - ``None``: equivalent to ``False`` for tensors and ``True`` for
              PIL images. This value exists for legacy reasons and you probably
              don't want to use it unless you really know what you are doing.

            The current default is ``None`` **but will change to** ``True`` **in
            v0.17** for the PIL and Tensor backends to be consistent.
898
899
    """

900
901
902
903
904
905
    def __init__(
        self,
        size,
        scale=(0.08, 1.0),
        ratio=(3.0 / 4.0, 4.0 / 3.0),
        interpolation=InterpolationMode.BILINEAR,
906
        antialias: Optional[Union[str, bool]] = "warn",
907
    ):
908
        super().__init__()
909
        _log_api_usage_once(self)
910
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
911

912
        if not isinstance(scale, Sequence):
913
            raise TypeError("Scale should be a sequence")
914
        if not isinstance(ratio, Sequence):
915
            raise TypeError("Ratio should be a sequence")
916
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
917
            warnings.warn("Scale and ratio should be of kind (min, max)")
918

919
920
921
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

922
        self.interpolation = interpolation
923
        self.antialias = antialias
924
925
        self.scale = scale
        self.ratio = ratio
926
927

    @staticmethod
928
    def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
929
930
931
        """Get parameters for ``crop`` for a random sized crop.

        Args:
932
            img (PIL Image or Tensor): Input image.
933
934
            scale (list): range of scale of the origin size cropped
            ratio (list): range of aspect ratio of the origin aspect ratio cropped
935
936
937

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
938
            sized crop.
939
        """
940
        _, height, width = F.get_dimensions(img)
Zhicheng Yan's avatar
Zhicheng Yan committed
941
        area = height * width
942

943
        log_ratio = torch.log(torch.tensor(ratio))
944
        for _ in range(10):
945
            target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
946
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
947
948
949
950

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

Zhicheng Yan's avatar
Zhicheng Yan committed
951
            if 0 < w <= width and 0 < h <= height:
952
953
                i = torch.randint(0, height - h + 1, size=(1,)).item()
                j = torch.randint(0, width - w + 1, size=(1,)).item()
954
955
                return i, j, h, w

956
        # Fallback to central crop
Zhicheng Yan's avatar
Zhicheng Yan committed
957
        in_ratio = float(width) / float(height)
958
        if in_ratio < min(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
959
            w = width
960
            h = int(round(w / min(ratio)))
961
        elif in_ratio > max(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
962
            h = height
963
            w = int(round(h * max(ratio)))
964
        else:  # whole image
Zhicheng Yan's avatar
Zhicheng Yan committed
965
966
967
968
            w = width
            h = height
        i = (height - h) // 2
        j = (width - w) // 2
969
        return i, j, h, w
970

971
    def forward(self, img):
972
973
        """
        Args:
974
            img (PIL Image or Tensor): Image to be cropped and resized.
975
976

        Returns:
977
            PIL Image or Tensor: Randomly cropped and resized image.
978
        """
979
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
980
        return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
981

Joao Gomes's avatar
Joao Gomes committed
982
    def __repr__(self) -> str:
983
        interpolate_str = self.interpolation.value
984
985
986
        format_string = self.__class__.__name__ + f"(size={self.size}"
        format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
        format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
987
        format_string += f", interpolation={interpolate_str}"
988
        format_string += f", antialias={self.antialias})"
989
        return format_string
990

991

vfdev's avatar
vfdev committed
992
993
class FiveCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop.
994
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
995
996
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
997
998
999
1000
1001
1002
1003
1004
1005

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.
1006
            If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
1007
1008
1009
1010

    Example:
         >>> transform = Compose([
         >>>    FiveCrop(size), # this is a list of PIL Images
1011
         >>>    Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
1012
1013
1014
1015
1016
1017
1018
1019
1020
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
1021
        super().__init__()
1022
        _log_api_usage_once(self)
1023
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1024

vfdev's avatar
vfdev committed
1025
1026
1027
1028
1029
1030
1031
1032
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 5 images. Image can be PIL Image or Tensor
        """
1033
1034
        return F.five_crop(img, self.size)

Joao Gomes's avatar
Joao Gomes committed
1035
1036
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size})"
1037

1038

vfdev's avatar
vfdev committed
1039
1040
1041
class TenCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default).
1042
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1043
1044
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
1045
1046
1047
1048
1049
1050
1051
1052
1053

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
1054
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
1055
        vertical_flip (bool): Use vertical flipping instead of horizontal
1056
1057
1058

    Example:
         >>> transform = Compose([
Philip Meier's avatar
Philip Meier committed
1059
         >>>    TenCrop(size), # this is a tuple of PIL Images
1060
         >>>    Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
1061
1062
1063
1064
1065
1066
1067
1068
1069
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size, vertical_flip=False):
vfdev's avatar
vfdev committed
1070
        super().__init__()
1071
        _log_api_usage_once(self)
1072
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1073
1074
        self.vertical_flip = vertical_flip

vfdev's avatar
vfdev committed
1075
1076
1077
1078
1079
1080
1081
1082
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 10 images. Image can be PIL Image or Tensor
        """
1083
1084
        return F.ten_crop(img, self.size, self.vertical_flip)

Joao Gomes's avatar
Joao Gomes committed
1085
1086
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
1087

1088

1089
class LinearTransformation(torch.nn.Module):
ekka's avatar
ekka committed
1090
    """Transform a tensor image with a square transformation matrix and a mean_vector computed
1091
    offline.
1092
    This transform does not support PIL Image.
ekka's avatar
ekka committed
1093
1094
1095
    Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
    subtract mean_vector from it which is then followed by computing the dot
    product with the transformation matrix and then reshaping the tensor to its
1096
    original shape.
1097

1098
    Applications:
1099
        whitening transformation: Suppose X is a column vector zero-centered data.
1100
1101
1102
        Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
        perform SVD on this matrix and pass it as transformation_matrix.

1103
1104
    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
ekka's avatar
ekka committed
1105
        mean_vector (Tensor): tensor [D], D = C x H x W
1106
1107
    """

ekka's avatar
ekka committed
1108
    def __init__(self, transformation_matrix, mean_vector):
1109
        super().__init__()
1110
        _log_api_usage_once(self)
1111
        if transformation_matrix.size(0) != transformation_matrix.size(1):
1112
1113
            raise ValueError(
                "transformation_matrix should be square. Got "
1114
                f"{tuple(transformation_matrix.size())} rectangular matrix."
1115
            )
ekka's avatar
ekka committed
1116
1117

        if mean_vector.size(0) != transformation_matrix.size(0):
1118
            raise ValueError(
1119
1120
                f"mean_vector should have the same length {mean_vector.size(0)}"
                f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
1121
            )
ekka's avatar
ekka committed
1122

1123
        if transformation_matrix.device != mean_vector.device:
1124
            raise ValueError(
1125
                f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
1126
            )
1127

1128
1129
1130
1131
1132
        if transformation_matrix.dtype != mean_vector.dtype:
            raise ValueError(
                f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
            )

1133
        self.transformation_matrix = transformation_matrix
ekka's avatar
ekka committed
1134
        self.mean_vector = mean_vector
1135

1136
    def forward(self, tensor: Tensor) -> Tensor:
1137
1138
        """
        Args:
vfdev's avatar
vfdev committed
1139
            tensor (Tensor): Tensor image to be whitened.
1140
1141
1142
1143

        Returns:
            Tensor: Transformed image.
        """
1144
1145
1146
        shape = tensor.shape
        n = shape[-3] * shape[-2] * shape[-1]
        if n != self.transformation_matrix.shape[0]:
1147
1148
            raise ValueError(
                "Input tensor and transformation matrix have incompatible shape."
1149
1150
                + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
                + f"{self.transformation_matrix.shape[0]}"
1151
            )
1152
1153

        if tensor.device.type != self.mean_vector.device.type:
1154
1155
            raise ValueError(
                "Input tensor should be on the same device as transformation matrix and mean vector. "
1156
                f"Got {tensor.device} vs {self.mean_vector.device}"
1157
            )
1158
1159

        flat_tensor = tensor.view(-1, n) - self.mean_vector
1160
1161
        transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
        transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
1162
1163
        tensor = transformed_tensor.view(shape)
        return tensor
1164

Joao Gomes's avatar
Joao Gomes committed
1165
1166
1167
1168
1169
1170
1171
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}(transformation_matrix="
            f"{self.transformation_matrix.tolist()}"
            f", mean_vector={self.mean_vector.tolist()})"
        )
        return s
1172

1173

1174
class ColorJitter(torch.nn.Module):
1175
    """Randomly change the brightness, contrast, saturation and hue of an image.
1176
    If the image is torch Tensor, it is expected
1177
1178
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
1179
1180

    Args:
yaox12's avatar
yaox12 committed
1181
1182
1183
1184
1185
        brightness (float or tuple of float (min, max)): How much to jitter brightness.
            brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
            or the given [min, max]. Should be non negative numbers.
        contrast (float or tuple of float (min, max)): How much to jitter contrast.
            contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
1186
            or the given [min, max]. Should be non-negative numbers.
yaox12's avatar
yaox12 committed
1187
1188
1189
1190
1191
1192
        saturation (float or tuple of float (min, max)): How much to jitter saturation.
            saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
            or the given [min, max]. Should be non negative numbers.
        hue (float or tuple of float (min, max)): How much to jitter hue.
            hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
            Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1193
1194
1195
            To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
            thus it does not work if you normalize your image to an interval with negative values,
            or use an interpolation that generates negative values before using this function.
1196
    """
1197

1198
1199
1200
1201
1202
1203
1204
    def __init__(
        self,
        brightness: Union[float, Tuple[float, float]] = 0,
        contrast: Union[float, Tuple[float, float]] = 0,
        saturation: Union[float, Tuple[float, float]] = 0,
        hue: Union[float, Tuple[float, float]] = 0,
    ) -> None:
1205
        super().__init__()
1206
        _log_api_usage_once(self)
1207
1208
1209
1210
        self.brightness = self._check_input(brightness, "brightness")
        self.contrast = self._check_input(contrast, "contrast")
        self.saturation = self._check_input(saturation, "saturation")
        self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
yaox12's avatar
yaox12 committed
1211

1212
    @torch.jit.unused
1213
    def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
yaox12's avatar
yaox12 committed
1214
1215
        if isinstance(value, numbers.Number):
            if value < 0:
1216
                raise ValueError(f"If {name} is a single number, it must be non negative.")
1217
            value = [center - float(value), center + float(value)]
yaox12's avatar
yaox12 committed
1218
            if clip_first_on_zero:
1219
                value[0] = max(value[0], 0.0)
yaox12's avatar
yaox12 committed
1220
        elif isinstance(value, (tuple, list)) and len(value) == 2:
1221
            value = [float(value[0]), float(value[1])]
yaox12's avatar
yaox12 committed
1222
        else:
1223
            raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
yaox12's avatar
yaox12 committed
1224

1225
1226
1227
        if not bound[0] <= value[0] <= value[1] <= bound[1]:
            raise ValueError(f"{name} values should be between {bound}, but got {value}.")

yaox12's avatar
yaox12 committed
1228
1229
1230
        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
1231
1232
1233
            return None
        else:
            return tuple(value)
1234
1235

    @staticmethod
1236
1237
1238
1239
1240
1241
    def get_params(
        brightness: Optional[List[float]],
        contrast: Optional[List[float]],
        saturation: Optional[List[float]],
        hue: Optional[List[float]],
    ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
1242
        """Get the parameters for the randomized transform to be applied on image.
1243

1244
1245
1246
1247
1248
1249
1250
1251
1252
        Args:
            brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
                uniformly. Pass None to turn off the transformation.
            contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
                uniformly. Pass None to turn off the transformation.
            saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
                uniformly. Pass None to turn off the transformation.
            hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
                Pass None to turn off the transformation.
1253
1254

        Returns:
1255
1256
            tuple: The parameters used to apply the randomized transform
            along with their random order.
1257
        """
1258
        fn_idx = torch.randperm(4)
1259

1260
1261
1262
1263
        b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
        c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
        s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
        h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
1264

1265
        return fn_idx, b, c, s, h
1266

1267
    def forward(self, img):
1268
1269
        """
        Args:
1270
            img (PIL Image or Tensor): Input image.
1271
1272

        Returns:
1273
1274
            PIL Image or Tensor: Color jittered image.
        """
1275
1276
1277
        fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
            self.brightness, self.contrast, self.saturation, self.hue
        )
1278

1279
        for fn_id in fn_idx:
1280
            if fn_id == 0 and brightness_factor is not None:
1281
                img = F.adjust_brightness(img, brightness_factor)
1282
            elif fn_id == 1 and contrast_factor is not None:
1283
                img = F.adjust_contrast(img, contrast_factor)
1284
            elif fn_id == 2 and saturation_factor is not None:
1285
                img = F.adjust_saturation(img, saturation_factor)
1286
            elif fn_id == 3 and hue_factor is not None:
1287
1288
1289
                img = F.adjust_hue(img, hue_factor)

        return img
1290

Joao Gomes's avatar
Joao Gomes committed
1291
1292
1293
1294
1295
1296
1297
1298
1299
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}("
            f"brightness={self.brightness}"
            f", contrast={self.contrast}"
            f", saturation={self.saturation}"
            f", hue={self.hue})"
        )
        return s
1300

1301

1302
class RandomRotation(torch.nn.Module):
1303
    """Rotate the image by angle.
1304
    If the image is torch Tensor, it is expected
1305
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1306
1307

    Args:
1308
        degrees (sequence or number): Range of degrees to select from.
1309
1310
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees).
1311
1312
1313
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1314
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1315
1316
1317
1318
        expand (bool, optional): Optional expansion flag.
            If true, expands the output to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1319
        center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1320
            Default is the center of the image.
1321
1322
        fill (sequence or number): Pixel fill value for the area outside the rotated
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1323
1324
1325

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1326
1327
    """

1328
    def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0):
1329
        super().__init__()
1330
        _log_api_usage_once(self)
1331

1332
1333
1334
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

1335
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1336
1337

        if center is not None:
1338
            _check_sequence_input(center, "center", req_sizes=(2,))
1339
1340

        self.center = center
1341

1342
        self.interpolation = interpolation
1343
        self.expand = expand
1344
1345
1346
1347
1348
1349

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1350
        self.fill = fill
1351
1352

    @staticmethod
1353
    def get_params(degrees: List[float]) -> float:
1354
1355
1356
        """Get parameters for ``rotate`` for a random rotation.

        Returns:
1357
            float: angle parameter to be passed to ``rotate`` for random rotation.
1358
        """
1359
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1360
1361
        return angle

1362
    def forward(self, img):
1363
        """
1364
        Args:
1365
            img (PIL Image or Tensor): Image to be rotated.
1366
1367

        Returns:
1368
            PIL Image or Tensor: Rotated image.
1369
        """
1370
        fill = self.fill
1371
        channels, _, _ = F.get_dimensions(img)
1372
1373
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
1374
                fill = [float(fill)] * channels
1375
1376
            else:
                fill = [float(f) for f in fill]
1377
        angle = self.get_params(self.degrees)
1378

1379
        return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)
1380

Joao Gomes's avatar
Joao Gomes committed
1381
    def __repr__(self) -> str:
1382
        interpolate_str = self.interpolation.value
1383
1384
1385
        format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
        format_string += f", interpolation={interpolate_str}"
        format_string += f", expand={self.expand}"
1386
        if self.center is not None:
1387
            format_string += f", center={self.center}"
1388
        if self.fill is not None:
1389
            format_string += f", fill={self.fill}"
1390
        format_string += ")"
1391
        return format_string
1392

1393

1394
1395
class RandomAffine(torch.nn.Module):
    """Random affine transformation of the image keeping center invariant.
1396
    If the image is torch Tensor, it is expected
1397
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1398
1399

    Args:
1400
        degrees (sequence or number): Range of degrees to select from.
1401
            If degrees is a number instead of sequence like (min, max), the range of degrees
1402
            will be (-degrees, +degrees). Set to 0 to deactivate rotations.
1403
1404
1405
1406
1407
1408
        translate (tuple, optional): tuple of maximum absolute fraction for horizontal
            and vertical translations. For example translate=(a, b), then horizontal shift
            is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
            randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
        scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
            randomly sampled from the range a <= scale <= b. Will keep original scale by default.
1409
        shear (sequence or number, optional): Range of degrees to select from.
1410
1411
            If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
            will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
1412
            range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
1413
            an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1414
            Will not apply shear by default.
1415
1416
1417
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1418
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
1419
1420
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1421
1422
        center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
            Default is the center of the image.
1423
1424
1425

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1426
1427
    """

1428
    def __init__(
1429
1430
1431
1432
1433
1434
1435
        self,
        degrees,
        translate=None,
        scale=None,
        shear=None,
        interpolation=InterpolationMode.NEAREST,
        fill=0,
1436
        center=None,
1437
    ):
1438
        super().__init__()
1439
        _log_api_usage_once(self)
1440

1441
1442
1443
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

1444
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1445
1446

        if translate is not None:
1447
            _check_sequence_input(translate, "translate", req_sizes=(2,))
1448
1449
1450
1451
1452
1453
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
1454
            _check_sequence_input(scale, "scale", req_sizes=(2,))
1455
1456
1457
1458
1459
1460
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
1461
            self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
1462
1463
1464
        else:
            self.shear = shear

1465
        self.interpolation = interpolation
1466
1467
1468
1469
1470
1471

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1472
        self.fill = fill
1473

1474
1475
1476
1477
1478
        if center is not None:
            _check_sequence_input(center, "center", req_sizes=(2,))

        self.center = center

1479
    @staticmethod
1480
    def get_params(
1481
1482
1483
1484
1485
        degrees: List[float],
        translate: Optional[List[float]],
        scale_ranges: Optional[List[float]],
        shears: Optional[List[float]],
        img_size: List[int],
1486
    ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
1487
1488
1489
        """Get parameters for affine transformation

        Returns:
1490
            params to be passed to the affine transformation
1491
        """
1492
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1493
        if translate is not None:
1494
1495
1496
1497
1498
            max_dx = float(translate[0] * img_size[0])
            max_dy = float(translate[1] * img_size[1])
            tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
            ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
            translations = (tx, ty)
1499
1500
1501
1502
        else:
            translations = (0, 0)

        if scale_ranges is not None:
1503
            scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
1504
1505
1506
        else:
            scale = 1.0

1507
        shear_x = shear_y = 0.0
1508
        if shears is not None:
1509
1510
1511
1512
1513
            shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
            if len(shears) == 4:
                shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())

        shear = (shear_x, shear_y)
1514
1515
1516

        return angle, translations, scale, shear

1517
    def forward(self, img):
1518
        """
1519
            img (PIL Image or Tensor): Image to be transformed.
1520
1521

        Returns:
1522
            PIL Image or Tensor: Affine transformed image.
1523
        """
1524
        fill = self.fill
1525
        channels, height, width = F.get_dimensions(img)
1526
1527
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
1528
                fill = [float(fill)] * channels
1529
1530
            else:
                fill = [float(f) for f in fill]
1531

1532
        img_size = [width, height]  # flip for keeping BC on get_params call
1533
1534

        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
1535

1536
        return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
1537

Joao Gomes's avatar
Joao Gomes committed
1538
1539
1540
1541
1542
1543
1544
1545
    def __repr__(self) -> str:
        s = f"{self.__class__.__name__}(degrees={self.degrees}"
        s += f", translate={self.translate}" if self.translate is not None else ""
        s += f", scale={self.scale}" if self.scale is not None else ""
        s += f", shear={self.shear}" if self.shear is not None else ""
        s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
        s += f", fill={self.fill}" if self.fill != 0 else ""
        s += f", center={self.center}" if self.center is not None else ""
1546
        s += ")"
Joao Gomes's avatar
Joao Gomes committed
1547
1548

        return s
1549
1550


1551
class Grayscale(torch.nn.Module):
1552
    """Convert image to grayscale.
1553
1554
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1555

1556
1557
1558
1559
    Args:
        num_output_channels (int): (1 or 3) number of channels desired for output image

    Returns:
1560
        PIL Image: Grayscale version of the input.
1561
1562
1563

        - If ``num_output_channels == 1`` : returned image is single channel
        - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
1564
1565
1566
1567

    """

    def __init__(self, num_output_channels=1):
1568
        super().__init__()
1569
        _log_api_usage_once(self)
1570
1571
        self.num_output_channels = num_output_channels

vfdev's avatar
vfdev committed
1572
    def forward(self, img):
1573
1574
        """
        Args:
1575
            img (PIL Image or Tensor): Image to be converted to grayscale.
1576
1577

        Returns:
1578
            PIL Image or Tensor: Grayscaled image.
1579
        """
1580
        return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
1581

Joao Gomes's avatar
Joao Gomes committed
1582
1583
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})"
1584

1585

1586
class RandomGrayscale(torch.nn.Module):
1587
    """Randomly convert image to grayscale with a probability of p (default 0.1).
1588
1589
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1590

1591
1592
1593
1594
    Args:
        p (float): probability that image should be converted to grayscale.

    Returns:
1595
        PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
1596
1597
1598
        with probability (1-p).
        - If input image is 1 channel: grayscale version is 1 channel
        - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1599
1600
1601
1602

    """

    def __init__(self, p=0.1):
1603
        super().__init__()
1604
        _log_api_usage_once(self)
1605
1606
        self.p = p

vfdev's avatar
vfdev committed
1607
    def forward(self, img):
1608
1609
        """
        Args:
1610
            img (PIL Image or Tensor): Image to be converted to grayscale.
1611
1612

        Returns:
1613
            PIL Image or Tensor: Randomly grayscaled image.
1614
        """
1615
        num_output_channels, _, _ = F.get_dimensions(img)
1616
1617
        if torch.rand(1) < self.p:
            return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
1618
        return img
1619

Joao Gomes's avatar
Joao Gomes committed
1620
1621
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
1622
1623


1624
class RandomErasing(torch.nn.Module):
1625
    """Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
1626
    This transform does not support PIL Image.
vfdev's avatar
vfdev committed
1627
    'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
1628

1629
1630
1631
1632
1633
1634
1635
1636
    Args:
         p: probability that the random erasing operation will be performed.
         scale: range of proportion of erased area against input image.
         ratio: range of aspect ratio of erased area.
         value: erasing value. Default is 0. If a single int, it is used to
            erase all pixels. If a tuple of length 3, it is used to erase
            R, G, B channels respectively.
            If a str of 'random', erasing each pixel with random values.
Zhun Zhong's avatar
Zhun Zhong committed
1637
         inplace: boolean to make this transform inplace. Default set to False.
1638

1639
1640
    Returns:
        Erased Image.
1641

vfdev's avatar
vfdev committed
1642
    Example:
1643
        >>> transform = transforms.Compose([
1644
        >>>   transforms.RandomHorizontalFlip(),
1645
1646
        >>>   transforms.PILToTensor(),
        >>>   transforms.ConvertImageDtype(torch.float),
1647
1648
        >>>   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>>   transforms.RandomErasing(),
1649
1650
1651
        >>> ])
    """

Zhun Zhong's avatar
Zhun Zhong committed
1652
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
1653
        super().__init__()
1654
        _log_api_usage_once(self)
1655
1656
1657
1658
1659
1660
1661
1662
        if not isinstance(value, (numbers.Number, str, tuple, list)):
            raise TypeError("Argument value should be either a number or str or a sequence")
        if isinstance(value, str) and value != "random":
            raise ValueError("If value is str, it should be 'random'")
        if not isinstance(scale, (tuple, list)):
            raise TypeError("Scale should be a sequence")
        if not isinstance(ratio, (tuple, list)):
            raise TypeError("Ratio should be a sequence")
1663
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
1664
            warnings.warn("Scale and ratio should be of kind (min, max)")
1665
        if scale[0] < 0 or scale[1] > 1:
1666
            raise ValueError("Scale should be between 0 and 1")
1667
        if p < 0 or p > 1:
1668
            raise ValueError("Random erasing probability should be between 0 and 1")
1669
1670
1671
1672
1673

        self.p = p
        self.scale = scale
        self.ratio = ratio
        self.value = value
1674
        self.inplace = inplace
1675
1676

    @staticmethod
1677
    def get_params(
1678
        img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
1679
    ) -> Tuple[int, int, int, int, Tensor]:
1680
1681
1682
        """Get parameters for ``erase`` for a random erasing.

        Args:
vfdev's avatar
vfdev committed
1683
            img (Tensor): Tensor image to be erased.
1684
1685
            scale (sequence): range of proportion of erased area against input image.
            ratio (sequence): range of aspect ratio of erased area.
1686
1687
1688
            value (list, optional): erasing value. If None, it is interpreted as "random"
                (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
                i.e. ``value[0]``.
1689
1690
1691
1692

        Returns:
            tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
        """
vfdev's avatar
vfdev committed
1693
        img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
1694
        area = img_h * img_w
1695

1696
        log_ratio = torch.log(torch.tensor(ratio))
1697
        for _ in range(10):
1698
            erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
1699
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
1700
1701
1702

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
1703
1704
1705
1706
1707
1708
1709
            if not (h < img_h and w < img_w):
                continue

            if value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(value)[:, None, None]
1710

1711
1712
            i = torch.randint(0, img_h - h + 1, size=(1,)).item()
            j = torch.randint(0, img_w - w + 1, size=(1,)).item()
1713
            return i, j, h, w, v
1714

Zhun Zhong's avatar
Zhun Zhong committed
1715
1716
1717
        # Return original image
        return 0, 0, img_h, img_w, img

1718
    def forward(self, img):
1719
1720
        """
        Args:
vfdev's avatar
vfdev committed
1721
            img (Tensor): Tensor image to be erased.
1722
1723
1724
1725

        Returns:
            img (Tensor): Erased Tensor image.
        """
1726
1727
1728
1729
        if torch.rand(1) < self.p:

            # cast self.value to script acceptable type
            if isinstance(self.value, (int, float)):
1730
                value = [float(self.value)]
1731
1732
            elif isinstance(self.value, str):
                value = None
1733
1734
            elif isinstance(self.value, (list, tuple)):
                value = [float(v) for v in self.value]
1735
1736
1737
1738
1739
1740
            else:
                value = self.value

            if value is not None and not (len(value) in (1, img.shape[-3])):
                raise ValueError(
                    "If value is a sequence, it should have either a single value or "
1741
                    f"{img.shape[-3]} (number of input channels)"
1742
1743
1744
                )

            x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
1745
            return F.erase(img, x, y, h, w, v, self.inplace)
1746
        return img
1747

Joao Gomes's avatar
Joao Gomes committed
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}"
            f"(p={self.p}, "
            f"scale={self.scale}, "
            f"ratio={self.ratio}, "
            f"value={self.value}, "
            f"inplace={self.inplace})"
        )
        return s
1758

1759

1760
1761
class GaussianBlur(torch.nn.Module):
    """Blurs image with randomly chosen Gaussian blur.
1762
1763
    If the image is torch Tensor, it is expected
    to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778

    Args:
        kernel_size (int or sequence): Size of the Gaussian kernel.
        sigma (float or tuple of float (min, max)): Standard deviation to be used for
            creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
            of float (min, max), sigma is chosen uniformly at random to lie in the
            given range.

    Returns:
        PIL Image or Tensor: Gaussian blurred version of the input image.

    """

    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        super().__init__()
1779
        _log_api_usage_once(self)
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
        self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
        for ks in self.kernel_size:
            if ks <= 0 or ks % 2 == 0:
                raise ValueError("Kernel size value should be an odd and positive number.")

        if isinstance(sigma, numbers.Number):
            if sigma <= 0:
                raise ValueError("If sigma is a single number, it must be positive.")
            sigma = (sigma, sigma)
        elif isinstance(sigma, Sequence) and len(sigma) == 2:
1790
            if not 0.0 < sigma[0] <= sigma[1]:
1791
1792
1793
1794
1795
1796
1797
1798
                raise ValueError("sigma values should be positive and of the form (min, max).")
        else:
            raise ValueError("sigma should be a single number or a list/tuple with length 2.")

        self.sigma = sigma

    @staticmethod
    def get_params(sigma_min: float, sigma_max: float) -> float:
vfdev's avatar
vfdev committed
1799
        """Choose sigma for random gaussian blurring.
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812

        Args:
            sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
            sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.

        Returns:
            float: Standard deviation to be passed to calculate kernel for gaussian blurring.
        """
        return torch.empty(1).uniform_(sigma_min, sigma_max).item()

    def forward(self, img: Tensor) -> Tensor:
        """
        Args:
vfdev's avatar
vfdev committed
1813
            img (PIL Image or Tensor): image to be blurred.
1814
1815
1816
1817
1818
1819
1820

        Returns:
            PIL Image or Tensor: Gaussian blurred image
        """
        sigma = self.get_params(self.sigma[0], self.sigma[1])
        return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])

Joao Gomes's avatar
Joao Gomes committed
1821
1822
1823
    def __repr__(self) -> str:
        s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
        return s
1824
1825


1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
def _setup_size(size, error_msg):
    if isinstance(size, numbers.Number):
        return int(size), int(size)

    if isinstance(size, Sequence) and len(size) == 1:
        return size[0], size[0]

    if len(size) != 2:
        raise ValueError(error_msg)

    return size


def _check_sequence_input(x, name, req_sizes):
    msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
    if not isinstance(x, Sequence):
1842
        raise TypeError(f"{name} should be a sequence of length {msg}.")
1843
    if len(x) not in req_sizes:
1844
        raise ValueError(f"{name} should be a sequence of length {msg}.")
1845
1846


1847
def _setup_angle(x, name, req_sizes=(2,)):
1848
1849
    if isinstance(x, numbers.Number):
        if x < 0:
1850
            raise ValueError(f"If {name} is a single number, it must be positive.")
1851
1852
1853
1854
1855
        x = [-x, x]
    else:
        _check_sequence_input(x, name, req_sizes)

    return [float(d) for d in x]
1856
1857
1858
1859


class RandomInvert(torch.nn.Module):
    """Inverts the colors of the given image randomly with a given probability.
1860
1861
1862
    If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1863
1864
1865
1866
1867
1868
1869

    Args:
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
1870
        _log_api_usage_once(self)
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be inverted.

        Returns:
            PIL Image or Tensor: Randomly color inverted image.
        """
        if torch.rand(1).item() < self.p:
            return F.invert(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
1885
1886
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
1887
1888
1889
1890


class RandomPosterize(torch.nn.Module):
    """Posterize the image randomly with a given probability by reducing the
1891
1892
1893
    number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
    and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1894
1895
1896

    Args:
        bits (int): number of bits to keep for each channel (0-8)
1897
        p (float): probability of the image being posterized. Default value is 0.5
1898
1899
1900
1901
    """

    def __init__(self, bits, p=0.5):
        super().__init__()
1902
        _log_api_usage_once(self)
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
        self.bits = bits
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be posterized.

        Returns:
            PIL Image or Tensor: Randomly posterized image.
        """
        if torch.rand(1).item() < self.p:
            return F.posterize(img, self.bits)
        return img

Joao Gomes's avatar
Joao Gomes committed
1918
1919
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
1920
1921
1922
1923


class RandomSolarize(torch.nn.Module):
    """Solarize the image randomly with a given probability by inverting all pixel
1924
1925
1926
    values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1927
1928
1929

    Args:
        threshold (float): all pixels equal or above this value are inverted.
1930
        p (float): probability of the image being solarized. Default value is 0.5
1931
1932
1933
1934
    """

    def __init__(self, threshold, p=0.5):
        super().__init__()
1935
        _log_api_usage_once(self)
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
        self.threshold = threshold
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be solarized.

        Returns:
            PIL Image or Tensor: Randomly solarized image.
        """
        if torch.rand(1).item() < self.p:
            return F.solarize(img, self.threshold)
        return img

Joao Gomes's avatar
Joao Gomes committed
1951
1952
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
1953
1954
1955


class RandomAdjustSharpness(torch.nn.Module):
1956
1957
    """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
    it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1958
1959
1960

    Args:
        sharpness_factor (float):  How much to adjust the sharpness. Can be
1961
            any non-negative number. 0 gives a blurred image, 1 gives the
1962
            original image while 2 increases the sharpness by a factor of 2.
1963
        p (float): probability of the image being sharpened. Default value is 0.5
1964
1965
1966
1967
    """

    def __init__(self, sharpness_factor, p=0.5):
        super().__init__()
1968
        _log_api_usage_once(self)
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
        self.sharpness_factor = sharpness_factor
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be sharpened.

        Returns:
            PIL Image or Tensor: Randomly sharpened image.
        """
        if torch.rand(1).item() < self.p:
            return F.adjust_sharpness(img, self.sharpness_factor)
        return img

Joao Gomes's avatar
Joao Gomes committed
1984
1985
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
1986
1987
1988
1989


class RandomAutocontrast(torch.nn.Module):
    """Autocontrast the pixels of the given image randomly with a given probability.
1990
1991
1992
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1993
1994
1995
1996
1997
1998
1999

    Args:
        p (float): probability of the image being autocontrasted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
2000
        _log_api_usage_once(self)
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be autocontrasted.

        Returns:
            PIL Image or Tensor: Randomly autocontrasted image.
        """
        if torch.rand(1).item() < self.p:
            return F.autocontrast(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
2015
2016
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
2017
2018
2019
2020


class RandomEqualize(torch.nn.Module):
    """Equalize the histogram of the given image randomly with a given probability.
2021
2022
2023
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
2024
2025
2026
2027
2028
2029
2030

    Args:
        p (float): probability of the image being equalized. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
2031
        _log_api_usage_once(self)
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be equalized.

        Returns:
            PIL Image or Tensor: Randomly equalized image.
        """
        if torch.rand(1).item() < self.p:
            return F.equalize(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
2046
2047
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067


class ElasticTransform(torch.nn.Module):
    """Transform a tensor image with elastic transformations.
    Given alpha and sigma, it will generate displacement
    vectors for all pixels based on random offsets. Alpha controls the strength
    and sigma controls the smoothness of the displacements.
    The displacements are added to an identity grid and the resulting grid is
    used to grid_sample from the image.

    Applications:
        Randomly transforms the morphology of objects in images and produces a
        see-through-water-like effect.

    Args:
        alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
        sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
2068
            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.

    """

    def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
        super().__init__()
        _log_api_usage_once(self)
        if not isinstance(alpha, (float, Sequence)):
            raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
        if isinstance(alpha, Sequence) and len(alpha) != 2:
            raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
        if isinstance(alpha, Sequence):
            for element in alpha:
                if not isinstance(element, float):
                    raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")

        if isinstance(alpha, float):
            alpha = [float(alpha), float(alpha)]
        if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
            alpha = [alpha[0], alpha[0]]

        self.alpha = alpha

        if not isinstance(sigma, (float, Sequence)):
            raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
        if isinstance(sigma, Sequence) and len(sigma) != 2:
            raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
        if isinstance(sigma, Sequence):
            for element in sigma:
                if not isinstance(element, float):
                    raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")

        if isinstance(sigma, float):
            sigma = [float(sigma), float(sigma)]
        if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
            sigma = [sigma[0], sigma[0]]

        self.sigma = sigma

        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)
        self.interpolation = interpolation

2113
2114
2115
2116
2117
2118
        if isinstance(fill, (int, float)):
            fill = [float(fill)]
        elif isinstance(fill, (list, tuple)):
            fill = [float(f) for f in fill]
        else:
            raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
        self.fill = fill

    @staticmethod
    def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
        dx = torch.rand([1, 1] + size) * 2 - 1
        if sigma[0] > 0.0:
            kx = int(8 * sigma[0] + 1)
            # if kernel size is even we have to make it odd
            if kx % 2 == 0:
                kx += 1
            dx = F.gaussian_blur(dx, [kx, kx], sigma)
        dx = dx * alpha[0] / size[0]

        dy = torch.rand([1, 1] + size) * 2 - 1
        if sigma[1] > 0.0:
            ky = int(8 * sigma[1] + 1)
            # if kernel size is even we have to make it odd
            if ky % 2 == 0:
                ky += 1
            dy = F.gaussian_blur(dy, [ky, ky], sigma)
        dy = dy * alpha[1] / size[1]
        return torch.concat([dx, dy], 1).permute([0, 2, 3, 1])  # 1 x H x W x 2

    def forward(self, tensor: Tensor) -> Tensor:
        """
        Args:
2145
            tensor (PIL Image or Tensor): Image to be transformed.
2146
2147
2148
2149

        Returns:
            PIL Image or Tensor: Transformed image.
        """
2150
2151
        _, height, width = F.get_dimensions(tensor)
        displacement = self.get_params(self.alpha, self.sigma, [height, width])
2152
2153
2154
        return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)

    def __repr__(self):
2155
2156
2157
2158
2159
        format_string = self.__class__.__name__
        format_string += f"(alpha={self.alpha}"
        format_string += f", sigma={self.sigma}"
        format_string += f", interpolation={self.interpolation}"
        format_string += f", fill={self.fill})"
2160
        return format_string