test_transforms_tensor.py 22.9 KB
Newer Older
1
import os
2
3
4
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
5

vfdev's avatar
vfdev committed
6
from PIL.Image import NEAREST, BILINEAR, BICUBIC
7
8
9
10
11

import numpy as np

import unittest

12
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
13
14


15
class Tester(TransformsTester):
16

17
18
19
    def setUp(self):
        self.device = "cpu"

20
    def _test_functional_op(self, func, fn_kwargs):
21
22
        if fn_kwargs is None:
            fn_kwargs = {}
23
24

        f = getattr(F, func)
25
        tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
26
27
        transformed_tensor = f(tensor, **fn_kwargs)
        transformed_pil_img = f(pil_img, **fn_kwargs)
28
29
        self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

30
    def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
31
32
33
34
        torch.manual_seed(12)
        out1 = transform(tensor)
        torch.manual_seed(12)
        out2 = s_transform(tensor)
35
        self.assertTrue(out1.equal(out2), msg=msg)
36

37
    def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
38
39
40
41
42
43
44
        torch.manual_seed(12)
        transformed_batch = transform(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img = transform(img_tensor)
45
            self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), msg=msg)
46
47
48

        torch.manual_seed(12)
        s_transformed_batch = s_transform(batch_tensors)
49
        self.assertTrue(transformed_batch.equal(s_transformed_batch), msg=msg)
50

51
    def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
52
53
        if meth_kwargs is None:
            meth_kwargs = {}
vfdev's avatar
vfdev committed
54
55
56
57
58

        # test for class interface
        f = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(f)

59
        tensor, pil_img = self._create_data(26, 34, device=self.device)
vfdev's avatar
vfdev committed
60
61
62
63
64
        # set seed to reproduce the same transformation for tensor and PIL image
        torch.manual_seed(12)
        transformed_tensor = f(tensor)
        torch.manual_seed(12)
        transformed_pil_img = f(pil_img)
65
66
67
68
        if test_exact_match:
            self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
        else:
            self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
69

vfdev's avatar
vfdev committed
70
71
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
72
        self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
73

74
75
76
        batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
        self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)

77
78
79
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_{}.pt".format(method)))

80
81
82
    def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
        self._test_functional_op(func, fn_kwargs)
        self._test_class_op(method, meth_kwargs)
83
84

    def test_random_horizontal_flip(self):
85
        self._test_op('hflip', 'RandomHorizontalFlip')
86
87

    def test_random_vertical_flip(self):
88
        self._test_op('vflip', 'RandomVerticalFlip')
89

vfdev's avatar
vfdev committed
90
91
92
    def test_color_jitter(self):

        tol = 1.0 + 1e-10
93
        for f in [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]:
vfdev's avatar
vfdev committed
94
95
96
97
98
            meth_kwargs = {"brightness": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )

99
        for f in [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]:
vfdev's avatar
vfdev committed
100
101
102
103
104
            meth_kwargs = {"contrast": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )

105
        for f in [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]:
vfdev's avatar
vfdev committed
106
107
108
109
            meth_kwargs = {"saturation": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
            )
110

111
112
113
114
115
116
117
118
119
120
121
122
        for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
            meth_kwargs = {"hue": f}
            self._test_class_op(
                "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
            )

        # All 4 parameters together
        meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
        self._test_class_op(
            "ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
        )

123
    def test_pad(self):
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        for m in ["constant", "edge", "reflect", "symmetric"]:
            fill = 127 if m == "constant" else 0
            # Negative pad currently unsupported for Tensor and symmetric
            multipliers = [1] if m == "symmetric" else [1, -1]
            for mul in multipliers:
                # Test functional.pad (PIL and Tensor) with padding as single int
                self._test_functional_op(
                    "pad", fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}
                )
                # Test functional.pad and transforms.Pad with padding as [int, ]
                fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m}
                self._test_op(
                    "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
                )
                # Test functional.pad and transforms.Pad with padding as list
                fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
                self._test_op(
                    "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
                )
                # Test functional.pad and transforms.Pad with padding as tuple
                fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
                self._test_op(
                    "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
                )
