transforms.py 32.8 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps, ImageEnhance
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
14
import warnings
soumith's avatar
soumith committed
15

16

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
32
def to_tensor(pic):
33
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
34
35
36
37

    See ``ToTensor`` for more details.

    Args:
38
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
39
40
41
42

    Returns:
        Tensor: Converted image.
    """
43
44
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
45

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


81
def to_pil_image(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
82
83
84
85
86
    """Convert a tensor or an ndarray to PIL Image.

    See ``ToPIlImage`` for more details.

    Args:
87
        pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
88
89

    Returns:
90
        PIL Image: Image converted to PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
91
    """
92
93
    if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
94

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
95
96
97
98
99
100
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
101
    assert isinstance(npimg, np.ndarray)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
102
103
104
105
106
107
108
109
110
111
112
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
113
114
115
    elif npimg.shape[2] == 4:
            if npimg.dtype == np.uint8:
                mode = 'RGBA'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
116
117
118
119
120
121
122
123
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
124
    """Normalize a tensor image with mean and standard deviation.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
125
126
127
128
129

    See ``Normalize`` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
130
131
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channely.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
132
133

    Returns:
134
        Tensor: Normalized Tensor image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
135
    """
136
137
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
138
139
140
141
142
143
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


144
def resize(img, size, interpolation=Image.BILINEAR):
145
    """Resize the input PIL Image to the given size.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
146
147

    Args:
148
        img (PIL Image): Image to be resized.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
149
        size (sequence or int): Desired output size. If size is a sequence like
150
151
152
            (h, w), the output size will be matched to this. If size is an int,
            the smaller edge of the image will be matched to this number maintaing
            the aspect ratio. i.e, if height > width, then image will be rescaled to
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
153
154
155
156
157
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``

    Returns:
158
        PIL Image: Resized image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
159
    """
160
161
162
163
164
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
165
166
167
168
169
170
171
172
173
174
175
176
177
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
178
        return img.resize(size[::-1], interpolation)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
179
180


181
182
183
184
185
186
def scale(*args, **kwargs):
    warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                  "please use transforms.Resize instead.")
    return resize(*args, **kwargs)


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
187
def pad(img, padding, fill=0):
188
    """Pad the given PIL Image on all sides with the given "pad" value.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
189
190

    Args:
191
        img (PIL Image): Image to be padded.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
192
193
194
195
196
197
198
199
200
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.

    Returns:
201
        PIL Image: Padded image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
202
    """
203
204
205
206
207
208
209
210
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
211
212
213
214
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
215
216
217
    return ImageOps.expand(img, border=padding, fill=fill)


218
def crop(img, i, j, h, w):
219
    """Crop the given PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
220
221

    Args:
222
        img (PIL Image): Image to be cropped.
223
224
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
225
        h: Height of the cropped image.
226
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
227
228

    Returns:
229
        PIL Image: Cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
230
    """
231
232
233
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

234
    return img.crop((j, i, j + w, i + h))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
235
236


237
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
238
    """Crop the given PIL Image and resize it to desired size.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
239

240
    Notably used in RandomResizedCrop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
241
242

    Args:
243
        img (PIL Image): Image to be cropped.
244
245
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
246
        h: Height of the cropped image.
247
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
248
249
250
251
        size (sequence or int): Desired output size. Same semantics as ``scale``.
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``.
    Returns:
252
        PIL Image: Cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
253
    """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
254
    assert _is_pil_image(img), 'img should be PIL Image'
255
    img = crop(img, i, j, h, w)
256
    img = resize(img, size, interpolation)
257
    return img
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
258
259
260


def hflip(img):
261
    """Horizontally flip the given PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
262
263

    Args:
264
        img (PIL Image): Image to be flipped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
265
266

    Returns:
267
        PIL Image:  Horizontall flipped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
268
    """
269
270
271
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
272
273
274
    return img.transpose(Image.FLIP_LEFT_RIGHT)


275
def vflip(img):
276
    """Vertically flip the given PIL Image.
277
278

    Args:
