test_transforms_tensor.py 31.1 KB
Newer Older
1
import os
2
from typing import Sequence
3
4

import numpy as np
5
import pytest
6
import torch
Nicolas Hug's avatar
Nicolas Hug committed
7
8
9
10
11
12
13
14
from common_utils import (
    get_tmp_dir,
    int_dtypes,
    float_dtypes,
    _create_data,
    _create_data_batch,
    _assert_equal_tensor_to_pil,
    _assert_approx_equal_tensor_to_pil,
15
    cpu_and_gpu,
16
    assert_equal,
Nicolas Hug's avatar
Nicolas Hug committed
17
)
18
19
20
from torchvision import transforms as T
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
21

22
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
23
24


25
26
27
28
29
30
def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None):
    torch.manual_seed(12)
    out1 = transform(tensor)
    torch.manual_seed(12)
    out2 = s_transform(tensor)
    assert_equal(out1, out2, msg=msg)
31

32

33
34
35
def _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors, msg=None):
    torch.manual_seed(12)
    transformed_batch = transform(batch_tensors)
36

37
38
    for i in range(len(batch_tensors)):
        img_tensor = batch_tensors[i, ...]
39
        torch.manual_seed(12)
40
41
        transformed_img = transform(img_tensor)
        assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)
42

43
44
45
    torch.manual_seed(12)
    s_transformed_batch = s_transform(batch_tensors)
    assert_equal(transformed_batch, s_transformed_batch, msg=msg)
46
47


48
def _test_functional_op(f, device, channels=3, fn_kwargs=None, test_exact_match=True, **match_kwargs):
49
    fn_kwargs = fn_kwargs or {}
50

51
    tensor, pil_img = _create_data(height=10, width=10, channels=channels, device=device)
52
53
54
55
56
57
    transformed_tensor = f(tensor, **fn_kwargs)
    transformed_pil_img = f(pil_img, **fn_kwargs)
    if test_exact_match:
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
    else:
        _assert_approx_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
vfdev's avatar
vfdev committed
58
59


60
def _test_class_op(method, device, channels=3, meth_kwargs=None, test_exact_match=True, **match_kwargs):
61
62
    # TODO: change the name: it's not a method, it's a class.
    meth_kwargs = meth_kwargs or {}
63

64
65
66
    # test for class interface
    f = method(**meth_kwargs)
    scripted_fn = torch.jit.script(f)
67

68
    tensor, pil_img = _create_data(26, 34, channels, device=device)
69
70
71
72
73
74
75
76
77
    # set seed to reproduce the same transformation for tensor and PIL image
    torch.manual_seed(12)
    transformed_tensor = f(tensor)
    torch.manual_seed(12)
    transformed_pil_img = f(pil_img)
    if test_exact_match:
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img, **match_kwargs)
    else:
        _assert_approx_equal_tensor_to_pil(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
78

79
80
81
82
    torch.manual_seed(12)
    transformed_tensor_script = scripted_fn(tensor)
    assert_equal(transformed_tensor, transformed_tensor_script)

83
    batch_tensors = _create_data_batch(height=23, width=34, channels=channels, num_samples=4, device=device)
84
85
86
87
    _test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)

    with get_tmp_dir() as tmp_dir:
        scripted_fn.save(os.path.join(tmp_dir, f"t_{method.__name__}.pt"))
88

89

90
91
92
def _test_op(func, method, device, channels=3, fn_kwargs=None, meth_kwargs=None, test_exact_match=True, **match_kwargs):
    _test_functional_op(func, device, channels, fn_kwargs, test_exact_match=test_exact_match, **match_kwargs)
    _test_class_op(method, device, channels, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs)
93
94


