test_transforms_v2_functional.py 40.4 KB
Newer Older
1
import inspect
2
import math
3
import os
4
import re
5

6
import numpy as np
7
import PIL.Image
8
import pytest
9
import torch
10

11
from common_utils import (
12
    assert_close,
13
    cache,
14
    cpu_and_cuda,
15
16
    DEFAULT_SQUARE_SPATIAL_SIZE,
    make_bounding_boxes,
17
    needs_cuda,
18
    parametrized_error_message,
19
    set_rng_seed,
20
)
21
from torch.utils._pytree import tree_map
22
from torchvision import datapoints
23
from torchvision.transforms.functional import _get_perspective_coeffs
24
25
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
26
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
27
from torchvision.transforms.v2.utils import is_simple_tensor
28
29
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
30
31


32
33
34
35
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}


36
37
38
39
40
41
@cache
def script(fn):
    try:
        return torch.jit.script(fn)
    except Exception as error:
        raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
42
43


44
45
46
47
48
49
50
51
52
# Scripting a function often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
    f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
)


53
54
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
    args_kwargs = list(args_kwargs_fn(info))
55
56
57
58
    if not args_kwargs:
        raise pytest.UsageError(
            f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
        )
59
60
61
62
63
64
65
66
67
68
69
70
    idx_field_len = len(str(len(args_kwargs)))
    return [
        pytest.param(
            info,
            args_kwargs_,
            marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
            id=f"{info.id}-{idx:0{idx_field_len}}",
        )
        for idx, args_kwargs_ in enumerate(args_kwargs)
    ]


71
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
72
73
74
75
76
77
78
79
    def decorator(test_fn):
        parts = test_fn.__qualname__.split(".")
        if len(parts) == 1:
            test_class_name = None
            test_function_name = parts[0]
        elif len(parts) == 2:
            test_class_name, test_function_name = parts
        else:
80
            raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
81
82
83
84
85
        test_id = (test_class_name, test_function_name)

        argnames = ("info", "args_kwargs")
        argvalues = []
        for info in infos:
86
            argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
87
88
89
90

        return pytest.mark.parametrize(argnames, argvalues)(test_fn)

    return decorator
91
92


Philip Meier's avatar
Philip Meier committed
93
94
95
96
97
98
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


99
100
101
102
103
104
105
@pytest.fixture()
def test_id(request):
    test_class_name = request.cls.__name__ if request.cls is not None else None
    test_function_name = request.node.originalname
    return test_class_name, test_function_name


106
class TestKernels:
107
    sample_inputs = make_info_args_kwargs_parametrization(
108
109
110
        KERNEL_INFOS,
        args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
    )
111
    reference_inputs = make_info_args_kwargs_parametrization(
112
        [info for info in KERNEL_INFOS if info.reference_fn is not None],
113
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
114
    )
115

116
117
118
119
    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.logs_usage],
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    )
120
    @pytest.mark.parametrize("device", cpu_and_cuda())
121
122
123
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

124
125
        (input, *other_args), kwargs = args_kwargs.load(device)
        info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
126
127
128

        spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

129
    @ignore_jit_warning_no_profile
130
    @sample_inputs
131
    @pytest.mark.parametrize("device", cpu_and_cuda())
132
    def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
133
134
        kernel_eager = info.kernel
        kernel_scripted = script(kernel_eager)
135

136
        (input, *other_args), kwargs = args_kwargs.load(device)
137
        input = input.as_subclass(torch.Tensor)
138

139
140
        actual = kernel_scripted(input, *other_args, **kwargs)
        expected = kernel_eager(input, *other_args, **kwargs)
141

142
143
144
145
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
146
            msg=parametrized_error_message(input, other_args, **kwargs),
147
        )
148

149
150
151
152
153
154
    def _unbatch(self, batch, *, data_dims):
        if isinstance(batch, torch.Tensor):
            batched_tensor = batch
            metadata = ()
        else:
            batched_tensor, *metadata = batch
155

156
157
        if batched_tensor.ndim == data_dims:
            return batch
158

159
160
161
162
163
164
        return [
            self._unbatch(unbatched, data_dims=data_dims)
            for unbatched in (
                batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
            )
        ]
165
166

    @sample_inputs
167
    @pytest.mark.parametrize("device", cpu_and_cuda())
