test_transforms_v2_functional.py 53.7 KB
Newer Older
1
import inspect
2
import math
3
import os
4
import re
5

6
7
from typing import get_type_hints

8
import numpy as np
9
import PIL.Image
10
import pytest
11
12

import torch
13

14
from common_utils import (
15
    assert_close,
16
    cache,
17
    cpu_and_cuda,
18
19
    DEFAULT_SQUARE_SPATIAL_SIZE,
    make_bounding_boxes,
20
    needs_cuda,
21
    parametrized_error_message,
22
    set_rng_seed,
23
)
24
from torch.utils._pytree import tree_map
25
from torchvision import datapoints
26
from torchvision.transforms.functional import _get_perspective_coeffs
27
28
29
30
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_box, convert_format_bounding_box
from torchvision.transforms.v2.utils import is_simple_tensor
31
32
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
33
34


35
36
37
38
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}


39
40
41
42
43
44
@cache
def script(fn):
    try:
        return torch.jit.script(fn)
    except Exception as error:
        raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
45
46


47
48
49
50
51
52
53
54
55
# Scripting a function often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
    f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
)


56
57
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
    args_kwargs = list(args_kwargs_fn(info))
58
59
60
61
    if not args_kwargs:
        raise pytest.UsageError(
            f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
        )
62
63
64
65
66
67
68
69
70
71
72
73
    idx_field_len = len(str(len(args_kwargs)))
    return [
        pytest.param(
            info,
            args_kwargs_,
            marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
            id=f"{info.id}-{idx:0{idx_field_len}}",
        )
        for idx, args_kwargs_ in enumerate(args_kwargs)
    ]


74
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
75
76
77
78
79
80
81
82
    def decorator(test_fn):
        parts = test_fn.__qualname__.split(".")
        if len(parts) == 1:
            test_class_name = None
            test_function_name = parts[0]
        elif len(parts) == 2:
            test_class_name, test_function_name = parts
        else:
83
            raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
84
85
86
87
88
        test_id = (test_class_name, test_function_name)

        argnames = ("info", "args_kwargs")
        argvalues = []
        for info in infos:
89
            argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
90
91
92
93

        return pytest.mark.parametrize(argnames, argvalues)(test_fn)

    return decorator
94
95


Philip Meier's avatar
Philip Meier committed
96
97
98
99
100
101
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


102
103
104
105
106
107
108
@pytest.fixture()
def test_id(request):
    test_class_name = request.cls.__name__ if request.cls is not None else None
    test_function_name = request.node.originalname
    return test_class_name, test_function_name


109
class TestKernels:
110
    sample_inputs = make_info_args_kwargs_parametrization(
111
112
113
        KERNEL_INFOS,
        args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
    )
114
    reference_inputs = make_info_args_kwargs_parametrization(
115
        [info for info in KERNEL_INFOS if info.reference_fn is not None],
116
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
117
    )
118

119
120
121
122
    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.logs_usage],
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    )
123
    @pytest.mark.parametrize("device", cpu_and_cuda())
124
125
126
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

127
128
        (input, *other_args), kwargs = args_kwargs.load(device)
        info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
129
130
131

        spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

132
    @ignore_jit_warning_no_profile
133
    @sample_inputs
134
    @pytest.mark.parametrize("device", cpu_and_cuda())
135
    def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
136
137
        kernel_eager = info.kernel
        kernel_scripted = script(kernel_eager)
138

139
        (input, *other_args), kwargs = args_kwargs.load(device)
140
        input = input.as_subclass(torch.Tensor)
141

142
143
        actual = kernel_scripted(input, *other_args, **kwargs)
        expected = kernel_eager(input, *other_args, **kwargs)
144

145
146
147
148
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
149
            msg=parametrized_error_message(input, other_args, **kwargs),
150
        )
151

152
153
154
155
156
157
    def _unbatch(self, batch, *, data_dims):
        if isinstance(batch, torch.Tensor):
            batched_tensor = batch
            metadata = ()
        else:
            batched_tensor, *metadata = batch
158

159
160
        if batched_tensor.ndim == data_dims:
            return batch
161

162
163
164
165
166
167
        return [
            self._unbatch(unbatched, data_dims=data_dims)
            for unbatched in (
                batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
            )
        ]
168
169

    @sample_inputs
170
    @pytest.mark.parametrize("device", cpu_and_cuda())
171
    def test_batched_vs_single(self, test_id, info, args_kwargs, device):
172
173
        (batched_input, *other_args), kwargs = args_kwargs.load(device)

174
        datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
175
176
177
        # This dictionary contains the number of rightmost dimensions that contain the actual data.
        # Everything to the left is considered a batch dimension.
        data_dims = {
178
179
            datapoints.Image: 3,
            datapoints.BoundingBox: 1,
180
181
182
183
            # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
            # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
            # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
            # common ground.
184
185
            datapoints.Mask: 2,
            datapoints.Video: 4,
186
        }.get(datapoint_type)
187
188
        if data_dims is None:
            raise pytest.UsageError(
189
                f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}."
190
191
192
193
194
195
            ) from None
        elif batched_input.ndim <= data_dims:
            pytest.skip("Input is not batched.")
        elif not all(batched_input.shape[:-data_dims]):
            pytest.skip("Input has a degenerate batch shape.")