148

vfdev's avatar
vfdev committed
149
150
151
152
    def test_crop(self):
        fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
        # Test transforms.RandomCrop with size and padding as tuple
        meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
153
        self._test_op(
vfdev's avatar
vfdev committed
154
155
156
            'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
157
158
159
160
161
162
163
164
165
166
167
168
169
        sizes = [5, [5, ], [6, 6]]
        padding_configs = [
            {"padding_mode": "constant", "fill": 0},
            {"padding_mode": "constant", "fill": 10},
            {"padding_mode": "constant", "fill": 20},
            {"padding_mode": "edge"},
            {"padding_mode": "reflect"},
        ]

        for size in sizes:
            for padding_config in padding_configs:
                config = dict(padding_config)
                config["size"] = size
170
                self._test_class_op("RandomCrop", config)
vfdev's avatar
vfdev committed
171
172
173
174

    def test_center_crop(self):
        fn_kwargs = {"output_size": (4, 5)}
        meth_kwargs = {"size": (4, 5), }
175
        self._test_op(
vfdev's avatar
vfdev committed
176
177
178
179
            "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = {"output_size": (5,)}
        meth_kwargs = {"size": (5, )}
180
        self._test_op(
vfdev's avatar
vfdev committed
181
182
            "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
183
        tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        # Test torchscript of transforms.CenterCrop with size as int
        f = T.CenterCrop(size=5)
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

        # Test torchscript of transforms.CenterCrop with size as [int, ]
        f = T.CenterCrop(size=[5, ])
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

        # Test torchscript of transforms.CenterCrop with size as tuple
        f = T.CenterCrop(size=(6, 6))
        scripted_fn = torch.jit.script(f)
        scripted_fn(tensor)

199
200
201
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt"))

202
    def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
vfdev's avatar
vfdev committed
203
204
205
206
        if fn_kwargs is None:
            fn_kwargs = {}
        if meth_kwargs is None:
            meth_kwargs = {}
207
208
209
210

        fn = getattr(F, func)
        scripted_fn = torch.jit.script(fn)

211
        tensor, pil_img = self._create_data(height=20, width=20, device=self.device)
212
213
        transformed_t_list = fn(tensor, **fn_kwargs)
        transformed_p_list = fn(pil_img, **fn_kwargs)
vfdev's avatar
vfdev committed
214
215
216
217
218
219
220
221
222
223
224
225
226
        self.assertEqual(len(transformed_t_list), len(transformed_p_list))
        self.assertEqual(len(transformed_t_list), out_length)
        for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
            self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

        transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
        self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
        self.assertEqual(len(transformed_t_list_script), out_length)
        for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
            self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
                            msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))

        # test for class interface
227
228
        fn = getattr(T, method)(**meth_kwargs)
        scripted_fn = torch.jit.script(fn)
vfdev's avatar
vfdev committed
229
230
231
        output = scripted_fn(tensor)
        self.assertEqual(len(output), len(transformed_t_list_script))

232
233
234
235
236
237
238
239
240
241
242
243
244
        # test on batch of tensors
        batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
        torch.manual_seed(12)
        transformed_batch_list = fn(batch_tensors)

        for i in range(len(batch_tensors)):
            img_tensor = batch_tensors[i, ...]
            torch.manual_seed(12)
            transformed_img_list = fn(img_tensor)
            for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
                self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
                                msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))

245
246
247
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))

vfdev's avatar
vfdev committed
248
249
    def test_five_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
