"torchvision/transforms/v2/_misc.py" did not exist on "54a2d4e8f7a4568823532d4342f6ba13e7339dce"
test_transforms_v2_functional.py 38.5 KB
Newer Older
1
import inspect
2
import math
3
import os
4
import re
5

6
import numpy as np
7
import PIL.Image
8
import pytest
9
import torch
10

11
from common_utils import assert_close, cache, cpu_and_cuda, needs_cuda, set_rng_seed
12
from torch.utils._pytree import tree_map
13
from torchvision import datapoints
14
from torchvision.transforms.functional import _get_perspective_coeffs
15
from torchvision.transforms.v2 import functional as F
Nicolas Hug's avatar
Nicolas Hug committed
16
from torchvision.transforms.v2._utils import is_pure_tensor
17
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
Nicolas Hug's avatar
Nicolas Hug committed
18
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format
19
20
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
21
22
23
24
25
from transforms_v2_legacy_utils import (
    DEFAULT_SQUARE_SPATIAL_SIZE,
    make_multiple_bounding_boxes,
    parametrized_error_message,
)
26
27


28
29
30
31
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}


32
33
34
35
36
37
@cache
def script(fn):
    try:
        return torch.jit.script(fn)
    except Exception as error:
        raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
38
39


40
41
42
43
44
45
46
47
48
# Scripting a function often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
    f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
)


49
50
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
    args_kwargs = list(args_kwargs_fn(info))
51
52
53
54
    if not args_kwargs:
        raise pytest.UsageError(
            f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
        )
55
56
57
58
59
60
61
62
63
64
65
66
    idx_field_len = len(str(len(args_kwargs)))
    return [
        pytest.param(
            info,
            args_kwargs_,
            marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
            id=f"{info.id}-{idx:0{idx_field_len}}",
        )
        for idx, args_kwargs_ in enumerate(args_kwargs)
    ]


67
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
68
69
70
71
72
73
74
75
    def decorator(test_fn):
        parts = test_fn.__qualname__.split(".")
        if len(parts) == 1:
            test_class_name = None
            test_function_name = parts[0]
        elif len(parts) == 2:
            test_class_name, test_function_name = parts
        else:
76
            raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
77
78
79
80
81
        test_id = (test_class_name, test_function_name)

        argnames = ("info", "args_kwargs")
        argvalues = []
        for info in infos:
82
            argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
83
84
85
86

        return pytest.mark.parametrize(argnames, argvalues)(test_fn)

    return decorator
87
88


Philip Meier's avatar
Philip Meier committed
89
90
91
92
93
94
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


95
96
97
98
99
100
101
@pytest.fixture()
def test_id(request):
    test_class_name = request.cls.__name__ if request.cls is not None else None
    test_function_name = request.node.originalname
    return test_class_name, test_function_name


102
class TestKernels:
103
    sample_inputs = make_info_args_kwargs_parametrization(
104
105
106
        KERNEL_INFOS,
        args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
    )
107
    reference_inputs = make_info_args_kwargs_parametrization(
108
        [info for info in KERNEL_INFOS if info.reference_fn is not None],
109
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
110
    )
111

112
113
114
115
    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.logs_usage],
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    )
116
    @pytest.mark.parametrize("device", cpu_and_cuda())
117
118
119
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

120
121
        (input, *other_args), kwargs = args_kwargs.load(device)
        info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
122
123
124

        spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

125
    @ignore_jit_warning_no_profile
126
    @sample_inputs
127
    @pytest.mark.parametrize("device", cpu_and_cuda())
128
    def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
129
130
        kernel_eager = info.kernel
        kernel_scripted = script(kernel_eager)
131

132
        (input, *other_args), kwargs = args_kwargs.load(device)
133
        input = input.as_subclass(torch.Tensor)
134

135
136
        actual = kernel_scripted(input, *other_args, **kwargs)
        expected = kernel_eager(input, *other_args, **kwargs)
137

138
139
140
141
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
142
            msg=parametrized_error_message(input, other_args, **kwargs),
143
        )
144

145
146
147
148
149
150
    def _unbatch(self, batch, *, data_dims):
        if isinstance(batch, torch.Tensor):
            batched_tensor = batch
            metadata = ()
        else:
            batched_tensor, *metadata = batch
151

152
153
        if batched_tensor.ndim == data_dims:
            return batch
154

