transforms.py 13.8 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
31
def to_tensor(pic):
32
33
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
34

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


def to_pilimage(pic):
71
72
    if not(_is_numpy_image(pic) or _is_tensor_image(pic)):
        raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
73

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
74
75
76
77
78
79
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
80
    assert isinstance(npimg, np.ndarray)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
100
101
    if not _is_tensor_image(tensor):
        raise TypeError('tensor is not a torch image.')
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
102
103
104
105
106
107
108
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def scale(img, size, interpolation=Image.BILINEAR):
109
110
111
112
113
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
        return img.resize(size, interpolation)


def pad(img, padding, fill=0):
131
132
133
134
135
136
137
138
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    if not isinstance(padding, (numbers.Number, tuple)):
        raise TypeError('Got inappropriate padding arg')
    if not isinstance(fill, (numbers.Number, str, tuple)):
        raise TypeError('Got inappropriate fill arg')

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
139
140
141
142
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
143
144
145
146
    return ImageOps.expand(img, border=padding, fill=fill)


def crop(img, x, y, w, h):
147
148
149
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
150
151
152
153
    return img.crop((x, y, x + w, y + h))


def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
154
    assert _is_pil_image(img), 'img should be PIL Image'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
155
156
157
158
159
    img = crop(img, x, y, w, h)
    img = scale(img, size, interpolation)


def hflip(img):
160
161
162
    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
163
164
165
    return img.transpose(Image.FLIP_LEFT_RIGHT)


soumith's avatar
soumith committed
166
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
167
168
169
    """Composes several transforms together.

    Args:
170
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
171
172
173
174
175
176

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
177
    """
178

soumith's avatar
soumith committed
179
180
181
182
183
184
185
186
187
188
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
189
190
191
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
192
193
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
194

soumith's avatar
soumith committed
195
    def __call__(self, pic):
196
197
198
199
200
201
202
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
203
        return to_tensor(pic)
204

Adam Paszke's avatar
Adam Paszke committed
205

206
class ToPILImage(object):
207
208
209
210
    """Convert a tensor to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
211
    """
212

213
    def __call__(self, pic):
214
215
216
217
218
219
220
221
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
222
        return to_pilimage(pic)
223

soumith's avatar
soumith committed
224
225

class Normalize(object):
226
227
228
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
229
230
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
231
232
233
234
235

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
236
    """
237

soumith's avatar
soumith committed
238
239
240
241
242
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
243
244
245
246
247
248
249
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
250
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
251
252
253


class Scale(object):
254
255
256
257
258
259
260
261
262
263
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
264
    """
265

soumith's avatar
soumith committed
266
    def __init__(self, size, interpolation=Image.BILINEAR):
267
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
268
269
270
271
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
272
273
274
275
276
277
278
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
279
        return scale(img, self.size, self.interpolation)
soumith's avatar
soumith committed
280
281
282


class CenterCrop(object):
283
284
285
286
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
287
            int instead of sequence like (h, w), a square crop (size, size) is
288
            made.
289
    """
290

soumith's avatar
soumith committed
291
    def __init__(self, size):
292
293
294
295
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
296

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
297
298
299
300
301
302
303
    def get_params(self, img):
        w, h = img.size
        th, tw = self.size
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        return x1, y1, tw, th

soumith's avatar
soumith committed
304
    def __call__(self, img):
305
306
307
308
309
310
311
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
312
313
        x1, y1, tw, th = self.get_params(img)
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
314
315


316
class Pad(object):
317
318
319
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
320
321
322
323
324
325
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
326
            length 3, it is used to fill R, G, B channels respectively.
327
    """
328

329
    def __init__(self, padding, fill=0):
330
331
332
333
334
335
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

336
337
338
339
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
340
341
342
343
344
345
346
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
347
        return pad(img, self.padding, self.fill)
348

349

Soumith Chintala's avatar
Soumith Chintala committed
350
class Lambda(object):
351
352
353
354
355
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
356

Soumith Chintala's avatar
Soumith Chintala committed
357
    def __init__(self, lambd):
358
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
359
360
361
362
363
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

364

soumith's avatar
soumith committed
365
class RandomCrop(object):
366
367
368
369
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
370
            int instead of sequence like (h, w), a square crop (size, size) is
371
372
373
374
375
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
376
    """
377

soumith's avatar
soumith committed
378
    def __init__(self, size, padding=0):
379
380
381
382
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
383
384
        self.padding = padding

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
385
386
387
388
389
390
391
392
393
394
    def get_params(self, img):
        w, h = img.size
        th, tw = self.size
        if w == tw and h == th:
            return img

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return x1, y1, tw, th

soumith's avatar
soumith committed
395
    def __call__(self, img):
396
397
398
399
400
401
402
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
403
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
404
            img = pad(img, self.padding)
soumith's avatar
soumith committed
405

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
406
        x1, y1, tw, th = self.get_params(img)
soumith's avatar
soumith committed
407

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
408
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
409
410
411


class RandomHorizontalFlip(object):
412
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
413

soumith's avatar
soumith committed
414
    def __call__(self, img):
415
416
417
418
419
420
421
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
422
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
423
            return hflip(img)
soumith's avatar
soumith committed
424
425
426
427
        return img


class RandomSizedCrop(object):
428
429
430
431
432
433
434
435
436
437
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
438
    """
439

soumith's avatar
soumith committed
440
441
442
443
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
444
    def get_params(self, img):
soumith's avatar
soumith committed
445
446
447
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
448
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
449
450
451
452
453
454
455
456

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
457
458
459
                x = random.randint(0, img.size[0] - w)
                y = random.randint(0, img.size[1] - h)
                return x, y, w, h
soumith's avatar
soumith committed
460
461

        # Fallback
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
462
463
464
465
466
467
468
469
        w = min(img.size[0], img.shape[1])
        x = (img.shape[0] - w) // 2
        y = (img.shape[1] - w) // 2
        return x, y, w, w

    def __call__(self, img):
        x, y, w, h = self.get_params(img)
        return scaled_crop(img, x, y, w, h, self.size, self.interpolation)