test_transforms_tensor.py 29.7 KB
Newer Older
1
import os
2
3
4
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
5
from torchvision.transforms import InterpolationMode
6
7
8
9

import numpy as np

import unittest
10
from typing import Sequence
11

Nicolas Hug's avatar
Nicolas Hug committed
12
13
14
15
16
17
18
19
20
from common_utils import (
    get_tmp_dir,
    int_dtypes,
    float_dtypes,
    _create_data,
    _create_data_batch,
    _assert_equal_tensor_to_pil,
    _assert_approx_equal_tensor_to_pil,
)
21
from _assert_utils import assert_equal
22
23


24
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
25
26


Nicolas Hug's avatar
Nicolas Hug committed
27
class Tester(unittest.TestCase):
28

29
30
31
    def setUp(self):
        self.device = "cpu"

32
    def _test_functional_op(self, func, fn_kwargs, test_exact_match=True, **match_kwargs):
33
34
        if fn_kwargs is None:
            fn_kwargs = {}
35
36

        f = getattr(F, func)
Nicolas Hug's avatar
Nicolas Hug committed
37
        tensor, pil_img = _create_data(height=10, width=10, device=self.device)
38
39
        transformed_tensor = f(tensor, **fn_kwargs)
        transformed_pil_img = f(pil_img, **fn_kwargs)
40
        if test_exact_match:
Nicolas Hug's avatar
Nicolas Hug committed
41
            _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
42
        else:
Nicolas Hug's avatar
Nicolas Hug committed
43
            _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
44

45
    def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
46
47
48
49
        torch.manual_seed(12)
        out1 = transform(tensor)
        torch.manual_seed(12)
        out2 = s_transform(tensor)
50
        assert_equal(out1, out2, msg=msg)
51

52
    def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
53
54
55
56
57
58
59
        torch.manual_seed(12)
        transformed_batch = transform(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img = transform(img_tensor)
60
            assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
61
62
63

        torch.manual_seed(12)
        s_transformed_batch = s_transform(batch_tensors)
64
        assert_equal(transformed_batch, s_transformed_batch, msg=msg)
65

66
    def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
67
68
        if meth_kwargs is None:
            meth_kwargs = {}
vfdev's avatar
vfdev committed
69
70
71
72
73

        # test for class interface
        f = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(f)

Nicolas Hug's avatar
Nicolas Hug committed
74
        tensor, pil_img = _create_data(26, 34, device=self.device)
vfdev's avatar
vfdev committed
75
76
77
78
79
        # set seed to reproduce the same transformation for tensor and PIL image
        torch.manual_seed(12)
        transformed_tensor = f(tensor)
        torch.manual_seed(12)
        transformed_pil_img = f(pil_img)
80
        if test_exact_match:
Nicolas Hug's avatar
Nicolas Hug committed
81
            _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
82
        else:
Nicolas Hug's avatar
Nicolas Hug committed
83
            _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
84

vfdev's avatar
vfdev committed
85
86
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
87
        assert_equal(transformed_tensor, transformed_tensor_script)
88

Nicolas Hug's avatar
Nicolas Hug committed
89
        batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
90
91
        self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)

92
93
94
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_{}.pt".format(method)))