250
        self._test_op_list_output(
vfdev's avatar
vfdev committed
251
252
253
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
254
        self._test_op_list_output(
vfdev's avatar
vfdev committed
255
256
257
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
258
        self._test_op_list_output(
vfdev's avatar
vfdev committed
259
260
261
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
262
        self._test_op_list_output(
vfdev's avatar
vfdev committed
263
264
265
266
267
            "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

    def test_ten_crop(self):
        fn_kwargs = meth_kwargs = {"size": (5,)}
268
        self._test_op_list_output(
vfdev's avatar
vfdev committed
269
270
271
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [5, ]}
272
        self._test_op_list_output(
vfdev's avatar
vfdev committed
273
274
275
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": (4, 5)}
276
        self._test_op_list_output(
vfdev's avatar
vfdev committed
277
278
279
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )
        fn_kwargs = meth_kwargs = {"size": [4, 5]}
280
        self._test_op_list_output(
vfdev's avatar
vfdev committed
281
282
283
            "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
        )

vfdev's avatar
vfdev committed
284
    def test_resize(self):
285
        tensor, _ = self._create_data(height=34, width=36, device=self.device)
286
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
vfdev's avatar
vfdev committed
287
288
289
290
291
292
        script_fn = torch.jit.script(F.resize)

        for dt in [None, torch.float32, torch.float64]:
            if dt is not None:
                # This is a trivial cast to float of uint8 data to test all cases
                tensor = tensor.to(dt)
293
            for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
vfdev's avatar
vfdev committed
294
295
296
297
298
299
300
301
302
303
304
305
306
                for interpolation in [BILINEAR, BICUBIC, NEAREST]:

                    resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)

                    if isinstance(size, int):
                        script_size = [size, ]
                    else:
                        script_size = size

                    s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
                    self.assertTrue(s_resized_tensor.equal(resized_tensor))

                    transform = T.Resize(size=script_size, interpolation=interpolation)
307
308
309
                    s_transform = torch.jit.script(transform)
                    self._test_transform_vs_scripted(transform, s_transform, tensor)
                    self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
vfdev's avatar
vfdev committed
310

311
312
313
        with get_tmp_dir() as tmp_dir:
            script_fn.save(os.path.join(tmp_dir, "t_resize.pt"))

314
    def test_resized_crop(self):
315
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
316
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
317

318
319
        for scale in [(0.7, 1.2), [0.7, 1.2]]:
            for ratio in [(0.75, 1.333), [0.75, 1.333]]:
320
                for size in [(32, ), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]:
321
322
323
324
325
                    for interpolation in [NEAREST, BILINEAR, BICUBIC]:
                        transform = T.RandomResizedCrop(
                            size=size, scale=scale, ratio=ratio, interpolation=interpolation
                        )
                        s_transform = torch.jit.script(transform)
326
327
                        self._test_transform_vs_scripted(transform, s_transform, tensor)
                        self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
328

329
330
331
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))

332
    def test_random_affine(self):
333
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
334
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
335
336
337
338
339
340
341
342
343
344
345
346

        for shear in [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]:
            for scale in [(0.7, 1.2), [0.7, 1.2]]:
                for translate in [(0.1, 0.2), [0.2, 0.1]]:
                    for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                        for interpolation in [NEAREST, BILINEAR]:
                            transform = T.RandomAffine(
                                degrees=degrees, translate=translate,
                                scale=scale, shear=shear, resample=interpolation
                            )
                            s_transform = torch.jit.script(transform)

347
348
                            self._test_transform_vs_scripted(transform, s_transform, tensor)
                            self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
349

350
351
352
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt"))

353
    def test_random_rotate(self):
354
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
355
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
356
357
358
359
360
361
362
363
364
365

        for center in [(0, 0), [10, 10], None, (56, 44)]:
            for expand in [True, False]:
                for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
                    for interpolation in [NEAREST, BILINEAR]:
                        transform = T.RandomRotation(
                            degrees=degrees, resample=interpolation, expand=expand, center=center
                        )
                        s_transform = torch.jit.script(transform)

366
367
                        self._test_transform_vs_scripted(transform, s_transform, tensor)
                        self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
368

369
370
371
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))

372
    def test_random_perspective(self):
373
        tensor = torch.randint(0, 255, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
374
        batch_tensors = torch.randint(0, 255, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
375
376
377
378
379
380
381
382
383

        for distortion_scale in np.linspace(0.1, 1.0, num=20):
            for interpolation in [NEAREST, BILINEAR]:
                transform = T.RandomPerspective(
                    distortion_scale=distortion_scale,
                    interpolation=interpolation
                )
                s_transform = torch.jit.script(transform)

384
385
                self._test_transform_vs_scripted(transform, s_transform, tensor)
                self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
386

387
388
389
        with get_tmp_dir() as tmp_dir:
            s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))