95
@pytest.mark.parametrize("device", cpu_and_gpu())
96
@pytest.mark.parametrize(
97
98
    "func,method,fn_kwargs,match_kwargs",
    [
99
100
101
102
103
104
        (F.hflip, T.RandomHorizontalFlip, None, {}),
        (F.vflip, T.RandomVerticalFlip, None, {}),
        (F.invert, T.RandomInvert, None, {}),
        (F.posterize, T.RandomPosterize, {"bits": 4}, {}),
        (F.solarize, T.RandomSolarize, {"threshold": 192.0}, {}),
        (F.adjust_sharpness, T.RandomAdjustSharpness, {"sharpness_factor": 2.0}, {}),
105
106
107
108
109
110
111
112
        (
            F.autocontrast,
            T.RandomAutocontrast,
            None,
            {"test_exact_match": False, "agg_method": "max", "tol": (1 + 1e-5), "allowed_percentage_diff": 0.05},
        ),
        (F.equalize, T.RandomEqualize, None, {}),
    ],
113
)
114
@pytest.mark.parametrize("channels", [1, 3])
115
116
def test_random(func, method, device, channels, fn_kwargs, match_kwargs):
    _test_op(func, method, device, channels, fn_kwargs, fn_kwargs, **match_kwargs)
117

118

119
120
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("channels", [1, 3])
121
class TestColorJitter:
122
    @pytest.mark.parametrize("brightness", [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]])
123
    def test_color_jitter_brightness(self, brightness, device, channels):
124
125
126
        tol = 1.0 + 1e-10
        meth_kwargs = {"brightness": brightness}
        _test_class_op(
127
128
129
130
131
132
133
            T.ColorJitter,
            meth_kwargs=meth_kwargs,
            test_exact_match=False,
            device=device,
            tol=tol,
            agg_method="max",
            channels=channels,
134
135
        )

136
    @pytest.mark.parametrize("contrast", [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]])
137
    def test_color_jitter_contrast(self, contrast, device, channels):
138
139
140
        tol = 1.0 + 1e-10
        meth_kwargs = {"contrast": contrast}
        _test_class_op(
141
142
143
144
145
146
147
            T.ColorJitter,
            meth_kwargs=meth_kwargs,
            test_exact_match=False,
            device=device,
            tol=tol,
            agg_method="max",
            channels=channels,
148
149
        )

150
    @pytest.mark.parametrize("saturation", [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]])
151
    def test_color_jitter_saturation(self, saturation, device, channels):
152
153
154
        tol = 1.0 + 1e-10
        meth_kwargs = {"saturation": saturation}
        _test_class_op(
155
156
157
158
159
160
161
            T.ColorJitter,
            meth_kwargs=meth_kwargs,
            test_exact_match=False,
            device=device,
            tol=tol,
            agg_method="max",
            channels=channels,
162
163
        )

164
    @pytest.mark.parametrize("hue", [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]])
165
    def test_color_jitter_hue(self, hue, device, channels):
166
167
        meth_kwargs = {"hue": hue}
        _test_class_op(
168
169
170
171
172
173
174
            T.ColorJitter,
            meth_kwargs=meth_kwargs,
            test_exact_match=False,
            device=device,
            tol=16.1,
            agg_method="max",
            channels=channels,
175
176
        )

177
    def test_color_jitter_all(self, device, channels):
178
179
180
        # All 4 parameters together
        meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
        _test_class_op(
181
182
183
184
185
186
187
            T.ColorJitter,
            meth_kwargs=meth_kwargs,
            test_exact_match=False,
            device=device,
            tol=12.1,
            agg_method="max",
            channels=channels,
188
189
190
        )


191
192
193
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"])
@pytest.mark.parametrize("mul", [1, -1])
194
195
196
197
def test_pad(m, mul, device):
    fill = 127 if m == "constant" else 0

    # Test functional.pad (PIL and Tensor) with padding as single int
198
    _test_functional_op(F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, device=device)
199
    # Test functional.pad and transforms.Pad with padding as [int, ]
200
201
202
203
204
205
206
207
    fn_kwargs = meth_kwargs = {
        "padding": [
            mul * 2,
        ],
        "fill": fill,
        "padding_mode": m,
    }
    _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
208
209
    # Test functional.pad and transforms.Pad with padding as list
    fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m}
210
    _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
211
212
    # Test functional.pad and transforms.Pad with padding as tuple
    fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m}
213
    _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
214
215


216
@pytest.mark.parametrize("device", cpu_and_gpu())
217
218
219
def test_crop(device):
    fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
    # Test transforms.RandomCrop with size and padding as tuple
