test_functional_tensor.py 16.1 KB
Newer Older
1
import unittest
ekka's avatar
ekka committed
2
import random
3
import colorsys
4
5

from PIL import Image
vfdev's avatar
vfdev committed
6
7
8
9
10
11
12
13
14
from PIL.Image import NEAREST, BILINEAR, BICUBIC

import numpy as np

import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
15
16
17
18


class Tester(unittest.TestCase):

19
20
21
22
23
24
25
26
27
    def _create_data(self, height=3, width=3, channels=3):
        tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
        pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
        return tensor, pil_img

    def compareTensorToPIL(self, tensor, pil_image, msg=None):
        pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
        self.assertTrue(tensor.equal(pil_tensor), msg)

vfdev's avatar
vfdev committed
28
29
30
31
32
33
34
35
    def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
        pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
        mae = torch.abs(tensor - pil_tensor).mean().item()
        self.assertTrue(
            mae < tol,
            msg="{}: mae={}, tol={}: \n{}\nvs\n{}".format(msg, mae, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
        )

36
    def test_vflip(self):
37
        script_vflip = torch.jit.script(F_t.vflip)
38
        img_tensor = torch.randn(3, 16, 16)
39
        img_tensor_clone = img_tensor.clone()
40
41
42
43
        vflipped_img = F_t.vflip(img_tensor)
        vflipped_img_again = F_t.vflip(vflipped_img)
        self.assertEqual(vflipped_img.shape, img_tensor.shape)
        self.assertTrue(torch.equal(img_tensor, vflipped_img_again))
44
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
45
46
47
        # scriptable function test
        vflipped_img_script = script_vflip(img_tensor)
        self.assertTrue(torch.equal(vflipped_img, vflipped_img_script))
48
49

    def test_hflip(self):
50
        script_hflip = torch.jit.script(F_t.hflip)
51
        img_tensor = torch.randn(3, 16, 16)
52
        img_tensor_clone = img_tensor.clone()
53
54
55
56
        hflipped_img = F_t.hflip(img_tensor)
        hflipped_img_again = F_t.hflip(hflipped_img)
        self.assertEqual(hflipped_img.shape, img_tensor.shape)
        self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
57
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
58
59
60
        # scriptable function test
        hflipped_img_script = script_hflip(img_tensor)
        self.assertTrue(torch.equal(hflipped_img, hflipped_img_script))
61

ekka's avatar
ekka committed
62
    def test_crop(self):
63
        script_crop = torch.jit.script(F_t.crop)
ekka's avatar
ekka committed
64
        img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
65
        img_tensor_clone = img_tensor.clone()
ekka's avatar
ekka committed
66
67
68
69
70
71
72
73
        top = random.randint(0, 15)
        left = random.randint(0, 15)
        height = random.randint(1, 16 - top)
        width = random.randint(1, 16 - left)
        img_cropped = F_t.crop(img_tensor, top, left, height, width)
        img_PIL = transforms.ToPILImage()(img_tensor)
        img_PIL_cropped = F.crop(img_PIL, top, left, height, width)
        img_cropped_GT = transforms.ToTensor()(img_PIL_cropped)
74
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
ekka's avatar
ekka committed
75
76
        self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
                        "functional_tensor crop not working")
77
78
79
        # scriptable function test
        cropped_img_script = script_crop(img_tensor, top, left, height, width)
        self.assertTrue(torch.equal(img_cropped, cropped_img_script))
