test_transforms_v2_functional.py 40.3 KB
Newer Older
1
import inspect
2
import math
3
import os
4
import re
5

6
import numpy as np
7
import PIL.Image
8
import pytest
9
import torch
10

11
from common_utils import (
12
    assert_close,
13
    cache,
14
    cpu_and_cuda,
15
16
    DEFAULT_SQUARE_SPATIAL_SIZE,
    make_bounding_boxes,
17
    needs_cuda,
18
    parametrized_error_message,
19
    set_rng_seed,
20
)
21
from torch.utils._pytree import tree_map
22
from torchvision import datapoints
23
from torchvision.transforms.functional import _get_perspective_coeffs
24
25
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
26
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
27
from torchvision.transforms.v2.utils import is_simple_tensor
28
29
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
30
31


32
33
34
35
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}


36
37
38
39
40
41
@cache
def script(fn):
    try:
        return torch.jit.script(fn)
    except Exception as error:
        raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
42
43


44
45
46
47
48
49
50
51
52
# Scripting a function often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
    f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
)


53
54
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
    args_kwargs = list(args_kwargs_fn(info))
55
56
57
58
    if not args_kwargs:
        raise pytest.UsageError(
            f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
        )
59
60
61
62
63
64
65
66
67
68
69
70
    idx_field_len = len(str(len(args_kwargs)))
    return [
        pytest.param(
            info,
            args_kwargs_,
            marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
            id=f"{info.id}-{idx:0{idx_field_len}}",
        )
        for idx, args_kwargs_ in enumerate(args_kwargs)
    ]


71
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
72
73
74
75
76
77
78
79
    def decorator(test_fn):
        parts = test_fn.__qualname__.split(".")
        if len(parts) == 1:
            test_class_name = None
            test_function_name = parts[0]
        elif len(parts) == 2:
            test_class_name, test_function_name = parts
        else:
80
            raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
81
82
83
84
85
        test_id = (test_class_name, test_function_name)

        argnames = ("info", "args_kwargs")
        argvalues = []
        for info in infos:
86
            argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
87
88
89
90

        return pytest.mark.parametrize(argnames, argvalues)(test_fn)

    return decorator
91
92


Philip Meier's avatar
Philip Meier committed
93
94
95
96
97
98
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


99
100
101
102
103
104
105
@pytest.fixture()
def test_id(request):
    test_class_name = request.cls.__name__ if request.cls is not None else None
    test_function_name = request.node.originalname
    return test_class_name, test_function_name


106
class TestKernels:
107
    sample_inputs = make_info_args_kwargs_parametrization(
108
109
110
        KERNEL_INFOS,
        args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
    )
111
    reference_inputs = make_info_args_kwargs_parametrization(
112
        [info for info in KERNEL_INFOS if info.reference_fn is not None],
113
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
114
    )
115

116
117
118
119
    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.logs_usage],
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    )
120
    @pytest.mark.parametrize("device", cpu_and_cuda())
121
122
123
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

124
125
        (input, *other_args), kwargs = args_kwargs.load(device)
        info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
126
127
128

        spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

129
    @ignore_jit_warning_no_profile
130
    @sample_inputs
131
    @pytest.mark.parametrize("device", cpu_and_cuda())
132
    def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
133
134
        kernel_eager = info.kernel
        kernel_scripted = script(kernel_eager)
135

136
        (input, *other_args), kwargs = args_kwargs.load(device)
137
        input = input.as_subclass(torch.Tensor)
138

139
140
        actual = kernel_scripted(input, *other_args, **kwargs)
        expected = kernel_eager(input, *other_args, **kwargs)
141

142
143
144
145
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
146
            msg=parametrized_error_message(input, other_args, **kwargs),
147
        )
148

149
150
151
152
153
154
    def _unbatch(self, batch, *, data_dims):
        if isinstance(batch, torch.Tensor):
            batched_tensor = batch
            metadata = ()
        else:
            batched_tensor, *metadata = batch
155

156
157
        if batched_tensor.ndim == data_dims:
            return batch
158

159
160
161
162
163
164
        return [
            self._unbatch(unbatched, data_dims=data_dims)
            for unbatched in (
                batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
            )
        ]
165
166

    @sample_inputs
167
    @pytest.mark.parametrize("device", cpu_and_cuda())
168
    def test_batched_vs_single(self, test_id, info, args_kwargs, device):
