"integration-tests/vscode:/vscode.git/clone" did not exist on "967ced2ff4565a5358d45a1372d32fbab113700b"
transforms.py 83.2 KB
Newer Older
1
import math
vfdev's avatar
vfdev committed
2
import numbers
3
import random
vfdev's avatar
vfdev committed
4
import warnings
vfdev's avatar
vfdev committed
5
from collections.abc import Sequence
6
from typing import List, Optional, Tuple, Union
vfdev's avatar
vfdev committed
7
8
9
10

import torch
from torch import Tensor

11
12
13
14
15
try:
    import accimage
except ImportError:
    accimage = None

16
from ..utils import _log_api_usage_once
17
from . import functional as F
18
from .functional import _interpolation_modes_from_int, InterpolationMode
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
__all__ = [
    "Compose",
    "ToTensor",
    "PILToTensor",
    "ConvertImageDtype",
    "ToPILImage",
    "Normalize",
    "Resize",
    "CenterCrop",
    "Pad",
    "Lambda",
    "RandomApply",
    "RandomChoice",
    "RandomOrder",
    "RandomCrop",
    "RandomHorizontalFlip",
    "RandomVerticalFlip",
    "RandomResizedCrop",
    "FiveCrop",
    "TenCrop",
    "LinearTransformation",
    "ColorJitter",
    "RandomRotation",
    "RandomAffine",
    "Grayscale",
    "RandomGrayscale",
    "RandomPerspective",
    "RandomErasing",
    "GaussianBlur",
    "InterpolationMode",
    "RandomInvert",
    "RandomPosterize",
    "RandomSolarize",
    "RandomAdjustSharpness",
    "RandomAutocontrast",
    "RandomEqualize",
56
    "ElasticTransform",
57
]
58

59

60
class Compose:
61
62
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.
63
64
65
66
67
68
69

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
70
71
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
72
        >>> ])
73
74
75
76
77
78
79
80
81
82
83
84
85

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

86
87
88
    """

    def __init__(self, transforms):
89
90
        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            _log_api_usage_once(self)
91
92
93
94
95
96
97
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
98
    def __repr__(self) -> str:
99
        format_string = self.__class__.__name__ + "("
100
        for t in self.transforms:
101
            format_string += "\n"
102
            format_string += f"    {t}"
103
        format_string += "\n)"
104
105
        return format_string

106

107
class ToTensor:
108
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
109
110

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
surgan12's avatar
surgan12 committed
111
112
113
114
115
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.
116
117
118
119
120

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

121
    .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
122
123
    """

124
125
126
    def __init__(self) -> None:
        _log_api_usage_once(self)

127
128
129
130
131
132
133
134
135
136
    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

Joao Gomes's avatar
Joao Gomes committed
137
138
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
139

140

141
class PILToTensor:
142
    """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript.
143

vfdev's avatar
vfdev committed
144
    Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
145
146
    """

147
148
149
    def __init__(self) -> None:
        _log_api_usage_once(self)

150
151
    def __call__(self, pic):
        """
152
153
154
155
        .. note::

            A deep copy of the underlying array is performed.

156
157
158
159
160
161
162
163
        Args:
            pic (PIL Image): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.pil_to_tensor(pic)

Joao Gomes's avatar
Joao Gomes committed
164
165
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
166
167


168
class ConvertImageDtype(torch.nn.Module):
169
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly
170
    This function does not support PIL Image.
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

    Args:
        dtype (torch.dtype): Desired data type of the output

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """

    def __init__(self, dtype: torch.dtype) -> None:
188
        super().__init__()
189
        _log_api_usage_once(self)
190
191
        self.dtype = dtype

vfdev's avatar
vfdev committed
192
    def forward(self, image):
193
194
195
        return F.convert_image_dtype(image, self.dtype)


196
class ToPILImage:
197
    """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.
198
199
200
201
202
203
204

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
vfdev's avatar
vfdev committed
205
206
207
208
209
            - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
            - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
            - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
            - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
            ``short``).
210

csukuangfj's avatar
csukuangfj committed
211
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
212
    """
213

214
    def __init__(self, mode=None):
215
        _log_api_usage_once(self)
216
217
218
219
220
221
222
223
224
225
226
227
228
        self.mode = mode

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        return F.to_pil_image(pic, self.mode)

Joao Gomes's avatar
Joao Gomes committed
229
    def __repr__(self) -> str:
230
        format_string = self.__class__.__name__ + "("
231
        if self.mode is not None:
232
            format_string += f"mode={self.mode}"
233
        format_string += ")"
234
        return format_string
235

236

237
class Normalize(torch.nn.Module):
Fang Gao's avatar
Fang Gao committed
238
    """Normalize a tensor image with mean and standard deviation.
239
    This transform does not support PIL Image.
240
241
242
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
abdjava's avatar
abdjava committed
243
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
244

245
    .. note::
246
        This transform acts out of place, i.e., it does not mutate the input tensor.
247

248
249
250
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
251
252
        inplace(bool,optional): Bool to make this operation in-place.

253
254
    """

surgan12's avatar
surgan12 committed
255
    def __init__(self, mean, std, inplace=False):
256
        super().__init__()
257
        _log_api_usage_once(self)
258
259
        self.mean = mean
        self.std = std
surgan12's avatar
surgan12 committed
260
        self.inplace = inplace
261

262
    def forward(self, tensor: Tensor) -> Tensor:
263
264
        """
        Args:
vfdev's avatar
vfdev committed
265
            tensor (Tensor): Tensor image to be normalized.
266
267
268
269

        Returns:
            Tensor: Normalized Tensor image.
        """
surgan12's avatar
surgan12 committed
270
        return F.normalize(tensor, self.mean, self.std, self.inplace)
271

Joao Gomes's avatar
Joao Gomes committed
272
273
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
274

275

vfdev's avatar
vfdev committed
276
277
class Resize(torch.nn.Module):
    """Resize the input image to the given size.
278
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
279
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
280

281
282
283
284
    .. warning::
        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
285
286
        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
        closer.
287

288
289
290
291
292
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
293
            (size * height / width, size).
294
295
296

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
297
298
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
299
300
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
301
302
303
304
        max_size (int, optional): The maximum allowed for the longer edge of
            the resized image: if the longer edge of the image is greater
            than ``max_size`` after being resized according to ``size``, then
            the image is resized again so that the longer edge is equal to
305
            ``max_size``. As a result, ``size`` might be overruled, i.e. the
306
307
308
            smaller edge may be shorter than ``size``. This is only supported
            if ``size`` is an int (or a sequence of length 1 in torchscript
            mode).
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        antialias (bool, optional): Whether to apply antialiasing.
            It only affects **tensors** with bilinear or bicubic modes and it is
            ignored otherwise: on PIL images, antialiasing is always applied on
            bilinear or bicubic modes; on other modes (for PIL images and
            tensors), antialiasing makes no sense and this parameter is ignored.
            Possible values are:

            - ``True``: will apply antialiasing for bilinear or bicubic modes.
              Other mode aren't affected. This is probably what you want to use.
            - ``False``: will not apply antialiasing for tensors on any mode. PIL
              images are still antialiased on bilinear or bicubic modes, because
              PIL doesn't support no antialias.
            - ``None``: equivalent to ``False`` for tensors and ``True`` for
              PIL images. This value exists for legacy reasons and you probably
              don't want to use it unless you really know what you are doing.

            The current default is ``None`` **but will change to** ``True`` **in
            v0.17** for the PIL and Tensor backends to be consistent.
327
328
    """

329
    def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias="warn"):
vfdev's avatar
vfdev committed
330
        super().__init__()
331
        _log_api_usage_once(self)
