transforms.py 37.1 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps, ImageEnhance
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
14
import warnings
soumith's avatar
soumith committed
15

16

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
32
def to_tensor(pic):
33
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
34
35
36
37

    See ``ToTensor`` for more details.

    Args:
38
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
39
40
41
42

    Returns:
        Tensor: Converted image.
    """
43
44
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
45

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


81
def to_pil_image(pic, mode=None):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
82
83
    """Convert a tensor or an ndarray to PIL Image.

84
    See :class:`~torchvision.transforms.ToPIlImage` for more details.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
85
86

    Args:
87
        pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
88
89
90
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).

    .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
91
92

    Returns:
93
        PIL Image: Image converted to PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
94
    """
95
96
    if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
97

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
98
99
100
101
102
    npimg = pic
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
103
104
105
106
107

    if not isinstance(npimg, np.ndarray):
        raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' +
                        'not {}'.format(type(npimg)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
108
    if npimg.shape[2] == 1:
109
        expected_mode = None
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
110
111
        npimg = npimg[:, :, 0]
        if npimg.dtype == np.uint8:
112
            expected_mode = 'L'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
113
        if npimg.dtype == np.int16:
114
            expected_mode = 'I;16'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
115
        if npimg.dtype == np.int32:
116
            expected_mode = 'I'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
117
        elif npimg.dtype == np.float32:
118
119
120
121
122
123
            expected_mode = 'F'
        if mode is not None and mode != expected_mode:
            raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}"
                             .format(mode, np.dtype, expected_mode))
        mode = expected_mode

124
    elif npimg.shape[2] == 4:
125
126
127
128
129
130
        permitted_4_channel_modes = ['RGBA', 'CMYK']
        if mode is not None and mode not in permitted_4_channel_modes:
            raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes))

        if mode is None and npimg.dtype == np.uint8:
            mode = 'RGBA'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
131
    else:
132
133
134
135
        permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
        if mode is not None and mode not in permitted_3_channel_modes:
            raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes))
        if mode is None and npimg.dtype == np.uint8:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
136
            mode = 'RGB'
137
138
139
140

    if mode is None:
        raise TypeError('Input type {} is not supported'.format(npimg.dtype))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
141
142
143
144
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
145
    """Normalize a tensor image with mean and standard deviation.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
146
147
148
149
150

    See ``Normalize`` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
151
152
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channely.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
153
154

    Returns:
155
        Tensor: Normalized Tensor image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
156
    """
157
158
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
159
160
161
162
163
164
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


165
def resize(img, size, interpolation=Image.BILINEAR):
166
    """Resize the input PIL Image to the given size.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
167
168

    Args:
169
        img (PIL Image): Image to be resized.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
170
        size (sequence or int): Desired output size. If size is a sequence like
171
172
173
            (h, w), the output size will be matched to this. If size is an int,
            the smaller edge of the image will be matched to this number maintaing
            the aspect ratio. i.e, if height > width, then image will be rescaled to
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
174
175
176
177
178
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``

    Returns:
179
        PIL Image: Resized image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
180
    """
181
182
183
184
185
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
186
187
188
189
190
191
192
193
194
195
196
197
198
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
199
        return img.resize(size[::-1], interpolation)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
200
201


202
203
204
205
206
207
def scale(*args, **kwargs):
    warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                  "please use transforms.Resize instead.")
    return resize(*args, **kwargs)


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
208
def pad(img, padding, fill=0):
209
    """Pad the given PIL Image on all sides with the given "pad" value.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
210
211

    Args:
212
        img (PIL Image): Image to be padded.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
213
214
215
216
217
218
219
220
221
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
            length 3, it is used to fill R, G, B channels respectively.

    Returns:
222
        PIL Image: Padded image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
223
    """
224
225
226
227
228
229
230
231
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
232
233
234
235
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
236
237
238
    return ImageOps.expand(img, border=padding, fill=fill)


239
def crop(img, i, j, h, w):
240
    """Crop the given PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
241
242

    Args:
243
        img (PIL Image): Image to be cropped.
244
245
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
246
        h: Height of the cropped image.
247
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
248
249

    Returns:
250
        PIL Image: Cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
251
    """
252
253
254
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

255
    return img.crop((j, i, j + w, i + h))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
256
257


258
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR):
259
    """Crop the given PIL Image and resize it to desired size.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
260

261
    Notably used in RandomResizedCrop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
262
263

    Args:
264
        img (PIL Image): Image to be cropped.
265
266
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
267
        h: Height of the cropped image.
268
        w: Width of the cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
269
270
271
272
        size (sequence or int): Desired output size. Same semantics as ``scale``.
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``.
    Returns:
273
        PIL Image: Cropped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
274
    """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
275
    assert _is_pil_image(img), 'img should be PIL Image'
276
    img = crop(img, i, j, h, w)
277
    img = resize(img, size, interpolation)
278
    return img
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
279
280
281


def hflip(img):
282
    """Horizontally flip the given PIL Image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
283
284

    Args:
285
        img (PIL Image): Image to be flipped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
286
287

    Returns:
288
        PIL Image:  Horizontall flipped image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
289
    """
290
291
292
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
293
294
295
    return img.transpose(Image.FLIP_LEFT_RIGHT)


296
def vflip(img):
297
    """Vertically flip the given PIL Image.
298
299

    Args:
300
        img (PIL Image): Image to be flipped.
301
302

    Returns:
303
        PIL Image:  Vertically flipped image.
304
305
306
307
308
309
310
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    return img.transpose(Image.FLIP_TOP_BOTTOM)


311
def five_crop(img, size):
312
    """Crop the given PIL Image into four corners and the central crop.
313

314
    .. Note::
315
316
317
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.
318
319

    Args:
320
321
322
323
         img (PIL Image): Image to be cropped.
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.

324
325
326
    Returns:
        tuple: tuple (tl, tr, bl, br, center) corresponding top left,
            top right, bottom left, bottom right and center crop.
327
328
329
330
331
332
333
334
335
336

    Example:
         >>> def transform(img):
         >>>    crops = five_crop(img, size) # this is a list of PIL Images
         >>>    return torch.stack([to_tensor(crop) for crop in crops)]) # returns a 4D tensor
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    w, h = img.size
    crop_h, crop_w = size
    if crop_w > w or crop_h > h:
        raise ValueError("Requested crop size {} is bigger than input size {}".format(size,
                                                                                      (h, w)))
    tl = img.crop((0, 0, crop_w, crop_h))
    tr = img.crop((w - crop_w, 0, w, crop_h))
    bl = img.crop((0, h - crop_h, crop_w, h))
    br = img.crop((w - crop_w, h - crop_h, w, h))
    center = CenterCrop((crop_h, crop_w))(img)
    return (tl, tr, bl, br, center)


def ten_crop(img, size, vertical_flip=False):
357
358
    """Crop the given PIL Image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default).
359

360
    .. Note::
361
362
363
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.
364

365
366
367
368
369
    Args:
         img (PIL Image): Image to be cropped.
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.
         vertical_flip (bool): Use vertical flipping instead of horizontal.
370

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    Returns:
        tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
            br_flip, center_flip) corresponding top left, top right,
            bottom left, bottom right and center crop and same for the
            flipped image.

    Example:
         >>> def transform(img):
         >>>    crops = ten_crop(img, size) # this is a list of PIL Images
         >>>    return torch.stack([to_tensor(crop) for crop in crops)]) # returns a 4D tensor
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    """
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    first_five = five_crop(img, size)

    if vertical_flip:
        img = vflip(img)
    else:
        img = hflip(img)

    second_five = five_crop(img, size)
    return first_five + second_five


403
404
405
406
def adjust_brightness(img, brightness_factor):
    """Adjust brightness of an Image.

    Args:
407
        img (PIL Image): PIL Image to be adjusted.