279
        img (PIL Image): Image to be flipped.
280
281

    Returns:
282
        PIL Image:  Vertically flipped image.
283
284
285
286
287
288
289
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    return img.transpose(Image.FLIP_TOP_BOTTOM)


290
def five_crop(img, size):
291
    """Crop the given PIL Image into four corners and the central crop.
292

293
294
295
    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

    Args:
       size (sequence or int): Desired output size of the crop. If size is an
           int instead of sequence like (h, w), a square crop (size, size) is
           made.
    Returns:
        tuple: tuple (tl, tr, bl, br, center) corresponding top left,
            top right, bottom left, bottom right and center crop.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    w, h = img.size
    crop_h, crop_w = size
    if crop_w > w or crop_h > h:
        raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
                                                                                      (h, w)))
    tl = img.crop((0, 0, crop_w, crop_h))
    tr = img.crop((w - crop_w, 0, w, crop_h))
    bl = img.crop((0, h - crop_h, crop_w, h))
    br = img.crop((w - crop_w, h - crop_h, w, h))
    center = CenterCrop((crop_h, crop_w))(img)
    return (tl, tr, bl, br, center)


def ten_crop(img, size, vertical_flip=False):
324
    """Crop the given PIL Image into four corners and the central crop plus the
325
326
       flipped version of these (horizontal flipping is used by default).

327
328
329
    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
           vertical_flip (bool): Use vertical flipping instead of horizontal

        Returns:
            tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
                br_flip, center_flip) corresponding top left, top right,
                bottom left, bottom right and center crop and same for the
                flipped image.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


359
360
361
362
def adjust_brightness(img, brightness_factor):
    """Adjust brightness of an Image.

    Args:
363
        img (PIL Image): PIL Image to be adjusted.
364
365
366
367
368
        brightness_factor (float):  How much to adjust the brightness. Can be
            any non negative number. 0 gives a black image, 1 gives the
            original image while 2 increases the brightness by a factor of 2.

    Returns:
369
        PIL Image: Brightness adjusted image.
370
371
372
373
374
375
376
377
378
379
380
381
382
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Brightness(img)
    img = enhancer.enhance(brightness_factor)
    return img


def adjust_contrast(img, contrast_factor):
    """Adjust contrast of an Image.

    Args:
383
        img (PIL Image): PIL Image to be adjusted.
384
385
386
387
388
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.

    Returns:
389
        PIL Image: Contrast adjusted image.
390
391
392
393
394
395
396
397
398
399
400
401
402
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Contrast(img)
    img = enhancer.enhance(contrast_factor)
    return img


def adjust_saturation(img, saturation_factor):
    """Adjust color saturation of an image.

    Args:
403
        img (PIL Image): PIL Image to be adjusted.
404
405
406
407
408
        saturation_factor (float):  How much to adjust the saturation. 0 will
            give a black and white image, 1 will give the original image while
            2 will enhance the saturation by a factor of 2.

    Returns:
409
        PIL Image: Saturation adjusted image.
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Color(img)
    img = enhancer.enhance(saturation_factor)
    return img


def adjust_hue(img, hue_factor):
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

    See https://en.wikipedia.org/wiki/Hue for more details on Hue.

    Args:
432
        img (PIL Image): PIL Image to be adjusted.
433
434
435
436
437
438
439
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
440
        PIL Image: Hue adjusted image.
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    """
    if not(-0.5 <= hue_factor <= 0.5):
        raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))

    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    input_mode = img.mode
    if input_mode in {'L', '1', 'I', 'F'}:
        return img

    h, s, v = img.convert('HSV').split()

    np_h = np.array(h, dtype=np.uint8)
    # uint8 addition take cares of rotation across boundaries
    with np.errstate(over='ignore'):
        np_h += np.uint8(hue_factor * 255)
    h = Image.fromarray(np_h, 'L')

    img = Image.merge('HSV', (h, s, v)).convert(input_mode)
    return img


