transforms.py 31 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps, ImageEnhance
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
14
import warnings
soumith's avatar
soumith committed
15

16

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
32
def to_tensor(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
33
34
35
36
37
38
39
40
41
42
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    See ``ToTensor`` for more details.

    Args:
        pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """
43
44
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
45

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


81
def to_pil_image(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
82
83
84
85
86
87
88
89
90
91
    """Convert a tensor or an ndarray to PIL Image.

    See ``ToPIlImage`` for more details.

    Args:
        pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

    Returns:
        PIL.Image: Image converted to PIL.Image.
    """
92
93
    if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
94

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
95
96
97
98
99
100
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
101
    assert isinstance(npimg, np.ndarray)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
102
103
104
105
106
107
108
109
110
111
112
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
113
114
115
    elif npimg.shape[2] == 4:
            if npimg.dtype == np.uint8:
                mode = 'RGBA'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
116
117
118
119
120
121
122
123
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
124
    """Normalize a tensor image with mean and standard deviation.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
125
126
127
128
129
130
131
132
133
134
135
136

    See ``Normalize`` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.

    Returns:
        Tensor: Normalized image.
    """
137
138
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
139
140
141
142
143
144
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


145
146
def resize(img, size, interpolation=Image.BILINEAR):
    """Resize the input PIL.Image to the given size.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
147
148

    Args:
149
        img (PIL.Image): Image to be resized.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
150
        size (sequence or int): Desired output size. If size is a sequence like
151
152
153
            (h, w), the output size will be matched to this. If size is an int,
            the smaller edge of the image will be matched to this number maintaing
            the aspect ratio. i.e, if height > width, then image will be rescaled to
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
154
155
156
157
158
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``

    Returns:
159
        PIL.Image: Resized image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
160
    """
161
162
163
164
165
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
166
167
168
169
170
171
172
173
174
175
176
177
178
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
179
        return img.resize(size[::-1], interpolation)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
180
181


182
183
184
185
186
187
def scale(*args, **kwargs):
    warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                  "please use transforms.Resize instead.")
    return resize(*args, **kwargs)


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
188
def pad(img, padding, fill=0):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
        img (PIL.Image): Image to be padded.
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.

    Returns:
        PIL.Image: Padded image.
    """
204
205
206
207
208
209
210
211
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
212
213
214
215
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
216
217
218
    return ImageOps.expand(img, border=padding, fill=fill)


219
def crop(img, i, j, h, w):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
220
221
222
223
    """Crop the given PIL.Image.

    Args:
        img (PIL.Image): Image to be cropped.
224
225
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
226
        h: Height of the cropped image.
227
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
228
229
230
231

    Returns:
        PIL.Image: Cropped image.
    """
232
233
234
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

235
    return img.crop((j, i, j + w, i + h))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
236
237


238
239
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
    """Crop the given PIL.Image and resize it to desired size.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
240

241
    Notably used in RandomResizedCrop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
242
243
244

    Args:
        img (PIL.Image): Image to be cropped.
245
246
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
247
        h: Height of the cropped image.
248
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
249
250
251
252
253
254
        size (sequence or int): Desired output size. Same semantics as ``scale``.
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``.
    Returns:
        PIL.Image: Cropped image.
    """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
255
    assert _is_pil_image(img), 'img should be PIL Image'
256
    img = crop(img, i, j, h, w)
257
    img = resize(img, size, interpolation)
258
    return img
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
259
260
261


def hflip(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
262
263
264
265
266
267
268
269
    """Horizontally flip the given PIL.Image.

    Args:
        img (PIL.Image): Image to be flipped.

    Returns:
        PIL.Image:  Horizontall flipped image.
    """
270
271
272
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
273
274
275
    return img.transpose(Image.FLIP_LEFT_RIGHT)


276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def vflip(img):
    """Vertically flip the given PIL.Image.

    Args:
        img (PIL.Image): Image to be flipped.

    Returns:
        PIL.Image:  Vertically flipped image.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    return img.transpose(Image.FLIP_TOP_BOTTOM)


291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def five_crop(img, size):
    """Crop the given PIL.Image into four corners and the central crop.

    Note: this transform returns a tuple of images and there may be a mismatch in the number of
    inputs and targets your `Dataset` returns.

    Args:
       size (sequence or int): Desired output size of the crop. If size is an
           int instead of sequence like (h, w), a square crop (size, size) is
           made.
    Returns:
        tuple: tuple (tl, tr, bl, br, center) corresponding top left,
            top right, bottom left, bottom right and center crop.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    w, h = img.size
    crop_h, crop_w = size
    if crop_w > w or crop_h > h:
        raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
                                                                                      (h, w)))
    tl = img.crop((0, 0, crop_w, crop_h))
    tr = img.crop((w - crop_w, 0, w, crop_h))
    bl = img.crop((0, h - crop_h, crop_w, h))
    br = img.crop((w - crop_w, h - crop_h, w, h))
    center = CenterCrop((crop_h, crop_w))(img)
    return (tl, tr, bl, br, center)


