transforms.py 32.8 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps, ImageEnhance
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
14
import warnings
soumith's avatar
soumith committed
15

16

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
32
def to_tensor(pic):
33
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
34
35
36
37

    See ``ToTensor`` for more details.

    Args:
38
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
39
40
41
42

    Returns:
        Tensor: Converted image.
    """
43
44
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
45

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


81
def to_pil_image(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
82
83
84
85
86
    """Convert a tensor or an ndarray to PIL Image.

    See ``ToPIlImage`` for more details.

    Args:
87
        pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
88
89

    Returns:
90
        PIL Image: Image converted to PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
91
    """
92
93
    if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
94

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
95
96
97
98
99
100
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
101
    assert isinstance(npimg, np.ndarray)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
102
103
104
105
106
107
108
109
110
111
112
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
113
114
115
    elif npimg.shape[2] == 4:
            if npimg.dtype == np.uint8:
                mode = 'RGBA'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
116
117
118
119
120
121
122
123
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
124
    """Normalize a tensor image with mean and standard deviation.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
125
126
127
128
129
130
131
132
133
134
135
136

    See ``Normalize`` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.

    Returns:
        Tensor: Normalized image.
    """
137
138
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
139
140
141
142
143
144
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


145
def resize(img, size, interpolation=Image.BILINEAR):
146
    """Resize the input PIL Image to the given size.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
147
148

    Args:
149
        img (PIL Image): Image to be resized.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
150
        size (sequence or int): Desired output size. If size is a sequence like
151
152
153
            (h, w), the output size will be matched to this. If size is an int,
            the smaller edge of the image will be matched to this number maintaing
            the aspect ratio. i.e, if height > width, then image will be rescaled to
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
154
155
156
157
158
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``

    Returns:
159
        PIL Image: Resized image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
160
    """
161
162
163
164
165
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
166
167
168
169
170
171
172
173
174
175
176
177
178
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
179
        return img.resize(size[::-1], interpolation)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
180
181


182
183
184
185
186
187
def scale(*args, **kwargs):
    warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                  "please use transforms.Resize instead.")
    return resize(*args, **kwargs)


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
188
def pad(img, padding, fill=0):
189
    """Pad the given PIL Image on all sides with the given "pad" value.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
190
191

    Args:
192
        img (PIL Image): Image to be padded.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
193
194
195
196
197
198
199
200
201
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.

    Returns:
202
        PIL Image: Padded image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
203
    """
204
205
206
207
208
209
210
211
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
212
213
214
215
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
216
217
218
    return ImageOps.expand(img, border=padding, fill=fill)


219
def crop(img, i, j, h, w):
220
    """Crop the given PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
221
222

    Args:
223
        img (PIL Image): Image to be cropped.
224
225
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
226
        h: Height of the cropped image.
227
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
228
229

    Returns:
230
        PIL Image: Cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
231
    """
232
233
234
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

235
    return img.crop((j, i, j + w, i + h))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
236
237


238
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
239
    """Crop the given PIL Image and resize it to desired size.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
240

241
    Notably used in RandomResizedCrop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
242
243

    Args:
244
        img (PIL Image): Image to be cropped.
245
246
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
247
        h: Height of the cropped image.
248
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
249
250
251
252
        size (sequence or int): Desired output size. Same semantics as ``scale``.
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``.
    Returns:
253
        PIL Image: Cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
254
    """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
255
    assert _is_pil_image(img), 'img should be PIL Image'
256
    img = crop(img, i, j, h, w)
257
    img = resize(img, size, interpolation)
258
    return img
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
259
260
261


def hflip(img):
262
    """Horizontally flip the given PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
263
264

    Args:
265
        img (PIL Image): Image to be flipped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
266
267

    Returns:
268
        PIL Image:  Horizontall flipped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
269
    """
270
271
272
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
273
274
275
    return img.transpose(Image.FLIP_LEFT_RIGHT)


276
def vflip(img):
277
    """Vertically flip the given PIL Image.
278
279

    Args:
280
        img (PIL Image): Image to be flipped.
281
282

    Returns:
283
        PIL Image:  Vertically flipped image.
284
285
286
287
288
289
290
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    return img.transpose(Image.FLIP_TOP_BOTTOM)


291
def five_crop(img, size):
292
    """Crop the given PIL Image into four corners and the central crop.
293

294
295
296
    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

    Args:
       size (sequence or int): Desired output size of the crop. If size is an
           int instead of sequence like (h, w), a square crop (size, size) is
           made.
    Returns:
        tuple: tuple (tl, tr, bl, br, center) corresponding top left,
            top right, bottom left, bottom right and center crop.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    w, h = img.size
    crop_h, crop_w = size
    if crop_w > w or crop_h > h:
        raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
                                                                                      (h, w)))
    tl = img.crop((0, 0, crop_w, crop_h))
    tr = img.crop((w - crop_w, 0, w, crop_h))
    bl = img.crop((0, h - crop_h, crop_w, h))
    br = img.crop((w - crop_w, h - crop_h, w, h))
    center = CenterCrop((crop_h, crop_w))(img)
    return (tl, tr, bl, br, center)


def ten_crop(img, size, vertical_flip=False):
325
    """Crop the given PIL Image into four corners and the central crop plus the
326
327
       flipped version of these (horizontal flipping is used by default).

328
329
330
    .. Note::
        This transform returns a tuple of images and there may be a
        mismatch in the number of inputs and targets your ``Dataset`` returns.
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
           vertical_flip (bool): Use vertical flipping instead of horizontal

        Returns:
            tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
                br_flip, center_flip) corresponding top left, top right,
                bottom left, bottom right and center crop and same for the
                flipped image.
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


360
361
362
363
def adjust_brightness(img, brightness_factor):
    """Adjust brightness of an Image.

    Args:
364
        img (PIL Image): PIL Image to be adjusted.
365
366
367
368
369
        brightness_factor (float):  How much to adjust the brightness. Can be
            any non negative number. 0 gives a black image, 1 gives the
            original image while 2 increases the brightness by a factor of 2.

    Returns:
370
        PIL Image: Brightness adjusted image.
371
372
373
374
375
376
377
378
379
380
381
382
383
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Brightness(img)
    img = enhancer.enhance(brightness_factor)
    return img


def adjust_contrast(img, contrast_factor):
    """Adjust contrast of an Image.

    Args:
384
        img (PIL Image): PIL Image to be adjusted.
385
386
387
388
389
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.

    Returns:
390
        PIL Image: Contrast adjusted image.
391
392
393
394
395
396
397
398
399
400
401
402
403
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Contrast(img)
    img = enhancer.enhance(contrast_factor)
    return img


def adjust_saturation(img, saturation_factor):
    """Adjust color saturation of an image.

    Args:
404
        img (PIL Image): PIL Image to be adjusted.
405
406
407
408
409
        saturation_factor (float):  How much to adjust the saturation. 0 will
            give a black and white image, 1 will give the original image while
            2 will enhance the saturation by a factor of 2.

    Returns:
410
        PIL Image: Saturation adjusted image.
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Color(img)
    img = enhancer.enhance(saturation_factor)
    return img


def adjust_hue(img, hue_factor):
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

    See https://en.wikipedia.org/wiki/Hue for more details on Hue.

    Args:
433
        img (PIL Image): PIL Image to be adjusted.
434
435
436
437
438
439
440
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
441
        PIL Image: Hue adjusted image.
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    """
    if not(-0.5 <= hue_factor <= 0.5):
        raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))

    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    input_mode = img.mode
    if input_mode in {'L', '1', 'I', 'F'}:
        return img

    h, s, v = img.convert('HSV').split()

    np_h = np.array(h, dtype=np.uint8)
    # uint8 addition take cares of rotation across boundaries
    with np.errstate(over='ignore'):
        np_h += np.uint8(hue_factor * 255)
    h = Image.fromarray(np_h, 'L')

    img = Image.merge('HSV', (h, s, v)).convert(input_mode)
    return img