95
96
97
    def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
        self._test_functional_op(func, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
        self._test_class_op(method, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
98
99

    def test_random_horizontal_flip(self):
100
        self._test_op('hflip', 'RandomHorizontalFlip')
101
102

    def test_random_vertical_flip(self):
103
        self._test_op('vflip', 'RandomVerticalFlip')
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    def test_random_invert(self):
        self._test_op('invert', 'RandomInvert')

    def test_random_posterize(self):
        fn_kwargs = meth_kwargs = {"bits": 4}
        self._test_op(
            'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

    def test_random_solarize(self):
        fn_kwargs = meth_kwargs = {"threshold": 192.0}
        self._test_op(
            'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

    def test_random_adjust_sharpness(self):
        fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
        self._test_op(
            'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

    def test_random_autocontrast(self):
127
128
129
130
        # We check the max abs difference because on some (very rare) pixels, the actual value may be different
        # between PIL and tensors due to floating approximations.
        self._test_op('autocontrast', 'RandomAutocontrast', test_exact_match=False, agg_method='max',
                      tol=(1 + 1e-5), allowed_percentage_diff=.05)
131
132
133
134

    def test_random_equalize(self):
        self._test_op('equalize', 'RandomEqualize')

vfdev's avatar
vfdev committed
135
136
137
    def test_color_jitter(self):

        tol = 1.0 + 1e-10
138
        for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
vfdev's avatar
vfdev committed
139
140
141
142
143
            meth_kwargs = {"brightness": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )

144
        for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
vfdev's avatar
vfdev committed
145
146
147
148
149
            meth_kwargs = {"contrast": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )

150
        for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
vfdev's avatar
vfdev committed
151
152
153
154
            meth_kwargs = {"saturation": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )
155

156
157
158
        for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
            meth_kwargs = {"hue": f}
            self._test_class_op(
vfdev's avatar
vfdev committed
159
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max"
160
161
162
163
164
            )

        # All 4 parameters together
        meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
        self._test_class_op(
vfdev's avatar
vfdev committed
165
            "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max"
166
167
        )

168
    def test_pad(self):
169
170
        for m in ["constant", "edge", "reflect", "symmetric"]:
            fill = 127 if m == "constant" else 0
171
            for mul in [1, -1]:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
                # Test functional.pad (PIL and Tensor) with padding as single int
                self._test_functional_op(
                    "pad", fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}
                )
                # Test functional.pad and transforms.Pad with padding as [int, ]
                fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m}
                self._test_op(
                    "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
                )
                # Test functional.pad and transforms.Pad with padding as list
                fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
                self._test_op(
                    "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
                )
                # Test functional.pad and transforms.Pad with padding as tuple
                fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
                self._test_op(
                    "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
                )
191

vfdev's avatar
vfdev committed
192
193
194
195
    def test_crop(self):
        fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
        # Test transforms.RandomCrop with size and padding as tuple
        meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
196
        self._test_op(
vfdev's avatar
vfdev committed
197
198
199
            'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        # Test transforms.functional.crop including outside the image area
        fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5}  # top
        self._test_functional_op('crop', fn_kwargs=fn_kwargs)

        fn_kwargs = {"top": 1, "left": -3, "height": 4, "width": 5}  # left
        self._test_functional_op('crop', fn_kwargs=fn_kwargs)

        fn_kwargs = {"top": 7, "left": 3, "height": 4, "width": 5}  # bottom
        self._test_functional_op('crop', fn_kwargs=fn_kwargs)

        fn_kwargs = {"top": 3, "left": 8, "height": 4, "width": 5}  # right
        self._test_functional_op('crop', fn_kwargs=fn_kwargs)

        fn_kwargs = {"top": -3, "left": -3, "height": 15, "width": 15}  # all
        self._test_functional_op('crop', fn_kwargs=fn_kwargs)

vfdev's avatar
vfdev committed
216
217
218
219
220
221
222
223
224
225
226
227
228
        sizes = [5, [5, ], [6, 6]]
        padding_configs = [
            {"padding_mode": "constant", "fill": 0},
            {"padding_mode": "constant", "fill": 10},
            {"padding_mode": "constant", "fill": 20},
            {"padding_mode": "edge"},
            {"padding_mode": "reflect"},
        ]

        for size in sizes:
            for padding_config in padding_configs:
                config = dict(padding_config)
                config["size"] = size
229
                self._test_class_op("RandomCrop", config)
vfdev's avatar
vfdev committed
230
231
232
233

    def test_center_crop(self):
        fn_kwargs = {"output_size": (4, 5)}
        meth_kwargs = {"size": (4, 5), }
234
        self._test_op(
vfdev's avatar
vfdev committed
235
236
237
238
            "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = {"output_size": (5,)}
        meth_kwargs = {"size": (5, )}
239
        self._test_op(
vfdev's avatar
vfdev committed
240
241
            "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
242
        tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        # Test torchscript of transforms.CenterCrop with size as int
        f = T.CenterCrop(size=5)
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

        # Test torchscript of transforms.CenterCrop with size as [int, ]
        f = T.CenterCrop(size=[5, ])
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

        # Test torchscript of transforms.CenterCrop with size as tuple
        f = T.CenterCrop(size=(6, 6))
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

258
259
260
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt"))

261
    def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
vfdev's avatar
vfdev committed
262
263
264
265
        if fn_kwargs is None:
            fn_kwargs = {}
        if meth_kwargs is None:
            meth_kwargs = {}
266
267
268
269

        fn = getattr(F, func)
        scripted_fn = torch.jit.script(fn)

Nicolas Hug's avatar
Nicolas Hug committed
270
        tensor, pil_img = _create_data(height=20, width=20, device=self.device)
271
272
        transformed_t_list = fn(tensor, **fn_kwargs)
        transformed_p_list = fn(pil_img, **fn_kwargs)
vfdev's avatar
vfdev committed
273
274
275
        self.assertEqual(len(transformed_t_list), len(transformed_p_list))
        self.assertEqual(len(transformed_t_list), out_length)
        for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
Nicolas Hug's avatar
Nicolas Hug committed
276
            _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)
vfdev's avatar
vfdev committed
277
278
279
280
281

        transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
        self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
        self.assertEqual(len(transformed_t_list_script), out_length)
        for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
282
283
284
285
286
            assert_equal(
                transformed_tensor,
                transformed_tensor_script,
                msg="{} vs {}".format(transformed_tensor, transformed_tensor_script),
            )
vfdev's avatar
vfdev committed
287
288

        # test for class interface
289
290
        fn = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(fn)
vfdev's avatar
vfdev committed
291
292
293
        output = scripted_fn(tensor)
        self.assertEqual(len(output), len(transformed_t_list_script))

294
        # test on batch of tensors
Nicolas Hug's avatar
Nicolas Hug committed
295
        batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
296
297
298
299
300
301
302
303
        torch.manual_seed(12)
        transformed_batch_list = fn(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img_list = fn(img_tensor)
            for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
304
305
306
307
308
                assert_equal(
                    transformed_img,
                    transformed_batch[i, ...],
                    msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]),
                )
309

310
311
312
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))

vfdev's avatar
vfdev committed
313
314
    def test_five_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
315
        self._test_op_list_output(
vfdev's avatar
vfdev committed
316
317
318
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
319
        self._test_op_list_output(
vfdev's avatar
vfdev committed
320
321
322
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
323
        self._test_op_list_output(
vfdev's avatar
vfdev committed
324
325
326
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
327
        self._test_op_list_output(
vfdev's avatar
vfdev committed
328
329
330
331
332
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

    def test_ten_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
333
        self._test_op_list_output(
vfdev's avatar
vfdev committed
334
335
336
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
337
        self._test_op_list_output(
vfdev's avatar
vfdev committed
338
339
340
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
341
        self._test_op_list_output(
vfdev's avatar
vfdev committed
342
343
344
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
345
        self._test_op_list_output(
vfdev's avatar
vfdev committed
346
347
348
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
349
    def test_resize(self):
350
351
352
353
354
355
356
357
358
359
360

        # TODO: Minimal check for bug-fix, improve this later
        x = torch.rand(3, 32, 46)
        t = T.Resize(size=38)
        y = t(x)
        # If size is an int, smaller edge of the image will be matched to this number.
        # i.e, if height > width, then image will be rescaled to (size * height / width, size).
        self.assertTrue(isinstance(y, torch.Tensor))
        self.assertEqual(y.shape[1], 38)
        self.assertEqual(y.shape[2], int(38 * 46 / 32))

Nicolas Hug's avatar
Nicolas Hug committed
361
        tensor, _ = _create_data(height=34, width=36, device=self.device)
362
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
363
364
365
366
367

        for dt in [None, torch.float32, torch.float64]:
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
368
            for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
369
370
371
372
                for max_size in (None, 35, 1000):
                    if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
                        continue  # Not supported
                    for interpolation in [BILINEAR, BICUBIC, NEAREST]:
vfdev's avatar
vfdev committed
373

374
375
376
377
                        if isinstance(size, int):
                            script_size = [size, ]
                        else:
                            script_size = size
vfdev's avatar
vfdev committed
378

379
380
381
382
                        transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size)
                        s_transform = torch.jit.script(transform)
                        self._test_transform_vs_scripted(transform, s_transform, tensor)
                        self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
vfdev's avatar
vfdev committed
383

384
        with get_tmp_dir() as tmp_dir:
385
            s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))
386

387
    def test_resized_crop(self):
388
389
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
390

391
392
        for scale in [(0.7, 1.2), [0.7, 1.2]]:
            for ratio in [(0.75, 1.333), [0.75, 1.333]]:
393
                for size in [(32, ), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]:
394
395
396
397
398
                    for interpolation in [NEAREST, BILINEAR, BICUBIC]:
                        transform = T.RandomResizedCrop(
                            size=size, scale=scale, ratio=ratio, interpolation=interpolation
                        )
                        s_transform = torch.jit.script(transform)
399
400
                        self._test_transform_vs_scripted(transform, s_transform, tensor)
                        self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
401

402
403
404
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))

405
    def test_random_affine(self):
406
407
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
408

409
410
411
412
413
414
415
416
417
418
419
420
421
        def _test(**kwargs):
            transform = T.RandomAffine(**kwargs)
            s_transform = torch.jit.script(transform)

            self._test_transform_vs_scripted(transform, s_transform, tensor)
            self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

            return s_transform

        for interpolation in [NEAREST, BILINEAR]:
            for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
                _test(degrees=0.0, interpolation=interpolation, shear=shear)

422
            for scale in [(0.7, 1.2), [0.7, 1.2]]:
423
424
425
426
427
428
429
430
431
432
                _test(degrees=0.0, interpolation=interpolation, scale=scale)

            for translate in [(0.1, 0.2), [0.2, 0.1]]:
                _test(degrees=0.0, interpolation=interpolation, translate=translate)

            for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                _test(degrees=degrees, interpolation=interpolation)

            for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
                _test(degrees=0.0, interpolation=interpolation, fill=fill)
433

434
        s_transform = _test(degrees=0.0)
435
436
437
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt"))

438
    def test_random_rotate(self):
439
440
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
441
442
443
444
445

        for center in [(0, 0), [10, 10], None, (56, 44)]:
            for expand in [True, False]:
                for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                    for interpolation in [NEAREST, BILINEAR]:
446
447
448
449
450
                        for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
                            transform = T.RandomRotation(
                                degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill
                            )
                            s_transform = torch.jit.script(transform)
451

452
453
                            self._test_transform_vs_scripted(transform, s_transform, tensor)
                            self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
454

455
456
457
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))

458
    def test_random_perspective(self):
459
460
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
461
462
463

        for distortion_scale in np.linspace(0.1, 1.0, num=20):
            for interpolation in [NEAREST, BILINEAR]:
464
465
466
467
468
469
470
                for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
                    transform = T.RandomPerspective(
                        distortion_scale=distortion_scale,
                        interpolation=interpolation,
                        fill=fill
                    )
                    s_transform = torch.jit.script(transform)
471

472
473
                    self._test_transform_vs_scripted(transform, s_transform, tensor)
                    self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
474

475
476
477
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    def test_to_grayscale(self):

        meth_kwargs = {"num_output_channels": 1}
        tol = 1.0 + 1e-10
        self._test_class_op(
            "Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

        meth_kwargs = {"num_output_channels": 3}
        self._test_class_op(
            "Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

        meth_kwargs = {}
        self._test_class_op(
            "RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

496
    def test_normalize(self):
497
        fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
Nicolas Hug's avatar
Nicolas Hug committed
498
        tensor, _ = _create_data(26, 34, device=self.device)
499

500
501
502
503
        with self.assertRaisesRegex(TypeError, r"Input tensor should be a float tensor"):
            fn(tensor)

        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
504
505
506
507
508
509
510
        tensor = tensor.to(dtype=torch.float32) / 255.0
        # test for class interface
        scripted_fn = torch.jit.script(fn)

        self._test_transform_vs_scripted(fn, scripted_fn, tensor)
        self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)

511
512
513
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))

514
515
516
    def test_linear_transformation(self):
        c, h, w = 3, 24, 32

Nicolas Hug's avatar
Nicolas Hug committed
517
        tensor, _ = _create_data(h, w, channels=c, device=self.device)
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533

        matrix = torch.rand(c * h * w, c * h * w, device=self.device)
        mean_vector = torch.rand(c * h * w, device=self.device)

        fn = T.LinearTransformation(matrix, mean_vector)
        scripted_fn = torch.jit.script(fn)

        self._test_transform_vs_scripted(fn, scripted_fn, tensor)

        batch_tensors = torch.rand(4, c, h, w, device=self.device)
        # We skip some tests from _test_transform_vs_scripted_on_batch as
        # results for scripted and non-scripted transformations are not exactly the same
        torch.manual_seed(12)
        transformed_batch = fn(batch_tensors)
        torch.manual_seed(12)
        s_transformed_batch = scripted_fn(batch_tensors)
534
        assert_equal(transformed_batch, s_transformed_batch)
535

536
537
538
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))