220
221
222
223
224
225
    meth_kwargs = {
        "size": (4, 5),
        "padding": (4, 4),
        "pad_if_needed": True,
    }
    _test_op(F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

    # Test transforms.functional.crop including outside the image area
    fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5}  # top
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)

    fn_kwargs = {"top": 1, "left": -3, "height": 4, "width": 5}  # left
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)

    fn_kwargs = {"top": 7, "left": 3, "height": 4, "width": 5}  # bottom
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)

    fn_kwargs = {"top": 3, "left": 8, "height": 4, "width": 5}  # right
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)

    fn_kwargs = {"top": -3, "left": -3, "height": 15, "width": 15}  # all
    _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device)


244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
    "padding_config",
    [
        {"padding_mode": "constant", "fill": 0},
        {"padding_mode": "constant", "fill": 10},
        {"padding_mode": "constant", "fill": 20},
        {"padding_mode": "edge"},
        {"padding_mode": "reflect"},
    ],
)
@pytest.mark.parametrize(
    "size",
    [
        5,
        [
            5,
        ],
        [6, 6],
    ],
)
265
266
267
def test_crop_pad(size, padding_config, device):
    config = dict(padding_config)
    config["size"] = size
268
    _test_class_op(T.RandomCrop, device, meth_kwargs=config)
269
270


271
@pytest.mark.parametrize("device", cpu_and_gpu())
272
def test_center_crop(device, tmpdir):
273
    fn_kwargs = {"output_size": (4, 5)}
274
275
276
277
    meth_kwargs = {
        "size": (4, 5),
    }
    _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
278
    fn_kwargs = {"output_size": (5,)}
279
    meth_kwargs = {"size": (5,)}
280
    _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs)
281
282
283
284
285
286
287
    tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=device)
    # Test torchscript of transforms.CenterCrop with size as int
    f = T.CenterCrop(size=5)
    scripted_fn = torch.jit.script(f)
    scripted_fn(tensor)

    # Test torchscript of transforms.CenterCrop with size as [int, ]
288
289
290
291
292
    f = T.CenterCrop(
        size=[
            5,
        ]
    )
293
294
295
296
297
298
299
300
    scripted_fn = torch.jit.script(f)
    scripted_fn(tensor)

    # Test torchscript of transforms.CenterCrop with size as tuple
    f = T.CenterCrop(size=(6, 6))
    scripted_fn = torch.jit.script(f)
    scripted_fn(tensor)

301
    scripted_fn.save(os.path.join(tmpdir, "t_center_crop.pt"))
302
303


304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
    "fn, method, out_length",
    [
        # test_five_crop
        (F.five_crop, T.FiveCrop, 5),
        # test_ten_crop
        (F.ten_crop, T.TenCrop, 10),
    ],
)
@pytest.mark.parametrize(
    "size",
    [
        (5,),
        [
            5,
        ],
        (4, 5),
        [4, 5],
    ],
)
325
def test_x_crop(fn, method, out_length, size, device):
326
    meth_kwargs = fn_kwargs = {"size": size}
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    scripted_fn = torch.jit.script(fn)

    tensor, pil_img = _create_data(height=20, width=20, device=device)
    transformed_t_list = fn(tensor, **fn_kwargs)
    transformed_p_list = fn(pil_img, **fn_kwargs)
    assert len(transformed_t_list) == len(transformed_p_list)
    assert len(transformed_t_list) == out_length
    for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
        _assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)

    transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
    assert len(transformed_t_list) == len(transformed_t_list_script)
    assert len(transformed_t_list_script) == out_length
    for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
        assert_equal(transformed_tensor, transformed_tensor_script)

    # test for class interface
    fn = method(**meth_kwargs)
    scripted_fn = torch.jit.script(fn)
    output = scripted_fn(tensor)
    assert len(output) == len(transformed_t_list_script)

    # test on batch of tensors
    batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
    torch.manual_seed(12)
    transformed_batch_list = fn(batch_tensors)

    for i in range(len(batch_tensors)):
        img_tensor = batch_tensors[i, ...]
        torch.manual_seed(12)
        transformed_img_list = fn(img_tensor)
        for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
            assert_equal(transformed_img, transformed_batch[i, ...])