332
        if not isinstance(size, (int, Sequence)):
333
            raise TypeError(f"Size should be int or sequence. Got {type(size)}")
334
335
336
        if isinstance(size, Sequence) and len(size) not in (1, 2):
            raise ValueError("If size is a sequence, it should have 1 or 2 values")
        self.size = size
337
        self.max_size = max_size
338

339
        self.interpolation = interpolation
340
        self.antialias = antialias
341

vfdev's avatar
vfdev committed
342
    def forward(self, img):
343
344
        """
        Args:
vfdev's avatar
vfdev committed
345
            img (PIL Image or Tensor): Image to be scaled.
346
347

        Returns:
vfdev's avatar
vfdev committed
348
            PIL Image or Tensor: Rescaled image.
349
        """
350
        return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
351

Joao Gomes's avatar
Joao Gomes committed
352
    def __repr__(self) -> str:
353
        detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
Joao Gomes's avatar
Joao Gomes committed
354
        return f"{self.__class__.__name__}{detail}"
355

356

vfdev's avatar
vfdev committed
357
358
class CenterCrop(torch.nn.Module):
    """Crops the given image at the center.
359
    If the image is torch Tensor, it is expected
360
361
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
362
363
364
365

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
366
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
367
368
369
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
370
        super().__init__()
371
        _log_api_usage_once(self)
372
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
373

vfdev's avatar
vfdev committed
374
    def forward(self, img):
375
376
        """
        Args:
vfdev's avatar
vfdev committed
377
            img (PIL Image or Tensor): Image to be cropped.
378
379

        Returns:
vfdev's avatar
vfdev committed
380
            PIL Image or Tensor: Cropped image.
381
382
383
        """
        return F.center_crop(img, self.size)

Joao Gomes's avatar
Joao Gomes committed
384
385
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size})"
386

387

388
389
class Pad(torch.nn.Module):
    """Pad the given image on all sides with the given "pad" value.
390
    If the image is torch Tensor, it is expected
391
392
393
    to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
    at most 3 leading dimensions for mode edge,
    and an arbitrary number of leading dimensions for mode constant
394
395

    Args:
396
397
398
        padding (int or sequence): Padding on each border. If a single int is provided this
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
399
            this is the padding for the left, top, right and bottom borders respectively.
400
401
402
403

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
404
        fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
405
            length 3, it is used to fill R, G, B channels respectively.
406
407
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
408
            Only int or tuple value is supported for PIL Image.
409
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
410
            Default is constant.
411
412
413

            - constant: pads with a constant value, this value is specified with fill

414
415
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
416

417
418
419
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
420

421
422
423
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
424
425
    """

426
427
    def __init__(self, padding, fill=0, padding_mode="constant"):
        super().__init__()
428
        _log_api_usage_once(self)
429
430
431
        if not isinstance(padding, (numbers.Number, tuple, list)):
            raise TypeError("Got inappropriate padding arg")

432
        if not isinstance(fill, (numbers.Number, tuple, list)):
433
434
435
436
437
438
            raise TypeError("Got inappropriate fill arg")

        if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
            raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

        if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
439
            raise ValueError(
440
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
441
            )
442
443
444

        self.padding = padding
        self.fill = fill
445
        self.padding_mode = padding_mode
446

447
    def forward(self, img):
448
449
        """
        Args:
450
            img (PIL Image or Tensor): Image to be padded.
451
452

        Returns:
453
            PIL Image or Tensor: Padded image.
454
        """
455
        return F.pad(img, self.padding, self.fill, self.padding_mode)
456

Joao Gomes's avatar
Joao Gomes committed
457
458
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
459

460

461
class Lambda:
462
    """Apply a user-defined lambda as a transform. This transform does not support torchscript.
463
464
465
466
467
468

    Args:
        lambd (function): Lambda/function to be used for transform.
    """

    def __init__(self, lambd):
469
        _log_api_usage_once(self)
470
        if not callable(lambd):
471
            raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
472
473
474
475
476
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

Joao Gomes's avatar
Joao Gomes committed
477
478
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"
479

480

481
class RandomTransforms:
482
483
484
    """Base class for a list of transformations with randomness

    Args:
485
        transforms (sequence): list of transformations
486
487
488
    """

    def __init__(self, transforms):
489
        _log_api_usage_once(self)
490
491
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence")
492
493
494
495
496
        self.transforms = transforms

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()

Joao Gomes's avatar
Joao Gomes committed
497
    def __repr__(self) -> str:
498
        format_string = self.__class__.__name__ + "("
499
        for t in self.transforms:
500
            format_string += "\n"
501
            format_string += f"    {t}"
502
        format_string += "\n)"
503
504
505
        return format_string


506
class RandomApply(torch.nn.Module):
507
    """Apply randomly a list of transformations with a given probability.
508
509
510
511
512
513
514
515
516
517
518
519

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.
520
521

    Args:
522
        transforms (sequence or torch.nn.Module): list of transformations
523
524
525
526
        p (float): probability
    """

    def __init__(self, transforms, p=0.5):
527
        super().__init__()
528
        _log_api_usage_once(self)
529
        self.transforms = transforms
530
531
        self.p = p

532
533
    def forward(self, img):
        if self.p < torch.rand(1):
534
535
536
537
538
            return img
        for t in self.transforms:
            img = t(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
539
    def __repr__(self) -> str:
540
        format_string = self.__class__.__name__ + "("
541
        format_string += f"\n    p={self.p}"
542
        for t in self.transforms:
543
            format_string += "\n"
544
            format_string += f"    {t}"
545
        format_string += "\n)"
546
547
548
549
        return format_string


class RandomOrder(RandomTransforms):
550
551
    """Apply a list of transformations in a random order. This transform does not support torchscript."""

552
553
554
555
556
557
558
559
560
    def __call__(self, img):
        order = list(range(len(self.transforms)))
        random.shuffle(order)
        for i in order:
            img = self.transforms[i](img)
        return img


class RandomChoice(RandomTransforms):
561
562
    """Apply single transformation randomly picked from a list. This transform does not support torchscript."""

563
564
565
    def __init__(self, transforms, p=None):
        super().__init__(transforms)
        if p is not None and not isinstance(p, Sequence):
566
            raise TypeError("Argument p should be a sequence")
567
568
569
570
571
572
        self.p = p

    def __call__(self, *args):
        t = random.choices(self.transforms, weights=self.p)[0]
        return t(*args)

Joao Gomes's avatar
Joao Gomes committed
573
574
    def __repr__(self) -> str:
        return f"{super().__repr__()}(p={self.p})"
575
576


vfdev's avatar
vfdev committed
577
578
class RandomCrop(torch.nn.Module):
    """Crop the given image at a random location.
579
    If the image is torch Tensor, it is expected
580
581
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
    but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
582
583
584
585

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
586
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
587
        padding (int or sequence, optional): Optional padding on each border
vfdev's avatar
vfdev committed
588
            of the image. Default is None. If a single int is provided this
589
590
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
vfdev's avatar
vfdev committed
591
            this is the padding for the left, top, right and bottom borders respectively.
592
593
594
595

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
596
        pad_if_needed (boolean): It will pad the image if smaller than the
ekka's avatar
ekka committed
597
            desired size to avoid raising an exception. Since cropping is done
598
            after padding, the padding seems to be done at a random offset.