def ten_crop(img, size, vertical_flip=False):
    """Crop the given PIL.Image into four corners and the central crop plus the
       flipped version of these (horizontal flipping is used by default).

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
           vertical_flip (bool): Use vertical flipping instead of horizontal

        Returns:
            tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
                br_flip, center_flip) corresponding top left, top right,
                bottom left, bottom right and center crop and same for the
                flipped image.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
def adjust_brightness(img, brightness_factor):
    """Adjust brightness of an Image.

    Args:
        img (PIL.Image): PIL Image to be adjusted.
        brightness_factor (float):  How much to adjust the brightness. Can be
            any non negative number. 0 gives a black image, 1 gives the
            original image while 2 increases the brightness by a factor of 2.

    Returns:
        PIL.Image: Brightness adjusted image.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Brightness(img)
    img = enhancer.enhance(brightness_factor)
    return img


def adjust_contrast(img, contrast_factor):
    """Adjust contrast of an Image.

    Args:
        img (PIL.Image): PIL Image to be adjusted.
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.

    Returns:
        PIL.Image: Contrast adjusted image.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Contrast(img)
    img = enhancer.enhance(contrast_factor)
    return img


def adjust_saturation(img, saturation_factor):
    """Adjust color saturation of an image.

    Args:
        img (PIL.Image): PIL Image to be adjusted.
        saturation_factor (float):  How much to adjust the saturation. 0 will
            give a black and white image, 1 will give the original image while
            2 will enhance the saturation by a factor of 2.

    Returns:
        PIL.Image: Saturation adjusted image.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Color(img)
    img = enhancer.enhance(saturation_factor)
    return img


def adjust_hue(img, hue_factor):
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

    See https://en.wikipedia.org/wiki/Hue for more details on Hue.

    Args:
        img (PIL.Image): PIL Image to be adjusted.
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
        PIL.Image: Hue adjusted image.
    """
    if not(-0.5 <= hue_factor <= 0.5):
        raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))

    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    input_mode = img.mode
    if input_mode in {'L', '1', 'I', 'F'}:
        return img

    h, s, v = img.convert('HSV').split()

    np_h = np.array(h, dtype=np.uint8)
    # uint8 addition take cares of rotation across boundaries
    with np.errstate(over='ignore'):
        np_h += np.uint8(hue_factor * 255)
    h = Image.fromarray(np_h, 'L')

    img = Image.merge('HSV', (h, s, v)).convert(input_mode)
    return img


def adjust_gamma(img, gamma, gain=1):
    """Perform gamma correction on an image.

    Also known as Power Law Transform. Intensities in RGB mode are adjusted
    based on the following equation:

        I_out = 255 * gain * ((I_in / 255) ** gamma)

    See https://en.wikipedia.org/wiki/Gamma_correction for more details.

    Args:
        img (PIL.Image): PIL Image to be adjusted.
        gamma (float): Non negative real number. gamma larger than 1 make the
            shadows darker, while gamma smaller than 1 make dark regions
            lighter.
        gain (float): The constant multiplier.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if gamma < 0:
        raise ValueError('Gamma should be a non-negative real number')

    input_mode = img.mode
    img = img.convert('RGB')

    np_img = np.array(img, dtype=np.float32)
    np_img = 255 * gain * ((np_img / 255) ** gamma)
    np_img = np.uint8(np.clip(np_img, 0, 255))

    img = Image.fromarray(np_img, 'RGB').convert(input_mode)
    return img