362
@pytest.mark.parametrize("method", ["FiveCrop", "TenCrop"])
363
def test_x_crop_save(method, tmpdir):
364
365
366
367
368
    fn = getattr(T, method)(
        size=[
            5,
        ]
    )
369
    scripted_fn = torch.jit.script(fn)
370
    scripted_fn.save(os.path.join(tmpdir, "t_op_list_{}.pt".format(method)))
371
372
373


class TestResize:
374
    @pytest.mark.parametrize("size", [32, 34, 35, 36, 38])
375
376
377
378
379
380
381
382
383
384
385
    def test_resize_int(self, size):
        # TODO: Minimal check for bug-fix, improve this later
        x = torch.rand(3, 32, 46)
        t = T.Resize(size=size)
        y = t(x)
        # If size is an int, smaller edge of the image will be matched to this number.
        # i.e, if height > width, then image will be rescaled to (size * height / width, size).
        assert isinstance(y, torch.Tensor)
        assert y.shape[1] == size
        assert y.shape[2] == int(size * 46 / 32)

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    @pytest.mark.parametrize("device", cpu_and_gpu())
    @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64])
    @pytest.mark.parametrize(
        "size",
        [
            [
                32,
            ],
            [32, 32],
            (32, 32),
            [34, 35],
        ],
    )
    @pytest.mark.parametrize("max_size", [None, 35, 1000])
    @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST])
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
    def test_resize_scripted(self, dt, size, max_size, interpolation, device):
        tensor, _ = _create_data(height=34, width=36, device=device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

        if dt is not None:
            # This is a trivial cast to float of uint8 data to test all cases
            tensor = tensor.to(dt)
        if max_size is not None and len(size) != 1:
            pytest.xfail("with max_size, size must be a sequence with 2 elements")

        transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size)
        s_transform = torch.jit.script(transform)
        _test_transform_vs_scripted(transform, s_transform, tensor)
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

416
    def test_resize_save(self, tmpdir):
417
418
419
420
421
        transform = T.Resize(
            size=[
                32,
            ]
        )
422
        s_transform = torch.jit.script(transform)
423
        s_transform.save(os.path.join(tmpdir, "t_resize.pt"))
424

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    @pytest.mark.parametrize("device", cpu_and_gpu())
    @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
    @pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]])
    @pytest.mark.parametrize(
        "size",
        [
            (32,),
            [
                44,
            ],
            [
                32,
            ],
            [32, 32],
            (32, 32),
            [44, 55],
        ],
    )
    @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
444
445
446
447
448
449
450
451
    def test_resized_crop(self, scale, ratio, size, interpolation, device):
        tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
        batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
        transform = T.RandomResizedCrop(size=size, scale=scale, ratio=ratio, interpolation=interpolation)
        s_transform = torch.jit.script(transform)
        _test_transform_vs_scripted(transform, s_transform, tensor)
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

452
    def test_resized_crop_save(self, tmpdir):
453
454
455
456
457
        transform = T.RandomResizedCrop(
            size=[
                32,
            ]
        )
458
        s_transform = torch.jit.script(transform)
459
        s_transform.save(os.path.join(tmpdir, "t_resized_crop.pt"))
460
461


462
463
464
465
466
467
468
469
470
471
def _test_random_affine_helper(device, **kwargs):
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
    transform = T.RandomAffine(**kwargs)
    s_transform = torch.jit.script(transform)

    _test_transform_vs_scripted(transform, s_transform, tensor)
    _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


472
@pytest.mark.parametrize("device", cpu_and_gpu())
473
def test_random_affine(device, tmpdir):
474
475
    transform = T.RandomAffine(degrees=45.0)
    s_transform = torch.jit.script(transform)
476
    s_transform.save(os.path.join(tmpdir, "t_random_affine.pt"))
477
478


479
480
481
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize("shear", [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]])
482
483
484
485
def test_random_affine_shear(device, interpolation, shear):
    _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear)


486
487
488
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
489
490
491
492
def test_random_affine_scale(device, interpolation, scale):
    _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, scale=scale)