599
        fill (number or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
600
            length 3, it is used to fill R, G, B channels respectively.
601
602
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
603
            Only int or tuple value is supported for PIL Image.
604
605
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is constant.
606

607
            - constant: pads with a constant value, this value is specified with fill
608

609
610
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
611

612
613
614
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
615

616
617
618
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
619
620
621
    """

    @staticmethod
vfdev's avatar
vfdev committed
622
    def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
623
624
625
        """Get parameters for ``crop`` for a random crop.

        Args:
vfdev's avatar
vfdev committed
626
            img (PIL Image or Tensor): Image to be cropped.
627
628
629
630
631
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
632
        _, h, w = F.get_dimensions(img)
633
        th, tw = output_size
vfdev's avatar
vfdev committed
634

635
        if h < th or w < tw:
636
            raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
vfdev's avatar
vfdev committed
637

638
639
640
        if w == tw and h == th:
            return 0, 0, h, w

641
642
        i = torch.randint(0, h - th + 1, size=(1,)).item()
        j = torch.randint(0, w - tw + 1, size=(1,)).item()
643
644
        return i, j, th, tw

vfdev's avatar
vfdev committed
645
646
    def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
        super().__init__()
647
        _log_api_usage_once(self)
vfdev's avatar
vfdev committed
648

649
        self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
650

vfdev's avatar
vfdev committed
651
652
653
654
655
656
        self.padding = padding
        self.pad_if_needed = pad_if_needed
        self.fill = fill
        self.padding_mode = padding_mode

    def forward(self, img):
657
658
        """
        Args:
vfdev's avatar
vfdev committed
659
            img (PIL Image or Tensor): Image to be cropped.
660
661

        Returns:
vfdev's avatar
vfdev committed
662
            PIL Image or Tensor: Cropped image.
663
        """
664
665
        if self.padding is not None:
            img = F.pad(img, self.padding, self.fill, self.padding_mode)
666

667
        _, height, width = F.get_dimensions(img)
668
        # pad the width if needed
vfdev's avatar
vfdev committed
669
670
671
        if self.pad_if_needed and width < self.size[1]:
            padding = [self.size[1] - width, 0]
            img = F.pad(img, padding, self.fill, self.padding_mode)
672
        # pad the height if needed
vfdev's avatar
vfdev committed
673
674
675
        if self.pad_if_needed and height < self.size[0]:
            padding = [0, self.size[0] - height]
            img = F.pad(img, padding, self.fill, self.padding_mode)
676

677
678
679
680
        i, j, h, w = self.get_params(img, self.size)

        return F.crop(img, i, j, h, w)

Joao Gomes's avatar
Joao Gomes committed
681
682
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})"
683

684

685
686
class RandomHorizontalFlip(torch.nn.Module):
    """Horizontally flip the given image randomly with a given probability.
687
    If the image is torch Tensor, it is expected
688
689
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
690
691
692
693
694
695

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
696
        super().__init__()
697
        _log_api_usage_once(self)
698
        self.p = p
699

700
    def forward(self, img):
701
702
        """
        Args:
703
            img (PIL Image or Tensor): Image to be flipped.
704
705

        Returns:
706
            PIL Image or Tensor: Randomly flipped image.
707
        """
708
        if torch.rand(1) < self.p:
709
710
711
            return F.hflip(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
712
713
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
714

715

716
class RandomVerticalFlip(torch.nn.Module):
vfdev's avatar
vfdev committed
717
    """Vertically flip the given image randomly with a given probability.
718
    If the image is torch Tensor, it is expected
719
720
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
721
722
723
724
725
726

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
727
        super().__init__()
728
        _log_api_usage_once(self)
729
        self.p = p
730

731
    def forward(self, img):
732
733
        """
        Args:
734
            img (PIL Image or Tensor): Image to be flipped.
735
736

        Returns:
737
            PIL Image or Tensor: Randomly flipped image.
738
        """
739
        if torch.rand(1) < self.p:
740
741
742
            return F.vflip(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
743
744
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
745

746

747
748
class RandomPerspective(torch.nn.Module):
    """Performs a random perspective transformation of the given image with a given probability.
749
    If the image is torch Tensor, it is expected
750
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
751
752

    Args:
753
754
755
        distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
            Default is 0.5.
        p (float): probability of the image being transformed. Default is 0.5.
756
757
758
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
759
760
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
761
762
    """

763
    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
764
        super().__init__()
765
        _log_api_usage_once(self)
766
        self.p = p
767

768
769
        self.interpolation = interpolation
        self.distortion_scale = distortion_scale
770
771
772
773
774
775

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

776
        self.fill = fill
777

778
    def forward(self, img):
779
780
        """
        Args:
781
            img (PIL Image or Tensor): Image to be Perspectively transformed.
782
783

        Returns:
784
            PIL Image or Tensor: Randomly transformed image.
785
        """
786
787

        fill = self.fill
788
        channels, height, width = F.get_dimensions(img)
789
790
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
791
                fill = [float(fill)] * channels
792
793
794
            else:
                fill = [float(f) for f in fill]

795
        if torch.rand(1) < self.p:
796
            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
797
            return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
798
799
800
        return img

    @staticmethod
801
    def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
802
803
804
        """Get parameters for ``perspective`` for a random perspective transform.

        Args:
805
806
807
            width (int): width of the image.
            height (int): height of the image.
            distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
808
809

        Returns:
810
            List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
811
812
            List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
        """
813
814
815
        half_height = height // 2
        half_width = width // 2
        topleft = [
816
817
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
818
819
        ]
        topright = [
820
821
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
822
823
        ]
        botright = [
824
825
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
826
827
        ]
        botleft = [
828
829
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
830
831
        ]
        startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
832
833
834
        endpoints = [topleft, topright, botright, botleft]
        return startpoints, endpoints

Joao Gomes's avatar
Joao Gomes committed
835
836
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
837
838


839
class RandomResizedCrop(torch.nn.Module):
840
841
    """Crop a random portion of image and resize it to a given size.

842
    If the image is torch Tensor, it is expected
843
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
844

845
846
847
    A crop of the original image is made: the crop has a random area (H * W)
    and a random aspect ratio. This crop is finally resized to the given
    size. This is popularly used to train the Inception networks.
848
849

    Args:
850
        size (int or sequence): expected output size of the crop, for each edge. If size is an
851
            int instead of sequence like (h, w), a square output size ``(size, size)`` is
852
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
853
854
855

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
Nicolas Hug's avatar
Nicolas Hug committed
856
857
        scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
            before resizing. The scale is defined with respect to the area of the original image.
858
859
        ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
            resizing.
860
861
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
862
863
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
        antialias (bool, optional): Whether to apply antialiasing.
            It only affects **tensors** with bilinear or bicubic modes and it is
            ignored otherwise: on PIL images, antialiasing is always applied on
            bilinear or bicubic modes; on other modes (for PIL images and
            tensors), antialiasing makes no sense and this parameter is ignored.
            Possible values are:

            - ``True``: will apply antialiasing for bilinear or bicubic modes.
              Other mode aren't affected. This is probably what you want to use.
            - ``False``: will not apply antialiasing for tensors on any mode. PIL
              images are still antialiased on bilinear or bicubic modes, because
              PIL doesn't support no antialias.
            - ``None``: equivalent to ``False`` for tensors and ``True`` for
              PIL images. This value exists for legacy reasons and you probably
              don't want to use it unless you really know what you are doing.

            The current default is ``None`` **but will change to** ``True`` **in
            v0.17** for the PIL and Tensor backends to be consistent.
882
883
    """

884
885
886
887
888
889
    def __init__(
        self,
        size,
        scale=(0.08, 1.0),
        ratio=(3.0 / 4.0, 4.0 / 3.0),
        interpolation=InterpolationMode.BILINEAR,
890
        antialias: Optional[Union[str, bool]] = "warn",
891
    ):
892
        super().__init__()
893
        _log_api_usage_once(self)
894
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
895

896
        if not isinstance(scale, Sequence):
897
            raise TypeError("Scale should be a sequence")
898
        if not isinstance(ratio, Sequence):
899
            raise TypeError("Ratio should be a sequence")
900
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
901
            warnings.warn("Scale and ratio should be of kind (min, max)")
902

903
        self.interpolation = interpolation
904
        self.antialias = antialias
905
        self.interpolation = interpolation
906
907
        self.scale = scale
        self.ratio = ratio
908
909

    @staticmethod
910
    def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
911
912
913
        """Get parameters for ``crop`` for a random sized crop.

        Args:
914
            img (PIL Image or Tensor): Input image.
915
916
            scale (list): range of scale of the origin size cropped
            ratio (list): range of aspect ratio of the origin aspect ratio cropped
917
918
919

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
920
            sized crop.
921
        """
922
        _, height, width = F.get_dimensions(img)
Zhicheng Yan's avatar
Zhicheng Yan committed
923
        area = height * width
924

925
        log_ratio = torch.log(torch.tensor(ratio))
926
        for _ in range(10):
927
            target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
928
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
929
930
931
932

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

Zhicheng Yan's avatar
Zhicheng Yan committed
933
            if 0 < w <= width and 0 < h <= height:
934
935
                i = torch.randint(0, height - h + 1, size=(1,)).item()
                j = torch.randint(0, width - w + 1, size=(1,)).item()
936
937
                return i, j, h, w

938
        # Fallback to central crop
Zhicheng Yan's avatar
Zhicheng Yan committed
939
        in_ratio = float(width) / float(height)
940
        if in_ratio < min(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
941
            w = width
942
            h = int(round(w / min(ratio)))
943
        elif in_ratio > max(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
944
            h = height
945
            w = int(round(h * max(ratio)))
946
        else:  # whole image
Zhicheng Yan's avatar
Zhicheng Yan committed
947
948
949
950
            w = width
            h = height
        i = (height - h) // 2
        j = (width - w) // 2
951
        return i, j, h, w
952

953
    def forward(self, img):
954
955
        """
        Args:
956
            img (PIL Image or Tensor): Image to be cropped and resized.
957
958

        Returns:
959
            PIL Image or Tensor: Randomly cropped and resized image.
960
        """
961
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
962
        return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
963

Joao Gomes's avatar
Joao Gomes committed
964
    def __repr__(self) -> str:
965
        interpolate_str = self.interpolation.value
966
967
968
        format_string = self.__class__.__name__ + f"(size={self.size}"
        format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
        format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
969
        format_string += f", interpolation={interpolate_str}"
970
        format_string += f", antialias={self.antialias})"
971
        return format_string
972

973

vfdev's avatar
vfdev committed
974
975
class FiveCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop.
976
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
977
978
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
979
980
981
982
983
984
985
986
987

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.
988
            If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
989
990
991
992

    Example:
         >>> transform = Compose([
         >>>    FiveCrop(size), # this is a list of PIL Images
993
         >>>    Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
994
995
996
997
998
999
1000
1001
1002
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
1003
        super().__init__()
1004
        _log_api_usage_once(self)
1005
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1006

vfdev's avatar
vfdev committed
1007
1008
1009
1010
1011
1012
1013
1014
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 5 images. Image can be PIL Image or Tensor
        """
1015
1016
        return F.five_crop(img, self.size)

Joao Gomes's avatar
Joao Gomes committed
1017
1018
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size})"
1019

1020

vfdev's avatar
vfdev committed
1021
1022
1023
class TenCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default).
1024
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1025
1026
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
1027
1028
1029
1030
1031
1032
1033
1034
1035

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
1036
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
1037
        vertical_flip (bool): Use vertical flipping instead of horizontal
1038
1039
1040
1041

    Example:
         >>> transform = Compose([
         >>>    TenCrop(size), # this is a list of PIL Images
1042
         >>>    Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
1043
1044
1045
1046
1047
1048
1049
1050
1051
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size, vertical_flip=False):
vfdev's avatar
vfdev committed
1052
        super().__init__()
1053
        _log_api_usage_once(self)
1054
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1055
1056
        self.vertical_flip = vertical_flip

vfdev's avatar
vfdev committed
1057
1058
1059
1060
1061
1062
1063
1064
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 10 images. Image can be PIL Image or Tensor
        """
1065
1066
        return F.ten_crop(img, self.size, self.vertical_flip)

Joao Gomes's avatar
Joao Gomes committed
1067
1068
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
1069

1070

1071
class LinearTransformation(torch.nn.Module):
ekka's avatar
ekka committed
1072
    """Transform a tensor image with a square transformation matrix and a mean_vector computed
1073
    offline.
1074
    This transform does not support PIL Image.
ekka's avatar
ekka committed
1075
1076
1077
    Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
    subtract mean_vector from it which is then followed by computing the dot
    product with the transformation matrix and then reshaping the tensor to its
1078
    original shape.
1079

1080
    Applications:
1081
        whitening transformation: Suppose X is a column vector zero-centered data.
1082
1083
1084
        Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
        perform SVD on this matrix and pass it as transformation_matrix.

1085
1086
    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
ekka's avatar
ekka committed
1087
        mean_vector (Tensor): tensor [D], D = C x H x W
1088
1089
    """

ekka's avatar
ekka committed
1090
    def __init__(self, transformation_matrix, mean_vector):
1091
        super().__init__()
1092
        _log_api_usage_once(self)
1093
        if transformation_matrix.size(0) != transformation_matrix.size(1):
1094
1095
            raise ValueError(
                "transformation_matrix should be square. Got "
1096
                f"{tuple(transformation_matrix.size())} rectangular matrix."
1097
            )
ekka's avatar
ekka committed
1098
1099

        if mean_vector.size(0) != transformation_matrix.size(0):
1100
            raise ValueError(
1101
1102
                f"mean_vector should have the same length {mean_vector.size(0)}"
                f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
1103
            )
ekka's avatar
ekka committed
1104

1105
        if transformation_matrix.device != mean_vector.device:
1106
            raise ValueError(
1107
                f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
1108
            )
1109

1110
1111
1112
1113
1114
        if transformation_matrix.dtype != mean_vector.dtype:
            raise ValueError(
                f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}"
            )

1115
        self.transformation_matrix = transformation_matrix
ekka's avatar
ekka committed
1116
        self.mean_vector = mean_vector
1117

1118
    def forward(self, tensor: Tensor) -> Tensor:
1119
1120
        """
        Args:
vfdev's avatar
vfdev committed
1121
            tensor (Tensor): Tensor image to be whitened.
1122
1123
1124
1125

        Returns:
            Tensor: Transformed image.
        """
1126
1127
1128
        shape = tensor.shape
        n = shape[-3] * shape[-2] * shape[-1]
        if n != self.transformation_matrix.shape[0]:
1129
1130
            raise ValueError(
                "Input tensor and transformation matrix have incompatible shape."
1131
1132
                + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
                + f"{self.transformation_matrix.shape[0]}"
1133
            )
1134
1135

        if tensor.device.type != self.mean_vector.device.type:
1136
1137
            raise ValueError(
                "Input tensor should be on the same device as transformation matrix and mean vector. "
1138
                f"Got {tensor.device} vs {self.mean_vector.device}"
1139
            )
1140
1141

        flat_tensor = tensor.view(-1, n) - self.mean_vector
1142
1143
1144
1145

        transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
        transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
        return transformed_tensor.view(shape)
1146

Joao Gomes's avatar
Joao Gomes committed
1147
1148
1149
1150
1151
1152
1153
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}(transformation_matrix="
            f"{self.transformation_matrix.tolist()}"
            f", mean_vector={self.mean_vector.tolist()})"
        )
        return s
1154

1155

1156
class ColorJitter(torch.nn.Module):
1157
    """Randomly change the brightness, contrast, saturation and hue of an image.
1158
    If the image is torch Tensor, it is expected
1159
1160
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
1161
1162

    Args:
yaox12's avatar
yaox12 committed
1163
1164
1165
1166
1167
        brightness (float or tuple of float (min, max)): How much to jitter brightness.
            brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
            or the given [min, max]. Should be non negative numbers.
        contrast (float or tuple of float (min, max)): How much to jitter contrast.
            contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
1168
            or the given [min, max]. Should be non-negative numbers.
yaox12's avatar
yaox12 committed
1169
1170
1171
1172
1173
1174
        saturation (float or tuple of float (min, max)): How much to jitter saturation.
            saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
            or the given [min, max]. Should be non negative numbers.
        hue (float or tuple of float (min, max)): How much to jitter hue.
            hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
            Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1175
1176
1177
            To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
            thus it does not work if you normalize your image to an interval with negative values,
            or use an interpolation that generates negative values before using this function.
1178
    """
1179

1180
1181
1182
1183
1184
1185
1186
    def __init__(
        self,
        brightness: Union[float, Tuple[float, float]] = 0,
        contrast: Union[float, Tuple[float, float]] = 0,
        saturation: Union[float, Tuple[float, float]] = 0,
        hue: Union[float, Tuple[float, float]] = 0,
    ) -> None:
1187
        super().__init__()
1188
        _log_api_usage_once(self)
1189
1190
1191
1192
        self.brightness = self._check_input(brightness, "brightness")
        self.contrast = self._check_input(contrast, "contrast")
        self.saturation = self._check_input(saturation, "saturation")
        self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
yaox12's avatar
yaox12 committed
1193

1194
    @torch.jit.unused
1195
    def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
yaox12's avatar
yaox12 committed
1196
1197
        if isinstance(value, numbers.Number):
            if value < 0:
1198
                raise ValueError(f"If {name} is a single number, it must be non negative.")
1199
            value = [center - float(value), center + float(value)]
yaox12's avatar
yaox12 committed
1200
            if clip_first_on_zero:
1201
                value[0] = max(value[0], 0.0)
yaox12's avatar
yaox12 committed
1202
        elif isinstance(value, (tuple, list)) and len(value) == 2:
1203
            value = [float(value[0]), float(value[1])]
yaox12's avatar
yaox12 committed
1204
        else:
1205
            raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
yaox12's avatar
yaox12 committed
1206

1207
1208
1209
        if not bound[0] <= value[0] <= value[1] <= bound[1]:
            raise ValueError(f"{name} values should be between {bound}, but got {value}.")

yaox12's avatar
yaox12 committed
1210
1211
1212
        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
1213
1214
1215
            return None
        else:
            return tuple(value)
1216
1217

    @staticmethod
1218
1219
1220
1221
1222
1223
    def get_params(
        brightness: Optional[List[float]],
        contrast: Optional[List[float]],
        saturation: Optional[List[float]],
        hue: Optional[List[float]],
    ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
1224
        """Get the parameters for the randomized transform to be applied on image.
1225

1226
1227
1228
1229
1230
1231
1232
1233
1234
        Args:
            brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
                uniformly. Pass None to turn off the transformation.
            contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
                uniformly. Pass None to turn off the transformation.
            saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
                uniformly. Pass None to turn off the transformation.
            hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
                Pass None to turn off the transformation.
1235
1236

        Returns:
1237
1238
            tuple: The parameters used to apply the randomized transform
            along with their random order.
1239
        """
1240
        fn_idx = torch.randperm(4)
1241

1242
1243
1244
1245
        b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
        c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
        s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
        h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
1246

1247
        return fn_idx, b, c, s, h
1248

1249
    def forward(self, img):
1250
1251
        """
        Args:
1252
            img (PIL Image or Tensor): Input image.
1253
1254

        Returns:
1255
1256
            PIL Image or Tensor: Color jittered image.
        """
1257
1258
1259
        fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
            self.brightness, self.contrast, self.saturation, self.hue
        )
1260

1261
        for fn_id in fn_idx:
1262
            if fn_id == 0 and brightness_factor is not None:
1263
                img = F.adjust_brightness(img, brightness_factor)
1264
            elif fn_id == 1 and contrast_factor is not None:
1265
                img = F.adjust_contrast(img, contrast_factor)
1266
            elif fn_id == 2 and saturation_factor is not None:
1267
                img = F.adjust_saturation(img, saturation_factor)
1268
            elif fn_id == 3 and hue_factor is not None:
1269
1270
1271
                img = F.adjust_hue(img, hue_factor)

        return img
1272

Joao Gomes's avatar
Joao Gomes committed
1273
1274
1275
1276
1277
1278
1279
1280
1281
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}("
            f"brightness={self.brightness}"
            f", contrast={self.contrast}"
            f", saturation={self.saturation}"
            f", hue={self.hue})"
        )
        return s
1282

1283

1284
class RandomRotation(torch.nn.Module):
1285
    """Rotate the image by angle.
1286
    If the image is torch Tensor, it is expected
1287
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1288
1289

    Args:
1290
        degrees (sequence or number): Range of degrees to select from.
1291
1292
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees).
1293
1294
1295
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1296
1297
1298
1299
        expand (bool, optional): Optional expansion flag.
            If true, expands the output to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1300
        center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1301
            Default is the center of the image.
1302
1303
        fill (sequence or number): Pixel fill value for the area outside the rotated
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1304
1305
1306

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1307
1308
    """

1309
    def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0):
1310
        super().__init__()
1311
        _log_api_usage_once(self)
1312

1313
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1314
1315

        if center is not None:
1316
            _check_sequence_input(center, "center", req_sizes=(2,))
1317
1318

        self.center = center
1319

1320
        self.interpolation = interpolation
1321
        self.expand = expand
1322
1323
1324
1325
1326
1327

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1328
        self.fill = fill
1329
1330

    @staticmethod
1331
    def get_params(degrees: List[float]) -> float:
1332
1333
1334
        """Get parameters for ``rotate`` for a random rotation.

        Returns:
1335
            float: angle parameter to be passed to ``rotate`` for random rotation.
1336
        """
1337
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1338
1339
        return angle

1340
    def forward(self, img):
1341
        """
1342
        Args:
1343
            img (PIL Image or Tensor): Image to be rotated.
1344
1345

        Returns:
1346
            PIL Image or Tensor: Rotated image.
1347
        """
1348
        fill = self.fill
1349
        channels, _, _ = F.get_dimensions(img)
1350
1351
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
1352
                fill = [float(fill)] * channels
1353
1354
            else:
                fill = [float(f) for f in fill]
1355
        angle = self.get_params(self.degrees)
1356

1357
        return F.rotate(img, angle, self.interpolation, self.expand, self.center, fill)
1358

Joao Gomes's avatar
Joao Gomes committed
1359
    def __repr__(self) -> str:
1360
        interpolate_str = self.interpolation.value
1361
1362
1363
        format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
        format_string += f", interpolation={interpolate_str}"
        format_string += f", expand={self.expand}"
1364
        if self.center is not None:
1365
            format_string += f", center={self.center}"
1366
        if self.fill is not None:
1367
            format_string += f", fill={self.fill}"
1368
        format_string += ")"
1369
        return format_string
1370

1371

1372
1373
class RandomAffine(torch.nn.Module):
    """Random affine transformation of the image keeping center invariant.
1374
    If the image is torch Tensor, it is expected
1375
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1376
1377

    Args:
1378
        degrees (sequence or number): Range of degrees to select from.
1379
            If degrees is a number instead of sequence like (min, max), the range of degrees
1380
            will be (-degrees, +degrees). Set to 0 to deactivate rotations.
1381
1382
1383
1384
1385
1386
        translate (tuple, optional): tuple of maximum absolute fraction for horizontal
            and vertical translations. For example translate=(a, b), then horizontal shift
            is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
            randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
        scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
            randomly sampled from the range a <= scale <= b. Will keep original scale by default.
1387
        shear (sequence or number, optional): Range of degrees to select from.
1388
1389
            If shear is a number, a shear parallel to the x-axis in the range (-shear, +shear)
            will be applied. Else if shear is a sequence of 2 values a shear parallel to the x-axis in the
1390
            range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
1391
            an x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1392
            Will not apply shear by default.
1393
1394
1395
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1396
1397
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1398
1399
        center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
            Default is the center of the image.
1400
1401
1402

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1403
1404
    """

1405
    def __init__(
1406
1407
1408
1409
1410
1411
1412
        self,
        degrees,
        translate=None,
        scale=None,
        shear=None,
        interpolation=InterpolationMode.NEAREST,
        fill=0,
1413
        center=None,
1414
    ):
1415
        super().__init__()
1416
        _log_api_usage_once(self)
1417

1418
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1419
1420

        if translate is not None:
1421
            _check_sequence_input(translate, "translate", req_sizes=(2,))
1422
1423
1424
1425
1426
1427
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
1428
            _check_sequence_input(scale, "scale", req_sizes=(2,))
1429
1430
1431
1432
1433
1434
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
1435
            self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
1436
1437
1438
        else:
            self.shear = shear

1439
        self.interpolation = interpolation
1440
1441
1442
1443
1444
1445

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1446
        self.fill = fill
1447

1448
1449
1450
1451
1452
        if center is not None:
            _check_sequence_input(center, "center", req_sizes=(2,))

        self.center = center

1453
    @staticmethod
1454
    def get_params(
1455
1456
1457
1458
1459
        degrees: List[float],
        translate: Optional[List[float]],
        scale_ranges: Optional[List[float]],
        shears: Optional[List[float]],
        img_size: List[int],
1460
    ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
1461
1462
1463
        """Get parameters for affine transformation

        Returns:
1464
            params to be passed to the affine transformation
1465
        """
1466
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1467
        if translate is not None:
1468
1469
1470
1471
1472
            max_dx = float(translate[0] * img_size[0])
            max_dy = float(translate[1] * img_size[1])
            tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
            ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
            translations = (tx, ty)
1473
1474
1475
1476
        else:
            translations = (0, 0)

        if scale_ranges is not None:
1477
            scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
1478
1479
1480
        else:
            scale = 1.0

1481
        shear_x = shear_y = 0.0
1482
        if shears is not None:
1483
1484
1485
1486
1487
            shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
            if len(shears) == 4:
                shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())

        shear = (shear_x, shear_y)
1488
1489
1490

        return angle, translations, scale, shear

1491
    def forward(self, img):
1492
        """
1493
            img (PIL Image or Tensor): Image to be transformed.
1494
1495

        Returns:
1496
            PIL Image or Tensor: Affine transformed image.
1497
        """
1498
        fill = self.fill
1499
        channels, height, width = F.get_dimensions(img)
1500
1501
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
1502
                fill = [float(fill)] * channels
1503
1504
            else:
                fill = [float(f) for f in fill]
1505

1506
        img_size = [width, height]  # flip for keeping BC on get_params call
1507
1508

        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
1509

1510
        return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
1511

Joao Gomes's avatar
Joao Gomes committed
1512
1513
1514
1515
1516
1517
1518
1519
    def __repr__(self) -> str:
        s = f"{self.__class__.__name__}(degrees={self.degrees}"
        s += f", translate={self.translate}" if self.translate is not None else ""
        s += f", scale={self.scale}" if self.scale is not None else ""
        s += f", shear={self.shear}" if self.shear is not None else ""
        s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
        s += f", fill={self.fill}" if self.fill != 0 else ""
        s += f", center={self.center}" if self.center is not None else ""
1520
        s += ")"
Joao Gomes's avatar
Joao Gomes committed
1521
1522

        return s
1523
1524


1525
class Grayscale(torch.nn.Module):
1526
    """Convert image to grayscale.
1527
1528
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1529

1530
1531
1532
1533
    Args:
        num_output_channels (int): (1 or 3) number of channels desired for output image

    Returns:
1534
        PIL Image: Grayscale version of the input.
1535
1536
1537

        - If ``num_output_channels == 1`` : returned image is single channel
        - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
1538
1539
1540
1541

    """

    def __init__(self, num_output_channels=1):
1542
        super().__init__()
1543
        _log_api_usage_once(self)
1544
1545
        self.num_output_channels = num_output_channels

vfdev's avatar
vfdev committed
1546
    def forward(self, img):
1547
1548
        """
        Args:
1549
            img (PIL Image or Tensor): Image to be converted to grayscale.
1550
1551

        Returns:
1552
            PIL Image or Tensor: Grayscaled image.
1553
        """
1554
        return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
1555

Joao Gomes's avatar
Joao Gomes committed
1556
1557
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})"
1558

