transforms.py 48.5 KB
Newer Older
1
2
3
from __future__ import division
import torch
import math
Tongzhou Wang's avatar
Tongzhou Wang committed
4
import sys
5
import random
6
from PIL import Image
7
8
9
10
11
12
13
14
15
16
17
18
try:
    import accimage
except ImportError:
    accimage = None
import numpy as np
import numbers
import types
import collections
import warnings

from . import functional as F

Tongzhou Wang's avatar
Tongzhou Wang committed
19
20
21
22
23
24
25
26
if sys.version_info < (3, 3):
    Sequence = collections.Sequence
    Iterable = collections.Iterable
else:
    Sequence = collections.abc.Sequence
    Iterable = collections.abc.Iterable


27
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
28
29
           "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
           "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
30
           "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
31
           "RandomPerspective", "RandomErasing"]
32

33
34
35
36
37
_pil_interpolation_to_str = {
    Image.NEAREST: 'PIL.Image.NEAREST',
    Image.BILINEAR: 'PIL.Image.BILINEAR',
    Image.BICUBIC: 'PIL.Image.BICUBIC',
    Image.LANCZOS: 'PIL.Image.LANCZOS',
surgan12's avatar
surgan12 committed
38
39
    Image.HAMMING: 'PIL.Image.HAMMING',
    Image.BOX: 'PIL.Image.BOX',
40
41
}

42

Zhicheng Yan's avatar
Zhicheng Yan committed
43
44
45
46
47
48
49
50
51
def _get_image_size(img):
    if F._is_pil_image(img):
        return img.size
    elif isinstance(img, torch.Tensor) and img.dim() > 2:
        return img.shape[-2:][::-1]
    else:
        raise TypeError("Unexpected type {}".format(type(img)))


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

73
74
75
76
77
78
79
80
    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

81
82
83
84
85

class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
surgan12's avatar
surgan12 committed
86
87
88
89
90
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.
91
92
93
94
95
96
97
98
99
100
101
102
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

103
104
105
    def __repr__(self):
        return self.__class__.__name__ + '()'

106
107
108
109
110
111
112
113
114
115

class ToPILImage(object):
    """Convert a tensor or an ndarray to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
surgan12's avatar
surgan12 committed
116
117
118
119
             - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
             - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
             - If the input has 2 channels, the ``mode`` is assumed to be ``LA``.
             - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``,
120
               ``short``).
121

csukuangfj's avatar
csukuangfj committed
122
    .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    """
    def __init__(self, mode=None):
        self.mode = mode

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        return F.to_pil_image(pic, self.mode)

138
    def __repr__(self):
139
140
141
142
143
        format_string = self.__class__.__name__ + '('
        if self.mode is not None:
            format_string += 'mode={0}'.format(self.mode)
        format_string += ')'
        return format_string
144

145
146

class Normalize(object):
Fang Gao's avatar
Fang Gao committed
147
    """Normalize a tensor image with mean and standard deviation.
148
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
149
    will normalize each channel of the input ``torch.*Tensor`` i.e.
abdjava's avatar
abdjava committed
150
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
151

152
    .. note::
153
        This transform acts out of place, i.e., it does not mutate the input tensor.
154

155
156
157
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
158
159
        inplace(bool,optional): Bool to make this operation in-place.

160
161
    """

surgan12's avatar
surgan12 committed
162
    def __init__(self, mean, std, inplace=False):
163
164
        self.mean = mean
        self.std = std
surgan12's avatar
surgan12 committed
165
        self.inplace = inplace
166
167
168
169
170
171
172
173
174

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized Tensor image.
        """
surgan12's avatar
surgan12 committed
175
        return F.normalize(tensor, self.mean, self.std, self.inplace)
176

177
178
179
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

class Resize(object):
    """Resize the input PIL Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
Tongzhou Wang's avatar
Tongzhou Wang committed
195
        assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
