test_transforms_tensor.py 30.1 KB
Newer Older
1
import os
2
3
4
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
5
from torchvision.transforms import InterpolationMode
6
7
8
9

import numpy as np

import unittest
10
from typing import Sequence
11

Nicolas Hug's avatar
Nicolas Hug committed
12
13
14
15
16
17
18
19
20
from common_utils import (
    get_tmp_dir,
    int_dtypes,
    float_dtypes,
    _create_data,
    _create_data_batch,
    _assert_equal_tensor_to_pil,
    _assert_approx_equal_tensor_to_pil,
)
21
from _assert_utils import assert_equal
22
23


24
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
25
26


27
28
29
30
31
32
def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None):
    torch.manual_seed(12)
    out1 = transform(tensor)
    torch.manual_seed(12)
    out2 = s_transform(tensor)
    assert_equal(out1, out2, msg=msg)
33

34

35
36
37
def _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors, msg=None):
    torch.manual_seed(12)
    transformed_batch = transform(batch_tensors)
38

39
40
    for i in range(len(batch_tensors)):
        img_tensor = batch_tensors[i, ...]
41
        torch.manual_seed(12)
42
43
        transformed_img = transform(img_tensor)
        assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
44

45
46
47
    torch.manual_seed(12)
    s_transformed_batch = s_transform(batch_tensors)
    assert_equal(transformed_batch, s_transformed_batch, msg=msg)
48
49


50
51
def _test_functional_op(f, device, fn_kwargs=None, test_exact_match=True, **match_kwargs):
    fn_kwargs = fn_kwargs or {}
52

53
54
55
56
57
58
59
    tensor, pil_img = _create_data(height=10, width=10, device=device)
    transformed_tensor = f(tensor, **fn_kwargs)
    transformed_pil_img = f(pil_img, **fn_kwargs)
    if test_exact_match:
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
    else:
        _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
vfdev's avatar
vfdev committed
60
61


62
63
64
def _test_class_op(method, device, meth_kwargs=None, test_exact_match=True, **match_kwargs):
    # TODO: change the name: it's not a method, it's a class.
    meth_kwargs = meth_kwargs or {}
65

66
67
68
    # test for class interface
    f = method(**meth_kwargs)
    scripted_fn = torch.jit.script(f)
69

70
71
72
73
74
75
76
77
78
79
    tensor, pil_img = _create_data(26, 34, device=device)
    # set seed to reproduce the same transformation for tensor and PIL image
    torch.manual_seed(12)
    transformed_tensor = f(tensor)
    torch.manual_seed(12)
    transformed_pil_img = f(pil_img)
    if test_exact_match:
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
    else:
        _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
80

81
82
83
84
85
86
87
88
89
    torch.manual_seed(12)
    transformed_tensor_script = scripted_fn(tensor)
    assert_equal(transformed_tensor, transformed_tensor_script)

    batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
    _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)

    with get_tmp_dir() as tmp_dir:
        scripted_fn.save(os.path.join(tmp_dir, f"t_{method.__name__}.pt"))
90

91
92
93
94
95
96
97
98
99
100

