transforms.py 77.6 KB
Newer Older
1
import math
vfdev's avatar
vfdev committed
2
import numbers
3
import random
vfdev's avatar
vfdev committed
4
import warnings
vfdev's avatar
vfdev committed
5
from collections.abc import Sequence
6
from typing import Tuple, List, Optional
vfdev's avatar
vfdev committed
7
8
9
10

import torch
from torch import Tensor

11
12
13
14
15
16
try:
    import accimage
except ImportError:
    accimage = None

from . import functional as F
17
from .functional import InterpolationMode, _interpolation_modes_from_int
18

19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
__all__ = [
    "Compose",
    "ToTensor",
    "PILToTensor",
    "ConvertImageDtype",
    "ToPILImage",
    "Normalize",
    "Resize",
    "Scale",
    "CenterCrop",
    "Pad",
    "Lambda",
    "RandomApply",
    "RandomChoice",
    "RandomOrder",
    "RandomCrop",
    "RandomHorizontalFlip",
    "RandomVerticalFlip",
    "RandomResizedCrop",
    "RandomSizedCrop",
    "FiveCrop",
    "TenCrop",
    "LinearTransformation",
    "ColorJitter",
    "RandomRotation",
    "RandomAffine",
    "Grayscale",
    "RandomGrayscale",
    "RandomPerspective",
    "RandomErasing",
    "GaussianBlur",
    "InterpolationMode",
    "RandomInvert",
    "RandomPosterize",
    "RandomSolarize",
    "RandomAdjustSharpness",
    "RandomAutocontrast",
    "RandomEqualize",
]
59

60

61
class Compose:
62
63
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.
64
65
66
67
68
69
70

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
71
72
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
73
        >>> ])
74
75
76
77
78
79
80
81
82
83
84
85
86

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

87
88
89
90
91
92
93
94
95
96
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

97
    def __repr__(self):
98
        format_string = self.__class__.__name__ + "("
99
        for t in self.transforms:
100
            format_string += "\n"
101
            format_string += f"    {t}"
102
        format_string += "\n)"
103
104
        return format_string

105

106
class ToTensor:
107
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
108
109

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
surgan12's avatar
surgan12 committed
110
111
112
113
114
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.
115
116
117
118
119

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

120
    .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
121
122
123
124
125
126
127
128
129
130
131
132
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

133
    def __repr__(self):
134
        return self.__class__.__name__ + "()"
135

136

137
class PILToTensor:
138
    """Convert a ``PIL Image`` to a tensor of the same type. This transform does not support torchscript.
139

vfdev's avatar
vfdev committed
140
    Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
141
142
143
144
    """

    def __call__(self, pic):
        """
145
146
147
148
        .. note::

            A deep copy of the underlying array is performed.

149
150
151
152
153
154
155
156
157
        Args:
            pic (PIL Image): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.pil_to_tensor(pic)

    def __repr__(self):
158
        return self.__class__.__name__ + "()"
159
160


161
class ConvertImageDtype(torch.nn.Module):
162
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly
163
    This function does not support PIL Image.
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

    Args:
        dtype (torch.dtype): Desired data type of the output

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """

    def __init__(self, dtype: torch.dtype) -> None:
181
        super().__init__()
182
183
        self.dtype = dtype

vfdev's avatar
vfdev committed
184
    def forward(self, image):
185
186
187
        return F.convert_image_dtype(image, self.dtype)


188
class ToPILImage:
189
    """Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript.
190
191
192
193
194
195
196

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
vfdev's avatar
vfdev committed
197
198
199
200
201
            - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
            - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
            - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
            - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
            ``short``).
202

csukuangfj's avatar
csukuangfj committed
203
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
204
    """
205

206
207
208
209
210
211
212
213
214
215
216
217
218
219
    def __init__(self, mode=None):
        self.mode = mode

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        return F.to_pil_image(pic, self.mode)

220
    def __repr__(self):
221
        format_string = self.__class__.__name__ + "("
222
        if self.mode is not None:
223
            format_string += f"mode={self.mode}"
224
        format_string += ")"
225
        return format_string
226

227

228
class Normalize(torch.nn.Module):
Fang Gao's avatar
Fang Gao committed
229
    """Normalize a tensor image with mean and standard deviation.
230
    This transform does not support PIL Image.
231
232
233
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
abdjava's avatar
abdjava committed
234
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
235

236
    .. note::
237
        This transform acts out of place, i.e., it does not mutate the input tensor.
238

239
240
241
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
242
243
        inplace(bool,optional): Bool to make this operation in-place.

244
245
    """

surgan12's avatar
surgan12 committed
246
    def __init__(self, mean, std, inplace=False):
247
        super().__init__()
248
249
        self.mean = mean
        self.std = std
surgan12's avatar
surgan12 committed
250
        self.inplace = inplace
251

252
    def forward(self, tensor: Tensor) -> Tensor:
253
254
        """
        Args:
vfdev's avatar
vfdev committed
255
            tensor (Tensor): Tensor image to be normalized.
256
257
258
259

        Returns:
            Tensor: Normalized Tensor image.
        """
surgan12's avatar
surgan12 committed
260
        return F.normalize(tensor, self.mean, self.std, self.inplace)
261

262
    def __repr__(self):
263
        return self.__class__.__name__ + f"(mean={self.mean}, std={self.std})"
264

265

vfdev's avatar
vfdev committed
266
267
class Resize(torch.nn.Module):
    """Resize the input image to the given size.
268
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
269
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
270

271
272
273
274
    .. warning::
        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
275
276
        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
        closer.
277

278
279
280
281
282
    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
283
            (size * height / width, size).
284
285
286

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
287
288
289
290
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
            ``InterpolationMode.BICUBIC`` are supported.
291
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
292
293
294
295
        max_size (int, optional): The maximum allowed for the longer edge of
            the resized image: if the longer edge of the image is greater
            than ``max_size`` after being resized according to ``size``, then
            the image is resized again so that the longer edge is equal to
296
            ``max_size``. As a result, ``size`` might be overruled, i.e the
297
298
299
            smaller edge may be shorter than ``size``. This is only supported
            if ``size`` is an int (or a sequence of length 1 in torchscript
            mode).
300
        antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
301
302
303
            is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
            ``InterpolationMode.BILINEAR`` only mode. This can help making the output for PIL images and tensors
            closer.
304
305
306

            .. warning::
                There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
307

308
309
    """

310
    def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None):
vfdev's avatar
vfdev committed
311
        super().__init__()
312
        if not isinstance(size, (int, Sequence)):
313
            raise TypeError(f"Size should be int or sequence. Got {type(size)}")
314
315
316
        if isinstance(size, Sequence) and len(size) not in (1, 2):
            raise ValueError("If size is a sequence, it should have 1 or 2 values")
        self.size = size
317
        self.max_size = max_size
318
319
320
321

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
322
323
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
324
325
326
            )
            interpolation = _interpolation_modes_from_int(interpolation)

327
        self.interpolation = interpolation
328
        self.antialias = antialias