539
    def test_compose(self):
Nicolas Hug's avatar
Nicolas Hug committed
540
        tensor, _ = _create_data(26, 34, device=self.device)
541
542
543
544
545
546
547
548
549
550
551
552
553
        tensor = tensor.to(dtype=torch.float32) / 255.0

        transforms = T.Compose([
            T.CenterCrop(10),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        s_transforms = torch.nn.Sequential(*transforms.transforms)

        scripted_fn = torch.jit.script(s_transforms)
        torch.manual_seed(12)
        transformed_tensor = transforms(tensor)
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
554
        assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
555
556
557
558
559
560
561

        t = T.Compose([
            lambda x: x,
        ])
        with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
            torch.jit.script(t)

562
    def test_random_apply(self):
Nicolas Hug's avatar
Nicolas Hug committed
563
        tensor, _ = _create_data(26, 34, device=self.device)
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        tensor = tensor.to(dtype=torch.float32) / 255.0

        transforms = T.RandomApply([
            T.RandomHorizontalFlip(),
            T.ColorJitter(),
        ], p=0.4)
        s_transforms = T.RandomApply(torch.nn.ModuleList([
            T.RandomHorizontalFlip(),
            T.ColorJitter(),
        ]), p=0.4)

        scripted_fn = torch.jit.script(s_transforms)
        torch.manual_seed(12)
        transformed_tensor = transforms(tensor)
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
580
        assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
581
582
583
584
585
586
587
588
589
590

        if torch.device(self.device).type == "cpu":
            # Can't check this twice, otherwise
            # "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply"
            transforms = T.RandomApply([
                T.ColorJitter(),
            ], p=0.3)
            with self.assertRaisesRegex(RuntimeError, r"Module 'RandomApply' has no attribute 'transforms'"):
                torch.jit.script(transforms)

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    def test_gaussian_blur(self):
        tol = 1.0 + 1e-10
        self._test_class_op(
            "GaussianBlur", meth_kwargs={"kernel_size": 3, "sigma": 0.75},
            test_exact_match=False, agg_method="max", tol=tol
        )

        self._test_class_op(
            "GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
            test_exact_match=False, agg_method="max", tol=tol
        )

        self._test_class_op(
            "GaussianBlur", meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
            test_exact_match=False, agg_method="max", tol=tol
        )

        self._test_class_op(
            "GaussianBlur", meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
            test_exact_match=False, agg_method="max", tol=tol
        )

        self._test_class_op(
            "GaussianBlur", meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
            test_exact_match=False, agg_method="max", tol=tol
        )

        self._test_class_op(
            "GaussianBlur", meth_kwargs={"kernel_size": [23], "sigma": 0.75},
            test_exact_match=False, agg_method="max", tol=tol
        )

vfdev's avatar
vfdev committed
623
624
625
626
627
628
629
630
    def test_random_erasing(self):
        img = torch.rand(3, 60, 60)

        # Test Set 0: invalid value
        random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
        with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
            random_erasing(img)

Nicolas Hug's avatar
Nicolas Hug committed
631
        tensor, _ = _create_data(24, 32, channels=3, device=self.device)
vfdev's avatar
vfdev committed
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

        test_configs = [
            {"value": 0.2},
            {"value": "random"},
            {"value": (0.2, 0.2, 0.2)},
            {"value": "random", "ratio": (0.1, 0.2)},
        ]

        for config in test_configs:
            fn = T.RandomErasing(**config)
            scripted_fn = torch.jit.script(fn)
            self._test_transform_vs_scripted(fn, scripted_fn, tensor)
            self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)

647
648
649
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))

