test_transforms_v2_functional.py 43.7 KB
Newer Older
1
import inspect
2
import math
3
import os
4
import re
5

6
7
from typing import get_type_hints

8
import numpy as np
9
import PIL.Image
10
import pytest
11
12

import torch
13

14
from common_utils import (
15
    assert_close,
16
    cache,
17
    cpu_and_cuda,
18
19
    DEFAULT_SQUARE_SPATIAL_SIZE,
    make_bounding_boxes,
20
    needs_cuda,
21
    parametrized_error_message,
22
    set_rng_seed,
23
)
24
from torch.utils._pytree import tree_map
25
from torchvision import datapoints
26
from torchvision.transforms.functional import _get_perspective_coeffs
27
28
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
29
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
30
from torchvision.transforms.v2.utils import is_simple_tensor
31
32
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
33
34


35
36
37
38
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}


39
40
41
42
43
44
@cache
def script(fn):
    try:
        return torch.jit.script(fn)
    except Exception as error:
        raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
45
46


47
48
49
50
51
52
53
54
55
# Scripting a function often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
    f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
)


56
57
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
    args_kwargs = list(args_kwargs_fn(info))
58
59
60
61
    if not args_kwargs:
        raise pytest.UsageError(
            f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
        )
62
63
64
65
66
67
68
69
70
71
72
73
    idx_field_len = len(str(len(args_kwargs)))
    return [
        pytest.param(
            info,
            args_kwargs_,
            marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
            id=f"{info.id}-{idx:0{idx_field_len}}",
        )
        for idx, args_kwargs_ in enumerate(args_kwargs)
    ]


74
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
75
76
77
78
79
80
81
82
    def decorator(test_fn):
        parts = test_fn.__qualname__.split(".")
        if len(parts) == 1:
            test_class_name = None
            test_function_name = parts[0]
        elif len(parts) == 2:
            test_class_name, test_function_name = parts
        else:
83
            raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
84
85
86
87
88
        test_id = (test_class_name, test_function_name)

        argnames = ("info", "args_kwargs")
        argvalues = []
        for info in infos:
89
            argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
90
91
92
93

        return pytest.mark.parametrize(argnames, argvalues)(test_fn)

    return decorator
94
95


Philip Meier's avatar
Philip Meier committed
96
97
98
99
100
101
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


102
103
104
105
106
107
108
@pytest.fixture()
def test_id(request):
    test_class_name = request.cls.__name__ if request.cls is not None else None
    test_function_name = request.node.originalname
    return test_class_name, test_function_name


109
class TestKernels:
110
    sample_inputs = make_info_args_kwargs_parametrization(
111
112
113
        KERNEL_INFOS,
        args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
    )
114
    reference_inputs = make_info_args_kwargs_parametrization(
115
        [info for info in KERNEL_INFOS if info.reference_fn is not None],
116
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
117
    )
118

119
120
121
122
    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.logs_usage],
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    )
123
    @pytest.mark.parametrize("device", cpu_and_cuda())
124
125
126
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

127
128
        (input, *other_args), kwargs = args_kwargs.load(device)
        info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
129
130
131

        spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

132
    @ignore_jit_warning_no_profile
133
    @sample_inputs
134
    @pytest.mark.parametrize("device", cpu_and_cuda())
135
    def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
136
137
        kernel_eager = info.kernel
        kernel_scripted = script(kernel_eager)
138

139
        (input, *other_args), kwargs = args_kwargs.load(device)
140
        input = input.as_subclass(torch.Tensor)
141

142
143
        actual = kernel_scripted(input, *other_args, **kwargs)
        expected = kernel_eager(input, *other_args, **kwargs)
144

145
146
147
148
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
149
            msg=parametrized_error_message(input, other_args, **kwargs),
150
        )
151

152
153
154
155
156
157
    def _unbatch(self, batch, *, data_dims):
        if isinstance(batch, torch.Tensor):
            batched_tensor = batch
            metadata = ()
        else:
            batched_tensor, *metadata = batch
158

159
160
        if batched_tensor.ndim == data_dims:
            return batch
161