def _test_op(func, method, device, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
    _test_functional_op(func, device, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
    _test_class_op(method, device, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)


class Tester(unittest.TestCase):

    def setUp(self):
        self.device = "cpu"
101
102

    def test_random_horizontal_flip(self):
103
        _test_op(F.hflip, T.RandomHorizontalFlip, device=self.device)
104
105

    def test_random_vertical_flip(self):
106
        _test_op(F.vflip, T.RandomVerticalFlip, device=self.device)
107

108
    def test_random_invert(self):
109
        _test_op(F.invert, T.RandomInvert, device=self.device)
110
111
112

    def test_random_posterize(self):
        fn_kwargs = meth_kwargs = {"bits": 4}
113
114
115
        _test_op(
            F.posterize, T.RandomPosterize, device=self.device, fn_kwargs=fn_kwargs,
            meth_kwargs=meth_kwargs
116
117
118
119
        )

    def test_random_solarize(self):
        fn_kwargs = meth_kwargs = {"threshold": 192.0}
120
121
122
        _test_op(
            F.solarize, T.RandomSolarize, device=self.device, fn_kwargs=fn_kwargs,
            meth_kwargs=meth_kwargs
123
124
125
126
        )

    def test_random_adjust_sharpness(self):
        fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
127
128
129
        _test_op(
            F.adjust_sharpness, T.RandomAdjustSharpness, device=self.device, fn_kwargs=fn_kwargs,
            meth_kwargs=meth_kwargs
130
131
132
        )

    def test_random_autocontrast(self):
133
134
        # We check the max abs difference because on some (very rare) pixels, the actual value may be different
        # between PIL and tensors due to floating approximations.
135
136
137
138
        _test_op(
            F.autocontrast, T.RandomAutocontrast, device=self.device, test_exact_match=False,
            agg_method='max', tol=(1 + 1e-5), allowed_percentage_diff=.05
        )
139
140

    def test_random_equalize(self):
141
        _test_op(F.equalize, T.RandomEqualize, device=self.device)
142

vfdev's avatar
vfdev committed
143
144
145
    def test_color_jitter(self):

        tol = 1.0 + 1e-10
146
        for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
vfdev's avatar
vfdev committed
147
            meth_kwargs = {"brightness": f}
148
149
150
            _test_class_op(
                T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
                tol=tol, agg_method="max"
vfdev's avatar
vfdev committed
151
152
            )

153
        for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
vfdev's avatar
vfdev committed
154
            meth_kwargs = {"contrast": f}
155
156
157
            _test_class_op(
                T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
                tol=tol, agg_method="max"
vfdev's avatar
vfdev committed
158
159
            )

160
        for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
vfdev's avatar
vfdev committed
161
            meth_kwargs = {"saturation": f}
162
163
164
            _test_class_op(
                T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
                tol=tol, agg_method="max"
vfdev's avatar
vfdev committed
165
            )
166

167
168
        for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
            meth_kwargs = {"hue": f}
169
170
171
            _test_class_op(
                T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
                tol=16.1, agg_method="max"
172
173
174
175
            )

        # All 4 parameters together
        meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
176
177
178
        _test_class_op(
            T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
            tol=12.1, agg_method="max"
179
180
        )

181
    def test_pad(self):
182
183
        for m in ["constant", "edge", "reflect", "symmetric"]:
            fill = 127 if m == "constant" else 0
184
            for mul in [1, -1]:
185
                # Test functional.pad (PIL and Tensor) with padding as single int
186
187
188
                _test_functional_op(
                    F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m},
                    device=self.device
189
190
191
                )
                # Test functional.pad and transforms.Pad with padding as [int, ]
                fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m}
192
193
                _test_op(
                    F.pad, T.Pad, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
194
195
196
                )
                # Test functional.pad and transforms.Pad with padding as list
                fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
197
198
                _test_op(
                    F.pad, T.Pad, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
199
200
201
                )
                # Test functional.pad and transforms.Pad with padding as tuple
                fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
202
203
                _test_op(
                    F.pad, T.Pad, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
204
                )
205

vfdev's avatar
vfdev committed
206
207
208
209
    def test_crop(self):
        fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
        # Test transforms.RandomCrop with size and padding as tuple
        meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
210
211
        _test_op(
            F.crop, T.RandomCrop, device=self.device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
vfdev's avatar
vfdev committed
212
213
        )

214
215
        # Test transforms.functional.crop including outside the image area
        fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5}  # top
216
        _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
217
218

        fn_kwargs = {"top": 1, "left": -3, "height": 4, "width": 5}  # left
219
        _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
220
221

        fn_kwargs = {"top": 7, "left": 3, "height": 4, "width": 5}  # bottom
222
        _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
223
224

        fn_kwargs = {"top": 3, "left": 8, "height": 4, "width": 5}  # right
225
        _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
226
227

        fn_kwargs = {"top": -3, "left": -3, "height": 15, "width": 15}  # all