650
    def test_convert_image_dtype(self):
Nicolas Hug's avatar
Nicolas Hug committed
651
        tensor, _ = _create_data(26, 34, device=self.device)
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

        for in_dtype in int_dtypes() + float_dtypes():
            in_tensor = tensor.to(in_dtype)
            in_batch_tensors = batch_tensors.to(in_dtype)
            for out_dtype in int_dtypes() + float_dtypes():

                fn = T.ConvertImageDtype(dtype=out_dtype)
                scripted_fn = torch.jit.script(fn)

                if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \
                        (in_dtype == torch.float64 and out_dtype == torch.int64):
                    with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
                        self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
                    with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
                        self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
                    continue

                self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
                self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)

        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))

676
677
678
679
    def test_autoaugment(self):
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)

680
        s_transform = None
681
682
        for policy in T.AutoAugmentPolicy:
            for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
683
684
                transform = T.AutoAugment(policy=policy, fill=fill)
                s_transform = torch.jit.script(transform)
685
                for _ in range(25):
686
687
688
                    self._test_transform_vs_scripted(transform, s_transform, tensor)
                    self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

689
690
691
        if s_transform is not None:
            with get_tmp_dir() as tmp_dir:
                s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))
692

693

694
695
696
697
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

    def setUp(self):
698
        torch.set_deterministic(False)
699
700
701
        self.device = "cuda"


702
703
if __name__ == '__main__':
    unittest.main()