329

vfdev's avatar
vfdev committed
330
    def forward(self, img):
331
332
        """
        Args:
vfdev's avatar
vfdev committed
333
            img (PIL Image or Tensor): Image to be scaled.
334
335

        Returns:
vfdev's avatar
vfdev committed
336
            PIL Image or Tensor: Rescaled image.
337
        """
338
        return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
339

340
    def __repr__(self):
341
342
        detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
        return self.__class__.__name__ + detail
343

344
345
346
347
348

class Scale(Resize):
    """
    Note: This transform is deprecated in favor of Resize.
    """
349

350
    def __init__(self, *args, **kwargs):
351
352
        warnings.warn("The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.")
        super().__init__(*args, **kwargs)
353
354


vfdev's avatar
vfdev committed
355
356
class CenterCrop(torch.nn.Module):
    """Crops the given image at the center.
357
    If the image is torch Tensor, it is expected
358
359
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
360
361
362
363

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
364
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
365
366
367
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
368
        super().__init__()
369
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
370

vfdev's avatar
vfdev committed
371
    def forward(self, img):
372
373
        """
        Args:
vfdev's avatar
vfdev committed
374
            img (PIL Image or Tensor): Image to be cropped.
375
376

        Returns:
vfdev's avatar
vfdev committed
377
            PIL Image or Tensor: Cropped image.
378
379
380
        """
        return F.center_crop(img, self.size)

381
    def __repr__(self):
382
        return self.__class__.__name__ + f"(size={self.size})"
383

384

385
386
class Pad(torch.nn.Module):
    """Pad the given image on all sides with the given "pad" value.
387
    If the image is torch Tensor, it is expected
388
389
390
    to have [..., H, W] shape, where ... means at most 2 leading dimensions for mode reflect and symmetric,
    at most 3 leading dimensions for mode edge,
    and an arbitrary number of leading dimensions for mode constant
391
392

    Args:
393
394
395
        padding (int or sequence): Padding on each border. If a single int is provided this
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
396
            this is the padding for the left, top, right and bottom borders respectively.
397
398
399
400

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
401
        fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
402
            length 3, it is used to fill R, G, B channels respectively.
403
404
405
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
            Only int or str or tuple value is supported for PIL Image.
406
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
407
            Default is constant.
408
409
410

            - constant: pads with a constant value, this value is specified with fill

411
412
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
413

414
415
416
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
417

418
419
420
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
421
422
    """

423
424
425
426
427
428
429
430
431
432
433
434
    def __init__(self, padding, fill=0, padding_mode="constant"):
        super().__init__()
        if not isinstance(padding, (numbers.Number, tuple, list)):
            raise TypeError("Got inappropriate padding arg")

        if not isinstance(fill, (numbers.Number, str, tuple)):
            raise TypeError("Got inappropriate fill arg")

        if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
            raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

        if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
435
            raise ValueError(
436
                f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
437
            )
438
439
440

        self.padding = padding
        self.fill = fill
441
        self.padding_mode = padding_mode
442

443
    def forward(self, img):
444
445
        """
        Args:
446
            img (PIL Image or Tensor): Image to be padded.
447
448

        Returns:
449
            PIL Image or Tensor: Padded image.
450
        """
451
        return F.pad(img, self.padding, self.fill, self.padding_mode)
452

453
    def __repr__(self):
454
        return self.__class__.__name__ + f"(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
455

456

457
class Lambda:
458
    """Apply a user-defined lambda as a transform. This transform does not support torchscript.
459
460
461
462
463
464

    Args:
        lambd (function): Lambda/function to be used for transform.
    """

    def __init__(self, lambd):
465
        if not callable(lambd):
466
            raise TypeError(f"Argument lambd should be callable, got {repr(type(lambd).__name__)}")
467
468
469
470
471
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

472
    def __repr__(self):
473
        return self.__class__.__name__ + "()"
474

475

476
class RandomTransforms:
477
478
479
    """Base class for a list of transformations with randomness

    Args:
480
        transforms (sequence): list of transformations
481
482
483
    """

    def __init__(self, transforms):
484
485
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence")
486
487
488
489
490
491
        self.transforms = transforms

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()

    def __repr__(self):
492
        format_string = self.__class__.__name__ + "("
493
        for t in self.transforms:
494
            format_string += "\n"
495
            format_string += f"    {t}"
496
        format_string += "\n)"
497
498
499
        return format_string


500
class RandomApply(torch.nn.Module):
501
    """Apply randomly a list of transformations with a given probability.
502
503
504
505
506
507
508
509
510
511
512
513

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.
514
515

    Args:
516
        transforms (sequence or torch.nn.Module): list of transformations
517
518
519
520
        p (float): probability
    """

    def __init__(self, transforms, p=0.5):
521
522
        super().__init__()
        self.transforms = transforms
523
524
        self.p = p

525
526
    def forward(self, img):
        if self.p < torch.rand(1):
527
528
529
530
531
532
            return img
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
533
        format_string = self.__class__.__name__ + "("
534
        format_string += f"\n    p={self.p}"
535
        for t in self.transforms:
536
            format_string += "\n"
537
            format_string += f"    {t}"
538
        format_string += "\n)"
539
540
541
542
        return format_string


class RandomOrder(RandomTransforms):
543
544
    """Apply a list of transformations in a random order. This transform does not support torchscript."""

545
546
547
548
549
550
551
552
553
    def __call__(self, img):
        order = list(range(len(self.transforms)))
        random.shuffle(order)
        for i in order:
            img = self.transforms[i](img)
        return img


class RandomChoice(RandomTransforms):
554
555
    """Apply single transformation randomly picked from a list. This transform does not support torchscript."""

556
557
558
    def __init__(self, transforms, p=None):
        super().__init__(transforms)
        if p is not None and not isinstance(p, Sequence):
559
            raise TypeError("Argument p should be a sequence")
560
561
562
563
564
565
566
567
        self.p = p

    def __call__(self, *args):
        t = random.choices(self.transforms, weights=self.p)[0]
        return t(*args)

    def __repr__(self):
        format_string = super().__repr__()
568
        format_string += f"(p={self.p})"
569
        return format_string
570
571


vfdev's avatar
vfdev committed
572
573
class RandomCrop(torch.nn.Module):
    """Crop the given image at a random location.
574
    If the image is torch Tensor, it is expected
575
576
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions,
    but if non-constant padding is used, the input is expected to have at most 2 leading dimensions
577
578
579
580

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
581
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
582
        padding (int or sequence, optional): Optional padding on each border
vfdev's avatar
vfdev committed
583
            of the image. Default is None. If a single int is provided this
584
585
            is used to pad all borders. If sequence of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a sequence of length 4 is provided
vfdev's avatar
vfdev committed
586
            this is the padding for the left, top, right and bottom borders respectively.
587
588
589
590

            .. note::
                In torchscript mode padding as single int is not supported, use a sequence of
                length 1: ``[padding, ]``.
591
        pad_if_needed (boolean): It will pad the image if smaller than the
