transforms.py 59.8 KB
Newer Older
1
import math
vfdev's avatar
vfdev committed
2
import numbers
3
import random
vfdev's avatar
vfdev committed
4
import warnings
vfdev's avatar
vfdev committed
5
from collections.abc import Sequence
6
from typing import Tuple, List, Optional
vfdev's avatar
vfdev committed
7
8

import torch
9
from PIL import Image
vfdev's avatar
vfdev committed
10
11
from torch import Tensor

12
13
14
15
16
17
18
try:
    import accimage
except ImportError:
    accimage = None

from . import functional as F

Tongzhou Wang's avatar
Tongzhou Wang committed
19

20
21
22
23
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
           "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
           "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
           "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
24
           "RandomPerspective", "RandomErasing"]
25

26
27
28
29
30
_pil_interpolation_to_str = {
    Image.NEAREST: 'PIL.Image.NEAREST',
    Image.BILINEAR: 'PIL.Image.BILINEAR',
    Image.BICUBIC: 'PIL.Image.BICUBIC',
    Image.LANCZOS: 'PIL.Image.LANCZOS',
surgan12's avatar
surgan12 committed
31
32
    Image.HAMMING: 'PIL.Image.HAMMING',
    Image.BOX: 'PIL.Image.BOX',
33
34
}

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

57
58
59
60
61
62
63
64
    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

65
66
67
68
69

class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
surgan12's avatar
surgan12 committed
70
71
72
73
74
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.
75
76
77
78
79
80

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

    .. _references: https://github.com/pytorch/vision/tree/master/references/segmentation
81
82
83
84
85
86
87
88
89
90
91
92
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

93
94
95
    def __repr__(self):
        return self.__class__.__name__ + '()'

96

97
98
99
class PILToTensor(object):
    """Convert a ``PIL Image`` to a tensor of the same type.

vfdev's avatar
vfdev committed
100
    Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.pil_to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class ConvertImageDtype(object):
    """Convert a tensor image to the given ``dtype`` and scale the values accordingly

    Args:
        dtype (torch.dtype): Desired data type of the output

    .. note::

        When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
        If converted back and forth, this mismatch has no effect.

    Raises:
        RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
            well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
            overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
            of the integer ``dtype``.
    """

    def __init__(self, dtype: torch.dtype) -> None:
        self.dtype = dtype

    def __call__(self, image: torch.Tensor) -> torch.Tensor:
        return F.convert_image_dtype(image, self.dtype)


142
143
144
145
146
147
148
149
150
class ToPILImage(object):
    """Convert a tensor or an ndarray to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
surgan12's avatar
surgan12 committed
151
152
153
154
             - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
             - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
             - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
             - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
155
               ``short``).
156

csukuangfj's avatar
csukuangfj committed
157
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    """
    def __init__(self, mode=None):
        self.mode = mode

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        return F.to_pil_image(pic, self.mode)

173
    def __repr__(self):
174
175
176
177
178
        format_string = self.__class__.__name__ + '('
        if self.mode is not None:
            format_string += 'mode={0}'.format(self.mode)
        format_string += ')'
        return format_string
179

180
181

class Normalize(object):
Fang Gao's avatar
Fang Gao committed
182
    """Normalize a tensor image with mean and standard deviation.
183
184
185
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
abdjava's avatar
abdjava committed
186
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
187

188
    .. note::
189
        This transform acts out of place, i.e., it does not mutate the input tensor.
190

191
192
193
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
194
195
        inplace(bool,optional): Bool to make this operation in-place.

196
197
    """

surgan12's avatar
surgan12 committed
198
    def __init__(self, mean, std, inplace=False):
199
200
        self.mean = mean
        self.std = std
surgan12's avatar
surgan12 committed
201
        self.inplace = inplace
202
203
204
205
206
207
208
209
210

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized Tensor image.
        """
surgan12's avatar
surgan12 committed
211
        return F.normalize(tensor, self.mean, self.std, self.inplace)
212

213
214
215
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

216

vfdev's avatar
vfdev committed
217
218
219
220
class Resize(torch.nn.Module):
    """Resize the input image to the given size.
    The image can be a PIL Image or a torch Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
221
222
223
224
225
226

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
vfdev's avatar
vfdev committed
227
228
229
            (size * height / width, size).
            In torchscript mode padding as single int is not supported, use a tuple or
            list of length 1: ``[size, ]``.
vfdev's avatar
vfdev committed
230
231
232
        interpolation (int, optional): Desired interpolation enum defined by `filters`_.
            Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
            and ``PIL.Image.BICUBIC`` are supported.
233
234
235
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
vfdev's avatar
vfdev committed
236
237
238
239
240
        super().__init__()
        if not isinstance(size, (int, Sequence)):
            raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
        if isinstance(size, Sequence) and len(size) not in (1, 2):
            raise ValueError("If size is a sequence, it should have 1 or 2 values")
241
242
243
        self.size = size
        self.interpolation = interpolation

vfdev's avatar
vfdev committed
244
    def forward(self, img):
245
246
        """
        Args:
vfdev's avatar
vfdev committed
247
            img (PIL Image or Tensor): Image to be scaled.
248
249

        Returns:
vfdev's avatar
vfdev committed
250
            PIL Image or Tensor: Rescaled image.
