test_transforms_tensor.py 30.2 KB
Newer Older
1
import os
2
3
4
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
5
from torchvision.transforms import InterpolationMode
6
7
8
9

import numpy as np

import unittest
10
import pytest
11
from typing import Sequence
12

Nicolas Hug's avatar
Nicolas Hug committed
13
14
15
16
17
18
19
20
from common_utils import (
    get_tmp_dir,
    int_dtypes,
    float_dtypes,
    _create_data,
    _create_data_batch,
    _assert_equal_tensor_to_pil,
    _assert_approx_equal_tensor_to_pil,
21
    cpu_and_gpu
Nicolas Hug's avatar
Nicolas Hug committed
22
)
23
from _assert_utils import assert_equal
24
25


26
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
27
28


29
30
31
32
33
34
def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None):
    torch.manual_seed(12)
    out1 = transform(tensor)
    torch.manual_seed(12)
    out2 = s_transform(tensor)
    assert_equal(out1, out2, msg=msg)
35

36

37
38
39
def _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors, msg=None):
    torch.manual_seed(12)
    transformed_batch = transform(batch_tensors)
40

41
42
    for i in range(len(batch_tensors)):
        img_tensor = batch_tensors[i, ...]
43
        torch.manual_seed(12)
44
45
        transformed_img = transform(img_tensor)
        assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
46

47
48
49
    torch.manual_seed(12)
    s_transformed_batch = s_transform(batch_tensors)
    assert_equal(transformed_batch, s_transformed_batch, msg=msg)
50
51


52
53
def _test_functional_op(f, device, fn_kwargs=None, test_exact_match=True, **match_kwargs):
    fn_kwargs = fn_kwargs or {}
54

55
56
57
58
59
60
61
    tensor, pil_img = _create_data(height=10, width=10, device=device)
    transformed_tensor = f(tensor, **fn_kwargs)
    transformed_pil_img = f(pil_img, **fn_kwargs)
    if test_exact_match:
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
    else:
        _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
vfdev's avatar
vfdev committed
62
63


64
65
66
def _test_class_op(method, device, meth_kwargs=None, test_exact_match=True, **match_kwargs):
    # TODO: change the name: it's not a method, it's a class.
    meth_kwargs = meth_kwargs or {}
67

68
69
70
    # test for class interface
    f = method(**meth_kwargs)
    scripted_fn = torch.jit.script(f)
71

72
73
74
75
76
77
78
79
80
81
    tensor, pil_img = _create_data(26, 34, device=device)
    # set seed to reproduce the same transformation for tensor and PIL image
    torch.manual_seed(12)
    transformed_tensor = f(tensor)
    torch.manual_seed(12)
    transformed_pil_img = f(pil_img)
    if test_exact_match:
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
    else:
        _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
82

83
84
85
86
87
88
89
90
91
    torch.manual_seed(12)
    transformed_tensor_script = scripted_fn(tensor)
    assert_equal(transformed_tensor, transformed_tensor_script)

    batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
    _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)

    with get_tmp_dir() as tmp_dir:
        scripted_fn.save(os.path.join(tmp_dir, f"t_{method.__name__}.pt"))
92

93
94
95
96
97
98
99
100
101
102

