test_transforms_tensor.py 19 KB
Newer Older
1
2
3
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
4

vfdev's avatar
vfdev committed
5
from PIL.Image import NEAREST, BILINEAR, BICUBIC
6
7
8
9
10

import numpy as np

import unittest

11
from common_utils import TransformsTester
12
13


14
class Tester(TransformsTester):
15

16
17
18
    def setUp(self):
        self.device = "cpu"

19
    def _test_functional_op(self, func, fn_kwargs):
20
21
        if fn_kwargs is None:
            fn_kwargs = {}
22
23

        f = getattr(F, func)
24
        tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
25
26
        transformed_tensor = f(tensor, **fn_kwargs)
        transformed_pil_img = f(pil_img, **fn_kwargs)
27
28
        self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    def _test_transform_vs_scripted(self, transform, s_transform, tensor):
        torch.manual_seed(12)
        out1 = transform(tensor)
        torch.manual_seed(12)
        out2 = s_transform(tensor)
        self.assertTrue(out1.equal(out2))

    def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors):
        torch.manual_seed(12)
        transformed_batch = transform(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img = transform(img_tensor)
            self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))

        torch.manual_seed(12)
        s_transformed_batch = s_transform(batch_tensors)
        self.assertTrue(transformed_batch.equal(s_transformed_batch))

50
    def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
51
52
        if meth_kwargs is None:
            meth_kwargs = {}
vfdev's avatar
vfdev committed
53
54
55
56
57

        # test for class interface
        f = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(f)

58
        tensor, pil_img = self._create_data(26, 34, device=self.device)
vfdev's avatar
vfdev committed
59
60
61
62
63
        # set seed to reproduce the same transformation for tensor and PIL image
        torch.manual_seed(12)
        transformed_tensor = f(tensor)
        torch.manual_seed(12)
        transformed_pil_img = f(pil_img)
64
65
66
67
        if test_exact_match:
            self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
        else:
            self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
68

vfdev's avatar
vfdev committed
69
70
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
71
        self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
72

73
74
75
        batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
        self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)

76
77
78
    def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
        self._test_functional_op(func, fn_kwargs)
        self._test_class_op(method, meth_kwargs)
79
80

    def test_random_horizontal_flip(self):
81
        self._test_op('hflip', 'RandomHorizontalFlip')
82
83

    def test_random_vertical_flip(self):
84
        self._test_op('vflip', 'RandomVerticalFlip')
85

vfdev's avatar
vfdev committed
86
87
88
    def test_color_jitter(self):

        tol = 1.0 + 1e-10
89
        for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
vfdev's avatar
vfdev committed
90
91
92
93
94
            meth_kwargs = {"brightness": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )

95
        for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
vfdev's avatar
vfdev committed
96
97
98
99
100
            meth_kwargs = {"contrast": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )

101
        for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
vfdev's avatar
vfdev committed
102
103
104
105
            meth_kwargs = {"saturation": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )
106

107
108
109
110
111
112
113
114
115
116
117
118
        for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
            meth_kwargs = {"hue": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
            )

        # All 4 parameters together
        meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
        self._test_class_op(
            "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
        )

119
120
121
    def test_pad(self):

        # Test functional.pad (PIL and Tensor) with padding as single int
122
        self._test_functional_op(
123
124
125
126
            "pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"}
        )
        # Test functional.pad and transforms.Pad with padding as [int, ]
        fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"}
