transforms.py 6.04 KB
Newer Older
1
from __future__ import division
soumith's avatar
soumith committed
2
3
4
5
import torch
import math
import random
from PIL import Image
6
import numpy as np
7
import numbers
soumith's avatar
soumith committed
8
9

class Compose(object):
10
11
12
13
14
15
16
    """ Composes several transforms together.
    For example:
    >>> transforms.Compose([
    >>>     transforms.CenterCrop(10),
    >>>     transforms.ToTensor(),
    >>>  ])
    """
soumith's avatar
soumith committed
17
18
19
20
21
22
23
24
25
26
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img


class ToTensor(object):
27
28
    """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
    to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
soumith's avatar
soumith committed
29
    def __call__(self, pic):
30
31
32
33
34
35
36
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(pic)
        else:
            # handle PIL Image
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
            img = img.view(pic.size[0], pic.size[1], 3)
37
            # put it from WHC to CHW format
38
            # yikes, this transpose takes 80% of the loading time/CPU
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
            img = img.transpose(0, 2).contiguous()
        return img.float().div(255)

class ToPILImage(object):
    """ Converts a torch.*Tensor of range [0, 1] and shape C x H x W 
    or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
    to a PIL.Image of range [0, 255]
    """
    def __call__(self, pic):
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = Image.fromarray(pic)
        else:
            npimg = pic.mul(255).byte().numpy()
            npimg = np.transpose(npimg, (1,2,0))
            img = Image.fromarray(npimg)
        return img
soumith's avatar
soumith committed
56
57

class Normalize(object):
58
59
60
61
    """ Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std
    """
soumith's avatar
soumith committed
62
63
64
65
66
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
67
        # TODO: make efficient
soumith's avatar
soumith committed
68
69
70
71
72
73
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor


class Scale(object):
74
75
76
77
78
79
80
    """ Rescales the input PIL.Image to the given 'size'.
    'size' will be the size of the smaller edge.
    For example, if height > width, then image will be
    rescaled to (size * height / width, size)
    size: size of the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """
soumith's avatar
soumith committed
81
82
83
84
85
86
87
88
89
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        w, h = img.size
        if (w <= h and w == self.size) or (h <= w and h == self.size):
            return img
        if w < h:
90
91
92
            ow = self.size
            oh = int(self.size * h / w)
            return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
93
        else:
94
95
96
            oh = self.size
            ow = int(self.size * w / h)
            return img.resize((ow, oh), self.interpolation)
soumith's avatar
soumith committed
97
98
99


class CenterCrop(object):
100
101
102
103
    """Crops the given PIL.Image at the center to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """
soumith's avatar
soumith committed
104
    def __init__(self, size):
105
106
107
108
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
109
110
111

    def __call__(self, img):
        w, h = img.size
112
113
114
115
        th, tw = self.size
        x1 = int(round((w - tw) / 2))
        y1 = int(round((h - th) / 2))
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
116
117
118


class RandomCrop(object):
119
120
121
122
    """Crops the given PIL.Image at a random location to have a region of
    the given size. size can be a tuple (target_height, target_width)
    or an integer, in which case the target will be of a square shape (size, size)
    """
soumith's avatar
soumith committed
123
    def __init__(self, size, padding=0):
124
125
126
127
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
soumith's avatar
soumith committed
128
129
130
131
132
133
134
        self.padding = padding

    def __call__(self, img):
        if self.padding > 0:
            raise NotImplementedError()

        w, h = img.size
135
136
        th, tw = self.size
        if w == tw and h == th:
soumith's avatar
soumith committed
137
138
            return img

139
140
141
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return img.crop((x1, y1, x1 + tw, y1 + th))
soumith's avatar
soumith committed
142
143
144


class RandomHorizontalFlip(object):
145
146
    """Randomly horizontally flips the given PIL.Image with a probability of 0.5
    """
soumith's avatar
soumith committed
147
148
149
150
151
152
153
    def __call__(self, img):
        if random.random() < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img


class RandomSizedCrop(object):
154
155
156
157
158
159
    """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
    and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
    This is popularly used to train the Inception networks
    size: size of the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """
soumith's avatar
soumith committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
            aspect_ratio = random.uniform(3 / 4, 4 / 3)

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                x1 = random.randint(0, img.size[0] - w)
                y1 = random.randint(0, img.size[1] - h)

                img = img.crop((x1, y1, x1 + w, y1 + h))
                assert(img.size == (w, h))

                return img.resize((self.size, self.size), self.interpolation)

        # Fallback
        scale = Scale(self.size, interpolation=self.interpolation)
        crop = CenterCrop(self.size)
        return crop(scale(img))