228
        _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=self.device)
229

vfdev's avatar
vfdev committed
230
231
232
233
234
235
236
237
238
239
240
241
242
        sizes = [5, [5, ], [6, 6]]
        padding_configs = [
            {"padding_mode": "constant", "fill": 0},
            {"padding_mode": "constant", "fill": 10},
            {"padding_mode": "constant", "fill": 20},
            {"padding_mode": "edge"},
            {"padding_mode": "reflect"},
        ]

        for size in sizes:
            for padding_config in padding_configs:
                config = dict(padding_config)
                config["size"] = size
243
                _test_class_op(T.RandomCrop, self.device, config)
vfdev's avatar
vfdev committed
244
245
246
247

    def test_center_crop(self):
        fn_kwargs = {"output_size": (4, 5)}
        meth_kwargs = {"size": (4, 5), }
248
249
250
        _test_op(
            F.center_crop, T.CenterCrop, device=self.device, fn_kwargs=fn_kwargs,
            meth_kwargs=meth_kwargs
vfdev's avatar
vfdev committed
251
252
253
        )
        fn_kwargs = {"output_size": (5,)}
        meth_kwargs = {"size": (5, )}
254
255
256
        _test_op(
            F.center_crop, T.CenterCrop, device=self.device, fn_kwargs=fn_kwargs,
            meth_kwargs=meth_kwargs
vfdev's avatar
vfdev committed
257
        )
258
        tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        # Test torchscript of transforms.CenterCrop with size as int
        f = T.CenterCrop(size=5)
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

        # Test torchscript of transforms.CenterCrop with size as [int, ]
        f = T.CenterCrop(size=[5, ])
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

        # Test torchscript of transforms.CenterCrop with size as tuple
        f = T.CenterCrop(size=(6, 6))
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

274
275
276
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt"))

277
    def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
vfdev's avatar
vfdev committed
278
279
280
281
        if fn_kwargs is None:
            fn_kwargs = {}
        if meth_kwargs is None:
            meth_kwargs = {}
282
283
284
285

        fn = getattr(F, func)
        scripted_fn = torch.jit.script(fn)

Nicolas Hug's avatar
Nicolas Hug committed
286
        tensor, pil_img = _create_data(height=20, width=20, device=self.device)
287
288
        transformed_t_list = fn(tensor, **fn_kwargs)
        transformed_p_list = fn(pil_img, **fn_kwargs)
vfdev's avatar
vfdev committed
289
290
291
        self.assertEqual(len(transformed_t_list), len(transformed_p_list))
        self.assertEqual(len(transformed_t_list), out_length)
        for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
Nicolas Hug's avatar
Nicolas Hug committed
292
            _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)
vfdev's avatar
vfdev committed
293
294
295
296
297

        transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
        self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
        self.assertEqual(len(transformed_t_list_script), out_length)
        for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
298
299
300
301
302
            assert_equal(
                transformed_tensor,
                transformed_tensor_script,
                msg="{} vs {}".format(transformed_tensor, transformed_tensor_script),
            )
vfdev's avatar
vfdev committed
303
304

        # test for class interface
305
306
        fn = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(fn)
vfdev's avatar
vfdev committed
307
308
309
        output = scripted_fn(tensor)
        self.assertEqual(len(output), len(transformed_t_list_script))

310
        # test on batch of tensors
Nicolas Hug's avatar
Nicolas Hug committed
311
        batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
312
313
314
315
316
317
318
319
        torch.manual_seed(12)
        transformed_batch_list = fn(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img_list = fn(img_tensor)
            for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
320
321
322
323
324
                assert_equal(
                    transformed_img,
                    transformed_batch[i, ...],
                    msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]),
                )
325

326
327
328
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))

vfdev's avatar
vfdev committed
329
330
    def test_five_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