169
170
        (batched_input, *other_args), kwargs = args_kwargs.load(device)

171
        datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
172
173
174
        # This dictionary contains the number of rightmost dimensions that contain the actual data.
        # Everything to the left is considered a batch dimension.
        data_dims = {
175
            datapoints.Image: 3,
176
            datapoints.BoundingBoxes: 1,
177
178
179
180
            # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
            # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
            # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
            # common ground.
181
182
            datapoints.Mask: 2,
            datapoints.Video: 4,
183
        }.get(datapoint_type)
184
185
        if data_dims is None:
            raise pytest.UsageError(
186
                f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}."
187
188
189
190
191
192
            ) from None
        elif batched_input.ndim <= data_dims:
            pytest.skip("Input is not batched.")
        elif not all(batched_input.shape[:-data_dims]):
            pytest.skip("Input has a degenerate batch shape.")

193
        batched_input = batched_input.as_subclass(torch.Tensor)
194
195
        batched_output = info.kernel(batched_input, *other_args, **kwargs)
        actual = self._unbatch(batched_output, data_dims=data_dims)
196

197
198
        single_inputs = self._unbatch(batched_input, data_dims=data_dims)
        expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
199

200
201
202
203
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
204
            msg=parametrized_error_message(batched_input, *other_args, **kwargs),
205
        )
206
207

    @sample_inputs
208
    @pytest.mark.parametrize("device", cpu_and_cuda())
209
210
    def test_no_inplace(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
211
        input = input.as_subclass(torch.Tensor)
212
213
214
215
216

        if input.numel() == 0:
            pytest.skip("The input has a degenerate shape.")

        input_version = input._version
217
        info.kernel(input, *other_args, **kwargs)
218

219
        assert input._version == input_version
220
221
222

    @sample_inputs
    @needs_cuda
223
    def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
224
        (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
225
        input_cpu = input_cpu.as_subclass(torch.Tensor)
226
227
228
229
230
        input_cuda = input_cpu.to("cuda")

        output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
        output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

231
232
233
234
235
        assert_close(
            output_cuda,
            output_cpu,
            check_device=False,
            **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
236
            msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
237
        )
238
239

    @sample_inputs
240
    @pytest.mark.parametrize("device", cpu_and_cuda())
241
242
    def test_dtype_and_device_consistency(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
243
        input = input.as_subclass(torch.Tensor)
244
245

        output = info.kernel(input, *other_args, **kwargs)
246
247
248
        # Most kernels just return a tensor, but some also return some additional metadata
        if not isinstance(output, torch.Tensor):
            output, *_ = output
249
250
251
252

        assert output.dtype == input.dtype
        assert output.device == input.device

253
    @reference_inputs
254
255
    def test_against_reference(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
256

257
258
259
        actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
        # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
        # metadata regardless of whether the kernel takes it explicitly or not
260
        expected = info.reference_fn(input, *other_args, **kwargs)
261

262
263
264
265
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
266
            msg=parametrized_error_message(input, *other_args, **kwargs),
267
268
269
270
271
272
273
274
        )

    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.float32_vs_uint8],
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
    )
    def test_float32_vs_uint8(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
275
        input = input.as_subclass(torch.Tensor)
276
277
278
279
280
281
282

        if input.dtype != torch.uint8:
            pytest.skip(f"Input dtype is {input.dtype}.")

        adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

        actual = info.kernel(
283
            F.to_dtype_image(input, dtype=torch.float32, scale=True),
284
285
286
287
            *adapted_other_args,
            **adapted_kwargs,
        )

288
        expected = F.to_dtype_image(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
289
290
291
292
293

        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
294
            msg=parametrized_error_message(input, *other_args, **kwargs),
295
        )
296
297


298
299
300
301
302
303
304
305
306
307
308
309
@pytest.fixture
def spy_on(mocker):
    def make_spy(fn, *, module=None, name=None):
        # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
        module = module or fn.__module__
        name = name or fn.__name__
        spy = mocker.patch(f"{module}.{name}", wraps=fn)
        return spy

    return make_spy


310
class TestDispatchers:
311
    image_sample_inputs = make_info_args_kwargs_parametrization(
312
313
        [info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
314
315
    )

316
317
318
319
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
320
    @pytest.mark.parametrize("device", cpu_and_cuda())
321
322
323
324
325
326
327
328
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

        args, kwargs = args_kwargs.load(device)
        info.dispatcher(*args, **kwargs)

        spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

329
    @ignore_jit_warning_no_profile
330
    @image_sample_inputs
331
    @pytest.mark.parametrize("device", cpu_and_cuda())
332
333
    def test_scripted_smoke(self, info, args_kwargs, device):
        dispatcher = script(info.dispatcher)
334

335
336
        (image_datapoint, *other_args), kwargs = args_kwargs.load(device)
        image_simple_tensor = torch.Tensor(image_datapoint)
337

338
        dispatcher(image_simple_tensor, *other_args, **kwargs)
339

340
341
    # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
    #  replaces this test for them.
342
    @ignore_jit_warning_no_profile
343
344
345
346
347
348
    @pytest.mark.parametrize(
        "dispatcher",
        [
            F.get_dimensions,
            F.get_image_num_channels,
            F.get_image_size,
349
350
            F.get_num_channels,
            F.get_num_frames,
Philip Meier's avatar
Philip Meier committed
351
            F.get_size,
352
            F.rgb_to_grayscale,
353
            F.uniform_temporal_subsample,
354
355
356
357
358
        ],
        ids=lambda dispatcher: dispatcher.__name__,
    )
    def test_scriptable(self, dispatcher):
        script(dispatcher)
359

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    @image_sample_inputs
    def test_simple_tensor_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
        image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)

        output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)

        # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
        assert type(output) is torch.Tensor

    @make_info_args_kwargs_parametrization(
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
    )
    def test_pil_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()

        if image_datapoint.ndim > 3:
            pytest.skip("Input is batched")

380
        image_pil = F.to_pil_image(image_datapoint)
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396

        output = info.dispatcher(image_pil, *other_args, **kwargs)

        assert isinstance(output, PIL.Image.Image)

    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
    def test_datapoint_output_type(self, info, args_kwargs):
        (datapoint, *other_args), kwargs = args_kwargs.load()

        output = info.dispatcher(datapoint, *other_args, **kwargs)

        assert isinstance(output, type(datapoint))

397
398
399
        if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes:
            assert output.format == datapoint.format

400
    @pytest.mark.parametrize(
401
        ("dispatcher_info", "datapoint_type", "kernel_info"),
402
        [
403
404
405
            pytest.param(
                dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}"
            )
406
            for dispatcher_info in DISPATCHER_INFOS
407
            for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items()
408
409
        ],
    )