ekka's avatar
ekka committed
592
            desired size to avoid raising an exception. Since cropping is done
593
            after padding, the padding seems to be done at a random offset.
594
        fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
595
            length 3, it is used to fill R, G, B channels respectively.
596
597
598
            This value is only used when the padding_mode is constant.
            Only number is supported for torch Tensor.
            Only int or str or tuple value is supported for PIL Image.
599
600
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is constant.
601

602
            - constant: pads with a constant value, this value is specified with fill
603

604
605
            - edge: pads with the last value at the edge of the image.
              If input a 5D torch Tensor, the last 3 dimensions will be padded instead of the last 2
606

607
608
609
            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]
610

611
612
613
            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]
614
615
616
    """

    @staticmethod
vfdev's avatar
vfdev committed
617
    def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
618
619
620
        """Get parameters for ``crop`` for a random crop.

        Args:
vfdev's avatar
vfdev committed
621
            img (PIL Image or Tensor): Image to be cropped.
622
623
624
625
626
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
627
        w, h = F.get_image_size(img)
628
        th, tw = output_size
vfdev's avatar
vfdev committed
629
630

        if h + 1 < th or w + 1 < tw:
631
            raise ValueError(f"Required crop size {(th, tw)} is larger then input image size {(h, w)}")
vfdev's avatar
vfdev committed
632

633
634
635
        if w == tw and h == th:
            return 0, 0, h, w

636
637
        i = torch.randint(0, h - th + 1, size=(1,)).item()
        j = torch.randint(0, w - tw + 1, size=(1,)).item()
638
639
        return i, j, th, tw

vfdev's avatar
vfdev committed
640
641
642
    def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
        super().__init__()

643
        self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
644

vfdev's avatar
vfdev committed
645
646
647
648
649
650
        self.padding = padding
        self.pad_if_needed = pad_if_needed
        self.fill = fill
        self.padding_mode = padding_mode

    def forward(self, img):
651
652
        """
        Args:
vfdev's avatar
vfdev committed
653
            img (PIL Image or Tensor): Image to be cropped.
654
655

        Returns:
vfdev's avatar
vfdev committed
656
            PIL Image or Tensor: Cropped image.
657
        """
658
659
        if self.padding is not None:
            img = F.pad(img, self.padding, self.fill, self.padding_mode)
660

661
        width, height = F.get_image_size(img)
662
        # pad the width if needed
vfdev's avatar
vfdev committed
663
664
665
        if self.pad_if_needed and width < self.size[1]:
            padding = [self.size[1] - width, 0]
            img = F.pad(img, padding, self.fill, self.padding_mode)
666
        # pad the height if needed
vfdev's avatar
vfdev committed
667
668
669
        if self.pad_if_needed and height < self.size[0]:
            padding = [0, self.size[0] - height]
            img = F.pad(img, padding, self.fill, self.padding_mode)
670

671
672
673
674
        i, j, h, w = self.get_params(img, self.size)

        return F.crop(img, i, j, h, w)

675
    def __repr__(self):
676
        return self.__class__.__name__ + f"(size={self.size}, padding={self.padding})"
677

678

679
680
class RandomHorizontalFlip(torch.nn.Module):
    """Horizontally flip the given image randomly with a given probability.
681
    If the image is torch Tensor, it is expected
682
683
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
684
685
686
687
688
689

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
690
        super().__init__()
691
        self.p = p
692

693
    def forward(self, img):
694
695
        """
        Args:
696
            img (PIL Image or Tensor): Image to be flipped.
697
698

        Returns:
699
            PIL Image or Tensor: Randomly flipped image.
700
        """
701
        if torch.rand(1) < self.p:
702
703
704
            return F.hflip(img)
        return img

705
    def __repr__(self):
706
        return self.__class__.__name__ + f"(p={self.p})"
707

708

709
class RandomVerticalFlip(torch.nn.Module):
vfdev's avatar
vfdev committed
710
    """Vertically flip the given image randomly with a given probability.
711
    If the image is torch Tensor, it is expected
712
713
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
714
715
716
717
718
719

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
720
        super().__init__()
721
        self.p = p
722

723
    def forward(self, img):
724
725
        """
        Args:
726
            img (PIL Image or Tensor): Image to be flipped.
727
728

        Returns:
729
            PIL Image or Tensor: Randomly flipped image.
730
        """
731
        if torch.rand(1) < self.p:
732
733
734
            return F.vflip(img)
        return img

735
    def __repr__(self):
736
        return self.__class__.__name__ + f"(p={self.p})"
737

738

739
740
class RandomPerspective(torch.nn.Module):
    """Performs a random perspective transformation of the given image with a given probability.
741
    If the image is torch Tensor, it is expected
742
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
743
744

    Args:
745
746
747
        distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
            Default is 0.5.
        p (float): probability of the image being transformed. Default is 0.5.
748
749
750
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
751
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
752
753
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
754
755
    """

756
    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
757
        super().__init__()
758
        self.p = p
759
760
761
762

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
763
764
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
765
766
767
            )
            interpolation = _interpolation_modes_from_int(interpolation)

768
769
        self.interpolation = interpolation
        self.distortion_scale = distortion_scale
770
771
772
773
774
775

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

776
        self.fill = fill
777

778
    def forward(self, img):
779
780
        """
        Args:
781
            img (PIL Image or Tensor): Image to be Perspectively transformed.
782
783

        Returns:
784
            PIL Image or Tensor: Randomly transformed image.
785
        """
786
787
788
789

        fill = self.fill
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
790
                fill = [float(fill)] * F.get_image_num_channels(img)
791
792
793
            else:
                fill = [float(f) for f in fill]

794
        if torch.rand(1) < self.p:
795
            width, height = F.get_image_size(img)
796
            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
797
            return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
798
799
800
        return img

    @staticmethod
801
    def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[List[int]], List[List[int]]]:
802
803
804
        """Get parameters for ``perspective`` for a random perspective transform.

        Args:
805
806
807
            width (int): width of the image.
            height (int): height of the image.
            distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
808
809

        Returns:
810
            List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
811
812
            List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
        """
813
814
815
        half_height = height // 2
        half_width = width // 2
        topleft = [
816
817
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
818
819
        ]
        topright = [
820
821
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
            int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
822
823
        ]
        botright = [
824
825
            int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
826
827
        ]
        botleft = [
828
829
            int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
            int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
830
831
        ]
        startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
832
833
834
835
        endpoints = [topleft, topright, botright, botleft]
        return startpoints, endpoints

    def __repr__(self):
836
        return self.__class__.__name__ + f"(p={self.p})"
837
838


