"docs/vscode:/vscode.git/clone" did not exist on "1d3dcdbd70fa53cf45a55334a70f9fc46c7e4bdb"
test_transforms_v2_functional.py 41.1 KB
Newer Older
1
import inspect
2
import math
3
import os
4
import re
5

6
import numpy as np
7
import PIL.Image
8
import pytest
9
import torch
10

11
from common_utils import (
12
    assert_close,
13
    cache,
14
    cpu_and_cuda,
15
16
    DEFAULT_SQUARE_SPATIAL_SIZE,
    make_bounding_boxes,
17
    needs_cuda,
18
    parametrized_error_message,
19
    set_rng_seed,
20
)
21
from torch.utils._pytree import tree_map
22
from torchvision import datapoints
23
from torchvision.transforms.functional import _get_perspective_coeffs
24
25
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
26
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
27
from torchvision.transforms.v2.utils import is_simple_tensor
28
29
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
30
31


32
33
34
35
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}


36
37
38
39
40
41
@cache
def script(fn):
    try:
        return torch.jit.script(fn)
    except Exception as error:
        raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
42
43


44
45
46
47
48
49
50
51
52
# Scripting a function often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
    f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
)


53
54
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
    args_kwargs = list(args_kwargs_fn(info))
55
56
57
58
    if not args_kwargs:
        raise pytest.UsageError(
            f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
        )
59
60
61
62
63
64
65
66
67
68
69
70
    idx_field_len = len(str(len(args_kwargs)))
    return [
        pytest.param(
            info,
            args_kwargs_,
            marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
            id=f"{info.id}-{idx:0{idx_field_len}}",
        )
        for idx, args_kwargs_ in enumerate(args_kwargs)
    ]


71
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
72
73
74
75
76
77
78
79
    def decorator(test_fn):
        parts = test_fn.__qualname__.split(".")
        if len(parts) == 1:
            test_class_name = None
            test_function_name = parts[0]
        elif len(parts) == 2:
            test_class_name, test_function_name = parts
        else:
80
            raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
81
82
83
84
85
        test_id = (test_class_name, test_function_name)

        argnames = ("info", "args_kwargs")
        argvalues = []
        for info in infos:
86
            argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
87
88
89
90

        return pytest.mark.parametrize(argnames, argvalues)(test_fn)

    return decorator
91
92


Philip Meier's avatar
Philip Meier committed
93
94
95
96
97
98
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


99
100
101
102
103
104
105
@pytest.fixture()
def test_id(request):
    test_class_name = request.cls.__name__ if request.cls is not None else None
    test_function_name = request.node.originalname
    return test_class_name, test_function_name


106
class TestKernels:
107
    sample_inputs = make_info_args_kwargs_parametrization(
108
109
110
        KERNEL_INFOS,
        args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
    )
111
    reference_inputs = make_info_args_kwargs_parametrization(
112
        [info for info in KERNEL_INFOS if info.reference_fn is not None],
113
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
114
    )
115

116
117
118
119
    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.logs_usage],
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    )
120
    @pytest.mark.parametrize("device", cpu_and_cuda())
121
122
123
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

124
125
        (input, *other_args), kwargs = args_kwargs.load(device)
        info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
126
127
128

        spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

129
    @ignore_jit_warning_no_profile
130
    @sample_inputs
131
    @pytest.mark.parametrize("device", cpu_and_cuda())
132
    def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
133
134
        kernel_eager = info.kernel
        kernel_scripted = script(kernel_eager)
135

136
        (input, *other_args), kwargs = args_kwargs.load(device)
137
        input = input.as_subclass(torch.Tensor)
138

139
140
        actual = kernel_scripted(input, *other_args, **kwargs)
        expected = kernel_eager(input, *other_args, **kwargs)
141

142
143
144
145
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
146
            msg=parametrized_error_message(input, other_args, **kwargs),
147
        )
148

149
150
151
152
153
154
    def _unbatch(self, batch, *, data_dims):
        if isinstance(batch, torch.Tensor):
            batched_tensor = batch
            metadata = ()
        else:
            batched_tensor, *metadata = batch
155

156
157
        if batched_tensor.ndim == data_dims:
            return batch
158

159
160
161
162
163
164
        return [
            self._unbatch(unbatched, data_dims=data_dims)
            for unbatched in (
                batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
            )
        ]
165
166

    @sample_inputs
