transforms.py 6.73 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
import numpy as np
7
import numbers
Soumith Chintala's avatar
Soumith Chintala committed
8
import types
soumith's avatar
soumith committed
9
10

class Compose(object):
11
12
13
14
15
16
17
    """ Composes several transforms together.
    For example:
    >>> transforms.Compose([
    >>>     transforms.CenterCrop(10),
    >>>     transforms.ToTensor(),
    >>>  ])
    """
soumith's avatar
soumith committed
18
19
20
21
22
23
24
25
26
27
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
28
29
    """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
    to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
soumith's avatar
soumith committed
30
    def __call__(self, pic):
31
32
33
34
35
36
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(pic)
        else:
            # handle PIL Image
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
37
            img = img.view(pic.size[1], pic.size[0], len(pic.mode))
Soumith Chintala's avatar
Soumith Chintala committed
38
            # put it from HWC to CHW format
39
            # yikes, this transpose takes 80% of the loading time/CPU
Soumith Chintala's avatar
Soumith Chintala committed
40
            img = img.transpose(0, 1).transpose(0, 2).contiguous()
41
42
43
        return img.float().div(255)

class ToPILImage(object):
44
    """ Converts a torch.*Tensor of range [0, 1] and shape C x H x W
45
46
47
48
49
50
51
52
53
54
55
56
    or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
    to a PIL.Image of range [0, 255]
    """
    def __call__(self, pic):
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = Image.fromarray(pic)
        else:
            npimg = pic.mul(255).byte().numpy()
            npimg = np.transpose(npimg, (1,2,0))
            img = Image.fromarray(npimg)
        return img
soumith's avatar
soumith committed
57
58

class Normalize(object):
59
60
61
62
    """ Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    """
soumith's avatar
soumith committed
63
64
65
66
67
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
68
        # TODO: make efficient
soumith's avatar
soumith committed
69
70
71
72
73
74
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor


class Scale(object):
75
76
77
78
79
80
81
    """ Rescales the input PIL.Image to the given 'size'.
    'size' will be the size of the smaller edge.
    For example, if height > width, then image will be
    rescaled to (size * height / width, size)
    size: size of the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """
soumith's avatar
soumith committed
82
83
84
85
86
87
88
89
90
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        w, h = img.size
        if (w <= h and w == self.size) or (h <= w and h == self.size):
            return img
        if w < h:
91
92
93
            ow = self.size
            oh = int(self.size * h / w)
            return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
94
        else:
95
96
97
            oh = self.size
            ow = int(self.size * w / h)
            return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
98
99
100


class CenterCrop(object):
101
102
103
104
    """Crops the given PIL.Image at the center to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """
soumith's avatar
soumith committed
105
    def __init__(self, size):
106
107
108
109
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
110
111
112

    def __call__(self, img):
        w, h = img.size
113
        th, tw = self.size
114
115
        x1 = int(round((w - tw) / 2.))
        y1 = int(round((h - th) / 2.))
116
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
117
118


119
120
121
122
123
124
125
126
127
128
129
class Pad(object):
    """Pads the given PIL.Image on all sides with the given "pad" value"""
    def __init__(self, padding, fill=0):
        assert isinstance(padding, numbers.Number)
        assert isinstance(fill, numbers.Number)
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
        return ImageOps.expand(img, border=self.padding, fill=self.fill)

Soumith Chintala's avatar
Soumith Chintala committed
130
131
132
133
134
135
136
137
138
class Lambda(object):
    """Applies a lambda as a transform"""
    def __init__(self, lambd):
        assert type(lambd) is types.LambdaType
        self.lambd = lambd

    def __call__(self, img):
        return self.lambd(img)

139

soumith's avatar
soumith committed
140
class RandomCrop(object):
141
142
143
144
    """Crops the given PIL.Image at a random location to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """
soumith's avatar
soumith committed
145
    def __init__(self, size, padding=0):
146
147
148
149
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
150
151
152
153
        self.padding = padding

    def __call__(self, img):
        if self.padding > 0:
154
            img = ImageOps.expand(img, border=self.padding, fill=0)
soumith's avatar
soumith committed
155
156

        w, h = img.size
157
158
        th, tw = self.size
        if w == tw and h == th:
soumith's avatar
soumith committed
159
160
            return img

161
162
163
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
164
165
166


class RandomHorizontalFlip(object):
167
168
    """Randomly horizontally flips the given PIL.Image with a probability of 0.5
    """
soumith's avatar
soumith committed
169
170
171
172
173
174
175
    def __call__(self, img):
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img


class RandomSizedCrop(object):
176
177
178
179
180
181
    """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
    and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
    This is popularly used to train the Inception networks
    size: size of the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """
soumith's avatar
soumith committed
182
183
184
185
186
187
188
189
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
190
            aspect_ratio = random.uniform(3. / 4, 4. / 3)
soumith's avatar
soumith committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                x1 = random.randint(0, img.size[0] - w)
                y1 = random.randint(0, img.size[1] - h)

                img = img.crop((x1, y1, x1 + w, y1 + h))
                assert(img.size == (w, h))

                return img.resize((self.size, self.size), self.interpolation)

        # Fallback
        scale = Scale(self.size, interpolation=self.interpolation)
        crop = CenterCrop(self.size)
        return crop(scale(img))