839
class RandomResizedCrop(torch.nn.Module):
840
841
    """Crop a random portion of image and resize it to a given size.

842
    If the image is torch Tensor, it is expected
843
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
844

845
846
847
    A crop of the original image is made: the crop has a random area (H * W)
    and a random aspect ratio. This crop is finally resized to the given
    size. This is popularly used to train the Inception networks.
848
849

    Args:
850
        size (int or sequence): expected output size of the crop, for each edge. If size is an
851
            int instead of sequence like (h, w), a square output size ``(size, size)`` is
852
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
853
854
855

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
Nicolas Hug's avatar
Nicolas Hug committed
856
857
        scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
            before resizing. The scale is defined with respect to the area of the original image.
858
859
        ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before
            resizing.
860
861
862
863
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
            ``InterpolationMode.BICUBIC`` are supported.
864
865
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.

866
867
    """

868
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR):
869
        super().__init__()
870
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
871

872
        if not isinstance(scale, Sequence):
873
            raise TypeError("Scale should be a sequence")
874
        if not isinstance(ratio, Sequence):
875
            raise TypeError("Ratio should be a sequence")
876
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
877
            warnings.warn("Scale and ratio should be of kind (min, max)")
878

879
880
881
        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
882
883
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
884
885
886
            )
            interpolation = _interpolation_modes_from_int(interpolation)

887
        self.interpolation = interpolation
888
889
        self.scale = scale
        self.ratio = ratio
890
891

    @staticmethod
892
    def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]:
893
894
895
        """Get parameters for ``crop`` for a random sized crop.

        Args:
896
            img (PIL Image or Tensor): Input image.
897
898
            scale (list): range of scale of the origin size cropped
            ratio (list): range of aspect ratio of the origin aspect ratio cropped
899
900
901

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
902
            sized crop.
903
        """
904
        width, height = F.get_image_size(img)
Zhicheng Yan's avatar
Zhicheng Yan committed
905
        area = height * width
906

907
        log_ratio = torch.log(torch.tensor(ratio))
908
        for _ in range(10):
909
            target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
910
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
911
912
913
914

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

Zhicheng Yan's avatar
Zhicheng Yan committed
915
            if 0 < w <= width and 0 < h <= height:
916
917
                i = torch.randint(0, height - h + 1, size=(1,)).item()
                j = torch.randint(0, width - w + 1, size=(1,)).item()
918
919
                return i, j, h, w

920
        # Fallback to central crop
Zhicheng Yan's avatar
Zhicheng Yan committed
921
        in_ratio = float(width) / float(height)
