test_functional_tensor.py 42.6 KB
Newer Older
1
import os
2
import unittest
3
import colorsys
4
import math
5

vfdev's avatar
vfdev committed
6
7
8
9
10
11
import numpy as np

import torch
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
12
from torchvision.transforms import InterpolationMode
13

14
from common_utils import TransformsTester
15

16
17
from typing import Dict, List, Tuple

18

19
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
20
21


22
class Tester(TransformsTester):
vfdev's avatar
vfdev committed
23

24
25
26
    def setUp(self):
        self.device = "cpu"

27
28
29
30
31
32
33
34
35
36
37
38
    def _test_fn_on_batch(self, batch_tensors, fn, **fn_kwargs):
        transformed_batch = fn(batch_tensors, **fn_kwargs)
        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            transformed_img = fn(img_tensor, **fn_kwargs)
            self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))

        scripted_fn = torch.jit.script(fn)
        # scriptable function test
        s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
        self.assertTrue(transformed_batch.allclose(s_transformed_batch))

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def test_assert_image_tensor(self):
        shape = (100,)
        tensor = torch.rand(*shape, dtype=torch.float, device=self.device)

        list_of_methods = [(F_t._get_image_size, (tensor, )), (F_t.vflip, (tensor, )),
                           (F_t.hflip, (tensor, )), (F_t.crop, (tensor, 1, 2, 4, 5)),
                           (F_t.adjust_brightness, (tensor, 0.)), (F_t.adjust_contrast, (tensor, 1.)),
                           (F_t.adjust_hue, (tensor, -0.5)), (F_t.adjust_saturation, (tensor, 2.)),
                           (F_t.center_crop, (tensor, [10, 11])), (F_t.five_crop, (tensor, [10, 11])),
                           (F_t.ten_crop, (tensor, [10, 11])), (F_t.pad, (tensor, [2, ], 2, "constant")),
                           (F_t.resize, (tensor, [10, 11])), (F_t.perspective, (tensor, [0.2, ])),
                           (F_t.gaussian_blur, (tensor, (2, 2), (0.7, 0.5))),
                           (F_t.invert, (tensor, )), (F_t.posterize, (tensor, 0)),
                           (F_t.solarize, (tensor, 0.3)), (F_t.adjust_sharpness, (tensor, 0.3)),
                           (F_t.autocontrast, (tensor, )), (F_t.equalize, (tensor, ))]

        for func, args in list_of_methods:
            with self.assertRaises(Exception) as context:
                func(*args)

            self.assertTrue('Tensor is not a torch image.' in str(context.exception))

61
    def test_vflip(self):
62
63
64
65
66
67
68
        script_vflip = torch.jit.script(F.vflip)

        img_tensor, pil_img = self._create_data(16, 18, device=self.device)
        vflipped_img = F.vflip(img_tensor)
        vflipped_pil_img = F.vflip(pil_img)
        self.compareTensorToPIL(vflipped_img, vflipped_pil_img)

69
70
        # scriptable function test
        vflipped_img_script = script_vflip(img_tensor)
71
72
73
74
        self.assertTrue(vflipped_img.equal(vflipped_img_script))

        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        self._test_fn_on_batch(batch_tensors, F.vflip)
75

76
    def test_hflip(self):
77
78
79
80
81
82
83
        script_hflip = torch.jit.script(F.hflip)

        img_tensor, pil_img = self._create_data(16, 18, device=self.device)
        hflipped_img = F.hflip(img_tensor)
        hflipped_pil_img = F.hflip(pil_img)
        self.compareTensorToPIL(hflipped_img, hflipped_pil_img)

84
85
        # scriptable function test
        hflipped_img_script = script_hflip(img_tensor)
86
87
88
89
        self.assertTrue(hflipped_img.equal(hflipped_img_script))

        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        self._test_fn_on_batch(batch_tensors, F.hflip)
90

91
    def test_crop(self):
92
        script_crop = torch.jit.script(F.crop)
93