def adjust_gamma(img, gamma, gain=1):
    """Perform gamma correction on an image.

    Also known as Power Law Transform. Intensities in RGB mode are adjusted
    based on the following equation:

        I_out = 255 * gain * ((I_in / 255) ** gamma)

    See https://en.wikipedia.org/wiki/Gamma_correction for more details.

    Args:
475
        img (PIL Image): PIL Image to be adjusted.
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
        gamma (float): Non negative real number. gamma larger than 1 make the
            shadows darker, while gamma smaller than 1 make dark regions
            lighter.
        gain (float): The constant multiplier.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if gamma < 0:
        raise ValueError('Gamma should be a non-negative real number')

    input_mode = img.mode
    img = img.convert('RGB')

    np_img = np.array(img, dtype=np.float32)
    np_img = 255 * gain * ((np_img / 255) ** gamma)
    np_img = np.uint8(np.clip(np_img, 0, 255))

    img = Image.fromarray(np_img, 'RGB').convert(input_mode)
    return img


soumith's avatar
soumith committed
498
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
499
500
501
    """Composes several transforms together.

    Args:
502
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
503
504
505
506
507
508

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
509
    """
510

soumith's avatar
soumith committed
511
512
513
514
515
516
517
518
519
520
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
521
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
522

523
    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
524
525
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
526

soumith's avatar
soumith committed
527
    def __call__(self, pic):
528
529
        """
        Args:
530
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
531
532
533
534

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
535
        return to_tensor(pic)
536

Adam Paszke's avatar
Adam Paszke committed
537

538
class ToPILImage(object):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
539
    """Convert a tensor or an ndarray to PIL Image.
540
541

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
542
    H x W x C to a PIL Image while preserving the value range.
543
    """
544

545
    def __call__(self, pic):
546
547
        """
        Args:
548
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
549
550

        Returns:
551
            PIL Image: Image converted to PIL Image.
552
553

        """
554
        return to_pil_image(pic)
555

soumith's avatar
soumith committed
556
557

class Normalize(object):
558
    """Normalize an tensor image with mean and standard deviation.
559
560
561
    Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
562
563

    Args:
564
565
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
566
    """
567

soumith's avatar
soumith committed
568
569
570
571
572
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
573
574
575
576
577
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
578
            Tensor: Normalized Tensor image.
579
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
580
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
581
582


583
class Resize(object):
584
    """Resize the input PIL Image to the given size.
585
586
587

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
588
            (h, w), output size will be matched to this. If size is an int,
589
590
591
592
593
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
594
    """
595

soumith's avatar
soumith committed
596
    def __init__(self, size, interpolation=Image.BILINEAR):
597
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
598
599
600
601
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
602
603
        """
        Args:
604
            img (PIL Image): Image to be scaled.
605
606

        Returns:
607
            PIL Image: Rescaled image.
608
        """
609
610
611
612
        return resize(img, self.size, self.interpolation)


class Scale(Resize):
613
614
615
    """
    Note: This transform is deprecated in favor of Resize.
    """
616
617
618
619
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                      "please use transforms.Resize instead.")
        super(Scale, self).__init__(*args, **kwargs)
soumith's avatar
soumith committed
620
621
622


class CenterCrop(object):
623
    """Crops the given PIL Image at the center.
624
625
626

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
627
            int instead of sequence like (h, w), a square crop (size, size) is
628
            made.
629
    """
630

soumith's avatar
soumith committed
631
    def __init__(self, size):
632
633
634
635
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
636

637
638
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
639
640
641
        """Get parameters for ``crop`` for center crop.

        Args:
642
            img (PIL Image): Image to be cropped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
643
644
645
            output_size (tuple): Expected output size of the crop.

        Returns:
646
            tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
647
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
648
        w, h = img.size
649
        th, tw = output_size
650
651
652
        i = int(round((h - th) / 2.))
        j = int(round((w - tw) / 2.))
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
653

soumith's avatar
soumith committed
654
    def __call__(self, img):
655
656
        """
        Args:
657
            img (PIL Image): Image to be cropped.
658
659

        Returns:
660
            PIL Image: Cropped image.