def adjust_gamma(img, gamma, gain=1):
    """Perform gamma correction on an image.

    Also known as Power Law Transform. Intensities in RGB mode are adjusted
    based on the following equation:

        I_out = 255 * gain * ((I_in / 255) ** gamma)

    See https://en.wikipedia.org/wiki/Gamma_correction for more details.

    Args:
476
        img (PIL Image): PIL Image to be adjusted.
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        gamma (float): Non negative real number. gamma larger than 1 make the
            shadows darker, while gamma smaller than 1 make dark regions
            lighter.
        gain (float): The constant multiplier.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if gamma < 0:
        raise ValueError('Gamma should be a non-negative real number')

    input_mode = img.mode
    img = img.convert('RGB')

    np_img = np.array(img, dtype=np.float32)
    np_img = 255 * gain * ((np_img / 255) ** gamma)
    np_img = np.uint8(np.clip(np_img, 0, 255))

    img = Image.fromarray(np_img, 'RGB').convert(input_mode)
    return img


soumith's avatar
soumith committed
499
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
500
501
502
    """Composes several transforms together.

    Args:
503
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
504
505
506
507
508
509

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
510
    """
511

soumith's avatar
soumith committed
512
513
514
515
516
517
518
519
520
521
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
522
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
523

524
    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
525
526
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
527

soumith's avatar
soumith committed
528
    def __call__(self, pic):
529
530
        """
        Args:
531
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
532
533
534
535

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
536
        return to_tensor(pic)
537

Adam Paszke's avatar
Adam Paszke committed
538

539
class ToPILImage(object):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
540
    """Convert a tensor or an ndarray to PIL Image.
541
542

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
543
    H x W x C to a PIL Image while preserving the value range.
544
    """
545

546
    def __call__(self, pic):
547
548
        """
        Args:
549
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
550
551

        Returns:
552
            PIL Image: Image converted to PIL Image.
553
554

        """
555
        return to_pil_image(pic)
556

soumith's avatar
soumith committed
557
558

class Normalize(object):
559
560
561
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
562
563
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
564
565
566
567
568

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
569
    """
570

soumith's avatar
soumith committed
571
572
573
574
575
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
576
577
578
579
580
581
582
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
583
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
584
585


586
class Resize(object):
587
    """Resize the input PIL Image to the given size.
588
589
590

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
591
            (h, w), output size will be matched to this. If size is an int,
592
593
594
595
596
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
597
    """
598

soumith's avatar
soumith committed
599
    def __init__(self, size, interpolation=Image.BILINEAR):
600
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
601
602
603
604
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
605
606
        """
        Args:
607
            img (PIL Image): Image to be scaled.
608
609

        Returns:
610
            PIL Image: Rescaled image.
611
        """
612
613
614
615
        return resize(img, self.size, self.interpolation)


class Scale(Resize):
616
617
618
    """
    Note: This transform is deprecated in favor of Resize.
    """
619
620
621
622
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                      "please use transforms.Resize instead.")
        super(Scale, self).__init__(*args, **kwargs)
soumith's avatar
soumith committed
623
624
625


class CenterCrop(object):
626
    """Crops the given PIL Image at the center.
627
628
629

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
630
            int instead of sequence like (h, w), a square crop (size, size) is
631
            made.
632
    """
633

soumith's avatar
soumith committed
634
    def __init__(self, size):
635
636
637
638
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
639

640
641
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
642
643
644
        """Get parameters for ``crop`` for center crop.

        Args:
645
            img (PIL Image): Image to be cropped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
646
647
648
            output_size (tuple): Expected output size of the crop.

        Returns:
649
            tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
650
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
651
        w, h = img.size
652
        th, tw = output_size
653
654
655
        i = int(round((h - th) / 2.))
        j = int(round((w - tw) / 2.))
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
656

soumith's avatar
soumith committed
657
    def __call__(self, img):
658
659
        """
        Args:
660
            img (PIL Image): Image to be cropped.
661
662

        Returns:
663
            PIL Image: Cropped image.
664
        """
665
666
        i, j, h, w = self.get_params(img, self.size)
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
667
668