331
        self._test_op_list_output(
vfdev's avatar
vfdev committed
332
333
334
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
335
        self._test_op_list_output(
vfdev's avatar
vfdev committed
336
337
338
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
339
        self._test_op_list_output(
vfdev's avatar
vfdev committed
340
341
342
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
343
        self._test_op_list_output(
vfdev's avatar
vfdev committed
344
345
346
347
348
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

    def test_ten_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
349
        self._test_op_list_output(
vfdev's avatar
vfdev committed
350
351
352
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
353
        self._test_op_list_output(
vfdev's avatar
vfdev committed
354
355
356
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
357
        self._test_op_list_output(
vfdev's avatar
vfdev committed
358
359
360
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
361
        self._test_op_list_output(
vfdev's avatar
vfdev committed
362
363
364
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
365
    def test_resize(self):
366
367
368
369
370
371
372
373
374
375
376

        # TODO: Minimal check for bug-fix, improve this later
        x = torch.rand(3, 32, 46)
        t = T.Resize(size=38)
        y = t(x)
        # If size is an int, smaller edge of the image will be matched to this number.
        # i.e, if height > width, then image will be rescaled to (size * height / width, size).
        self.assertTrue(isinstance(y, torch.Tensor))
        self.assertEqual(y.shape[1], 38)
        self.assertEqual(y.shape[2], int(38 * 46 / 32))

Nicolas Hug's avatar
Nicolas Hug committed
377
        tensor, _ = _create_data(height=34, width=36, device=self.device)
378
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
379
380
381
382
383

        for dt in [None, torch.float32, torch.float64]:
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
384
            for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
385
386
387
388
                for max_size in (None, 35, 1000):
                    if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
                        continue  # Not supported
                    for interpolation in [BILINEAR, BICUBIC, NEAREST]:
vfdev's avatar
vfdev committed
389

390
391
392
393
                        if isinstance(size, int):
                            script_size = [size, ]
                        else:
                            script_size = size
vfdev's avatar
vfdev committed
394

395
396
                        transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size)
                        s_transform = torch.jit.script(transform)
397
398
                        _test_transform_vs_scripted(transform, s_transform, tensor)
                        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
vfdev's avatar
vfdev committed
399

400
        with get_tmp_dir() as tmp_dir:
401
            s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))
402

403
    def test_resized_crop(self):
404
405
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
406

407
408
        for scale in [(0.7, 1.2), [0.7, 1.2]]:
            for ratio in [(0.75, 1.333), [0.75, 1.333]]:
409
                for size in [(32, ), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]:
410
411
412
413
414
                    for interpolation in [NEAREST, BILINEAR, BICUBIC]:
                        transform = T.RandomResizedCrop(
                            size=size, scale=scale, ratio=ratio, interpolation=interpolation
                        )
                        s_transform = torch.jit.script(transform)
415
416
                        _test_transform_vs_scripted(transform, s_transform, tensor)
                        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
417

418
419
420
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))

421
    def test_random_affine(self):
422
423
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
424

425
426
427
428
        def _test(**kwargs):
            transform = T.RandomAffine(**kwargs)
            s_transform = torch.jit.script(transform)

429
430
            _test_transform_vs_scripted(transform, s_transform, tensor)
            _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
431
432
433
434
435
436
437

            return s_transform

        for interpolation in [NEAREST, BILINEAR]:
            for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
                _test(degrees=0.0, interpolation=interpolation, shear=shear)

438
            for scale in [(0.7, 1.2), [0.7, 1.2]]:
439
440
441
442
443
444
445
446
447
448
                _test(degrees=0.0, interpolation=interpolation, scale=scale)

            for translate in [(0.1, 0.2), [0.2, 0.1]]:
                _test(degrees=0.0, interpolation=interpolation, translate=translate)

            for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                _test(degrees=degrees, interpolation=interpolation)

            for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
                _test(degrees=0.0, interpolation=interpolation, fill=fill)
449

450
        s_transform = _test(degrees=0.0)
451
452
453
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt"))

454
    def test_random_rotate(self):
455
456
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
457
458
459
460
461

        for center in [(0, 0), [10, 10], None, (56, 44)]:
            for expand in [True, False]:
                for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                    for interpolation in [NEAREST, BILINEAR]:
462
463
464
465
466
                        for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
                            transform = T.RandomRotation(
                                degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill
                            )
                            s_transform = torch.jit.script(transform)
467

468
469
                            _test_transform_vs_scripted(transform, s_transform, tensor)
                            _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
470

471
472
473
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))

474
    def test_random_perspective(self):
475
476
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
477
478
479

        for distortion_scale in np.linspace(0.1, 1.0, num=20):
            for interpolation in [NEAREST, BILINEAR]:
480
481
482
483
484
485
486
                for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
                    transform = T.RandomPerspective(
                        distortion_scale=distortion_scale,
                        interpolation=interpolation,
                        fill=fill
                    )
                    s_transform = torch.jit.script(transform)
487

488
489
                    _test_transform_vs_scripted(transform, s_transform, tensor)
                    _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
490

491
492
493
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))