196
197
198
199
200
201
202
203
204
205
206
207
208
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image.
        """
        return F.resize(img, self.size, self.interpolation)

209
    def __repr__(self):
210
211
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

class Scale(Resize):
    """
    Note: This transform is deprecated in favor of Resize.
    """
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                      "please use transforms.Resize instead.")
        super(Scale, self).__init__(*args, **kwargs)


class CenterCrop(object):
    """Crops the given PIL Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.

        Returns:
            PIL Image: Cropped image.
        """
        return F.center_crop(img, self.size)

249
250
251
    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)

252
253
254
255
256
257
258
259
260
261

class Pad(object):
    """Pad the given PIL Image on all sides with the given "pad" value.

    Args:
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
262
        fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
263
            length 3, it is used to fill R, G, B channels respectively.
264
            This value is only used when the padding_mode is constant
265
266
267
268
269
270
271
272
273
274
        padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is constant.

            - constant: pads with a constant value, this value is specified with fill

            - edge: pads with the last value at the edge of the image

            - reflect: pads with reflection of image without repeating the last value on the edge

                For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
275
                will result in [3, 2, 1, 2, 3, 4, 3, 2]
276
277
278
279

            - symmetric: pads with reflection of image repeating the last value on the edge

                For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
280
                will result in [2, 1, 1, 2, 3, 4, 4, 3]
281
282
    """

283
    def __init__(self, padding, fill=0, padding_mode='constant'):
284
285
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
286
        assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
Tongzhou Wang's avatar
Tongzhou Wang committed
287
        if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
288
289
290
291
292
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

        self.padding = padding
        self.fill = fill
293
        self.padding_mode = padding_mode
294
295
296
297
298
299
300
301
302

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be padded.

        Returns:
            PIL Image: Padded image.
        """
303
        return F.pad(img, self.padding, self.fill, self.padding_mode)
304

305
    def __repr__(self):
306
307
        return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
            format(self.padding, self.fill, self.padding_mode)
308

309
310
311
312
313
314
315
316
317

class Lambda(object):
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """

    def __init__(self, lambd):
318
        assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
319
320
321
322
323
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

324
325
326
    def __repr__(self):
        return self.__class__.__name__ + '()'

327

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
class RandomTransforms(object):
    """Base class for a list of transformations with randomness

    Args:
        transforms (list or tuple): list of transformations
    """

    def __init__(self, transforms):
        assert isinstance(transforms, (list, tuple))
        self.transforms = transforms

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class RandomApply(RandomTransforms):
    """Apply randomly a list of transformations with a given probability

    Args:
        transforms (list or tuple): list of transformations
        p (float): probability
    """

    def __init__(self, transforms, p=0.5):
        super(RandomApply, self).__init__(transforms)
        self.p = p

    def __call__(self, img):
        if self.p < random.random():
            return img
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        format_string += '\n    p={}'.format(self.p)
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string


class RandomOrder(RandomTransforms):
    """Apply a list of transformations in a random order
    """
    def __call__(self, img):
        order = list(range(len(self.transforms)))
        random.shuffle(order)
        for i in order:
            img = self.transforms[i](img)
        return img


class RandomChoice(RandomTransforms):
    """Apply single transformation randomly picked from a list
    """
    def __call__(self, img):
        t = random.choice(self.transforms)
        return t(img)


399
400
401
402
403
404
405
406
class RandomCrop(object):
    """Crop the given PIL Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
        padding (int or sequence, optional): Optional padding on each border
407
            of the image. Default is None, i.e no padding. If a sequence of length
408
            4 is provided, it is used to pad left, top, right, bottom borders
409
410
            respectively. If a sequence of length 2 is provided, it is used to
            pad left/right, top/bottom borders, respectively.
411
        pad_if_needed (boolean): It will pad the image if smaller than the
ekka's avatar
ekka committed
412
            desired size to avoid raising an exception. Since cropping is done
413
            after padding, the padding seems to be done at a random offset.
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        fill: Pixel fill value for constant fill. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.
            This value is only used when the padding_mode is constant
        padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.

             - constant: pads with a constant value, this value is specified with fill

             - edge: pads with the last value on the edge of the image

             - reflect: pads with reflection of image (without repeating the last value on the edge)

                padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
                will result in [3, 2, 1, 2, 3, 4, 3, 2]

             - symmetric: pads with reflection of image (repeating the last value on the edge)

                padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
                will result in [2, 1, 1, 2, 3, 4, 4, 3]