162
163
164
165
166
167
        return [
            self._unbatch(unbatched, data_dims=data_dims)
            for unbatched in (
                batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
            )
        ]
168
169

    @sample_inputs
170
    @pytest.mark.parametrize("device", cpu_and_cuda())
171
    def test_batched_vs_single(self, test_id, info, args_kwargs, device):
172
173
        (batched_input, *other_args), kwargs = args_kwargs.load(device)

174
        datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
175
176
177
        # This dictionary contains the number of rightmost dimensions that contain the actual data.
        # Everything to the left is considered a batch dimension.
        data_dims = {
178
            datapoints.Image: 3,
179
            datapoints.BoundingBoxes: 1,
180
181
182
183
            # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
            # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
            # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
            # common ground.
184
185
            datapoints.Mask: 2,
            datapoints.Video: 4,
186
        }.get(datapoint_type)
187
188
        if data_dims is None:
            raise pytest.UsageError(
189
                f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}."
190
191
192
193
194
195
            ) from None
        elif batched_input.ndim <= data_dims:
            pytest.skip("Input is not batched.")
        elif not all(batched_input.shape[:-data_dims]):
            pytest.skip("Input has a degenerate batch shape.")

196
        batched_input = batched_input.as_subclass(torch.Tensor)
197
198
        batched_output = info.kernel(batched_input, *other_args, **kwargs)
        actual = self._unbatch(batched_output, data_dims=data_dims)
199

200
201
        single_inputs = self._unbatch(batched_input, data_dims=data_dims)
        expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
202

203
204
205
206
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
207
            msg=parametrized_error_message(batched_input, *other_args, **kwargs),
208
        )
209
210

    @sample_inputs
211
    @pytest.mark.parametrize("device", cpu_and_cuda())
212
213
    def test_no_inplace(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
214
        input = input.as_subclass(torch.Tensor)
215
216
217
218
219

        if input.numel() == 0:
            pytest.skip("The input has a degenerate shape.")

        input_version = input._version
220
        info.kernel(input, *other_args, **kwargs)
221

222
        assert input._version == input_version
223
224
225

    @sample_inputs
    @needs_cuda
226
    def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
227
        (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
228
        input_cpu = input_cpu.as_subclass(torch.Tensor)
229
230
231
232
233
        input_cuda = input_cpu.to("cuda")

        output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
        output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

234
235
236
237
238
        assert_close(
            output_cuda,
            output_cpu,
            check_device=False,
            **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
239
            msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
240
        )
241
242

    @sample_inputs
243
    @pytest.mark.parametrize("device", cpu_and_cuda())
244
245
    def test_dtype_and_device_consistency(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
246
        input = input.as_subclass(torch.Tensor)
247
248

        output = info.kernel(input, *other_args, **kwargs)
249
250
251
        # Most kernels just return a tensor, but some also return some additional metadata
        if not isinstance(output, torch.Tensor):
            output, *_ = output
252
253
254
255

        assert output.dtype == input.dtype
        assert output.device == input.device

256
    @reference_inputs
257
258
    def test_against_reference(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
259

260
261
262
        actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
        # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
        # metadata regardless of whether the kernel takes it explicitly or not
263
        expected = info.reference_fn(input, *other_args, **kwargs)
264

265
266
267
268
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
269
            msg=parametrized_error_message(input, *other_args, **kwargs),
270
271
272
273
274
275
276
277
        )

    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.float32_vs_uint8],
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
    )
    def test_float32_vs_uint8(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
278
        input = input.as_subclass(torch.Tensor)
279
280
281
282
283
284
285

        if input.dtype != torch.uint8:
            pytest.skip(f"Input dtype is {input.dtype}.")

        adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

        actual = info.kernel(
286
            F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True),
287
288
289
290
            *adapted_other_args,
            **adapted_kwargs,
        )

291
        expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
292
293
294
295
296

        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
297
            msg=parametrized_error_message(input, *other_args, **kwargs),
298
        )
299
300