494
495
496
497
    def test_to_grayscale(self):

        meth_kwargs = {"num_output_channels": 1}
        tol = 1.0 + 1e-10
498
499
500
        _test_class_op(
            T.Grayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
            tol=tol, agg_method="max"
501
502
503
        )

        meth_kwargs = {"num_output_channels": 3}
504
505
506
        _test_class_op(
            T.Grayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
            tol=tol, agg_method="max"
507
508
509
        )

        meth_kwargs = {}
510
511
512
        _test_class_op(
            T.RandomGrayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
            tol=tol, agg_method="max"
513
514
        )

515
    def test_normalize(self):
516
        fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
Nicolas Hug's avatar
Nicolas Hug committed
517
        tensor, _ = _create_data(26, 34, device=self.device)
518

519
520
521
522
        with self.assertRaisesRegex(TypeError, r"Input tensor should be a float tensor"):
            fn(tensor)

        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
523
524
525
526
        tensor = tensor.to(dtype=torch.float32) / 255.0
        # test for class interface
        scripted_fn = torch.jit.script(fn)

527
528
        _test_transform_vs_scripted(fn, scripted_fn, tensor)
        _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
529

530
531
532
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))

533
534
535
    def test_linear_transformation(self):
        c, h, w = 3, 24, 32

Nicolas Hug's avatar
Nicolas Hug committed
536
        tensor, _ = _create_data(h, w, channels=c, device=self.device)
537
538
539
540
541
542
543

        matrix = torch.rand(c * h * w, c * h * w, device=self.device)
        mean_vector = torch.rand(c * h * w, device=self.device)

        fn = T.LinearTransformation(matrix, mean_vector)
        scripted_fn = torch.jit.script(fn)

544
        _test_transform_vs_scripted(fn, scripted_fn, tensor)
545
546
547
548
549
550
551
552

        batch_tensors = torch.rand(4, c, h, w, device=self.device)
        # We skip some tests from _test_transform_vs_scripted_on_batch as
        # results for scripted and non-scripted transformations are not exactly the same
        torch.manual_seed(12)
        transformed_batch = fn(batch_tensors)
        torch.manual_seed(12)
        s_transformed_batch = scripted_fn(batch_tensors)
553
        assert_equal(transformed_batch, s_transformed_batch)
554

555
556
557
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))

558
    def test_compose(self):
Nicolas Hug's avatar
Nicolas Hug committed
559
        tensor, _ = _create_data(26, 34, device=self.device)