196
        batched_input = batched_input.as_subclass(torch.Tensor)
197
198
        batched_output = info.kernel(batched_input, *other_args, **kwargs)
        actual = self._unbatch(batched_output, data_dims=data_dims)
199

200
201
        single_inputs = self._unbatch(batched_input, data_dims=data_dims)
        expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
202

203
204
205
206
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
207
            msg=parametrized_error_message(batched_input, *other_args, **kwargs),
208
        )
209
210

    @sample_inputs
211
    @pytest.mark.parametrize("device", cpu_and_cuda())
212
213
    def test_no_inplace(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
214
        input = input.as_subclass(torch.Tensor)
215
216
217
218
219

        if input.numel() == 0:
            pytest.skip("The input has a degenerate shape.")

        input_version = input._version
220
        info.kernel(input, *other_args, **kwargs)
221

222
        assert input._version == input_version
223
224
225

    @sample_inputs
    @needs_cuda
226
    def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
227
        (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
228
        input_cpu = input_cpu.as_subclass(torch.Tensor)
229
230
231
232
233
        input_cuda = input_cpu.to("cuda")

        output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
        output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

234
235
236
237
238
        assert_close(
            output_cuda,
            output_cpu,
            check_device=False,
            **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
239
            msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
240
        )
241
242

    @sample_inputs
243
    @pytest.mark.parametrize("device", cpu_and_cuda())
244
245
    def test_dtype_and_device_consistency(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
246
        input = input.as_subclass(torch.Tensor)
247
248

        output = info.kernel(input, *other_args, **kwargs)
249
250
251
        # Most kernels just return a tensor, but some also return some additional metadata
        if not isinstance(output, torch.Tensor):
            output, *_ = output
252
253
254
255

        assert output.dtype == input.dtype
        assert output.device == input.device

256
    @reference_inputs
257
258
    def test_against_reference(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
259

260
261
262
        actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
        # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
        # metadata regardless of whether the kernel takes it explicitly or not
263
        expected = info.reference_fn(input, *other_args, **kwargs)
264

265
266
267
268
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
269
            msg=parametrized_error_message(input, *other_args, **kwargs),
270
271
272
273
274
275
276
277
        )

    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.float32_vs_uint8],
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
    )
    def test_float32_vs_uint8(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
278
        input = input.as_subclass(torch.Tensor)
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

        if input.dtype != torch.uint8:
            pytest.skip(f"Input dtype is {input.dtype}.")

        adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

        actual = info.kernel(
            F.convert_dtype_image_tensor(input, dtype=torch.float32),
            *adapted_other_args,
            **adapted_kwargs,
        )

        expected = F.convert_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32)

        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
297
            msg=parametrized_error_message(input, *other_args, **kwargs),
298
        )
299
300


301
302
303
304
305
306
307
308
309
310
311
312
@pytest.fixture
def spy_on(mocker):
    def make_spy(fn, *, module=None, name=None):
        # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
        module = module or fn.__module__
        name = name or fn.__name__
        spy = mocker.patch(f"{module}.{name}", wraps=fn)
        return spy

    return make_spy


313
class TestDispatchers:
314
    image_sample_inputs = make_info_args_kwargs_parametrization(
315
316
        [info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
317
318
    )

319
320
321
322
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
323
    @pytest.mark.parametrize("device", cpu_and_cuda())
324
325
326
327
328
329
330
331
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

        args, kwargs = args_kwargs.load(device)
        info.dispatcher(*args, **kwargs)

        spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

332
    @ignore_jit_warning_no_profile
333
    @image_sample_inputs
334
    @pytest.mark.parametrize("device", cpu_and_cuda())
335
336
    def test_scripted_smoke(self, info, args_kwargs, device):
        dispatcher = script(info.dispatcher)
337

338
339
        (image_datapoint, *other_args), kwargs = args_kwargs.load(device)
        image_simple_tensor = torch.Tensor(image_datapoint)
340

341
        dispatcher(image_simple_tensor, *other_args, **kwargs)
342

343
344
    # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
    #  replaces this test for them.
345
    @ignore_jit_warning_no_profile
346
347
348
349
350
351
    @pytest.mark.parametrize(
        "dispatcher",
        [
            F.get_dimensions,
            F.get_image_num_channels,
            F.get_image_size,
352
353
            F.get_num_channels,
            F.get_num_frames,
354
355
            F.get_spatial_size,
            F.rgb_to_grayscale,
356
            F.uniform_temporal_subsample,
357
358
359
360
361
        ],
        ids=lambda dispatcher: dispatcher.__name__,
    )
    def test_scriptable(self, dispatcher):
        script(dispatcher)
362

363
    @image_sample_inputs
364
    def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
365
366
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
        image_simple_tensor = torch.Tensor(image_datapoint)
367

368
        kernel_info = info.kernel_infos[datapoints.Image]
369
        spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
370
371
372
373
374

        info.dispatcher(image_simple_tensor, *other_args, **kwargs)

        spy.assert_called_once()

375
376
377
378
379
380
381
382
383
384
    @image_sample_inputs
    def test_simple_tensor_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
        image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)

        output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)

        # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
        assert type(output) is torch.Tensor