433
434
    """

435
    def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
436
437
438
439
440
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.padding = padding
441
        self.pad_if_needed = pad_if_needed
442
443
        self.fill = fill
        self.padding_mode = padding_mode
444
445
446
447
448
449
450
451
452
453
454
455

    @staticmethod
    def get_params(img, output_size):
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
Zhicheng Yan's avatar
Zhicheng Yan committed
456
        w, h = _get_image_size(img)
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
        th, tw = output_size
        if w == tw and h == th:
            return 0, 0, h, w

        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.

        Returns:
            PIL Image: Cropped image.
        """
473
474
        if self.padding is not None:
            img = F.pad(img, self.padding, self.fill, self.padding_mode)
475

476
477
        # pad the width if needed
        if self.pad_if_needed and img.size[0] < self.size[1]:
478
            img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
479
480
        # pad the height if needed
        if self.pad_if_needed and img.size[1] < self.size[0]:
481
            img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
482

483
484
485
486
        i, j, h, w = self.get_params(img, self.size)

        return F.crop(img, i, j, h, w)

487
    def __repr__(self):
488
        return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
489

490
491

class RandomHorizontalFlip(object):
492
493
494
495
496
497
498
499
    """Horizontally flip the given PIL Image randomly with a given probability.

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p
500
501
502
503
504
505
506
507
508

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be flipped.

        Returns:
            PIL Image: Randomly flipped image.
        """
509
        if random.random() < self.p:
510
511
512
            return F.hflip(img)
        return img

513
    def __repr__(self):
514
        return self.__class__.__name__ + '(p={})'.format(self.p)
515

516
517

class RandomVerticalFlip(object):
518
519
520
521
522
523
524
525
    """Vertically flip the given PIL Image randomly with a given probability.

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p
526
527
528
529
530
531
532
533
534

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be flipped.

        Returns:
            PIL Image: Randomly flipped image.
        """
535
        if random.random() < self.p:
536
537
538
            return F.vflip(img)
        return img

539
    def __repr__(self):
540
        return self.__class__.__name__ + '(p={})'.format(self.p)
541

542

543
544
545
546
547
548
549
550
551
552
class RandomPerspective(object):
    """Performs Perspective transformation of the given PIL Image randomly with a given probability.

    Args:
        interpolation : Default- Image.BICUBIC

        p (float): probability of the image being perspectively transformed. Default value is 0.5

        distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.

553
554
        fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
            If int, it is used for all channels respectively. Default value is 0.
555
556
    """

557
    def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC, fill=0):
558
559
560
        self.p = p
        self.interpolation = interpolation
        self.distortion_scale = distortion_scale
561
        self.fill = fill
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be Perspectively transformed.

        Returns:
            PIL Image: Random perspectivley transformed image.
        """
        if not F._is_pil_image(img):
            raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

        if random.random() < self.p:
            width, height = img.size
            startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
577
            return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
578
579
580
581
582
583
584
585
586
587
588
        return img

    @staticmethod
    def get_params(width, height, distortion_scale):
        """Get parameters for ``perspective`` for a random perspective transform.

        Args:
            width : width of the image.
            height : height of the image.

        Returns:
589
            List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
            List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
        """
        half_height = int(height / 2)
        half_width = int(width / 2)
        topleft = (random.randint(0, int(distortion_scale * half_width)),
                   random.randint(0, int(distortion_scale * half_height)))
        topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
                    random.randint(0, int(distortion_scale * half_height)))
        botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
                    random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
        botleft = (random.randint(0, int(distortion_scale * half_width)),
                   random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
        startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
        endpoints = [topleft, topright, botright, botleft]
        return startpoints, endpoints

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


610
611
612
class RandomResizedCrop(object):
    """Crop the given PIL Image to random size and aspect ratio.

613
614
    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
615
616
617
618
619
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: expected output size of each edge
620
621
        scale: range of size of the origin size cropped
        ratio: range of aspect ratio of the origin aspect ratio cropped