ekka's avatar
ekka committed
80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    def test_hsv2rgb(self):
        shape = (3, 100, 150)
        for _ in range(20):
            img = torch.rand(*shape, dtype=torch.float)
            ft_img = F_t._hsv2rgb(img).permute(1, 2, 0).flatten(0, 1)

            h, s, v, = img.unbind(0)
            h = h.flatten().numpy()
            s = s.flatten().numpy()
            v = v.flatten().numpy()

            rgb = []
            for h1, s1, v1 in zip(h, s, v):
                rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))

            colorsys_img = torch.tensor(rgb, dtype=torch.float32)
            max_diff = (ft_img - colorsys_img).abs().max()
            self.assertLess(max_diff, 1e-5)

    def test_rgb2hsv(self):
        shape = (3, 150, 100)
        for _ in range(20):
            img = torch.rand(*shape, dtype=torch.float)
            ft_hsv_img = F_t._rgb2hsv(img).permute(1, 2, 0).flatten(0, 1)

            r, g, b, = img.unbind(0)
            r = r.flatten().numpy()
            g = g.flatten().numpy()
            b = b.flatten().numpy()

            hsv = []
            for r1, g1, b1 in zip(r, g, b):
                hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))

            colorsys_img = torch.tensor(hsv, dtype=torch.float32)

            max_diff = (colorsys_img - ft_hsv_img).abs().max()
            self.assertLess(max_diff, 1e-5)

120
    def test_adjustments(self):
121
122
123
124
125
126
127
        script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
        script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
        script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)

        fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
               (F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
               (F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))
128
129
130
131
132
133
134
135
136
137
138
139

        for _ in range(20):
            channels = 3
            dims = torch.randint(1, 50, (2,))
            shape = (channels, dims[0], dims[1])

            if torch.randint(0, 2, (1,)) == 0:
                img = torch.rand(*shape, dtype=torch.float)
            else:
                img = torch.randint(0, 256, shape, dtype=torch.uint8)

            factor = 3 * torch.rand(1)
140
            img_clone = img.clone()
141
            for f, ft, sft in fns:
142
143

                ft_img = ft(img, factor)
144
                sft_img = sft(img, factor)
145
146
                if not img.dtype.is_floating_point:
                    ft_img = ft_img.to(torch.float) / 255
147
                    sft_img = sft_img.to(torch.float) / 255
148
149
150
151
152
153
154
155

                img_pil = transforms.ToPILImage()(img)
                f_img_pil = f(img_pil, factor)
                f_img = transforms.ToTensor()(f_img_pil)

                # F uses uint8 and F_t uses float, so there is a small
                # difference in values caused by (at most 5) truncations.
                max_diff = (ft_img - f_img).abs().max()
156
                max_diff_scripted = (sft_img - f_img).abs().max()
157
                self.assertLess(max_diff, 5 / 255 + 1e-5)
158
                self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
159
                self.assertTrue(torch.equal(img, img_clone))
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            # test for class interface
            f = transforms.ColorJitter(brightness=factor.item())
            scripted_fn = torch.jit.script(f)
            scripted_fn(img)

            f = transforms.ColorJitter(contrast=factor.item())
            scripted_fn = torch.jit.script(f)
            scripted_fn(img)

            f = transforms.ColorJitter(saturation=factor.item())
            scripted_fn = torch.jit.script(f)
            scripted_fn(img)

        f = transforms.ColorJitter(brightness=1)
        scripted_fn = torch.jit.script(f)
        scripted_fn(img)

178
    def test_rgb_to_grayscale(self):
179
        script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
180
        img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
181
        img_tensor_clone = img_tensor.clone()
182
183
184
185
        grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
        grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
        max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
        self.assertLess(max_diff, 1.0001)
186
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
187
188
189
        # scriptable function test
        grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
        self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
190

191
    def test_center_crop(self):
192
        script_center_crop = torch.jit.script(F_t.center_crop)
193
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
194
        img_tensor_clone = img_tensor.clone()
195
196
197
198
        cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
        cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
        self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
199
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
200
201
202
        # scriptable function test
        cropped_script = script_center_crop(img_tensor, [10, 10])
        self.assertTrue(torch.equal(cropped_script, cropped_tensor))
203
204

    def test_five_crop(self):
205
        script_five_crop = torch.jit.script(F_t.five_crop)
206
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
207
        img_tensor_clone = img_tensor.clone()