493
494
495
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize("translate", [(0.1, 0.2), [0.2, 0.1]])
496
497
498
499
def test_random_affine_translate(device, interpolation, translate):
    _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, translate=translate)


500
501
502
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]])
503
504
505
506
def test_random_affine_degrees(device, interpolation, degrees):
    _test_random_affine_helper(device, degrees=degrees, interpolation=interpolation)


507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize(
    "fill",
    [
        85,
        (10, -10, 10),
        0.7,
        [0.0, 0.0, 0.0],
        [
            1,
        ],
        1,
    ],
)
522
523
524
525
def test_random_affine_fill(device, interpolation, fill):
    _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, fill=fill)


526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("center", [(0, 0), [10, 10], None, (56, 44)])
@pytest.mark.parametrize("expand", [True, False])
@pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]])
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize(
    "fill",
    [
        85,
        (10, -10, 10),
        0.7,
        [0.0, 0.0, 0.0],
        [
            1,
        ],
        1,
    ],
)
544
545
546
547
def test_random_rotate(device, center, expand, degrees, interpolation, fill):
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

548
    transform = T.RandomRotation(degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill)
549
550
551
552
553
554
    s_transform = torch.jit.script(transform)

    _test_transform_vs_scripted(transform, s_transform, tensor)
    _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


555
def test_random_rotate_save(tmpdir):
556
557
    transform = T.RandomRotation(degrees=45.0)
    s_transform = torch.jit.script(transform)
558
    s_transform.save(os.path.join(tmpdir, "t_random_rotate.pt"))
559
560


561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("distortion_scale", np.linspace(0.1, 1.0, num=20))
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR])
@pytest.mark.parametrize(
    "fill",
    [
        85,
        (10, -10, 10),
        0.7,
        [0.0, 0.0, 0.0],
        [
            1,
        ],
        1,
    ],
)
577
578
579
580
def test_random_perspective(device, distortion_scale, interpolation, fill):
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

581
    transform = T.RandomPerspective(distortion_scale=distortion_scale, interpolation=interpolation, fill=fill)
582
583
584
585
586
587
    s_transform = torch.jit.script(transform)

    _test_transform_vs_scripted(transform, s_transform, tensor)
    _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


588
def test_random_perspective_save(tmpdir):
589
590
    transform = T.RandomPerspective()
    s_transform = torch.jit.script(transform)
591
    s_transform.save(os.path.join(tmpdir, "t_perspective.pt"))
592
593


594
595
596
597
598
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
    "Klass, meth_kwargs",
    [(T.Grayscale, {"num_output_channels": 1}), (T.Grayscale, {"num_output_channels": 3}), (T.RandomGrayscale, {})],
)
599
600
def test_to_grayscale(device, Klass, meth_kwargs):
    tol = 1.0 + 1e-10
601
    _test_class_op(Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max")
602
603


604
605
606
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("in_dtype", int_dtypes() + float_dtypes())
@pytest.mark.parametrize("out_dtype", int_dtypes() + float_dtypes())
607
608
609
610
611
612
613
614
615
616
def test_convert_image_dtype(device, in_dtype, out_dtype):
    tensor, _ = _create_data(26, 34, device=device)
    batch_tensors = torch.rand(4, 3, 44, 56, device=device)

    in_tensor = tensor.to(in_dtype)
    in_batch_tensors = batch_tensors.to(in_dtype)

    fn = T.ConvertImageDtype(dtype=out_dtype)
    scripted_fn = torch.jit.script(fn)

617
618
619
    if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or (
        in_dtype == torch.float64 and out_dtype == torch.int64
    ):
620
621
622
623
624
625
626
627
628
629
        with pytest.raises(RuntimeError, match=r"cannot be performed safely"):
            _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
        with pytest.raises(RuntimeError, match=r"cannot be performed safely"):
            _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
        return

    _test_transform_vs_scripted(fn, scripted_fn, in_tensor)
    _test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)


630
def test_convert_image_dtype_save(tmpdir):
631
632
    fn = T.ConvertImageDtype(dtype=torch.uint8)
    scripted_fn = torch.jit.script(fn)