def _test_op(func, method, device, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
    _test_functional_op(func, device, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
    _test_class_op(method, device, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)


class Tester(unittest.TestCase):

    def setUp(self):
        self.device = "cpu"
103
104

    def test_random_horizontal_flip(self):
105
        _test_op(F.hflip, T.RandomHorizontalFlip, device=self.device)
106
107

    def test_random_vertical_flip(self):
108
        _test_op(F.vflip, T.RandomVerticalFlip, device=self.device)
109

110
    def test_random_invert(self):
111
        _test_op(F.invert, T.RandomInvert, device=self.device)
112
113
114

    def test_random_posterize(self):
        fn_kwargs = meth_kwargs = {"bits": 4}
115
116
117
        _test_op(
            F.posterize, T.RandomPosterize, device=self.device, fn_kwargs=fn_kwargs,
            meth_kwargs=meth_kwargs
118
119
120
121
        )

    def test_random_solarize(self):
        fn_kwargs = meth_kwargs = {"threshold": 192.0}
122
123
124
        _test_op(
            F.solarize, T.RandomSolarize, device=self.device, fn_kwargs=fn_kwargs,
            meth_kwargs=meth_kwargs
125
126
127
128
        )

    def test_random_adjust_sharpness(self):
        fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
129
130
131
        _test_op(
            F.adjust_sharpness, T.RandomAdjustSharpness, device=self.device, fn_kwargs=fn_kwargs,
            meth_kwargs=meth_kwargs
132
133
134
        )

    def test_random_autocontrast(self):
135
136
        # We check the max abs difference because on some (very rare) pixels, the actual value may be different
        # between PIL and tensors due to floating approximations.
137
138
139
140
        _test_op(
            F.autocontrast, T.RandomAutocontrast, device=self.device, test_exact_match=False,
            agg_method='max', tol=(1 + 1e-5), allowed_percentage_diff=.05
        )
141
142

    def test_random_equalize(self):
143
        _test_op(F.equalize, T.RandomEqualize, device=self.device)
144

145
    def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
vfdev's avatar
vfdev committed
146
147
148
149
        if fn_kwargs is None:
            fn_kwargs = {}
        if meth_kwargs is None:
            meth_kwargs = {}
150
151
152
153

        fn = getattr(F, func)
        scripted_fn = torch.jit.script(fn)

Nicolas Hug's avatar
Nicolas Hug committed
154
        tensor, pil_img = _create_data(height=20, width=20, device=self.device)
155
156
        transformed_t_list = fn(tensor, **fn_kwargs)
        transformed_p_list = fn(pil_img, **fn_kwargs)
vfdev's avatar
vfdev committed
157
158
159
        self.assertEqual(len(transformed_t_list), len(transformed_p_list))
        self.assertEqual(len(transformed_t_list), out_length)
        for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
Nicolas Hug's avatar
Nicolas Hug committed
160
            _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)
vfdev's avatar
vfdev committed
161
162
163
164
165

        transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
        self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
        self.assertEqual(len(transformed_t_list_script), out_length)
        for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
166
167
168
169
170
            assert_equal(
                transformed_tensor,
                transformed_tensor_script,
                msg="{} vs {}".format(transformed_tensor, transformed_tensor_script),
            )
vfdev's avatar
vfdev committed
171
172

        # test for class interface
173
174
        fn = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(fn)
vfdev's avatar
vfdev committed
175
176
177
        output = scripted_fn(tensor)
        self.assertEqual(len(output), len(transformed_t_list_script))

178
        # test on batch of tensors
Nicolas Hug's avatar
Nicolas Hug committed
179
        batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
180
181
182
183
184
185
186
187
        torch.manual_seed(12)
        transformed_batch_list = fn(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img_list = fn(img_tensor)
            for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
188
189
190
191
192
                assert_equal(
                    transformed_img,
                    transformed_batch[i, ...],
                    msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]),
                )
193

194
195
196
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))

vfdev's avatar
vfdev committed
197
198
    def test_five_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