94
        img_tensor, pil_img = self._create_data(16, 18, device=self.device)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        test_configs = [
            (1, 2, 4, 5),   # crop inside top-left corner
            (2, 12, 3, 4),  # crop inside top-right corner
            (8, 3, 5, 6),   # crop inside bottom-left corner
            (8, 11, 4, 3),  # crop inside bottom-right corner
        ]

        for top, left, height, width in test_configs:
            pil_img_cropped = F.crop(pil_img, top, left, height, width)

            img_tensor_cropped = F.crop(img_tensor, top, left, height, width)
            self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)

            img_tensor_cropped = script_crop(img_tensor, top, left, height, width)
            self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
ekka's avatar
ekka committed
111

112
113
114
            batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
            self._test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width)

115
    def test_hsv2rgb(self):
116
        scripted_fn = torch.jit.script(F_t._hsv2rgb)
117
        shape = (3, 100, 150)
118
119
120
121
        for _ in range(10):
            hsv_img = torch.rand(*shape, dtype=torch.float, device=self.device)
            rgb_img = F_t._hsv2rgb(hsv_img)
            ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1)
122

123
124
125
126
            h, s, v, = hsv_img.unbind(0)
            h = h.flatten().cpu().numpy()
            s = s.flatten().cpu().numpy()
            v = v.flatten().cpu().numpy()
127
128
129
130

            rgb = []
            for h1, s1, v1 in zip(h, s, v):
                rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
131
            colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device)
132
133
134
            max_diff = (ft_img - colorsys_img).abs().max()
            self.assertLess(max_diff, 1e-5)

135
136
137
            s_rgb_img = scripted_fn(hsv_img)
            self.assertTrue(rgb_img.allclose(s_rgb_img))

138
139
140
        batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
        self._test_fn_on_batch(batch_tensors, F_t._hsv2rgb)

141
    def test_rgb2hsv(self):
142
        scripted_fn = torch.jit.script(F_t._rgb2hsv)
143
        shape = (3, 150, 100)
144
145
146
147
        for _ in range(10):
            rgb_img = torch.rand(*shape, dtype=torch.float, device=self.device)
            hsv_img = F_t._rgb2hsv(rgb_img)
            ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
148

149
            r, g, b, = rgb_img.unbind(dim=-3)
150
151
152
            r = r.flatten().cpu().numpy()
            g = g.flatten().cpu().numpy()
            b = b.flatten().cpu().numpy()
153
154
155
156
157

            hsv = []
            for r1, g1, b1 in zip(r, g, b):
                hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))

158
            colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=self.device)
159

160
161
162
163
164
165
            ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1)
            colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1)

            max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max()
            max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max()
            max_diff = max(max_diff_h, max_diff_sv)
166
167
            self.assertLess(max_diff, 1e-5)

168
169
170
            s_hsv_img = scripted_fn(rgb_img)
            self.assertTrue(hsv_img.allclose(s_hsv_img))

171
172
173
        batch_tensors = self._create_data_batch(120, 100, num_samples=4, device=self.device).float()
        self._test_fn_on_batch(batch_tensors, F_t._rgb2hsv)

174
    def test_rgb_to_grayscale(self):
175
176
        script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

177
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
178
179
180
181
182
183
184
185
186
187

        for num_output_channels in (3, 1):
            gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
            gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)

            self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")

            s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
            self.assertTrue(s_gray_tensor.equal(gray_tensor))

188
189
190
            batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
            self._test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels)

191
    def test_center_crop(self):
192
193
        script_center_crop = torch.jit.script(F.center_crop)

194
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
195
196
197
198
199
200
201
202

        cropped_pil_image = F.center_crop(pil_img, [10, 11])

        cropped_tensor = F.center_crop(img_tensor, [10, 11])
        self.compareTensorToPIL(cropped_tensor, cropped_pil_image)

        cropped_tensor = script_center_crop(img_tensor, [10, 11])
        self.compareTensorToPIL(cropped_tensor, cropped_pil_image)