669
class Pad(object):
670
    """Pad the given PIL Image on all sides with the given "pad" value.
671
672

    Args:
673
674
675
676
677
678
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
679
            length 3, it is used to fill R, G, B channels respectively.
680
    """
681

682
    def __init__(self, padding, fill=0):
683
684
685
686
687
688
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

689
690
691
692
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
693
694
        """
        Args:
695
            img (PIL Image): Image to be padded.
696
697

        Returns:
698
            PIL Image: Padded image.
699
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
700
        return pad(img, self.padding, self.fill)
701

702

Soumith Chintala's avatar
Soumith Chintala committed
703
class Lambda(object):
704
705
706
707
708
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
709

Soumith Chintala's avatar
Soumith Chintala committed
710
    def __init__(self, lambd):
711
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
712
713
714
715
716
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

717

soumith's avatar
soumith committed
718
class RandomCrop(object):
719
    """Crop the given PIL Image at a random location.
720
721
722

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
723
            int instead of sequence like (h, w), a square crop (size, size) is
724
725
726
727
728
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
729
    """
730

soumith's avatar
soumith committed
731
    def __init__(self, size, padding=0):
732
733
734
735
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
736
737
        self.padding = padding

738
739
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
740
741
742
        """Get parameters for ``crop`` for a random crop.

        Args:
743
            img (PIL Image): Image to be cropped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
744
745
746
            output_size (tuple): Expected output size of the crop.

        Returns:
747
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
748
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
749
        w, h = img.size
750
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
751
        if w == tw and h == th:
752
            return 0, 0, h, w
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
753

754
755
756
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
757

soumith's avatar
soumith committed
758
    def __call__(self, img):
759
760
        """
        Args:
761
            img (PIL Image): Image to be cropped.
762
763

        Returns:
764
            PIL Image: Cropped image.
765
        """
soumith's avatar
soumith committed
766
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
767
            img = pad(img, self.padding)
soumith's avatar
soumith committed
768

769
        i, j, h, w = self.get_params(img, self.size)
soumith's avatar
soumith committed
770

771
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
772
773
774


class RandomHorizontalFlip(object):
775
    """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
776

soumith's avatar
soumith committed
777
    def __call__(self, img):
778
779
        """
        Args:
780
            img (PIL Image): Image to be flipped.
781
782

        Returns:
783
            PIL Image: Randomly flipped image.
784
        """
soumith's avatar
soumith committed
785
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
786
            return hflip(img)
soumith's avatar
soumith committed
787
788
789
        return img


790
class RandomVerticalFlip(object):
791
    """Vertically flip the given PIL Image randomly with a probability of 0.5."""
792
793
794
795

    def __call__(self, img):
        """
        Args:
796
            img (PIL Image): Image to be flipped.
797
798

        Returns:
799
            PIL Image: Randomly flipped image.
800
801
        """
        if random.random() < 0.5:
802
            return vflip(img)
803
804
805
        return img


806
class RandomResizedCrop(object):
807
    """Crop the given PIL Image to random size and aspect ratio.
808
809
810
811
812
813
814

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
815
        size: expected output size of each edge
816
        interpolation: Default: PIL.Image.BILINEAR
817
    """
818

soumith's avatar
soumith committed
819
    def __init__(self, size, interpolation=Image.BILINEAR):
820
        self.size = (size, size)
soumith's avatar
soumith committed
821
822
        self.interpolation = interpolation

823
824
    @staticmethod
    def get_params(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
825
826
827
        """Get parameters for ``crop`` for a random sized crop.

        Args:
828
            img (PIL Image): Image to be cropped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
829
830

        Returns:
831
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
832
833
                sized crop.
        """
soumith's avatar
soumith committed
834
835
836
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
837
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
838
839
840
841
842
843
844
845

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
846
847
848
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w
soumith's avatar
soumith committed
849
850

        # Fallback
851
852
853
        w = min(img.size[0], img.size[1])
        i = (img.size[1] - w) // 2
        j = (img.size[0] - w) // 2
854
        return i, j, w, w
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
855
856

    def __call__(self, img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
857
858
        """
        Args:
859
            img (PIL Image): Image to be flipped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
860
861

        Returns:
862
            PIL Image: Randomly cropped and resize image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
863
        """
864
        i, j, h, w = self.get_params(img)
865
866
867
868
        return resized_crop(img, i, j, h, w, self.size, self.interpolation)


class RandomSizedCrop(RandomResizedCrop):
869
870
871
    """
    Note: This transform is deprecated in favor of RandomResizedCrop.
    """
872
873
874
875
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
                      "please use transforms.RandomResizedCrop instead.")
        super(RandomSizedCrop, self).__init__(*args, **kwargs)
876
877
878


class FiveCrop(object):
879
    """Crop the given PIL Image into four corners and the central crop.abs
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
    """

    def __init__(self, size):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size

    def __call__(self, img):
899
        return five_crop(img, self.size)
900
901
902


class TenCrop(object):
903
    """Crop the given PIL Image into four corners and the central crop plus the
904
905
906
907
908
909
910
911
912
       flipped version of these (horizontal flipping is used by default)

       Note: this transform returns a tuple of images and there may be a mismatch in the number of
       inputs and targets your `Dataset` returns.

       Args:
           size (sequence or int): Desired output size of the crop. If size is an
               int instead of sequence like (h, w), a square crop (size, size) is
               made.
913
           vertical_flip(bool): Use vertical flipping instead of horizontal
914
915
    """

916
    def __init__(self, size, vertical_flip=False):
917
918
919
920
921
922
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size
923
        self.vertical_flip = vertical_flip
924
925

    def __call__(self, img):
926
        return ten_crop(img, self.size, self.vertical_flip)
927
928


929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
class LinearTransformation(object):
    """Transform a tensor image with a square transformation matrix computed
    offline.

    Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
    product with the transformation matrix and reshape the tensor to its
    original shape.

    Applications:
    - whitening: zero-center the data, compute the data covariance matrix
                 [D x D] with np.dot(X.T, X), perform SVD on this matrix and
                 pass it as transformation_matrix.

    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
    """

    def __init__(self, transformation_matrix):
        if transformation_matrix.size(0) != transformation_matrix.size(1):
            raise ValueError("transformation_matrix should be square. Got " +
                             "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
        self.transformation_matrix = transformation_matrix

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be whitened.

        Returns:
            Tensor: Transformed image.
        """
        if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
            raise ValueError("tensor and transformation matrix have incompatible shape." +
                             "[{} x {} x {}] != ".format(*tensor.size()) +
                             "{}".format(self.transformation_matrix.size(0)))
        flat_tensor = tensor.view(1, -1)
        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
        tensor = transformed_tensor.view(tensor.size())
        return tensor


970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
class ColorJitter(object):
    """Randomly change the brightness, contrast and saturation of an image.

    Args:
        brightness (float): How much to jitter brightness. brightness_factor
            is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
        contrast (float): How much to jitter contrast. contrast_factor
            is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
        saturation (float): How much to jitter saturation. saturation_factor
            is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
        hue(float): How much to jitter hue. hue_factor is chosen uniformly from
            [-hue, hue]. Should be >=0 and <= 0.5.
    """
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue

    @staticmethod
    def get_params(brightness, contrast, saturation, hue):
        """Get a randomized transform to be applied on image.

        Arguments are same as that of __init__.

        Returns:
            Transform which randomly adjusts brightness, contrast and
            saturation in a random order.
        """
        transforms = []
        if brightness > 0:
            brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
            transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))

        if contrast > 0:
            contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
            transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))

        if saturation > 0:
            saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
            transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))

        if hue > 0:
            hue_factor = np.random.uniform(-hue, hue)
            transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))

        np.random.shuffle(transforms)
        transform = Compose(transforms)

        return transform

    def __call__(self, img):
        """
        Args:
1024
            img (PIL Image): Input image.
1025
1026

        Returns:
1027
            PIL Image: Color jittered image.
1028
1029
1030
1031
        """
        transform = self.get_params(self.brightness, self.contrast,
                                    self.saturation, self.hue)
        return transform(img)