soumith's avatar
soumith committed
497
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
498
499
500
    """Composes several transforms together.

    Args:
501
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
502
503
504
505
506
507

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
508
    """
509

soumith's avatar
soumith committed
510
511
512
513
514
515
516
517
518
519
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
520
521
522
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
523
524
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
525

soumith's avatar
soumith committed
526
    def __call__(self, pic):
527
528
529
530
531
532
533
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
534
        return to_tensor(pic)
535

Adam Paszke's avatar
Adam Paszke committed
536

537
class ToPILImage(object):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
538
    """Convert a tensor or an ndarray to PIL Image.
539
540
541

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
542
    """
543

544
    def __call__(self, pic):
545
546
547
548
549
550
551
552
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
553
        return to_pil_image(pic)
554

soumith's avatar
soumith committed
555
556

class Normalize(object):
557
558
559
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
560
561
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
562
563
564
565
566

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
567
    """
568

soumith's avatar
soumith committed
569
570
571
572
573
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
574
575
576
577
578
579
580
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
581
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
582
583


584
585
class Resize(object):
    """Resize the input PIL.Image to the given size.
586
587
588

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
589
            (h, w), output size will be matched to this. If size is an int,
590
591
592
593
594
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
595
    """
596

soumith's avatar
soumith committed
597
    def __init__(self, size, interpolation=Image.BILINEAR):
598
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
599
600
601
602
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
603
604
605
606
607
608
609
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
610
611
612
613
614
615
616
617
        return resize(img, self.size, self.interpolation)


class Scale(Resize):
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                      "please use transforms.Resize instead.")
        super(Scale, self).__init__(*args, **kwargs)
soumith's avatar
soumith committed
618
619
620


class CenterCrop(object):
621
622
623
624
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
625
            int instead of sequence like (h, w), a square crop (size, size) is
626
            made.
627
    """
628

soumith's avatar
soumith committed
629
    def __init__(self, size):
630
631
632
633
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
634

635
636
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
637
638
639
640
641
642
643
        """Get parameters for ``crop`` for center crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
644
            tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
645
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
646
        w, h = img.size
647
        th, tw = output_size
648
649
650
        i = int(round((h - th) / 2.))
        j = int(round((w - tw) / 2.))
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
651

soumith's avatar
soumith committed
652
    def __call__(self, img):
653
654
655
656
657
658
659
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
660
661
        i, j, h, w = self.get_params(img, self.size)
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
662
663


664
class Pad(object):
665
666
667
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
668
669
670
671
672
673
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
674
            length 3, it is used to fill R, G, B channels respectively.
675
    """
676

677
    def __init__(self, padding, fill=0):
678
679
680
681
682
683
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

684
685
686
687
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
688
689
690
691
692
693
694
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
695
        return pad(img, self.padding, self.fill)
696

697

Soumith Chintala's avatar
Soumith Chintala committed
698
class Lambda(object):
699
700
701
702
703
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
704

Soumith Chintala's avatar
Soumith Chintala committed
705
    def __init__(self, lambd):
706
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
707
708
709
710
711
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

712

soumith's avatar
soumith committed
713
class RandomCrop(object):
714
715
716
717
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
718
            int instead of sequence like (h, w), a square crop (size, size) is