168
    def test_batched_vs_single(self, test_id, info, args_kwargs, device):
169
170
        (batched_input, *other_args), kwargs = args_kwargs.load(device)

171
        datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
172
173
174
        # This dictionary contains the number of rightmost dimensions that contain the actual data.
        # Everything to the left is considered a batch dimension.
        data_dims = {
175
            datapoints.Image: 3,
176
            datapoints.BoundingBoxes: 1,
177
178
179
180
            # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
            # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
            # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
            # common ground.
181
182
            datapoints.Mask: 2,
            datapoints.Video: 4,
183
        }.get(datapoint_type)
184
185
        if data_dims is None:
            raise pytest.UsageError(
186
                f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}."
187
188
189
190
191
192
            ) from None
        elif batched_input.ndim <= data_dims:
            pytest.skip("Input is not batched.")
        elif not all(batched_input.shape[:-data_dims]):
            pytest.skip("Input has a degenerate batch shape.")

193
        batched_input = batched_input.as_subclass(torch.Tensor)
194
195
        batched_output = info.kernel(batched_input, *other_args, **kwargs)
        actual = self._unbatch(batched_output, data_dims=data_dims)
196

197
198
        single_inputs = self._unbatch(batched_input, data_dims=data_dims)
        expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
199

200
201
202
203
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
204
            msg=parametrized_error_message(batched_input, *other_args, **kwargs),
205
        )
206
207

    @sample_inputs
208
    @pytest.mark.parametrize("device", cpu_and_cuda())
209
210
    def test_no_inplace(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
211
        input = input.as_subclass(torch.Tensor)
212
213
214
215
216

        if input.numel() == 0:
            pytest.skip("The input has a degenerate shape.")

        input_version = input._version
217
        info.kernel(input, *other_args, **kwargs)
218

219
        assert input._version == input_version
220
221
222

    @sample_inputs
    @needs_cuda
223
    def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
224
        (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
225
        input_cpu = input_cpu.as_subclass(torch.Tensor)
226
227
228
229
230
        input_cuda = input_cpu.to("cuda")

        output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
        output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

231
232
233
234
235
        assert_close(
            output_cuda,
            output_cpu,
            check_device=False,
            **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
236
            msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
237
        )
238
239

    @sample_inputs
240
    @pytest.mark.parametrize("device", cpu_and_cuda())
241
242
    def test_dtype_and_device_consistency(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
243
        input = input.as_subclass(torch.Tensor)
244
245

        output = info.kernel(input, *other_args, **kwargs)
246
247
248
        # Most kernels just return a tensor, but some also return some additional metadata
        if not isinstance(output, torch.Tensor):
            output, *_ = output
249
250
251
252

        assert output.dtype == input.dtype
        assert output.device == input.device

253
    @reference_inputs
254
255
    def test_against_reference(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
256

257
258
259
        actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
        # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
        # metadata regardless of whether the kernel takes it explicitly or not
260
        expected = info.reference_fn(input, *other_args, **kwargs)
261

262
263
264
265
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
266
            msg=parametrized_error_message(input, *other_args, **kwargs),
267
268
269
270
271
272
273
274
        )

    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.float32_vs_uint8],
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
    )
    def test_float32_vs_uint8(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
275
        input = input.as_subclass(torch.Tensor)
276
277
278
279
280
281
282

        if input.dtype != torch.uint8:
            pytest.skip(f"Input dtype is {input.dtype}.")

        adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

        actual = info.kernel(
283
            F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True),
284
285
286
287
            *adapted_other_args,
            **adapted_kwargs,
        )

288
        expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
289
290
291
292
293

        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
294
            msg=parametrized_error_message(input, *other_args, **kwargs),
295
        )
296
297


298
299
300
301
302
303
304
305
306
307
308
309
@pytest.fixture
def spy_on(mocker):
    def make_spy(fn, *, module=None, name=None):
        # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
        module = module or fn.__module__
        name = name or fn.__name__
        spy = mocker.patch(f"{module}.{name}", wraps=fn)
        return spy

    return make_spy