301
302
303
304
305
306
307
308
309
310
311
312
@pytest.fixture
def spy_on(mocker):
    def make_spy(fn, *, module=None, name=None):
        # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
        module = module or fn.__module__
        name = name or fn.__name__
        spy = mocker.patch(f"{module}.{name}", wraps=fn)
        return spy

    return make_spy


313
class TestDispatchers:
314
    image_sample_inputs = make_info_args_kwargs_parametrization(
315
316
        [info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
317
318
    )

319
320
321
322
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
323
    @pytest.mark.parametrize("device", cpu_and_cuda())
324
325
326
327
328
329
330
331
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

        args, kwargs = args_kwargs.load(device)
        info.dispatcher(*args, **kwargs)

        spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

332
    @ignore_jit_warning_no_profile
333
    @image_sample_inputs
334
    @pytest.mark.parametrize("device", cpu_and_cuda())
335
336
    def test_scripted_smoke(self, info, args_kwargs, device):
        dispatcher = script(info.dispatcher)
337

338
339
        (image_datapoint, *other_args), kwargs = args_kwargs.load(device)
        image_simple_tensor = torch.Tensor(image_datapoint)
340

341
        dispatcher(image_simple_tensor, *other_args, **kwargs)
342

343
344
    # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
    #  replaces this test for them.
345
    @ignore_jit_warning_no_profile
346
347
348
349
350
351
    @pytest.mark.parametrize(
        "dispatcher",
        [
            F.get_dimensions,
            F.get_image_num_channels,
            F.get_image_size,
352
353
            F.get_num_channels,
            F.get_num_frames,
Philip Meier's avatar
Philip Meier committed
354
            F.get_size,
355
            F.rgb_to_grayscale,
356
            F.uniform_temporal_subsample,
357
358
359
360
361
        ],
        ids=lambda dispatcher: dispatcher.__name__,
    )
    def test_scriptable(self, dispatcher):
        script(dispatcher)
362

363
    @image_sample_inputs
364
    def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
365
366
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
        image_simple_tensor = torch.Tensor(image_datapoint)
367

368
        kernel_info = info.kernel_infos[datapoints.Image]
369
        spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
370
371
372
373
374

        info.dispatcher(image_simple_tensor, *other_args, **kwargs)

        spy.assert_called_once()

375
376
377
378
379
380
381
382
383
384
    @image_sample_inputs
    def test_simple_tensor_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
        image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)

        output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)

        # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
        assert type(output) is torch.Tensor

385
    @make_info_args_kwargs_parametrization(
386
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
387
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
388
389
    )
    def test_dispatch_pil(self, info, args_kwargs, spy_on):
390
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
391

392
        if image_datapoint.ndim > 3:
393
394
            pytest.skip("Input is batched")

395
        image_pil = F.to_image_pil(image_datapoint)
396
397

        pil_kernel_info = info.pil_kernel_info
398
        spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
399
400
401
402
403

        info.dispatcher(image_pil, *other_args, **kwargs)

        spy.assert_called_once()

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    @make_info_args_kwargs_parametrization(
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
    )
    def test_pil_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()

        if image_datapoint.ndim > 3:
            pytest.skip("Input is batched")

        image_pil = F.to_image_pil(image_datapoint)

        output = info.dispatcher(image_pil, *other_args, **kwargs)

        assert isinstance(output, PIL.Image.Image)

420
    @make_info_args_kwargs_parametrization(
421
422
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
423
    )
424
425
    def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
        (datapoint, *other_args), kwargs = args_kwargs.load()
426

427
        method_name = info.id
428
429
430
        method = getattr(datapoint, method_name)
        datapoint_type = type(datapoint)
        spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")
431

432
        info.dispatcher(datapoint, *other_args, **kwargs)
433
434
435

        spy.assert_called_once()

436
437
438
439
440
441
442
443
444
445
446
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
    def test_datapoint_output_type(self, info, args_kwargs):
        (datapoint, *other_args), kwargs = args_kwargs.load()

        output = info.dispatcher(datapoint, *other_args, **kwargs)

        assert isinstance(output, type(datapoint))