385
    @make_info_args_kwargs_parametrization(
386
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
387
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
388
389
    )
    def test_dispatch_pil(self, info, args_kwargs, spy_on):
390
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
391

392
        if image_datapoint.ndim > 3:
393
394
            pytest.skip("Input is batched")

395
        image_pil = F.to_image_pil(image_datapoint)
396
397

        pil_kernel_info = info.pil_kernel_info
398
        spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
399
400
401
402
403

        info.dispatcher(image_pil, *other_args, **kwargs)

        spy.assert_called_once()

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    @make_info_args_kwargs_parametrization(
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
    )
    def test_pil_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()

        if image_datapoint.ndim > 3:
            pytest.skip("Input is batched")

        image_pil = F.to_image_pil(image_datapoint)

        output = info.dispatcher(image_pil, *other_args, **kwargs)

        assert isinstance(output, PIL.Image.Image)

420
    @make_info_args_kwargs_parametrization(
421
422
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
423
    )
424
425
    def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
        (datapoint, *other_args), kwargs = args_kwargs.load()
426

427
        method_name = info.id
428
429
430
        method = getattr(datapoint, method_name)
        datapoint_type = type(datapoint)
        spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")
431

432
        info.dispatcher(datapoint, *other_args, **kwargs)
433
434
435

        spy.assert_called_once()

436
437
438
439
440
441
442
443
444
445
446
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
    def test_datapoint_output_type(self, info, args_kwargs):
        (datapoint, *other_args), kwargs = args_kwargs.load()

        output = info.dispatcher(datapoint, *other_args, **kwargs)

        assert isinstance(output, type(datapoint))

447
    @pytest.mark.parametrize(
448
        ("dispatcher_info", "datapoint_type", "kernel_info"),
449
        [
450
451
452
            pytest.param(
                dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}"
            )
453
            for dispatcher_info in DISPATCHER_INFOS
454
            for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items()
455
456
        ],
    )
457
    def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info):
458
459
460
461
462
463
        dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

        kernel_signature = inspect.signature(kernel_info.kernel)
        kernel_params = list(kernel_signature.parameters.values())[1:]

464
        # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
465
        # explicit passed to the kernel.
466
467
        datapoint_type_metadata = datapoint_type.__annotations__.keys()
        kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata]
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

        dispatcher_params = iter(dispatcher_params)
        for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
            try:
                # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
                # dispatcher parameters that have no kernel equivalent while keeping the order intact.
                while dispatcher_param.name != kernel_param.name:
                    dispatcher_param = next(dispatcher_params)
            except StopIteration:
                raise AssertionError(
                    f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
                    f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
                ) from None

            assert dispatcher_param == kernel_param

    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
485
    def test_dispatcher_datapoint_signatures_consistency(self, info):
486
        try:
487
            datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id)
488
        except AttributeError:
489
            pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.")
490
491
492
493

        dispatcher_signature = inspect.signature(info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

494
495
        datapoint_signature = inspect.signature(datapoint_method)
        datapoint_params = list(datapoint_signature.parameters.values())[1:]
496

497
498
499
500
501
502
        # Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
        # defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
        # natively concrete dispatcher annotations.
        datapoint_annotations = get_type_hints(datapoint_method)
        for param in datapoint_params:
            param._annotation = datapoint_annotations[param.name]
503

504
        assert dispatcher_params == datapoint_params
505

506
507
508
509
510
511
512
513
    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
    def test_unkown_type(self, info):
        unkown_input = object()
        (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")

        with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
            info.dispatcher(unkown_input, *other_args, **kwargs)

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    @make_info_args_kwargs_parametrization(
        [
            info
            for info in DISPATCHER_INFOS
            if datapoints.BoundingBox in info.kernels and info.dispatcher is not F.convert_format_bounding_box
        ],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBox),
    )
    def test_bounding_box_format_consistency(self, info, args_kwargs):
        (bounding_box, *other_args), kwargs = args_kwargs.load()
        format = bounding_box.format

        output = info.dispatcher(bounding_box, *other_args, **kwargs)

        assert output.format == format

530

531
@pytest.mark.parametrize(
532
    ("alias", "target"),
533
    [
534
535
536
537
538
539
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
            (F.to_pil_image, F.to_image_pil),
540
            (F.elastic_transform, F.elastic),
541
            (F.convert_image_dtype, F.convert_dtype_image_tensor),
542
        ]
543
544
    ],
)
545
546
def test_alias(alias, target):
    assert alias is target
547
548


549
550
551
@pytest.mark.parametrize(
    ("info", "args_kwargs"),
    make_info_args_kwargs_params(
552
        KERNEL_INFOS_MAP[F.convert_dtype_image_tensor],
553
554
555
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    ),
)
556
@pytest.mark.parametrize("device", cpu_and_cuda())
557
def test_convert_dtype_image_tensor_dtype_and_device(info, args_kwargs, device):
558
559
560
561
562
563
564
565
566
    (input, *other_args), kwargs = args_kwargs.load(device)
    dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32)

    output = info.kernel(input, dtype)

    assert output.dtype == dtype
    assert output.device == input.device


567
@pytest.mark.parametrize("device", cpu_and_cuda())
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
    stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

    def assert_samples_from_standard_normal(t):
        p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
        return p_value > 1e-4

    image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
    mean = image.mean(dim=(1, 2)).tolist()
    std = image.std(dim=(1, 2)).tolist()

    assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))


