transforms.py 11.6 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

soumith's avatar
soumith committed
16
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
17
18
19
    """Composes several transforms together.

    Args:
20
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
21
22
23
24
25
26

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
27
    """
28

soumith's avatar
soumith committed
29
30
31
32
33
34
35
36
37
38
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
39
40
41
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
42
43
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
44

soumith's avatar
soumith committed
45
    def __call__(self, pic):
46
47
48
49
50
51
52
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
53
54
        if isinstance(pic, np.ndarray):
            # handle numpy array
55
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
Michael Galkov's avatar
Michael Galkov committed
56
            # backward compatibility
57
            return img.float().div(255)
58
59
60
61
62
63

        if accimage is not None and isinstance(pic, accimage.Image):
            nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
            pic.copyto(nppic)
            return torch.from_numpy(nppic)

64
65
        # handle PIL Image
        if pic.mode == 'I':
66
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
67
        elif pic.mode == 'I;16':
68
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
69
70
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img
86

Adam Paszke's avatar
Adam Paszke committed
87

88
class ToPILImage(object):
89
90
91
92
    """Convert a tensor to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
93
    """
94

95
    def __call__(self, pic):
96
97
98
99
100
101
102
103
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
104
105
        npimg = pic
        mode = None
106
107
108
109
110
        if isinstance(pic, torch.FloatTensor):
            pic = pic.mul(255).byte()
        if torch.is_tensor(pic):
            npimg = np.transpose(pic.numpy(), (1, 2, 0))
        assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
111
112
        if npimg.shape[2] == 1:
            npimg = npimg[:, :, 0]
113
114
115

            if npimg.dtype == np.uint8:
                mode = 'L'
116
            if npimg.dtype == np.int16:
117
                mode = 'I;16'
118
119
            if npimg.dtype == np.int32:
                mode = 'I'
120
121
            elif npimg.dtype == np.float32:
                mode = 'F'
122
123
124
        elif npimg.shape[2] == 4:
            if npimg.dtype == np.uint8:
                mode = 'RGBA'
125
126
127
128
        else:
            if npimg.dtype == np.uint8:
                mode = 'RGB'
        assert mode is not None, '{} is not supported'.format(npimg.dtype)
129
130
        return Image.fromarray(npimg, mode=mode)

soumith's avatar
soumith committed
131
132

class Normalize(object):
133
134
135
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
136
137
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
138
139
140
141
142

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
143
    """
144

soumith's avatar
soumith committed
145
146
147
148
149
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
150
151
152
153
154
155
156
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
157
        # TODO: make efficient
soumith's avatar
soumith committed
158
159
160
161
162
163
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor


class Scale(object):
164
165
166
167
168
169
170
171
172
173
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
174
    """
175

soumith's avatar
soumith committed
176
    def __init__(self, size, interpolation=Image.BILINEAR):
177
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
178
179
180
181
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
182
183
184
185
186
187
188
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
189
190
191
192
193
194
195
196
197
198
199
200
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
201
        else:
202
            return img.resize(self.size, self.interpolation)
soumith's avatar
soumith committed
203
204
205


class CenterCrop(object):
206
207
208
209
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
210
            int instead of sequence like (w, h), a square crop (size, size) is
211
            made.
212
    """
213

soumith's avatar
soumith committed
214
    def __init__(self, size):
215
216
217
218
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
219
220

    def __call__(self, img):
221
222
223
224
225
226
227
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
228
        w, h = img.size
229
        th, tw = self.size
230
231
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
232
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
233
234


235
class Pad(object):
236
237
238
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
239
240
241
242
243
244
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
245
            length 3, it is used to fill R, G, B channels respectively.
246
    """
247

248
    def __init__(self, padding, fill=0):
249
250
251
252
253
254
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

255
256
257
258
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
259
260
261
262
263
264
265
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
266
267
        return ImageOps.expand(img, border=self.padding, fill=self.fill)

268

Soumith Chintala's avatar
Soumith Chintala committed
269
class Lambda(object):
270
271
272
273
274
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
275

Soumith Chintala's avatar
Soumith Chintala committed
276
    def __init__(self, lambd):
277
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
278
279
280
281
282
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

283

soumith's avatar
soumith committed
284
class RandomCrop(object):
285
286
287
288
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
289
            int instead of sequence like (w, h), a square crop (size, size) is
290
291
292
293
294
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
295
    """
296

soumith's avatar
soumith committed
297
    def __init__(self, size, padding=0):
298
299
300
301
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
302
303
304
        self.padding = padding

    def __call__(self, img):
305
306
307
308
309
310
311
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
312
        if self.padding > 0:
313
            img = ImageOps.expand(img, border=self.padding, fill=0)
soumith's avatar
soumith committed
314
315

        w, h = img.size
316
317
        th, tw = self.size
        if w == tw and h == th:
soumith's avatar
soumith committed
318
319
            return img

320
321
322
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
323
324
325


class RandomHorizontalFlip(object):
326
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
327

soumith's avatar
soumith committed
328
    def __call__(self, img):
329
330
331
332
333
334
335
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
336
337
338
339
340
341
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img


class RandomSizedCrop(object):
342
343
344
345
346
347
348
349
350
351
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
352
    """
353

soumith's avatar
soumith committed
354
355
356
357
358
359
360
361
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
362
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                x1 = random.randint(0, img.size[0] - w)
                y1 = random.randint(0, img.size[1] - h)

                img = img.crop((x1, y1, x1 + w, y1 + h))
                assert(img.size == (w, h))

                return img.resize((self.size, self.size), self.interpolation)

        # Fallback
        scale = Scale(self.size, interpolation=self.interpolation)
        crop = CenterCrop(self.size)
        return crop(scale(img))