"vscode:/vscode.git/clone" did not exist on "0fe7c13be18d1edd8682747ce558b430a1aa1c9e"
transforms.py 8.4 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
import numpy as np
7
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
8
import types
9
import collections
soumith's avatar
soumith committed
10

11

soumith's avatar
soumith committed
12
class Compose(object):
Adam Paszke's avatar
Adam Paszke committed
13
14
15
16
17
18
19
20
21
22
    """Composes several transforms together.

    Args:
        transforms (List[Transform]): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
23
    """
24

soumith's avatar
soumith committed
25
26
27
28
29
30
31
32
33
34
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
35
    """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
Adam Paszke's avatar
Adam Paszke committed
36
37
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """
38

soumith's avatar
soumith committed
39
    def __call__(self, pic):
40
41
        if isinstance(pic, np.ndarray):
            # handle numpy array
42
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
43
44
45
46
            # backard compability
            return img.float().div(255)
        # handle PIL Image
        if pic.mode == 'I':
47
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
48
        elif pic.mode == 'I;16':
49
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
50
51
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(255)
        else:
            return img
67

Adam Paszke's avatar
Adam Paszke committed
68

69
class ToPILImage(object):
70
71
    """Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL.Image while preserving value range.
72
    """
73

74
    def __call__(self, pic):
75
76
        npimg = pic
        mode = None
77
78
79
80
81
        if isinstance(pic, torch.FloatTensor):
            pic = pic.mul(255).byte()
        if torch.is_tensor(pic):
            npimg = np.transpose(pic.numpy(), (1, 2, 0))
        assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
82
83
        if npimg.shape[2] == 1:
            npimg = npimg[:, :, 0]
84
85
86

            if npimg.dtype == np.uint8:
                mode = 'L'
87
            if npimg.dtype == np.int16:
88
                mode = 'I;16'
89
90
            if npimg.dtype == np.int32:
                mode = 'I'
91
92
93
94
95
96
            elif npimg.dtype == np.float32:
                mode = 'F'
        else:
            if npimg.dtype == np.uint8:
                mode = 'RGB'
        assert mode is not None, '{} is not supported'.format(npimg.dtype)
97
98
        return Image.fromarray(npimg, mode=mode)

soumith's avatar
soumith committed
99
100

class Normalize(object):
Adam Paszke's avatar
Adam Paszke committed
101
    """Given mean: (R, G, B) and std: (R, G, B),
102
103
104
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    """
105

soumith's avatar
soumith committed
106
107
108
109
110
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
111
        # TODO: make efficient
soumith's avatar
soumith committed
112
113
114
115
116
117
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor


class Scale(object):
Adam Paszke's avatar
Adam Paszke committed
118
    """Rescales the input PIL.Image to the given 'size'.
119
120
    If 'size' is a 2-element tuple or list in the order of (width, height), it will be the exactly size to scale.
    If 'size' is a number, it will indicate the size of the smaller edge.
121
122
    For example, if height > width, then image will be
    rescaled to (size * height / width, size)
123
    size: size of the exactly size or the smaller edge
124
125
    interpolation: Default: PIL.Image.BILINEAR
    """
126

soumith's avatar
soumith committed
127
    def __init__(self, size, interpolation=Image.BILINEAR):
128
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
soumith's avatar
soumith committed
129
130
131
132
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
133
134
135
136
137
138
139
140
141
142
143
144
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
145
        else:
146
            return img.resize(self.size, self.interpolation)
soumith's avatar
soumith committed
147
148
149


class CenterCrop(object):
150
151
152
153
    """Crops the given PIL.Image at the center to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """
154

soumith's avatar
soumith committed
155
    def __init__(self, size):
156
157
158
159
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
160
161
162

    def __call__(self, img):
        w, h = img.size
163
        th, tw = self.size
164
165
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
166
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
167
168


169
170
class Pad(object):
    """Pads the given PIL.Image on all sides with the given "pad" value"""
171

172
173
    def __init__(self, padding, fill=0):
        assert isinstance(padding, numbers.Number)
174
        assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
175
176
177
178
179
180
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
        return ImageOps.expand(img, border=self.padding, fill=self.fill)

181

Soumith Chintala's avatar
Soumith Chintala committed
182
class Lambda(object):
Adam Paszke's avatar
Adam Paszke committed
183
    """Applies a lambda as a transform."""
184

Soumith Chintala's avatar
Soumith Chintala committed
185
    def __init__(self, lambd):
186
        assert isinstance(lambd, types.LambdaType)
Soumith Chintala's avatar
Soumith Chintala committed
187
188
189
190
191
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

192

soumith's avatar
soumith committed
193
class RandomCrop(object):
194
195
196
197
    """Crops the given PIL.Image at a random location to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """
198

soumith's avatar
soumith committed
199
    def __init__(self, size, padding=0):
200
201
202
203
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
204
205
206
207
        self.padding = padding

    def __call__(self, img):
        if self.padding > 0:
208
            img = ImageOps.expand(img, border=self.padding, fill=0)
soumith's avatar
soumith committed
209
210

        w, h = img.size
211
212
        th, tw = self.size
        if w == tw and h == th:
soumith's avatar
soumith committed
213
214
            return img

215
216
217
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
218
219
220


class RandomHorizontalFlip(object):
221
222
    """Randomly horizontally flips the given PIL.Image with a probability of 0.5
    """
223

soumith's avatar
soumith committed
224
225
226
227
228
229
230
    def __call__(self, img):
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img


class RandomSizedCrop(object):
231
232
233
234
235
236
    """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
    and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
    This is popularly used to train the Inception networks
    size: size of the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """
237

soumith's avatar
soumith committed
238
239
240
241
242
243
244
245
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
246
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                x1 = random.randint(0, img.size[0] - w)
                y1 = random.randint(0, img.size[1] - h)

                img = img.crop((x1, y1, x1 + w, y1 + h))
                assert(img.size == (w, h))

                return img.resize((self.size, self.size), self.interpolation)

        # Fallback
        scale = Scale(self.size, interpolation=self.interpolation)
        crop = CenterCrop(self.size)
        return crop(scale(img))