583
584
585
586
587
588
589
590
591
592
593
594
class TestClampBoundingBox:
    @pytest.mark.parametrize(
        "metadata",
        [
            dict(),
            dict(format=datapoints.BoundingBoxFormat.XYXY),
            dict(spatial_size=(1, 1)),
        ],
    )
    def test_simple_tensor_insufficient_metadata(self, metadata):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

595
        with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` has to be passed")):
596
597
598
599
600
601
602
603
604
605
606
607
608
            F.clamp_bounding_box(simple_tensor, **metadata)

    @pytest.mark.parametrize(
        "metadata",
        [
            dict(format=datapoints.BoundingBoxFormat.XYXY),
            dict(spatial_size=(1, 1)),
            dict(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(1, 1)),
        ],
    )
    def test_datapoint_explicit_metadata(self, metadata):
        datapoint = next(make_bounding_boxes())

609
        with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` must not be passed")):
610
611
612
            F.clamp_bounding_box(datapoint, **metadata)


613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
class TestConvertFormatBoundingBox:
    @pytest.mark.parametrize(
        ("inpt", "old_format"),
        [
            (next(make_bounding_boxes()), None),
            (next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
        ],
    )
    def test_missing_new_format(self, inpt, old_format):
        with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
            F.convert_format_bounding_box(inpt, old_format)

    def test_simple_tensor_insufficient_metadata(self):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

        with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
            F.convert_format_bounding_box(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)

    def test_datapoint_explicit_metadata(self):
        datapoint = next(make_bounding_boxes())

        with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
            F.convert_format_bounding_box(
                datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
            )


640
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
641
#  `transforms_v2_kernel_infos.py`
642
643


644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
    rot = math.radians(angle_)
    cx, cy = center_
    tx, ty = translate_
    sx, sy = [math.radians(sh_) for sh_ in shear_]

    c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
    t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
    c_matrix_inv = np.linalg.inv(c_matrix)
    rs_matrix = np.array(
        [
            [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
            [scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
            [0, 0, 1],
        ]
    )
    shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
    shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
    rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
    true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
    return true_matrix


667
@pytest.mark.parametrize("device", cpu_and_cuda())
668
669
def test_correctness_affine_bounding_box_on_fixed_input(device):
    # Check transformation against known expected output
670
    format = datapoints.BoundingBoxFormat.XYXY
671
    spatial_size = (64, 64)
672
673
674
    in_boxes = [
        [20, 25, 35, 45],
        [50, 5, 70, 22],
675
        [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
676
677
        [1, 1, 5, 5],
    ]
678
    in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device)
679
680
681
682
683
684
685
686
687
688
689
    # Tested parameters
    angle = 63
    scale = 0.89
    dx = 0.12
    dy = 0.23

    # Expected bboxes computed using albumentations:
    # from albumentations.augmentations.geometric.functional import bbox_shift_scale_rotate
    # from albumentations.augmentations.geometric.functional import normalize_bbox, denormalize_bbox
    # expected_bboxes = []
    # for in_box in in_boxes:
690
691
692
    #     n_in_box = normalize_bbox(in_box, *spatial_size)
    #     n_out_box = bbox_shift_scale_rotate(n_in_box, -angle, scale, dx, dy, *spatial_size)
    #     out_box = denormalize_bbox(n_out_box, *spatial_size)
693
694
695
696
697
698
699
700
    #     expected_bboxes.append(out_box)
    expected_bboxes = [
        (24.522435977922218, 34.375689508290854, 46.443125279998114, 54.3516575015695),
        (54.88288587110401, 50.08453280875634, 76.44484547743795, 72.81332520036864),
        (27.709526487041554, 34.74952648704156, 51.650473512958435, 58.69047351295844),
        (48.56528888843238, 9.611532109828834, 53.35347829361575, 14.39972151501221),
    ]

701
702
703
704
    expected_bboxes = clamp_bounding_box(
        datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size)
    ).tolist()

705
706
    output_boxes = F.affine_bounding_box(
        in_boxes,
707
708
709
710
711
        format=format,
        spatial_size=spatial_size,
        angle=angle,
        translate=(dx * spatial_size[1], dy * spatial_size[0]),
        scale=scale,
712
713
714
        shear=(0, 0),
    )

715
    torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
716
717


718
@pytest.mark.parametrize("device", cpu_and_cuda())
719
720
721
722
723
724
725
726
727
728
729
730
731
732
def test_correctness_affine_segmentation_mask_on_fixed_input(device):
    # Check transformation against known expected output and CPU/CUDA devices

    # Create a fixed input segmentation mask with 2 square masks
    # in top-left, bottom-left corners
    mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device)
    mask[0, 2:10, 2:10] = 1
    mask[0, 32 - 9 : 32 - 3, 3:9] = 2

    # Rotate 90 degrees and scale
    expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1))
    expected_mask = torch.nn.functional.interpolate(expected_mask[None, :].float(), size=(64, 64), mode="nearest")
    expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long()

733
    out_mask = F.affine_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0])
734
735
736
737

    torch.testing.assert_close(out_mask, expected_mask)