251
252
253
        """
        return F.resize(img, self.size, self.interpolation)

254
    def __repr__(self):
255
256
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
257

258
259
260
261
262
263
264
265
266
267
268

class Scale(Resize):
    """
    Note: This transform is deprecated in favor of Resize.
    """
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                      "please use transforms.Resize instead.")
        super(Scale, self).__init__(*args, **kwargs)


vfdev's avatar
vfdev committed
269
270
271
272
class CenterCrop(torch.nn.Module):
    """Crops the given image at the center.
    The image can be a PIL Image or a torch Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
273
274
275
276

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
vfdev's avatar
vfdev committed
277
            made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
278
279
280
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
281
        super().__init__()
282
283
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
vfdev's avatar
vfdev committed
284
285
        elif isinstance(size, Sequence) and len(size) == 1:
            self.size = (size[0], size[0])
286
        else:
vfdev's avatar
vfdev committed
287
288
289
            if len(size) != 2:
                raise ValueError("Please provide only two dimensions (h, w) for size.")

290
291
            self.size = size

vfdev's avatar
vfdev committed
292
    def forward(self, img):
293
294
        """
        Args:
vfdev's avatar
vfdev committed
295
            img (PIL Image or Tensor): Image to be cropped.
296
297

        Returns:
vfdev's avatar
vfdev committed
298
            PIL Image or Tensor: Cropped image.
299
300
301
        """
        return F.center_crop(img, self.size)

302
303
304
    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)

305

306
307
308
309
class Pad(torch.nn.Module):
    """Pad the given image on all sides with the given "pad" value.
    The image can be a PIL Image or a torch Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
310
311

    Args:
312
        padding (int or tuple or list): Padding on each border. If a single int is provided this
313
314
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
315
316
317
            this is the padding for the left, top, right and bottom borders respectively.
            In torchscript mode padding as single int is not supported, use a tuple or
            list of length 1: ``[padding, ]``.
318
        fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
319
            length 3, it is used to fill R, G, B channels respectively.
320
            This value is only used when the padding_mode is constant
321
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
vfdev's avatar
vfdev committed
322
            Default is constant. Mode symmetric is not yet supported for Tensor inputs.
323
324
325
326
327
328
329
330

            - constant: pads with a constant value, this value is specified with fill

            - edge: pads with the last value at the edge of the image

            - reflect: pads with reflection of image without repeating the last value on the edge

                For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
331
                will result in [3, 2, 1, 2, 3, 4, 3, 2]
332
333
334
335

            - symmetric: pads with reflection of image repeating the last value on the edge

                For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
336
                will result in [2, 1, 1, 2, 3, 4, 4, 3]
337
338
    """

339
340
341
342
343
344
345
346
347
348
349
350
351
    def __init__(self, padding, fill=0, padding_mode="constant"):
        super().__init__()
        if not isinstance(padding, (numbers.Number, tuple, list)):
            raise TypeError("Got inappropriate padding arg")

        if not isinstance(fill, (numbers.Number, str, tuple)):
            raise TypeError("Got inappropriate fill arg")

        if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
            raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

        if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
            raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
352
353
354
355
                             "{} element tuple".format(len(padding)))

        self.padding = padding
        self.fill = fill
356
        self.padding_mode = padding_mode
357

358
    def forward(self, img):
359
360
        """
        Args:
361
            img (PIL Image or Tensor): Image to be padded.
362
363

        Returns:
364
            PIL Image or Tensor: Padded image.
365
        """
366
        return F.pad(img, self.padding, self.fill, self.padding_mode)
367

368
    def __repr__(self):
369
370
        return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
            format(self.padding, self.fill, self.padding_mode)
371

372
373
374
375
376
377
378
379
380

class Lambda(object):
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """

    def __init__(self, lambd):
381
        assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
382
383
384
385
386
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

387
388
389
    def __repr__(self):
        return self.__class__.__name__ + '()'

390

391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
class RandomTransforms(object):
    """Base class for a list of transformations with randomness

    Args:
        transforms (list or tuple): list of transformations
    """

    def __init__(self, transforms):
        assert isinstance(transforms, (list, tuple))
        self.transforms = transforms

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class RandomApply(RandomTransforms):
    """Apply randomly a list of transformations with a given probability

    Args:
        transforms (list or tuple): list of transformations
        p (float): probability
    """

    def __init__(self, transforms, p=0.5):
        super(RandomApply, self).__init__(transforms)
        self.p = p

    def __call__(self, img):
        if self.p < random.random():
            return img
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += '\n    p={}'.format(self.p)
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class RandomOrder(RandomTransforms):
    """Apply a list of transformations in a random order
    """
    def __call__(self, img):
        order = list(range(len(self.transforms)))
        random.shuffle(order)
        for i in order:
            img = self.transforms[i](img)
        return img


class RandomChoice(RandomTransforms):
    """Apply single transformation randomly picked from a list
    """
    def __call__(self, img):
        t = random.choice(self.transforms)
        return t(img)


vfdev's avatar
vfdev committed
462
463
464
465
466
class RandomCrop(torch.nn.Module):
    """Crop the given image at a random location.
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
467
468
469
470

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
vfdev's avatar
vfdev committed
471
            made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
472
        padding (int or sequence, optional): Optional padding on each border
vfdev's avatar
vfdev committed
473
474
475
476
477
478
            of the image. Default is None. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders respectively.
            In torchscript mode padding as single int is not supported, use a tuple or
            list of length 1: ``[padding, ]``.
479
        pad_if_needed (boolean): It will pad the image if smaller than the
ekka's avatar
ekka committed
480
            desired size to avoid raising an exception. Since cropping is done