447
    @pytest.mark.parametrize(
448
        ("dispatcher_info", "datapoint_type", "kernel_info"),
449
        [
450
451
452
            pytest.param(
                dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}"
            )
453
            for dispatcher_info in DISPATCHER_INFOS
454
            for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items()
455
456
        ],
    )
457
    def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info):
458
459
460
461
462
463
        dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

        kernel_signature = inspect.signature(kernel_info.kernel)
        kernel_params = list(kernel_signature.parameters.values())[1:]

464
        # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
465
        # explicit passed to the kernel.
466
467
        datapoint_type_metadata = datapoint_type.__annotations__.keys()
        kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata]
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

        dispatcher_params = iter(dispatcher_params)
        for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
            try:
                # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
                # dispatcher parameters that have no kernel equivalent while keeping the order intact.
                while dispatcher_param.name != kernel_param.name:
                    dispatcher_param = next(dispatcher_params)
            except StopIteration:
                raise AssertionError(
                    f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
                    f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
                ) from None

            assert dispatcher_param == kernel_param

    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
485
    def test_dispatcher_datapoint_signatures_consistency(self, info):
486
        try:
487
            datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id)
488
        except AttributeError:
489
            pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.")
490
491
492
493

        dispatcher_signature = inspect.signature(info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

494
495
        datapoint_signature = inspect.signature(datapoint_method)
        datapoint_params = list(datapoint_signature.parameters.values())[1:]
496

497
498
499
500
501
502
        # Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
        # defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
        # natively concrete dispatcher annotations.
        datapoint_annotations = get_type_hints(datapoint_method)
        for param in datapoint_params:
            param._annotation = datapoint_annotations[param.name]
503

504
        assert dispatcher_params == datapoint_params
505

506
507
508
509
510
511
512
513
    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
    def test_unkown_type(self, info):
        unkown_input = object()
        (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")

        with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
            info.dispatcher(unkown_input, *other_args, **kwargs)

514
515
516
517
    @make_info_args_kwargs_parametrization(
        [
            info
            for info in DISPATCHER_INFOS
518
            if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_format_bounding_boxes
519
        ],
520
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
521
    )
522
523
524
    def test_bounding_boxes_format_consistency(self, info, args_kwargs):
        (bounding_boxes, *other_args), kwargs = args_kwargs.load()
        format = bounding_boxes.format
525

526
        output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
527
528
529

        assert output.format == format

530

531
@pytest.mark.parametrize(
532
    ("alias", "target"),
533
    [
534
535
536
537
538
539
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
            (F.to_pil_image, F.to_image_pil),
540
            (F.elastic_transform, F.elastic),
541
            (F.to_grayscale, F.rgb_to_grayscale),
542
        ]
543
544
    ],
)
545
546
def test_alias(alias, target):
    assert alias is target
547
548


549
@pytest.mark.parametrize("device", cpu_and_cuda())
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
    stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

    def assert_samples_from_standard_normal(t):
        p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
        return p_value > 1e-4

    image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
    mean = image.mean(dim=(1, 2)).tolist()
    std = image.std(dim=(1, 2)).tolist()

    assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))


565
class TestClampBoundingBoxes:
566
567
568
569
570
    @pytest.mark.parametrize(
        "metadata",
        [
            dict(),
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
571
            dict(canvas_size=(1, 1)),
572
573
574
575
576
        ],
    )
    def test_simple_tensor_insufficient_metadata(self, metadata):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

Philip Meier's avatar
Philip Meier committed
577
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
578
            F.clamp_bounding_boxes(simple_tensor, **metadata)