738
@pytest.mark.parametrize("angle", range(-90, 90, 56))
739
@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
740
741
742
743
744
def test_correctness_rotate_bounding_box(angle, expand, center):
    def _compute_expected_bbox(bbox, angle_, expand_, center_):
        affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
        affine_matrix = affine_matrix[:2, :]

745
        height, width = bbox.spatial_size
746
        bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
747
748
749
750
751
752
753
754
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
                # image frame
                [0.0, 0.0, 1.0],
755
756
757
                [0.0, height, 1.0],
                [width, height, 1.0],
                [width, 0.0, 1.0],
758
759
760
761
            ]
        )
        transformed_points = np.matmul(points, affine_matrix.T)
        out_bbox = [
762
763
764
765
            float(np.min(transformed_points[:4, 0])),
            float(np.min(transformed_points[:4, 1])),
            float(np.max(transformed_points[:4, 0])),
            float(np.max(transformed_points[:4, 1])),
766
767
768
769
770
771
772
773
774
        ]
        if expand_:
            tr_x = np.min(transformed_points[4:, 0])
            tr_y = np.min(transformed_points[4:, 1])
            out_bbox[0] -= tr_x
            out_bbox[1] -= tr_y
            out_bbox[2] -= tr_x
            out_bbox[3] -= tr_y

775
776
            height = int(height - 2 * tr_y)
            width = int(width - 2 * tr_x)
777

778
        out_bbox = datapoints.BoundingBox(
779
            out_bbox,
780
            format=datapoints.BoundingBoxFormat.XYXY,
781
            spatial_size=(height, width),
782
            dtype=bbox.dtype,
783
            device=bbox.device,
784
        )
785
786
        out_bbox = clamp_bounding_box(convert_format_bounding_box(out_bbox, new_format=bbox.format))
        return out_bbox, (height, width)
787

788
    spatial_size = (32, 38)
789

790
    for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)):
791
        bboxes_format = bboxes.format
792
        bboxes_spatial_size = bboxes.spatial_size
793

794
        output_bboxes, output_spatial_size = F.rotate_bounding_box(
795
796
            bboxes.as_subclass(torch.Tensor),
            format=bboxes_format,
797
            spatial_size=bboxes_spatial_size,
798
799
800
801
802
            angle=angle,
            expand=expand,
            center=center,
        )

803
804
        center_ = center
        if center_ is None:
805
            center_ = [s * 0.5 for s in bboxes_spatial_size[::-1]]
806
807
808
809
810
811

        if bboxes.ndim < 2:
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
812
            bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
813
            expected_bbox, expected_spatial_size = _compute_expected_bbox(bbox, -angle, expand, center_)
814
            expected_bboxes.append(expected_bbox)
815
816
817
818
        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
819
        torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0)
820
        torch.testing.assert_close(output_spatial_size, expected_spatial_size, atol=1, rtol=0)
821
822


823
@pytest.mark.parametrize("device", cpu_and_cuda())
824
@pytest.mark.parametrize("expand", [False])  # expand=True does not match D2
825
826
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
    # Check transformation against known expected output
827
    format = datapoints.BoundingBoxFormat.XYXY
828
    spatial_size = (64, 64)
829
830
831
    # xyxy format
    in_boxes = [
        [1, 1, 5, 5],
832
833
834
        [1, spatial_size[0] - 6, 5, spatial_size[0] - 2],
        [spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2],
        [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
835
    ]
836
    in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device)
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
    # Tested parameters
    angle = 45
    center = None if expand else [12, 23]

    # # Expected bboxes computed using Detectron2:
    # from detectron2.data.transforms import RotationTransform, AugmentationList
    # from detectron2.data.transforms import AugInput
    # import cv2
    # inpt = AugInput(im1, boxes=np.array(in_boxes, dtype="float32"))
    # augs = AugmentationList([RotationTransform(*size, angle, expand=expand, center=center, interp=cv2.INTER_NEAREST), ])
    # out = augs(inpt)
    # print(inpt.boxes)
    if expand:
        expected_bboxes = [
            [1.65937957, 42.67157288, 7.31623382, 48.32842712],
            [41.96446609, 82.9766594, 47.62132034, 88.63351365],
            [82.26955262, 42.67157288, 87.92640687, 48.32842712],
            [31.35786438, 31.35786438, 59.64213562, 59.64213562],
        ]
    else:
        expected_bboxes = [
            [-11.33452378, 12.39339828, -5.67766953, 18.05025253],
            [28.97056275, 52.69848481, 34.627417, 58.35533906],
            [69.27564928, 12.39339828, 74.93250353, 18.05025253],
            [18.36396103, 1.07968978, 46.64823228, 29.36396103],
        ]
863
864
865
        expected_bboxes = clamp_bounding_box(
            datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size)
        ).tolist()
866

867
    output_boxes, _ = F.rotate_bounding_box(
868
        in_boxes,
869
870
871
        format=format,
        spatial_size=spatial_size,
        angle=angle,
872
873
874
875
        expand=expand,
        center=center,
    )

876
    torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
877
878