481
            after padding, the padding seems to be done at a random offset.
vfdev's avatar
vfdev committed
482
        fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
483
484
            length 3, it is used to fill R, G, B channels respectively.
            This value is only used when the padding_mode is constant
vfdev's avatar
vfdev committed
485
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
vfdev's avatar
vfdev committed
486
            Mode symmetric is not yet supported for Tensor inputs.
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501

             - constant: pads with a constant value, this value is specified with fill

             - edge: pads with the last value on the edge of the image

             - reflect: pads with reflection of image (without repeating the last value on the edge)

                padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
                will result in [3, 2, 1, 2, 3, 4, 3, 2]

             - symmetric: pads with reflection of image (repeating the last value on the edge)

                padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
                will result in [2, 1, 1, 2, 3, 4, 4, 3]

502
503
504
    """

    @staticmethod
vfdev's avatar
vfdev committed
505
    def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
506
507
508
        """Get parameters for ``crop`` for a random crop.

        Args:
vfdev's avatar
vfdev committed
509
            img (PIL Image or Tensor): Image to be cropped.
510
511
512
513
514
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
vfdev's avatar
vfdev committed
515
        w, h = F._get_image_size(img)
516
517
518
519
        th, tw = output_size
        if w == tw and h == th:
            return 0, 0, h, w

520
521
        i = torch.randint(0, h - th + 1, size=(1, )).item()
        j = torch.randint(0, w - tw + 1, size=(1, )).item()
522
523
        return i, j, th, tw

vfdev's avatar
vfdev committed
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
        super().__init__()
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        elif isinstance(size, Sequence) and len(size) == 1:
            self.size = (size[0], size[0])
        else:
            if len(size) != 2:
                raise ValueError("Please provide only two dimensions (h, w) for size.")

            # cast to tuple for torchscript
            self.size = tuple(size)
        self.padding = padding
        self.pad_if_needed = pad_if_needed
        self.fill = fill
        self.padding_mode = padding_mode

    def forward(self, img):
542
543
        """
        Args:
vfdev's avatar
vfdev committed
544
            img (PIL Image or Tensor): Image to be cropped.
545
546

        Returns:
vfdev's avatar
vfdev committed
547
            PIL Image or Tensor: Cropped image.
548
        """
549
550
        if self.padding is not None:
            img = F.pad(img, self.padding, self.fill, self.padding_mode)
551

vfdev's avatar
vfdev committed
552
        width, height = F._get_image_size(img)
553
        # pad the width if needed
vfdev's avatar
vfdev committed
554
555
556
        if self.pad_if_needed and width < self.size[1]:
            padding = [self.size[1] - width, 0]
            img = F.pad(img, padding, self.fill, self.padding_mode)
557
        # pad the height if needed
vfdev's avatar
vfdev committed
558
559
560
        if self.pad_if_needed and height < self.size[0]:
            padding = [0, self.size[0] - height]
            img = F.pad(img, padding, self.fill, self.padding_mode)
561

562
563
564
565
        i, j, h, w = self.get_params(img, self.size)

        return F.crop(img, i, j, h, w)

566
    def __repr__(self):
vfdev's avatar
vfdev committed
567
        return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding)
568

569

570
571
572
573
574
class RandomHorizontalFlip(torch.nn.Module):
    """Horizontally flip the given image randomly with a given probability.
    The image can be a PIL Image or a torch Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
575
576
577
578
579
580

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
581
        super().__init__()
582
        self.p = p
583

584
    def forward(self, img):
585
586
        """
        Args:
587
            img (PIL Image or Tensor): Image to be flipped.
588
589

        Returns:
590
            PIL Image or Tensor: Randomly flipped image.
591
        """
592
        if torch.rand(1) < self.p:
593
594
595
            return F.hflip(img)
        return img

596
    def __repr__(self):
597
        return self.__class__.__name__ + '(p={})'.format(self.p)
598

599

600
class RandomVerticalFlip(torch.nn.Module):
vfdev's avatar
vfdev committed
601
    """Vertically flip the given image randomly with a given probability.
602
603
604
    The image can be a PIL Image or a torch Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
605
606
607
608
609
610

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
611
        super().__init__()
612
        self.p = p
613

614
    def forward(self, img):
615
616
        """
        Args:
617
            img (PIL Image or Tensor): Image to be flipped.
618
619

        Returns:
620
            PIL Image or Tensor: Randomly flipped image.
621
        """
622
        if torch.rand(1) < self.p:
623
624
625
            return F.vflip(img)
        return img

626
    def __repr__(self):
627
        return self.__class__.__name__ + '(p={})'.format(self.p)
628

629

630
631
632
633
634
635
636
637
638
639
class RandomPerspective(object):
    """Performs Perspective transformation of the given PIL Image randomly with a given probability.

    Args:
        interpolation : Default- Image.BICUBIC

        p (float): probability of the image being perspectively transformed. Default value is 0.5

        distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.

640
641
        fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
            If int, it is used for all channels respectively. Default value is 0.
642
643
    """

644
    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC, fill=0):
645
646
647
        self.p = p
        self.interpolation = interpolation
        self.distortion_scale = distortion_scale
648
        self.fill = fill
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be Perspectively transformed.

        Returns:
            PIL Image: Random perspectivley transformed image.
        """
        if not F._is_pil_image(img):
            raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

        if random.random() < self.p:
            width, height = img.size
            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
664
            return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
665
666
667
668
669
670
671
672
673
674
675
        return img

    @staticmethod
    def get_params(width, height, distortion_scale):
        """Get parameters for ``perspective`` for a random perspective transform.

        Args:
            width : width of the image.
            height : height of the image.

        Returns:
676
            List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
            List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
        """
        half_height = int(height / 2)
        half_width = int(width / 2)
        topleft = (random.randint(0, int(distortion_scale * half_width)),
                   random.randint(0, int(distortion_scale * half_height)))
        topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
                    random.randint(0, int(distortion_scale * half_height)))
        botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
                    random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
        botleft = (random.randint(0, int(distortion_scale * half_width)),
                   random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
        startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
        endpoints = [topleft, topright, botright, botleft]
        return startpoints, endpoints

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


697
698
699
700
class RandomResizedCrop(torch.nn.Module):
    """Crop the given image to random size and aspect ratio.
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
701

702
703
    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
704
705
706
707
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
708
709
710
711
712
        size (int or sequence): expected output size of each edge. If size is an
            int instead of sequence like (h, w), a square output size ``(size, size)`` is
            made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
        scale (tuple of float): range of size of the origin size cropped
        ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
vfdev's avatar
vfdev committed
713
714
715
        interpolation (int): Desired interpolation enum defined by `filters`_.
            Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR``
            and ``PIL.Image.BICUBIC`` are supported.
716
717
    """

718
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
719
720
721
722
723
        super().__init__()
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        elif isinstance(size, Sequence) and len(size) == 1:
            self.size = (size[0], size[0])
724
        else:
725
726
727
728
            if len(size) != 2:
                raise ValueError("Please provide only two dimensions (h, w) for size.")
            self.size = size

729
        if not isinstance(scale, Sequence):
730
            raise TypeError("Scale should be a sequence")
731
        if not isinstance(ratio, Sequence):
732
            raise TypeError("Ratio should be a sequence")
733
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
734
            warnings.warn("Scale and ratio should be of kind (min, max)")
735

736
        self.interpolation = interpolation
737
738
        self.scale = scale
        self.ratio = ratio
739
740

    @staticmethod
741
    def get_params(
742
            img: Tensor, scale: List[float], ratio: List[float]
743
    ) -> Tuple[int, int, int, int]:
744
745
746
        """Get parameters for ``crop`` for a random sized crop.

        Args:
747
            img (PIL Image or Tensor): Input image.
748
749
            scale (list): range of scale of the origin size cropped
            ratio (list): range of aspect ratio of the origin aspect ratio cropped
750
751
752
753
754

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
                sized crop.
        """
vfdev's avatar
vfdev committed
755
        width, height = F._get_image_size(img)
Zhicheng Yan's avatar
Zhicheng Yan committed
756
        area = height * width
757

758
        for _ in range(10):
759
            target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
760
761
762
763
            log_ratio = torch.log(torch.tensor(ratio))
            aspect_ratio = torch.exp(
                torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
            ).item()
764
765
766
767

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

Zhicheng Yan's avatar
Zhicheng Yan committed
768
            if 0 < w <= width and 0 < h <= height:
769
770
                i = torch.randint(0, height - h + 1, size=(1,)).item()
                j = torch.randint(0, width - w + 1, size=(1,)).item()
771
772
                return i, j, h, w

773
        # Fallback to central crop
Zhicheng Yan's avatar
Zhicheng Yan committed
774
        in_ratio = float(width) / float(height)
775
        if in_ratio < min(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
776
            w = width
777
            h = int(round(w / min(ratio)))
778
        elif in_ratio > max(ratio):
Zhicheng Yan's avatar
Zhicheng Yan committed
779
            h = height
780
            w = int(round(h * max(ratio)))
781
        else:  # whole image
Zhicheng Yan's avatar
Zhicheng Yan committed
782
783
784
785
            w = width
            h = height
        i = (height - h) // 2
        j = (width - w) // 2
786
        return i, j, h, w
787

788
    def forward(self, img):
789
790
        """
        Args:
791
            img (PIL Image or Tensor): Image to be cropped and resized.
792
793

        Returns:
794
            PIL Image or Tensor: Randomly cropped and resized image.
795
        """
796
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
797
798
        return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)

799
    def __repr__(self):
800
801
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
802
803
        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
804
805
        format_string += ', interpolation={0})'.format(interpolate_str)
        return format_string
806

807
808
809
810
811
812
813
814
815
816
817

class RandomSizedCrop(RandomResizedCrop):
    """
    Note: This transform is deprecated in favor of RandomResizedCrop.
    """
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
                      "please use transforms.RandomResizedCrop instead.")
        super(RandomSizedCrop, self).__init__(*args, **kwargs)


vfdev's avatar
vfdev committed
818
819
820
821
822
class FiveCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop.
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
823
824
825
826
827
828
829
830
831

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.
vfdev's avatar
vfdev committed
832
            If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
833
834
835
836
837
838
839
840
841
842
843
844
845
846

    Example:
         >>> transform = Compose([
         >>>    FiveCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size):
vfdev's avatar
vfdev committed
847
        super().__init__()
848
849
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
vfdev's avatar
vfdev committed
850
851
        elif isinstance(size, Sequence) and len(size) == 1:
            self.size = (size[0], size[0])
852
        else:
vfdev's avatar
vfdev committed
853
854
855
            if len(size) != 2:
                raise ValueError("Please provide only two dimensions (h, w) for size.")

856
857
            self.size = size

vfdev's avatar
vfdev committed
858
859
860
861
862
863
864
865
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 5 images. Image can be PIL Image or Tensor
        """