127
        self._test_op(
128
129
130
131
            "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        # Test functional.pad and transforms.Pad with padding as list
        fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"}
132
        self._test_op(
133
134
135
136
            "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        # Test functional.pad and transforms.Pad with padding as tuple
        fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"}
137
        self._test_op(
138
139
140
            "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
141
142
143
144
    def test_crop(self):
        fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
        # Test transforms.RandomCrop with size and padding as tuple
        meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
145
        self._test_op(
vfdev's avatar
vfdev committed
146
147
148
            'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
149
150
151
152
153
154
155
156
157
158
159
160
161
        sizes = [5, [5, ], [6, 6]]
        padding_configs = [
            {"padding_mode": "constant", "fill": 0},
            {"padding_mode": "constant", "fill": 10},
            {"padding_mode": "constant", "fill": 20},
            {"padding_mode": "edge"},
            {"padding_mode": "reflect"},
        ]

        for size in sizes:
            for padding_config in padding_configs:
                config = dict(padding_config)
                config["size"] = size
162
                self._test_class_op("RandomCrop", config)
vfdev's avatar
vfdev committed
163
164
165
166

    def test_center_crop(self):
        fn_kwargs = {"output_size": (4, 5)}
        meth_kwargs = {"size": (4, 5), }
167
        self._test_op(
vfdev's avatar
vfdev committed
168
169
170
171
            "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = {"output_size": (5,)}
        meth_kwargs = {"size": (5, )}
172
        self._test_op(
vfdev's avatar
vfdev committed
173
174
            "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
175
        tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        # Test torchscript of transforms.CenterCrop with size as int
        f = T.CenterCrop(size=5)
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

        # Test torchscript of transforms.CenterCrop with size as [int, ]
        f = T.CenterCrop(size=[5, ])
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

        # Test torchscript of transforms.CenterCrop with size as tuple
        f = T.CenterCrop(size=(6, 6))
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

191
    def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
vfdev's avatar
vfdev committed
192
193
194
195
        if fn_kwargs is None:
            fn_kwargs = {}
        if meth_kwargs is None:
            meth_kwargs = {}
196
197
198
199

        fn = getattr(F, func)
        scripted_fn = torch.jit.script(fn)

200
        tensor, pil_img = self._create_data(height=20, width=20, device=self.device)
201
202
        transformed_t_list = fn(tensor, **fn_kwargs)
        transformed_p_list = fn(pil_img, **fn_kwargs)
vfdev's avatar
vfdev committed
203
204
205
206
207
208
209
210
211
212
213
214
215
        self.assertEqual(len(transformed_t_list), len(transformed_p_list))
        self.assertEqual(len(transformed_t_list), out_length)
        for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
            self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

        transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
        self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
        self.assertEqual(len(transformed_t_list_script), out_length)
        for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
            self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
                            msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))

        # test for class interface
216
217
        fn = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(fn)
vfdev's avatar
vfdev committed
218
219
220
        output = scripted_fn(tensor)
        self.assertEqual(len(output), len(transformed_t_list_script))

221
222
223
224
225
226
227
228
229
230
231
232
233
        # test on batch of tensors
        batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
        torch.manual_seed(12)
        transformed_batch_list = fn(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img_list = fn(img_tensor)
            for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
                self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
                                msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))

vfdev's avatar
vfdev committed
234
235
    def test_five_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
236
        self._test_op_list_output(
vfdev's avatar
vfdev committed
237
238
239
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
240
        self._test_op_list_output(
vfdev's avatar
vfdev committed
241
242
243
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
244
        self._test_op_list_output(
vfdev's avatar
vfdev committed
245
246
247
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
248
        self._test_op_list_output(
vfdev's avatar
vfdev committed
249
250
251
252
253
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

    def test_ten_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
254
        self._test_op_list_output(
vfdev's avatar
vfdev committed
255
256
257
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
258
        self._test_op_list_output(
vfdev's avatar
vfdev committed
259
260
261
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
262
        self._test_op_list_output(
vfdev's avatar
vfdev committed
263
264
265
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
266
        self._test_op_list_output(
vfdev's avatar
vfdev committed
267
268
269
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
270
    def test_resize(self):
271
        tensor, _ = self._create_data(height=34, width=36, device=self.device)
272
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
273
274
275
276
277
278
        script_fn = torch.jit.script(F.resize)

        for dt in [None, torch.float32, torch.float64]:
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
279
            for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
vfdev's avatar
vfdev committed
280
281
282
283
284
285
286
287
288
289
290
291
292
                for interpolation in [BILINEAR, BICUBIC, NEAREST]:

                    resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)

                    if isinstance(size, int):
                        script_size = [size, ]
                    else:
                        script_size = size

                    s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
                    self.assertTrue(s_resized_tensor.equal(resized_tensor))

                    transform = T.Resize(size=script_size, interpolation=interpolation)
293
294
295
                    s_transform = torch.jit.script(transform)
                    self._test_transform_vs_scripted(transform, s_transform, tensor)
                    self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
vfdev's avatar
vfdev committed
296

297
    def test_resized_crop(self):
298
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
299
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
300

301
302
        for scale in [(0.7, 1.2), [0.7, 1.2]]:
            for ratio in [(0.75, 1.333), [0.75, 1.333]]:
303
                for size in [(32, ), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]:
304
305
306
307
308
                    for interpolation in [NEAREST, BILINEAR, BICUBIC]:
                        transform = T.RandomResizedCrop(
                            size=size, scale=scale, ratio=ratio, interpolation=interpolation
                        )
                        s_transform = torch.jit.script(transform)
309
310
                        self._test_transform_vs_scripted(transform, s_transform, tensor)
                        self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
311
312

    def test_random_affine(self):
313
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
314
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
315
316
317
318
319
320
321
322
323
324
325
326

        for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
            for scale in [(0.7, 1.2), [0.7, 1.2]]:
                for translate in [(0.1, 0.2), [0.2, 0.1]]:
                    for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                        for interpolation in [NEAREST, BILINEAR]:
                            transform = T.RandomAffine(
                                degrees=degrees, translate=translate,
                                scale=scale, shear=shear, resample=interpolation
                            )
                            s_transform = torch.jit.script(transform)

327
328
                            self._test_transform_vs_scripted(transform, s_transform, tensor)
                            self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
329

330
    def test_random_rotate(self):
331
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
332
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
333
334
335
336
337
338
339
340
341
342

        for center in [(0, 0), [10, 10], None, (56, 44)]:
            for expand in [True, False]:
                for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                    for interpolation in [NEAREST, BILINEAR]:
                        transform = T.RandomRotation(
                            degrees=degrees, resample=interpolation, expand=expand, center=center
                        )
                        s_transform = torch.jit.script(transform)

343
344
                        self._test_transform_vs_scripted(transform, s_transform, tensor)
                        self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
345

346
    def test_random_perspective(self):
347
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
348
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
349
350
351
352
353
354
355
356
357

        for distortion_scale in np.linspace(0.1, 1.0, num=20):
            for interpolation in [NEAREST, BILINEAR]:
                transform = T.RandomPerspective(
                    distortion_scale=distortion_scale,
                    interpolation=interpolation
                )
                s_transform = torch.jit.script(transform)

358
359
                self._test_transform_vs_scripted(transform, s_transform, tensor)
                self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    def test_to_grayscale(self):

        meth_kwargs = {"num_output_channels": 1}
        tol = 1.0 + 1e-10
        self._test_class_op(
            "Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

        meth_kwargs = {"num_output_channels": 3}
        self._test_class_op(
            "Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

        meth_kwargs = {}
        self._test_class_op(
            "RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    def test_normalize(self):
        tensor, _ = self._create_data(26, 34, device=self.device)
        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

        tensor = tensor.to(dtype=torch.float32) / 255.0
        # test for class interface
        fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        scripted_fn = torch.jit.script(fn)

        self._test_transform_vs_scripted(fn, scripted_fn, tensor)
        self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)

    def test_linear_transformation(self):
        c, h, w = 3, 24, 32

        tensor, _ = self._create_data(h, w, channels=c, device=self.device)

        matrix = torch.rand(c * h * w, c * h * w, device=self.device)
        mean_vector = torch.rand(c * h * w, device=self.device)

        fn = T.LinearTransformation(matrix, mean_vector)
        scripted_fn = torch.jit.script(fn)

        self._test_transform_vs_scripted(fn, scripted_fn, tensor)

        batch_tensors = torch.rand(4, c, h, w, device=self.device)
        # We skip some tests from _test_transform_vs_scripted_on_batch as
        # results for scripted and non-scripted transformations are not exactly the same
        torch.manual_seed(12)
        transformed_batch = fn(batch_tensors)
        torch.manual_seed(12)
        s_transformed_batch = scripted_fn(batch_tensors)
        self.assertTrue(transformed_batch.equal(s_transformed_batch))

    def test_compose(self):
        tensor, _ = self._create_data(26, 34, device=self.device)
        tensor = tensor.to(dtype=torch.float32) / 255.0

        transforms = T.Compose([
            T.CenterCrop(10),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        s_transforms = torch.nn.Sequential(*transforms.transforms)

        scripted_fn = torch.jit.script(s_transforms)
        torch.manual_seed(12)
        transformed_tensor = transforms(tensor)
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
        self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))

        t = T.Compose([
            lambda x: x,
        ])
        with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
            torch.jit.script(t)

436

437
438
439
440
441
442
443
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

    def setUp(self):
        self.device = "cuda"


444
445
if __name__ == '__main__':
    unittest.main()