560
561
562
563
564
565
566
567
568
569
570
571
572
        tensor = tensor.to(dtype=torch.float32) / 255.0

        transforms = T.Compose([
            T.CenterCrop(10),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        s_transforms = torch.nn.Sequential(*transforms.transforms)

        scripted_fn = torch.jit.script(s_transforms)
        torch.manual_seed(12)
        transformed_tensor = transforms(tensor)
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
573
        assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
574
575
576
577
578
579
580

        t = T.Compose([
            lambda x: x,
        ])
        with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
            torch.jit.script(t)

581
    def test_random_apply(self):
Nicolas Hug's avatar
Nicolas Hug committed
582
        tensor, _ = _create_data(26, 34, device=self.device)
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
        tensor = tensor.to(dtype=torch.float32) / 255.0

        transforms = T.RandomApply([
            T.RandomHorizontalFlip(),
            T.ColorJitter(),
        ], p=0.4)
        s_transforms = T.RandomApply(torch.nn.ModuleList([
            T.RandomHorizontalFlip(),
            T.ColorJitter(),
        ]), p=0.4)

        scripted_fn = torch.jit.script(s_transforms)
        torch.manual_seed(12)
        transformed_tensor = transforms(tensor)
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
599
        assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
600
601
602
603
604
605
606
607
608
609

        if torch.device(self.device).type == "cpu":
            # Can't check this twice, otherwise
            # "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply"
            transforms = T.RandomApply([
                T.ColorJitter(),
            ], p=0.3)
            with self.assertRaisesRegex(RuntimeError, r"Module 'RandomApply' has no attribute 'transforms'"):
                torch.jit.script(transforms)

610
611
    def test_gaussian_blur(self):
        tol = 1.0 + 1e-10
612
613
614
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": 3, "sigma": 0.75},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
615
616
        )

617
618
619
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
620
621
        )

622
623
624
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
625
626
        )

627
628
629
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
630
631
        )

632
633
634
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
635
636
        )

637
638
639
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": [23], "sigma": 0.75},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
640
641
        )

vfdev's avatar
vfdev committed
642
643
644
645
646
647
648
649
    def test_random_erasing(self):
        img = torch.rand(3, 60, 60)

        # Test Set 0: invalid value
        random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
        with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
            random_erasing(img)

Nicolas Hug's avatar
Nicolas Hug committed
650
        tensor, _ = _create_data(24, 32, channels=3, device=self.device)
vfdev's avatar
vfdev committed
651
652
653
654
655
656
657
658
659
660
661
662
        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

        test_configs = [
            {"value": 0.2},
            {"value": "random"},
            {"value": (0.2, 0.2, 0.2)},
            {"value": "random", "ratio": (0.1, 0.2)},
        ]

        for config in test_configs:
            fn = T.RandomErasing(**config)
            scripted_fn = torch.jit.script(fn)
663
664
            _test_transform_vs_scripted(fn, scripted_fn, tensor)
            _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
vfdev's avatar
vfdev committed
665

666
667
668
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))

669
    def test_convert_image_dtype(self):
Nicolas Hug's avatar
Nicolas Hug committed
670
        tensor, _ = _create_data(26, 34, device=self.device)
671
672
673
674
675
676
677
678
679
680
681
682
683
        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

        for in_dtype in int_dtypes() + float_dtypes():
            in_tensor = tensor.to(in_dtype)
            in_batch_tensors = batch_tensors.to(in_dtype)
            for out_dtype in int_dtypes() + float_dtypes():

                fn = T.ConvertImageDtype(dtype=out_dtype)
                scripted_fn = torch.jit.script(fn)

                if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \
                        (in_dtype == torch.float64 and out_dtype == torch.int64):
                    with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
684
                        _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
685
                    with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
686
                        _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
687
688
                    continue

689
690
                _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
                _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
691
692
693
694

        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))

695
696
697
698
    def test_autoaugment(self):
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)

699
        s_transform = None
700
701
        for policy in T.AutoAugmentPolicy:
            for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
702
703
                transform = T.AutoAugment(policy=policy, fill=fill)
                s_transform = torch.jit.script(transform)
704
                for _ in range(25):
705
706
                    _test_transform_vs_scripted(transform, s_transform, tensor)
                    _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
707

708
709
710
        if s_transform is not None:
            with get_tmp_dir() as tmp_dir:
                s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))
711

712

713
714
715
716
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

    def setUp(self):
717
        torch.set_deterministic(False)
718
719
720
        self.device = "cuda"


721
722
if __name__ == '__main__':
    unittest.main()