879
@pytest.mark.parametrize("device", cpu_and_cuda())
880
881
882
883
884
885
886
887
888
889
890
def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
    # Check transformation against known expected output and CPU/CUDA devices

    # Create a fixed input segmentation mask with 2 square masks
    # in top-left, bottom-left corners
    mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device)
    mask[0, 2:10, 2:10] = 1
    mask[0, 32 - 9 : 32 - 3, 3:9] = 2

    # Rotate 90 degrees
    expected_mask = torch.rot90(mask, k=1, dims=(-2, -1))
891
    out_mask = F.rotate_mask(mask, 90, expand=False)
892
    torch.testing.assert_close(out_mask, expected_mask)
893
894


895
@pytest.mark.parametrize("device", cpu_and_cuda())
896
897
@pytest.mark.parametrize(
    "format",
898
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
899
)
900
901
902
903
904
905
906
@pytest.mark.parametrize(
    "top, left, height, width, expected_bboxes",
    [
        [8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
        [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
    ],
)
907
def test_correctness_crop_bounding_box(device, format, top, left, height, width, expected_bboxes):
908
909
910
911
912
913
914
915
916
917
918
919
920

    # Expected bboxes computed using Albumentations:
    # import numpy as np
    # from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox
    # expected_bboxes = []
    # for in_box in in_boxes:
    #     n_in_box = normalize_bbox(in_box, *size)
    #     n_out_box = crop_bbox_by_coords(
    #         n_in_box, (left, top, left + width, top + height), height, width, *size
    #     )
    #     out_box = denormalize_bbox(n_out_box, height, width)
    #     expected_bboxes.append(out_box)

921
922
    format = datapoints.BoundingBoxFormat.XYXY
    spatial_size = (64, 76)
923
924
925
926
927
    in_boxes = [
        [10.0, 15.0, 25.0, 35.0],
        [50.0, 5.0, 70.0, 22.0],
        [45.0, 46.0, 56.0, 62.0],
    ]
928
    in_boxes = torch.tensor(in_boxes, device=device)
929
930
    if format != datapoints.BoundingBoxFormat.XYXY:
        in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
931

932
933
934
935
    expected_bboxes = clamp_bounding_box(
        datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size)
    ).tolist()

936
    output_boxes, output_spatial_size = F.crop_bounding_box(
937
        in_boxes,
938
        format,
939
940
        top,
        left,
941
942
        spatial_size[0],
        spatial_size[1],
943
944
    )

945
946
    if format != datapoints.BoundingBoxFormat.XYXY:
        output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
947

948
    torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
949
    torch.testing.assert_close(output_spatial_size, spatial_size)
950
951


952
@pytest.mark.parametrize("device", cpu_and_cuda())
953
954
955
956
def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device):
    mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    mask[:, :, 0] = 1

957
    out_mask = F.horizontal_flip_mask(mask)
958
959
960
961
962
963

    expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    expected_mask[:, :, -1] = 1
    torch.testing.assert_close(out_mask, expected_mask)


964
@pytest.mark.parametrize("device", cpu_and_cuda())
965
966
967
968
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
    mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    mask[:, 0, :] = 1

969
    out_mask = F.vertical_flip_mask(mask)
970
971
972
973

    expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    expected_mask[:, -1, :] = 1
    torch.testing.assert_close(out_mask, expected_mask)
974
975


976
@pytest.mark.parametrize("device", cpu_and_cuda())
977
978
@pytest.mark.parametrize(
    "format",
979
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
980
981
982
983
984
985
986
987
988
)
@pytest.mark.parametrize(
    "top, left, height, width, size",
    [
        [0, 0, 30, 30, (60, 60)],
        [-5, 5, 35, 45, (32, 34)],
    ],
)
def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size):
989
    def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
990
991
992
993
994
995
996
        # bbox should be xyxy
        bbox[0] = (bbox[0] - left_) * size_[1] / width_
        bbox[1] = (bbox[1] - top_) * size_[0] / height_
        bbox[2] = (bbox[2] - left_) * size_[1] / width_
        bbox[3] = (bbox[3] - top_) * size_[0] / height_
        return bbox

997
    format = datapoints.BoundingBoxFormat.XYXY
998
    spatial_size = (100, 100)
999
1000
1001
1002
1003
1004
    in_boxes = [
        [10.0, 10.0, 20.0, 20.0],
        [5.0, 10.0, 15.0, 20.0],
    ]
    expected_bboxes = []
    for in_box in in_boxes:
1005
        expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
1006
1007
    expected_bboxes = torch.tensor(expected_bboxes, device=device)

1008
1009
    in_boxes = datapoints.BoundingBox(
        in_boxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device
1010
    )
1011
1012
    if format != datapoints.BoundingBoxFormat.XYXY:
        in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
1013

1014
    output_boxes, output_spatial_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)
1015