866
867
        return F.five_crop(img, self.size)

868
869
870
    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)

871

vfdev's avatar
vfdev committed
872
873
874
875
876
877
class TenCrop(torch.nn.Module):
    """Crop the given image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default).
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading
    dimensions
878
879
880
881
882
883
884
885
886

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
vfdev's avatar
vfdev committed
887
            made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
888
        vertical_flip (bool): Use vertical flipping instead of horizontal
889
890
891
892
893
894
895
896
897
898
899
900
901
902

    Example:
         >>> transform = Compose([
         >>>    TenCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size, vertical_flip=False):
vfdev's avatar
vfdev committed
903
        super().__init__()
904
905
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
vfdev's avatar
vfdev committed
906
907
        elif isinstance(size, Sequence) and len(size) == 1:
            self.size = (size[0], size[0])
908
        else:
vfdev's avatar
vfdev committed
909
910
911
            if len(size) != 2:
                raise ValueError("Please provide only two dimensions (h, w) for size.")

912
913
914
            self.size = size
        self.vertical_flip = vertical_flip

vfdev's avatar
vfdev committed
915
916
917
918
919
920
921
922
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            tuple of 10 images. Image can be PIL Image or Tensor
        """
923
924
        return F.ten_crop(img, self.size, self.vertical_flip)

925
    def __repr__(self):
926
        return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
927

928

929
class LinearTransformation(object):
ekka's avatar
ekka committed
930
    """Transform a tensor image with a square transformation matrix and a mean_vector computed
931
    offline.
ekka's avatar
ekka committed
932
933
934
    Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
    subtract mean_vector from it which is then followed by computing the dot
    product with the transformation matrix and then reshaping the tensor to its
935
    original shape.
936

937
    Applications:
938
        whitening transformation: Suppose X is a column vector zero-centered data.
939
940
941
        Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
        perform SVD on this matrix and pass it as transformation_matrix.

942
943
    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
ekka's avatar
ekka committed
944
        mean_vector (Tensor): tensor [D], D = C x H x W
945
946
    """

ekka's avatar
ekka committed
947
    def __init__(self, transformation_matrix, mean_vector):
948
949
950
        if transformation_matrix.size(0) != transformation_matrix.size(1):
            raise ValueError("transformation_matrix should be square. Got " +
                             "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
ekka's avatar
ekka committed
951
952
953

        if mean_vector.size(0) != transformation_matrix.size(0):
            raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) +
Francisco Massa's avatar
Francisco Massa committed
954
955
                             " as any one of the dimensions of the transformation_matrix [{}]"
                             .format(tuple(transformation_matrix.size())))
ekka's avatar
ekka committed
956

957
        self.transformation_matrix = transformation_matrix
ekka's avatar
ekka committed
958
        self.mean_vector = mean_vector
959
960
961
962
963
964
965
966
967
968
969
970
971

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be whitened.

        Returns:
            Tensor: Transformed image.
        """
        if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
            raise ValueError("tensor and transformation matrix have incompatible shape." +
                             "[{} x {} x {}] != ".format(*tensor.size()) +
                             "{}".format(self.transformation_matrix.size(0)))
ekka's avatar
ekka committed
972
        flat_tensor = tensor.view(1, -1) - self.mean_vector
973
974
975
976
        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
        tensor = transformed_tensor.view(tensor.size())
        return tensor

977
    def __repr__(self):
ekka's avatar
ekka committed
978
979
980
        format_string = self.__class__.__name__ + '(transformation_matrix='
        format_string += (str(self.transformation_matrix.tolist()) + ')')
        format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')')
981
982
        return format_string

983

984
class ColorJitter(torch.nn.Module):
985
986
987
    """Randomly change the brightness, contrast and saturation of an image.

    Args:
yaox12's avatar
yaox12 committed
988
989
990
991
992
993
994
995
996
997
998
999
        brightness (float or tuple of float (min, max)): How much to jitter brightness.
            brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
            or the given [min, max]. Should be non negative numbers.
        contrast (float or tuple of float (min, max)): How much to jitter contrast.
            contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
            or the given [min, max]. Should be non negative numbers.
        saturation (float or tuple of float (min, max)): How much to jitter saturation.
            saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
            or the given [min, max]. Should be non negative numbers.
        hue (float or tuple of float (min, max)): How much to jitter hue.
            hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
            Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1000
    """
1001

1002
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
1003
        super().__init__()
yaox12's avatar
yaox12 committed
1004
1005
1006
1007
1008
1009
        self.brightness = self._check_input(brightness, 'brightness')
        self.contrast = self._check_input(contrast, 'contrast')
        self.saturation = self._check_input(saturation, 'saturation')
        self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
                                     clip_first_on_zero=False)

1010
    @torch.jit.unused
yaox12's avatar
yaox12 committed
1011
1012
1013
1014
    def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError("If {} is a single number, it must be non negative.".format(name))
1015
            value = [center - float(value), center + float(value)]
yaox12's avatar
yaox12 committed
1016
            if clip_first_on_zero:
1017
                value[0] = max(value[0], 0.0)
yaox12's avatar
yaox12 committed
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError("{} values should be between {}".format(name, bound))
        else:
            raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))

        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
            value = None
        return value
1029
1030

    @staticmethod
1031
    @torch.jit.unused
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
    def get_params(brightness, contrast, saturation, hue):
        """Get a randomized transform to be applied on image.

        Arguments are same as that of __init__.

        Returns:
            Transform which randomly adjusts brightness, contrast and
            saturation in a random order.
        """
        transforms = []
yaox12's avatar
yaox12 committed
1042
1043
1044

        if brightness is not None:
            brightness_factor = random.uniform(brightness[0], brightness[1])
1045
1046
            transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))

yaox12's avatar
yaox12 committed
1047
1048
        if contrast is not None:
            contrast_factor = random.uniform(contrast[0], contrast[1])
1049
1050
            transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))

yaox12's avatar
yaox12 committed
1051
1052
        if saturation is not None:
            saturation_factor = random.uniform(saturation[0], saturation[1])
1053
1054
            transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))

yaox12's avatar
yaox12 committed
1055
1056
        if hue is not None:
            hue_factor = random.uniform(hue[0], hue[1])
1057
1058
            transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))

vfdev's avatar
vfdev committed
1059
        random.shuffle(transforms)
1060
1061
1062
1063
        transform = Compose(transforms)

        return transform

1064
    def forward(self, img):
1065
1066
        """
        Args:
1067
            img (PIL Image or Tensor): Input image.
1068
1069

        Returns:
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
            PIL Image or Tensor: Color jittered image.
        """
        fn_idx = torch.randperm(4)
        for fn_id in fn_idx:
            if fn_id == 0 and self.brightness is not None:
                brightness = self.brightness
                brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
                img = F.adjust_brightness(img, brightness_factor)

            if fn_id == 1 and self.contrast is not None:
                contrast = self.contrast
                contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
                img = F.adjust_contrast(img, contrast_factor)

            if fn_id == 2 and self.saturation is not None:
                saturation = self.saturation
                saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
                img = F.adjust_saturation(img, saturation_factor)

            if fn_id == 3 and self.hue is not None:
                hue = self.hue
                hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
                img = F.adjust_hue(img, hue_factor)

        return img
1095

1096
    def __repr__(self):
1097
1098
1099
1100
1101
1102
        format_string = self.__class__.__name__ + '('
        format_string += 'brightness={0}'.format(self.brightness)
        format_string += ', contrast={0}'.format(self.contrast)
        format_string += ', saturation={0}'.format(self.saturation)
        format_string += ', hue={0})'.format(self.hue)
        return format_string
1103

1104

1105
class RandomRotation(torch.nn.Module):
1106
    """Rotate the image by angle.
1107
1108
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1109
1110
1111
1112
1113

    Args:
        degrees (sequence or float or int): Range of degrees to select from.
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees).
1114
        resample (int, optional): An optional resampling filter. See `filters`_ for more information.
1115
            If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
1116
            If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
1117
1118
1119
1120
        expand (bool, optional): Optional expansion flag.
            If true, expands the output to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
1121
        center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
1122
            Default is the center of the image.
Philip Meier's avatar
Philip Meier committed
1123
1124
        fill (n-tuple or int or float): Pixel fill value for area outside the rotated
            image. If int or float, the value is used for all bands respectively.
1125
1126
1127
            Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0.
            This option is not supported for Tensor input. Fill value for the area outside the transform in the output
            image is always 0.
1128
1129
1130

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1131
1132
    """

Philip Meier's avatar
Philip Meier committed
1133
    def __init__(self, degrees, resample=False, expand=False, center=None, fill=None):
1134
        super().__init__()
1135
1136
1137
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError("If degrees is a single number, it must be positive.")
1138
            degrees = [-degrees, degrees]
1139
        else:
1140
1141
            if not isinstance(degrees, Sequence):
                raise TypeError("degrees should be a sequence of length 2.")
1142
1143
            if len(degrees) != 2:
                raise ValueError("If degrees is a sequence, it must be of len 2.")
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153

        self.degrees = [float(d) for d in degrees]

        if center is not None:
            if not isinstance(center, Sequence):
                raise TypeError("center should be a sequence of length 2.")
            if len(center) != 2:
                raise ValueError("center should be a sequence of length 2.")

        self.center = center
1154
1155
1156

        self.resample = resample
        self.expand = expand
1157
        self.fill = fill
1158
1159

    @staticmethod
1160
    def get_params(degrees: List[float]) -> float:
1161
1162
1163
        """Get parameters for ``rotate`` for a random rotation.

        Returns:
1164
            float: angle parameter to be passed to ``rotate`` for random rotation.
1165
        """
1166
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1167
1168
        return angle

1169
    def forward(self, img):
1170
        """
1171
        Args:
1172
            img (PIL Image or Tensor): Image to be rotated.
1173
1174

        Returns:
1175
            PIL Image or Tensor: Rotated image.
1176
1177
        """
        angle = self.get_params(self.degrees)
1178
        return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
1179

1180
    def __repr__(self):
1181
1182
1183
1184
1185
        format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
        format_string += ', resample={0}'.format(self.resample)
        format_string += ', expand={0}'.format(self.expand)
        if self.center is not None:
            format_string += ', center={0}'.format(self.center)
1186
1187
        if self.fill is not None:
            format_string += ', fill={0}'.format(self.fill)
1188
1189
        format_string += ')'
        return format_string
1190

1191

1192
1193
1194
1195
class RandomAffine(torch.nn.Module):
    """Random affine transformation of the image keeping center invariant.
    The image can be a PIL Image or a Tensor, in which case it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
