transforms.py 6.48 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
import torch
import math
import random
5
from PIL import Image, ImageOps
6
import numpy as np
7
import numbers
soumith's avatar
soumith committed
8
9

class Compose(object):
10
11
12
13
14
15
16
    """ Composes several transforms together.
    For example:
    >>> transforms.Compose([
    >>>     transforms.CenterCrop(10),
    >>>     transforms.ToTensor(),
    >>>  ])
    """
soumith's avatar
soumith committed
17
18
19
20
21
22
23
24
25
26
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
27
28
    """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
    to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
soumith's avatar
soumith committed
29
    def __call__(self, pic):
30
31
32
33
34
35
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(pic)
        else:
            # handle PIL Image
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
Soumith Chintala's avatar
Soumith Chintala committed
36
37
            img = img.view(pic.size[1], pic.size[0], 3)
            # put it from HWC to CHW format
38
            # yikes, this transpose takes 80% of the loading time/CPU
Soumith Chintala's avatar
Soumith Chintala committed
39
            img = img.transpose(0, 1).transpose(0, 2).contiguous()
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        return img.float().div(255)

class ToPILImage(object):
    """ Converts a torch.*Tensor of range [0, 1] and shape C x H x W 
    or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
    to a PIL.Image of range [0, 255]
    """
    def __call__(self, pic):
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = Image.fromarray(pic)
        else:
            npimg = pic.mul(255).byte().numpy()
            npimg = np.transpose(npimg, (1,2,0))
            img = Image.fromarray(npimg)
        return img
soumith's avatar
soumith committed
56
57

class Normalize(object):
58
59
60
61
    """ Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    """
soumith's avatar
soumith committed
62
63
64
65
66
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
67
        # TODO: make efficient
soumith's avatar
soumith committed
68
69
70
71
72
73
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor


class Scale(object):
74
75
76
77
78
79
80
    """ Rescales the input PIL.Image to the given 'size'.
    'size' will be the size of the smaller edge.
    For example, if height > width, then image will be
    rescaled to (size * height / width, size)
    size: size of the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """
soumith's avatar
soumith committed
81
82
83
84
85
86
87
88
89
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        w, h = img.size
        if (w <= h and w == self.size) or (h <= w and h == self.size):
            return img
        if w < h:
90
91
92
            ow = self.size
            oh = int(self.size * h / w)
            return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
93
        else:
94
95
96
            oh = self.size
            ow = int(self.size * w / h)
            return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
97
98
99


class CenterCrop(object):
100
101
102
103
    """Crops the given PIL.Image at the center to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """
soumith's avatar
soumith committed
104
    def __init__(self, size):
105
106
107
108
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
109
110
111

    def __call__(self, img):
        w, h = img.size
112
113
114
115
        th, tw = self.size
        x1 = int(round((w - tw) / 2))
        y1 = int(round((h - th) / 2))
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
116
117


118
119
120
121
122
123
124
125
126
127
128
129
class Pad(object):
    """Pads the given PIL.Image on all sides with the given "pad" value"""
    def __init__(self, padding, fill=0):
        assert isinstance(padding, numbers.Number)
        assert isinstance(fill, numbers.Number)
        self.padding = padding
        self.fill = fill

    def __call__(self, img):
        return ImageOps.expand(img, border=self.padding, fill=self.fill)


soumith's avatar
soumith committed
130
class RandomCrop(object):
131
132
133
134
    """Crops the given PIL.Image at a random location to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """
soumith's avatar
soumith committed
135
    def __init__(self, size, padding=0):
136
137
138
139
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
140
141
142
143
        self.padding = padding

    def __call__(self, img):
        if self.padding > 0:
144
            img = ImageOps.expand(img, border=self.padding, fill=0)
soumith's avatar
soumith committed
145
146

        w, h = img.size
147
148
        th, tw = self.size
        if w == tw and h == th:
soumith's avatar
soumith committed
149
150
            return img

151
152
153
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
154
155
156


class RandomHorizontalFlip(object):
157
158
    """Randomly horizontally flips the given PIL.Image with a probability of 0.5
    """
soumith's avatar
soumith committed
159
160
161
162
163
164
165
    def __call__(self, img):
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img


class RandomSizedCrop(object):
166
167
168
169
170
171
    """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
    and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
    This is popularly used to train the Inception networks
    size: size of the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """
soumith's avatar
soumith committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
            aspect_ratio = random.uniform(3 / 4, 4 / 3)

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                x1 = random.randint(0, img.size[0] - w)
                y1 = random.randint(0, img.size[1] - h)

                img = img.crop((x1, y1, x1 + w, y1 + h))
                assert(img.size == (w, h))

                return img.resize((self.size, self.size), self.interpolation)

        # Fallback
        scale = Scale(self.size, interpolation=self.interpolation)
        crop = CenterCrop(self.size)
        return crop(scale(img))