408
409
410
411
412
        brightness_factor (float):  How much to adjust the brightness. Can be
            any non negative number. 0 gives a black image, 1 gives the
            original image while 2 increases the brightness by a factor of 2.

    Returns:
413
        PIL Image: Brightness adjusted image.
414
415
416
417
418
419
420
421
422
423
424
425
426
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Brightness(img)
    img = enhancer.enhance(brightness_factor)
    return img


def adjust_contrast(img, contrast_factor):
    """Adjust contrast of an Image.

    Args:
427
        img (PIL Image): PIL Image to be adjusted.
428
429
430
431
432
        contrast_factor (float): How much to adjust the contrast. Can be any
            non negative number. 0 gives a solid gray image, 1 gives the
            original image while 2 increases the contrast by a factor of 2.

    Returns:
433
        PIL Image: Contrast adjusted image.
434
435
436
437
438
439
440
441
442
443
444
445
446
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Contrast(img)
    img = enhancer.enhance(contrast_factor)
    return img


def adjust_saturation(img, saturation_factor):
    """Adjust color saturation of an image.

    Args:
447
        img (PIL Image): PIL Image to be adjusted.
448
449
450
451
452
        saturation_factor (float):  How much to adjust the saturation. 0 will
            give a black and white image, 1 will give the original image while
            2 will enhance the saturation by a factor of 2.

    Returns:
453
        PIL Image: Saturation adjusted image.
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    enhancer = ImageEnhance.Color(img)
    img = enhancer.enhance(saturation_factor)
    return img


def adjust_hue(img, hue_factor):
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

    See https://en.wikipedia.org/wiki/Hue for more details on Hue.

    Args:
476
        img (PIL Image): PIL Image to be adjusted.
477
478
479
480
481
482
483
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
484
        PIL Image: Hue adjusted image.
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    """
    if not(-0.5 <= hue_factor <= 0.5):
        raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))

    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    input_mode = img.mode
    if input_mode in {'L', '1', 'I', 'F'}:
        return img

    h, s, v = img.convert('HSV').split()

    np_h = np.array(h, dtype=np.uint8)
    # uint8 addition take cares of rotation across boundaries
    with np.errstate(over='ignore'):
        np_h += np.uint8(hue_factor * 255)
    h = Image.fromarray(np_h, 'L')

    img = Image.merge('HSV', (h, s, v)).convert(input_mode)
    return img


def adjust_gamma(img, gamma, gain=1):
    """Perform gamma correction on an image.

    Also known as Power Law Transform. Intensities in RGB mode are adjusted
    based on the following equation:

        I_out = 255 * gain * ((I_in / 255) ** gamma)

    See https://en.wikipedia.org/wiki/Gamma_correction for more details.

    Args:
519
        img (PIL Image): PIL Image to be adjusted.
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
        gamma (float): Non negative real number. gamma larger than 1 make the
            shadows darker, while gamma smaller than 1 make dark regions
            lighter.
        gain (float): The constant multiplier.
    """
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if gamma < 0:
        raise ValueError('Gamma should be a non-negative real number')

    input_mode = img.mode
    img = img.convert('RGB')

    np_img = np.array(img, dtype=np.float32)
    np_img = 255 * gain * ((np_img / 255) ** gamma)
    np_img = np.uint8(np.clip(np_img, 0, 255))

    img = Image.fromarray(np_img, 'RGB').convert(input_mode)
    return img


soumith's avatar
soumith committed
542
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
543
544
545
    """Composes several transforms together.

    Args:
546
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
547
548
549
550
551
552

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
553
    """
554

soumith's avatar
soumith committed
555
556
557
558
559
560
561
562
563
564
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
565
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
566

567
    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
568
569
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
570

soumith's avatar
soumith committed
571
    def __call__(self, pic):
572
573
        """
        Args:
574
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
575
576
577
578

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
579
        return to_tensor(pic)
580

Adam Paszke's avatar
Adam Paszke committed
581