922
        if in_ratio < min(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
923
            w = width
924
            h = int(round(w / min(ratio)))
925
        elif in_ratio > max(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
926
            h = height
927
            w = int(round(h * max(ratio)))
928
        else:  # whole image
Zhicheng Yan's avatar
Zhicheng Yan committed
929
930
931
932
            w = width
            h = height
        i = (height - h) // 2
        j = (width - w) // 2
933
        return i, j, h, w
934

935
    def forward(self, img):
936
937
        """
        Args:
938
            img (PIL Image or Tensor): Image to be cropped and resized.
939
940

        Returns:
941
            PIL Image or Tensor: Randomly cropped and resized image.
942
        """
943
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
944
945
        return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)

946
    def __repr__(self):
947
        interpolate_str = self.interpolation.value
948
949
950
951
        format_string = self.__class__.__name__ + f"(size={self.size}"
        format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
        format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
        format_string += f", interpolation={interpolate_str})"
952
        return format_string
953

954
955
956
957
958

class RandomSizedCrop(RandomResizedCrop):
    """
    Note: This transform is deprecated in favor of RandomResizedCrop.
    """
959

960
    def __init__(self, *args, **kwargs):
961
962
963
964
        warnings.warn(
            "The use of the transforms.RandomSizedCrop transform is deprecated, "
            + "please use transforms.RandomResizedCrop instead."
        )
965
        super().__init__(*args, **kwargs)
966
967


vfdev's avatar
vfdev committed
968
969
class FiveCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop.
970
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
971
972
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
973
974
975
976
977
978
979
980
981

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.
982
            If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
983
984
985
986
987
988
989
990
991
992
993
994
995
996

    Example:
         >>> transform = Compose([
         >>>    FiveCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
997
        super().__init__()
998
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
999

vfdev's avatar
vfdev committed
1000
1001
1002
1003
1004
1005
1006
1007
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 5 images. Image can be PIL Image or Tensor
        """
1008
1009
        return F.five_crop(img, self.size)

1010
    def __repr__(self):
1011
        return self.__class__.__name__ + f"(size={self.size})"
1012

1013

vfdev's avatar
vfdev committed
1014
1015
1016
class TenCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default).
1017
    If the image is torch Tensor, it is expected
vfdev's avatar
vfdev committed
1018
1019
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
1020
1021
1022
1023
1024
1025
1026
1027
1028

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
1029
            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
1030
        vertical_flip (bool): Use vertical flipping instead of horizontal
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044

    Example:
         >>> transform = Compose([
         >>>    TenCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size, vertical_flip=False):
vfdev's avatar
vfdev committed
1045
        super().__init__()
1046
        self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
1047
1048
        self.vertical_flip = vertical_flip

vfdev's avatar
vfdev committed
1049
1050
1051
1052
1053
1054
1055
1056
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 10 images. Image can be PIL Image or Tensor
        """
1057
1058
        return F.ten_crop(img, self.size, self.vertical_flip)

1059
    def __repr__(self):
1060
        return self.__class__.__name__ + f"(size={self.size}, vertical_flip={self.vertical_flip})"
1061

1062

1063
class LinearTransformation(torch.nn.Module):
ekka's avatar
ekka committed
1064
    """Transform a tensor image with a square transformation matrix and a mean_vector computed
1065
    offline.
1066
    This transform does not support PIL Image.
ekka's avatar
ekka committed
1067
1068
1069
    Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
    subtract mean_vector from it which is then followed by computing the dot
    product with the transformation matrix and then reshaping the tensor to its
1070
    original shape.
1071

1072
    Applications:
1073
        whitening transformation: Suppose X is a column vector zero-centered data.
1074
1075
1076
        Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
        perform SVD on this matrix and pass it as transformation_matrix.

1077
1078
    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
ekka's avatar
ekka committed
1079
        mean_vector (Tensor): tensor [D], D = C x H x W
1080
1081
    """

ekka's avatar
ekka committed
1082
    def __init__(self, transformation_matrix, mean_vector):
1083
        super().__init__()
1084
        if transformation_matrix.size(0) != transformation_matrix.size(1):
1085
1086
            raise ValueError(
                "transformation_matrix should be square. Got "
1087
                f"{tuple(transformation_matrix.size())} rectangular matrix."
1088
            )
ekka's avatar
ekka committed
1089
1090

        if mean_vector.size(0) != transformation_matrix.size(0):
1091
            raise ValueError(
1092
1093
                f"mean_vector should have the same length {mean_vector.size(0)}"
                f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
1094
            )
ekka's avatar
ekka committed
1095

1096
        if transformation_matrix.device != mean_vector.device:
1097
            raise ValueError(
1098
                f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
1099
            )
1100

1101
        self.transformation_matrix = transformation_matrix
ekka's avatar
ekka committed
1102
        self.mean_vector = mean_vector
1103

1104
    def forward(self, tensor: Tensor) -> Tensor:
1105
1106
        """
        Args:
vfdev's avatar
vfdev committed
1107
            tensor (Tensor): Tensor image to be whitened.
1108
1109
1110
1111

        Returns:
            Tensor: Transformed image.
        """
1112
1113
1114
        shape = tensor.shape
        n = shape[-3] * shape[-2] * shape[-1]
        if n != self.transformation_matrix.shape[0]:
1115
1116
            raise ValueError(
                "Input tensor and transformation matrix have incompatible shape."
1117
1118
                + f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
                + f"{self.transformation_matrix.shape[0]}"
1119
            )
1120
1121

        if tensor.device.type != self.mean_vector.device.type:
1122
1123
            raise ValueError(
                "Input tensor should be on the same device as transformation matrix and mean vector. "
1124
                f"Got {tensor.device} vs {self.mean_vector.device}"
1125
            )
1126
1127

        flat_tensor = tensor.view(-1, n) - self.mean_vector
1128
        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
1129
        tensor = transformed_tensor.view(shape)
1130
1131
        return tensor

1132
    def __repr__(self):
1133
1134
1135
        format_string = self.__class__.__name__ + "(transformation_matrix="
        format_string += str(self.transformation_matrix.tolist()) + ")"
        format_string += ", (mean_vector=" + str(self.mean_vector.tolist()) + ")"
1136
1137
        return format_string

1138

1139
class ColorJitter(torch.nn.Module):
1140
    """Randomly change the brightness, contrast, saturation and hue of an image.
1141
    If the image is torch Tensor, it is expected
1142
1143
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported.
1144
1145

    Args:
yaox12's avatar
yaox12 committed
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
        brightness (float or tuple of float (min, max)): How much to jitter brightness.
            brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
            or the given [min, max]. Should be non negative numbers.
        contrast (float or tuple of float (min, max)): How much to jitter contrast.
            contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
            or the given [min, max]. Should be non negative numbers.
        saturation (float or tuple of float (min, max)): How much to jitter saturation.
            saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
            or the given [min, max]. Should be non negative numbers.
        hue (float or tuple of float (min, max)): How much to jitter hue.
            hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
            Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1158
    """
1159

1160
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
1161
        super().__init__()
1162
1163
1164
1165
        self.brightness = self._check_input(brightness, "brightness")
        self.contrast = self._check_input(contrast, "contrast")
        self.saturation = self._check_input(saturation, "saturation")
        self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
yaox12's avatar
yaox12 committed
1166

1167
    @torch.jit.unused
1168
    def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
yaox12's avatar
yaox12 committed
1169
1170
        if isinstance(value, numbers.Number):
            if value < 0:
1171
                raise ValueError(f"If {name} is a single number, it must be non negative.")
1172
            value = [center - float(value), center + float(value)]
yaox12's avatar
yaox12 committed
1173
            if clip_first_on_zero:
1174
                value[0] = max(value[0], 0.0)
yaox12's avatar
yaox12 committed
1175
1176
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
1177
                raise ValueError(f"{name} values should be between {bound}")
yaox12's avatar
yaox12 committed
1178
        else:
1179
            raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
yaox12's avatar
yaox12 committed
1180
1181
1182
1183
1184
1185

        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
            value = None
        return value
1186
1187

    @staticmethod
1188
1189
1190
1191
1192
1193
    def get_params(
        brightness: Optional[List[float]],
        contrast: Optional[List[float]],
        saturation: Optional[List[float]],
        hue: Optional[List[float]],
    ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
1194
        """Get the parameters for the randomized transform to be applied on image.
1195

1196
1197
1198
1199
1200
1201
1202
1203
1204
        Args:
            brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
                uniformly. Pass None to turn off the transformation.
            contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
                uniformly. Pass None to turn off the transformation.
            saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
                uniformly. Pass None to turn off the transformation.
            hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
                Pass None to turn off the transformation.
1205
1206

        Returns:
1207
1208
            tuple: The parameters used to apply the randomized transform
            along with their random order.
1209
        """
1210
        fn_idx = torch.randperm(4)
1211

1212
1213
1214
1215
        b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
        c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
        s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
        h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
1216

1217
        return fn_idx, b, c, s, h
1218

1219
    def forward(self, img):
1220
1221
        """
        Args:
1222
            img (PIL Image or Tensor): Input image.
1223
1224

        Returns:
1225
1226
            PIL Image or Tensor: Color jittered image.
        """
1227
1228
1229
        fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
            self.brightness, self.contrast, self.saturation, self.hue
        )
1230

1231
        for fn_id in fn_idx:
1232
            if fn_id == 0 and brightness_factor is not None:
1233
                img = F.adjust_brightness(img, brightness_factor)
1234
            elif fn_id == 1 and contrast_factor is not None:
1235
                img = F.adjust_contrast(img, contrast_factor)
1236
            elif fn_id == 2 and saturation_factor is not None:
1237
                img = F.adjust_saturation(img, saturation_factor)
1238
            elif fn_id == 3 and hue_factor is not None:
1239
1240
1241
                img = F.adjust_hue(img, hue_factor)

        return img
1242

1243
    def __repr__(self):
1244
        format_string = self.__class__.__name__ + "("
1245
1246
1247
1248
        format_string += f"brightness={self.brightness}"
        format_string += f", contrast={self.contrast}"
        format_string += f", saturation={self.saturation}"
        format_string += f", hue={self.hue})"
1249
        return format_string
1250

1251

1252
class RandomRotation(torch.nn.Module):
1253
    """Rotate the image by angle.
1254
    If the image is torch Tensor, it is expected
1255
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1256
1257

    Args:
1258
        degrees (sequence or number): Range of degrees to select from.
1259
1260
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees).
1261
1262
1263
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1264
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1265
1266
1267
1268
        expand (bool, optional): Optional expansion flag.
            If true, expands the output to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1269
        center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1270
            Default is the center of the image.
1271
1272
        fill (sequence or number): Pixel fill value for the area outside the rotated
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1273
        resample (int, optional): deprecated argument and will be removed since v0.10.0.
1274
            Please use the ``interpolation`` parameter instead.
1275
1276
1277

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1278
1279
    """

1280
    def __init__(
1281
        self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None
1282
    ):
1283
        super().__init__()
1284
1285
1286
1287
1288
1289
1290
1291
1292
        if resample is not None:
            warnings.warn(
                "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
            )
            interpolation = _interpolation_modes_from_int(resample)

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
1293
1294
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
1295
1296
1297
            )
            interpolation = _interpolation_modes_from_int(interpolation)

1298
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1299
1300

        if center is not None:
1301
            _check_sequence_input(center, "center", req_sizes=(2,))
1302
1303

        self.center = center
1304

1305
        self.resample = self.interpolation = interpolation
1306
        self.expand = expand
1307
1308
1309
1310
1311
1312

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1313
        self.fill = fill
1314
1315

    @staticmethod
1316
    def get_params(degrees: List[float]) -> float:
1317
1318
1319
        """Get parameters for ``rotate`` for a random rotation.

        Returns:
1320
            float: angle parameter to be passed to ``rotate`` for random rotation.
1321
        """
1322
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1323
1324
        return angle

1325
    def forward(self, img):
1326
        """
1327
        Args:
1328
            img (PIL Image or Tensor): Image to be rotated.
1329
1330

        Returns:
1331
            PIL Image or Tensor: Rotated image.
1332
        """
1333
1334
1335
        fill = self.fill
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
1336
                fill = [float(fill)] * F.get_image_num_channels(img)
1337
1338
            else:
                fill = [float(f) for f in fill]
1339
        angle = self.get_params(self.degrees)
1340
1341

        return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
1342

1343
    def __repr__(self):
1344
        interpolate_str = self.interpolation.value
1345
1346
1347
        format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
        format_string += f", interpolation={interpolate_str}"
        format_string += f", expand={self.expand}"
1348
        if self.center is not None:
1349
            format_string += f", center={self.center}"
1350
        if self.fill is not None:
1351
            format_string += f", fill={self.fill}"
1352
        format_string += ")"
1353
        return format_string
1354

1355

1356
1357
class RandomAffine(torch.nn.Module):
    """Random affine transformation of the image keeping center invariant.
1358
    If the image is torch Tensor, it is expected
1359
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1360
1361

    Args:
1362
        degrees (sequence or number): Range of degrees to select from.
1363
            If degrees is a number instead of sequence like (min, max), the range of degrees
1364
            will be (-degrees, +degrees). Set to 0 to deactivate rotations.
1365
1366
1367
1368
1369
1370
        translate (tuple, optional): tuple of maximum absolute fraction for horizontal
            and vertical translations. For example translate=(a, b), then horizontal shift
            is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
            randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
        scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
            randomly sampled from the range a <= scale <= b. Will keep original scale by default.
1371
        shear (sequence or number, optional): Range of degrees to select from.
ptrblck's avatar
ptrblck committed
1372
            If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
1373
1374
            will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the
            range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
ptrblck's avatar
ptrblck committed
1375
            a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1376
            Will not apply shear by default.
1377
1378
1379
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1380
            For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1381
1382
        fill (sequence or number): Pixel fill value for the area outside the transformed
            image. Default is ``0``. If given a number, the value is used for all bands respectively.
1383
        fillcolor (sequence or number, optional): deprecated argument and will be removed since v0.10.0.
1384
            Please use the ``fill`` parameter instead.
1385
        resample (int, optional): deprecated argument and will be removed since v0.10.0.
1386
            Please use the ``interpolation`` parameter instead.
1387
1388
1389

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1390
1391
    """

1392
    def __init__(
1393
1394
1395
1396
1397
1398
1399
1400
1401
        self,
        degrees,
        translate=None,
        scale=None,
        shear=None,
        interpolation=InterpolationMode.NEAREST,
        fill=0,
        fillcolor=None,
        resample=None,
1402
    ):
1403
        super().__init__()
1404
1405
1406
1407
1408
1409
1410
1411
1412
        if resample is not None:
            warnings.warn(
                "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"
            )
            interpolation = _interpolation_modes_from_int(resample)

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            warnings.warn(
1413
1414
                "Argument interpolation should be of type InterpolationMode instead of int. "
                "Please, use InterpolationMode enum."
1415
1416
1417
1418
1419
1420
1421
1422
1423
            )
            interpolation = _interpolation_modes_from_int(interpolation)

        if fillcolor is not None:
            warnings.warn(
                "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead"
            )
            fill = fillcolor

1424
        self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
1425
1426

        if translate is not None:
1427
            _check_sequence_input(translate, "translate", req_sizes=(2,))
1428
1429
1430
1431
1432
1433
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
1434
            _check_sequence_input(scale, "scale", req_sizes=(2,))
1435
1436
1437
1438
1439
1440
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
1441
            self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4))
