transforms.py 24 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
14
import warnings
soumith's avatar
soumith committed
15

16

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
32
def to_tensor(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
33
34
35
36
37
38
39
40
41
42
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    See ``ToTensor`` for more details.

    Args:
        pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
43
44
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
45

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


81
def to_pil_image(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
82
83
84
85
86
87
88
89
90
91
    """Convert a tensor or an ndarray to PIL Image.

    See ``ToPIlImage`` for more details.

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

    Returns:
        PIL.Image: Image converted to PIL.Image.
    """
92
93
    if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
94

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
95
96
97
98
99
100
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
101
    assert isinstance(npimg, np.ndarray)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
102
103
104
105
106
107
108
109
110
111
112
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
113
114
115
    elif npimg.shape[2] == 4:
            if npimg.dtype == np.uint8:
                mode = 'RGBA'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
116
117
118
119
120
121
122
123
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
124
    """Normalize a tensor image with mean and standard deviation.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
125
126
127
128
129
130
131
132
133
134
135
136

    See ``Normalize`` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.

    Returns:
        Tensor: Normalized image.
    """
137
138
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
139
140
141
142
143
144
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


145
146
def resize(img, size, interpolation=Image.BILINEAR):
    """Resize the input PIL.Image to the given size.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
147
148

    Args:
149
        img (PIL.Image): Image to be resized.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
150
        size (sequence or int): Desired output size. If size is a sequence like
151
152
153
            (h, w), the output size will be matched to this. If size is an int,
            the smaller edge of the image will be matched to this number maintaing
            the aspect ratio. i.e, if height > width, then image will be rescaled to
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
154
155
156
157
158
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``

    Returns:
159
        PIL.Image: Resized image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
160
    """
161
162
163
164
165
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
166
167
168
169
170
171
172
173
174
175
176
177
178
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
179
        return img.resize(size[::-1], interpolation)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
180
181


182
183
184
185
186
187
def scale(*args, **kwargs):
    warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                  "please use transforms.Resize instead.")
    return resize(*args, **kwargs)


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
188
def pad(img, padding, fill=0):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
        img (PIL.Image): Image to be padded.
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.

    Returns:
        PIL.Image: Padded image.
    """
204
205
206
207
208
209
210
211
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
212
213
214
215
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
216
217
218
    return ImageOps.expand(img, border=padding, fill=fill)


219
def crop(img, i, j, h, w):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
220
221
222
223
    """Crop the given PIL.Image.

    Args:
        img (PIL.Image): Image to be cropped.
224
225
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
226
        h: Height of the cropped image.
227
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
228
229
230
231

    Returns:
        PIL.Image: Cropped image.
    """
232
233
234
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

235
    return img.crop((j, i, j + w, i + h))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
236
237


238
239
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
    """Crop the given PIL.Image and resize it to desired size.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
240

241
    Notably used in RandomResizedCrop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
242
243
244

    Args:
        img (PIL.Image): Image to be cropped.
245
246
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
247
        h: Height of the cropped image.
248
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
249
250
251
252
253
254
        size (sequence or int): Desired output size. Same semantics as ``scale``.
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``.
    Returns:
        PIL.Image: Cropped image.
    """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
255
    assert _is_pil_image(img), 'img should be PIL Image'
256
    img = crop(img, i, j, h, w)
257
    img = resize(img, size, interpolation)
258
    return img
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
259
260
261


def hflip(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
262
263
264
265
266
267
268
269
    """Horizontally flip the given PIL.Image.

    Args:
        img (PIL.Image): Image to be flipped.

    Returns:
        PIL.Image:  Horizontall flipped image.
    """
270
271
272
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
273
274
275
    return img.transpose(Image.FLIP_LEFT_RIGHT)


276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def vflip(img):
    """Vertically flip the given PIL.Image.

    Args:
        img (PIL.Image): Image to be flipped.

    Returns:
        PIL.Image:  Vertically flipped image.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    return img.transpose(Image.FLIP_TOP_BOTTOM)


291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def five_crop(img, size):
    """Crop the given PIL.Image into four corners and the central crop.

    Note: this transform returns a tuple of images and there may be a mismatch in the number of
    inputs and targets your `Dataset` returns.

    Args:
       size (sequence or int): Desired output size of the crop. If size is an
           int instead of sequence like (h, w), a square crop (size, size) is
           made.
    Returns:
        tuple: tuple (tl, tr, bl, br, center) corresponding top left,
            top right, bottom left, bottom right and center crop.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    w, h = img.size
    crop_h, crop_w = size
    if crop_w > w or crop_h > h:
        raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
                                                                                      (h, w)))
    tl = img.crop((0, 0, crop_w, crop_h))
    tr = img.crop((w - crop_w, 0, w, crop_h))
    bl = img.crop((0, h - crop_h, crop_w, h))
    br = img.crop((w - crop_w, h - crop_h, w, h))
    center = CenterCrop((crop_h, crop_w))(img)
    return (tl, tr, bl, br, center)


def ten_crop(img, size, vertical_flip=False):
    """Crop the given PIL.Image into four corners and the central crop plus the
       flipped version of these (horizontal flipping is used by default).

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
           vertical_flip (bool): Use vertical flipping instead of horizontal

        Returns:
            tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
                br_flip, center_flip) corresponding top left, top right,
                bottom left, bottom right and center crop and same for the
                flipped image.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