633
    scripted_fn.save(os.path.join(tmpdir, "t_convert_dtype.pt"))
634
635


636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("policy", [policy for policy in T.AutoAugmentPolicy])
@pytest.mark.parametrize(
    "fill",
    [
        None,
        85,
        (10, -10, 10),
        0.7,
        [0.0, 0.0, 0.0],
        [
            1,
        ],
        1,
    ],
)
652
653
654
655
656
657
658
659
660
661
662
def test_autoaugment(device, policy, fill):
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

    transform = T.AutoAugment(policy=policy, fill=fill)
    s_transform = torch.jit.script(transform)
    for _ in range(25):
        _test_transform_vs_scripted(transform, s_transform, tensor)
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("num_ops", [1, 2, 3])
@pytest.mark.parametrize("magnitude", [7, 9, 11])
@pytest.mark.parametrize(
    "fill",
    [
        None,
        85,
        (10, -10, 10),
        0.7,
        [0.0, 0.0, 0.0],
        [
            1,
        ],
        1,
    ],
)
680
681
682
683
684
685
686
687
688
689
690
def test_randaugment(device, num_ops, magnitude, fill):
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

    transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
    s_transform = torch.jit.script(transform)
    for _ in range(25):
        _test_transform_vs_scripted(transform, s_transform, tensor)
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
    "fill",
    [
        None,
        85,
        (10, -10, 10),
        0.7,
        [0.0, 0.0, 0.0],
        [
            1,
        ],
        1,
    ],
)
706
707
708
709
710
711
712
713
714
715
716
def test_trivialaugmentwide(device, fill):
    tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
    batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

    transform = T.TrivialAugmentWide(fill=fill)
    s_transform = torch.jit.script(transform)
    for _ in range(25):
        _test_transform_vs_scripted(transform, s_transform, tensor)
        _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


717
@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
718
719
def test_autoaugment_save(augmentation, tmpdir):
    transform = augmentation()
720
    s_transform = torch.jit.script(transform)
721
    s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
722
723


724
@pytest.mark.parametrize("device", cpu_and_gpu())
725
@pytest.mark.parametrize(
726
727
    "config",
    [{"value": 0.2}, {"value": "random"}, {"value": (0.2, 0.2, 0.2)}, {"value": "random", "ratio": (0.1, 0.2)}],
728
729
730
731
732
733
734
735
736
737
738
)
def test_random_erasing(device, config):
    tensor, _ = _create_data(24, 32, channels=3, device=device)
    batch_tensors = torch.rand(4, 3, 44, 56, device=device)

    fn = T.RandomErasing(**config)
    scripted_fn = torch.jit.script(fn)
    _test_transform_vs_scripted(fn, scripted_fn, tensor)
    _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)


739
def test_random_erasing_save(tmpdir):
740
741
    fn = T.RandomErasing(value=0.2)
    scripted_fn = torch.jit.script(fn)
742
    scripted_fn.save(os.path.join(tmpdir, "t_random_erasing.pt"))
743
744
745
746
747
748
749
750
751
752


def test_random_erasing_with_invalid_data():
    img = torch.rand(3, 60, 60)
    # Test Set 0: invalid value
    random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
    with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value or 3"):
        random_erasing(img)


753
@pytest.mark.parametrize("device", cpu_and_gpu())
754
def test_normalize(device, tmpdir):
755
756
757
758
759
760
761
762
763
764
765
766
767
768
    fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    tensor, _ = _create_data(26, 34, device=device)

    with pytest.raises(TypeError, match="Input tensor should be a float tensor"):
        fn(tensor)

    batch_tensors = torch.rand(4, 3, 44, 56, device=device)
    tensor = tensor.to(dtype=torch.float32) / 255.0
    # test for class interface
    scripted_fn = torch.jit.script(fn)

    _test_transform_vs_scripted(fn, scripted_fn, tensor)
    _test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)

769
    scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
770
771