1016
1017
    if format != datapoints.BoundingBoxFormat.XYXY:
        output_boxes = convert_format_bounding_box(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
1018
1019

    torch.testing.assert_close(output_boxes, expected_bboxes)
1020
    torch.testing.assert_close(output_spatial_size, size)
1021
1022


1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
def _parse_padding(padding):
    if isinstance(padding, int):
        return [padding] * 4
    if isinstance(padding, list):
        if len(padding) == 1:
            return padding * 4
        if len(padding) == 2:
            return padding * 2  # [left, up, right, down]

    return padding


1035
@pytest.mark.parametrize("device", cpu_and_cuda())
1036
1037
1038
1039
1040
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
def test_correctness_pad_bounding_box(device, padding):
    def _compute_expected_bbox(bbox, padding_):
        pad_left, pad_up, _, _ = _parse_padding(padding_)

1041
1042
        dtype = bbox.dtype
        format = bbox.format
1043
1044
        bbox = (
            bbox.clone()
1045
1046
            if format == datapoints.BoundingBoxFormat.XYXY
            else convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
1047
        )
1048
1049
1050
1051

        bbox[0::2] += pad_left
        bbox[1::2] += pad_up

1052
1053
        bbox = convert_format_bounding_box(bbox, new_format=format)
        if bbox.dtype != dtype:
1054
1055
            # Temporary cast to original dtype
            # e.g. float32 -> int
1056
            bbox = bbox.to(dtype)
1057
1058
        return bbox

1059
    def _compute_expected_spatial_size(bbox, padding_):
1060
        pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
1061
        height, width = bbox.spatial_size
1062
1063
        return height + pad_up + pad_down, width + pad_left + pad_right

1064
1065
1066
    for bboxes in make_bounding_boxes():
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
1067
        bboxes_spatial_size = bboxes.spatial_size
1068

1069
1070
        output_boxes, output_spatial_size = F.pad_bounding_box(
            bboxes, format=bboxes_format, spatial_size=bboxes_spatial_size, padding=padding
1071
1072
        )

1073
        torch.testing.assert_close(output_spatial_size, _compute_expected_spatial_size(bboxes, padding))
1074

1075
        if bboxes.ndim < 2 or bboxes.shape[0] == 0:
1076
1077
1078
1079
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
1080
            bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
1081
1082
1083
1084
1085
1086
            expected_bboxes.append(_compute_expected_bbox(bbox, padding))

        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
1087
        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
1088
1089


1090
@pytest.mark.parametrize("device", cpu_and_cuda())
1091
1092
1093
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
    mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)

1094
    out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
1095
1096
1097
1098
1099
1100

    expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
    expected_mask[:, 1:-1, 1:-1] = 1
    torch.testing.assert_close(out_mask, expected_mask)