soumith's avatar
soumith committed
358
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
359
360
361
    """Composes several transforms together.

    Args:
362
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
363
364
365
366
367
368

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
369
    """
370

soumith's avatar
soumith committed
371
372
373
374
375
376
377
378
379
380
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
381
382
383
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
384
385
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
386

soumith's avatar
soumith committed
387
    def __call__(self, pic):
388
389
390
391
392
393
394
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
395
        return to_tensor(pic)
396

Adam Paszke's avatar
Adam Paszke committed
397

398
class ToPILImage(object):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
399
    """Convert a tensor or an ndarray to PIL Image.
400
401
402

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
403
    """
404

405
    def __call__(self, pic):
406
407
408
409
410
411
412
413
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
414
        return to_pil_image(pic)
415

soumith's avatar
soumith committed
416
417

class Normalize(object):
418
419
420
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
421
422
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
423
424
425
426
427

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
428
    """
429

soumith's avatar
soumith committed
430
431
432
433
434
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
435
436
437
438
439
440
441
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
442
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
443
444


445
446
class Resize(object):
    """Resize the input PIL.Image to the given size.
447
448
449

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
450
            (h, w), output size will be matched to this. If size is an int,
451
452
453
454
455
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
456
    """
457

soumith's avatar
soumith committed
458
    def __init__(self, size, interpolation=Image.BILINEAR):
459
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
460
461
462
463
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
464
465
466
467
468
469
470
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
471
472
473
474
475
476
477
478
        return resize(img, self.size, self.interpolation)


class Scale(Resize):
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                      "please use transforms.Resize instead.")
        super(Scale, self).__init__(*args, **kwargs)
soumith's avatar
soumith committed
479
480
481


class CenterCrop(object):
482
483
484
485
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
486
            int instead of sequence like (h, w), a square crop (size, size) is
487
            made.
488
    """
489

soumith's avatar
soumith committed
490
    def __init__(self, size):
491
492
493
494
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
495

496
497
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
498
499
500
501
502
503
504
        """Get parameters for ``crop`` for center crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
505
            tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
506
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
507
        w, h = img.size
508
        th, tw = output_size
509
510
511
        i = int(round((h - th) / 2.))
        j = int(round((w - tw) / 2.))
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
512

soumith's avatar
soumith committed
513
    def __call__(self, img):
514
515
516
517
518
519
520
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
521
522
        i, j, h, w = self.get_params(img, self.size)
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
523
524


525
class Pad(object):
526
527
528
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
529
530
531
532
533
534
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
535
            length 3, it is used to fill R, G, B channels respectively.
536
    """
537

538
    def __init__(self, padding, fill=0):
539
540
541
542
543
544
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

