transforms.py 12.4 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def to_tensor(pic):
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


def to_pilimage(pic):
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
    assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def scale(img, size, interpolation=Image.BILINEAR):
    assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
        return img.resize(size, interpolation)


def pad(img, padding, fill=0):
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
104
105
106
107
108
109
    assert isinstance(padding, (numbers.Number, tuple))
    assert isinstance(fill, (numbers.Number, str, tuple))
    if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
        raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    return ImageOps.expand(img, border=padding, fill=fill)


def crop(img, x, y, w, h):
    return img.crop((x, y, x + w, y + h))


def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR):
    img = crop(img, x, y, w, h)
    img = scale(img, size, interpolation)


def hflip(img):
    return img.transpose(Image.FLIP_LEFT_RIGHT)


soumith's avatar
soumith committed
126
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
127
128
129
    """Composes several transforms together.

    Args:
130
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
131
132
133
134
135
136

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
137
    """
138

soumith's avatar
soumith committed
139
140
141
142
143
144
145
146
147
148
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
149
150
151
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
152
153
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
154

soumith's avatar
soumith committed
155
    def __call__(self, pic):
156
157
158
159
160
161
162
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
163
        return to_tensor(pic)
164

Adam Paszke's avatar
Adam Paszke committed
165

166
class ToPILImage(object):
167
168
169
170
    """Convert a tensor to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
171
    """
172

173
    def __call__(self, pic):
174
175
176
177
178
179
180
181
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
182
        return to_pilimage(pic)
183

soumith's avatar
soumith committed
184
185

class Normalize(object):
186
187
188
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
189
190
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
191
192
193
194
195

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
196
    """
197

soumith's avatar
soumith committed
198
199
200
201
202
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
203
204
205
206
207
208
209
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
210
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
211
212
213


class Scale(object):
214
215
216
217
218
219
220
221
222
223
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
224
    """
225

soumith's avatar
soumith committed
226
    def __init__(self, size, interpolation=Image.BILINEAR):
227
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
228
229
230
231
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
232
233
234
235
236
237
238
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
239
        return scale(img, self.size, self.interpolation)
soumith's avatar
soumith committed
240
241
242


class CenterCrop(object):
243
244
245
246
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
247
            int instead of sequence like (h, w), a square crop (size, size) is
248
            made.
249
    """
250

soumith's avatar
soumith committed
251
    def __init__(self, size):
252
253
254
255
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
256

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
257
258
259
260
261
262
263
    def get_params(self, img):
        w, h = img.size
        th, tw = self.size
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        return x1, y1, tw, th

soumith's avatar
soumith committed
264
    def __call__(self, img):
265
266
267
268
269
270
271
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
272
273
        x1, y1, tw, th = self.get_params(img)
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
274
275


276
class Pad(object):
277
278
279
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
280
281
282
283
284
285
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
286
            length 3, it is used to fill R, G, B channels respectively.
287
    """
288

289
    def __init__(self, padding, fill=0):
290
291
292
293
294
295
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

296
297
298
299
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
300
301
302
303
304
305
306
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
307
        return pad(img, self.padding, self.fill)
308

309

Soumith Chintala's avatar
Soumith Chintala committed
310
class Lambda(object):
311
312
313
314
315
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
316

Soumith Chintala's avatar
Soumith Chintala committed
317
    def __init__(self, lambd):
318
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
319
320
321
322
323
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

324

soumith's avatar
soumith committed
325
class RandomCrop(object):
326
327
328
329
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
330
            int instead of sequence like (h, w), a square crop (size, size) is
331
332
333
334
335
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
336
    """
337

soumith's avatar
soumith committed
338
    def __init__(self, size, padding=0):
339
340
341
342
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
343
344
        self.padding = padding

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
345
346
347
348
349
350
351
352
353
354
    def get_params(self, img):
        w, h = img.size
        th, tw = self.size
        if w == tw and h == th:
            return img

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return x1, y1, tw, th

soumith's avatar
soumith committed
355
    def __call__(self, img):
356
357
358
359
360
361
362
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
363
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
364
            img = pad(img, self.padding)
soumith's avatar
soumith committed
365

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
366
        x1, y1, tw, th = self.get_params(img)
soumith's avatar
soumith committed
367

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
368
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
369
370
371


class RandomHorizontalFlip(object):
372
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
373

soumith's avatar
soumith committed
374
    def __call__(self, img):
375
376
377
378
379
380
381
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
382
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
383
            return hflip(img)
soumith's avatar
soumith committed
384
385
386
387
        return img


class RandomSizedCrop(object):
388
389
390
391
392
393
394
395
396
397
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
398
    """
399

soumith's avatar
soumith committed
400
401
402
403
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
404
    def get_params(self, img):
soumith's avatar
soumith committed
405
406
407
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
408
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
409
410
411
412
413
414
415
416

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
417
418
419
                x = random.randint(0, img.size[0] - w)
                y = random.randint(0, img.size[1] - h)
                return x, y, w, h
soumith's avatar
soumith committed
420
421

        # Fallback
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
422
423
424
425
426
427
428
429
        w = min(img.size[0], img.shape[1])
        x = (img.shape[0] - w) // 2
        y = (img.shape[1] - w) // 2
        return x, y, w, w

    def __call__(self, img):
        x, y, w, h = self.get_params(img)
        return scaled_crop(img, x, y, w, h, self.size, self.interpolation)