203

204
205
206
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        self._test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11])

207
    def test_five_crop(self):
208
209
        script_five_crop = torch.jit.script(F.five_crop)

210
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
211
212
213
214
215
216
217
218
219
220

        cropped_pil_images = F.five_crop(pil_img, [10, 11])

        cropped_tensors = F.five_crop(img_tensor, [10, 11])
        for i in range(5):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])

        cropped_tensors = script_five_crop(img_tensor, [10, 11])
        for i in range(5):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        tuple_transformed_batches = F.five_crop(batch_tensors, [10, 11])
        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            tuple_transformed_imgs = F.five_crop(img_tensor, [10, 11])
            self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))

            for j in range(len(tuple_transformed_imgs)):
                true_transformed_img = tuple_transformed_imgs[j]
                transformed_img = tuple_transformed_batches[j][i, ...]
                self.assertTrue(true_transformed_img.equal(transformed_img))

        # scriptable function test
        s_tuple_transformed_batches = script_five_crop(batch_tensors, [10, 11])
        for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
            self.assertTrue(transformed_batch.equal(s_transformed_batch))

239
    def test_ten_crop(self):
240
241
        script_ten_crop = torch.jit.script(F.ten_crop)

242
        img_tensor, pil_img = self._create_data(32, 34, device=self.device)
243
244
245
246
247
248
249
250
251
252

        cropped_pil_images = F.ten_crop(pil_img, [10, 11])

        cropped_tensors = F.ten_crop(img_tensor, [10, 11])
        for i in range(10):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])

        cropped_tensors = script_ten_crop(img_tensor, [10, 11])
        for i in range(10):
            self.compareTensorToPIL(cropped_tensors[i], cropped_pil_images[i])
253

254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
        tuple_transformed_batches = F.ten_crop(batch_tensors, [10, 11])
        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            tuple_transformed_imgs = F.ten_crop(img_tensor, [10, 11])
            self.assertEqual(len(tuple_transformed_imgs), len(tuple_transformed_batches))

            for j in range(len(tuple_transformed_imgs)):
                true_transformed_img = tuple_transformed_imgs[j]
                transformed_img = tuple_transformed_batches[j][i, ...]
                self.assertTrue(true_transformed_img.equal(transformed_img))

        # scriptable function test
        s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11])
        for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches):
            self.assertTrue(transformed_batch.equal(s_transformed_batch))

271
    def test_pad(self):
272
        script_fn = torch.jit.script(F.pad)
273
        tensor, pil_img = self._create_data(7, 8, device=self.device)
274
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
275

276
277
278
279
280
281
        for dt in [None, torch.float32, torch.float64, torch.float16]:

            if dt == torch.float16 and torch.device(self.device).type == "cpu":
                # skip float16 on CPU case
                continue

282
283
284
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
285
286
                batch_tensors = batch_tensors.to(dt)

287
288
289
290
291
292
293
            for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]:
                configs = [
                    {"padding_mode": "constant", "fill": 0},
                    {"padding_mode": "constant", "fill": 10},
                    {"padding_mode": "constant", "fill": 20},
                    {"padding_mode": "edge"},
                    {"padding_mode": "reflect"},
294
                    {"padding_mode": "symmetric"},
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
                ]
                for kwargs in configs:
                    pad_tensor = F_t.pad(tensor, pad, **kwargs)
                    pad_pil_img = F_pil.pad(pil_img, pad, **kwargs)

                    pad_tensor_8b = pad_tensor
                    # we need to cast to uint8 to compare with PIL image
                    if pad_tensor_8b.dtype != torch.uint8:
                        pad_tensor_8b = pad_tensor_8b.to(torch.uint8)

                    self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))

                    if isinstance(pad, int):
                        script_pad = [pad, ]
                    else:
                        script_pad = pad
                    pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
                    self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
313

314
315
                    self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)

316
317
    def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
                        dts=(None, torch.float32, torch.float64)):
