transforms.py 23.2 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
31
def to_tensor(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
32
33
34
35
36
37
38
39
40
41
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    See ``ToTensor`` for more details.

    Args:
        pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
42
43
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
44

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


80
def to_pil_image(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
81
82
83
84
85
86
87
88
89
90
    """Convert a tensor or an ndarray to PIL Image.

    See ``ToPIlImage`` for more details.

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

    Returns:
        PIL.Image: Image converted to PIL.Image.
    """
91
92
    if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
93

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
94
95
96
97
98
99
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
100
    assert isinstance(npimg, np.ndarray)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
101
102
103
104
105
106
107
108
109
110
111
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
112
113
114
    elif npimg.shape[2] == 4:
            if npimg.dtype == np.uint8:
                mode = 'RGBA'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
115
116
117
118
119
120
121
122
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
123
    """Normalize a tensor image with mean and standard deviation.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
124
125
126
127
128
129
130
131
132
133
134
135

    See ``Normalize`` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.

    Returns:
        Tensor: Normalized image.
    """
136
137
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
138
139
140
141
142
143
144
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def scale(img, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
145
146
147
148
149
    """Rescale the input PIL.Image to the given size.

    Args:
        img (PIL.Image): Image to be scaled.
        size (sequence or int): Desired output size. If size is a sequence like
150
            (h, w), output size will be matched to this. If size is an int,
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
151
152
153
154
155
156
157
158
159
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``

    Returns:
        PIL.Image: Rescaled image.
    """
160
161
162
163
164
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
165
166
167
168
169
170
171
172
173
174
175
176
177
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
178
        return img.resize(size[::-1], interpolation)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
179
180
181


def pad(img, padding, fill=0):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
        img (PIL.Image): Image to be padded.
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.

    Returns:
        PIL.Image: Padded image.
    """
197
198
199
200
201
202
203
204
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
205
206
207
208
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
209
210
211
    return ImageOps.expand(img, border=padding, fill=fill)


212
def crop(img, i, j, h, w):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
213
214
215
216
    """Crop the given PIL.Image.

    Args:
        img (PIL.Image): Image to be cropped.
217
218
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
219
        h: Height of the cropped image.
220
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
221
222
223
224

    Returns:
        PIL.Image: Cropped image.
    """
225
226
227
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

228
    return img.crop((j, i, j + w, i + h))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
229
230


231
def scaled_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
232
233
234
235
236
237
    """Crop the given PIL.Image and scale it to desired size.

    Notably used in RandomSizedCrop.

    Args:
        img (PIL.Image): Image to be cropped.
238
239
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
240
        h: Height of the cropped image.
241
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
242
243
244
245
246
247
        size (sequence or int): Desired output size. Same semantics as ``scale``.
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``.
    Returns:
        PIL.Image: Cropped image.
    """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
248
    assert _is_pil_image(img), 'img should be PIL Image'
249
    img = crop(img, i, j, h, w)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
250
    img = scale(img, size, interpolation)
251
    return img
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
252
253
254


def hflip(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
255
256
257
258
259
260
261
262
    """Horizontally flip the given PIL.Image.

    Args:
        img (PIL.Image): Image to be flipped.

    Returns:
        PIL.Image:  Horizontall flipped image.
    """
263
264
265
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
266
267
268
    return img.transpose(Image.FLIP_LEFT_RIGHT)


269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def vflip(img):
    """Vertically flip the given PIL.Image.

    Args:
        img (PIL.Image): Image to be flipped.

    Returns:
        PIL.Image:  Vertically flipped image.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    return img.transpose(Image.FLIP_TOP_BOTTOM)


284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def five_crop(img, size):
    """Crop the given PIL.Image into four corners and the central crop.

    Note: this transform returns a tuple of images and there may be a mismatch in the number of
    inputs and targets your `Dataset` returns.

    Args:
       size (sequence or int): Desired output size of the crop. If size is an
           int instead of sequence like (h, w), a square crop (size, size) is
           made.
    Returns:
        tuple: tuple (tl, tr, bl, br, center) corresponding top left,
            top right, bottom left, bottom right and center crop.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    w, h = img.size
    crop_h, crop_w = size
    if crop_w > w or crop_h > h:
        raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
                                                                                      (h, w)))
    tl = img.crop((0, 0, crop_w, crop_h))
    tr = img.crop((w - crop_w, 0, w, crop_h))
    bl = img.crop((0, h - crop_h, crop_w, h))
    br = img.crop((w - crop_w, h - crop_h, w, h))
    center = CenterCrop((crop_h, crop_w))(img)
    return (tl, tr, bl, br, center)


def ten_crop(img, size, vertical_flip=False):
    """Crop the given PIL.Image into four corners and the central crop plus the
       flipped version of these (horizontal flipping is used by default).

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
           vertical_flip (bool): Use vertical flipping instead of horizontal

        Returns:
            tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
                br_flip, center_flip) corresponding top left, top right,
                bottom left, bottom right and center crop and same for the
                flipped image.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


soumith's avatar
soumith committed
351
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
352
353
354
    """Composes several transforms together.

    Args:
355
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
356
357
358
359
360
361

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
362
    """
363

soumith's avatar
soumith committed
364
365
366
367
368
369
370
371
372
373
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
374
375
376
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
377
378
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
379

soumith's avatar
soumith committed
380
    def __call__(self, pic):
381
382
383
384
385
386
387
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
388
        return to_tensor(pic)
389

Adam Paszke's avatar
Adam Paszke committed
390

391
class ToPILImage(object):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
392
    """Convert a tensor or an ndarray to PIL Image.
393
394
395

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
396
    """
397

398
    def __call__(self, pic):
399
400
401
402
403
404
405
406
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
407
        return to_pil_image(pic)
408

soumith's avatar
soumith committed
409
410

class Normalize(object):
411
412
413
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
414
415
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
416
417
418
419
420

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
421
    """
422

soumith's avatar
soumith committed
423
424
425
426
427
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
428
429
430
431
432
433
434
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
435
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
436
437
438


class Scale(object):
439
440
441
442
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
443
            (h, w), output size will be matched to this. If size is an int,
444
445
446
447
448
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
449
    """
450

soumith's avatar
soumith committed
451
    def __init__(self, size, interpolation=Image.BILINEAR):
452
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
453
454
455
456
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
457
458
459
460
461
462
463
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
464
        return scale(img, self.size, self.interpolation)
soumith's avatar
soumith committed
465
466
467


class CenterCrop(object):
468
469
470
471
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
472
            int instead of sequence like (h, w), a square crop (size, size) is
473
            made.
474
    """
475

soumith's avatar
soumith committed
476
    def __init__(self, size):
477
478
479
480
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
481

482
483
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
484
485
486
487
488
489
490
        """Get parameters for ``crop`` for center crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
491
            tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
492
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
493
        w, h = img.size
494
        th, tw = output_size
495
496
497
        i = int(round((h - th) / 2.))
        j = int(round((w - tw) / 2.))
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
498

soumith's avatar
soumith committed
499
    def __call__(self, img):
500
501
502
503
504
505
506
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
507
508
        i, j, h, w = self.get_params(img, self.size)
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
509
510


511
class Pad(object):
512
513
514
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
515
516
517
518
519
520
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
521
            length 3, it is used to fill R, G, B channels respectively.
522
    """
523

524
    def __init__(self, padding, fill=0):
525
526
527
528
529
530
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

531
532
533
534
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
535
536
537
538
539
540
541
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
542
        return pad(img, self.padding, self.fill)
543

544

Soumith Chintala's avatar
Soumith Chintala committed
545
class Lambda(object):
546
547
548
549
550
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
551

Soumith Chintala's avatar
Soumith Chintala committed
552
    def __init__(self, lambd):
553
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
554
555
556
557
558
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

559

soumith's avatar
soumith committed
560
class RandomCrop(object):
561
562
563
564
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
565
            int instead of sequence like (h, w), a square crop (size, size) is
566
567
568
569
570
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
571
    """
572

soumith's avatar
soumith committed
573
    def __init__(self, size, padding=0):
574
575
576
577
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
578
579
        self.padding = padding

580
581
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
582
583
584
585
586
587
588
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
589
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
590
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
591
        w, h = img.size
592
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
593
594
595
        if w == tw and h == th:
            return img

596
597
598
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
599

soumith's avatar
soumith committed
600
    def __call__(self, img):
601
602
603
604
605
606
607
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
608
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
609
            img = pad(img, self.padding)
soumith's avatar
soumith committed
610

611
        i, j, h, w = self.get_params(img, self.size)
soumith's avatar
soumith committed
612

613
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
614
615
616


class RandomHorizontalFlip(object):
617
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
618

soumith's avatar
soumith committed
619
    def __call__(self, img):
620
621
622
623
624
625
626
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
627
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
628
            return hflip(img)
soumith's avatar
soumith committed
629
630
631
        return img


632
class RandomVerticalFlip(object):
633
    """Vertically flip the given PIL.Image randomly with a probability of 0.5."""
634
635
636
637
638
639
640
641
642
643

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
        if random.random() < 0.5:
644
            return vflip(img)
645
646
647
        return img


soumith's avatar
soumith committed
648
class RandomSizedCrop(object):
649
650
651
652
653
654
655
656
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
657
        size: expected output size of each edge
658
        interpolation: Default: PIL.Image.BILINEAR
659
    """
660

soumith's avatar
soumith committed
661
    def __init__(self, size, interpolation=Image.BILINEAR):
662
        self.size = (size, size)
soumith's avatar
soumith committed
663
664
        self.interpolation = interpolation

665
666
    @staticmethod
    def get_params(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
667
668
669
670
671
672
        """Get parameters for ``crop`` for a random sized crop.

        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
673
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
674
675
                sized crop.
        """
soumith's avatar
soumith committed
676
677
678
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
679
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
680
681
682
683
684
685
686
687

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
688
689
690
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w
soumith's avatar
soumith committed
691
692

        # Fallback
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
693
        w = min(img.size[0], img.shape[1])
694
695
696
        i = (img.shape[1] - w) // 2
        j = (img.shape[0] - w) // 2
        return i, j, w, w
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
697
698

    def __call__(self, img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
699
700
701
702
703
704
705
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly cropped and scaled image.
        """
706
707
        i, j, h, w = self.get_params(img)
        return scaled_crop(img, i, j, h, w, self.size, self.interpolation)
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730


class FiveCrop(object):
    """Crop the given PIL.Image into four corners and the central crop.abs

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
    """

    def __init__(self, size):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size

    def __call__(self, img):
731
        return five_crop(img, self.size)
732
733
734
735
736
737
738
739
740
741
742
743
744


class TenCrop(object):
    """Crop the given PIL.Image into four corners and the central crop plus the
       flipped version of these (horizontal flipping is used by default)

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
745
           vertical_flip(bool): Use vertical flipping instead of horizontal
746
747
    """

748
    def __init__(self, size, vertical_flip=False):
749
750
751
752
753
754
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size
755
        self.vertical_flip = vertical_flip
756
757

    def __call__(self, img):
758
        return ten_crop(img, self.size, self.vertical_flip)