410
    def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info):
411
412
413
414
415
416
        dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

        kernel_signature = inspect.signature(kernel_info.kernel)
        kernel_params = list(kernel_signature.parameters.values())[1:]

417
        # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
418
419
420
421
422
423
        # explicitly passed to the kernel.
        input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
        explicit_metadata = {
            datapoints.BoundingBoxes: {"format", "canvas_size"},
        }
        kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439

        dispatcher_params = iter(dispatcher_params)
        for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
            try:
                # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
                # dispatcher parameters that have no kernel equivalent while keeping the order intact.
                while dispatcher_param.name != kernel_param.name:
                    dispatcher_param = next(dispatcher_params)
            except StopIteration:
                raise AssertionError(
                    f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
                    f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
                ) from None

            assert dispatcher_param == kernel_param

440
441
442
443
444
445
446
447
    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
    def test_unkown_type(self, info):
        unkown_input = object()
        (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")

        with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
            info.dispatcher(unkown_input, *other_args, **kwargs)

448
449
450
451
    @make_info_args_kwargs_parametrization(
        [
            info
            for info in DISPATCHER_INFOS
452
            if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_format_bounding_boxes
453
        ],
454
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
455
    )
456
457
458
    def test_bounding_boxes_format_consistency(self, info, args_kwargs):
        (bounding_boxes, *other_args), kwargs = args_kwargs.load()
        format = bounding_boxes.format
459

460
        output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
461
462
463

        assert output.format == format

464

465
@pytest.mark.parametrize(
466
    ("alias", "target"),
467
    [
468
469
470
471
472
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
473
            (F.to_pil_image, F.to_pil_image),
474
            (F.elastic_transform, F.elastic),
475
            (F.to_grayscale, F.rgb_to_grayscale),
476
        ]
477
478
    ],
)
479
480
def test_alias(alias, target):
    assert alias is target
481
482