582
class ToPILImage(object):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
583
    """Convert a tensor or an ndarray to PIL Image.
584
585

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
586
    H x W x C to a PIL Image while preserving the value range.
587
588
589
590
591
592
593
594
595
596

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
            1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
            2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
            3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e,
            ``int``, ``float``, ``short``).

    .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
597
    """
598
599
    def __init__(self, mode=None):
        self.mode = mode
600

601
    def __call__(self, pic):
602
603
        """
        Args:
604
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
605
606

        Returns:
607
            PIL Image: Image converted to PIL Image.
608
609

        """
610
        return to_pil_image(pic, self.mode)
611

soumith's avatar
soumith committed
612
613

class Normalize(object):
614
    """Normalize an tensor image with mean and standard deviation.
615
616
617
    Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
618
619

    Args:
620
621
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
622
    """
623

soumith's avatar
soumith committed
624
625
626
627
628
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
629
630
631
632
633
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
634
            Tensor: Normalized Tensor image.
635
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
636
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
637
638


639
class Resize(object):
640
    """Resize the input PIL Image to the given size.
641
642
643

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
644
            (h, w), output size will be matched to this. If size is an int,
645
646
647
648
649
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
650
    """
651

soumith's avatar
soumith committed
652
    def __init__(self, size, interpolation=Image.BILINEAR):
653
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
654
655
656
657
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
658
659
        """
        Args:
660
            img (PIL Image): Image to be scaled.
661
662

        Returns:
663
            PIL Image: Rescaled image.
664
        """
665
666
667
668
        return resize(img, self.size, self.interpolation)


class Scale(Resize):
669
670
671
    """
    Note: This transform is deprecated in favor of Resize.
    """
672
673
674
675
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.Scale transform is deprecated, " +
                      "please use transforms.Resize instead.")
        super(Scale, self).__init__(*args, **kwargs)
soumith's avatar
soumith committed
676
677
678


class CenterCrop(object):
679
    """Crops the given PIL Image at the center.
680
681
682

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
683
            int instead of sequence like (h, w), a square crop (size, size) is
684
            made.
685
    """
686

soumith's avatar
soumith committed
687
    def __init__(self, size):
688
689
690
691
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
692

693
694
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
695
696
697
        """Get parameters for ``crop`` for center crop.

        Args:
698
            img (PIL Image): Image to be cropped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
699
700
701
            output_size (tuple): Expected output size of the crop.

        Returns:
702
            tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
703
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
704
        w, h = img.size
705
        th, tw = output_size
706
707
708
        i = int(round((h - th) / 2.))
        j = int(round((w - tw) / 2.))
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
709

soumith's avatar
soumith committed
710
    def __call__(self, img):
711
712
        """
        Args:
713
            img (PIL Image): Image to be cropped.
714
715

        Returns:
716
            PIL Image: Cropped image.
717
        """
718
719
        i, j, h, w = self.get_params(img, self.size)
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
720
721


722
class Pad(object):
723
    """Pad the given PIL Image on all sides with the given "pad" value.
724
725

    Args:
726
727
728
729
730
731
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
732
            length 3, it is used to fill R, G, B channels respectively.
733
    """
734

735
    def __init__(self, padding, fill=0):
736
737
738
739
740
741
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

742
743
744
745
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
746
747
        """
        Args:
748
            img (PIL Image): Image to be padded.
749
750

        Returns:
751
            PIL Image: Padded image.
752
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
753
        return pad(img, self.padding, self.fill)
754

755

Soumith Chintala's avatar
Soumith Chintala committed
756
class Lambda(object):
757
758
759
760
761
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
762

Soumith Chintala's avatar
Soumith Chintala committed
763
    def __init__(self, lambd):
764
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
765
766
767
768
769
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

770

soumith's avatar
soumith committed
771
class RandomCrop(object):
772
    """Crop the given PIL Image at a random location.
773
774
775

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
776
            int instead of sequence like (h, w), a square crop (size, size) is
777
778
779
780
781
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
782
    """
783

soumith's avatar
soumith committed
784
    def __init__(self, size, padding=0):
785
786
787
788
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
789
790
        self.padding = padding