1101
@pytest.mark.parametrize("device", cpu_and_cuda())
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
@pytest.mark.parametrize(
    "startpoints, endpoints",
    [
        [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
    ],
)
def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
    def _compute_expected_bbox(bbox, pcoeffs_):
        m1 = np.array(
            [
                [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
                [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
            ]
        )
        m2 = np.array(
            [
                [pcoeffs_[6], pcoeffs_[7], 1.0],
                [pcoeffs_[6], pcoeffs_[7], 1.0],
            ]
        )

1125
        bbox_xyxy = convert_format_bounding_box(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
        numer = np.matmul(points, m1.T)
        denom = np.matmul(points, m2.T)
        transformed_points = numer / denom
1137
1138
1139
1140
1141
1142
1143
1144
        out_bbox = np.array(
            [
                np.min(transformed_points[:, 0]),
                np.min(transformed_points[:, 1]),
                np.max(transformed_points[:, 0]),
                np.max(transformed_points[:, 1]),
            ]
        )
1145
        out_bbox = datapoints.BoundingBox(
1146
            out_bbox,
1147
            format=datapoints.BoundingBoxFormat.XYXY,
1148
            spatial_size=bbox.spatial_size,
1149
            dtype=bbox.dtype,
1150
1151
            device=bbox.device,
        )
1152
        return clamp_bounding_box(convert_format_bounding_box(out_bbox, new_format=bbox.format))
1153

1154
    spatial_size = (32, 38)
1155
1156
1157
1158

    pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
    inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

1159
    for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)):
1160
1161
1162
        bboxes = bboxes.to(device)

        output_bboxes = F.perspective_bounding_box(
1163
1164
            bboxes.as_subclass(torch.Tensor),
            format=bboxes.format,
1165
            spatial_size=bboxes.spatial_size,
1166
1167
            startpoints=None,
            endpoints=None,
1168
            coefficients=pcoeffs,
1169
1170
1171
1172
1173
1174
1175
        )

        if bboxes.ndim < 2:
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
1176
            bbox = datapoints.BoundingBox(bbox, format=bboxes.format, spatial_size=bboxes.spatial_size)
1177
1178
1179
1180
1181
            expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
1182
        torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
1183
1184


1185
@pytest.mark.parametrize("device", cpu_and_cuda())
1186
1187
1188
1189
1190
1191
1192
@pytest.mark.parametrize(
    "output_size",
    [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
)
def test_correctness_center_crop_bounding_box(device, output_size):
    def _compute_expected_bbox(bbox, output_size_):
        format_ = bbox.format
1193
        spatial_size_ = bbox.spatial_size
1194
        dtype = bbox.dtype
1195
        bbox = convert_format_bounding_box(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
1196
1197
1198
1199

        if len(output_size_) == 1:
            output_size_.append(output_size_[-1])

1200
1201
        cy = int(round((spatial_size_[0] - output_size_[0]) * 0.5))
        cx = int(round((spatial_size_[1] - output_size_[1]) * 0.5))
1202
1203
1204
1205
1206
1207
        out_bbox = [
            bbox[0].item() - cx,
            bbox[1].item() - cy,
            bbox[2].item(),
            bbox[3].item(),
        ]
1208
        out_bbox = torch.tensor(out_bbox)
1209
        out_bbox = convert_format_bounding_box(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
1210
        out_bbox = clamp_bounding_box(out_bbox, format=format_, spatial_size=output_size)
1211
        return out_bbox.to(dtype=dtype, device=bbox.device)
1212

1213
    for bboxes in make_bounding_boxes(extra_dims=((4,),)):
1214
1215
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
1216
        bboxes_spatial_size = bboxes.spatial_size
1217

1218
1219
        output_boxes, output_spatial_size = F.center_crop_bounding_box(
            bboxes, bboxes_format, bboxes_spatial_size, output_size
1220
        )
1221
1222
1223
1224
1225
1226

        if bboxes.ndim < 2:
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
1227
            bbox = datapoints.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
1228
1229
1230
1231
1232
1233
            expected_bboxes.append(_compute_expected_bbox(bbox, output_size))

        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
1234
1235

        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
1236
        torch.testing.assert_close(output_spatial_size, output_size)
1237
1238


1239
@pytest.mark.parametrize("device", cpu_and_cuda())
1240
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
1241
1242
def test_correctness_center_crop_mask(device, output_size):
    def _compute_expected_mask(mask, output_size):
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
        crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]

        _, image_height, image_width = mask.shape
        if crop_width > image_height or crop_height > image_width:
            padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
            mask = F.pad_image_tensor(mask, padding, fill=0)

        left = round((image_width - crop_width) * 0.5)
        top = round((image_height - crop_height) * 0.5)

        return mask[:, top : top + crop_height, left : left + crop_width]

    mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
1256
    actual = F.center_crop_mask(mask, output_size)
1257

1258
    expected = _compute_expected_mask(mask, output_size)
1259
    torch.testing.assert_close(expected, actual)
1260
1261
1262


# Copied from test/test_functional_tensor.py
1263
@pytest.mark.parametrize("device", cpu_and_cuda())
1264
@pytest.mark.parametrize("spatial_size", ("small", "large"))
1265
1266
1267
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
1268
def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, sigma):
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
    fn = F.gaussian_blur_image_tensor

    # true_cv2_results = {
    #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
    #     "3_3_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
    #     "3_3_0.5": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
    #     "3_5_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
    #     "3_5_0.5": ...
    #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
    #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
    #     "23_23_1.7": ...
    # }
    p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
    true_cv2_results = torch.load(p)

1288
    if spatial_size == "small":
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
        tensor = (
            torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
        )
    else:
        tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)

    if dt == torch.float16 and device == "cpu":
        # skip float16 on CPU case
        return

    if dt is not None:
        tensor = tensor.to(dtype=dt)

    _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
    _sigma = sigma[0] if sigma is not None else None
    shape = tensor.shape
    gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
    if gt_key not in true_cv2_results:
        return

    true_out = (
        torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
    )

1313
    image = datapoints.Image(tensor)
1314
1315
1316
1317
1318

    out = fn(image, kernel_size=ksize, sigma=sigma)
    torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


1319
1320
1321
1322
1323
1324
1325
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
1326
1327
def test_to_image_tensor(inpt):
    output = F.to_image_tensor(inpt)
1328
    assert isinstance(output, torch.Tensor)
1329
    assert output.shape == (3, 32, 32)
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
def test_to_image_pil(inpt, mode):
    output = F.to_image_pil(inpt, mode=mode)
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357


def test_equalize_image_tensor_edge_cases():
    inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
    output = F.equalize_image_tensor(inpt)
    torch.testing.assert_close(inpt, output)

    inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
    inpt[..., 100:, 100:] = 1
    output = F.equalize_image_tensor(inpt)
    assert output.unique().tolist() == [0, 255]
1358
1359


1360
@pytest.mark.parametrize("device", cpu_and_cuda())
1361
1362
1363
1364
1365
1366
1367
def test_correctness_uniform_temporal_subsample(device):
    video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
    out_video = F.uniform_temporal_subsample(video, 5)
    assert out_video.unique().tolist() == [0, 2, 4, 6, 9]

    out_video = F.uniform_temporal_subsample(video, 8)
    assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
    [info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor],
    args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
    (input, *other_args), kwargs = args_kwargs.load("cpu")

    output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

    error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
    assert input.ndim == 3, error_msg_fn
    input_stride = input.stride()
    output_stride = output.stride()
    # Here we check output memory format according to the input:
    # if input_stride is (..., 1) then input is most likely channels first and thus
    #   output strides should match channels first strides (H * W, H, 1)
    # if input_stride is (1, ...) then input is most likely channels last and thus
    #   output strides should match channels last strides (1, W * C, C)
    if input_stride[-1] == 1:
        expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
        assert expected_stride == output_stride, error_msg_fn("")
    elif input_stride[0] == 1:
        expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
        assert expected_stride == output_stride, error_msg_fn("")
    else:
        assert False, error_msg_fn("")
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407


def test_resize_float16_no_rounding():
    # Make sure Resize() doesn't round float16 images
    # Non-regression test for https://github.com/pytorch/vision/issues/7667

    img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16)
    out = F.resize(img, size=(10, 10))
    assert out.dtype == torch.float16
    assert (out.round() - out).sum() > 0