622
623
624
        interpolation: Default: PIL.Image.BILINEAR
    """

625
    def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
626
        if isinstance(size, (tuple, list)):
627
628
629
630
631
632
            self.size = size
        else:
            self.size = (size, size)
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
            warnings.warn("range should be of kind (min, max)")

633
        self.interpolation = interpolation
634
635
        self.scale = scale
        self.ratio = ratio
636
637

    @staticmethod
638
    def get_params(img, scale, ratio):
639
640
641
642
        """Get parameters for ``crop`` for a random sized crop.

        Args:
            img (PIL Image): Image to be cropped.
643
644
            scale (tuple): range of size of the origin size cropped
            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
645
646
647
648
649

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
                sized crop.
        """
Zhicheng Yan's avatar
Zhicheng Yan committed
650
651
        width, height = _get_image_size(img)
        area = height * width
652

653
        for attempt in range(10):
654
            target_area = random.uniform(*scale) * area
655
656
            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
            aspect_ratio = math.exp(random.uniform(*log_ratio))
657
658
659
660

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

Zhicheng Yan's avatar
Zhicheng Yan committed
661
662
663
            if 0 < w <= width and 0 < h <= height:
                i = random.randint(0, height - h)
                j = random.randint(0, width - w)
664
665
                return i, j, h, w

666
        # Fallback to central crop
Zhicheng Yan's avatar
Zhicheng Yan committed
667
        in_ratio = float(width) / float(height)
668
        if (in_ratio < min(ratio)):
Zhicheng Yan's avatar
Zhicheng Yan committed
669
            w = width
670
            h = int(round(w / min(ratio)))
671
        elif (in_ratio > max(ratio)):
Zhicheng Yan's avatar
Zhicheng Yan committed
672
            h = height
673
            w = int(round(h * max(ratio)))
674
        else:  # whole image
Zhicheng Yan's avatar
Zhicheng Yan committed
675
676
677
678
            w = width
            h = height
        i = (height - h) // 2
        j = (width - w) // 2
679
        return i, j, h, w
680
681
682
683

    def __call__(self, img):
        """
        Args:
684
            img (PIL Image): Image to be cropped and resized.
685
686

        Returns:
687
            PIL Image: Randomly cropped and resized image.
688
        """
689
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
690
691
        return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)

692
    def __repr__(self):
693
694
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
695
696
        format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
        format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
697
698
        format_string += ', interpolation={0})'.format(interpolate_str)
        return format_string
699

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745

class RandomSizedCrop(RandomResizedCrop):
    """
    Note: This transform is deprecated in favor of RandomResizedCrop.
    """
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
                      "please use transforms.RandomResizedCrop instead.")
        super(RandomSizedCrop, self).__init__(*args, **kwargs)


class FiveCrop(object):
    """Crop the given PIL Image into four corners and the central crop

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.

    Example:
         >>> transform = Compose([
         >>>    FiveCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size

    def __call__(self, img):
        return F.five_crop(img, self.size)

746
747
748
    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)

749
750
751
752
753
754
755
756
757
758
759
760
761
762

class TenCrop(object):
    """Crop the given PIL Image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default)

    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
763
        vertical_flip (bool): Use vertical flipping instead of horizontal
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788

    Example:
         >>> transform = Compose([
         >>>    TenCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    """

    def __init__(self, size, vertical_flip=False):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size
        self.vertical_flip = vertical_flip

    def __call__(self, img):
        return F.ten_crop(img, self.size, self.vertical_flip)

789
    def __repr__(self):
790
        return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip)
791

792

793
class LinearTransformation(object):
ekka's avatar
ekka committed
794
    """Transform a tensor image with a square transformation matrix and a mean_vector computed
795
    offline.
ekka's avatar
ekka committed
796
797
798
    Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and
    subtract mean_vector from it which is then followed by computing the dot
    product with the transformation matrix and then reshaping the tensor to its
799
    original shape.
800

801
    Applications:
802
        whitening transformation: Suppose X is a column vector zero-centered data.
803
804
805
        Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X),
        perform SVD on this matrix and pass it as transformation_matrix.

806
807
    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
ekka's avatar
ekka committed
808
        mean_vector (Tensor): tensor [D], D = C x H x W
809
810
    """