vfdev's avatar
vfdev committed
318
319
320
        script_fn = torch.jit.script(fn)
        torch.manual_seed(15)
        tensor, pil_img = self._create_data(26, 34, device=self.device)
321
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
vfdev's avatar
vfdev committed
322

323
        for dt in dts:
324
325
326

            if dt is not None:
                tensor = F.convert_image_dtype(tensor, dt)
327
                batch_tensors = F.convert_image_dtype(batch_tensors, dt)
328

vfdev's avatar
vfdev committed
329
330
331
332
333
334
335
            for config in configs:
                adjusted_tensor = fn_t(tensor, **config)
                adjusted_pil = fn_pil(pil_img, **config)
                scripted_result = script_fn(tensor, **config)
                msg = "{}, {}".format(dt, config)
                self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype, msg=msg)
                self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1], msg=msg)
336
337

                rbg_tensor = adjusted_tensor
vfdev's avatar
vfdev committed
338

339
340
341
                if adjusted_tensor.dtype != torch.uint8:
                    rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)

vfdev's avatar
vfdev committed
342
343
                # Check that max difference does not exceed 2 in [0, 255] range
                # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
344
345
346
347
348
349
                self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol=tol, msg=msg, agg_method=agg_method)

                atol = 1e-6
                if adjusted_tensor.dtype == torch.uint8 and "cuda" in torch.device(self.device).type:
                    atol = 1.0
                self.assertTrue(adjusted_tensor.allclose(scripted_result, atol=atol), msg=msg)
vfdev's avatar
vfdev committed
350

351
352
                self._test_fn_on_batch(batch_tensors, fn, **config)

vfdev's avatar
vfdev committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    def test_adjust_brightness(self):
        self._test_adjust_fn(
            F.adjust_brightness,
            F_pil.adjust_brightness,
            F_t.adjust_brightness,
            [{"brightness_factor": f} for f in [0.1, 0.5, 1.0, 1.34, 2.5]]
        )

    def test_adjust_contrast(self):
        self._test_adjust_fn(
            F.adjust_contrast,
            F_pil.adjust_contrast,
            F_t.adjust_contrast,
            [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
        )

    def test_adjust_saturation(self):
        self._test_adjust_fn(
            F.adjust_saturation,
            F_pil.adjust_saturation,
            F_t.adjust_saturation,
            [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
        )
376

377
378
379
380
381
382
    def test_adjust_hue(self):
        self._test_adjust_fn(
            F.adjust_hue,
            F_pil.adjust_hue,
            F_t.adjust_hue,
            [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
vfdev's avatar
vfdev committed
383
384
            tol=16.1,
            agg_method="max"
385
386
        )

vfdev's avatar
vfdev committed
387
388
389
390
391
392
393
    def test_adjust_gamma(self):
        self._test_adjust_fn(
            F.adjust_gamma,
            F_pil.adjust_gamma,
            F_t.adjust_gamma,
            [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]
        )
394

395
    def test_resize(self):
396
        script_fn = torch.jit.script(F.resize)
397
        tensor, pil_img = self._create_data(26, 36, device=self.device)
398
        batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
vfdev's avatar
vfdev committed
399

400
401
402
403
404
405
        for dt in [None, torch.float32, torch.float64, torch.float16]:

            if dt == torch.float16 and torch.device(self.device).type == "cpu":
                # skip float16 on CPU case
                continue

vfdev's avatar
vfdev committed
406
407
408
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
409
410
                batch_tensors = batch_tensors.to(dt)

411
            for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
vfdev's avatar
vfdev committed
412
                for interpolation in [BILINEAR, BICUBIC, NEAREST]:
413
414
                    resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
                    resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
vfdev's avatar
vfdev committed
415
416
417
418
419

                    self.assertEqual(
                        resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation)
                    )

420
                    if interpolation not in [NEAREST, ]:
vfdev's avatar
vfdev committed
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
                        # We can not check values if mode = NEAREST, as results are different
                        # E.g. resized_tensor  = [[a, a, b, c, d, d, e, ...]]
                        # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
                        resized_tensor_f = resized_tensor
                        # we need to cast to uint8 to compare with PIL image
                        if resized_tensor_f.dtype == torch.uint8:
                            resized_tensor_f = resized_tensor_f.to(torch.float)

                        # Pay attention to high tolerance for MAE
                        self.approxEqualTensorToPIL(
                            resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
                        )

                    if isinstance(size, int):
                        script_size = [size, ]
                    else:
                        script_size = size
438

439
440
                    resize_result = script_fn(tensor, size=script_size, interpolation=interpolation)
                    self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
vfdev's avatar
vfdev committed
441

442
443
444
445
                    self._test_fn_on_batch(
                        batch_tensors, F.resize, size=script_size, interpolation=interpolation
                    )

446
        # assert changed type warning
447
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
448
449
450
451
            res1 = F.resize(tensor, size=32, interpolation=2)
            res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

452
    def test_resized_crop(self):
453
454
        # test values of F.resized_crop in several cases:
        # 1) resize to the same size, crop to the same size => should be identity
455
        tensor, _ = self._create_data(26, 36, device=self.device)
456
457
458

        for mode in [NEAREST, BILINEAR, BICUBIC]:
            out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode)
459
460
461
            self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]))

        # 2) resize by half and crop a TL corner