483
@pytest.mark.parametrize("device", cpu_and_cuda())
484
485
486
487
488
489
490
491
492
493
494
495
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
    stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

    def assert_samples_from_standard_normal(t):
        p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
        return p_value > 1e-4

    image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
    mean = image.mean(dim=(1, 2)).tolist()
    std = image.std(dim=(1, 2)).tolist()

496
    assert_samples_from_standard_normal(F.normalize_image(image, mean, std))
497
498


499
class TestClampBoundingBoxes:
500
501
502
503
504
    @pytest.mark.parametrize(
        "metadata",
        [
            dict(),
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
505
            dict(canvas_size=(1, 1)),
506
507
508
509
510
        ],
    )
    def test_simple_tensor_insufficient_metadata(self, metadata):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

Philip Meier's avatar
Philip Meier committed
511
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
512
            F.clamp_bounding_boxes(simple_tensor, **metadata)
513
514
515
516
517

    @pytest.mark.parametrize(
        "metadata",
        [
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
518
519
            dict(canvas_size=(1, 1)),
            dict(format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
520
521
522
523
524
        ],
    )
    def test_datapoint_explicit_metadata(self, metadata):
        datapoint = next(make_bounding_boxes())

Philip Meier's avatar
Philip Meier committed
525
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
526
            F.clamp_bounding_boxes(datapoint, **metadata)
527
528


529
class TestConvertFormatBoundingBoxes:
530
531
532
533
534
535
536
537
538
    @pytest.mark.parametrize(
        ("inpt", "old_format"),
        [
            (next(make_bounding_boxes()), None),
            (next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
        ],
    )
    def test_missing_new_format(self, inpt, old_format):
        with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
539
            F.convert_format_bounding_boxes(inpt, old_format)
540
541
542
543
544

    def test_simple_tensor_insufficient_metadata(self):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

        with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
545
            F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
546
547
548
549
550

    def test_datapoint_explicit_metadata(self):
        datapoint = next(make_bounding_boxes())

        with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
551
            F.convert_format_bounding_boxes(
552
553
554
555
                datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
            )


556
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
557
#  `transforms_v2_kernel_infos.py`
558
559


560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
    rot = math.radians(angle_)
    cx, cy = center_
    tx, ty = translate_
    sx, sy = [math.radians(sh_) for sh_ in shear_]

    c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
    t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
    c_matrix_inv = np.linalg.inv(c_matrix)
    rs_matrix = np.array(
        [
            [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
            [scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
            [0, 0, 1],
        ]
    )
    shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
    shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
    rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
    true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
    return true_matrix


583
@pytest.mark.parametrize("device", cpu_and_cuda())
584
585
@pytest.mark.parametrize(
    "format",
586
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
587
)
588
589
590
591
592
593
594
@pytest.mark.parametrize(
    "top, left, height, width, expected_bboxes",
    [
        [8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
        [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
    ],
)
595
def test_correctness_crop_bounding_boxes(device, format, top, left, height, width, expected_bboxes):
596
597
598
599
600
601
602
603
604
605
606
607
608

    # Expected bboxes computed using Albumentations:
    # import numpy as np
    # from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox
    # expected_bboxes = []
    # for in_box in in_boxes:
    #     n_in_box = normalize_bbox(in_box, *size)
    #     n_out_box = crop_bbox_by_coords(
    #         n_in_box, (left, top, left + width, top + height), height, width, *size
    #     )
    #     out_box = denormalize_bbox(n_out_box, height, width)
    #     expected_bboxes.append(out_box)

609
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
610
    canvas_size = (64, 76)
611
612
613
614
615
    in_boxes = [
        [10.0, 15.0, 25.0, 35.0],
        [50.0, 5.0, 70.0, 22.0],
        [45.0, 46.0, 56.0, 62.0],
    ]
616
    in_boxes = torch.tensor(in_boxes, device=device)
617
    if format != datapoints.BoundingBoxFormat.XYXY:
618
        in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
619

620
    expected_bboxes = clamp_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
621
        datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
622
623
    ).tolist()

Philip Meier's avatar
Philip Meier committed
624
    output_boxes, output_canvas_size = F.crop_bounding_boxes(
625
        in_boxes,
626
        format,
627
628
        top,
        left,
Philip Meier's avatar
Philip Meier committed
629
630
        canvas_size[0],
        canvas_size[1],
631
632
    )

633
    if format != datapoints.BoundingBoxFormat.XYXY:
634
        output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
635

636
    torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
Philip Meier's avatar
Philip Meier committed
637
    torch.testing.assert_close(output_canvas_size, canvas_size)
638
639


640
@pytest.mark.parametrize("device", cpu_and_cuda())
641
642
643
644
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
    mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    mask[:, 0, :] = 1

645
    out_mask = F.vertical_flip_mask(mask)
646
647
648
649

    expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    expected_mask[:, -1, :] = 1
    torch.testing.assert_close(out_mask, expected_mask)
650
651


652
@pytest.mark.parametrize("device", cpu_and_cuda())
653
654
@pytest.mark.parametrize(
    "format",
655
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
656
657
658
659
660
661
662
663
)
@pytest.mark.parametrize(
    "top, left, height, width, size",
    [
        [0, 0, 30, 30, (60, 60)],
        [-5, 5, 35, 45, (32, 34)],
    ],
)
664
def test_correctness_resized_crop_bounding_boxes(device, format, top, left, height, width, size):
665
    def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
666
667
668
669
670
671
672
        # bbox should be xyxy
        bbox[0] = (bbox[0] - left_) * size_[1] / width_
        bbox[1] = (bbox[1] - top_) * size_[0] / height_
        bbox[2] = (bbox[2] - left_) * size_[1] / width_
        bbox[3] = (bbox[3] - top_) * size_[0] / height_
        return bbox

673
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
674
    canvas_size = (100, 100)
675
676
677
678
679
680
    in_boxes = [
        [10.0, 10.0, 20.0, 20.0],
        [5.0, 10.0, 15.0, 20.0],
    ]
    expected_bboxes = []
    for in_box in in_boxes:
681
        expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
682
683
    expected_bboxes = torch.tensor(expected_bboxes, device=device)

684
    in_boxes = datapoints.BoundingBoxes(
Philip Meier's avatar
Philip Meier committed
685
        in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
686
    )
687
    if format != datapoints.BoundingBoxFormat.XYXY:
688
        in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
689

Philip Meier's avatar
Philip Meier committed
690
    output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
691

692
    if format != datapoints.BoundingBoxFormat.XYXY:
693
        output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
694
695

    torch.testing.assert_close(output_boxes, expected_bboxes)
Philip Meier's avatar
Philip Meier committed
696
    torch.testing.assert_close(output_canvas_size, size)
697
698


699
700
701
702
703
704
705
706
707
708
709
710
def _parse_padding(padding):
    if isinstance(padding, int):
        return [padding] * 4
    if isinstance(padding, list):
        if len(padding) == 1:
            return padding * 4
        if len(padding) == 2:
            return padding * 2  # [left, up, right, down]

    return padding


711
@pytest.mark.parametrize("device", cpu_and_cuda())
712
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
713
def test_correctness_pad_bounding_boxes(device, padding):
714
    def _compute_expected_bbox(bbox, format, padding_):
715
716
        pad_left, pad_up, _, _ = _parse_padding(padding_)

717
        dtype = bbox.dtype
718
719
        bbox = (
            bbox.clone()
720
            if format == datapoints.BoundingBoxFormat.XYXY
721
            else convert_format_bounding_boxes(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
722
        )
723
724
725
726

        bbox[0::2] += pad_left
        bbox[1::2] += pad_up

727
        bbox = convert_format_bounding_boxes(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
728
        if bbox.dtype != dtype:
729
730
            # Temporary cast to original dtype
            # e.g. float32 -> int
731
            bbox = bbox.to(dtype)
732
733
        return bbox

Philip Meier's avatar
Philip Meier committed
734
    def _compute_expected_canvas_size(bbox, padding_):
735
        pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
Philip Meier's avatar
Philip Meier committed
736
        height, width = bbox.canvas_size
737
738
        return height + pad_up + pad_down, width + pad_left + pad_right

739
    for bboxes in make_bounding_boxes(extra_dims=((4,),)):
740
741
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
742
        bboxes_canvas_size = bboxes.canvas_size
743

Philip Meier's avatar
Philip Meier committed
744
745
        output_boxes, output_canvas_size = F.pad_bounding_boxes(
            bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
746
747
        )

Philip Meier's avatar
Philip Meier committed
748
        torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
749

750
751
752
        expected_bboxes = torch.stack(
            [_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()]
        ).reshape(bboxes.shape)
753

754
        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
755
756


757
@pytest.mark.parametrize("device", cpu_and_cuda())
758
759
760
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
    mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)

761
    out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
762
763
764
765
766
767

    expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
    expected_mask[:, 1:-1, 1:-1] = 1
    torch.testing.assert_close(out_mask, expected_mask)


768
@pytest.mark.parametrize("device", cpu_and_cuda())
769
770
771
772
773
774
775
776
@pytest.mark.parametrize(
    "startpoints, endpoints",
    [
        [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
    ],
)
777
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
778
    def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_):
779
780
781
782
783
784
785
786
787
788
789
790
791
        m1 = np.array(
            [
                [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
                [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
            ]
        )
        m2 = np.array(
            [
                [pcoeffs_[6], pcoeffs_[7], 1.0],
                [pcoeffs_[6], pcoeffs_[7], 1.0],
            ]
        )

792
793
794
        bbox_xyxy = convert_format_bounding_boxes(
            bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY
        )
795
796
797
798
799
800
801
802
803
804
805
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
        numer = np.matmul(points, m1.T)
        denom = np.matmul(points, m2.T)
        transformed_points = numer / denom
806
807
808
809
810
811
812
813
        out_bbox = np.array(
            [
                np.min(transformed_points[:, 0]),
                np.min(transformed_points[:, 1]),
                np.max(transformed_points[:, 0]),
                np.max(transformed_points[:, 1]),
            ]
        )
814
815
816
        out_bbox = torch.from_numpy(out_bbox)
        out_bbox = convert_format_bounding_boxes(
            out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_
817
        )
818
        return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
819

Philip Meier's avatar
Philip Meier committed
820
    canvas_size = (32, 38)
821
822
823
824

    pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
    inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

825
    for bboxes in make_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
826
827
        bboxes = bboxes.to(device)

828
        output_bboxes = F.perspective_bounding_boxes(
829
830
            bboxes.as_subclass(torch.Tensor),
            format=bboxes.format,
Philip Meier's avatar
Philip Meier committed
831
            canvas_size=bboxes.canvas_size,
832
833
            startpoints=None,
            endpoints=None,
834
            coefficients=pcoeffs,
835
836
        )

837
838
839
840
841
842
        expected_bboxes = torch.stack(
            [
                _compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs)
                for b in bboxes.reshape(-1, 4).unbind()
            ]
        ).reshape(bboxes.shape)
843

844
        torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
845
846


847
@pytest.mark.parametrize("device", cpu_and_cuda())
848
849
850
851
@pytest.mark.parametrize(
    "output_size",
    [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
)
852
def test_correctness_center_crop_bounding_boxes(device, output_size):
853
    def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
854
        dtype = bbox.dtype
855
        bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
856
857
858
859

        if len(output_size_) == 1:
            output_size_.append(output_size_[-1])

Philip Meier's avatar
Philip Meier committed
860
861
        cy = int(round((canvas_size_[0] - output_size_[0]) * 0.5))
        cx = int(round((canvas_size_[1] - output_size_[1]) * 0.5))
862
863
864
865
866
867
        out_bbox = [
            bbox[0].item() - cx,
            bbox[1].item() - cy,
            bbox[2].item(),
            bbox[3].item(),
        ]
868
        out_bbox = torch.tensor(out_bbox)
869
        out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
Philip Meier's avatar
Philip Meier committed
870
        out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
871
        return out_bbox.to(dtype=dtype, device=bbox.device)
872

873
    for bboxes in make_bounding_boxes(extra_dims=((4,),)):
874
875
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
876
        bboxes_canvas_size = bboxes.canvas_size
877

Philip Meier's avatar
Philip Meier committed
878
879
        output_boxes, output_canvas_size = F.center_crop_bounding_boxes(
            bboxes, bboxes_format, bboxes_canvas_size, output_size
880
        )
881

882
883
884
885
886
887
        expected_bboxes = torch.stack(
            [
                _compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size)
                for b in bboxes.reshape(-1, 4).unbind()
            ]
        ).reshape(bboxes.shape)
888
889

        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
Philip Meier's avatar
Philip Meier committed
890
        torch.testing.assert_close(output_canvas_size, output_size)
891
892


893
@pytest.mark.parametrize("device", cpu_and_cuda())
894
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
895
896
def test_correctness_center_crop_mask(device, output_size):
    def _compute_expected_mask(mask, output_size):
897
898
899
900
901
        crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]

        _, image_height, image_width = mask.shape
        if crop_width > image_height or crop_height > image_width:
            padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
902
            mask = F.pad_image(mask, padding, fill=0)
903
904
905
906
907
908
909

        left = round((image_width - crop_width) * 0.5)
        top = round((image_height - crop_height) * 0.5)

        return mask[:, top : top + crop_height, left : left + crop_width]

    mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
910
    actual = F.center_crop_mask(mask, output_size)
911

912
    expected = _compute_expected_mask(mask, output_size)
913
    torch.testing.assert_close(expected, actual)
914
915
916


# Copied from test/test_functional_tensor.py
917
@pytest.mark.parametrize("device", cpu_and_cuda())
Philip Meier's avatar
Philip Meier committed
918
@pytest.mark.parametrize("canvas_size", ("small", "large"))
919
920
921
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
Philip Meier's avatar
Philip Meier committed
922
def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
923
    fn = F.gaussian_blur_image
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941

    # true_cv2_results = {
    #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
    #     "3_3_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
    #     "3_3_0.5": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
    #     "3_5_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
    #     "3_5_0.5": ...
    #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
    #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
    #     "23_23_1.7": ...
    # }
    p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
    true_cv2_results = torch.load(p)

Philip Meier's avatar
Philip Meier committed
942
    if canvas_size == "small":
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
        tensor = (
            torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
        )
    else:
        tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)

    if dt == torch.float16 and device == "cpu":
        # skip float16 on CPU case
        return

    if dt is not None:
        tensor = tensor.to(dtype=dt)

    _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
    _sigma = sigma[0] if sigma is not None else None
    shape = tensor.shape
    gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
    if gt_key not in true_cv2_results:
        return

    true_out = (
        torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
    )

967
    image = datapoints.Image(tensor)
968
969
970
971
972

    out = fn(image, kernel_size=ksize, sigma=sigma)
    torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


973
974
975
976
977
978
979
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
980
981
def test_to_image(inpt):
    output = F.to_image(inpt)
982
    assert isinstance(output, torch.Tensor)
983
    assert output.shape == (3, 32, 32)
984
985
986
987
988
989
990
991
992
993
994
995

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
996
997
def test_to_pil_image(inpt, mode):
    output = F.to_pil_image(inpt, mode=mode)
998
999
1000
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()
1001
1002
1003
1004


def test_equalize_image_tensor_edge_cases():
    inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
1005
    output = F.equalize_image(inpt)
1006
1007
1008
1009
    torch.testing.assert_close(inpt, output)

    inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
    inpt[..., 100:, 100:] = 1
1010
    output = F.equalize_image(inpt)
1011
    assert output.unique().tolist() == [0, 255]
1012
1013


1014
@pytest.mark.parametrize("device", cpu_and_cuda())
1015
1016
1017
1018
1019
1020
1021
def test_correctness_uniform_temporal_subsample(device):
    video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
    out_video = F.uniform_temporal_subsample(video, 5)
    assert out_video.unique().tolist() == [0, 2, 4, 6, 9]

    out_video = F.uniform_temporal_subsample(video, 8)
    assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]
1022
1023
1024
1025
1026


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
1027
    [info for info in KERNEL_INFOS if info.kernel is F.resize_image],
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
    (input, *other_args), kwargs = args_kwargs.load("cpu")

    output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

    error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
    assert input.ndim == 3, error_msg_fn
    input_stride = input.stride()
    output_stride = output.stride()
    # Here we check output memory format according to the input:
    # if input_stride is (..., 1) then input is most likely channels first and thus
    #   output strides should match channels first strides (H * W, H, 1)
    # if input_stride is (1, ...) then input is most likely channels last and thus
    #   output strides should match channels last strides (1, W * C, C)
    if input_stride[-1] == 1:
        expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
        assert expected_stride == output_stride, error_msg_fn("")
    elif input_stride[0] == 1:
        expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
        assert expected_stride == output_stride, error_msg_fn("")
    else:
        assert False, error_msg_fn("")
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061


def test_resize_float16_no_rounding():
    # Make sure Resize() doesn't round float16 images
    # Non-regression test for https://github.com/pytorch/vision/issues/7667

    img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16)
    out = F.resize(img, size=(10, 10))
    assert out.dtype == torch.float16
    assert (out.round() - out).sum() > 0