661
        """
662
663
        i, j, h, w = self.get_params(img, self.size)
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
664
665


666
class Pad(object):
667
    """Pad the given PIL Image on all sides with the given "pad" value.
668
669

    Args:
670
671
672
673
674
675
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
676
            length 3, it is used to fill R, G, B channels respectively.
677
    """
678

679
    def __init__(self, padding, fill=0):
680
681
682
683
684
685
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

686
687
688
689
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
690
691
        """
        Args:
692
            img (PIL Image): Image to be padded.
693
694

        Returns:
695
            PIL Image: Padded image.
696
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
697
        return pad(img, self.padding, self.fill)
698

699

Soumith Chintala's avatar
Soumith Chintala committed
700
class Lambda(object):
701
702
703
704
705
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
706

Soumith Chintala's avatar
Soumith Chintala committed
707
    def __init__(self, lambd):
708
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
709
710
711
712
713
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

714

soumith's avatar
soumith committed
715
class RandomCrop(object):
716
    """Crop the given PIL Image at a random location.
717
718
719

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
720
            int instead of sequence like (h, w), a square crop (size, size) is
721
722
723
724
725
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
726
    """
727

soumith's avatar
soumith committed
728
    def __init__(self, size, padding=0):
729
730
731
732
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
733
734
        self.padding = padding

735
736
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
737
738
739
        """Get parameters for ``crop`` for a random crop.

        Args:
740
            img (PIL Image): Image to be cropped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
741
742
743
            output_size (tuple): Expected output size of the crop.

        Returns:
744
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
745
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
746
        w, h = img.size
747
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
748
        if w == tw and h == th:
749
            return 0, 0, h, w
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
750

751
752
753
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
754

soumith's avatar
soumith committed
755
    def __call__(self, img):
756
757
        """
        Args:
758
            img (PIL Image): Image to be cropped.
759
760

        Returns:
761
            PIL Image: Cropped image.
762
        """
soumith's avatar
soumith committed
763
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
764
            img = pad(img, self.padding)
soumith's avatar
soumith committed
765

766
        i, j, h, w = self.get_params(img, self.size)
soumith's avatar
soumith committed
767

768
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
769
770
771


class RandomHorizontalFlip(object):
772
    """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
773

soumith's avatar
soumith committed
774
    def __call__(self, img):
775
776
        """
        Args:
777
            img (PIL Image): Image to be flipped.
778
779

        Returns:
780
            PIL Image: Randomly flipped image.
781
        """
soumith's avatar
soumith committed
782
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
783
            return hflip(img)
soumith's avatar
soumith committed
784
785
786
        return img


787
class RandomVerticalFlip(object):
788
    """Vertically flip the given PIL Image randomly with a probability of 0.5."""
789
790
791
792

    def __call__(self, img):
        """
        Args:
793
            img (PIL Image): Image to be flipped.
794
795

        Returns:
796
            PIL Image: Randomly flipped image.
797
798
        """
        if random.random() < 0.5:
799
            return vflip(img)
800
801
802
        return img


803
class RandomResizedCrop(object):
804
    """Crop the given PIL Image to random size and aspect ratio.
805
806
807
808
809
810
811

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
812
        size: expected output size of each edge
813
        interpolation: Default: PIL.Image.BILINEAR
814
    """
815

soumith's avatar
soumith committed
816
    def __init__(self, size, interpolation=Image.BILINEAR):
817
        self.size = (size, size)
soumith's avatar
soumith committed
818
819
        self.interpolation = interpolation