462
        tensor, _ = self._create_data(26, 36, device=self.device)
463
        out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST)
464
465
466
467
468
469
        expected_out_tensor = tensor[:, :20:2, :30:2]
        self.assertTrue(
            expected_out_tensor.equal(out_tensor),
            msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
        )

470
471
        batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
        self._test_fn_on_batch(
472
            batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST
473
474
        )

475
476
    def _test_affine_identity_map(self, tensor, scripted_affine):
        # 1) identity map
477
        out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
vfdev's avatar
vfdev committed
478

479
480
481
        self.assertTrue(
            tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
        )
482
483
484
        out_tensor = scripted_affine(
            tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
        )
485
486
487
        self.assertTrue(
            tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
        )
488

489
490
491
492
493
494
495
496
497
498
499
500
501
    def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine):
        # 2) Test rotation
        test_configs = [
            (90, torch.rot90(tensor, k=1, dims=(-1, -2))),
            (45, None),
            (30, None),
            (-30, None),
            (-45, None),
            (-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
            (180, torch.rot90(tensor, k=2, dims=(-1, -2))),
        ]
        for a, true_tensor in test_configs:
            out_pil_img = F.affine(
502
                pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
503
            )
504
505
506
507
            out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device)

            for fn in [F.affine, scripted_affine]:
                out_tensor = fn(
508
                    tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
509
510
511
512
513
514
                )
                if true_tensor is not None:
                    self.assertTrue(
                        true_tensor.equal(out_tensor),
                        msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
                    )
515

516
517
518
519
520
521
522
523
524
525
526
                if out_tensor.dtype != torch.uint8:
                    out_tensor = out_tensor.to(torch.uint8)

                num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                # Tolerance : less than 6% of different pixels
                self.assertLess(
                    ratio_diff_pixels,
                    0.06,
                    msg="{}\n{} vs \n{}".format(
                        ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
527
                    )
528
                )
529

530
531
532
533
534
    def _test_affine_rect_rotations(self, tensor, pil_img, scripted_affine):
        test_configs = [
            90, 45, 15, -30, -60, -120
        ]
        for a in test_configs:
535

536
            out_pil_img = F.affine(
537
                pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
538
539
540
541
542
            )
            out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

            for fn in [F.affine, scripted_affine]:
                out_tensor = fn(
543
                    tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
544
545
546
547
548
549
550
551
552
553
554
555
556
                ).cpu()

                if out_tensor.dtype != torch.uint8:
                    out_tensor = out_tensor.to(torch.uint8)

                num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                # Tolerance : less than 3% of different pixels
                self.assertLess(
                    ratio_diff_pixels,
                    0.03,
                    msg="{}: {}\n{} vs \n{}".format(
                        a, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
557
                    )
558
                )