167
    @pytest.mark.parametrize("device", cpu_and_cuda())
168
    def test_batched_vs_single(self, test_id, info, args_kwargs, device):
169
170
        (batched_input, *other_args), kwargs = args_kwargs.load(device)

171
        datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
172
173
174
        # This dictionary contains the number of rightmost dimensions that contain the actual data.
        # Everything to the left is considered a batch dimension.
        data_dims = {
175
            datapoints.Image: 3,
176
            datapoints.BoundingBoxes: 1,
177
178
179
180
            # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
            # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
            # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
            # common ground.
181
182
            datapoints.Mask: 2,
            datapoints.Video: 4,
183
        }.get(datapoint_type)
184
185
        if data_dims is None:
            raise pytest.UsageError(
186
                f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}."
187
188
189
190
191
192
            ) from None
        elif batched_input.ndim <= data_dims:
            pytest.skip("Input is not batched.")
        elif not all(batched_input.shape[:-data_dims]):
            pytest.skip("Input has a degenerate batch shape.")

193
        batched_input = batched_input.as_subclass(torch.Tensor)
194
195
        batched_output = info.kernel(batched_input, *other_args, **kwargs)
        actual = self._unbatch(batched_output, data_dims=data_dims)
196

197
198
        single_inputs = self._unbatch(batched_input, data_dims=data_dims)
        expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
199

200
201
202
203
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
204
            msg=parametrized_error_message(batched_input, *other_args, **kwargs),
205
        )
206
207

    @sample_inputs
208
    @pytest.mark.parametrize("device", cpu_and_cuda())
209
210
    def test_no_inplace(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
211
        input = input.as_subclass(torch.Tensor)
212
213
214
215
216

        if input.numel() == 0:
            pytest.skip("The input has a degenerate shape.")

        input_version = input._version
217
        info.kernel(input, *other_args, **kwargs)
218

219
        assert input._version == input_version
220
221
222

    @sample_inputs
    @needs_cuda
223
    def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
224
        (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
225
        input_cpu = input_cpu.as_subclass(torch.Tensor)
226
227
228
229
230
        input_cuda = input_cpu.to("cuda")

        output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
        output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

231
232
233
234
235
        assert_close(
            output_cuda,
            output_cpu,
            check_device=False,
            **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
236
            msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
237
        )
238
239

    @sample_inputs
240
    @pytest.mark.parametrize("device", cpu_and_cuda())
241
242
    def test_dtype_and_device_consistency(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
243
        input = input.as_subclass(torch.Tensor)
244
245

        output = info.kernel(input, *other_args, **kwargs)
246
247
248
        # Most kernels just return a tensor, but some also return some additional metadata
        if not isinstance(output, torch.Tensor):
            output, *_ = output
249
250
251
252

        assert output.dtype == input.dtype
        assert output.device == input.device

253
    @reference_inputs
254
255
    def test_against_reference(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
256

257
258
259
        actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
        # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
        # metadata regardless of whether the kernel takes it explicitly or not
260
        expected = info.reference_fn(input, *other_args, **kwargs)
261

262
263
264
265
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
266
            msg=parametrized_error_message(input, *other_args, **kwargs),
267
268
269
270
271
272
273
274
        )

    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.float32_vs_uint8],
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
    )
    def test_float32_vs_uint8(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
275
        input = input.as_subclass(torch.Tensor)
276
277
278
279
280
281
282

        if input.dtype != torch.uint8:
            pytest.skip(f"Input dtype is {input.dtype}.")

        adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

        actual = info.kernel(
283
            F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True),
284
285
286
287
            *adapted_other_args,
            **adapted_kwargs,
        )

288
        expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
289
290
291
292
293

        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
294
            msg=parametrized_error_message(input, *other_args, **kwargs),
295
        )
296
297


298
299
300
301
302
303
304
305
306
307
308
309
@pytest.fixture
def spy_on(mocker):
    def make_spy(fn, *, module=None, name=None):
        # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
        module = module or fn.__module__
        name = name or fn.__name__
        spy = mocker.patch(f"{module}.{name}", wraps=fn)
        return spy

    return make_spy


