transforms.py 11 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

soumith's avatar
soumith committed
16
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
17
18
19
    """Composes several transforms together.

    Args:
20
        transforms (list of ``Transform`` objects): list of transforms to compose.
Adam Paszke's avatar
Adam Paszke committed
21
22
23
24
25
26

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
27
    """
28

soumith's avatar
soumith committed
29
30
31
32
33
34
35
36
37
38
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
39
40
41
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
42
43
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
44

soumith's avatar
soumith committed
45
    def __call__(self, pic):
46
47
48
49
50
51
52
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
53
54
        if isinstance(pic, np.ndarray):
            # handle numpy array
55
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
Michael Galkov's avatar
Michael Galkov committed
56
            # backward compatibility
57
            return img.float().div(255)
58
59
60
61
62
63

        if accimage is not None and isinstance(pic, accimage.Image):
            nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
            pic.copyto(nppic)
            return torch.from_numpy(nppic)

64
65
        # handle PIL Image
        if pic.mode == 'I':
66
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
67
        elif pic.mode == 'I;16':
68
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
69
70
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img
86

Adam Paszke's avatar
Adam Paszke committed
87

88
class ToPILImage(object):
89
90
91
92
    """Convert a tensor to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving the value range.
93
    """
94

95
    def __call__(self, pic):
96
97
98
99
100
101
102
103
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL.Image.

        Returns:
            PIL.Image: Image converted to PIL.Image.

        """
104
105
        npimg = pic
        mode = None
106
107
108
109
110
        if isinstance(pic, torch.FloatTensor):
            pic = pic.mul(255).byte()
        if torch.is_tensor(pic):
            npimg = np.transpose(pic.numpy(), (1, 2, 0))
        assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
111
112
        if npimg.shape[2] == 1:
            npimg = npimg[:, :, 0]
113
114
115

            if npimg.dtype == np.uint8:
                mode = 'L'
116
            if npimg.dtype == np.int16:
117
                mode = 'I;16'
118
119
            if npimg.dtype == np.int32:
                mode = 'I'
120
121
122
123
124
125
            elif npimg.dtype == np.float32:
                mode = 'F'
        else:
            if npimg.dtype == np.uint8:
                mode = 'RGB'
        assert mode is not None, '{} is not supported'.format(npimg.dtype)
126
127
        return Image.fromarray(npimg, mode=mode)

soumith's avatar
soumith committed
128
129

class Normalize(object):
130
131
132
    """Normalize an tensor image with mean and standard deviation.

    Given mean: (R, G, B) and std: (R, G, B),
133
134
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
135
136
137
138
139

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
140
    """
141

soumith's avatar
soumith committed
142
143
144
145
146
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
147
148
149
150
151
152
153
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            Tensor: Normalized image.
        """
154
        # TODO: make efficient
soumith's avatar
soumith committed
155
156
157
158
159
160
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor


class Scale(object):
161
162
163
164
165
166
167
168
169
170
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
171
    """
172

soumith's avatar
soumith committed
173
    def __init__(self, size, interpolation=Image.BILINEAR):
174
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
175
176
177
178
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
179
180
181
182
183
184
185
        """
        Args:
            img (PIL.Image): Image to be scaled.

        Returns:
            PIL.Image: Rescaled image.
        """
186
187
188
189
190
191
192
193
194
195
196
197
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
198
        else:
199
            return img.resize(self.size, self.interpolation)
soumith's avatar
soumith committed
200
201
202


class CenterCrop(object):
203
204
205
206
207
208
    """Crops the given PIL.Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (w, h), a square crop (size, size) is
            made.
209
    """
210

soumith's avatar
soumith committed
211
    def __init__(self, size):
212
213
214
215
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
216
217

    def __call__(self, img):
218
219
220
221
222
223
224
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
225
        w, h = img.size
226
        th, tw = self.size
227
228
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
229
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
230
231


232
class Pad(object):
233
234
235
236
237
238
239
    """Pad the given PIL.Image on all sides with the given "pad" value.

    Args:
        padding (int or sequence): Padding on each border. If a sequence of
            length 4, it is used to pad left, top, right and bottom borders respectively.
        fill: Pixel fill value. Default is 0.
    """
240

241
242
    def __init__(self, padding, fill=0):
        assert isinstance(padding, numbers.Number)
243
        assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
244
245
246
247
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
248
249
250
251
252
253
254
        """
        Args:
            img (PIL.Image): Image to be padded.

        Returns:
            PIL.Image: Padded image.
        """
255
256
        return ImageOps.expand(img, border=self.padding, fill=self.fill)

257

Soumith Chintala's avatar
Soumith Chintala committed
258
class Lambda(object):
259
260
261
262
263
    """Apply a user-defined lambda as a transform.

    Args:
        lambd (function): Lambda/function to be used for transform.
    """
264

Soumith Chintala's avatar
Soumith Chintala committed
265
    def __init__(self, lambd):
266
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
267
268
269
270
271
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

272

soumith's avatar
soumith committed
273
class RandomCrop(object):
274
275
276
277
278
279
280
281
282
283
    """Crop the given PIL.Image at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (w, h), a square crop (size, size) is
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
284
    """
285

soumith's avatar
soumith committed
286
    def __init__(self, size, padding=0):
287
288
289
290
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
291
292
293
        self.padding = padding

    def __call__(self, img):
294
295
296
297
298
299
300
        """
        Args:
            img (PIL.Image): Image to be cropped.

        Returns:
            PIL.Image: Cropped image.
        """
soumith's avatar
soumith committed
301
        if self.padding > 0:
302
            img = ImageOps.expand(img, border=self.padding, fill=0)
soumith's avatar
soumith committed
303
304

        w, h = img.size
305
306
        th, tw = self.size
        if w == tw and h == th:
soumith's avatar
soumith committed
307
308
            return img

309
310
311
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
312
313
314


class RandomHorizontalFlip(object):
315
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
316

soumith's avatar
soumith committed
317
    def __call__(self, img):
318
319
320
321
322
323
324
        """
        Args:
            img (PIL.Image): Image to be flipped.

        Returns:
            PIL.Image: Randomly flipped image.
        """
soumith's avatar
soumith committed
325
326
327
328
329
330
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img


class RandomSizedCrop(object):
331
332
333
334
335
336
337
338
339
340
    """Crop the given PIL.Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: size of the smaller edge
        interpolation: Default: PIL.Image.BILINEAR
341
    """
342

soumith's avatar
soumith committed
343
344
345
346
347
348
349
350
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
351
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                x1 = random.randint(0, img.size[0] - w)
                y1 = random.randint(0, img.size[1] - h)

                img = img.crop((x1, y1, x1 + w, y1 + h))
                assert(img.size == (w, h))

                return img.resize((self.size, self.size), self.interpolation)

        # Fallback
        scale = Scale(self.size, interpolation=self.interpolation)
        crop = CenterCrop(self.size)
        return crop(scale(img))