310
class TestDispatchers:
311
    image_sample_inputs = make_info_args_kwargs_parametrization(
312
313
        [info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
314
315
    )

316
317
318
319
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
320
    @pytest.mark.parametrize("device", cpu_and_cuda())
321
322
323
324
325
326
327
328
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

        args, kwargs = args_kwargs.load(device)
        info.dispatcher(*args, **kwargs)

        spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

329
    @ignore_jit_warning_no_profile
330
    @image_sample_inputs
331
    @pytest.mark.parametrize("device", cpu_and_cuda())
332
333
    def test_scripted_smoke(self, info, args_kwargs, device):
        dispatcher = script(info.dispatcher)
334

335
336
        (image_datapoint, *other_args), kwargs = args_kwargs.load(device)
        image_simple_tensor = torch.Tensor(image_datapoint)
337

338
        dispatcher(image_simple_tensor, *other_args, **kwargs)
339

340
341
    # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
    #  replaces this test for them.
342
    @ignore_jit_warning_no_profile
343
344
345
346
347
348
    @pytest.mark.parametrize(
        "dispatcher",
        [
            F.get_dimensions,
            F.get_image_num_channels,
            F.get_image_size,
349
350
            F.get_num_channels,
            F.get_num_frames,
Philip Meier's avatar
Philip Meier committed
351
            F.get_size,
352
            F.rgb_to_grayscale,
353
            F.uniform_temporal_subsample,
354
355
356
357
358
        ],
        ids=lambda dispatcher: dispatcher.__name__,
    )
    def test_scriptable(self, dispatcher):
        script(dispatcher)
359

360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    @image_sample_inputs
    def test_simple_tensor_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
        image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)

        output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)

        # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
        assert type(output) is torch.Tensor

    @make_info_args_kwargs_parametrization(
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
    )
    def test_pil_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()

        if image_datapoint.ndim > 3:
            pytest.skip("Input is batched")

        image_pil = F.to_image_pil(image_datapoint)

        output = info.dispatcher(image_pil, *other_args, **kwargs)

        assert isinstance(output, PIL.Image.Image)

    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
    def test_datapoint_output_type(self, info, args_kwargs):
        (datapoint, *other_args), kwargs = args_kwargs.load()

        output = info.dispatcher(datapoint, *other_args, **kwargs)

        assert isinstance(output, type(datapoint))

397
398
399
        if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes:
            assert output.format == datapoint.format

400
    @pytest.mark.parametrize(
401
        ("dispatcher_info", "datapoint_type", "kernel_info"),
402
        [
403
404
405
            pytest.param(
                dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}"
            )
406
            for dispatcher_info in DISPATCHER_INFOS
407
            for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items()
408
409
        ],
    )
410
    def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info):
411
412
413
414
415
416
        dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

        kernel_signature = inspect.signature(kernel_info.kernel)
        kernel_params = list(kernel_signature.parameters.values())[1:]

417
        # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
418
419
420
421
422
423
        # explicitly passed to the kernel.
        input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
        explicit_metadata = {
            datapoints.BoundingBoxes: {"format", "canvas_size"},
        }
        kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439

        dispatcher_params = iter(dispatcher_params)
        for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
            try:
                # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
                # dispatcher parameters that have no kernel equivalent while keeping the order intact.
                while dispatcher_param.name != kernel_param.name:
                    dispatcher_param = next(dispatcher_params)
            except StopIteration:
                raise AssertionError(
                    f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
                    f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
                ) from None

            assert dispatcher_param == kernel_param

440
441
442
443
444
445
446
447
    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
    def test_unkown_type(self, info):
        unkown_input = object()
        (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")

        with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
            info.dispatcher(unkown_input, *other_args, **kwargs)

448
449
450
451
    @make_info_args_kwargs_parametrization(
        [
            info
            for info in DISPATCHER_INFOS
452
            if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_format_bounding_boxes
453
        ],
454
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
455
    )
456
457
458
    def test_bounding_boxes_format_consistency(self, info, args_kwargs):
        (bounding_boxes, *other_args), kwargs = args_kwargs.load()
        format = bounding_boxes.format
459

460
        output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
461
462
463

        assert output.format == format

464

465
@pytest.mark.parametrize(
466
    ("alias", "target"),
467
    [
468
469
470
471
472
473
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
            (F.to_pil_image, F.to_image_pil),
474
            (F.elastic_transform, F.elastic),
475
            (F.to_grayscale, F.rgb_to_grayscale),
476
        ]
477
478
    ],
)
479
480
def test_alias(alias, target):
    assert alias is target