791
792
    @staticmethod
    def get_params(img, output_size):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
793
794
795
        """Get parameters for ``crop`` for a random crop.

        Args:
796
            img (PIL Image): Image to be cropped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
797
798
799
            output_size (tuple): Expected output size of the crop.

        Returns:
800
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
801
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
802
        w, h = img.size
803
        th, tw = output_size
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
804
        if w == tw and h == th:
805
            return 0, 0, h, w
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
806

807
808
809
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
810

soumith's avatar
soumith committed
811
    def __call__(self, img):
812
813
        """
        Args:
814
            img (PIL Image): Image to be cropped.
815
816

        Returns:
817
            PIL Image: Cropped image.
818
        """
soumith's avatar
soumith committed
819
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
820
            img = pad(img, self.padding)
soumith's avatar
soumith committed
821

822
        i, j, h, w = self.get_params(img, self.size)
soumith's avatar
soumith committed
823

824
        return crop(img, i, j, h, w)
soumith's avatar
soumith committed
825
826
827


class RandomHorizontalFlip(object):
828
    """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
829

soumith's avatar
soumith committed
830
    def __call__(self, img):
831
832
        """
        Args:
833
            img (PIL Image): Image to be flipped.
834
835

        Returns:
836
            PIL Image: Randomly flipped image.
837
        """
soumith's avatar
soumith committed
838
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
839
            return hflip(img)
soumith's avatar
soumith committed
840
841
842
        return img


843
class RandomVerticalFlip(object):
844
    """Vertically flip the given PIL Image randomly with a probability of 0.5."""
845
846
847
848

    def __call__(self, img):
        """
        Args:
849
            img (PIL Image): Image to be flipped.
850
851

        Returns:
852
            PIL Image: Randomly flipped image.
853
854
        """
        if random.random() < 0.5:
855
            return vflip(img)
856
857
858
        return img


859
class RandomResizedCrop(object):
860
    """Crop the given PIL Image to random size and aspect ratio.
861
862
863
864
865
866
867

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
868
        size: expected output size of each edge
869
        interpolation: Default: PIL.Image.BILINEAR
870
    """
871

soumith's avatar
soumith committed
872
    def __init__(self, size, interpolation=Image.BILINEAR):
873
        self.size = (size, size)
soumith's avatar
soumith committed
874
875
        self.interpolation = interpolation

