transforms.py 11.2 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

soumith's avatar
soumith committed
16
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
17
18
19
    """Composes several transforms together.

    Args:
20
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
21
22
23
24
25
26

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
27
    """
28

soumith's avatar
soumith committed
29
30
31
32
33
34
35
36
37
38
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
39
40
41
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
42
43
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
44

soumith's avatar
soumith committed
45
    def __call__(self, pic):
46
47
48
49
50
51
52
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
53
54
        if isinstance(pic, np.ndarray):
            # handle numpy array
55
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
Michael Galkov's avatar
Michael Galkov committed
56
            # backward compatibility
57
            return img.float().div(255)
58
59
60
61
62
63

        if accimage is not None and isinstance(pic, accimage.Image):
            nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
            pic.copyto(nppic)
            return torch.from_numpy(nppic)

64
65
        # handle PIL Image
        if pic.mode == 'I':
66
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
67
        elif pic.mode == 'I;16':
68
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
69
70
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img
86

Adam Paszke's avatar
Adam Paszke committed
87

88
class ToPILImage(object):
89
90
91
92
    """Convert a tensor to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
93
    """
94

95
    def __call__(self, pic):
96
97
98
99
100
101
102
103
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
104
105
        npimg = pic
        mode = None
106
107
108
109
110
        if isinstance(pic, torch.FloatTensor):
            pic = pic.mul(255).byte()
        if torch.is_tensor(pic):
            npimg = np.transpose(pic.numpy(), (1, 2, 0))
        assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
111
112
        if npimg.shape[2] == 1:
            npimg = npimg[:, :, 0]
113
114
115

            if npimg.dtype == np.uint8:
                mode = 'L'
116
            if npimg.dtype == np.int16:
117
                mode = 'I;16'
118
119
            if npimg.dtype == np.int32:
                mode = 'I'
120
121
122
123
124
125
            elif npimg.dtype == np.float32:
                mode = 'F'
        else:
            if npimg.dtype == np.uint8:
                mode = 'RGB'
        assert mode is not None, '{} is not supported'.format(npimg.dtype)
126
127
        return Image.fromarray(npimg, mode=mode)

soumith's avatar
soumith committed
128
129

class Normalize(object):
130
131
132
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
133
134
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
135
136
137
138
139

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
140
    """
141

soumith's avatar
soumith committed
142
143
144
145
146
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
147
148
149
150
151
152
153
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
154
        # TODO: make efficient
soumith's avatar
soumith committed
155
156
157
158
159
160
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor


class Scale(object):
161
162
163
164
165
166
167
168
169
170
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
171
    """
172

soumith's avatar
soumith committed
173
    def __init__(self, size, interpolation=Image.BILINEAR):
174
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
175
176
177
178
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
179
180
181
182
183
184
185
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
186
187
188
189
190
191
192
193
194
195
196
197
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
198
        else:
199
            return img.resize(self.size, self.interpolation)
soumith's avatar
soumith committed
200
201
202


class CenterCrop(object):
203
204
205
206
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
207
            int instead of sequence like (h, w), a square crop (size, size) is
208
            made.
209
    """
210

soumith's avatar
soumith committed
211
    def __init__(self, size):
212
213
214
215
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
216
217

    def __call__(self, img):
218
219
220
221
222
223
224
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
225
        w, h = img.size
226
        th, tw = self.size
227
228
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
229
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
230
231


232
class Pad(object):
233
234
235
236
237
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
        padding (int or sequence): Padding on each border. If a sequence of
            length 4, it is used to pad left, top, right and bottom borders respectively.
238
239
        fill: Pixel fill value. Default is 0. If a sequence of
            length 3, it is used to fill R, G, B channels respectively.
240
    """
241

242
    def __init__(self, padding, fill=0):
243
        assert isinstance(padding, numbers.Number) or isinstance(padding, tuple)
244
        assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
245
246
247
248
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
249
250
251
252
253
254
255
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
256
257
        return ImageOps.expand(img, border=self.padding, fill=self.fill)

258

Soumith Chintala's avatar
Soumith Chintala committed
259
class Lambda(object):
260
261
262
263
264
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
265

Soumith Chintala's avatar
Soumith Chintala committed
266
    def __init__(self, lambd):
267
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
268
269
270
271
272
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

273

soumith's avatar
soumith committed
274
class RandomCrop(object):
275
276
277
278
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
279
            int instead of sequence like (h, w), a square crop (size, size) is
280
281
282
283
284
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
285
    """
286

soumith's avatar
soumith committed
287
    def __init__(self, size, padding=0):
288
289
290
291
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
292
293
294
        self.padding = padding

    def __call__(self, img):
295
296
297
298
299
300
301
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
302
        if self.padding > 0:
303
            img = ImageOps.expand(img, border=self.padding, fill=0)
soumith's avatar
soumith committed
304
305

        w, h = img.size
306
307
        th, tw = self.size
        if w == tw and h == th:
soumith's avatar
soumith committed
308
309
            return img

310
311
312
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
313
314
315


class RandomHorizontalFlip(object):
316
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
317

soumith's avatar
soumith committed
318
    def __call__(self, img):
319
320
321
322
323
324
325
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
326
327
328
329
330
331
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img


class RandomSizedCrop(object):
332
333
334
335
336
337
338
339
340
341
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
342
    """
343

soumith's avatar
soumith committed
344
345
346
347
348
349
350
351
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
352
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                x1 = random.randint(0, img.size[0] - w)
                y1 = random.randint(0, img.size[1] - h)

                img = img.crop((x1, y1, x1 + w, y1 + h))
                assert(img.size == (w, h))

                return img.resize((self.size, self.size), self.interpolation)

        # Fallback
        scale = Scale(self.size, interpolation=self.interpolation)
        crop = CenterCrop(self.size)
        return crop(scale(img))