310
class TestDispatchers:
311
    image_sample_inputs = make_info_args_kwargs_parametrization(
312
313
        [info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
314
315
    )

316
317
318
319
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
320
    @pytest.mark.parametrize("device", cpu_and_cuda())
321
322
323
324
325
326
327
328
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

        args, kwargs = args_kwargs.load(device)
        info.dispatcher(*args, **kwargs)

        spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

329
    @ignore_jit_warning_no_profile
330
    @image_sample_inputs
331
    @pytest.mark.parametrize("device", cpu_and_cuda())
332
333
    def test_scripted_smoke(self, info, args_kwargs, device):
        dispatcher = script(info.dispatcher)
334

335
336
        (image_datapoint, *other_args), kwargs = args_kwargs.load(device)
        image_simple_tensor = torch.Tensor(image_datapoint)
337

338
        dispatcher(image_simple_tensor, *other_args, **kwargs)
339

340
341
    # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
    #  replaces this test for them.
342
    @ignore_jit_warning_no_profile
343
344
345
346
347
348
    @pytest.mark.parametrize(
        "dispatcher",
        [
            F.get_dimensions,
            F.get_image_num_channels,
            F.get_image_size,
349
350
            F.get_num_channels,
            F.get_num_frames,
Philip Meier's avatar
Philip Meier committed
351
            F.get_size,
352
            F.rgb_to_grayscale,
353
            F.uniform_temporal_subsample,
354
355
356
357
358
        ],
        ids=lambda dispatcher: dispatcher.__name__,
    )
    def test_scriptable(self, dispatcher):
        script(dispatcher)
359

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    @image_sample_inputs
    def test_simple_tensor_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
        image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)

        output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)

        # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
        assert type(output) is torch.Tensor

    @make_info_args_kwargs_parametrization(
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
    )
    def test_pil_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()

        if image_datapoint.ndim > 3:
            pytest.skip("Input is batched")

        image_pil = F.to_image_pil(image_datapoint)

        output = info.dispatcher(image_pil, *other_args, **kwargs)

        assert isinstance(output, PIL.Image.Image)

    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
    def test_datapoint_output_type(self, info, args_kwargs):
        (datapoint, *other_args), kwargs = args_kwargs.load()

        output = info.dispatcher(datapoint, *other_args, **kwargs)

        assert isinstance(output, type(datapoint))

397
398
399
        if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes:
            assert output.format == datapoint.format

400
    @pytest.mark.parametrize(
401
        ("dispatcher_info", "datapoint_type", "kernel_info"),
402
        [
403
404
405
            pytest.param(
                dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}"
            )
406
            for dispatcher_info in DISPATCHER_INFOS
407
            for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items()
408
409
        ],
    )
410
    def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info):
411
412
413
414
415
416
        dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

        kernel_signature = inspect.signature(kernel_info.kernel)
        kernel_params = list(kernel_signature.parameters.values())[1:]

417
        # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
418
419
420
421
422
423
        # explicitly passed to the kernel.
        input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
        explicit_metadata = {
            datapoints.BoundingBoxes: {"format", "canvas_size"},
        }
        kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439

        dispatcher_params = iter(dispatcher_params)
        for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
            try:
                # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
                # dispatcher parameters that have no kernel equivalent while keeping the order intact.
                while dispatcher_param.name != kernel_param.name:
                    dispatcher_param = next(dispatcher_params)
            except StopIteration:
                raise AssertionError(
                    f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
                    f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
                ) from None

            assert dispatcher_param == kernel_param

440
441
442
443
444
445
446
447
    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
    def test_unkown_type(self, info):
        unkown_input = object()
        (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")

        with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
            info.dispatcher(unkown_input, *other_args, **kwargs)

448
449
450
451
    @make_info_args_kwargs_parametrization(
        [
            info
            for info in DISPATCHER_INFOS
452
            if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_format_bounding_boxes
453
        ],
454
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
455
    )
456
457
458
    def test_bounding_boxes_format_consistency(self, info, args_kwargs):
        (bounding_boxes, *other_args), kwargs = args_kwargs.load()
        format = bounding_boxes.format
459

460
        output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
461
462
463

        assert output.format == format

464

465
@pytest.mark.parametrize(
466
    ("alias", "target"),
467
    [
468
469
470
471
472
473
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
            (F.to_pil_image, F.to_image_pil),
474
            (F.elastic_transform, F.elastic),
475
            (F.to_grayscale, F.rgb_to_grayscale),
476
        ]
477
478
    ],
)
479
480
def test_alias(alias, target):
    assert alias is target
481
482