199
        self._test_op_list_output(
vfdev's avatar
vfdev committed
200
201
202
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
203
        self._test_op_list_output(
vfdev's avatar
vfdev committed
204
205
206
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
207
        self._test_op_list_output(
vfdev's avatar
vfdev committed
208
209
210
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
211
        self._test_op_list_output(
vfdev's avatar
vfdev committed
212
213
214
215
216
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

    def test_ten_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
217
        self._test_op_list_output(
vfdev's avatar
vfdev committed
218
219
220
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
221
        self._test_op_list_output(
vfdev's avatar
vfdev committed
222
223
224
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
225
        self._test_op_list_output(
vfdev's avatar
vfdev committed
226
227
228
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
229
        self._test_op_list_output(
vfdev's avatar
vfdev committed
230
231
232
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
233
    def test_resize(self):
234
235
236
237
238
239
240
241
242
243
244

        # TODO: Minimal check for bug-fix, improve this later
        x = torch.rand(3, 32, 46)
        t = T.Resize(size=38)
        y = t(x)
        # If size is an int, smaller edge of the image will be matched to this number.
        # i.e, if height > width, then image will be rescaled to (size * height / width, size).
        self.assertTrue(isinstance(y, torch.Tensor))
        self.assertEqual(y.shape[1], 38)
        self.assertEqual(y.shape[2], int(38 * 46 / 32))

Nicolas Hug's avatar
Nicolas Hug committed
245
        tensor, _ = _create_data(height=34, width=36, device=self.device)
246
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
247
248
249
250
251

        for dt in [None, torch.float32, torch.float64]:
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
252
            for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
253
254
255
256
                for max_size in (None, 35, 1000):
                    if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
                        continue  # Not supported
                    for interpolation in [BILINEAR, BICUBIC, NEAREST]:
vfdev's avatar
vfdev committed
257

258
259
260
261
                        if isinstance(size, int):
                            script_size = [size, ]
                        else:
                            script_size = size
vfdev's avatar
vfdev committed
262

263
264
                        transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size)
                        s_transform = torch.jit.script(transform)
265
266
                        _test_transform_vs_scripted(transform, s_transform, tensor)
                        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
vfdev's avatar
vfdev committed
267

268
        with get_tmp_dir() as tmp_dir:
269
            s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))
270

271
    def test_resized_crop(self):
272
273
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
274

275
276
        for scale in [(0.7, 1.2), [0.7, 1.2]]:
            for ratio in [(0.75, 1.333), [0.75, 1.333]]:
277
                for size in [(32, ), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]:
278
279
280
281
282
                    for interpolation in [NEAREST, BILINEAR, BICUBIC]:
                        transform = T.RandomResizedCrop(
                            size=size, scale=scale, ratio=ratio, interpolation=interpolation
                        )
                        s_transform = torch.jit.script(transform)
283
284
                        _test_transform_vs_scripted(transform, s_transform, tensor)
                        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
285

286
287
288
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))

289
    def test_random_affine(self):
290
291
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
292

293
294
295
296
        def _test(**kwargs):
            transform = T.RandomAffine(**kwargs)
            s_transform = torch.jit.script(transform)

297
298
            _test_transform_vs_scripted(transform, s_transform, tensor)
            _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
299
300
301
302
303
304
305

            return s_transform

        for interpolation in [NEAREST, BILINEAR]:
            for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
                _test(degrees=0.0, interpolation=interpolation, shear=shear)

306
            for scale in [(0.7, 1.2), [0.7, 1.2]]:
307
308
309
310
311
312
313
314
315
316
                _test(degrees=0.0, interpolation=interpolation, scale=scale)

            for translate in [(0.1, 0.2), [0.2, 0.1]]:
                _test(degrees=0.0, interpolation=interpolation, translate=translate)

            for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                _test(degrees=degrees, interpolation=interpolation)

            for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
                _test(degrees=0.0, interpolation=interpolation, fill=fill)
317

318
        s_transform = _test(degrees=0.0)
319
320
321
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt"))

322
    def test_random_rotate(self):
323
324
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
325
326
327
328
329

        for center in [(0, 0), [10, 10], None, (56, 44)]:
            for expand in [True, False]:
                for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                    for interpolation in [NEAREST, BILINEAR]:
330
331
332
333
334
                        for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
                            transform = T.RandomRotation(
                                degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill
                            )
                            s_transform = torch.jit.script(transform)
335

336
337
                            _test_transform_vs_scripted(transform, s_transform, tensor)
                            _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
338

339
340
341
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))

342
    def test_random_perspective(self):
343
344
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
345
346
347

        for distortion_scale in np.linspace(0.1, 1.0, num=20):
            for interpolation in [NEAREST, BILINEAR]:
348
349
350
351
352
353
354
                for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
                    transform = T.RandomPerspective(
                        distortion_scale=distortion_scale,
                        interpolation=interpolation,
                        fill=fill
                    )
                    s_transform = torch.jit.script(transform)
355

356
357
                    _test_transform_vs_scripted(transform, s_transform, tensor)
                    _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
358

359
360
361
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))