719
720
721
722
723
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
724
    """
725

soumith's avatar
soumith committed
726
    def __init__(self, size, padding=0):
727
728
729
730
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
731
732
        self.padding = padding

733
734
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
735
736
737
738
739
740
741
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL.Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
742
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
743
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
744
        w, h = img.size
745
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
746
747
748
        if w == tw and h == th:
            return img

749
750
751
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
752

soumith's avatar
soumith committed
753
    def __call__(self, img):
754
755
756
757
758
759
760
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
761
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
762
            img = pad(img, self.padding)
soumith's avatar
soumith committed
763

764
        i, j, h, w = self.get_params(img, self.size)
soumith's avatar
soumith committed
765

766
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
767
768
769


class RandomHorizontalFlip(object):
770
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
771

soumith's avatar
soumith committed
772
    def __call__(self, img):
773
774
775
776
777
778
779
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
780
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
781
            return hflip(img)
soumith's avatar
soumith committed
782
783
784
        return img


785
class RandomVerticalFlip(object):
786
    """Vertically flip the given PIL.Image randomly with a probability of 0.5."""
787
788
789
790
791
792
793
794
795
796

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
        if random.random() < 0.5:
797
            return vflip(img)
798
799
800
        return img


801
class RandomResizedCrop(object):
802
803
804
805
806
807
808
809
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
810
        size: expected output size of each edge
811
        interpolation: Default: PIL.Image.BILINEAR
812
    """
813

soumith's avatar
soumith committed
814
    def __init__(self, size, interpolation=Image.BILINEAR):
815
        self.size = (size, size)
soumith's avatar
soumith committed
816
817
        self.interpolation = interpolation

818
819
    @staticmethod
    def get_params(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
820
821
822
823
824
825
        """Get parameters for ``crop`` for a random sized crop.

        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
826
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
827
828
                sized crop.
        """
soumith's avatar
soumith committed
829
830
831
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
832
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
833
834
835
836
837
838
839
840

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
841
842
843
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w
soumith's avatar
soumith committed
844
845

        # Fallback
846
847
848
        w = min(img.size[0], img.size[1])
        i = (img.size[1] - w) // 2
        j = (img.size[0] - w) // 2
849
        return i, j, w, w
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
850
851

    def __call__(self, img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
852
853
854
855
856
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
857
            PIL.Image: Randomly cropped and resize image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
858
        """
859
        i, j, h, w = self.get_params(img)
860
861
862
863
864
865
866
867
        return resized_crop(img, i, j, h, w, self.size, self.interpolation)


class RandomSizedCrop(RandomResizedCrop):
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
                      "please use transforms.RandomResizedCrop instead.")
        super(RandomSizedCrop, self).__init__(*args, **kwargs)
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890


class FiveCrop(object):
    """Crop the given PIL.Image into four corners and the central crop.abs

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
    """

    def __init__(self, size):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size

    def __call__(self, img):
891
        return five_crop(img, self.size)
892
893
894
895
896
897
898
899
900
901
902
903
904


class TenCrop(object):
    """Crop the given PIL.Image into four corners and the central crop plus the
       flipped version of these (horizontal flipping is used by default)

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
905
           vertical_flip(bool): Use vertical flipping instead of horizontal
906
907
    """

908
    def __init__(self, size, vertical_flip=False):
909
910
911
912
913
914
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size
915
        self.vertical_flip = vertical_flip
916
917

    def __call__(self, img):
918
        return ten_crop(img, self.size, self.vertical_flip)
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982


class ColorJitter(object):
    """Randomly change the brightness, contrast and saturation of an image.

    Args:
        brightness (float): How much to jitter brightness. brightness_factor
            is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
        contrast (float): How much to jitter contrast. contrast_factor
            is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
        saturation (float): How much to jitter saturation. saturation_factor
            is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
        hue(float): How much to jitter hue. hue_factor is chosen uniformly from
            [-hue, hue]. Should be >=0 and <= 0.5.
    """
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue

    @staticmethod
    def get_params(brightness, contrast, saturation, hue):
        """Get a randomized transform to be applied on image.

        Arguments are same as that of __init__.

        Returns:
            Transform which randomly adjusts brightness, contrast and
            saturation in a random order.
        """
        transforms = []
        if brightness > 0:
            brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
            transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))

        if contrast > 0:
            contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
            transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))

        if saturation > 0:
            saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
            transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))

        if hue > 0:
            hue_factor = np.random.uniform(-hue, hue)
            transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))

        np.random.shuffle(transforms)
        transform = Compose(transforms)

        return transform

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Input image.

        Returns:
            PIL.Image: Color jittered image.
        """
        transform = self.get_params(self.brightness, self.contrast,
                                    self.saturation, self.hue)
        return transform(img)