390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    def test_to_grayscale(self):

        meth_kwargs = {"num_output_channels": 1}
        tol = 1.0 + 1e-10
        self._test_class_op(
            "Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

        meth_kwargs = {"num_output_channels": 3}
        self._test_class_op(
            "Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

        meth_kwargs = {}
        self._test_class_op(
            "RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
        )

408
409
410
411
412
413
414
415
416
417
418
419
    def test_normalize(self):
        tensor, _ = self._create_data(26, 34, device=self.device)
        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

        tensor = tensor.to(dtype=torch.float32) / 255.0
        # test for class interface
        fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        scripted_fn = torch.jit.script(fn)

        self._test_transform_vs_scripted(fn, scripted_fn, tensor)
        self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)

420
421
422
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    def test_linear_transformation(self):
        c, h, w = 3, 24, 32

        tensor, _ = self._create_data(h, w, channels=c, device=self.device)

        matrix = torch.rand(c * h * w, c * h * w, device=self.device)
        mean_vector = torch.rand(c * h * w, device=self.device)

        fn = T.LinearTransformation(matrix, mean_vector)
        scripted_fn = torch.jit.script(fn)

        self._test_transform_vs_scripted(fn, scripted_fn, tensor)

        batch_tensors = torch.rand(4, c, h, w, device=self.device)
        # We skip some tests from _test_transform_vs_scripted_on_batch as
        # results for scripted and non-scripted transformations are not exactly the same
        torch.manual_seed(12)
        transformed_batch = fn(batch_tensors)
        torch.manual_seed(12)
        s_transformed_batch = scripted_fn(batch_tensors)
        self.assertTrue(transformed_batch.equal(s_transformed_batch))

445
446
447
        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    def test_compose(self):
        tensor, _ = self._create_data(26, 34, device=self.device)
        tensor = tensor.to(dtype=torch.float32) / 255.0

        transforms = T.Compose([
            T.CenterCrop(10),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        s_transforms = torch.nn.Sequential(*transforms.transforms)

        scripted_fn = torch.jit.script(s_transforms)
        torch.manual_seed(12)
        transformed_tensor = transforms(tensor)
        torch.manual_seed(12)
        transformed_tensor_script = scripted_fn(tensor)
        self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))

        t = T.Compose([
            lambda x: x,
        ])
        with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
            torch.jit.script(t)

vfdev's avatar
vfdev committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    def test_random_erasing(self):
        img = torch.rand(3, 60, 60)

        # Test Set 0: invalid value
        random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
        with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
            random_erasing(img)

        tensor, _ = self._create_data(24, 32, channels=3, device=self.device)
        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

        test_configs = [
            {"value": 0.2},
            {"value": "random"},
            {"value": (0.2, 0.2, 0.2)},
            {"value": "random", "ratio": (0.1, 0.2)},
        ]

        for config in test_configs:
            fn = T.RandomErasing(**config)
            scripted_fn = torch.jit.script(fn)
            self._test_transform_vs_scripted(fn, scripted_fn, tensor)
            self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)

495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    def test_convert_image_dtype(self):
        tensor, _ = self._create_data(26, 34, device=self.device)
        batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

        for in_dtype in int_dtypes() + float_dtypes():
            in_tensor = tensor.to(in_dtype)
            in_batch_tensors = batch_tensors.to(in_dtype)
            for out_dtype in int_dtypes() + float_dtypes():

                fn = T.ConvertImageDtype(dtype=out_dtype)
                scripted_fn = torch.jit.script(fn)

                if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \
                        (in_dtype == torch.float64 and out_dtype == torch.int64):
                    with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
                        self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
                    with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
                        self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
                    continue

                self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
                self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)

        with get_tmp_dir() as tmp_dir:
            scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))

521

522
523
524
525
526
527
528
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

    def setUp(self):
        self.device = "cuda"


529
530
if __name__ == '__main__':
    unittest.main()