ekka's avatar
ekka committed
811
    def __init__(self, transformation_matrix, mean_vector):
812
813
814
        if transformation_matrix.size(0) != transformation_matrix.size(1):
            raise ValueError("transformation_matrix should be square. Got " +
                             "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
ekka's avatar
ekka committed
815
816
817
818
819
820

        if mean_vector.size(0) != transformation_matrix.size(0):
            raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) +
                             " as any one of the dimensions of the transformation_matrix [{} x {}]"
                             .format(transformation_matrix.size()))

821
        self.transformation_matrix = transformation_matrix
ekka's avatar
ekka committed
822
        self.mean_vector = mean_vector
823
824
825
826
827
828
829
830
831
832
833
834
835

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be whitened.

        Returns:
            Tensor: Transformed image.
        """
        if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
            raise ValueError("tensor and transformation matrix have incompatible shape." +
                             "[{} x {} x {}] != ".format(*tensor.size()) +
                             "{}".format(self.transformation_matrix.size(0)))
ekka's avatar
ekka committed
836
        flat_tensor = tensor.view(1, -1) - self.mean_vector
837
838
839
840
        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
        tensor = transformed_tensor.view(tensor.size())
        return tensor

841
    def __repr__(self):
ekka's avatar
ekka committed
842
843
844
        format_string = self.__class__.__name__ + '(transformation_matrix='
        format_string += (str(self.transformation_matrix.tolist()) + ')')
        format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')')
845
846
        return format_string

847
848
849
850
851

class ColorJitter(object):
    """Randomly change the brightness, contrast and saturation of an image.

    Args:
yaox12's avatar
yaox12 committed
852
853
854
855
856
857
858
859
860
861
862
863
        brightness (float or tuple of float (min, max)): How much to jitter brightness.
            brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
            or the given [min, max]. Should be non negative numbers.
        contrast (float or tuple of float (min, max)): How much to jitter contrast.
            contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
            or the given [min, max]. Should be non negative numbers.
        saturation (float or tuple of float (min, max)): How much to jitter saturation.
            saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
            or the given [min, max]. Should be non negative numbers.
        hue (float or tuple of float (min, max)): How much to jitter hue.
            hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
            Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
864
865
    """
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
yaox12's avatar
yaox12 committed
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
        self.brightness = self._check_input(brightness, 'brightness')
        self.contrast = self._check_input(contrast, 'contrast')
        self.saturation = self._check_input(saturation, 'saturation')
        self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
                                     clip_first_on_zero=False)

    def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError("If {} is a single number, it must be non negative.".format(name))
            value = [center - value, center + value]
            if clip_first_on_zero:
                value[0] = max(value[0], 0)
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError("{} values should be between {}".format(name, bound))
        else:
            raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))

        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
            value = None
        return value
890
891
892
893
894
895
896
897
898
899
900
901

    @staticmethod
    def get_params(brightness, contrast, saturation, hue):
        """Get a randomized transform to be applied on image.

        Arguments are same as that of __init__.

        Returns:
            Transform which randomly adjusts brightness, contrast and
            saturation in a random order.
        """
        transforms = []
yaox12's avatar
yaox12 committed
902
903
904

        if brightness is not None:
            brightness_factor = random.uniform(brightness[0], brightness[1])
905
906
            transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))

yaox12's avatar
yaox12 committed
907
908
        if contrast is not None:
            contrast_factor = random.uniform(contrast[0], contrast[1])
909
910
            transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))

yaox12's avatar
yaox12 committed
911
912
        if saturation is not None:
            saturation_factor = random.uniform(saturation[0], saturation[1])
913
914
            transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))

yaox12's avatar
yaox12 committed
915
916
        if hue is not None:
            hue_factor = random.uniform(hue[0], hue[1])
917
918
            transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))

vfdev's avatar
vfdev committed
919
        random.shuffle(transforms)
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
        transform = Compose(transforms)

        return transform

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Input image.

        Returns:
            PIL Image: Color jittered image.
        """
        transform = self.get_params(self.brightness, self.contrast,
                                    self.saturation, self.hue)
        return transform(img)