876
877
    @staticmethod
    def get_params(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
878
879
880
        """Get parameters for ``crop`` for a random sized crop.

        Args:
881
            img (PIL Image): Image to be cropped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
882
883

        Returns:
884
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
885
886
                sized crop.
        """
soumith's avatar
soumith committed
887
888
889
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
890
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
891
892
893
894
895
896
897
898

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
899
900
901
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w
soumith's avatar
soumith committed
902
903

        # Fallback
904
905
906
        w = min(img.size[0], img.size[1])
        i = (img.size[1] - w) // 2
        j = (img.size[0] - w) // 2
907
        return i, j, w, w
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
908
909

    def __call__(self, img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
910
911
        """
        Args:
912
            img (PIL Image): Image to be flipped.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
913
914

        Returns:
915
            PIL Image: Randomly cropped and resize image.
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
916
        """
917
        i, j, h, w = self.get_params(img)
918
919
920
921
        return resized_crop(img, i, j, h, w, self.size, self.interpolation)


class RandomSizedCrop(RandomResizedCrop):
922
923
924
    """
    Note: This transform is deprecated in favor of RandomResizedCrop.
    """
925
926
927
928
    def __init__(self, *args, **kwargs):
        warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " +
                      "please use transforms.RandomResizedCrop instead.")
        super(RandomSizedCrop, self).__init__(*args, **kwargs)
929
930
931


class FiveCrop(object):
932
    """Crop the given PIL Image into four corners and the central crop
933

934
935
936
937
    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.
938

939
940
941
942
943
944
945
946
947
948
949
950
951
952
    Args:
         size (sequence or int): Desired output size of the crop. If size is an ``int``
            instead of sequence like (h, w), a square crop of size (size, size) is made.

    Example:
         >>> transform = Compose([
         >>>    FiveCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
953
954
955
956
957
958
959
960
961
962
963
    """

    def __init__(self, size):
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size

    def __call__(self, img):
964
        return five_crop(img, self.size)
965
966
967


class TenCrop(object):
968
969
    """Crop the given PIL Image into four corners and the central crop plus the flipped version of
    these (horizontal flipping is used by default)
970

971
972
973
974
    .. Note::
         This transform returns a tuple of images and there may be a mismatch in the number of
         inputs and targets your Dataset returns. See below for an example of how to deal with
         this.
975

976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
        vertical_flip(bool): Use vertical flipping instead of horizontal

    Example:
         >>> transform = Compose([
         >>>    TenCrop(size), # this is a list of PIL Images
         >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
         >>> ])
         >>> #In your test loop you can do the following:
         >>> input, target = batch # input is a 5d tensor, target is 2d
         >>> bs, ncrops, c, h, w = input.size()
         >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
         >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
992
993
    """

994
    def __init__(self, size, vertical_flip=False):
995
996
997
998
999
1000
        self.size = size
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
            self.size = size
1001
        self.vertical_flip = vertical_flip
1002
1003

    def __call__(self, img):
1004
        return ten_crop(img, self.size, self.vertical_flip)
1005
1006


1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
class LinearTransformation(object):
    """Transform a tensor image with a square transformation matrix computed
    offline.

    Given transformation_matrix, will flatten the torch.*Tensor, compute the dot
    product with the transformation matrix and reshape the tensor to its
    original shape.

    Applications:
    - whitening: zero-center the data, compute the data covariance matrix
                 [D x D] with np.dot(X.T, X), perform SVD on this matrix and
                 pass it as transformation_matrix.

    Args:
        transformation_matrix (Tensor): tensor [D x D], D = C x H x W
    """

    def __init__(self, transformation_matrix):
        if transformation_matrix.size(0) != transformation_matrix.size(1):
            raise ValueError("transformation_matrix should be square. Got " +
                             "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()))
        self.transformation_matrix = transformation_matrix

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be whitened.

        Returns:
            Tensor: Transformed image.
        """
        if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0):
            raise ValueError("tensor and transformation matrix have incompatible shape." +
                             "[{} x {} x {}] != ".format(*tensor.size()) +
                             "{}".format(self.transformation_matrix.size(0)))
        flat_tensor = tensor.view(1, -1)
        transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
        tensor = transformed_tensor.view(tensor.size())
        return tensor


1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
class ColorJitter(object):
    """Randomly change the brightness, contrast and saturation of an image.

    Args:
        brightness (float): How much to jitter brightness. brightness_factor
            is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
        contrast (float): How much to jitter contrast. contrast_factor
            is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
        saturation (float): How much to jitter saturation. saturation_factor
            is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
        hue(float): How much to jitter hue. hue_factor is chosen uniformly from
            [-hue, hue]. Should be >=0 and <= 0.5.
    """
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue

    @staticmethod
    def get_params(brightness, contrast, saturation, hue):
        """Get a randomized transform to be applied on image.

        Arguments are same as that of __init__.

        Returns:
            Transform which randomly adjusts brightness, contrast and
            saturation in a random order.
        """
        transforms = []
        if brightness > 0:
            brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
            transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))

        if contrast > 0:
            contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
            transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))

        if saturation > 0:
            saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
            transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))

        if hue > 0:
            hue_factor = np.random.uniform(-hue, hue)
            transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))

        np.random.shuffle(transforms)
        transform = Compose(transforms)

        return transform

    def __call__(self, img):
        """
        Args:
1102
            img (PIL Image): Input image.
1103
1104

        Returns:
1105
            PIL Image: Color jittered image.
1106
1107
1108
1109
        """
        transform = self.get_params(self.brightness, self.contrast,
                                    self.saturation, self.hue)
        return transform(img)