545
546
547
548
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
549
550
551
552
553
554
555
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
556
        return pad(img, self.padding, self.fill)
557

558

Soumith Chintala's avatar
Soumith Chintala committed
559
class Lambda(object):
560
561
562
563
564
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
565

Soumith Chintala's avatar
Soumith Chintala committed
566
    def __init__(self, lambd):
567
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
568
569
570
571
572
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

573

soumith's avatar
soumith committed
574
class RandomCrop(object):
575
576
577
578
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
579
            int instead of sequence like (h, w), a square crop (size, size) is
580
581
582
583
584
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
585
    """
586

soumith's avatar
soumith committed
587
    def __init__(self, size, padding=0):
588
589
590
591
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
592
593
        self.padding = padding

594
595
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
596
597
598
599
600
601
602
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
603
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
604
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
605
        w, h = img.size
606
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
607
608
609
        if w == tw and h == th:
            return img

610
611
612
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
613

soumith's avatar
soumith committed
614
    def __call__(self, img):
615
616
617
618
619
620
621
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
622
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
623
            img = pad(img, self.padding)
soumith's avatar
soumith committed
624

625
        i, j, h, w = self.get_params(img, self.size)
soumith's avatar
soumith committed
626

627
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
628
629
630


class RandomHorizontalFlip(object):
631
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
632

soumith's avatar
soumith committed
633
    def __call__(self, img):
634
635
636
637
638
639
640
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
641
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
642
            return hflip(img)
soumith's avatar
soumith committed
643
644
645
        return img


646
class RandomVerticalFlip(object):
647
    """Vertically flip the given PIL.Image randomly with a probability of 0.5."""
648
649
650
651
652
653
654
655
656
657

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
        if random.random() < 0.5:
658
            return vflip(img)
659
660
661
        return img


662
class RandomResizedCrop(object):
663
664
665
666
667
668
669
670
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
671
        size: expected output size of each edge
672
        interpolation: Default: PIL.Image.BILINEAR
673
    """
674

soumith's avatar
soumith committed
675
    def __init__(self, size, interpolation=Image.BILINEAR):
676
        self.size = (size, size)
soumith's avatar
soumith committed
677
678
        self.interpolation = interpolation

679
680
    @staticmethod
    def get_params(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
681
682
683
684
685
686
        """Get parameters for ``crop`` for a random sized crop.

        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
687
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
688
689
                sized crop.
        """
soumith's avatar
soumith committed
690
691
692
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
693
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
694
695
696
697
698
699
700
701

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
702
703
704
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w
soumith's avatar
soumith committed
705
706

        # Fallback
707
708
709
        w = min(img.size[0], img.size[1])
        i = (img.size[1] - w) // 2
        j = (img.size[0] - w) // 2
710
        return i, j, w, w
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
711
712

    def __call__(self, img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
713
714
715
716
717
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
718
            PIL.Image: Randomly cropped and resize image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
719
        """
720
        i, j, h, w = self.get_params(img)
721
722
723
724
725
726
727
728
        return resized_crop(img, i, j, h, w, self.size, self.interpolation)


class RandomSizedCrop(RandomResizedCrop):
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
                      "please use transforms.RandomResizedCrop instead.")
        super(RandomSizedCrop, self).__init__(*args, **kwargs)
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751


class FiveCrop(object):
    """Crop the given PIL.Image into four corners and the central crop.abs

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
    """

    def __init__(self, size):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size

    def __call__(self, img):
752
        return five_crop(img, self.size)
753
754
755
756
757
758
759
760
761
762
763
764
765


class TenCrop(object):
    """Crop the given PIL.Image into four corners and the central crop plus the
       flipped version of these (horizontal flipping is used by default)

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
766
           vertical_flip(bool): Use vertical flipping instead of horizontal
767
768
    """

769
    def __init__(self, size, vertical_flip=False):
770
771
772
773
774
775
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size
776
        self.vertical_flip = vertical_flip
777
778

    def __call__(self, img):
779
        return ten_crop(img, self.size, self.vertical_flip)