935

936
    def __repr__(self):
937
938
939
940
941
942
        format_string = self.__class__.__name__ + '('
        format_string += 'brightness={0}'.format(self.brightness)
        format_string += ', contrast={0}'.format(self.contrast)
        format_string += ', saturation={0}'.format(self.saturation)
        format_string += ', hue={0})'.format(self.hue)
        return format_string
943

944
945
946
947
948
949
950
951
952

class RandomRotation(object):
    """Rotate the image by angle.

    Args:
        degrees (sequence or float or int): Range of degrees to select from.
            If degrees is a number instead of sequence like (min, max), the range of degrees
            will be (-degrees, +degrees).
        resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
953
            An optional resampling filter. See `filters`_ for more information.
954
955
956
957
958
959
960
961
            If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
        expand (bool, optional): Optional expansion flag.
            If true, expands the output to make it large enough to hold the entire rotated image.
            If false or omitted, make the output image the same size as the input image.
            Note that the expand flag assumes rotation around the center and no translation.
        center (2-tuple, optional): Optional center of rotation.
            Origin is the upper left corner.
            Default is the center of the image.
Philip Meier's avatar
Philip Meier committed
962
963
964
        fill (n-tuple or int or float): Pixel fill value for area outside the rotated
            image. If int or float, the value is used for all bands respectively.
            Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
965
966
967

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

968
969
    """

Philip Meier's avatar
Philip Meier committed
970
    def __init__(self, degrees, resample=False, expand=False, center=None, fill=None):
971
972
973
974
975
976
977
978
979
980
981
982
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError("If degrees is a single number, it must be positive.")
            self.degrees = (-degrees, degrees)
        else:
            if len(degrees) != 2:
                raise ValueError("If degrees is a sequence, it must be of len 2.")
            self.degrees = degrees

        self.resample = resample
        self.expand = expand
        self.center = center
983
        self.fill = fill
984
985
986
987
988
989
990
991

    @staticmethod
    def get_params(degrees):
        """Get parameters for ``rotate`` for a random rotation.

        Returns:
            sequence: params to be passed to ``rotate`` for random rotation.
        """
vfdev's avatar
vfdev committed
992
        angle = random.uniform(degrees[0], degrees[1])
993
994
995
996
997

        return angle

    def __call__(self, img):
        """
998
        Args:
999
1000
1001
1002
1003
1004
1005
1006
            img (PIL Image): Image to be rotated.

        Returns:
            PIL Image: Rotated image.
        """

        angle = self.get_params(self.degrees)

1007
        return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
1008

1009
    def __repr__(self):
1010
1011
1012
1013
1014
1015
1016
        format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
        format_string += ', resample={0}'.format(self.resample)
        format_string += ', expand={0}'.format(self.expand)
        if self.center is not None:
            format_string += ', center={0}'.format(self.center)
        format_string += ')'
        return format_string
1017

1018

1019
1020
1021
1022
1023
1024
class RandomAffine(object):
    """Random affine transformation of the image keeping center invariant

    Args:
        degrees (sequence or float or int): Range of degrees to select from.
            If degrees is a number instead of sequence like (min, max), the range of degrees
1025
            will be (-degrees, +degrees). Set to 0 to deactivate rotations.
1026
1027
1028
1029
1030
1031
1032
        translate (tuple, optional): tuple of maximum absolute fraction for horizontal
            and vertical translations. For example translate=(a, b), then horizontal shift
            is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
            randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
        scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
            randomly sampled from the range a <= scale <= b. Will keep original scale by default.
        shear (sequence or float or int, optional): Range of degrees to select from.
ptrblck's avatar
ptrblck committed
1033
1034
1035
1036
1037
            If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
            will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
            range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
            a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
            Will not apply shear by default
1038
        resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
1039
            An optional resampling filter. See `filters`_ for more information.
1040
            If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
Surgan Jandial's avatar
Surgan Jandial committed
1041
1042
        fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area
            outside the transform in the output image.(Pillow>=5.0.0)
1043
1044
1045

    .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters

1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
    """

    def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
        if isinstance(degrees, numbers.Number):
            if degrees < 0:
                raise ValueError("If degrees is a single number, it must be positive.")
            self.degrees = (-degrees, degrees)
        else:
            assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
                "degrees should be a list or tuple and it must be of length 2."
            self.degrees = degrees

        if translate is not None:
            assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
                "translate should be a list or tuple and it must be of length 2."
            for t in translate:
                if not (0.0 <= t <= 1.0):
                    raise ValueError("translation values should be between 0 and 1")
        self.translate = translate

        if scale is not None:
            assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
                "scale should be a list or tuple and it must be of length 2."
            for s in scale:
                if s <= 0:
                    raise ValueError("scale values should be positive")
        self.scale = scale

        if shear is not None:
            if isinstance(shear, numbers.Number):
                if shear < 0:
                    raise ValueError("If shear is a single number, it must be positive.")
                self.shear = (-shear, shear)
            else:
ptrblck's avatar
ptrblck committed
1080
1081
1082
1083
1084
1085
1086
1087
                assert isinstance(shear, (tuple, list)) and \
                    (len(shear) == 2 or len(shear) == 4), \
                    "shear should be a list or tuple and it must be of length 2 or 4."
                # X-Axis shear with [min, max]
                if len(shear) == 2:
                    self.shear = [shear[0], shear[1], 0., 0.]
                elif len(shear) == 4:
                    self.shear = [s for s in shear]
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        else:
            self.shear = shear

        self.resample = resample
        self.fillcolor = fillcolor

    @staticmethod
    def get_params(degrees, translate, scale_ranges, shears, img_size):
        """Get parameters for affine transformation

        Returns:
            sequence: params to be passed to the affine transformation
        """
        angle = random.uniform(degrees[0], degrees[1])
        if translate is not None:
            max_dx = translate[0] * img_size[0]
            max_dy = translate[1] * img_size[1]
            translations = (np.round(random.uniform(-max_dx, max_dx)),
                            np.round(random.uniform(-max_dy, max_dy)))
        else:
            translations = (0, 0)

        if scale_ranges is not None:
            scale = random.uniform(scale_ranges[0], scale_ranges[1])
        else:
            scale = 1.0

        if shears is not None:
ptrblck's avatar
ptrblck committed
1116
1117
1118
1119
1120
            if len(shears) == 2:
                shear = [random.uniform(shears[0], shears[1]), 0.]
            elif len(shears) == 4:
                shear = [random.uniform(shears[0], shears[1]),
                         random.uniform(shears[2], shears[3])]
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        else:
            shear = 0.0

        return angle, translations, scale, shear

    def __call__(self, img):
        """
            img (PIL Image): Image to be transformed.

        Returns:
            PIL Image: Affine transformed image.
        """
        ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
        return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)

    def __repr__(self):
        s = '{name}(degrees={degrees}'
        if self.translate is not None:
            s += ', translate={translate}'
        if self.scale is not None:
            s += ', scale={scale}'
        if self.shear is not None:
            s += ', shear={shear}'
        if self.resample > 0:
            s += ', resample={resample}'
        if self.fillcolor != 0:
            s += ', fillcolor={fillcolor}'
        s += ')'
        d = dict(self.__dict__)
        d['resample'] = _pil_interpolation_to_str[d['resample']]
        return s.format(name=self.__class__.__name__, **d)


1154
1155
class Grayscale(object):
    """Convert image to grayscale.
1156

1157
1158
1159
1160
    Args:
        num_output_channels (int): (1 or 3) number of channels desired for output image

    Returns:
1161
1162
1163
        PIL Image: Grayscale version of the input.
        - If num_output_channels == 1 : returned image is single channel
        - If num_output_channels == 3 : returned image is 3 channel with r == g == b
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179

    """

    def __init__(self, num_output_channels=1):
        self.num_output_channels = num_output_channels

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be converted to grayscale.

        Returns:
            PIL Image: Randomly grayscaled image.
        """
        return F.to_grayscale(img, num_output_channels=self.num_output_channels)

1180
    def __repr__(self):
1181
        return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
1182

1183
1184
1185

class RandomGrayscale(object):
    """Randomly convert image to grayscale with a probability of p (default 0.1).