772
@pytest.mark.parametrize("device", cpu_and_gpu())
773
def test_linear_transformation(device, tmpdir):
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
    c, h, w = 3, 24, 32

    tensor, _ = _create_data(h, w, channels=c, device=device)

    matrix = torch.rand(c * h * w, c * h * w, device=device)
    mean_vector = torch.rand(c * h * w, device=device)

    fn = T.LinearTransformation(matrix, mean_vector)
    scripted_fn = torch.jit.script(fn)

    _test_transform_vs_scripted(fn, scripted_fn, tensor)

    batch_tensors = torch.rand(4, c, h, w, device=device)
    # We skip some tests from _test_transform_vs_scripted_on_batch as
    # results for scripted and non-scripted transformations are not exactly the same
    torch.manual_seed(12)
    transformed_batch = fn(batch_tensors)
    torch.manual_seed(12)
    s_transformed_batch = scripted_fn(batch_tensors)
    assert_equal(transformed_batch, s_transformed_batch)

795
    scripted_fn.save(os.path.join(tmpdir, "t_norm.pt"))
796
797


798
@pytest.mark.parametrize("device", cpu_and_gpu())
799
800
801
def test_compose(device):
    tensor, _ = _create_data(26, 34, device=device)
    tensor = tensor.to(dtype=torch.float32) / 255.0
802
803
804
805
806
807
    transforms = T.Compose(
        [
            T.CenterCrop(10),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
808
809
810
811
812
813
814
815
816
    s_transforms = torch.nn.Sequential(*transforms.transforms)

    scripted_fn = torch.jit.script(s_transforms)
    torch.manual_seed(12)
    transformed_tensor = transforms(tensor)
    torch.manual_seed(12)
    transformed_tensor_script = scripted_fn(tensor)
    assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))

817
818
819
820
821
    t = T.Compose(
        [
            lambda x: x,
        ]
    )
822
    with pytest.raises(RuntimeError, match="cannot call a value of type 'Tensor'"):
823
824
825
        torch.jit.script(t)


826
@pytest.mark.parametrize("device", cpu_and_gpu())
827
828
829
830
def test_random_apply(device):
    tensor, _ = _create_data(26, 34, device=device)
    tensor = tensor.to(dtype=torch.float32) / 255.0

831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
    transforms = T.RandomApply(
        [
            T.RandomHorizontalFlip(),
            T.ColorJitter(),
        ],
        p=0.4,
    )
    s_transforms = T.RandomApply(
        torch.nn.ModuleList(
            [
                T.RandomHorizontalFlip(),
                T.ColorJitter(),
            ]
        ),
        p=0.4,
    )
847
848
849
850
851
852
853
854
855
856
857

    scripted_fn = torch.jit.script(s_transforms)
    torch.manual_seed(12)
    transformed_tensor = transforms(tensor)
    torch.manual_seed(12)
    transformed_tensor_script = scripted_fn(tensor)
    assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))

    if device == "cpu":
        # Can't check this twice, otherwise
        # "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply"
858
859
860
861
862
863
        transforms = T.RandomApply(
            [
                T.ColorJitter(),
            ],
            p=0.3,
        )
864
865
866
867
        with pytest.raises(RuntimeError, match="Module 'RandomApply' has no attribute 'transforms'"):
            torch.jit.script(transforms)


868
869
870
871
872
873
874
875
876
877
878
879
880
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
    "meth_kwargs",
    [
        {"kernel_size": 3, "sigma": 0.75},
        {"kernel_size": 23, "sigma": [0.1, 2.0]},
        {"kernel_size": 23, "sigma": (0.1, 2.0)},
        {"kernel_size": [3, 3], "sigma": (1.0, 1.0)},
        {"kernel_size": (3, 3), "sigma": (0.1, 2.0)},
        {"kernel_size": [23], "sigma": 0.75},
    ],
)
@pytest.mark.parametrize("channels", [1, 3])
881
def test_gaussian_blur(device, channels, meth_kwargs):
882
    tol = 1.0 + 1e-10
883
    torch.manual_seed(12)
884
    _test_class_op(
885
886
887
888
889
890
891
        T.GaussianBlur,
        meth_kwargs=meth_kwargs,
        channels=channels,
        test_exact_match=False,
        device=device,
        agg_method="max",
        tol=tol,
892
    )