1196
1197
1198
1199

    Args:
        degrees (sequence or float or int): Range of degrees to select from.
            If degrees is a number instead of sequence like (min, max), the range of degrees
1200
            will be (-degrees, +degrees). Set to 0 to deactivate rotations.
1201
1202
1203
1204
1205
1206
1207
        translate (tuple, optional): tuple of maximum absolute fraction for horizontal
            and vertical translations. For example translate=(a, b), then horizontal shift
            is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
            randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
        scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
            randomly sampled from the range a <= scale <= b. Will keep original scale by default.
        shear (sequence or float or int, optional): Range of degrees to select from.
ptrblck's avatar
ptrblck committed
1208
            If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
1209
            will be applied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
ptrblck's avatar
ptrblck committed
1210
1211
            range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
            a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
1212
1213
1214
1215
1216
1217
1218
            Will not apply shear by default.
        resample (int, optional): An optional resampling filter. See `filters`_ for more information.
            If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
            If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
        fillcolor (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area
            outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor
            input. Fill value for the area outside the transform in the output image is always 0.
1219
1220
1221

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1222
1223
    """

1224
1225
    def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, fillcolor=0):
        super().__init__()
1226
1227
1228
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError("If degrees is a single number, it must be positive.")
1229
            degrees = [-degrees, degrees]
1230
        else:
1231
1232
1233
1234
1235
1236
            if not isinstance(degrees, Sequence):
                raise TypeError("degrees should be a sequence of length 2.")
            if len(degrees) != 2:
                raise ValueError("degrees should be sequence of length 2.")

        self.degrees = [float(d) for d in degrees]
1237
1238

        if translate is not None:
1239
1240
1241
1242
            if not isinstance(translate, Sequence):
                raise TypeError("translate should be a sequence of length 2.")
            if len(translate) != 2:
                raise ValueError("translate should be sequence of length 2.")
1243
1244
1245
1246
1247
1248
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
1249
1250
1251
1252
1253
            if not isinstance(scale, Sequence):
                raise TypeError("scale should be a sequence of length 2.")
            if len(scale) != 2:
                raise ValueError("scale should be sequence of length 2.")

1254
1255
1256
1257
1258
1259
1260
1261
1262
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
            if isinstance(shear, numbers.Number):
                if shear < 0:
                    raise ValueError("If shear is a single number, it must be positive.")
1263
                shear = [-shear, shear]
1264
            else:
1265
1266
1267
1268
1269
1270
                if not isinstance(shear, Sequence):
                    raise TypeError("shear should be a sequence of length 2 or 4.")
                if len(shear) not in (2, 4):
                    raise ValueError("shear should be sequence of length 2 or 4.")

            self.shear = [float(s) for s in shear]
1271
1272
1273
1274
1275
1276
1277
        else:
            self.shear = shear

        self.resample = resample
        self.fillcolor = fillcolor

    @staticmethod
1278
1279
1280
1281
1282
1283
1284
    def get_params(
            degrees: List[float],
            translate: Optional[List[float]],
            scale_ranges: Optional[List[float]],
            shears: Optional[List[float]],
            img_size: List[int]
    ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]:
1285
1286
1287
        """Get parameters for affine transformation

        Returns:
1288
            params to be passed to the affine transformation
1289
        """
1290
        angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item())
1291
        if translate is not None:
1292
1293
1294
1295
1296
            max_dx = float(translate[0] * img_size[0])
            max_dy = float(translate[1] * img_size[1])
            tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
            ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
            translations = (tx, ty)
1297
1298
1299
1300
        else:
            translations = (0, 0)

        if scale_ranges is not None:
1301
            scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item())
1302
1303
1304
        else:
            scale = 1.0

1305
        shear_x = shear_y = 0.0
1306
        if shears is not None:
1307
1308
1309
1310
1311
            shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item())
            if len(shears) == 4:
                shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item())

        shear = (shear_x, shear_y)
1312
1313
1314

        return angle, translations, scale, shear

1315
    def forward(self, img):
1316
        """
1317
            img (PIL Image or Tensor): Image to be transformed.
1318
1319

        Returns:
1320
            PIL Image or Tensor: Affine transformed image.
1321
        """
1322
1323
1324
1325

        img_size = F._get_image_size(img)

        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
        return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)

    def __repr__(self):
        s = '{name}(degrees={degrees}'
        if self.translate is not None:
            s += ', translate={translate}'
        if self.scale is not None:
            s += ', scale={scale}'
        if self.shear is not None:
            s += ', shear={shear}'
        if self.resample > 0:
            s += ', resample={resample}'
        if self.fillcolor != 0:
            s += ', fillcolor={fillcolor}'
        s += ')'
        d = dict(self.__dict__)
        d['resample'] = _pil_interpolation_to_str[d['resample']]
        return s.format(name=self.__class__.__name__, **d)


1346
1347
class Grayscale(object):
    """Convert image to grayscale.
1348

1349
1350
1351
1352
    Args:
        num_output_channels (int): (1 or 3) number of channels desired for output image

    Returns:
1353
        PIL Image: Grayscale version of the input.
1354
1355
         - If ``num_output_channels == 1`` : returned image is single channel
         - If ``num_output_channels == 3`` : returned image is 3 channel with r == g == b
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371

    """

    def __init__(self, num_output_channels=1):
        self.num_output_channels = num_output_channels

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be converted to grayscale.

        Returns:
            PIL Image: Randomly grayscaled image.
        """
        return F.to_grayscale(img, num_output_channels=self.num_output_channels)

1372
    def __repr__(self):
1373
        return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
1374

1375
1376
1377

class RandomGrayscale(object):
    """Randomly convert image to grayscale with a probability of p (default 0.1).
1378

1379
1380
1381
1382
    Args:
        p (float): probability that image should be converted to grayscale.

    Returns:
1383
1384
1385
1386
        PIL Image: Grayscale version of the input image with probability p and unchanged
        with probability (1-p).
        - If input image is 1 channel: grayscale version is 1 channel
        - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404

    """

    def __init__(self, p=0.1):
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be converted to grayscale.

        Returns:
            PIL Image: Randomly grayscaled image.
        """
        num_output_channels = 1 if img.mode == 'L' else 3
        if random.random() < self.p:
            return F.to_grayscale(img, num_output_channels=num_output_channels)
        return img
1405
1406

    def __repr__(self):
1407
        return self.__class__.__name__ + '(p={0})'.format(self.p)
1408
1409


1410
class RandomErasing(torch.nn.Module):
1411
    """ Randomly selects a rectangle region in an image and erases its pixels.
1412
1413
    'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf

1414
1415
1416
1417
1418
1419
1420
1421
    Args:
         p: probability that the random erasing operation will be performed.
         scale: range of proportion of erased area against input image.
         ratio: range of aspect ratio of erased area.
         value: erasing value. Default is 0. If a single int, it is used to
            erase all pixels. If a tuple of length 3, it is used to erase
            R, G, B channels respectively.
            If a str of 'random', erasing each pixel with random values.
Zhun Zhong's avatar
Zhun Zhong committed
1422
         inplace: boolean to make this transform inplace. Default set to False.
1423

1424
1425
    Returns:
        Erased Image.
1426

1427
1428
    # Examples:
        >>> transform = transforms.Compose([
1429
1430
1431
1432
        >>>   transforms.RandomHorizontalFlip(),
        >>>   transforms.ToTensor(),
        >>>   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>>   transforms.RandomErasing(),
1433
1434
1435
        >>> ])
    """

Zhun Zhong's avatar
Zhun Zhong committed
1436
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
1437
1438
1439
1440
1441
1442
1443
1444
1445
        super().__init__()
        if not isinstance(value, (numbers.Number, str, tuple, list)):
            raise TypeError("Argument value should be either a number or str or a sequence")
        if isinstance(value, str) and value != "random":
            raise ValueError("If value is str, it should be 'random'")
        if not isinstance(scale, (tuple, list)):
            raise TypeError("Scale should be a sequence")
        if not isinstance(ratio, (tuple, list)):
            raise TypeError("Ratio should be a sequence")
1446
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
1447
            warnings.warn("Scale and ratio should be of kind (min, max)")
1448
        if scale[0] < 0 or scale[1] > 1:
1449
            raise ValueError("Scale should be between 0 and 1")
1450
        if p < 0 or p > 1:
1451
            raise ValueError("Random erasing probability should be between 0 and 1")
1452
1453
1454
1455
1456

        self.p = p
        self.scale = scale
        self.ratio = ratio
        self.value = value
1457
        self.inplace = inplace
1458
1459

    @staticmethod
1460
1461
1462
    def get_params(
            img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
    ) -> Tuple[int, int, int, int, Tensor]:
1463
1464
1465
1466
        """Get parameters for ``erase`` for a random erasing.

        Args:
            img (Tensor): Tensor image of size (C, H, W) to be erased.
1467
1468
1469
1470
1471
            scale (tuple or list): range of proportion of erased area against input image.
            ratio (tuple or list): range of aspect ratio of erased area.
            value (list, optional): erasing value. If None, it is interpreted as "random"
                (erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
                i.e. ``value[0]``.
1472
1473
1474
1475

        Returns:
            tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
        """
Zhun Zhong's avatar
Zhun Zhong committed
1476
        img_c, img_h, img_w = img.shape
1477
        area = img_h * img_w
1478

1479
        for _ in range(10):
1480
1481
            erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
            aspect_ratio = torch.empty(1).uniform_(ratio[0], ratio[1]).item()
1482
1483
1484

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))
1485
1486
1487
1488
1489
1490
1491
            if not (h < img_h and w < img_w):
                continue

            if value is None:
                v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
            else:
                v = torch.tensor(value)[:, None, None]
1492

1493
1494
            i = torch.randint(0, img_h - h + 1, size=(1, )).item()
            j = torch.randint(0, img_w - w + 1, size=(1, )).item()
1495
            return i, j, h, w, v
1496

Zhun Zhong's avatar
Zhun Zhong committed
1497
1498
1499
        # Return original image
        return 0, 0, img_h, img_w, img

1500
    def forward(self, img):
1501
1502
1503
1504
1505
1506
1507
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W) to be erased.

        Returns:
            img (Tensor): Erased Tensor image.
        """
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
        if torch.rand(1) < self.p:

            # cast self.value to script acceptable type
            if isinstance(self.value, (int, float)):
                value = [self.value, ]
            elif isinstance(self.value, str):
                value = None
            elif isinstance(self.value, tuple):
                value = list(self.value)
            else:
                value = self.value

            if value is not None and not (len(value) in (1, img.shape[-3])):
                raise ValueError(
                    "If value is a sequence, it should have either a single value or "
                    "{} (number of input channels)".format(img.shape[-3])
                )

            x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
1527
            return F.erase(img, x, y, h, w, v, self.inplace)
1528
        return img