transforms.py 13.2 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


def _is_tensor_image(img):
    return torch.is_tensor(img) and img.ndimension() == 3


def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
31
def to_tensor(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
32
33
    assert _is_pil_image(pic) or _is_numpy_image(pic), 'pic should be PIL Image or ndarray'

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


def to_pilimage(pic):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
70
71
    assert _is_numpy_image(pic) or _is_tensor_image(pic), 'pic should be Tensor or ndarray'

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
72
73
74
75
76
77
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
78
    assert isinstance(npimg, np.ndarray)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
98
    assert _is_tensor_image(tensor)
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
99
100
101
102
103
104
105
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def scale(img, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
106
    assert _is_pil_image(img), 'img should be PIL Image'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
        return img.resize(size, interpolation)


def pad(img, padding, fill=0):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
125
    assert _is_pil_image(img), 'img should be PIL Image'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
126
127
128
129
130
131
    assert isinstance(padding, (numbers.Number, tuple))
    assert isinstance(fill, (numbers.Number, str, tuple))
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
132
133
134
135
    return ImageOps.expand(img, border=padding, fill=fill)


def crop(img, x, y, w, h):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
136
    assert _is_pil_image(img), 'img should be PIL Image'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
137
138
139
140
    return img.crop((x, y, x + w, y + h))


def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
141
    assert _is_pil_image(img), 'img should be PIL Image'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
142
143
144
145
146
    img = crop(img, x, y, w, h)
    img = scale(img, size, interpolation)


def hflip(img):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
147
    assert _is_pil_image(img), 'img should be PIL Image'
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
148
149
150
    return img.transpose(Image.FLIP_LEFT_RIGHT)


soumith's avatar
soumith committed
151
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
152
153
154
    """Composes several transforms together.

    Args:
155
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
156
157
158
159
160
161

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
162
    """
163

soumith's avatar
soumith committed
164
165
166
167
168
169
170
171
172
173
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
174
175
176
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
177
178
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
179

soumith's avatar
soumith committed
180
    def __call__(self, pic):
181
182
183
184
185
186
187
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
188
        return to_tensor(pic)
189

Adam Paszke's avatar
Adam Paszke committed
190

191
class ToPILImage(object):
192
193
194
195
    """Convert a tensor to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
196
    """
197

198
    def __call__(self, pic):
199
200
201
202
203
204
205
206
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
207
        return to_pilimage(pic)
208

soumith's avatar
soumith committed
209
210

class Normalize(object):
211
212
213
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
214
215
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
216
217
218
219
220

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
221
    """
222

soumith's avatar
soumith committed
223
224
225
226
227
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
228
229
230
231
232
233
234
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
235
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
236
237
238


class Scale(object):
239
240
241
242
243
244
245
246
247
248
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
249
    """
250

soumith's avatar
soumith committed
251
    def __init__(self, size, interpolation=Image.BILINEAR):
252
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
253
254
255
256
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
257
258
259
260
261
262
263
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
264
        return scale(img, self.size, self.interpolation)
soumith's avatar
soumith committed
265
266
267


class CenterCrop(object):
268
269
270
271
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
272
            int instead of sequence like (h, w), a square crop (size, size) is
273
            made.
274
    """
275

soumith's avatar
soumith committed
276
    def __init__(self, size):
277
278
279
280
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
281

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
282
283
284
285
286
287
288
    def get_params(self, img):
        w, h = img.size
        th, tw = self.size
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        return x1, y1, tw, th

soumith's avatar
soumith committed
289
    def __call__(self, img):
290
291
292
293
294
295
296
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
297
298
        x1, y1, tw, th = self.get_params(img)
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
299
300


301
class Pad(object):
302
303
304
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
305
306
307
308
309
310
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
311
            length 3, it is used to fill R, G, B channels respectively.
312
    """
313

314
    def __init__(self, padding, fill=0):
315
316
317
318
319
320
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

321
322
323
324
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
325
326
327
328
329
330
331
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
332
        return pad(img, self.padding, self.fill)
333

334

Soumith Chintala's avatar
Soumith Chintala committed
335
class Lambda(object):
336
337
338
339
340
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
341

Soumith Chintala's avatar
Soumith Chintala committed
342
    def __init__(self, lambd):
343
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
344
345
346
347
348
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

349

soumith's avatar
soumith committed
350
class RandomCrop(object):
351
352
353
354
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
355
            int instead of sequence like (h, w), a square crop (size, size) is
356
357
358
359
360
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
361
    """
362

soumith's avatar
soumith committed
363
    def __init__(self, size, padding=0):
364
365
366
367
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
368
369
        self.padding = padding

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
370
371
372
373
374
375
376
377
378
379
    def get_params(self, img):
        w, h = img.size
        th, tw = self.size
        if w == tw and h == th:
            return img

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return x1, y1, tw, th

soumith's avatar
soumith committed
380
    def __call__(self, img):
381
382
383
384
385
386
387
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
388
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
389
            img = pad(img, self.padding)
soumith's avatar
soumith committed
390

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
391
        x1, y1, tw, th = self.get_params(img)
soumith's avatar
soumith committed
392

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
393
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
394
395
396


class RandomHorizontalFlip(object):
397
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
398

soumith's avatar
soumith committed
399
    def __call__(self, img):
400
401
402
403
404
405
406
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
407
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
408
            return hflip(img)
soumith's avatar
soumith committed
409
410
411
412
        return img


class RandomSizedCrop(object):
413
414
415
416
417
418
419
420
421
422
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
423
    """
424

soumith's avatar
soumith committed
425
426
427
428
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
429
    def get_params(self, img):
soumith's avatar
soumith committed
430
431
432
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
433
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
434
435
436
437
438
439
440
441

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
442
443
444
                x = random.randint(0, img.size[0] - w)
                y = random.randint(0, img.size[1] - h)
                return x, y, w, h
soumith's avatar
soumith committed
445
446

        # Fallback
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
447
448
449
450
451
452
453
454
        w = min(img.size[0], img.shape[1])
        x = (img.shape[0] - w) // 2
        y = (img.shape[1] - w) // 2
        return x, y, w, w

    def __call__(self, img):
        x, y, w, h = self.get_params(img)
        return scaled_crop(img, x, y, w, h, self.size, self.interpolation)