579
580
581
582
583

    @pytest.mark.parametrize(
        "metadata",
        [
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
584
585
            dict(canvas_size=(1, 1)),
            dict(format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
586
587
588
589
590
        ],
    )
    def test_datapoint_explicit_metadata(self, metadata):
        datapoint = next(make_bounding_boxes())

Philip Meier's avatar
Philip Meier committed
591
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
592
            F.clamp_bounding_boxes(datapoint, **metadata)
593
594


595
class TestConvertFormatBoundingBoxes:
596
597
598
599
600
601
602
603
604
    @pytest.mark.parametrize(
        ("inpt", "old_format"),
        [
            (next(make_bounding_boxes()), None),
            (next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
        ],
    )
    def test_missing_new_format(self, inpt, old_format):
        with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
605
            F.convert_format_bounding_boxes(inpt, old_format)
606
607
608
609
610

    def test_simple_tensor_insufficient_metadata(self):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

        with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
611
            F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
612
613
614
615
616

    def test_datapoint_explicit_metadata(self):
        datapoint = next(make_bounding_boxes())

        with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
617
            F.convert_format_bounding_boxes(
618
619
620
621
                datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
            )


622
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
623
#  `transforms_v2_kernel_infos.py`
624
625


626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
    rot = math.radians(angle_)
    cx, cy = center_
    tx, ty = translate_
    sx, sy = [math.radians(sh_) for sh_ in shear_]

    c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
    t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
    c_matrix_inv = np.linalg.inv(c_matrix)
    rs_matrix = np.array(
        [
            [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
            [scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
            [0, 0, 1],
        ]
    )
    shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
    shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
    rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
    true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
    return true_matrix


649
@pytest.mark.parametrize("device", cpu_and_cuda())
650
651
@pytest.mark.parametrize(
    "format",
652
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
653
)
654
655
656
657
658
659
660
@pytest.mark.parametrize(
    "top, left, height, width, expected_bboxes",
    [
        [8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
        [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
    ],
)
661
def test_correctness_crop_bounding_boxes(device, format, top, left, height, width, expected_bboxes):
662
663
664
665
666
667
668
669
670
671
672
673
674

    # Expected bboxes computed using Albumentations:
    # import numpy as np
    # from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox
    # expected_bboxes = []
    # for in_box in in_boxes:
    #     n_in_box = normalize_bbox(in_box, *size)
    #     n_out_box = crop_bbox_by_coords(
    #         n_in_box, (left, top, left + width, top + height), height, width, *size
    #     )
    #     out_box = denormalize_bbox(n_out_box, height, width)
    #     expected_bboxes.append(out_box)

675
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
676
    canvas_size = (64, 76)
677
678
679
680
681
    in_boxes = [
        [10.0, 15.0, 25.0, 35.0],
        [50.0, 5.0, 70.0, 22.0],
        [45.0, 46.0, 56.0, 62.0],
    ]
682
    in_boxes = torch.tensor(in_boxes, device=device)
683
    if format != datapoints.BoundingBoxFormat.XYXY:
684
        in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
685

686
    expected_bboxes = clamp_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
687
        datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
688
689
    ).tolist()

Philip Meier's avatar
Philip Meier committed
690
    output_boxes, output_canvas_size = F.crop_bounding_boxes(
691
        in_boxes,
692
        format,
693
694
        top,
        left,
Philip Meier's avatar
Philip Meier committed
695
696
        canvas_size[0],
        canvas_size[1],
697
698
    )

699
    if format != datapoints.BoundingBoxFormat.XYXY:
700
        output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
701

702
    torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
Philip Meier's avatar
Philip Meier committed
703
    torch.testing.assert_close(output_canvas_size, canvas_size)
704
705


706
@pytest.mark.parametrize("device", cpu_and_cuda())
707
708
709
710
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
    mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    mask[:, 0, :] = 1

711
    out_mask = F.vertical_flip_mask(mask)
712
713
714
715

    expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    expected_mask[:, -1, :] = 1
    torch.testing.assert_close(out_mask, expected_mask)
716
717


718
@pytest.mark.parametrize("device", cpu_and_cuda())
719
720
@pytest.mark.parametrize(
    "format",
721
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
722
723
724
725
726
727
728
729
)
@pytest.mark.parametrize(
    "top, left, height, width, size",
    [
        [0, 0, 30, 30, (60, 60)],
        [-5, 5, 35, 45, (32, 34)],
    ],
)
730
def test_correctness_resized_crop_bounding_boxes(device, format, top, left, height, width, size):
731
    def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
732
733
734
735
736
737
738
        # bbox should be xyxy
        bbox[0] = (bbox[0] - left_) * size_[1] / width_
        bbox[1] = (bbox[1] - top_) * size_[0] / height_
        bbox[2] = (bbox[2] - left_) * size_[1] / width_
        bbox[3] = (bbox[3] - top_) * size_[0] / height_
        return bbox

739
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
740
    canvas_size = (100, 100)
741
742
743
744
745
746
    in_boxes = [
        [10.0, 10.0, 20.0, 20.0],
        [5.0, 10.0, 15.0, 20.0],
    ]
    expected_bboxes = []
    for in_box in in_boxes:
747
        expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
748
749
    expected_bboxes = torch.tensor(expected_bboxes, device=device)

750
    in_boxes = datapoints.BoundingBoxes(
Philip Meier's avatar
Philip Meier committed
751
        in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
752
    )
753
    if format != datapoints.BoundingBoxFormat.XYXY:
754
        in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
755

Philip Meier's avatar
Philip Meier committed
756
    output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
757

758
    if format != datapoints.BoundingBoxFormat.XYXY:
759
        output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
760
761

    torch.testing.assert_close(output_boxes, expected_bboxes)
Philip Meier's avatar
Philip Meier committed
762
    torch.testing.assert_close(output_canvas_size, size)
763
764


765
766
767
768
769
770
771
772
773
774
775
776
def _parse_padding(padding):
    if isinstance(padding, int):
        return [padding] * 4
    if isinstance(padding, list):
        if len(padding) == 1:
            return padding * 4
        if len(padding) == 2:
            return padding * 2  # [left, up, right, down]

    return padding


777
@pytest.mark.parametrize("device", cpu_and_cuda())
778
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
779
def test_correctness_pad_bounding_boxes(device, padding):
780
781
782
    def _compute_expected_bbox(bbox, padding_):
        pad_left, pad_up, _, _ = _parse_padding(padding_)

783
784
        dtype = bbox.dtype
        format = bbox.format
785
786
        bbox = (
            bbox.clone()
787
            if format == datapoints.BoundingBoxFormat.XYXY
788
            else convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
789
        )
790
791
792
793

        bbox[0::2] += pad_left
        bbox[1::2] += pad_up

794
        bbox = convert_format_bounding_boxes(bbox, new_format=format)
795
        if bbox.dtype != dtype:
796
797
            # Temporary cast to original dtype
            # e.g. float32 -> int
798
            bbox = bbox.to(dtype)
799
800
        return bbox

Philip Meier's avatar
Philip Meier committed
801
    def _compute_expected_canvas_size(bbox, padding_):
802
        pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
Philip Meier's avatar
Philip Meier committed
803
        height, width = bbox.canvas_size
804
805
        return height + pad_up + pad_down, width + pad_left + pad_right

806
807
808
    for bboxes in make_bounding_boxes():
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
809
        bboxes_canvas_size = bboxes.canvas_size
810

Philip Meier's avatar
Philip Meier committed
811
812
        output_boxes, output_canvas_size = F.pad_bounding_boxes(
            bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
813
814
        )

Philip Meier's avatar
Philip Meier committed
815
        torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
816

817
        if bboxes.ndim < 2 or bboxes.shape[0] == 0:
818
819
820
821
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
Philip Meier's avatar
Philip Meier committed
822
            bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
823
824
825
826
827
828
            expected_bboxes.append(_compute_expected_bbox(bbox, padding))

        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
829
        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
830
831


832
@pytest.mark.parametrize("device", cpu_and_cuda())
833
834
835
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
    mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)

836
    out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
837
838
839
840
841
842

    expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
    expected_mask[:, 1:-1, 1:-1] = 1
    torch.testing.assert_close(out_mask, expected_mask)


843
@pytest.mark.parametrize("device", cpu_and_cuda())
844
845
846
847
848
849
850
851
@pytest.mark.parametrize(
    "startpoints, endpoints",
    [
        [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
    ],
)
852
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
853
854
855
856
857
858
859
860
861
862
863
864
865
866
    def _compute_expected_bbox(bbox, pcoeffs_):
        m1 = np.array(
            [
                [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
                [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
            ]
        )
        m2 = np.array(
            [
                [pcoeffs_[6], pcoeffs_[7], 1.0],
                [pcoeffs_[6], pcoeffs_[7], 1.0],
            ]
        )

867
        bbox_xyxy = convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
868
869
870
871
872
873
874
875
876
877
878
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
        numer = np.matmul(points, m1.T)
        denom = np.matmul(points, m2.T)
        transformed_points = numer / denom
879
880
881
882
883
884
885
886
        out_bbox = np.array(
            [
                np.min(transformed_points[:, 0]),
                np.min(transformed_points[:, 1]),
                np.max(transformed_points[:, 0]),
                np.max(transformed_points[:, 1]),
            ]
        )
887
        out_bbox = datapoints.BoundingBoxes(
888
            out_bbox,
889
            format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
890
            canvas_size=bbox.canvas_size,
891
            dtype=bbox.dtype,
892
893
            device=bbox.device,
        )
894
        return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format))
895

Philip Meier's avatar
Philip Meier committed
896
    canvas_size = (32, 38)
897
898
899
900

    pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
    inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

Philip Meier's avatar
Philip Meier committed
901
    for bboxes in make_bounding_boxes(canvas_size=canvas_size, extra_dims=((4,),)):
902
903
        bboxes = bboxes.to(device)

904
        output_bboxes = F.perspective_bounding_boxes(
905
906
            bboxes.as_subclass(torch.Tensor),
            format=bboxes.format,
Philip Meier's avatar
Philip Meier committed
907
            canvas_size=bboxes.canvas_size,
908
909
            startpoints=None,
            endpoints=None,
910
            coefficients=pcoeffs,
911
912
913
914
915
916
917
        )

        if bboxes.ndim < 2:
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
Philip Meier's avatar
Philip Meier committed
918
            bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, canvas_size=bboxes.canvas_size)
919
920
921
922
923
            expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
924
        torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
925
926


927
@pytest.mark.parametrize("device", cpu_and_cuda())
928
929
930
931
@pytest.mark.parametrize(
    "output_size",
    [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
)
932
def test_correctness_center_crop_bounding_boxes(device, output_size):
933
934
    def _compute_expected_bbox(bbox, output_size_):
        format_ = bbox.format
Philip Meier's avatar
Philip Meier committed
935
        canvas_size_ = bbox.canvas_size
936
        dtype = bbox.dtype
937
        bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
938
939
940
941

        if len(output_size_) == 1:
            output_size_.append(output_size_[-1])

Philip Meier's avatar
Philip Meier committed
942
943
        cy = int(round((canvas_size_[0] - output_size_[0]) * 0.5))
        cx = int(round((canvas_size_[1] - output_size_[1]) * 0.5))
944
945
946
947
948
949
        out_bbox = [
            bbox[0].item() - cx,
            bbox[1].item() - cy,
            bbox[2].item(),
            bbox[3].item(),
        ]
950
        out_bbox = torch.tensor(out_bbox)
951
        out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
Philip Meier's avatar
Philip Meier committed
952
        out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
953
        return out_bbox.to(dtype=dtype, device=bbox.device)
954

955
    for bboxes in make_bounding_boxes(extra_dims=((4,),)):
956
957
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
958
        bboxes_canvas_size = bboxes.canvas_size
959

Philip Meier's avatar
Philip Meier committed
960
961
        output_boxes, output_canvas_size = F.center_crop_bounding_boxes(
            bboxes, bboxes_format, bboxes_canvas_size, output_size
962
        )
963
964
965
966
967
968

        if bboxes.ndim < 2:
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
Philip Meier's avatar
Philip Meier committed
969
            bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
970
971
972
973
974
975
            expected_bboxes.append(_compute_expected_bbox(bbox, output_size))

        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
976
977

        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
Philip Meier's avatar
Philip Meier committed
978
        torch.testing.assert_close(output_canvas_size, output_size)
979
980


981
@pytest.mark.parametrize("device", cpu_and_cuda())
982
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
983
984
def test_correctness_center_crop_mask(device, output_size):
    def _compute_expected_mask(mask, output_size):
985
986
987
988
989
990
991
992
993
994
995
996
997
        crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]

        _, image_height, image_width = mask.shape
        if crop_width > image_height or crop_height > image_width:
            padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
            mask = F.pad_image_tensor(mask, padding, fill=0)

        left = round((image_width - crop_width) * 0.5)
        top = round((image_height - crop_height) * 0.5)

        return mask[:, top : top + crop_height, left : left + crop_width]

    mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
998
    actual = F.center_crop_mask(mask, output_size)
999

1000
    expected = _compute_expected_mask(mask, output_size)
1001
    torch.testing.assert_close(expected, actual)
1002
1003
1004


# Copied from test/test_functional_tensor.py
1005
@pytest.mark.parametrize("device", cpu_and_cuda())
Philip Meier's avatar
Philip Meier committed
1006
@pytest.mark.parametrize("canvas_size", ("small", "large"))
1007
1008
1009
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
Philip Meier's avatar
Philip Meier committed
1010
def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
    fn = F.gaussian_blur_image_tensor

    # true_cv2_results = {
    #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
    #     "3_3_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
    #     "3_3_0.5": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
    #     "3_5_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
    #     "3_5_0.5": ...
    #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
    #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
    #     "23_23_1.7": ...
    # }
    p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
    true_cv2_results = torch.load(p)

Philip Meier's avatar
Philip Meier committed
1030
    if canvas_size == "small":
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        tensor = (
            torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
        )
    else:
        tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)

    if dt == torch.float16 and device == "cpu":
        # skip float16 on CPU case
        return

    if dt is not None:
        tensor = tensor.to(dtype=dt)

    _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
    _sigma = sigma[0] if sigma is not None else None
    shape = tensor.shape
    gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
    if gt_key not in true_cv2_results:
        return

    true_out = (
        torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
    )

1055
    image = datapoints.Image(tensor)
1056
1057
1058
1059
1060

    out = fn(image, kernel_size=ksize, sigma=sigma)
    torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


1061
1062
1063
1064
1065
1066
1067
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
1068
1069
def test_to_image_tensor(inpt):
    output = F.to_image_tensor(inpt)
1070
    assert isinstance(output, torch.Tensor)
1071
    assert output.shape == (3, 32, 32)
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
def test_to_image_pil(inpt, mode):
    output = F.to_image_pil(inpt, mode=mode)
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099


def test_equalize_image_tensor_edge_cases():
    inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
    output = F.equalize_image_tensor(inpt)
    torch.testing.assert_close(inpt, output)

    inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
    inpt[..., 100:, 100:] = 1
    output = F.equalize_image_tensor(inpt)
    assert output.unique().tolist() == [0, 255]
1100
1101


1102
@pytest.mark.parametrize("device", cpu_and_cuda())
1103
1104
1105
1106
1107
1108
1109
def test_correctness_uniform_temporal_subsample(device):
    video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
    out_video = F.uniform_temporal_subsample(video, 5)
    assert out_video.unique().tolist() == [0, 2, 4, 6, 9]

    out_video = F.uniform_temporal_subsample(video, 8)
    assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
    [info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor],
    args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
    (input, *other_args), kwargs = args_kwargs.load("cpu")

    output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

    error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
    assert input.ndim == 3, error_msg_fn
    input_stride = input.stride()
    output_stride = output.stride()
    # Here we check output memory format according to the input:
    # if input_stride is (..., 1) then input is most likely channels first and thus
    #   output strides should match channels first strides (H * W, H, 1)
    # if input_stride is (1, ...) then input is most likely channels last and thus
    #   output strides should match channels last strides (1, W * C, C)
    if input_stride[-1] == 1:
        expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
        assert expected_stride == output_stride, error_msg_fn("")
    elif input_stride[0] == 1:
        expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
        assert expected_stride == output_stride, error_msg_fn("")
    else:
        assert False, error_msg_fn("")
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149


def test_resize_float16_no_rounding():
    # Make sure Resize() doesn't round float16 images
    # Non-regression test for https://github.com/pytorch/vision/issues/7667

    img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16)
    out = F.resize(img, size=(10, 10))
    assert out.dtype == torch.float16
    assert (out.round() - out).sum() > 0