155
156
157
158
159
160
        return [
            self._unbatch(unbatched, data_dims=data_dims)
            for unbatched in (
                batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
            )
        ]
161
162

    @sample_inputs
163
    @pytest.mark.parametrize("device", cpu_and_cuda())
164
    def test_batched_vs_single(self, test_id, info, args_kwargs, device):
165
166
        (batched_input, *other_args), kwargs = args_kwargs.load(device)

167
        datapoint_type = datapoints.Image if is_pure_tensor(batched_input) else type(batched_input)
168
169
170
        # This dictionary contains the number of rightmost dimensions that contain the actual data.
        # Everything to the left is considered a batch dimension.
        data_dims = {
171
            datapoints.Image: 3,
172
            datapoints.BoundingBoxes: 1,
173
174
175
176
            # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
            # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
            # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
            # common ground.
177
178
            datapoints.Mask: 2,
            datapoints.Video: 4,
179
        }.get(datapoint_type)
180
181
        if data_dims is None:
            raise pytest.UsageError(
182
                f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}."
183
184
185
186
187
188
            ) from None
        elif batched_input.ndim <= data_dims:
            pytest.skip("Input is not batched.")
        elif not all(batched_input.shape[:-data_dims]):
            pytest.skip("Input has a degenerate batch shape.")

189
        batched_input = batched_input.as_subclass(torch.Tensor)
190
191
        batched_output = info.kernel(batched_input, *other_args, **kwargs)
        actual = self._unbatch(batched_output, data_dims=data_dims)
192

193
194
        single_inputs = self._unbatch(batched_input, data_dims=data_dims)
        expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
195

196
197
198
199
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
200
            msg=parametrized_error_message(batched_input, *other_args, **kwargs),
201
        )
202
203

    @sample_inputs
204
    @pytest.mark.parametrize("device", cpu_and_cuda())