1442
1443
1444
        else:
            self.shear = shear

1445
        self.resample = self.interpolation = interpolation
1446
1447
1448
1449
1450
1451

        if fill is None:
            fill = 0
        elif not isinstance(fill, (Sequence, numbers.Number)):
            raise TypeError("Fill should be either a sequence or a number.")

1452
        self.fillcolor = self.fill = fill
1453
1454

    @staticmethod
1455
    def get_params(
1456
1457
1458
1459
1460
        degrees: List[float],
        translate: Optional[List[float]],
        scale_ranges: Optional[List[float]],
        shears: Optional[List[float]],
        img_size: List[int],
1461
    ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
1462
1463
1464
        """Get parameters for affine transformation

        Returns:
1465
            params to be passed to the affine transformation
1466
        """
1467
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1468
        if translate is not None:
1469
1470
1471
1472
1473
            max_dx = float(translate[0] * img_size[0])
            max_dy = float(translate[1] * img_size[1])
            tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
            ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
            translations = (tx, ty)
1474
1475
1476
1477
        else:
            translations = (0, 0)

        if scale_ranges is not None:
1478
            scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
1479
1480
1481
        else:
            scale = 1.0

1482
        shear_x = shear_y = 0.0
1483
        if shears is not None:
1484
1485
1486
1487
1488
            shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
            if len(shears) == 4:
                shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())

        shear = (shear_x, shear_y)
1489
1490
1491

        return angle, translations, scale, shear

1492
    def forward(self, img):
1493
        """
1494
            img (PIL Image or Tensor): Image to be transformed.
1495
1496

        Returns:
1497
            PIL Image or Tensor: Affine transformed image.
1498
        """
1499
1500
1501
        fill = self.fill
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
1502
                fill = [float(fill)] * F.get_image_num_channels(img)
1503
1504
            else:
                fill = [float(f) for f in fill]
1505

1506
        img_size = F.get_image_size(img)
1507
1508

        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
1509
1510

        return F.affine(img, *ret, interpolation=self.interpolation, fill=fill)
1511
1512

    def __repr__(self):
1513
        s = "{name}(degrees={degrees}"
1514
        if self.translate is not None:
1515
            s += ", translate={translate}"
1516
        if self.scale is not None:
1517
            s += ", scale={scale}"
1518
        if self.shear is not None:
1519
            s += ", shear={shear}"
1520
        if self.interpolation != InterpolationMode.NEAREST:
1521
            s += ", interpolation={interpolation}"
1522
        if self.fill != 0:
1523
1524
            s += ", fill={fill}"
        s += ")"
1525
        d = dict(self.__dict__)
1526
        d["interpolation"] = self.interpolation.value
1527
1528
1529
        return s.format(name=self.__class__.__name__, **d)


1530
class Grayscale(torch.nn.Module):
1531
    """Convert image to grayscale.
1532
1533
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1534

1535
1536
1537
1538
    Args:
        num_output_channels (int): (1 or 3) number of channels desired for output image

    Returns:
1539
        PIL Image: Grayscale version of the input.
1540
1541
1542

        - If ``num_output_channels == 1`` : returned image is single channel
        - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
1543
1544
1545
1546

    """

    def __init__(self, num_output_channels=1):