559

560
561
562
563
564
565
    def _test_affine_translations(self, tensor, pil_img, scripted_affine):
        # 3) Test translation
        test_configs = [
            [10, 12], (-12, -13)
        ]
        for t in test_configs:
566

567
            out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
568

569
            for fn in [F.affine, scripted_affine]:
570
                out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST)
571

572
573
                if out_tensor.dtype != torch.uint8:
                    out_tensor = out_tensor.to(torch.uint8)
574

575
576
577
578
579
                self.compareTensorToPIL(out_tensor, out_pil_img)

    def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
        # 4) Test rotation + translation + scale + share
        test_configs = [
580
581
582
583
584
585
586
587
588
589
            (45.5, [5, 6], 1.0, [0.0, 0.0], None),
            (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
            (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
            (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
            (85, (10, -10), 0.7, [0.0, 0.0], [1, ]),
            (0, [0, 0], 1.0, [35.0, ], (2.0, )),
            (-25, [0, 0], 1.2, [0.0, 15.0], None),
            (-45, [-10, 0], 0.7, [2.0, 5.0], None),
            (-45, [-10, -10], 1.2, [4.0, 5.0], None),
            (-90, [0, 0], 1.0, [0.0, 0.0], None),
590
        ]
591
        for r in [NEAREST, ]:
592
593
594
            for a, t, s, sh, f in test_configs:
                f_pil = int(f[0]) if f is not None and len(f) == 1 else f
                out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f_pil)
595
596
597
                out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

                for fn in [F.affine, scripted_affine]:
598
                    out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f).cpu()
599
600
601
602
603
604
605
606
607
608
609
610

                    if out_tensor.dtype != torch.uint8:
                        out_tensor = out_tensor.to(torch.uint8)

                    num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                    ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                    # Tolerance : less than 5% (cpu), 6% (cuda) of different pixels
                    tol = 0.06 if self.device == "cuda" else 0.05
                    self.assertLess(
                        ratio_diff_pixels,
                        tol,
                        msg="{}: {}\n{} vs \n{}".format(
611
                            (r, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
vfdev's avatar
vfdev committed
612
                        )
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
                    )

    def test_affine(self):
        # Tests on square and rectangular images
        scripted_affine = torch.jit.script(F.affine)

        data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
        for tensor, pil_img in data:

            for dt in [None, torch.float32, torch.float64, torch.float16]:

                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

                self._test_affine_identity_map(tensor, scripted_affine)
                if pil_img.size[0] == pil_img.size[1]:
                    self._test_affine_square_rotations(tensor, pil_img, scripted_affine)
                else:
                    self._test_affine_rect_rotations(tensor, pil_img, scripted_affine)
                self._test_affine_translations(tensor, pil_img, scripted_affine)
637
638
639
640
641
642
643
644
645
646
                self._test_affine_all_ops(tensor, pil_img, scripted_affine)

                batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
                if dt is not None:
                    batch_tensors = batch_tensors.to(dtype=dt)

                self._test_fn_on_batch(
                    batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]
                )

647
648
649
650
651
652
653
654
        tensor, pil_img = data[0]
        # assert deprecation warning and non-BC
        with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
            res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2)
            res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

        # assert changed type warning
655
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
656
657
658
659
660
661
662
663
664
            res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2)
            res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

        with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"):
            res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10)
            res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10)
            self.assertEqual(res1, res2)

665
666
667
    def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers):
        img_size = pil_img.size
        dt = tensor.dtype
668
        for r in [NEAREST, ]:
669
670
671
            for a in range(-180, 180, 17):
                for e in [True, False]:
                    for c in centers:
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
                        for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
                            f_pil = int(f[0]) if f is not None and len(f) == 1 else f
                            out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
                            out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
                            for fn in [F.rotate, scripted_rotate]:
                                out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()

                                if out_tensor.dtype != torch.uint8:
                                    out_tensor = out_tensor.to(torch.uint8)

                                self.assertEqual(
                                    out_tensor.shape,
                                    out_pil_tensor.shape,
                                    msg="{}: {} vs {}".format(
                                        (img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
                                    ))

                                num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                                ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                                # Tolerance : less than 3% of different pixels
                                self.assertLess(
693
                                    ratio_diff_pixels,
694
695
696
697
698
699
700
                                    0.03,
                                    msg="{}: {}\n{} vs \n{}".format(
                                        (img_size, r, dt, a, e, c, f),
                                        ratio_diff_pixels,
                                        out_tensor[0, :7, :7],
                                        out_pil_tensor[0, :7, :7]
                                    )
701
                                )
vfdev's avatar
vfdev committed
702

703
    def test_rotate(self):
vfdev's avatar
vfdev committed
704
705
706
        # Tests on square image
        scripted_rotate = torch.jit.script(F.rotate)

707
708
        data = [self._create_data(26, 26, device=self.device), self._create_data(32, 26, device=self.device)]
        for tensor, pil_img in data:
709
710
711
712
713
714
715
716

            img_size = pil_img.size
            centers = [
                None,
                (int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
                [int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
            ]

717
718
719
720
721
722
723
724
725
            for dt in [None, torch.float32, torch.float64, torch.float16]:

                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

726
727
728
729
730
731
732
733
                self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers)

                batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
                if dt is not None:
                    batch_tensors = batch_tensors.to(dtype=dt)

                center = (20, 22)
                self._test_fn_on_batch(
734
                    batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center
735
                )
736
737
738
739
740
741
742
743
        tensor, pil_img = data[0]
        # assert deprecation warning and non-BC
        with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"):
            res1 = F.rotate(tensor, 45, resample=2)
            res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

        # assert changed type warning
744
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
745
746
747
            res1 = F.rotate(tensor, 45, interpolation=2)
            res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))
748

749
    def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
750
        dt = tensor.dtype
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
        for f in [None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, )]:
            for r in [NEAREST, ]:
                for spoints, epoints in test_configs:
                    f_pil = int(f[0]) if f is not None and len(f) == 1 else f
                    out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r,
                                                fill=f_pil)
                    out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))

                    for fn in [F.perspective, scripted_transform]:
                        out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r, fill=f).cpu()

                        if out_tensor.dtype != torch.uint8:
                            out_tensor = out_tensor.to(torch.uint8)

                        num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
                        ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
                        # Tolerance : less than 5% of different pixels
                        self.assertLess(
769
                            ratio_diff_pixels,
770
771
772
773
774
775
776
                            0.05,
                            msg="{}: {}\n{} vs \n{}".format(
                                (f, r, dt, spoints, epoints),
                                ratio_diff_pixels,
                                out_tensor[0, :7, :7],
                                out_pil_tensor[0, :7, :7]
                            )
777
                        )
vfdev's avatar
vfdev committed
778

779
    def test_perspective(self):
780
781
782

        from torchvision.transforms import RandomPerspective

783
        data = [self._create_data(26, 34, device=self.device), self._create_data(26, 26, device=self.device)]
784
        scripted_transform = torch.jit.script(F.perspective)
785

786
        for tensor, pil_img in data:
787
788
789
790
791
792
793
794
795
796
797

            test_configs = [
                [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
                [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
                [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
            ]
            n = 10
            test_configs += [
                RandomPerspective.get_params(pil_img.size[0], pil_img.size[1], i / n) for i in range(n)
            ]

798
799
800
801
802
803
804
805
806
            for dt in [None, torch.float32, torch.float64, torch.float16]:

                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

807
                self._test_perspective(tensor, pil_img, scripted_transform, test_configs)
808

809
810
811
                batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device)
                if dt is not None:
                    batch_tensors = batch_tensors.to(dtype=dt)
812

813
814
                for spoints, epoints in test_configs:
                    self._test_fn_on_batch(
815
                        batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=NEAREST
816
                    )
817

818
819
820
        # assert changed type warning
        spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
        epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
821
        with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
822
823
824
825
            res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2)
            res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
            self.assertTrue(res1.equal(res2))

826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
    def test_gaussian_blur(self):
        small_image_tensor = torch.from_numpy(
            np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
        ).permute(2, 0, 1).to(self.device)

        large_image_tensor = torch.from_numpy(
            np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))
        ).to(self.device)

        scripted_transform = torch.jit.script(F.gaussian_blur)

        # true_cv2_results = {
        #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
        #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
        #     "3_3_0.8": ...
        #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
        #     "3_3_0.5": ...
        #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
        #     "3_5_0.8": ...
        #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
        #     "3_5_0.5": ...
        #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
        #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
        #     "23_23_1.7": ...
        # }
        p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt')
        true_cv2_results = torch.load(p)

        for tensor in [small_image_tensor, large_image_tensor]:

            for dt in [None, torch.float32, torch.float64, torch.float16]:
                if dt == torch.float16 and torch.device(self.device).type == "cpu":
                    # skip float16 on CPU case
                    continue

                if dt is not None:
                    tensor = tensor.to(dtype=dt)

                for ksize in [(3, 3), [3, 5], (23, 23)]:
                    for sigma in [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]:

                        _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
                        _sigma = sigma[0] if sigma is not None else None
                        shape = tensor.shape
                        gt_key = "{}_{}_{}__{}_{}_{}".format(
                            shape[-2], shape[-1], shape[-3],
                            _ksize[0], _ksize[1], _sigma
                        )
                        if gt_key not in true_cv2_results:
                            continue

                        true_out = torch.tensor(
                            true_cv2_results[gt_key]
                        ).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)

                        for fn in [F.gaussian_blur, scripted_transform]:
                            out = fn(tensor, kernel_size=ksize, sigma=sigma)
                            self.assertEqual(true_out.shape, out.shape, msg="{}, {}".format(ksize, sigma))
                            self.assertLessEqual(
                                torch.max(true_out.float() - out.float()),
                                1.0,
                                msg="{}, {}".format(ksize, sigma)
                            )

890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
    def test_invert(self):
        self._test_adjust_fn(
            F.invert,
            F_pil.invert,
            F_t.invert,
            [{}],
            tol=1.0,
            agg_method="max"
        )

    def test_posterize(self):
        self._test_adjust_fn(
            F.posterize,
            F_pil.posterize,
            F_t.posterize,
            [{"bits": bits} for bits in range(0, 8)],
            tol=1.0,
            agg_method="max",
            dts=(None,)
        )

    def test_solarize(self):
        self._test_adjust_fn(
            F.solarize,
            F_pil.solarize,
            F_t.solarize,
            [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
            tol=1.0,
            agg_method="max",
            dts=(None,)
        )
        self._test_adjust_fn(
            F.solarize,
            lambda img, threshold: F_pil.solarize(img, 255 * threshold),
            F_t.solarize,
            [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
            tol=1.0,
            agg_method="max",
            dts=(torch.float32, torch.float64)
        )

    def test_adjust_sharpness(self):
        self._test_adjust_fn(
            F.adjust_sharpness,
            F_pil.adjust_sharpness,
            F_t.adjust_sharpness,
            [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
        )

    def test_autocontrast(self):
        self._test_adjust_fn(
            F.autocontrast,
            F_pil.autocontrast,
            F_t.autocontrast,
            [{}],
            tol=1.0,
            agg_method="max"
        )

    def test_equalize(self):
        torch.set_deterministic(False)
        self._test_adjust_fn(
            F.equalize,
            F_pil.equalize,
            F_t.equalize,
            [{}],
            tol=1.0,
            agg_method="max",
            dts=(None,)
        )

961

962
963
964
965
966
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

    def setUp(self):
        self.device = "cuda"
967

968
969
970

if __name__ == '__main__':
    unittest.main()