481
482


483
@pytest.mark.parametrize("device", cpu_and_cuda())
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
    stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

    def assert_samples_from_standard_normal(t):
        p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
        return p_value > 1e-4

    image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
    mean = image.mean(dim=(1, 2)).tolist()
    std = image.std(dim=(1, 2)).tolist()

    assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))


499
class TestClampBoundingBoxes:
500
501
502
503
504
    @pytest.mark.parametrize(
        "metadata",
        [
            dict(),
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
505
            dict(canvas_size=(1, 1)),
506
507
508
509
510
        ],
    )
    def test_simple_tensor_insufficient_metadata(self, metadata):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

Philip Meier's avatar
Philip Meier committed
511
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
512
            F.clamp_bounding_boxes(simple_tensor, **metadata)
513
514
515
516
517

    @pytest.mark.parametrize(
        "metadata",
        [
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
518
519
            dict(canvas_size=(1, 1)),
            dict(format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
520
521
522
523
524
        ],
    )
    def test_datapoint_explicit_metadata(self, metadata):
        datapoint = next(make_bounding_boxes())

Philip Meier's avatar
Philip Meier committed
525
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
526
            F.clamp_bounding_boxes(datapoint, **metadata)
527
528


529
class TestConvertFormatBoundingBoxes:
530
531
532
533
534
535
536
537
538
    @pytest.mark.parametrize(
        ("inpt", "old_format"),
        [
            (next(make_bounding_boxes()), None),
            (next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
        ],
    )
    def test_missing_new_format(self, inpt, old_format):
        with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
539
            F.convert_format_bounding_boxes(inpt, old_format)
540
541
542
543
544

    def test_simple_tensor_insufficient_metadata(self):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

        with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
545
            F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
546
547
548
549
550

    def test_datapoint_explicit_metadata(self):
        datapoint = next(make_bounding_boxes())

        with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
551
            F.convert_format_bounding_boxes(
552
553
554
555
                datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
            )


556
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
557
#  `transforms_v2_kernel_infos.py`
558
559


560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
    rot = math.radians(angle_)
    cx, cy = center_
    tx, ty = translate_
    sx, sy = [math.radians(sh_) for sh_ in shear_]

    c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
    t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
    c_matrix_inv = np.linalg.inv(c_matrix)
    rs_matrix = np.array(
        [
            [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
            [scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
            [0, 0, 1],
        ]
    )
    shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
    shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
    rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
    true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
    return true_matrix


583
@pytest.mark.parametrize("device", cpu_and_cuda())
584
585
@pytest.mark.parametrize(
    "format",
586
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
587
)
588
589
590
591
592
593
594
@pytest.mark.parametrize(
    "top, left, height, width, expected_bboxes",
    [
        [8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
        [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
    ],
)
595
def test_correctness_crop_bounding_boxes(device, format, top, left, height, width, expected_bboxes):
596
597
598
599
600
601
602
603
604
605
606
607
608

    # Expected bboxes computed using Albumentations:
    # import numpy as np
    # from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox
    # expected_bboxes = []
    # for in_box in in_boxes:
    #     n_in_box = normalize_bbox(in_box, *size)
    #     n_out_box = crop_bbox_by_coords(
    #         n_in_box, (left, top, left + width, top + height), height, width, *size
    #     )
    #     out_box = denormalize_bbox(n_out_box, height, width)
    #     expected_bboxes.append(out_box)

609
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
610
    canvas_size = (64, 76)
611
612
613
614
615
    in_boxes = [
        [10.0, 15.0, 25.0, 35.0],
        [50.0, 5.0, 70.0, 22.0],
        [45.0, 46.0, 56.0, 62.0],
    ]
616
    in_boxes = torch.tensor(in_boxes, device=device)
617
    if format != datapoints.BoundingBoxFormat.XYXY:
618
        in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
619

620
    expected_bboxes = clamp_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
621
        datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
622
623
    ).tolist()

Philip Meier's avatar
Philip Meier committed
624
    output_boxes, output_canvas_size = F.crop_bounding_boxes(
625
        in_boxes,
626
        format,
627
628
        top,
        left,
Philip Meier's avatar
Philip Meier committed
629
630
        canvas_size[0],
        canvas_size[1],
631
632
    )

633
    if format != datapoints.BoundingBoxFormat.XYXY:
634
        output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
635

636
    torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
Philip Meier's avatar
Philip Meier committed
637
    torch.testing.assert_close(output_canvas_size, canvas_size)
638
639


640
@pytest.mark.parametrize("device", cpu_and_cuda())
641
642
643
644
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
    mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    mask[:, 0, :] = 1

645
    out_mask = F.vertical_flip_mask(mask)
646
647
648
649

    expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    expected_mask[:, -1, :] = 1
    torch.testing.assert_close(out_mask, expected_mask)
650
651


652
@pytest.mark.parametrize("device", cpu_and_cuda())
653
654
@pytest.mark.parametrize(
    "format",
655
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
656
657
658
659
660
661
662
663
)
@pytest.mark.parametrize(
    "top, left, height, width, size",
    [
        [0, 0, 30, 30, (60, 60)],
        [-5, 5, 35, 45, (32, 34)],
    ],
)
664
def test_correctness_resized_crop_bounding_boxes(device, format, top, left, height, width, size):
665
    def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
666
667
668
669
670
671
672
        # bbox should be xyxy
        bbox[0] = (bbox[0] - left_) * size_[1] / width_
        bbox[1] = (bbox[1] - top_) * size_[0] / height_
        bbox[2] = (bbox[2] - left_) * size_[1] / width_
        bbox[3] = (bbox[3] - top_) * size_[0] / height_
        return bbox

673
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
674
    canvas_size = (100, 100)
675
676
677
678
679
680
    in_boxes = [
        [10.0, 10.0, 20.0, 20.0],
        [5.0, 10.0, 15.0, 20.0],
    ]
    expected_bboxes = []
    for in_box in in_boxes:
681
        expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
682
683
    expected_bboxes = torch.tensor(expected_bboxes, device=device)

684
    in_boxes = datapoints.BoundingBoxes(
Philip Meier's avatar
Philip Meier committed
685
        in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
686
    )
687
    if format != datapoints.BoundingBoxFormat.XYXY:
688
        in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
689

Philip Meier's avatar
Philip Meier committed
690
    output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
691

692
    if format != datapoints.BoundingBoxFormat.XYXY:
693
        output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
694
695

    torch.testing.assert_close(output_boxes, expected_bboxes)
Philip Meier's avatar
Philip Meier committed
696
    torch.testing.assert_close(output_canvas_size, size)
697
698


699
700
701
702
703
704
705
706
707
708
709
710
def _parse_padding(padding):
    if isinstance(padding, int):
        return [padding] * 4
    if isinstance(padding, list):
        if len(padding) == 1:
            return padding * 4
        if len(padding) == 2:
            return padding * 2  # [left, up, right, down]

    return padding


711
@pytest.mark.parametrize("device", cpu_and_cuda())
712
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
713
def test_correctness_pad_bounding_boxes(device, padding):
714
    def _compute_expected_bbox(bbox, format, padding_):
715
716
        pad_left, pad_up, _, _ = _parse_padding(padding_)

717
        dtype = bbox.dtype
718
719
        bbox = (
            bbox.clone()
720
            if format == datapoints.BoundingBoxFormat.XYXY
721
            else convert_format_bounding_boxes(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
722
        )
723
724
725
726

        bbox[0::2] += pad_left
        bbox[1::2] += pad_up

727
        bbox = convert_format_bounding_boxes(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
728
        if bbox.dtype != dtype:
729
730
            # Temporary cast to original dtype
            # e.g. float32 -> int
731
            bbox = bbox.to(dtype)
732
733
        return bbox

Philip Meier's avatar
Philip Meier committed
734
    def _compute_expected_canvas_size(bbox, padding_):
735
        pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
Philip Meier's avatar
Philip Meier committed
736
        height, width = bbox.canvas_size
737
738
        return height + pad_up + pad_down, width + pad_left + pad_right

739
    for bboxes in make_bounding_boxes(extra_dims=((4,),)):
740
741
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
742
        bboxes_canvas_size = bboxes.canvas_size
743

Philip Meier's avatar
Philip Meier committed
744
745
        output_boxes, output_canvas_size = F.pad_bounding_boxes(
            bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
746
747
        )

Philip Meier's avatar
Philip Meier committed
748
        torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
749

750
751
752
        expected_bboxes = torch.stack(
            [_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()]
        ).reshape(bboxes.shape)
753

754
        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
755
756


757
@pytest.mark.parametrize("device", cpu_and_cuda())
758
759
760
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
    mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)

761
    out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
762
763
764
765
766
767

    expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
    expected_mask[:, 1:-1, 1:-1] = 1
    torch.testing.assert_close(out_mask, expected_mask)


768
@pytest.mark.parametrize("device", cpu_and_cuda())
769
770
771
772
773
774
775
776
@pytest.mark.parametrize(
    "startpoints, endpoints",
    [
        [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
    ],
)
777
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
778
    def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_):
779
780
781
782
783
784
785
786
787
788
789
790
791
        m1 = np.array(
            [
                [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
                [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
            ]
        )
        m2 = np.array(
            [
                [pcoeffs_[6], pcoeffs_[7], 1.0],
                [pcoeffs_[6], pcoeffs_[7], 1.0],
            ]
        )

792
793
794
        bbox_xyxy = convert_format_bounding_boxes(
            bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY
        )
795
796
797
798
799
800
801
802
803
804
805
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
        numer = np.matmul(points, m1.T)
        denom = np.matmul(points, m2.T)
        transformed_points = numer / denom
806
807
808
809
810
811
812
813
        out_bbox = np.array(
            [
                np.min(transformed_points[:, 0]),
                np.min(transformed_points[:, 1]),
                np.max(transformed_points[:, 0]),
                np.max(transformed_points[:, 1]),
            ]
        )
814
815
816
        out_bbox = torch.from_numpy(out_bbox)
        out_bbox = convert_format_bounding_boxes(
            out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_
817
        )
818
        return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
819

Philip Meier's avatar
Philip Meier committed
820
    canvas_size = (32, 38)
821
822
823
824

    pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
    inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

825
    for bboxes in make_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
826
827
        bboxes = bboxes.to(device)

828
        output_bboxes = F.perspective_bounding_boxes(
829
830
            bboxes.as_subclass(torch.Tensor),
            format=bboxes.format,
Philip Meier's avatar
Philip Meier committed
831
            canvas_size=bboxes.canvas_size,
832
833
            startpoints=None,
            endpoints=None,
834
            coefficients=pcoeffs,
835
836
        )

837
838
839
840
841
842
        expected_bboxes = torch.stack(
            [
                _compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs)
                for b in bboxes.reshape(-1, 4).unbind()
            ]
        ).reshape(bboxes.shape)
843

844
        torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
845
846


847
@pytest.mark.parametrize("device", cpu_and_cuda())
848
849
850
851
@pytest.mark.parametrize(
    "output_size",
    [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
)
852
def test_correctness_center_crop_bounding_boxes(device, output_size):
853
    def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
854
        dtype = bbox.dtype
855
        bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
856
857
858
859

        if len(output_size_) == 1:
            output_size_.append(output_size_[-1])

Philip Meier's avatar
Philip Meier committed
860
861
        cy = int(round((canvas_size_[0] - output_size_[0]) * 0.5))
        cx = int(round((canvas_size_[1] - output_size_[1]) * 0.5))
862
863
864
865
866
867
        out_bbox = [
            bbox[0].item() - cx,
            bbox[1].item() - cy,
            bbox[2].item(),
            bbox[3].item(),
        ]
868
        out_bbox = torch.tensor(out_bbox)
869
        out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
Philip Meier's avatar
Philip Meier committed
870
        out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
871
        return out_bbox.to(dtype=dtype, device=bbox.device)
872

873
    for bboxes in make_bounding_boxes(extra_dims=((4,),)):
874
875
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
876
        bboxes_canvas_size = bboxes.canvas_size
877

Philip Meier's avatar
Philip Meier committed
878
879
        output_boxes, output_canvas_size = F.center_crop_bounding_boxes(
            bboxes, bboxes_format, bboxes_canvas_size, output_size
880
        )
881

882
883
884
885
886
887
        expected_bboxes = torch.stack(
            [
                _compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size)
                for b in bboxes.reshape(-1, 4).unbind()
            ]
        ).reshape(bboxes.shape)
888
889

        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
Philip Meier's avatar
Philip Meier committed
890
        torch.testing.assert_close(output_canvas_size, output_size)
891
892


893
@pytest.mark.parametrize("device", cpu_and_cuda())
894
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
895
896
def test_correctness_center_crop_mask(device, output_size):
    def _compute_expected_mask(mask, output_size):
897
898
899
900
901
902
903
904
905
906
907
908
909
        crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]

        _, image_height, image_width = mask.shape
        if crop_width > image_height or crop_height > image_width:
            padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
            mask = F.pad_image_tensor(mask, padding, fill=0)

        left = round((image_width - crop_width) * 0.5)
        top = round((image_height - crop_height) * 0.5)

        return mask[:, top : top + crop_height, left : left + crop_width]

    mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
910
    actual = F.center_crop_mask(mask, output_size)
911

912
    expected = _compute_expected_mask(mask, output_size)
913
    torch.testing.assert_close(expected, actual)
914
915
916


# Copied from test/test_functional_tensor.py
917
@pytest.mark.parametrize("device", cpu_and_cuda())
Philip Meier's avatar
Philip Meier committed
918
@pytest.mark.parametrize("canvas_size", ("small", "large"))
919
920
921
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
Philip Meier's avatar
Philip Meier committed
922
def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
    fn = F.gaussian_blur_image_tensor

    # true_cv2_results = {
    #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
    #     "3_3_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
    #     "3_3_0.5": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
    #     "3_5_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
    #     "3_5_0.5": ...
    #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
    #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
    #     "23_23_1.7": ...
    # }
    p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
    true_cv2_results = torch.load(p)

Philip Meier's avatar
Philip Meier committed
942
    if canvas_size == "small":
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
        tensor = (
            torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
        )
    else:
        tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)

    if dt == torch.float16 and device == "cpu":
        # skip float16 on CPU case
        return

    if dt is not None:
        tensor = tensor.to(dtype=dt)

    _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
    _sigma = sigma[0] if sigma is not None else None
    shape = tensor.shape
    gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
    if gt_key not in true_cv2_results:
        return

    true_out = (
        torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
    )

967
    image = datapoints.Image(tensor)
968
969
970
971
972

    out = fn(image, kernel_size=ksize, sigma=sigma)
    torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


973
974
975
976
977
978
979
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
980
981
def test_to_image_tensor(inpt):
    output = F.to_image_tensor(inpt)
982
    assert isinstance(output, torch.Tensor)
983
    assert output.shape == (3, 32, 32)
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
def test_to_image_pil(inpt, mode):
    output = F.to_image_pil(inpt, mode=mode)
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011


def test_equalize_image_tensor_edge_cases():
    inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
    output = F.equalize_image_tensor(inpt)
    torch.testing.assert_close(inpt, output)

    inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
    inpt[..., 100:, 100:] = 1
    output = F.equalize_image_tensor(inpt)
    assert output.unique().tolist() == [0, 255]
1012
1013


1014
@pytest.mark.parametrize("device", cpu_and_cuda())
1015
1016
1017
1018
1019
1020
1021
def test_correctness_uniform_temporal_subsample(device):
    video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
    out_video = F.uniform_temporal_subsample(video, 5)
    assert out_video.unique().tolist() == [0, 2, 4, 6, 9]

    out_video = F.uniform_temporal_subsample(video, 8)
    assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
    [info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor],
    args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
    (input, *other_args), kwargs = args_kwargs.load("cpu")

    output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

    error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
    assert input.ndim == 3, error_msg_fn
    input_stride = input.stride()
    output_stride = output.stride()
    # Here we check output memory format according to the input:
    # if input_stride is (..., 1) then input is most likely channels first and thus
    #   output strides should match channels first strides (H * W, H, 1)
    # if input_stride is (1, ...) then input is most likely channels last and thus
    #   output strides should match channels last strides (1, W * C, C)
    if input_stride[-1] == 1:
        expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
        assert expected_stride == output_stride, error_msg_fn("")
    elif input_stride[0] == 1:
        expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
        assert expected_stride == output_stride, error_msg_fn("")
    else:
        assert False, error_msg_fn("")
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061


def test_resize_float16_no_rounding():
    # Make sure Resize() doesn't round float16 images
    # Non-regression test for https://github.com/pytorch/vision/issues/7667

    img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16)
    out = F.resize(img, size=(10, 10))
    assert out.dtype == torch.float16
    assert (out.round() - out).sum() > 0