1547
        super().__init__()
1548
1549
        self.num_output_channels = num_output_channels

vfdev's avatar
vfdev committed
1550
    def forward(self, img):
1551
1552
        """
        Args:
1553
            img (PIL Image or Tensor): Image to be converted to grayscale.
1554
1555

        Returns:
1556
            PIL Image or Tensor: Grayscaled image.
1557
        """
1558
        return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
1559

1560
    def __repr__(self):
1561
        return self.__class__.__name__ + f"(num_output_channels={self.num_output_channels})"
1562

1563

1564
class RandomGrayscale(torch.nn.Module):
1565
    """Randomly convert image to grayscale with a probability of p (default 0.1).
1566
1567
    If the image is torch Tensor, it is expected
    to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
1568

1569
1570
1571
1572
    Args:
        p (float): probability that image should be converted to grayscale.

    Returns:
1573
        PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
1574
1575
1576
        with probability (1-p).
        - If input image is 1 channel: grayscale version is 1 channel
        - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1577
1578
1579
1580

    """

    def __init__(self, p=0.1):
1581
        super().__init__()
1582
1583
        self.p = p

vfdev's avatar
vfdev committed
1584
    def forward(self, img):
1585
1586
        """
        Args:
1587
            img (PIL Image or Tensor): Image to be converted to grayscale.
1588
1589

        Returns:
1590
            PIL Image or Tensor: Randomly grayscaled image.
1591
        """
1592
        num_output_channels = F.get_image_num_channels(img)
1593
1594
        if torch.rand(1) < self.p:
            return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
1595
        return img
1596
1597

    def __repr__(self):
1598
        return self.__class__.__name__ + f"(p={self.p})"
1599
1600


1601
class RandomErasing(torch.nn.Module):
1602
    """Randomly selects a rectangle region in an torch Tensor image and erases its pixels.
1603
    This transform does not support PIL Image.
vfdev's avatar
vfdev committed
1604
    'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
1605

1606
1607
1608
1609
1610
1611
1612
1613
    Args:
         p: probability that the random erasing operation will be performed.
         scale: range of proportion of erased area against input image.
         ratio: range of aspect ratio of erased area.
         value: erasing value. Default is 0. If a single int, it is used to
            erase all pixels. If a tuple of length 3, it is used to erase
            R, G, B channels respectively.
            If a str of 'random', erasing each pixel with random values.
Zhun Zhong's avatar
Zhun Zhong committed
1614
         inplace: boolean to make this transform inplace. Default set to False.
1615

1616
1617
    Returns:
        Erased Image.
1618

vfdev's avatar
vfdev committed
1619
    Example:
1620
        >>> transform = transforms.Compose([
1621
        >>>   transforms.RandomHorizontalFlip(),
1622
1623
        >>>   transforms.PILToTensor(),
        >>>   transforms.ConvertImageDtype(torch.float),
1624
1625
        >>>   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>>   transforms.RandomErasing(),
1626
1627
1628
        >>> ])
    """

Zhun Zhong's avatar
Zhun Zhong committed
1629
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
1630
1631
1632
1633
1634
1635
1636
1637
1638
        super().__init__()
        if not isinstance(value, (numbers.Number, str, tuple, list)):
            raise TypeError("Argument value should be either a number or str or a sequence")
        if isinstance(value, str) and value != "random":
            raise ValueError("If value is str, it should be 'random'")
        if not isinstance(scale, (tuple, list)):
            raise TypeError("Scale should be a sequence")
        if not isinstance(ratio, (tuple, list)):
            raise TypeError("Ratio should be a sequence")
1639
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
1640
            warnings.warn("Scale and ratio should be of kind (min, max)")
1641
        if scale[0] < 0 or scale[1] > 1:
1642
            raise ValueError("Scale should be between 0 and 1")
1643
        if p < 0 or p > 1:
1644
            raise ValueError("Random erasing probability should be between 0 and 1")
1645
1646
1647
1648
1649

        self.p = p
        self.scale = scale
        self.ratio = ratio
        self.value = value
1650
        self.inplace = inplace
1651
1652

    @staticmethod
1653
    def get_params(
1654
        img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
1655
    ) -> Tuple[int, int, int, int, Tensor]:
1656
1657
1658
        """Get parameters for ``erase`` for a random erasing.

        Args:
vfdev's avatar
vfdev committed
1659
            img (Tensor): Tensor image to be erased.
1660
1661
            scale (sequence): range of proportion of erased area against input image.
            ratio (sequence): range of aspect ratio of erased area.
1662
1663
1664
            value (list, optional): erasing value. If None, it is interpreted as "random"
                (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
                i.e. ``value[0]``.
1665
1666
1667
1668

        Returns:
            tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
        """
vfdev's avatar
vfdev committed
1669
        img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
1670
        area = img_h * img_w
1671

1672
        log_ratio = torch.log(torch.tensor(ratio))
1673
        for _ in range(10):
1674
            erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
1675
            aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
1676
1677
1678

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
1679
1680
1681
1682
1683
1684
1685
            if not (h < img_h and w < img_w):
                continue

            if value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(value)[:, None, None]
1686

1687
1688
            i = torch.randint(0, img_h - h + 1, size=(1,)).item()
            j = torch.randint(0, img_w - w + 1, size=(1,)).item()
1689
            return i, j, h, w, v
1690

Zhun Zhong's avatar
Zhun Zhong committed
1691
1692
1693
        # Return original image
        return 0, 0, img_h, img_w, img

1694
    def forward(self, img):
1695
1696
        """
        Args:
vfdev's avatar
vfdev committed
1697
            img (Tensor): Tensor image to be erased.
1698
1699
1700
1701

        Returns:
            img (Tensor): Erased Tensor image.
        """
1702
1703
1704
1705
        if torch.rand(1) < self.p:

            # cast self.value to script acceptable type
            if isinstance(self.value, (int, float)):
1706
1707
1708
                value = [
                    self.value,
                ]
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
            elif isinstance(self.value, str):
                value = None
            elif isinstance(self.value, tuple):
                value = list(self.value)
            else:
                value = self.value

            if value is not None and not (len(value) in (1, img.shape[-3])):
                raise ValueError(
                    "If value is a sequence, it should have either a single value or "
1719
                    f"{img.shape[-3]} (number of input channels)"
1720
1721
1722
                )

            x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
1723
            return F.erase(img, x, y, h, w, v, self.inplace)
1724
        return img
1725

1726
    def __repr__(self):
1727
1728
1729
1730
1731
        s = f"(p={self.p}, "
        s += f"scale={self.scale}, "
        s += f"ratio={self.ratio}, "
        s += f"value={self.value}, "
        s += f"inplace={self.inplace})"
1732
1733
        return self.__class__.__name__ + s

1734