362
363
364
365
    def test_to_grayscale(self):

        meth_kwargs = {"num_output_channels": 1}
        tol = 1.0 + 1e-10
366
367
368
        _test_class_op(
            T.Grayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
            tol=tol, agg_method="max"
369
370
371
        )

        meth_kwargs = {"num_output_channels": 3}
372
373
374
        _test_class_op(
            T.Grayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
            tol=tol, agg_method="max"
375
376
377
        )

        meth_kwargs = {}
378
379
380
        _test_class_op(
            T.RandomGrayscale, meth_kwargs=meth_kwargs, test_exact_match=False, device=self.device,
            tol=tol, agg_method="max"
381
382
        )

383
    def test_normalize(self):
384
        fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
Nicolas Hug's avatar
Nicolas Hug committed
385
        tensor, _ = _create_data(26, 34, device=self.device)
386

387
388
389
390
        with self.assertRaisesRegex(TypeError, r"Input tensor should be a float tensor"):
            fn(tensor)

        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
391
392
393
394
        tensor = tensor.to(dtype=torch.float32) / 255.0
        # test for class interface
        scripted_fn = torch.jit.script(fn)

395
396
        _test_transform_vs_scripted(fn, scripted_fn, tensor)
        _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
397

398
399
400
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))

401
402
403
    def test_linear_transformation(self):
        c, h, w = 3, 24, 32

Nicolas Hug's avatar
Nicolas Hug committed
404
        tensor, _ = _create_data(h, w, channels=c, device=self.device)
405
406
407
408
409
410
411

        matrix = torch.rand(c * h * w, c * h * w, device=self.device)
        mean_vector = torch.rand(c * h * w, device=self.device)

        fn = T.LinearTransformation(matrix, mean_vector)
        scripted_fn = torch.jit.script(fn)

412
        _test_transform_vs_scripted(fn, scripted_fn, tensor)
413
414
415
416
417
418
419
420

        batch_tensors = torch.rand(4, c, h, w, device=self.device)
        # We skip some tests from _test_transform_vs_scripted_on_batch as
        # results for scripted and non-scripted transformations are not exactly the same
        torch.manual_seed(12)
        transformed_batch = fn(batch_tensors)
        torch.manual_seed(12)
        s_transformed_batch = scripted_fn(batch_tensors)
421
        assert_equal(transformed_batch, s_transformed_batch)
422

423
424
425
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))

426
    def test_compose(self):
Nicolas Hug's avatar
Nicolas Hug committed
427
        tensor, _ = _create_data(26, 34, device=self.device)
428
429
430
431
432
433
434
435
436
437
438
439
440
        tensor = tensor.to(dtype=torch.float32) / 255.0

        transforms = T.Compose([
            T.CenterCrop(10),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        s_transforms = torch.nn.Sequential(*transforms.transforms)

        scripted_fn = torch.jit.script(s_transforms)
        torch.manual_seed(12)
        transformed_tensor = transforms(tensor)
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
441
        assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
442
443
444
445
446
447
448

        t = T.Compose([
            lambda x: x,
        ])
        with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
            torch.jit.script(t)

449
    def test_random_apply(self):
Nicolas Hug's avatar
Nicolas Hug committed
450
        tensor, _ = _create_data(26, 34, device=self.device)
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        tensor = tensor.to(dtype=torch.float32) / 255.0

        transforms = T.RandomApply([
            T.RandomHorizontalFlip(),
            T.ColorJitter(),
        ], p=0.4)
        s_transforms = T.RandomApply(torch.nn.ModuleList([
            T.RandomHorizontalFlip(),
            T.ColorJitter(),
        ]), p=0.4)

        scripted_fn = torch.jit.script(s_transforms)
        torch.manual_seed(12)
        transformed_tensor = transforms(tensor)
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
467
        assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))
