test_transforms.py 8.08 KB
Newer Older
1
2
3
4
import torch
import torchvision.transforms as transforms
import unittest
import random
5
import numpy as np
6

7

8
class Tester(unittest.TestCase):
9

10
11
12
13
    def test_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
14
15
        owidth = random.randint(5, (width - 2) / 2) * 2

16
        img = torch.ones(3, height, width)
17
18
19
        oh1 = (height - oheight) // 2
        ow1 = (width - owidth) // 2
        imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
20
21
22
23
24
25
26
        imgnarrow.fill_(0)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        assert result.sum() == 0, "height: " + str(height) + " width: " \
27
                                  + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
28
29
30
31
32
33
34
35
36
        oheight += 1
        owidth += 1
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        sum1 = result.sum()
        assert sum1 > 1, "height: " + str(height) + " width: " \
37
                         + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
38
        oheight += 1
39
        owidth += 1
40
41
42
43
44
45
46
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.CenterCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        sum2 = result.sum()
        assert sum2 > 0, "height: " + str(height) + " width: " \
47
                         + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
48
        assert sum2 > sum1, "height: " + str(height) + " width: " \
49
                            + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
50
51
52
53
54

    def test_scale(self):
        height = random.randint(24, 32) * 2
        width = random.randint(24, 32) * 2
        osize = random.randint(5, 12) * 2
55

56
57
58
59
60
61
62
63
64
65
66
        img = torch.ones(3, height, width)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Scale(osize),
            transforms.ToTensor(),
        ])(img)
        # print img.size()
        # print 'output size:', osize
        # print result.size()
        assert osize in result.size()
        if height < width:
67
            assert result.size(1) <= result.size(2)
68
69
70
71
72
73
74
        elif width < height:
            assert result.size(1) >= result.size(2)

    def test_random_crop(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        oheight = random.randint(5, (height - 2) / 2) * 2
75
        owidth = random.randint(5, (width - 2) / 2) * 2
76
77
78
79
80
81
82
83
84
        img = torch.ones(3, height, width)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((oheight, owidth)),
            transforms.ToTensor(),
        ])(img)
        assert result.size(1) == oheight
        assert result.size(2) == owidth

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        padding = random.randint(1, 20)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop((oheight, owidth), padding=padding),
            transforms.ToTensor(),
        ])(img)
        assert result.size(1) == oheight
        assert result.size(2) == owidth

    def test_pad(self):
        height = random.randint(10, 32) * 2
        width = random.randint(10, 32) * 2
        img = torch.ones(3, height, width)
        padding = random.randint(1, 20)
        result = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Pad(padding),
            transforms.ToTensor(),
        ])(img)
104
105
        assert result.size(1) == height + 2 * padding
        assert result.size(2) == width + 2 * padding
Soumith Chintala's avatar
Soumith Chintala committed
106
107
108
109
110

    def test_lambda(self):
        trans = transforms.Lambda(lambda x: x.add(10))
        x = torch.randn(10)
        y = trans(x)
111
        assert (y.equal(torch.add(x, 10)))
Soumith Chintala's avatar
Soumith Chintala committed
112
113
114
115

        trans = transforms.Lambda(lambda x: x.add_(10))
        x = torch.randn(10)
        y = trans(x)
116
117
        assert (y.equal(x))

118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def test_to_tensor(self):
        channels = 3
        height, width = 4, 4
        trans = transforms.ToTensor()
        input_data = torch.ByteTensor(channels, height, width).random_(0, 255).float().div_(255)
        img = transforms.ToPILImage()(input_data)
        output = trans(img)
        assert np.allclose(input_data.numpy(), output.numpy())

        ndarray = np.random.randint(low=0, high=255, size=(height, width, channels))
        output = trans(ndarray)
        expected_output = ndarray.transpose((2, 0, 1)) / 255.0
        assert np.allclose(output.numpy(), expected_output)

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    def test_tensor_to_pil_image(self):
        trans = transforms.ToPILImage()
        to_tensor = transforms.ToTensor()

        img_data = torch.Tensor(3, 4, 4).uniform_()
        img = trans(img_data)
        assert img.getbands() == ('R', 'G', 'B')
        r, g, b = img.split()

        expected_output = img_data.mul(255).int().float().div(255)
        assert np.allclose(expected_output[0].numpy(), to_tensor(r).numpy())
        assert np.allclose(expected_output[1].numpy(), to_tensor(g).numpy())
        assert np.allclose(expected_output[2].numpy(), to_tensor(b).numpy())

        # single channel image
        img_data = torch.Tensor(1, 4, 4).uniform_()
        img = trans(img_data)
        assert img.getbands() == ('L',)
        l, = img.split()
        expected_output = img_data.mul(255).int().float().div(255)
        assert np.allclose(expected_output[0].numpy(), to_tensor(l).numpy())

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    def test_tensor_gray_to_pil_image(self):
        trans = transforms.ToPILImage()
        to_tensor = transforms.ToTensor()

        img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255)
        img_data_short = torch.ShortTensor(1, 4, 4).random_()
        img_data_int = torch.IntTensor(1, 4, 4).random_()

        img_byte = trans(img_data_byte)
        img_short = trans(img_data_short)
        img_int = trans(img_data_int)
        assert img_byte.mode == 'L'
        assert img_short.mode == 'I;16'
        assert img_int.mode == 'I'

        assert np.allclose(img_data_short.numpy(), to_tensor(img_short).numpy())
        assert np.allclose(img_data_int.numpy(), to_tensor(img_int).numpy())

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    def test_ndarray_to_pil_image(self):
        trans = transforms.ToPILImage()
        img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy()
        img = trans(img_data)
        assert img.getbands() == ('R', 'G', 'B')
        r, g, b = img.split()

        assert np.allclose(r, img_data[:, :, 0])
        assert np.allclose(g, img_data[:, :, 1])
        assert np.allclose(b, img_data[:, :, 2])

        # single channel image
        img_data = torch.ByteTensor(4, 4, 1).random_(0, 255).numpy()
        img = trans(img_data)
        assert img.getbands() == ('L',)
        l, = img.split()
        assert np.allclose(l, img_data[:, :, 0])
189

190
    def test_ndarray_bad_types_to_pil_image(self):
191
        trans = transforms.ToPILImage()
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        with self.assertRaises(AssertionError):
            trans(np.ones([4, 4, 1], np.int64))
            trans(np.ones([4, 4, 1], np.uint16))
            trans(np.ones([4, 4, 1], np.uint32))
            trans(np.ones([4, 4, 1], np.float64))

    def test_ndarray_gray_float32_to_pil_image(self):
        trans = transforms.ToPILImage()
        img_data = torch.FloatTensor(4, 4, 1).random_().numpy()
        img = trans(img_data)
        assert img.mode == 'F'
        assert np.allclose(img, img_data[:, :, 0])

    def test_ndarray_gray_int16_to_pil_image(self):
        trans = transforms.ToPILImage()
        img_data = torch.ShortTensor(4, 4, 1).random_().numpy()
208
209
210
        img = trans(img_data)
        assert img.mode == 'I;16'
        assert np.allclose(img, img_data[:, :, 0])
211

212
213
214
215
216
217
218
    def test_ndarray_gray_int32_to_pil_image(self):
        trans = transforms.ToPILImage()
        img_data = torch.IntTensor(4, 4, 1).random_().numpy()
        img = trans(img_data)
        assert img.mode == 'I'
        assert np.allclose(img, img_data[:, :, 0])

219

220
221
if __name__ == '__main__':
    unittest.main()