1559

1560
class RandomGrayscale(torch.nn.Module):
1561
    """Randomly convert image to grayscale with a probability of p (default 0.1).
1562
1563
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1564

1565
1566
1567
1568
    Args:
        p (float): probability that image should be converted to grayscale.

    Returns:
1569
        PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
1570
1571
1572
        with probability (1-p).
        - If input image is 1 channel: grayscale version is 1 channel
        - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1573
1574
1575
1576

    """

    def __init__(self, p=0.1):
1577
        super().__init__()
1578
        _log_api_usage_once(self)
1579
1580
        self.p = p

vfdev's avatar
vfdev committed
1581
    def forward(self, img):
1582
1583
        """
        Args:
1584
            img (PIL Image or Tensor): Image to be converted to grayscale.
1585
1586

        Returns:
1587
            PIL Image or Tensor: Randomly grayscaled image.
1588
        """
1589
        num_output_channels, _, _ = F.get_dimensions(img)
1590
1591
        if torch.rand(1) < self.p:
            return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
1592
        return img
1593

Joao Gomes's avatar
Joao Gomes committed
1594
1595
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
1596
1597


1598
class RandomErasing(torch.nn.Module):
1599
    """Randomly selects a rectangle region in a torch.Tensor image and erases its pixels.
1600
    This transform does not support PIL Image.
vfdev's avatar
vfdev committed
1601
    'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
1602

1603
1604
1605
1606
1607
1608
1609
1610
    Args:
         p: probability that the random erasing operation will be performed.
         scale: range of proportion of erased area against input image.
         ratio: range of aspect ratio of erased area.
         value: erasing value. Default is 0. If a single int, it is used to
            erase all pixels. If a tuple of length 3, it is used to erase
            R, G, B channels respectively.
            If a str of 'random', erasing each pixel with random values.
Zhun Zhong's avatar
Zhun Zhong committed
1611
         inplace: boolean to make this transform inplace. Default set to False.
1612

1613
1614
    Returns:
        Erased Image.
1615

vfdev's avatar
vfdev committed
1616
    Example:
1617
        >>> transform = transforms.Compose([
1618
        >>>   transforms.RandomHorizontalFlip(),
1619
1620
        >>>   transforms.PILToTensor(),
        >>>   transforms.ConvertImageDtype(torch.float),
1621
1622
        >>>   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>>   transforms.RandomErasing(),
1623
1624
1625
        >>> ])
    """