205
206
    def test_no_inplace(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
207
        input = input.as_subclass(torch.Tensor)
208
209
210
211
212

        if input.numel() == 0:
            pytest.skip("The input has a degenerate shape.")

        input_version = input._version
213
        info.kernel(input, *other_args, **kwargs)
214

215
        assert input._version == input_version
216
217
218

    @sample_inputs
    @needs_cuda
219
    def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
220
        (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
221
        input_cpu = input_cpu.as_subclass(torch.Tensor)
222
223
224
225
226
        input_cuda = input_cpu.to("cuda")

        output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
        output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

227
228
229
230
231
        assert_close(
            output_cuda,
            output_cpu,
            check_device=False,
            **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
232
            msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
233
        )
234
235

    @sample_inputs
236
    @pytest.mark.parametrize("device", cpu_and_cuda())
237
238
    def test_dtype_and_device_consistency(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
239
        input = input.as_subclass(torch.Tensor)
240
241

        output = info.kernel(input, *other_args, **kwargs)
242
243
244
        # Most kernels just return a tensor, but some also return some additional metadata
        if not isinstance(output, torch.Tensor):
            output, *_ = output
245
246
247
248

        assert output.dtype == input.dtype
        assert output.device == input.device

249
    @reference_inputs
250
251
    def test_against_reference(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
252

253
254
255
        actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
        # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
        # metadata regardless of whether the kernel takes it explicitly or not
256
        expected = info.reference_fn(input, *other_args, **kwargs)
257

258
259
260
261
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
262
            msg=parametrized_error_message(input, *other_args, **kwargs),
263
264
265
266
267
268
269
270
        )

    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.float32_vs_uint8],
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
    )
    def test_float32_vs_uint8(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
271
        input = input.as_subclass(torch.Tensor)
272
273
274
275
276
277
278

        if input.dtype != torch.uint8:
            pytest.skip(f"Input dtype is {input.dtype}.")

        adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

        actual = info.kernel(
279
            F.to_dtype_image(input, dtype=torch.float32, scale=True),
280
281
282
283
            *adapted_other_args,
            **adapted_kwargs,
        )

284
        expected = F.to_dtype_image(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
285
286
287
288
289

        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
290
            msg=parametrized_error_message(input, *other_args, **kwargs),
291
        )
292
293


294
295
296
297
298
299
300
301
302
303
304
305
@pytest.fixture
def spy_on(mocker):
    def make_spy(fn, *, module=None, name=None):
        # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
        module = module or fn.__module__
        name = name or fn.__name__
        spy = mocker.patch(f"{module}.{name}", wraps=fn)
        return spy

    return make_spy


306
class TestDispatchers:
307
    image_sample_inputs = make_info_args_kwargs_parametrization(
308
309
        [info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
310
311
    )

312
313
314
315
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
316
    @pytest.mark.parametrize("device", cpu_and_cuda())
317
318
319
320
321
322
323
324
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

        args, kwargs = args_kwargs.load(device)
        info.dispatcher(*args, **kwargs)

        spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

325
    @ignore_jit_warning_no_profile
326
    @image_sample_inputs
327
    @pytest.mark.parametrize("device", cpu_and_cuda())
328
329
    def test_scripted_smoke(self, info, args_kwargs, device):
        dispatcher = script(info.dispatcher)
330

331
        (image_datapoint, *other_args), kwargs = args_kwargs.load(device)
332
        image_pure_tensor = torch.Tensor(image_datapoint)
333

334
        dispatcher(image_pure_tensor, *other_args, **kwargs)
335

336
337
    # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
    #  replaces this test for them.
338
    @ignore_jit_warning_no_profile
339
340
341
342
343
344
    @pytest.mark.parametrize(
        "dispatcher",
        [
            F.get_dimensions,
            F.get_image_num_channels,
            F.get_image_size,
345
346
            F.get_num_channels,
            F.get_num_frames,
Philip Meier's avatar
Philip Meier committed
347
            F.get_size,
348
            F.rgb_to_grayscale,
349
            F.uniform_temporal_subsample,
350
351
352
353
354
        ],
        ids=lambda dispatcher: dispatcher.__name__,
    )
    def test_scriptable(self, dispatcher):
        script(dispatcher)
355

356
    @image_sample_inputs
357
    def test_pure_tensor_output_type(self, info, args_kwargs):
358
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
359
        image_pure_tensor = image_datapoint.as_subclass(torch.Tensor)
360

361
        output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)
362
363
364
365
366
367
368
369
370
371
372
373
374
375

        # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
        assert type(output) is torch.Tensor

    @make_info_args_kwargs_parametrization(
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
    )
    def test_pil_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()

        if image_datapoint.ndim > 3:
            pytest.skip("Input is batched")

376
        image_pil = F.to_pil_image(image_datapoint)
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392

        output = info.dispatcher(image_pil, *other_args, **kwargs)

        assert isinstance(output, PIL.Image.Image)

    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
    def test_datapoint_output_type(self, info, args_kwargs):
        (datapoint, *other_args), kwargs = args_kwargs.load()

        output = info.dispatcher(datapoint, *other_args, **kwargs)

        assert isinstance(output, type(datapoint))

Nicolas Hug's avatar
Nicolas Hug committed
393
        if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
394
395
            assert output.format == datapoint.format

396
    @pytest.mark.parametrize(
397
        ("dispatcher_info", "datapoint_type", "kernel_info"),
398
        [
399
400
401
            pytest.param(
                dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}"
            )
402
            for dispatcher_info in DISPATCHER_INFOS
403
            for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items()
404
405
        ],
    )
406
    def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info):
407
408
409
410
411
412
        dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

        kernel_signature = inspect.signature(kernel_info.kernel)
        kernel_params = list(kernel_signature.parameters.values())[1:]

413
        # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
414
415
416
417
418
419
        # explicitly passed to the kernel.
        input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
        explicit_metadata = {
            datapoints.BoundingBoxes: {"format", "canvas_size"},
        }
        kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

        dispatcher_params = iter(dispatcher_params)
        for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
            try:
                # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
                # dispatcher parameters that have no kernel equivalent while keeping the order intact.
                while dispatcher_param.name != kernel_param.name:
                    dispatcher_param = next(dispatcher_params)
            except StopIteration:
                raise AssertionError(
                    f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
                    f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
                ) from None

            assert dispatcher_param == kernel_param

436
437
438
439
440
441
442
443
    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
    def test_unkown_type(self, info):
        unkown_input = object()
        (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")

        with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
            info.dispatcher(unkown_input, *other_args, **kwargs)

444
445
446
447
    @make_info_args_kwargs_parametrization(
        [
            info
            for info in DISPATCHER_INFOS
Nicolas Hug's avatar
Nicolas Hug committed
448
            if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
449
        ],
450
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
451
    )
452
453
454
    def test_bounding_boxes_format_consistency(self, info, args_kwargs):
        (bounding_boxes, *other_args), kwargs = args_kwargs.load()
        format = bounding_boxes.format
455

456
        output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
457
458
459

        assert output.format == format

460

461
@pytest.mark.parametrize(
462
    ("alias", "target"),
463
    [
464
465
466
467
468
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
469
            (F.to_pil_image, F.to_pil_image),
470
            (F.elastic_transform, F.elastic),
471
            (F.to_grayscale, F.rgb_to_grayscale),
472
        ]
473
474
    ],
)
475
476
def test_alias(alias, target):
    assert alias is target