1735
1736
class GaussianBlur(torch.nn.Module):
    """Blurs image with randomly chosen Gaussian blur.
1737
1738
    If the image is torch Tensor, it is expected
    to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions.
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763

    Args:
        kernel_size (int or sequence): Size of the Gaussian kernel.
        sigma (float or tuple of float (min, max)): Standard deviation to be used for
            creating kernel to perform blurring. If float, sigma is fixed. If it is tuple
            of float (min, max), sigma is chosen uniformly at random to lie in the
            given range.

    Returns:
        PIL Image or Tensor: Gaussian blurred version of the input image.

    """

    def __init__(self, kernel_size, sigma=(0.1, 2.0)):
        super().__init__()
        self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
        for ks in self.kernel_size:
            if ks <= 0 or ks % 2 == 0:
                raise ValueError("Kernel size value should be an odd and positive number.")

        if isinstance(sigma, numbers.Number):
            if sigma <= 0:
                raise ValueError("If sigma is a single number, it must be positive.")
            sigma = (sigma, sigma)
        elif isinstance(sigma, Sequence) and len(sigma) == 2:
1764
            if not 0.0 < sigma[0] <= sigma[1]:
1765
1766
1767
1768
1769
1770
1771
1772
                raise ValueError("sigma values should be positive and of the form (min, max).")
        else:
            raise ValueError("sigma should be a single number or a list/tuple with length 2.")

        self.sigma = sigma

    @staticmethod
    def get_params(sigma_min: float, sigma_max: float) -> float:
vfdev's avatar
vfdev committed
1773
        """Choose sigma for random gaussian blurring.
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786

        Args:
            sigma_min (float): Minimum standard deviation that can be chosen for blurring kernel.
            sigma_max (float): Maximum standard deviation that can be chosen for blurring kernel.

        Returns:
            float: Standard deviation to be passed to calculate kernel for gaussian blurring.
        """
        return torch.empty(1).uniform_(sigma_min, sigma_max).item()

    def forward(self, img: Tensor) -> Tensor:
        """
        Args:
vfdev's avatar
vfdev committed
1787
            img (PIL Image or Tensor): image to be blurred.
1788
1789
1790
1791
1792
1793
1794
1795

        Returns:
            PIL Image or Tensor: Gaussian blurred image
        """
        sigma = self.get_params(self.sigma[0], self.sigma[1])
        return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])

    def __repr__(self):
1796
1797
        s = f"(kernel_size={self.kernel_size}, "
        s += f"sigma={self.sigma})"
1798
1799
1800
        return self.__class__.__name__ + s


1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
def _setup_size(size, error_msg):
    if isinstance(size, numbers.Number):
        return int(size), int(size)

    if isinstance(size, Sequence) and len(size) == 1:
        return size[0], size[0]

    if len(size) != 2:
        raise ValueError(error_msg)

    return size


def _check_sequence_input(x, name, req_sizes):
    msg = req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes])
    if not isinstance(x, Sequence):
1817
        raise TypeError(f"{name} should be a sequence of length {msg}.")
1818
    if len(x) not in req_sizes:
1819
        raise ValueError(f"{name} should be sequence of length {msg}.")
1820
1821


1822
def _setup_angle(x, name, req_sizes=(2,)):
1823
1824
    if isinstance(x, numbers.Number):
        if x < 0:
1825
            raise ValueError(f"If {name} is a single number, it must be positive.")
1826
1827
1828
1829
1830
        x = [-x, x]
    else:
        _check_sequence_input(x, name, req_sizes)

    return [float(d) for d in x]
1831
1832
1833
1834


class RandomInvert(torch.nn.Module):
    """Inverts the colors of the given image randomly with a given probability.
1835
1836
1837
    If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859

    Args:
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be inverted.

        Returns:
            PIL Image or Tensor: Randomly color inverted image.
        """
        if torch.rand(1).item() < self.p:
            return F.invert(img)
        return img

    def __repr__(self):
1860
        return self.__class__.__name__ + f"(p={self.p})"
1861
1862
1863
1864


class RandomPosterize(torch.nn.Module):
    """Posterize the image randomly with a given probability by reducing the
1865
1866
1867
    number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
    and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891

    Args:
        bits (int): number of bits to keep for each channel (0-8)
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, bits, p=0.5):
        super().__init__()
        self.bits = bits
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be posterized.

        Returns:
            PIL Image or Tensor: Randomly posterized image.
        """
        if torch.rand(1).item() < self.p:
            return F.posterize(img, self.bits)
        return img

    def __repr__(self):
1892
        return self.__class__.__name__ + f"(bits={self.bits},p={self.p})"
1893
1894
1895
1896


class RandomSolarize(torch.nn.Module):
    """Solarize the image randomly with a given probability by inverting all pixel
1897
1898
1899
    values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
    where ... means it can have an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923

    Args:
        threshold (float): all pixels equal or above this value are inverted.
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, threshold, p=0.5):
        super().__init__()
        self.threshold = threshold
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be solarized.

        Returns:
            PIL Image or Tensor: Randomly solarized image.
        """
        if torch.rand(1).item() < self.p:
            return F.solarize(img, self.threshold)
        return img

    def __repr__(self):
1924
        return self.__class__.__name__ + f"(threshold={self.threshold},p={self.p})"
1925
1926
1927


class RandomAdjustSharpness(torch.nn.Module):
1928
1929
    """Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
    it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955

    Args:
        sharpness_factor (float):  How much to adjust the sharpness. Can be
            any non negative number. 0 gives a blurred image, 1 gives the
            original image while 2 increases the sharpness by a factor of 2.
        p (float): probability of the image being color inverted. Default value is 0.5
    """

    def __init__(self, sharpness_factor, p=0.5):
        super().__init__()
        self.sharpness_factor = sharpness_factor
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be sharpened.

        Returns:
            PIL Image or Tensor: Randomly sharpened image.
        """
        if torch.rand(1).item() < self.p:
            return F.adjust_sharpness(img, self.sharpness_factor)
        return img

    def __repr__(self):
1956
        return self.__class__.__name__ + f"(sharpness_factor={self.sharpness_factor},p={self.p})"
1957
1958
1959
1960


class RandomAutocontrast(torch.nn.Module):
    """Autocontrast the pixels of the given image randomly with a given probability.
1961
1962
1963
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985

    Args:
        p (float): probability of the image being autocontrasted. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be autocontrasted.

        Returns:
            PIL Image or Tensor: Randomly autocontrasted image.
        """
        if torch.rand(1).item() < self.p:
            return F.autocontrast(img)
        return img

    def __repr__(self):
1986
        return self.__class__.__name__ + f"(p={self.p})"
1987
1988
1989
1990


class RandomEqualize(torch.nn.Module):
    """Equalize the histogram of the given image randomly with a given probability.
1991
1992
1993
    If the image is torch Tensor, it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015

    Args:
        p (float): probability of the image being equalized. Default value is 0.5
    """

    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be equalized.

        Returns:
            PIL Image or Tensor: Randomly equalized image.
        """
        if torch.rand(1).item() < self.p:
            return F.equalize(img)
        return img

    def __repr__(self):
2016
        return self.__class__.__name__ + f"(p={self.p})"