Zhun Zhong's avatar
Zhun Zhong committed
1626
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
1627
        super().__init__()
1628
        _log_api_usage_once(self)
1629
1630
1631
1632
1633
1634
1635
1636
        if not isinstance(value, (numbers.Number, str, tuple, list)):
            raise TypeError("Argument value should be either a number or str or a sequence")
        if isinstance(value, str) and value != "random":
            raise ValueError("If value is str, it should be 'random'")
        if not isinstance(scale, (tuple, list)):
            raise TypeError("Scale should be a sequence")
        if not isinstance(ratio, (tuple, list)):
            raise TypeError("Ratio should be a sequence")
1637
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
1638
            warnings.warn("Scale and ratio should be of kind (min, max)")
1639
        if scale[0] < 0 or scale[1] > 1:
1640
            raise ValueError("Scale should be between 0 and 1")
1641
        if p < 0 or p > 1:
1642
            raise ValueError("Random erasing probability should be between 0 and 1")
1643
1644
1645
1646
1647

        self.p = p
        self.scale = scale
        self.ratio = ratio
        self.value = value
1648
        self.inplace = inplace
1649
1650

    @staticmethod
1651
    def get_params(
1652
        img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
1653
    ) -> Tuple[int, int, int, int, Tensor]:
1654
1655
1656
        """Get parameters for ``erase`` for a random erasing.

        Args:
vfdev's avatar
vfdev committed
1657
            img (Tensor): Tensor image to be erased.
1658
1659
            scale (sequence): range of proportion of erased area against input image.
            ratio (sequence): range of aspect ratio of erased area.
1660
1661
1662
            value (list, optional): erasing value. If None, it is interpreted as "random"
                (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
                i.e. ``value[0]``.
1663
1664
1665
1666

        Returns:
            tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
        """