468
469
470
471
472
473
474
475
476
477

        if torch.device(self.device).type == "cpu":
            # Can't check this twice, otherwise
            # "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply"
            transforms = T.RandomApply([
                T.ColorJitter(),
            ], p=0.3)
            with self.assertRaisesRegex(RuntimeError, r"Module 'RandomApply' has no attribute 'transforms'"):
                torch.jit.script(transforms)

478
479
    def test_gaussian_blur(self):
        tol = 1.0 + 1e-10
480
481
482
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": 3, "sigma": 0.75},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
483
484
        )

485
486
487
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": 23, "sigma": [0.1, 2.0]},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
488
489
        )

490
491
492
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": 23, "sigma": (0.1, 2.0)},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
493
494
        )

495
496
497
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
498
499
        )

500
501
502
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
503
504
        )

505
506
507
        _test_class_op(
            T.GaussianBlur, meth_kwargs={"kernel_size": [23], "sigma": 0.75},
            test_exact_match=False, device=self.device, agg_method="max", tol=tol
508
509
        )

vfdev's avatar
vfdev committed
510
511
512
513
514
515
516
517
    def test_random_erasing(self):
        img = torch.rand(3, 60, 60)

        # Test Set 0: invalid value
        random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
        with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
            random_erasing(img)

Nicolas Hug's avatar
Nicolas Hug committed
518
        tensor, _ = _create_data(24, 32, channels=3, device=self.device)
vfdev's avatar
vfdev committed
519
520
521
522
523
524
525
526
527
528
529
530
        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

        test_configs = [
            {"value": 0.2},
            {"value": "random"},
            {"value": (0.2, 0.2, 0.2)},
            {"value": "random", "ratio": (0.1, 0.2)},
        ]

        for config in test_configs:
            fn = T.RandomErasing(**config)
            scripted_fn = torch.jit.script(fn)
531
532
            _test_transform_vs_scripted(fn, scripted_fn, tensor)
            _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
vfdev's avatar
vfdev committed
533

534
535
536
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_random_erasing.pt"))

537
    def test_convert_image_dtype(self):
Nicolas Hug's avatar
Nicolas Hug committed
538
        tensor, _ = _create_data(26, 34, device=self.device)
539
540
541
542
543
544
545
546
547
548
549
550
551
        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

        for in_dtype in int_dtypes() + float_dtypes():
            in_tensor = tensor.to(in_dtype)
            in_batch_tensors = batch_tensors.to(in_dtype)
            for out_dtype in int_dtypes() + float_dtypes():

                fn = T.ConvertImageDtype(dtype=out_dtype)
                scripted_fn = torch.jit.script(fn)

                if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \
                        (in_dtype == torch.float64 and out_dtype == torch.int64):
                    with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
552
                        _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
553
                    with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
554
                        _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
555
556
                    continue

557
558
                _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
                _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
559
560
561
562

        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))

563
564
565
566
    def test_autoaugment(self):
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)

567
        s_transform = None
568
569
        for policy in T.AutoAugmentPolicy:
            for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
570
571
                transform = T.AutoAugment(policy=policy, fill=fill)
                s_transform = torch.jit.script(transform)
572
                for _ in range(25):
573
574
                    _test_transform_vs_scripted(transform, s_transform, tensor)
                    _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
575

576
577
578
        if s_transform is not None:
            with get_tmp_dir() as tmp_dir:
                s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))
579

580