208
209
210
211
212
213
214
215
216
217
218
219
        cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
        cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        self.assertTrue(torch.equal(cropped_tensor[0],
                                    (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[1],
                                    (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[2],
                                    (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[3],
                                    (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[4],
                                    (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
220
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
221
222
223
224
        # scriptable function test
        cropped_script = script_five_crop(img_tensor, [10, 10])
        for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
            self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
225
226

    def test_ten_crop(self):
227
        script_ten_crop = torch.jit.script(F_t.ten_crop)
228
        img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
229
        img_tensor_clone = img_tensor.clone()
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
        cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10])
        self.assertTrue(torch.equal(cropped_tensor[0],
                                    (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[1],
                                    (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[2],
                                    (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[3],
                                    (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[4],
                                    (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[5],
                                    (transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[6],
                                    (transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[7],
                                    (transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[8],
                                    (transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8)))
        self.assertTrue(torch.equal(cropped_tensor[9],
                                    (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
252
        self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
253
254
255
256
        # scriptable function test
        cropped_script = script_ten_crop(img_tensor, [10, 10])
        for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
            self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
257

258
259
260
    def test_pad(self):
        script_fn = torch.jit.script(F_t.pad)
        tensor, pil_img = self._create_data(7, 8)
261
262
263
264
265
266
267
268
269
270
271
272

        for dt in [None, torch.float32, torch.float64]:
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
            for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]:
                configs = [
                    {"padding_mode": "constant", "fill": 0},
                    {"padding_mode": "constant", "fill": 10},
                    {"padding_mode": "constant", "fill": 20},
                    {"padding_mode": "edge"},
                    {"padding_mode": "reflect"},
273
                    {"padding_mode": "symmetric"},
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
                ]
                for kwargs in configs:
                    pad_tensor = F_t.pad(tensor, pad, **kwargs)
                    pad_pil_img = F_pil.pad(pil_img, pad, **kwargs)

                    pad_tensor_8b = pad_tensor
                    # we need to cast to uint8 to compare with PIL image
                    if pad_tensor_8b.dtype != torch.uint8:
                        pad_tensor_8b = pad_tensor_8b.to(torch.uint8)

                    self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))

                    if isinstance(pad, int):
                        script_pad = [pad, ]
                    else:
                        script_pad = pad
                    pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
                    self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
292

293
294
295
        with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
            F_t.pad(tensor, (-2, -3), padding_mode="symmetric")

vfdev's avatar
vfdev committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    def test_resize(self):
        script_fn = torch.jit.script(F_t.resize)
        tensor, pil_img = self._create_data(26, 36)

        for dt in [None, torch.float32, torch.float64]:
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
            for size in [32, [32, ], [32, 32], (32, 32), ]:
                for interpolation in [BILINEAR, BICUBIC, NEAREST]:
                    resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation)
                    resized_pil_img = F_pil.resize(pil_img, size=size, interpolation=interpolation)

                    self.assertEqual(
                        resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation)
                    )

                    if interpolation != NEAREST:
                        # We can not check values if mode = NEAREST, as results are different
                        # E.g. resized_tensor  = [[a, a, b, c, d, d, e, ...]]
                        # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
                        resized_tensor_f = resized_tensor
                        # we need to cast to uint8 to compare with PIL image
                        if resized_tensor_f.dtype == torch.uint8:
                            resized_tensor_f = resized_tensor_f.to(torch.float)

                        # Pay attention to high tolerance for MAE
                        self.approxEqualTensorToPIL(
                            resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
                        )

                    if isinstance(size, int):
                        script_size = [size, ]
                    else:
                        script_size = size
                    pad_tensor_script = script_fn(tensor, size=script_size, interpolation=interpolation)
                    self.assertTrue(resized_tensor.equal(pad_tensor_script), msg="{}, {}".format(size, interpolation))

334
335
336

if __name__ == '__main__':
    unittest.main()