vfdev's avatar
vfdev committed
1667
        img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
1668
        area = img_h * img_w
1669

1670
        log_ratio = torch.log(torch.tensor(ratio))
1671
        for _ in range(10):
1672
            erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
1673
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
1674
1675
1676

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
1677
1678
1679
1680
1681
1682
1683
            if not (h < img_h and w < img_w):
                continue

            if value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(value)[:, None, None]
1684

1685
1686
            i = torch.randint(0, img_h - h + 1, size=(1,)).item()
            j = torch.randint(0, img_w - w + 1, size=(1,)).item()
1687
            return i, j, h, w, v
1688

Zhun Zhong's avatar
Zhun Zhong committed
1689
1690
1691
        # Return original image
        return 0, 0, img_h, img_w, img

1692
    def forward(self, img):
1693
1694
        """
        Args:
vfdev's avatar
vfdev committed
1695
            img (Tensor): Tensor image to be erased.
1696
1697
1698
1699

        Returns:
            img (Tensor): Erased Tensor image.
        """
1700
1701
1702
1703
        if torch.rand(1) < self.p:

            # cast self.value to script acceptable type
            if isinstance(self.value, (int, float)):
1704
                value = [float(self.value)]
1705
1706
            elif isinstance(self.value, str):
                value = None
1707
1708
            elif isinstance(self.value, (list, tuple)):
                value = [float(v) for v in self.value]
1709
1710
1711
1712
1713
1714
            else:
                value = self.value

            if value is not None and not (len(value) in (1, img.shape[-3])):
                raise ValueError(
                    "If value is a sequence, it should have either a single value or "
1715
                    f"{img.shape[-3]} (number of input channels)"
1716
1717
1718
                )

            x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
1719
            return F.erase(img, x, y, h, w, v, self.inplace)
1720
        return img
1721

Joao Gomes's avatar
Joao Gomes committed
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
    def __repr__(self) -> str:
        s = (
            f"{self.__class__.__name__}"
            f"(p={self.p}, "
            f"scale={self.scale}, "
            f"ratio={self.ratio}, "
            f"value={self.value}, "
            f"inplace={self.inplace})"
        )
        return s
1732

1733

1734
1735
class GaussianBlur(torch.nn.Module):
    """Blurs image with randomly chosen Gaussian blur.
1736
1737
    If the image is torch Tensor, it is expected
    to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752

    Args:
        kernel_size (int or sequence): Size of the Gaussian kernel.
        sigma (float or tuple of float (min, max)): Standard deviation to be used for
            creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
            of float (min, max), sigma is chosen uniformly at random to lie in the
            given range.

    Returns:
        PIL Image or Tensor: Gaussian blurred version of the input image.

    """

    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        super().__init__()
1753
        _log_api_usage_once(self)
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
        self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
        for ks in self.kernel_size:
            if ks <= 0 or ks % 2 == 0:
                raise ValueError("Kernel size value should be an odd and positive number.")

        if isinstance(sigma, numbers.Number):
            if sigma <= 0:
                raise ValueError("If sigma is a single number, it must be positive.")
            sigma = (sigma, sigma)
        elif isinstance(sigma, Sequence) and len(sigma) == 2:
1764
            if not 0.0 < sigma[0] <= sigma[1]:
1765
1766
1767
1768
1769
1770
1771
1772
                raise ValueError("sigma values should be positive and of the form (min, max).")
        else:
            raise ValueError("sigma should be a single number or a list/tuple with length 2.")

        self.sigma = sigma

    @staticmethod
    def get_params(sigma_min: float, sigma_max: float) -> float:
vfdev's avatar
vfdev committed
1773
        """Choose sigma for random gaussian blurring.
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786

        Args:
            sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
            sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.

        Returns:
            float: Standard deviation to be passed to calculate kernel for gaussian blurring.
        """
        return torch.empty(1).uniform_(sigma_min, sigma_max).item()

    def forward(self, img: Tensor) -> Tensor:
        """
        Args:
vfdev's avatar
vfdev committed
1787
            img (PIL Image or Tensor): image to be blurred.
1788
1789
1790
1791
1792
1793
1794

        Returns:
            PIL Image or Tensor: Gaussian blurred image
        """
        sigma = self.get_params(self.sigma[0], self.sigma[1])
        return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])

Joao Gomes's avatar
Joao Gomes committed
1795
1796
1797
    def __repr__(self) -> str:
        s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
        return s
1798
1799


1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
def _setup_size(size, error_msg):
    if isinstance(size, numbers.Number):
        return int(size), int(size)

    if isinstance(size, Sequence) and len(size) == 1:
        return size[0], size[0]

    if len(size) != 2:
        raise ValueError(error_msg)

    return size


def _check_sequence_input(x, name, req_sizes):
    msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
    if not isinstance(x, Sequence):
1816
        raise TypeError(f"{name} should be a sequence of length {msg}.")
1817
    if len(x) not in req_sizes:
1818
        raise ValueError(f"{name} should be a sequence of length {msg}.")
1819
1820


1821
def _setup_angle(x, name, req_sizes=(2,)):
1822
1823
    if isinstance(x, numbers.Number):
        if x < 0:
1824
            raise ValueError(f"If {name} is a single number, it must be positive.")
1825
1826
1827
1828
1829
        x = [-x, x]
    else:
        _check_sequence_input(x, name, req_sizes)

    return [float(d) for d in x]
1830
1831
1832
1833


