transforms.py 11.5 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

soumith's avatar
soumith committed
16
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
17
18
19
    """Composes several transforms together.

    Args:
20
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
21
22
23
24
25
26

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
27
    """
28

soumith's avatar
soumith committed
29
30
31
32
33
34
35
36
37
38
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
39
40
41
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
42
43
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
44

soumith's avatar
soumith committed
45
    def __call__(self, pic):
46
47
48
49
50
51
52
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
53
54
        if isinstance(pic, np.ndarray):
            # handle numpy array
55
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
Michael Galkov's avatar
Michael Galkov committed
56
            # backward compatibility
57
            return img.float().div(255)
58
59
60
61
62
63

        if accimage is not None and isinstance(pic, accimage.Image):
            nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
            pic.copyto(nppic)
            return torch.from_numpy(nppic)

64
65
        # handle PIL Image
        if pic.mode == 'I':
66
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
67
        elif pic.mode == 'I;16':
68
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
69
70
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img
86

Adam Paszke's avatar
Adam Paszke committed
87

88
class ToPILImage(object):
89
90
91
92
    """Convert a tensor to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
93
    """
94

95
    def __call__(self, pic):
96
97
98
99
100
101
102
103
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
104
105
        npimg = pic
        mode = None
106
107
108
109
110
        if isinstance(pic, torch.FloatTensor):
            pic = pic.mul(255).byte()
        if torch.is_tensor(pic):
            npimg = np.transpose(pic.numpy(), (1, 2, 0))
        assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
111
112
        if npimg.shape[2] == 1:
            npimg = npimg[:, :, 0]
113
114
115

            if npimg.dtype == np.uint8:
                mode = 'L'
116
            if npimg.dtype == np.int16:
117
                mode = 'I;16'
118
119
            if npimg.dtype == np.int32:
                mode = 'I'
120
121
122
123
124
125
            elif npimg.dtype == np.float32:
                mode = 'F'
        else:
            if npimg.dtype == np.uint8:
                mode = 'RGB'
        assert mode is not None, '{} is not supported'.format(npimg.dtype)
126
127
        return Image.fromarray(npimg, mode=mode)

soumith's avatar
soumith committed
128
129

class Normalize(object):
130
131
132
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
133
134
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
135
136
137
138
139

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
140
    """
141

soumith's avatar
soumith committed
142
143
144
145
146
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
147
148
149
150
151
152
153
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
154
        # TODO: make efficient
soumith's avatar
soumith committed
155
156
157
158
159
160
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor


class Scale(object):
161
162
163
164
165
166
167
168
169
170
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
171
    """
172

soumith's avatar
soumith committed
173
    def __init__(self, size, interpolation=Image.BILINEAR):
174
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
175
176
177
178
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
179
180
181
182
183
184
185
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
186
187
188
189
190
191
192
193
194
195
196
197
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
198
        else:
199
            return img.resize(self.size, self.interpolation)
soumith's avatar
soumith committed
200
201
202


class CenterCrop(object):
203
204
205
206
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
207
            int instead of sequence like (w, h), a square crop (size, size) is
208
            made.
209
    """
210

soumith's avatar
soumith committed
211
    def __init__(self, size):
212
213
214
215
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
216
217

    def __call__(self, img):
218
219
220
221
222
223
224
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
225
        w, h = img.size
226
        th, tw = self.size
227
228
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
229
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
230
231


232
class Pad(object):
233
234
235
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
236
237
238
239
240
241
        padding (int or tuple): Padding on each border. If a single int is provided this
            is used to pad all borders. If tuple of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively.
        fill: Pixel fill value. Default is 0. If a tuple of
242
            length 3, it is used to fill R, G, B channels respectively.
243
    """
244

245
    def __init__(self, padding, fill=0):
246
247
248
249
250
251
        assert isinstance(padding, (numbers.Number, tuple))
        assert isinstance(fill, (numbers.Number, str, tuple))
        if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
            raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
                             "{} element tuple".format(len(padding)))

252
253
254
255
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
256
257
258
259
260
261
262
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
263
264
        return ImageOps.expand(img, border=self.padding, fill=self.fill)

265

Soumith Chintala's avatar
Soumith Chintala committed
266
class Lambda(object):
267
268
269
270
271
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
272

Soumith Chintala's avatar
Soumith Chintala committed
273
    def __init__(self, lambd):
274
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
275
276
277
278
279
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

280

soumith's avatar
soumith committed
281
class RandomCrop(object):
282
283
284
285
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
286
            int instead of sequence like (w, h), a square crop (size, size) is
287
288
289
290
291
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
292
    """
293

soumith's avatar
soumith committed
294
    def __init__(self, size, padding=0):
295
296
297
298
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
299
300
301
        self.padding = padding

    def __call__(self, img):
302
303
304
305
306
307
308
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
309
        if self.padding > 0:
310
            img = ImageOps.expand(img, border=self.padding, fill=0)
soumith's avatar
soumith committed
311
312

        w, h = img.size
313
314
        th, tw = self.size
        if w == tw and h == th:
soumith's avatar
soumith committed
315
316
            return img

317
318
319
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
320
321
322


class RandomHorizontalFlip(object):
323
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
324

soumith's avatar
soumith committed
325
    def __call__(self, img):
326
327
328
329
330
331
332
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
333
334
335
336
337
338
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img


class RandomSizedCrop(object):
339
340
341
342
343
344
345
346
347
348
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
349
    """
350

soumith's avatar
soumith committed
351
352
353
354
355
356
357
358
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
359
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                x1 = random.randint(0, img.size[0] - w)
                y1 = random.randint(0, img.size[1] - h)

                img = img.crop((x1, y1, x1 + w, y1 + h))
                assert(img.size == (w, h))

                return img.resize((self.size, self.size), self.interpolation)

        # Fallback
        scale = Scale(self.size, interpolation=self.interpolation)
        crop = CenterCrop(self.size)
        return crop(scale(img))