581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
@pytest.mark.parametrize('device', cpu_and_gpu())
class TestColorJitter:

    @pytest.mark.parametrize('brightness', [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]])
    def test_color_jitter_brightness(self, brightness, device):
        tol = 1.0 + 1e-10
        meth_kwargs = {"brightness": brightness}
        _test_class_op(
            T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
            tol=tol, agg_method="max"
        )

    @pytest.mark.parametrize('contrast', [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]])
    def test_color_jitter_contrast(self, contrast, device):
        tol = 1.0 + 1e-10
        meth_kwargs = {"contrast": contrast}
        _test_class_op(
            T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
            tol=tol, agg_method="max"
        )

    @pytest.mark.parametrize('saturation', [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]])
    def test_color_jitter_saturation(self, saturation, device):
        tol = 1.0 + 1e-10
        meth_kwargs = {"saturation": saturation}
        _test_class_op(
            T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
            tol=tol, agg_method="max"
        )

    @pytest.mark.parametrize('hue', [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]])
    def test_color_jitter_hue(self, hue, device):
        meth_kwargs = {"hue": hue}
        _test_class_op(
            T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
            tol=16.1, agg_method="max"
        )

    def test_color_jitter_all(self, device):
        # All 4 parameters together
        meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
        _test_class_op(
            T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device,
            tol=12.1, agg_method="max"
        )


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('m', ["constant", "edge", "reflect", "symmetric"])
@pytest.mark.parametrize('mul', [1, -1])
def test_pad(m, mul, device):
    fill = 127 if m == "constant" else 0

    # Test functional.pad (PIL and Tensor) with padding as single int
    _test_functional_op(
        F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m},
        device=device
    )
    # Test functional.pad and transforms.Pad with padding as [int, ]
    fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m}
    _test_op(
        F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
    )
    # Test functional.pad and transforms.Pad with padding as list
    fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
    _test_op(
        F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
    )
    # Test functional.pad and transforms.Pad with padding as tuple
    fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
    _test_op(
        F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
    )


@pytest.mark.parametrize('device', cpu_and_gpu())
def test_crop(device):
    fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
    # Test transforms.RandomCrop with size and padding as tuple
    meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
    _test_op(
        F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
    )

    # Test transforms.functional.crop including outside the image area
    fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5}  # top
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)

    fn_kwargs = {"top": 1, "left": -3, "height": 4, "width": 5}  # left
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)

    fn_kwargs = {"top": 7, "left": 3, "height": 4, "width": 5}  # bottom
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)

    fn_kwargs = {"top": 3, "left": 8, "height": 4, "width": 5}  # right
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)

    fn_kwargs = {"top": -3, "left": -3, "height": 15, "width": 15}  # all
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('padding_config', [
    {"padding_mode": "constant", "fill": 0},
    {"padding_mode": "constant", "fill": 10},
    {"padding_mode": "constant", "fill": 20},
    {"padding_mode": "edge"},
    {"padding_mode": "reflect"}
])
@pytest.mark.parametrize('size', [5, [5, ], [6, 6]])
def test_crop_pad(size, padding_config, device):
    config = dict(padding_config)
    config["size"] = size
    _test_class_op(T.RandomCrop, device, config)


@pytest.mark.parametrize('device', cpu_and_gpu())
def test_center_crop(device):
    fn_kwargs = {"output_size": (4, 5)}
    meth_kwargs = {"size": (4, 5), }
    _test_op(
        F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs,
        meth_kwargs=meth_kwargs
    )
    fn_kwargs = {"output_size": (5,)}
    meth_kwargs = {"size": (5, )}
    _test_op(
        F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs,
        meth_kwargs=meth_kwargs
    )
    tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=device)
    # Test torchscript of transforms.CenterCrop with size as int
    f = T.CenterCrop(size=5)
    scripted_fn = torch.jit.script(f)
    scripted_fn(tensor)

    # Test torchscript of transforms.CenterCrop with size as [int, ]
    f = T.CenterCrop(size=[5, ])
    scripted_fn = torch.jit.script(f)
    scripted_fn(tensor)

    # Test torchscript of transforms.CenterCrop with size as tuple
    f = T.CenterCrop(size=(6, 6))
    scripted_fn = torch.jit.script(f)
    scripted_fn(tensor)

    with get_tmp_dir() as tmp_dir:
        scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt"))


731
732
733
734
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

    def setUp(self):
735
        torch.set_deterministic(False)
736
737
738
        self.device = "cuda"


739
740
if __name__ == '__main__':
    unittest.main()