class RandomInvert(torch.nn.Module):
    """Inverts the colors of the given image randomly with a given probability.
1834
1835
1836
    If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1837
1838
1839
1840
1841
1842
1843

    Args:
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
1844
        _log_api_usage_once(self)
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be inverted.

        Returns:
            PIL Image or Tensor: Randomly color inverted image.
        """
        if torch.rand(1).item() < self.p:
            return F.invert(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
1859
1860
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
1861
1862
1863
1864


class RandomPosterize(torch.nn.Module):
    """Posterize the image randomly with a given probability by reducing the
1865
1866
1867
    number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
    and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1868
1869
1870

    Args:
        bits (int): number of bits to keep for each channel (0-8)
1871
        p (float): probability of the image being posterized. Default value is 0.5
1872
1873
1874
1875
    """

    def __init__(self, bits, p=0.5):
        super().__init__()
1876
        _log_api_usage_once(self)
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
        self.bits = bits
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be posterized.

        Returns:
            PIL Image or Tensor: Randomly posterized image.
        """
        if torch.rand(1).item() < self.p:
            return F.posterize(img, self.bits)
        return img

Joao Gomes's avatar
Joao Gomes committed
1892
1893
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
1894
1895
1896
1897


class RandomSolarize(torch.nn.Module):
    """Solarize the image randomly with a given probability by inverting all pixel
1898
1899
1900
    values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1901
1902
1903

    Args:
        threshold (float): all pixels equal or above this value are inverted.
1904
        p (float): probability of the image being solarized. Default value is 0.5
1905
1906
1907
1908
    """

    def __init__(self, threshold, p=0.5):
        super().__init__()
1909
        _log_api_usage_once(self)
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
        self.threshold = threshold
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be solarized.

        Returns:
            PIL Image or Tensor: Randomly solarized image.
        """
        if torch.rand(1).item() < self.p:
            return F.solarize(img, self.threshold)
        return img

Joao Gomes's avatar
Joao Gomes committed
1925
1926
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
1927
1928
1929


class RandomAdjustSharpness(torch.nn.Module):
1930
1931
    """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
    it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1932
1933
1934

    Args:
        sharpness_factor (float):  How much to adjust the sharpness. Can be
1935
            any non-negative number. 0 gives a blurred image, 1 gives the
1936
            original image while 2 increases the sharpness by a factor of 2.
1937
        p (float): probability of the image being sharpened. Default value is 0.5
1938
1939
1940
1941
    """

    def __init__(self, sharpness_factor, p=0.5):
        super().__init__()
1942
        _log_api_usage_once(self)
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
        self.sharpness_factor = sharpness_factor
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be sharpened.

        Returns:
            PIL Image or Tensor: Randomly sharpened image.
        """
        if torch.rand(1).item() < self.p:
            return F.adjust_sharpness(img, self.sharpness_factor)
        return img

Joao Gomes's avatar
Joao Gomes committed
1958
1959
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
1960
1961
1962
1963


class RandomAutocontrast(torch.nn.Module):
    """Autocontrast the pixels of the given image randomly with a given probability.
1964
1965
1966
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1967
1968
1969
1970
1971
1972
1973

    Args:
        p (float): probability of the image being autocontrasted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
1974
        _log_api_usage_once(self)
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be autocontrasted.

        Returns:
            PIL Image or Tensor: Randomly autocontrasted image.
        """
        if torch.rand(1).item() < self.p:
            return F.autocontrast(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
1989
1990
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
1991
1992
1993
1994


class RandomEqualize(torch.nn.Module):
    """Equalize the histogram of the given image randomly with a given probability.
1995
1996
1997
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1998
1999
2000
2001
2002
2003
2004

    Args:
        p (float): probability of the image being equalized. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
2005
        _log_api_usage_once(self)
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be equalized.

        Returns:
            PIL Image or Tensor: Randomly equalized image.
        """
        if torch.rand(1).item() < self.p:
            return F.equalize(img)
        return img

Joao Gomes's avatar
Joao Gomes committed
2020
2021
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(p={self.p})"
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091


class ElasticTransform(torch.nn.Module):
    """Transform a tensor image with elastic transformations.
    Given alpha and sigma, it will generate displacement
    vectors for all pixels based on random offsets. Alpha controls the strength
    and sigma controls the smoothness of the displacements.
    The displacements are added to an identity grid and the resulting grid is
    used to grid_sample from the image.

    Applications:
        Randomly transforms the morphology of objects in images and produces a
        see-through-water-like effect.

    Args:
        alpha (float or sequence of floats): Magnitude of displacements. Default is 50.0.
        sigma (float or sequence of floats): Smoothness of displacements. Default is 5.0.
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.

    """

    def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINEAR, fill=0):
        super().__init__()
        _log_api_usage_once(self)
        if not isinstance(alpha, (float, Sequence)):
            raise TypeError(f"alpha should be float or a sequence of floats. Got {type(alpha)}")
        if isinstance(alpha, Sequence) and len(alpha) != 2:
            raise ValueError(f"If alpha is a sequence its length should be 2. Got {len(alpha)}")
        if isinstance(alpha, Sequence):
            for element in alpha:
                if not isinstance(element, float):
                    raise TypeError(f"alpha should be a sequence of floats. Got {type(element)}")

        if isinstance(alpha, float):
            alpha = [float(alpha), float(alpha)]
        if isinstance(alpha, (list, tuple)) and len(alpha) == 1:
            alpha = [alpha[0], alpha[0]]

        self.alpha = alpha

        if not isinstance(sigma, (float, Sequence)):
            raise TypeError(f"sigma should be float or a sequence of floats. Got {type(sigma)}")
        if isinstance(sigma, Sequence) and len(sigma) != 2:
            raise ValueError(f"If sigma is a sequence its length should be 2. Got {len(sigma)}")
        if isinstance(sigma, Sequence):
            for element in sigma:
                if not isinstance(element, float):
                    raise TypeError(f"sigma should be a sequence of floats. Got {type(element)}")

        if isinstance(sigma, float):
            sigma = [float(sigma), float(sigma)]
        if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
            sigma = [sigma[0], sigma[0]]

        self.sigma = sigma

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
            )
            interpolation = _interpolation_modes_from_int(interpolation)
        self.interpolation = interpolation

2092
2093
2094
2095
2096
2097
        if isinstance(fill, (int, float)):
            fill = [float(fill)]
        elif isinstance(fill, (list, tuple)):
            fill = [float(f) for f in fill]
        else:
            raise TypeError(f"fill should be int or float or a list or tuple of them. Got {type(fill)}")
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
        self.fill = fill

    @staticmethod
    def get_params(alpha: List[float], sigma: List[float], size: List[int]) -> Tensor:
        dx = torch.rand([1, 1] + size) * 2 - 1
        if sigma[0] > 0.0:
            kx = int(8 * sigma[0] + 1)
            # if kernel size is even we have to make it odd
            if kx % 2 == 0:
                kx += 1
            dx = F.gaussian_blur(dx, [kx, kx], sigma)
        dx = dx * alpha[0] / size[0]

        dy = torch.rand([1, 1] + size) * 2 - 1
        if sigma[1] > 0.0:
            ky = int(8 * sigma[1] + 1)
            # if kernel size is even we have to make it odd
            if ky % 2 == 0:
                ky += 1
            dy = F.gaussian_blur(dy, [ky, ky], sigma)
        dy = dy * alpha[1] / size[1]
        return torch.concat([dx, dy], 1).permute([0, 2, 3, 1])  # 1 x H x W x 2

    def forward(self, tensor: Tensor) -> Tensor:
        """
        Args:
2124
            tensor (PIL Image or Tensor): Image to be transformed.
2125
2126
2127
2128

        Returns:
            PIL Image or Tensor: Transformed image.
        """
2129
2130
        _, height, width = F.get_dimensions(tensor)
        displacement = self.get_params(self.alpha, self.sigma, [height, width])
2131
2132
2133
        return F.elastic_transform(tensor, displacement, self.interpolation, self.fill)

    def __repr__(self):
2134
2135
2136
2137
2138
        format_string = self.__class__.__name__
        format_string += f"(alpha={self.alpha}"
        format_string += f", sigma={self.sigma}"
        format_string += f", interpolation={self.interpolation}"
        format_string += f", fill={self.fill})"
2139
        return format_string