477
478


479
@pytest.mark.parametrize("device", cpu_and_cuda())
480
481
482
483
484
485
486
487
488
489
490
491
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
    stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

    def assert_samples_from_standard_normal(t):
        p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
        return p_value > 1e-4

    image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
    mean = image.mean(dim=(1, 2)).tolist()
    std = image.std(dim=(1, 2)).tolist()

492
    assert_samples_from_standard_normal(F.normalize_image(image, mean, std))
493
494


495
class TestClampBoundingBoxes:
496
497
498
499
500
    @pytest.mark.parametrize(
        "metadata",
        [
            dict(),
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
501
            dict(canvas_size=(1, 1)),
502
503
        ],
    )
504
    def test_pure_tensor_insufficient_metadata(self, metadata):
505
        pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
506

Philip Meier's avatar
Philip Meier committed
507
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
508
            F.clamp_bounding_boxes(pure_tensor, **metadata)
509
510
511
512
513

    @pytest.mark.parametrize(
        "metadata",
        [
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
514
515
            dict(canvas_size=(1, 1)),
            dict(format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
516
517
518
        ],
    )
    def test_datapoint_explicit_metadata(self, metadata):
519
        datapoint = next(make_multiple_bounding_boxes())
520

Philip Meier's avatar
Philip Meier committed
521
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
522
            F.clamp_bounding_boxes(datapoint, **metadata)
523
524


525
class TestConvertFormatBoundingBoxes:
526
527
528
    @pytest.mark.parametrize(
        ("inpt", "old_format"),
        [
529
530
            (next(make_multiple_bounding_boxes()), None),
            (next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
531
532
533
534
        ],
    )
    def test_missing_new_format(self, inpt, old_format):
        with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
Nicolas Hug's avatar
Nicolas Hug committed
535
            F.convert_bounding_box_format(inpt, old_format)
536

537
    def test_pure_tensor_insufficient_metadata(self):
538
        pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
539
540

        with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
Nicolas Hug's avatar
Nicolas Hug committed
541
            F.convert_bounding_box_format(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
542
543

    def test_datapoint_explicit_metadata(self):
544
        datapoint = next(make_multiple_bounding_boxes())
545
546

        with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
Nicolas Hug's avatar
Nicolas Hug committed
547
            F.convert_bounding_box_format(
548
549
550
551
                datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
            )


552
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
553
#  `transforms_v2_kernel_infos.py`
554
555


556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
    rot = math.radians(angle_)
    cx, cy = center_
    tx, ty = translate_
    sx, sy = [math.radians(sh_) for sh_ in shear_]

    c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
    t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
    c_matrix_inv = np.linalg.inv(c_matrix)
    rs_matrix = np.array(
        [
            [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
            [scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
            [0, 0, 1],
        ]
    )
    shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
    shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
    rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
    true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
    return true_matrix


579
@pytest.mark.parametrize("device", cpu_and_cuda())
580
581
@pytest.mark.parametrize(
    "format",
582
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
583
)
584
585
586
587
588
589
590
@pytest.mark.parametrize(
    "top, left, height, width, expected_bboxes",
    [
        [8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
        [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
    ],
)
591
def test_correctness_crop_bounding_boxes(device, format, top, left, height, width, expected_bboxes):
592
593
594
595
596
597
598
599
600
601
602
603
604

    # Expected bboxes computed using Albumentations:
    # import numpy as np
    # from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox
    # expected_bboxes = []
    # for in_box in in_boxes:
    #     n_in_box = normalize_bbox(in_box, *size)
    #     n_out_box = crop_bbox_by_coords(
    #         n_in_box, (left, top, left + width, top + height), height, width, *size
    #     )
    #     out_box = denormalize_bbox(n_out_box, height, width)
    #     expected_bboxes.append(out_box)

605
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
606
    canvas_size = (64, 76)
607
608
609
610
611
    in_boxes = [
        [10.0, 15.0, 25.0, 35.0],
        [50.0, 5.0, 70.0, 22.0],
        [45.0, 46.0, 56.0, 62.0],
    ]
612
    in_boxes = torch.tensor(in_boxes, device=device)
613
    if format != datapoints.BoundingBoxFormat.XYXY:
Nicolas Hug's avatar
Nicolas Hug committed
614
        in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
615

616
    expected_bboxes = clamp_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
617
        datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
618
619
    ).tolist()

Philip Meier's avatar
Philip Meier committed
620
    output_boxes, output_canvas_size = F.crop_bounding_boxes(
621
        in_boxes,
622
        format,
623
624
        top,
        left,
Philip Meier's avatar
Philip Meier committed
625
626
        canvas_size[0],
        canvas_size[1],
627
628
    )

629
    if format != datapoints.BoundingBoxFormat.XYXY:
Nicolas Hug's avatar
Nicolas Hug committed
630
        output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
631

632
    torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
Philip Meier's avatar
Philip Meier committed
633
    torch.testing.assert_close(output_canvas_size, canvas_size)
634
635


636
@pytest.mark.parametrize("device", cpu_and_cuda())
637
638
639
640
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
    mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    mask[:, 0, :] = 1

641
    out_mask = F.vertical_flip_mask(mask)
642
643
644
645

    expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    expected_mask[:, -1, :] = 1
    torch.testing.assert_close(out_mask, expected_mask)
646
647


648
@pytest.mark.parametrize("device", cpu_and_cuda())
649
650
@pytest.mark.parametrize(
    "format",
651
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
652
653
654
655
656
657
658
659
)
@pytest.mark.parametrize(
    "top, left, height, width, size",
    [
        [0, 0, 30, 30, (60, 60)],
        [-5, 5, 35, 45, (32, 34)],
    ],
)
660
def test_correctness_resized_crop_bounding_boxes(device, format, top, left, height, width, size):
661
    def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
662
663
664
665
666
667
668
        # bbox should be xyxy
        bbox[0] = (bbox[0] - left_) * size_[1] / width_
        bbox[1] = (bbox[1] - top_) * size_[0] / height_
        bbox[2] = (bbox[2] - left_) * size_[1] / width_
        bbox[3] = (bbox[3] - top_) * size_[0] / height_
        return bbox

669
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
670
    canvas_size = (100, 100)
671
672
673
674
675
676
    in_boxes = [
        [10.0, 10.0, 20.0, 20.0],
        [5.0, 10.0, 15.0, 20.0],
    ]
    expected_bboxes = []
    for in_box in in_boxes:
677
        expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
678
679
    expected_bboxes = torch.tensor(expected_bboxes, device=device)

680
    in_boxes = datapoints.BoundingBoxes(
Philip Meier's avatar
Philip Meier committed
681
        in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
682
    )
683
    if format != datapoints.BoundingBoxFormat.XYXY:
Nicolas Hug's avatar
Nicolas Hug committed
684
        in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
685

Philip Meier's avatar
Philip Meier committed
686
    output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
687

688
    if format != datapoints.BoundingBoxFormat.XYXY:
Nicolas Hug's avatar
Nicolas Hug committed
689
        output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
690
691

    torch.testing.assert_close(output_boxes, expected_bboxes)
Philip Meier's avatar
Philip Meier committed
692
    torch.testing.assert_close(output_canvas_size, size)
693
694


695
696
697
698
699
700
701
702
703
704
705
706
def _parse_padding(padding):
    if isinstance(padding, int):
        return [padding] * 4
    if isinstance(padding, list):
        if len(padding) == 1:
            return padding * 4
        if len(padding) == 2:
            return padding * 2  # [left, up, right, down]

    return padding


707
@pytest.mark.parametrize("device", cpu_and_cuda())
708
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
709
def test_correctness_pad_bounding_boxes(device, padding):
710
    def _compute_expected_bbox(bbox, format, padding_):
711
712
        pad_left, pad_up, _, _ = _parse_padding(padding_)

713
        dtype = bbox.dtype
714
715
        bbox = (
            bbox.clone()
716
            if format == datapoints.BoundingBoxFormat.XYXY
Nicolas Hug's avatar
Nicolas Hug committed
717
            else convert_bounding_box_format(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
718
        )
719
720
721
722

        bbox[0::2] += pad_left
        bbox[1::2] += pad_up

Nicolas Hug's avatar
Nicolas Hug committed
723
        bbox = convert_bounding_box_format(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
724
        if bbox.dtype != dtype:
725
726
            # Temporary cast to original dtype
            # e.g. float32 -> int
727
            bbox = bbox.to(dtype)
728
729
        return bbox

Philip Meier's avatar
Philip Meier committed
730
    def _compute_expected_canvas_size(bbox, padding_):
731
        pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
Philip Meier's avatar
Philip Meier committed
732
        height, width = bbox.canvas_size
733
734
        return height + pad_up + pad_down, width + pad_left + pad_right

735
    for bboxes in make_multiple_bounding_boxes(extra_dims=((4,),)):
736
737
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
738
        bboxes_canvas_size = bboxes.canvas_size
739

Philip Meier's avatar
Philip Meier committed
740
741
        output_boxes, output_canvas_size = F.pad_bounding_boxes(
            bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
742
743
        )

Philip Meier's avatar
Philip Meier committed
744
        torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
745

746
747
748
        expected_bboxes = torch.stack(
            [_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()]
        ).reshape(bboxes.shape)
749

750
        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
751
752


753
@pytest.mark.parametrize("device", cpu_and_cuda())
754
755
756
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
    mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)

757
    out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
758
759
760
761
762
763

    expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
    expected_mask[:, 1:-1, 1:-1] = 1
    torch.testing.assert_close(out_mask, expected_mask)


764
@pytest.mark.parametrize("device", cpu_and_cuda())
765
766
767
768
769
770
771
772
@pytest.mark.parametrize(
    "startpoints, endpoints",
    [
        [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
    ],
)
773
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
774
    def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_):
775
776
777
778
779
780
781
782
783
784
785
786
787
        m1 = np.array(
            [
                [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
                [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
            ]
        )
        m2 = np.array(
            [
                [pcoeffs_[6], pcoeffs_[7], 1.0],
                [pcoeffs_[6], pcoeffs_[7], 1.0],
            ]
        )

Nicolas Hug's avatar
Nicolas Hug committed
788
        bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY)
789
790
791
792
793
794
795
796
797
798
799
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
        numer = np.matmul(points, m1.T)
        denom = np.matmul(points, m2.T)
        transformed_points = numer / denom
800
801
802
803
804
805
806
807
        out_bbox = np.array(
            [
                np.min(transformed_points[:, 0]),
                np.min(transformed_points[:, 1]),
                np.max(transformed_points[:, 0]),
                np.max(transformed_points[:, 1]),
            ]
        )
808
        out_bbox = torch.from_numpy(out_bbox)
Nicolas Hug's avatar
Nicolas Hug committed
809
        out_bbox = convert_bounding_box_format(
810
            out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_
811
        )
812
        return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
813

Philip Meier's avatar
Philip Meier committed
814
    canvas_size = (32, 38)
815
816
817
818

    pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
    inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

819
    for bboxes in make_multiple_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
820
821
        bboxes = bboxes.to(device)

822
        output_bboxes = F.perspective_bounding_boxes(
823
824
            bboxes.as_subclass(torch.Tensor),
            format=bboxes.format,
Philip Meier's avatar
Philip Meier committed
825
            canvas_size=bboxes.canvas_size,
826
827
            startpoints=None,
            endpoints=None,
828
            coefficients=pcoeffs,
829
830
        )

831
832
833
834
835
836
        expected_bboxes = torch.stack(
            [
                _compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs)
                for b in bboxes.reshape(-1, 4).unbind()
            ]
        ).reshape(bboxes.shape)
837

838
        torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
839
840


841
@pytest.mark.parametrize("device", cpu_and_cuda())
842
843
844
845
@pytest.mark.parametrize(
    "output_size",
    [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
)
846
def test_correctness_center_crop_bounding_boxes(device, output_size):
847
    def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
848
        dtype = bbox.dtype
Nicolas Hug's avatar
Nicolas Hug committed
849
        bbox = convert_bounding_box_format(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
850
851
852
853

        if len(output_size_) == 1:
            output_size_.append(output_size_[-1])

Philip Meier's avatar
Philip Meier committed
854
855
        cy = int(round((canvas_size_[0] - output_size_[0]) * 0.5))
        cx = int(round((canvas_size_[1] - output_size_[1]) * 0.5))
856
857
858
859
860
861
        out_bbox = [
            bbox[0].item() - cx,
            bbox[1].item() - cy,
            bbox[2].item(),
            bbox[3].item(),
        ]
862
        out_bbox = torch.tensor(out_bbox)
Nicolas Hug's avatar
Nicolas Hug committed
863
        out_bbox = convert_bounding_box_format(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
Philip Meier's avatar
Philip Meier committed
864
        out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
865
        return out_bbox.to(dtype=dtype, device=bbox.device)
866

867
    for bboxes in make_multiple_bounding_boxes(extra_dims=((4,),)):
868
869
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
870
        bboxes_canvas_size = bboxes.canvas_size
871

Philip Meier's avatar
Philip Meier committed
872
873
        output_boxes, output_canvas_size = F.center_crop_bounding_boxes(
            bboxes, bboxes_format, bboxes_canvas_size, output_size
874
        )
875

876
877
878
879
880
881
        expected_bboxes = torch.stack(
            [
                _compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size)
                for b in bboxes.reshape(-1, 4).unbind()
            ]
        ).reshape(bboxes.shape)
882
883

        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
Philip Meier's avatar
Philip Meier committed
884
        torch.testing.assert_close(output_canvas_size, output_size)
885
886


887
@pytest.mark.parametrize("device", cpu_and_cuda())
888
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
889
890
def test_correctness_center_crop_mask(device, output_size):
    def _compute_expected_mask(mask, output_size):
891
892
893
894
895
        crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]

        _, image_height, image_width = mask.shape
        if crop_width > image_height or crop_height > image_width:
            padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
896
            mask = F.pad_image(mask, padding, fill=0)
897
898
899
900
901
902
903

        left = round((image_width - crop_width) * 0.5)
        top = round((image_height - crop_height) * 0.5)

        return mask[:, top : top + crop_height, left : left + crop_width]

    mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
904
    actual = F.center_crop_mask(mask, output_size)
905

906
    expected = _compute_expected_mask(mask, output_size)
907
    torch.testing.assert_close(expected, actual)
908
909
910


# Copied from test/test_functional_tensor.py
911
@pytest.mark.parametrize("device", cpu_and_cuda())
Philip Meier's avatar
Philip Meier committed
912
@pytest.mark.parametrize("canvas_size", ("small", "large"))
913
914
915
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
Philip Meier's avatar
Philip Meier committed
916
def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
917
    fn = F.gaussian_blur_image
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935

    # true_cv2_results = {
    #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
    #     "3_3_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
    #     "3_3_0.5": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
    #     "3_5_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
    #     "3_5_0.5": ...
    #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
    #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
    #     "23_23_1.7": ...
    # }
    p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
    true_cv2_results = torch.load(p)

Philip Meier's avatar
Philip Meier committed
936
    if canvas_size == "small":
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
        tensor = (
            torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
        )
    else:
        tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)

    if dt == torch.float16 and device == "cpu":
        # skip float16 on CPU case
        return

    if dt is not None:
        tensor = tensor.to(dtype=dt)

    _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
    _sigma = sigma[0] if sigma is not None else None
    shape = tensor.shape
    gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
    if gt_key not in true_cv2_results:
        return

    true_out = (
        torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
    )

961
    image = datapoints.Image(tensor)
962
963
964
965
966

    out = fn(image, kernel_size=ksize, sigma=sigma)
    torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


967
968
969
970
971
972
973
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
974
975
def test_to_image(inpt):
    output = F.to_image(inpt)
976
    assert isinstance(output, torch.Tensor)
977
    assert output.shape == (3, 32, 32)
978
979
980
981
982
983
984
985
986
987
988
989

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
990
991
def test_to_pil_image(inpt, mode):
    output = F.to_pil_image(inpt, mode=mode)
992
993
994
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()
995
996
997
998


def test_equalize_image_tensor_edge_cases():
    inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
999
    output = F.equalize_image(inpt)
1000
1001
1002
1003
    torch.testing.assert_close(inpt, output)

    inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
    inpt[..., 100:, 100:] = 1
1004
    output = F.equalize_image(inpt)
1005
    assert output.unique().tolist() == [0, 255]
1006
1007


1008
@pytest.mark.parametrize("device", cpu_and_cuda())
1009
1010
1011
1012
1013
1014
1015
def test_correctness_uniform_temporal_subsample(device):
    video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
    out_video = F.uniform_temporal_subsample(video, 5)
    assert out_video.unique().tolist() == [0, 2, 4, 6, 9]

    out_video = F.uniform_temporal_subsample(video, 8)
    assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]