820
821
    @staticmethod
    def get_params(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
822
823
824
        """Get parameters for ``crop`` for a random sized crop.

        Args:
825
            img (PIL Image): Image to be cropped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
826
827

        Returns:
828
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
829
830
                sized crop.
        """
soumith's avatar
soumith committed
831
832
833
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
834
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
835
836
837
838
839
840
841
842

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
843
844
845
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w
soumith's avatar
soumith committed
846
847

        # Fallback
848
849
850
        w = min(img.size[0], img.size[1])
        i = (img.size[1] - w) // 2
        j = (img.size[0] - w) // 2
851
        return i, j, w, w
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
852
853

    def __call__(self, img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
854
855
        """
        Args:
856
            img (PIL Image): Image to be flipped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
857
858

        Returns:
859
            PIL Image: Randomly cropped and resize image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
860
        """
861
        i, j, h, w = self.get_params(img)
862
863
864
865
        return resized_crop(img, i, j, h, w, self.size, self.interpolation)


class RandomSizedCrop(RandomResizedCrop):
866
867
868
    """
    Note: This transform is deprecated in favor of RandomResizedCrop.
    """
869
870
871
872
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
                      "please use transforms.RandomResizedCrop instead.")
        super(RandomSizedCrop, self).__init__(*args, **kwargs)
873
874
875


class FiveCrop(object):
876
    """Crop the given PIL Image into four corners and the central crop.abs
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
    """

    def __init__(self, size):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size

    def __call__(self, img):
896
        return five_crop(img, self.size)
897
898
899


class TenCrop(object):
900
    """Crop the given PIL Image into four corners and the central crop plus the
901
902
903
904
905
906
907
908
909
       flipped version of these (horizontal flipping is used by default)

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
910
           vertical_flip(bool): Use vertical flipping instead of horizontal
911
912
    """

913
    def __init__(self, size, vertical_flip=False):
914
915
916
917
918
919
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size
920
        self.vertical_flip = vertical_flip
921
922

    def __call__(self, img):
923
        return ten_crop(img, self.size, self.vertical_flip)
924
925


926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
class LinearTransformation(object):
    """Transform a tensor image with a square transformation matrix computed
    offline.

    Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
    product with the transformation matrix and reshape the tensor to its
    original shape.

    Applications:
    - whitening: zero-center the data, compute the data covariance matrix
                 [D x D] with np.dot(X.T, X), perform SVD on this matrix and
                 pass it as transformation_matrix.

    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
    """

    def __init__(self, transformation_matrix):
        if transformation_matrix.size(0) != transformation_matrix.size(1):
            raise ValueError("transformation_matrix should be square. Got " +
                             "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
        self.transformation_matrix = transformation_matrix

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be whitened.

        Returns:
            Tensor: Transformed image.
        """
        if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
            raise ValueError("tensor and transformation matrix have incompatible shape." +
                             "[{} x {} x {}] != ".format(*tensor.size()) +
                             "{}".format(self.transformation_matrix.size(0)))
        flat_tensor = tensor.view(1, -1)
        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
        tensor = transformed_tensor.view(tensor.size())
        return tensor


967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
class ColorJitter(object):
    """Randomly change the brightness, contrast and saturation of an image.

    Args:
        brightness (float): How much to jitter brightness. brightness_factor
            is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
        contrast (float): How much to jitter contrast. contrast_factor
            is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
        saturation (float): How much to jitter saturation. saturation_factor
            is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
        hue(float): How much to jitter hue. hue_factor is chosen uniformly from
            [-hue, hue]. Should be >=0 and <= 0.5.
    """
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue

    @staticmethod
    def get_params(brightness, contrast, saturation, hue):
        """Get a randomized transform to be applied on image.

        Arguments are same as that of __init__.

        Returns:
            Transform which randomly adjusts brightness, contrast and
            saturation in a random order.
        """
        transforms = []
        if brightness > 0:
            brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
            transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))

        if contrast > 0:
            contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
            transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))

        if saturation > 0:
            saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
            transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))

        if hue > 0:
            hue_factor = np.random.uniform(-hue, hue)
            transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))

        np.random.shuffle(transforms)
        transform = Compose(transforms)

        return transform

    def __call__(self, img):
        """
        Args:
1021
            img (PIL Image): Input image.
1022
1023

        Returns:
1024
            PIL Image: Color jittered image.
1025
1026
1027
1028
        """
        transform = self.get_params(self.brightness, self.contrast,
                                    self.saturation, self.hue)
        return transform(img)