1186

1187
1188
1189
1190
    Args:
        p (float): probability that image should be converted to grayscale.

    Returns:
1191
1192
1193
1194
        PIL Image: Grayscale version of the input image with probability p and unchanged
        with probability (1-p).
        - If input image is 1 channel: grayscale version is 1 channel
        - If input image is 3 channel: grayscale version is 3 channel with r == g == b
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212

    """

    def __init__(self, p=0.1):
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be converted to grayscale.

        Returns:
            PIL Image: Randomly grayscaled image.
        """
        num_output_channels = 1 if img.mode == 'L' else 3
        if random.random() < self.p:
            return F.to_grayscale(img, num_output_channels=num_output_channels)
        return img
1213
1214

    def __repr__(self):
1215
        return self.__class__.__name__ + '(p={0})'.format(self.p)
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229


class RandomErasing(object):
    """ Randomly selects a rectangle region in an image and erases its pixels.
        'Random Erasing Data Augmentation' by Zhong et al.
        See https://arxiv.org/pdf/1708.04896.pdf
    Args:
         p: probability that the random erasing operation will be performed.
         scale: range of proportion of erased area against input image.
         ratio: range of aspect ratio of erased area.
         value: erasing value. Default is 0. If a single int, it is used to
            erase all pixels. If a tuple of length 3, it is used to erase
            R, G, B channels respectively.
            If a str of 'random', erasing each pixel with random values.
Zhun Zhong's avatar
Zhun Zhong committed
1230
         inplace: boolean to make this transform inplace. Default set to False.
1231

1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
    Returns:
        Erased Image.
    # Examples:
        >>> transform = transforms.Compose([
        >>> transforms.RandomHorizontalFlip(),
        >>> transforms.ToTensor(),
        >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> transforms.RandomErasing(),
        >>> ])
    """

Zhun Zhong's avatar
Zhun Zhong committed
1243
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
1244
1245
1246
1247
1248
        assert isinstance(value, (numbers.Number, str, tuple, list))
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
            warnings.warn("range should be of kind (min, max)")
        if scale[0] < 0 or scale[1] > 1:
            raise ValueError("range of scale should be between 0 and 1")
1249
1250
        if p < 0 or p > 1:
            raise ValueError("range of random erasing probability should be between 0 and 1")
1251
1252
1253
1254
1255

        self.p = p
        self.scale = scale
        self.ratio = ratio
        self.value = value
1256
        self.inplace = inplace
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269

    @staticmethod
    def get_params(img, scale, ratio, value=0):
        """Get parameters for ``erase`` for a random erasing.

        Args:
            img (Tensor): Tensor image of size (C, H, W) to be erased.
            scale: range of proportion of erased area against input image.
            ratio: range of aspect ratio of erased area.

        Returns:
            tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
        """
Zhun Zhong's avatar
Zhun Zhong committed
1270
        img_c, img_h, img_w = img.shape
1271
        area = img_h * img_w
1272

Zhun Zhong's avatar
Zhun Zhong committed
1273
        for attempt in range(10):
1274
1275
1276
1277
1278
1279
            erase_area = random.uniform(scale[0], scale[1]) * area
            aspect_ratio = random.uniform(ratio[0], ratio[1])

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))

1280
1281
1282
            if h < img_h and w < img_w:
                i = random.randint(0, img_h - h)
                j = random.randint(0, img_w - w)
1283
1284
1285
                if isinstance(value, numbers.Number):
                    v = value
                elif isinstance(value, torch._six.string_classes):
Zhun Zhong's avatar
Zhun Zhong committed
1286
                    v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
1287
1288
1289
1290
                elif isinstance(value, (list, tuple)):
                    v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
                return i, j, h, w, v

Zhun Zhong's avatar
Zhun Zhong committed
1291
1292
1293
        # Return original image
        return 0, 0, img_h, img_w, img

1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W) to be erased.

        Returns:
            img (Tensor): Erased Tensor image.
        """
        if random.uniform(0, 1) < self.p:
            x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
1304
            return F.erase(img, x, y, h, w, v, self.inplace)
1305
        return img