483
@pytest.mark.parametrize("device", cpu_and_cuda())
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
    stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

    def assert_samples_from_standard_normal(t):
        p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
        return p_value > 1e-4

    image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
    mean = image.mean(dim=(1, 2)).tolist()
    std = image.std(dim=(1, 2)).tolist()

    assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))


499
class TestClampBoundingBoxes:
500
501
502
503
504
    @pytest.mark.parametrize(
        "metadata",
        [
            dict(),
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
505
            dict(canvas_size=(1, 1)),
506
507
508
509
510
        ],
    )
    def test_simple_tensor_insufficient_metadata(self, metadata):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

Philip Meier's avatar
Philip Meier committed
511
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
512
            F.clamp_bounding_boxes(simple_tensor, **metadata)
513
514
515
516
517

    @pytest.mark.parametrize(
        "metadata",
        [
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
518
519
            dict(canvas_size=(1, 1)),
            dict(format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
520
521
522
523
524
        ],
    )
    def test_datapoint_explicit_metadata(self, metadata):
        datapoint = next(make_bounding_boxes())

Philip Meier's avatar
Philip Meier committed
525
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
526
            F.clamp_bounding_boxes(datapoint, **metadata)
527
528


529
class TestConvertFormatBoundingBoxes:
530
531
532
533
534
535
536
537
538
    @pytest.mark.parametrize(
        ("inpt", "old_format"),
        [
            (next(make_bounding_boxes()), None),
            (next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
        ],
    )
    def test_missing_new_format(self, inpt, old_format):
        with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
539
            F.convert_format_bounding_boxes(inpt, old_format)
540
541
542
543
544

    def test_simple_tensor_insufficient_metadata(self):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

        with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
545
            F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
546
547
548
549
550

    def test_datapoint_explicit_metadata(self):
        datapoint = next(make_bounding_boxes())

        with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
551
            F.convert_format_bounding_boxes(
552
553
554
555
                datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
            )


556
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
557
#  `transforms_v2_kernel_infos.py`
558
559


560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
    rot = math.radians(angle_)
    cx, cy = center_
    tx, ty = translate_
    sx, sy = [math.radians(sh_) for sh_ in shear_]

    c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
    t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
    c_matrix_inv = np.linalg.inv(c_matrix)
    rs_matrix = np.array(
        [
            [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
            [scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
            [0, 0, 1],
        ]
    )
    shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
    shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
    rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
    true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
    return true_matrix


583
@pytest.mark.parametrize("device", cpu_and_cuda())
584
585
@pytest.mark.parametrize(
    "format",
586
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
587
)
588
589
590
591
592
593
594
@pytest.mark.parametrize(
    "top, left, height, width, expected_bboxes",
    [
        [8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
        [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
    ],
)
595
def test_correctness_crop_bounding_boxes(device, format, top, left, height, width, expected_bboxes):
596
597
598
599
600
601
602
603
604
605
606
607
608

    # Expected bboxes computed using Albumentations:
    # import numpy as np
    # from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox
    # expected_bboxes = []
    # for in_box in in_boxes:
    #     n_in_box = normalize_bbox(in_box, *size)
    #     n_out_box = crop_bbox_by_coords(
    #         n_in_box, (left, top, left + width, top + height), height, width, *size
    #     )
    #     out_box = denormalize_bbox(n_out_box, height, width)
    #     expected_bboxes.append(out_box)

609
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
610
    canvas_size = (64, 76)
611
612
613
614
615
    in_boxes = [
        [10.0, 15.0, 25.0, 35.0],
        [50.0, 5.0, 70.0, 22.0],
        [45.0, 46.0, 56.0, 62.0],
    ]
616
    in_boxes = torch.tensor(in_boxes, device=device)
617
    if format != datapoints.BoundingBoxFormat.XYXY:
618
        in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
619

620
    expected_bboxes = clamp_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
621
        datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
622
623
    ).tolist()

Philip Meier's avatar
Philip Meier committed
624
    output_boxes, output_canvas_size = F.crop_bounding_boxes(
625
        in_boxes,
626
        format,
627
628
        top,
        left,
Philip Meier's avatar
Philip Meier committed
629
630
        canvas_size[0],
        canvas_size[1],
631
632
    )

633
    if format != datapoints.BoundingBoxFormat.XYXY:
634
        output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
635

636
    torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
Philip Meier's avatar
Philip Meier committed
637
    torch.testing.assert_close(output_canvas_size, canvas_size)
638
639


640
@pytest.mark.parametrize("device", cpu_and_cuda())
641
642
643
644
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
    mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    mask[:, 0, :] = 1

645
    out_mask = F.vertical_flip_mask(mask)
646
647
648
649

    expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    expected_mask[:, -1, :] = 1
    torch.testing.assert_close(out_mask, expected_mask)
650
651


652
@pytest.mark.parametrize("device", cpu_and_cuda())
653
654
@pytest.mark.parametrize(
    "format",
655
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
656
657
658
659
660
661
662
663
)
@pytest.mark.parametrize(
    "top, left, height, width, size",
    [
        [0, 0, 30, 30, (60, 60)],
        [-5, 5, 35, 45, (32, 34)],
    ],
)
664
def test_correctness_resized_crop_bounding_boxes(device, format, top, left, height, width, size):
665
    def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
666
667
668
669
670
671
672
        # bbox should be xyxy
        bbox[0] = (bbox[0] - left_) * size_[1] / width_
        bbox[1] = (bbox[1] - top_) * size_[0] / height_
        bbox[2] = (bbox[2] - left_) * size_[1] / width_
        bbox[3] = (bbox[3] - top_) * size_[0] / height_
        return bbox

673
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
674
    canvas_size = (100, 100)
675
676
677
678
679
680
    in_boxes = [
        [10.0, 10.0, 20.0, 20.0],
        [5.0, 10.0, 15.0, 20.0],
    ]
    expected_bboxes = []
    for in_box in in_boxes:
681
        expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
682
683
    expected_bboxes = torch.tensor(expected_bboxes, device=device)

684
    in_boxes = datapoints.BoundingBoxes(
Philip Meier's avatar
Philip Meier committed
685
        in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
686
    )
687
    if format != datapoints.BoundingBoxFormat.XYXY:
688
        in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
689

Philip Meier's avatar
Philip Meier committed
690
    output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
691

692
    if format != datapoints.BoundingBoxFormat.XYXY:
693
        output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
694
695

    torch.testing.assert_close(output_boxes, expected_bboxes)
Philip Meier's avatar
Philip Meier committed
696
    torch.testing.assert_close(output_canvas_size, size)
697
698


699
700
701
702
703
704
705
706
707
708
709
710
def _parse_padding(padding):
    if isinstance(padding, int):
        return [padding] * 4
    if isinstance(padding, list):
        if len(padding) == 1:
            return padding * 4
        if len(padding) == 2:
            return padding * 2  # [left, up, right, down]

    return padding


711
@pytest.mark.parametrize("device", cpu_and_cuda())
712
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
713
def test_correctness_pad_bounding_boxes(device, padding):
714
715
716
    def _compute_expected_bbox(bbox, padding_):
        pad_left, pad_up, _, _ = _parse_padding(padding_)

717
718
        dtype = bbox.dtype
        format = bbox.format
719
720
        bbox = (
            bbox.clone()
721
            if format == datapoints.BoundingBoxFormat.XYXY
722
            else convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
723
        )
724
725
726
727

        bbox[0::2] += pad_left
        bbox[1::2] += pad_up

728
        bbox = convert_format_bounding_boxes(bbox, new_format=format)
729
        if bbox.dtype != dtype:
730
731
            # Temporary cast to original dtype
            # e.g. float32 -> int
732
            bbox = bbox.to(dtype)
733
734
        return bbox

Philip Meier's avatar
Philip Meier committed
735
    def _compute_expected_canvas_size(bbox, padding_):
736
        pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
Philip Meier's avatar
Philip Meier committed
737
        height, width = bbox.canvas_size
738
739
        return height + pad_up + pad_down, width + pad_left + pad_right

740
741
742
    for bboxes in make_bounding_boxes():
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
743
        bboxes_canvas_size = bboxes.canvas_size
744

Philip Meier's avatar
Philip Meier committed
745
746
        output_boxes, output_canvas_size = F.pad_bounding_boxes(
            bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
747
748
        )

Philip Meier's avatar
Philip Meier committed
749
        torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
750

751
        if bboxes.ndim < 2 or bboxes.shape[0] == 0:
752
753
754
755
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
Philip Meier's avatar
Philip Meier committed
756
            bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
757
758
759
760
761
762
            expected_bboxes.append(_compute_expected_bbox(bbox, padding))

        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
763
        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
764
765


766
@pytest.mark.parametrize("device", cpu_and_cuda())
767
768
769
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
    mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)

770
    out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
771
772
773
774
775
776

    expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
    expected_mask[:, 1:-1, 1:-1] = 1
    torch.testing.assert_close(out_mask, expected_mask)


777
@pytest.mark.parametrize("device", cpu_and_cuda())
778
779
780
781
782
783
784
785
@pytest.mark.parametrize(
    "startpoints, endpoints",
    [
        [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
    ],
)
786
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
787
788
789
790
791
792
793
794
795
796
797
798
799
800
    def _compute_expected_bbox(bbox, pcoeffs_):
        m1 = np.array(
            [
                [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
                [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
            ]
        )
        m2 = np.array(
            [
                [pcoeffs_[6], pcoeffs_[7], 1.0],
                [pcoeffs_[6], pcoeffs_[7], 1.0],
            ]
        )

801
        bbox_xyxy = convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
802
803
804
805
806
807
808
809
810
811
812
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
        numer = np.matmul(points, m1.T)
        denom = np.matmul(points, m2.T)
        transformed_points = numer / denom
813
814
815
816
817
818
819
820
        out_bbox = np.array(
            [
                np.min(transformed_points[:, 0]),
                np.min(transformed_points[:, 1]),
                np.max(transformed_points[:, 0]),
                np.max(transformed_points[:, 1]),
            ]
        )
821
        out_bbox = datapoints.BoundingBoxes(
822
            out_bbox,
823
            format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
824
            canvas_size=bbox.canvas_size,
825
            dtype=bbox.dtype,
826
827
            device=bbox.device,
        )
828
        return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format))
829

Philip Meier's avatar
Philip Meier committed
830
    canvas_size = (32, 38)
831
832
833
834

    pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
    inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

835
    for bboxes in make_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
836
837
        bboxes = bboxes.to(device)

838
        output_bboxes = F.perspective_bounding_boxes(
839
840
            bboxes.as_subclass(torch.Tensor),
            format=bboxes.format,
Philip Meier's avatar
Philip Meier committed
841
            canvas_size=bboxes.canvas_size,
842
843
            startpoints=None,
            endpoints=None,
844
            coefficients=pcoeffs,
845
846
847
848
849
850
851
        )

        if bboxes.ndim < 2:
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
Philip Meier's avatar
Philip Meier committed
852
            bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, canvas_size=bboxes.canvas_size)
853
854
855
856
857
            expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
858
        torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
859
860


861
@pytest.mark.parametrize("device", cpu_and_cuda())
862
863
864
865
@pytest.mark.parametrize(
    "output_size",
    [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
)
866
def test_correctness_center_crop_bounding_boxes(device, output_size):
867
868
    def _compute_expected_bbox(bbox, output_size_):
        format_ = bbox.format
Philip Meier's avatar
Philip Meier committed
869
        canvas_size_ = bbox.canvas_size
870
        dtype = bbox.dtype
871
        bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
872
873
874
875

        if len(output_size_) == 1:
            output_size_.append(output_size_[-1])

Philip Meier's avatar
Philip Meier committed
876
877
        cy = int(round((canvas_size_[0] - output_size_[0]) * 0.5))
        cx = int(round((canvas_size_[1] - output_size_[1]) * 0.5))
878
879
880
881
882
883
        out_bbox = [
            bbox[0].item() - cx,
            bbox[1].item() - cy,
            bbox[2].item(),
            bbox[3].item(),
        ]
884
        out_bbox = torch.tensor(out_bbox)
885
        out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
Philip Meier's avatar
Philip Meier committed
886
        out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
887
        return out_bbox.to(dtype=dtype, device=bbox.device)
888

889
    for bboxes in make_bounding_boxes(extra_dims=((4,),)):
890
891
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
892
        bboxes_canvas_size = bboxes.canvas_size
893

Philip Meier's avatar
Philip Meier committed
894
895
        output_boxes, output_canvas_size = F.center_crop_bounding_boxes(
            bboxes, bboxes_format, bboxes_canvas_size, output_size
896
        )
897
898
899
900
901
902

        if bboxes.ndim < 2:
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
Philip Meier's avatar
Philip Meier committed
903
            bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
904
905
906
907
908
909
            expected_bboxes.append(_compute_expected_bbox(bbox, output_size))

        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
910
911

        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
Philip Meier's avatar
Philip Meier committed
912
        torch.testing.assert_close(output_canvas_size, output_size)
913
914


915
@pytest.mark.parametrize("device", cpu_and_cuda())
916
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
917
918
def test_correctness_center_crop_mask(device, output_size):
    def _compute_expected_mask(mask, output_size):
919
920
921
922
923
924
925
926
927
928
929
930
931
        crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]

        _, image_height, image_width = mask.shape
        if crop_width > image_height or crop_height > image_width:
            padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
            mask = F.pad_image_tensor(mask, padding, fill=0)

        left = round((image_width - crop_width) * 0.5)
        top = round((image_height - crop_height) * 0.5)

        return mask[:, top : top + crop_height, left : left + crop_width]

    mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
932
    actual = F.center_crop_mask(mask, output_size)
933

934
    expected = _compute_expected_mask(mask, output_size)
935
    torch.testing.assert_close(expected, actual)
936
937
938


# Copied from test/test_functional_tensor.py
939
@pytest.mark.parametrize("device", cpu_and_cuda())
Philip Meier's avatar
Philip Meier committed
940
@pytest.mark.parametrize("canvas_size", ("small", "large"))
941
942
943
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
Philip Meier's avatar
Philip Meier committed
944
def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
    fn = F.gaussian_blur_image_tensor

    # true_cv2_results = {
    #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
    #     "3_3_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
    #     "3_3_0.5": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
    #     "3_5_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
    #     "3_5_0.5": ...
    #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
    #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
    #     "23_23_1.7": ...
    # }
    p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
    true_cv2_results = torch.load(p)

Philip Meier's avatar
Philip Meier committed
964
    if canvas_size == "small":
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
        tensor = (
            torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
        )
    else:
        tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)

    if dt == torch.float16 and device == "cpu":
        # skip float16 on CPU case
        return

    if dt is not None:
        tensor = tensor.to(dtype=dt)

    _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
    _sigma = sigma[0] if sigma is not None else None
    shape = tensor.shape
    gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
    if gt_key not in true_cv2_results:
        return

    true_out = (
        torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
    )

989
    image = datapoints.Image(tensor)
990
991
992
993
994

    out = fn(image, kernel_size=ksize, sigma=sigma)
    torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


995
996
997
998
999
1000
1001
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
1002
1003
def test_to_image_tensor(inpt):
    output = F.to_image_tensor(inpt)
1004
    assert isinstance(output, torch.Tensor)
1005
    assert output.shape == (3, 32, 32)
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
def test_to_image_pil(inpt, mode):
    output = F.to_image_pil(inpt, mode=mode)
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033


def test_equalize_image_tensor_edge_cases():
    inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
    output = F.equalize_image_tensor(inpt)
    torch.testing.assert_close(inpt, output)

    inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
    inpt[..., 100:, 100:] = 1
    output = F.equalize_image_tensor(inpt)
    assert output.unique().tolist() == [0, 255]
1034
1035


1036
@pytest.mark.parametrize("device", cpu_and_cuda())
1037
1038
1039
1040
1041
1042
1043
def test_correctness_uniform_temporal_subsample(device):
    video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
    out_video = F.uniform_temporal_subsample(video, 5)
    assert out_video.unique().tolist() == [0, 2, 4, 6, 9]

    out_video = F.uniform_temporal_subsample(video, 8)
    assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
    [info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor],
    args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
    (input, *other_args), kwargs = args_kwargs.load("cpu")

    output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

    error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
    assert input.ndim == 3, error_msg_fn
    input_stride = input.stride()
    output_stride = output.stride()
    # Here we check output memory format according to the input:
    # if input_stride is (..., 1) then input is most likely channels first and thus
    #   output strides should match channels first strides (H * W, H, 1)
    # if input_stride is (1, ...) then input is most likely channels last and thus
    #   output strides should match channels last strides (1, W * C, C)
    if input_stride[-1] == 1:
        expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
        assert expected_stride == output_stride, error_msg_fn("")
    elif input_stride[0] == 1:
        expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
        assert expected_stride == output_stride, error_msg_fn("")
    else:
        assert False, error_msg_fn("")
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083


def test_resize_float16_no_rounding():
    # Make sure Resize() doesn't round float16 images
    # Non-regression test for https://github.com/pytorch/vision/issues/7667

    img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16)
    out = F.resize(img, size=(10, 10))
    assert out.dtype == torch.float16
    assert (out.round() - out).sum() > 0