transforms.py 8.69 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
7
8
9
try:
    import accimage
except ImportError:
    accimage = None
10
import numpy as np
11
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
12
import types
13
import collections
soumith's avatar
soumith committed
14

15

soumith's avatar
soumith committed
16
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
17
18
19
20
21
22
23
24
25
26
    """Composes several transforms together.

    Args:
        transforms (List[Transform]): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
27
    """
28

soumith's avatar
soumith committed
29
30
31
32
33
34
35
36
37
38
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
39
    """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
40
41
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
42

soumith's avatar
soumith committed
43
    def __call__(self, pic):
44
45
        if isinstance(pic, np.ndarray):
            # handle numpy array
46
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
Michael Galkov's avatar
Michael Galkov committed
47
            # backward compatibility
48
            return img.float().div(255)
49
50
51
52
53
54

        if accimage is not None and isinstance(pic, accimage.Image):
            nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
            pic.copyto(nppic)
            return torch.from_numpy(nppic)

55
56
        # handle PIL Image
        if pic.mode == 'I':
57
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
58
        elif pic.mode == 'I;16':
59
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
60
61
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img
77

Adam Paszke's avatar
Adam Paszke committed
78

79
class ToPILImage(object):
80
81
    """Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving value range.
82
    """
83

84
    def __call__(self, pic):
85
86
        npimg = pic
        mode = None
87
88
89
90
91
        if isinstance(pic, torch.FloatTensor):
            pic = pic.mul(255).byte()
        if torch.is_tensor(pic):
            npimg = np.transpose(pic.numpy(), (1, 2, 0))
        assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
92
93
        if npimg.shape[2] == 1:
            npimg = npimg[:, :, 0]
94
95
96

            if npimg.dtype == np.uint8:
                mode = 'L'
97
            if npimg.dtype == np.int16:
98
                mode = 'I;16'
99
100
            if npimg.dtype == np.int32:
                mode = 'I'
101
102
103
104
105
106
            elif npimg.dtype == np.float32:
                mode = 'F'
        else:
            if npimg.dtype == np.uint8:
                mode = 'RGB'
        assert mode is not None, '{} is not supported'.format(npimg.dtype)
107
108
        return Image.fromarray(npimg, mode=mode)

soumith's avatar
soumith committed
109
110

class Normalize(object):
Adam Paszke's avatar
Adam Paszke committed
111
    """Given mean: (R, G, B) and std: (R, G, B),
112
113
114
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    """
115

soumith's avatar
soumith committed
116
117
118
119
120
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
121
        # TODO: make efficient
soumith's avatar
soumith committed
122
123
124
125
126
127
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor


class Scale(object):
Adam Paszke's avatar
Adam Paszke committed
128
    """Rescales the input PIL.Image to the given 'size'.
129
130
    If 'size' is a 2-element tuple or list in the order of (width, height), it will be the exactly size to scale.
    If 'size' is a number, it will indicate the size of the smaller edge.
131
132
    For example, if height > width, then image will be
    rescaled to (size * height / width, size)
133
    size: size of the exactly size or the smaller edge
134
135
    interpolation: Default: PIL.Image.BILINEAR
    """
136

soumith's avatar
soumith committed
137
    def __init__(self, size, interpolation=Image.BILINEAR):
138
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
139
140
141
142
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
143
144
145
146
147
148
149
150
151
152
153
154
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
155
        else:
156
            return img.resize(self.size, self.interpolation)
soumith's avatar
soumith committed
157
158
159


class CenterCrop(object):
160
161
162
163
    """Crops the given PIL.Image at the center to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """
164

soumith's avatar
soumith committed
165
    def __init__(self, size):
166
167
168
169
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
170
171
172

    def __call__(self, img):
        w, h = img.size
173
        th, tw = self.size
174
175
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
176
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
177
178


179
180
class Pad(object):
    """Pads the given PIL.Image on all sides with the given "pad" value"""
181

182
183
    def __init__(self, padding, fill=0):
        assert isinstance(padding, numbers.Number)
184
        assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
185
186
187
188
189
190
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
        return ImageOps.expand(img, border=self.padding, fill=self.fill)

191

Soumith Chintala's avatar
Soumith Chintala committed
192
class Lambda(object):
Adam Paszke's avatar
Adam Paszke committed
193
    """Applies a lambda as a transform."""
194

Soumith Chintala's avatar
Soumith Chintala committed
195
    def __init__(self, lambd):
196
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
197
198
199
200
201
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

202

soumith's avatar
soumith committed
203
class RandomCrop(object):
204
205
206
207
    """Crops the given PIL.Image at a random location to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """
208

soumith's avatar
soumith committed
209
    def __init__(self, size, padding=0):
210
211
212
213
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
214
215
216
217
        self.padding = padding

    def __call__(self, img):
        if self.padding > 0:
218
            img = ImageOps.expand(img, border=self.padding, fill=0)
soumith's avatar
soumith committed
219
220

        w, h = img.size
221
222
        th, tw = self.size
        if w == tw and h == th:
soumith's avatar
soumith committed
223
224
            return img

225
226
227
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
228
229
230


class RandomHorizontalFlip(object):
231
232
    """Randomly horizontally flips the given PIL.Image with a probability of 0.5
    """
233

soumith's avatar
soumith committed
234
235
236
237
238
239
240
    def __call__(self, img):
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img


class RandomSizedCrop(object):
241
242
243
244
245
246
    """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
    and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
    This is popularly used to train the Inception networks
    size: size of the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """
247

soumith's avatar
soumith committed
248
249
250
251
252
253
254
255
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
256
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                x1 = random.randint(0, img.size[0] - w)
                y1 = random.randint(0, img.size[1] - h)

                img = img.crop((x1, y1, x1 + w, y1 + h))
                assert(img.size == (w, h))

                return img.resize((self.size, self.size), self.interpolation)

        # Fallback
        scale = Scale(self.size, interpolation=self.interpolation)
        crop = CenterCrop(self.size)
        return crop(scale(img))