transforms.py 12.2 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def to_tensor(pic):
    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        return img.float().div(255)

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        return torch.from_numpy(nppic)

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        return img.float().div(255)
    else:
        return img


def to_pilimage(pic):
    npimg = pic
    mode = None
    if isinstance(pic, torch.FloatTensor):
        pic = pic.mul(255).byte()
    if torch.is_tensor(pic):
        npimg = np.transpose(pic.numpy(), (1, 2, 0))
    assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
    if npimg.shape[2] == 1:
        npimg = npimg[:, :, 0]

        if npimg.dtype == np.uint8:
            mode = 'L'
        if npimg.dtype == np.int16:
            mode = 'I;16'
        if npimg.dtype == np.int32:
            mode = 'I'
        elif npimg.dtype == np.float32:
            mode = 'F'
    else:
        if npimg.dtype == np.uint8:
            mode = 'RGB'
    assert mode is not None, '{} is not supported'.format(npimg.dtype)
    return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std):
    # TODO: make efficient
    for t, m, s in zip(tensor, mean, std):
        t.sub_(m).div_(s)
    return tensor


def scale(img, size, interpolation=Image.BILINEAR):
    assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
    if isinstance(size, int):
        w, h = img.size
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            return img.resize((ow, oh), interpolation)
        else:
            oh = size
            ow = int(size * w / h)
            return img.resize((ow, oh), interpolation)
    else:
        return img.resize(size, interpolation)


def pad(img, padding, fill=0):
    assert isinstance(padding, numbers.Number)
    assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
    return ImageOps.expand(img, border=padding, fill=fill)


def crop(img, x, y, w, h):
    return img.crop((x, y, x + w, y + h))


def scaled_crop(img, x, y, w, h, size, interpolation=Image.BILINEAR):
    img = crop(img, x, y, w, h)
    img = scale(img, size, interpolation)


def hflip(img):
    return img.transpose(Image.FLIP_LEFT_RIGHT)


soumith's avatar
soumith committed
122
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
123
124
125
    """Composes several transforms together.

    Args:
126
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
127
128
129
130
131
132

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
133
    """
134

soumith's avatar
soumith committed
135
136
137
138
139
140
141
142
143
144
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
145
146
147
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
148
149
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
150

soumith's avatar
soumith committed
151
    def __call__(self, pic):
152
153
154
155
156
157
158
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
159
        return to_tensor(pic)
160

Adam Paszke's avatar
Adam Paszke committed
161

162
class ToPILImage(object):
163
164
165
166
    """Convert a tensor to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
167
    """
168

169
    def __call__(self, pic):
170
171
172
173
174
175
176
177
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
178
        return to_pilimage(pic)
179

soumith's avatar
soumith committed
180
181

class Normalize(object):
182
183
184
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
185
186
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
187
188
189
190
191

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
192
    """
193

soumith's avatar
soumith committed
194
195
196
197
198
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
199
200
201
202
203
204
205
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
206
        return normalize(tensor, self.mean, self.std)
soumith's avatar
soumith committed
207
208
209


class Scale(object):
210
211
212
213
214
215
216
217
218
219
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
220
    """
221

soumith's avatar
soumith committed
222
    def __init__(self, size, interpolation=Image.BILINEAR):
223
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
224
225
226
227
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
228
229
230
231
232
233
234
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
235
        return scale(img, self.size, self.interpolation)
soumith's avatar
soumith committed
236
237
238


class CenterCrop(object):
239
240
241
242
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
243
            int instead of sequence like (h, w), a square crop (size, size) is
244
            made.
245
    """
246

soumith's avatar
soumith committed
247
    def __init__(self, size):
248
249
250
251
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
252

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
253
254
255
256
257
258
259
    def get_params(self, img):
        w, h = img.size
        th, tw = self.size
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
        return x1, y1, tw, th

soumith's avatar
soumith committed
260
    def __call__(self, img):
261
262
263
264
265
266
267
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
268
269
        x1, y1, tw, th = self.get_params(img)
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
270
271


272
class Pad(object):
273
274
275
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
276
277
278
279
280
281
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
282
            length 3, it is used to fill R, G, B channels respectively.
283
    """
284

285
    def __init__(self, padding, fill=0):
286
287
288
289
290
291
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

292
293
294
295
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
296
297
298
299
300
301
302
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
303
        return pad(img, self.padding, self.fill)
304

305

Soumith Chintala's avatar
Soumith Chintala committed
306
class Lambda(object):
307
308
309
310
311
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
312

Soumith Chintala's avatar
Soumith Chintala committed
313
    def __init__(self, lambd):
314
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
315
316
317
318
319
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

320

soumith's avatar
soumith committed
321
class RandomCrop(object):
322
323
324
325
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
326
            int instead of sequence like (h, w), a square crop (size, size) is
327
328
329
330
331
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
332
    """
333

soumith's avatar
soumith committed
334
    def __init__(self, size, padding=0):
335
336
337
338
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
339
340
        self.padding = padding

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
341
342
343
344
345
346
347
348
349
350
    def get_params(self, img):
        w, h = img.size
        th, tw = self.size
        if w == tw and h == th:
            return img

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return x1, y1, tw, th

soumith's avatar
soumith committed
351
    def __call__(self, img):
352
353
354
355
356
357
358
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
359
        if self.padding > 0:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
360
            img = pad(img, self.padding)
soumith's avatar
soumith committed
361

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
362
        x1, y1, tw, th = self.get_params(img)
soumith's avatar
soumith committed
363

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
364
        return crop(img, x1, y1, tw, th)
soumith's avatar
soumith committed
365
366
367


class RandomHorizontalFlip(object):
368
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
369

soumith's avatar
soumith committed
370
    def __call__(self, img):
371
372
373
374
375
376
377
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
378
        if random.random() < 0.5:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
379
            return hflip(img)
soumith's avatar
soumith committed
380
381
382
383
        return img


class RandomSizedCrop(object):
384
385
386
387
388
389
390
391
392
393
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
394
    """
395

soumith's avatar
soumith committed
396
397
398
399
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
400
    def get_params(self, img):
soumith's avatar
soumith committed
401
402
403
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
404
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
405
406
407
408
409
410
411
412

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
413
414
415
                x = random.randint(0, img.size[0] - w)
                y = random.randint(0, img.size[1] - h)
                return x, y, w, h
soumith's avatar
soumith committed
416
417

        # Fallback
Sasank Chilamkurthy's avatar
Sasank Chilamkurthy committed
418
419
420
421
422
423
424
425
        w = min(img.size[0], img.shape[1])
        x = (img.shape[0] - w) // 2
        y = (img.shape[1] - w) // 2
        return x, y, w, w

    def __call__(self, img):
        x, y, w, h = self.get_params(img)
        return scaled_crop(img, x, y, w, h, self.size, self.interpolation)