test_transforms_v2_functional.py 29.7 KB
Newer Older
1
import inspect
2
import re
3

4
import numpy as np
5
import PIL.Image
6
import pytest
7
import torch
8

9
from common_utils import assert_close, cache, cpu_and_cuda, needs_cuda, set_rng_seed
10
from torch.utils._pytree import tree_map
11
from torchvision import tv_tensors
12
from torchvision.transforms.functional import _get_perspective_coeffs
13
from torchvision.transforms.v2 import functional as F
Nicolas Hug's avatar
Nicolas Hug committed
14
from torchvision.transforms.v2._utils import is_pure_tensor
15
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
Nicolas Hug's avatar
Nicolas Hug committed
16
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format
17
18
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
19
20
21
22
23
from transforms_v2_legacy_utils import (
    DEFAULT_SQUARE_SPATIAL_SIZE,
    make_multiple_bounding_boxes,
    parametrized_error_message,
)
24
25


26
27
28
29
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}


30
31
32
33
34
35
@cache
def script(fn):
    try:
        return torch.jit.script(fn)
    except Exception as error:
        raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
36
37


38
39
40
41
42
43
44
45
46
# Scripting a function often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
    f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
)


47
48
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
    args_kwargs = list(args_kwargs_fn(info))
49
50
51
52
    if not args_kwargs:
        raise pytest.UsageError(
            f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
        )
53
54
55
56
57
58
59
60
61
62
63
64
    idx_field_len = len(str(len(args_kwargs)))
    return [
        pytest.param(
            info,
            args_kwargs_,
            marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
            id=f"{info.id}-{idx:0{idx_field_len}}",
        )
        for idx, args_kwargs_ in enumerate(args_kwargs)
    ]


65
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
66
67
68
69
70
71
72
73
    def decorator(test_fn):
        parts = test_fn.__qualname__.split(".")
        if len(parts) == 1:
            test_class_name = None
            test_function_name = parts[0]
        elif len(parts) == 2:
            test_class_name, test_function_name = parts
        else:
74
            raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
75
76
77
78
79
        test_id = (test_class_name, test_function_name)

        argnames = ("info", "args_kwargs")
        argvalues = []
        for info in infos:
80
            argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
81
82
83
84

        return pytest.mark.parametrize(argnames, argvalues)(test_fn)

    return decorator
85
86


Philip Meier's avatar
Philip Meier committed
87
88
89
90
91
92
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


93
94
95
96
97
98
99
@pytest.fixture()
def test_id(request):
    test_class_name = request.cls.__name__ if request.cls is not None else None
    test_function_name = request.node.originalname
    return test_class_name, test_function_name


100
class TestKernels:
101
    sample_inputs = make_info_args_kwargs_parametrization(
102
103
104
        KERNEL_INFOS,
        args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
    )
105
    reference_inputs = make_info_args_kwargs_parametrization(
106
        [info for info in KERNEL_INFOS if info.reference_fn is not None],
107
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
108
    )
109

110
111
112
113
    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.logs_usage],
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    )
114
    @pytest.mark.parametrize("device", cpu_and_cuda())
115
116
117
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

118
119
        (input, *other_args), kwargs = args_kwargs.load(device)
        info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
120
121
122

        spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

123
    @ignore_jit_warning_no_profile
124
    @sample_inputs
125
    @pytest.mark.parametrize("device", cpu_and_cuda())
126
    def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
127
128
        kernel_eager = info.kernel
        kernel_scripted = script(kernel_eager)
129

130
        (input, *other_args), kwargs = args_kwargs.load(device)
131
        input = input.as_subclass(torch.Tensor)
132

133
134
        actual = kernel_scripted(input, *other_args, **kwargs)
        expected = kernel_eager(input, *other_args, **kwargs)
135

136
137
138
139
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
140
            msg=parametrized_error_message(input, other_args, **kwargs),
141
        )
142

143
144
145
146
147
148
    def _unbatch(self, batch, *, data_dims):
        if isinstance(batch, torch.Tensor):
            batched_tensor = batch
            metadata = ()
        else:
            batched_tensor, *metadata = batch
149

150
151
        if batched_tensor.ndim == data_dims:
            return batch
152

153
154
155
156
157
158
        return [
            self._unbatch(unbatched, data_dims=data_dims)
            for unbatched in (
                batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
            )
        ]
159
160

    @sample_inputs
161
    @pytest.mark.parametrize("device", cpu_and_cuda())
162
    def test_batched_vs_single(self, test_id, info, args_kwargs, device):
163
164
        (batched_input, *other_args), kwargs = args_kwargs.load(device)

165
        tv_tensor_type = tv_tensors.Image if is_pure_tensor(batched_input) else type(batched_input)
166
167
168
        # This dictionary contains the number of rightmost dimensions that contain the actual data.
        # Everything to the left is considered a batch dimension.
        data_dims = {
169
170
            tv_tensors.Image: 3,
            tv_tensors.BoundingBoxes: 1,
171
172
173
174
            # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
            # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
            # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
            # common ground.
175
176
177
            tv_tensors.Mask: 2,
            tv_tensors.Video: 4,
        }.get(tv_tensor_type)
178
179
        if data_dims is None:
            raise pytest.UsageError(
180
                f"The number of data dimensions cannot be determined for input of type {tv_tensor_type.__name__}."
181
182
183
184
185
186
            ) from None
        elif batched_input.ndim <= data_dims:
            pytest.skip("Input is not batched.")
        elif not all(batched_input.shape[:-data_dims]):
            pytest.skip("Input has a degenerate batch shape.")

187
        batched_input = batched_input.as_subclass(torch.Tensor)
188
189
        batched_output = info.kernel(batched_input, *other_args, **kwargs)
        actual = self._unbatch(batched_output, data_dims=data_dims)
190

191
192
        single_inputs = self._unbatch(batched_input, data_dims=data_dims)
        expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
193

194
195
196
197
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
198
            msg=parametrized_error_message(batched_input, *other_args, **kwargs),
199
        )
200
201

    @sample_inputs
202
    @pytest.mark.parametrize("device", cpu_and_cuda())
203
204
    def test_no_inplace(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
205
        input = input.as_subclass(torch.Tensor)
206
207
208
209
210

        if input.numel() == 0:
            pytest.skip("The input has a degenerate shape.")

        input_version = input._version
211
        info.kernel(input, *other_args, **kwargs)
212

213
        assert input._version == input_version
214
215
216

    @sample_inputs
    @needs_cuda
217
    def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
218
        (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
219
        input_cpu = input_cpu.as_subclass(torch.Tensor)
220
221
222
223
224
        input_cuda = input_cpu.to("cuda")

        output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
        output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

225
226
227
228
229
        assert_close(
            output_cuda,
            output_cpu,
            check_device=False,
            **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
230
            msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
231
        )
232
233

    @sample_inputs
234
    @pytest.mark.parametrize("device", cpu_and_cuda())
235
236
    def test_dtype_and_device_consistency(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
237
        input = input.as_subclass(torch.Tensor)
238
239

        output = info.kernel(input, *other_args, **kwargs)
240
241
242
        # Most kernels just return a tensor, but some also return some additional metadata
        if not isinstance(output, torch.Tensor):
            output, *_ = output
243
244
245
246

        assert output.dtype == input.dtype
        assert output.device == input.device

247
    @reference_inputs
248
249
    def test_against_reference(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
250

251
252
253
        actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
        # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
        # metadata regardless of whether the kernel takes it explicitly or not
254
        expected = info.reference_fn(input, *other_args, **kwargs)
255

256
257
258
259
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
260
            msg=parametrized_error_message(input, *other_args, **kwargs),
261
262
263
264
265
266
267
268
        )

    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.float32_vs_uint8],
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
    )
    def test_float32_vs_uint8(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
269
        input = input.as_subclass(torch.Tensor)
270
271
272
273
274
275
276

        if input.dtype != torch.uint8:
            pytest.skip(f"Input dtype is {input.dtype}.")

        adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

        actual = info.kernel(
277
            F.to_dtype_image(input, dtype=torch.float32, scale=True),
278
279
280
281
            *adapted_other_args,
            **adapted_kwargs,
        )

282
        expected = F.to_dtype_image(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
283
284
285
286
287

        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
288
            msg=parametrized_error_message(input, *other_args, **kwargs),
289
        )
290
291


292
293
294
295
296
297
298
299
300
301
302
303
@pytest.fixture
def spy_on(mocker):
    def make_spy(fn, *, module=None, name=None):
        # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
        module = module or fn.__module__
        name = name or fn.__name__
        spy = mocker.patch(f"{module}.{name}", wraps=fn)
        return spy

    return make_spy


304
class TestDispatchers:
305
    image_sample_inputs = make_info_args_kwargs_parametrization(
306
307
        [info for info in DISPATCHER_INFOS if tv_tensors.Image in info.kernels],
        args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
308
309
    )

310
311
312
313
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
314
    @pytest.mark.parametrize("device", cpu_and_cuda())
315
316
317
318
319
320
321
322
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

        args, kwargs = args_kwargs.load(device)
        info.dispatcher(*args, **kwargs)

        spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

323
    @ignore_jit_warning_no_profile
324
    @image_sample_inputs
325
    @pytest.mark.parametrize("device", cpu_and_cuda())
326
327
    def test_scripted_smoke(self, info, args_kwargs, device):
        dispatcher = script(info.dispatcher)
328

329
330
        (image_tv_tensor, *other_args), kwargs = args_kwargs.load(device)
        image_pure_tensor = torch.Tensor(image_tv_tensor)
331

332
        dispatcher(image_pure_tensor, *other_args, **kwargs)
333

334
335
    # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
    #  replaces this test for them.
336
    @ignore_jit_warning_no_profile
337
338
339
340
341
342
    @pytest.mark.parametrize(
        "dispatcher",
        [
            F.get_dimensions,
            F.get_image_num_channels,
            F.get_image_size,
343
344
            F.get_num_channels,
            F.get_num_frames,
Philip Meier's avatar
Philip Meier committed
345
            F.get_size,
346
            F.rgb_to_grayscale,
347
            F.uniform_temporal_subsample,
348
349
350
351
352
        ],
        ids=lambda dispatcher: dispatcher.__name__,
    )
    def test_scriptable(self, dispatcher):
        script(dispatcher)
353

354
    @image_sample_inputs
355
    def test_pure_tensor_output_type(self, info, args_kwargs):
356
357
        (image_tv_tensor, *other_args), kwargs = args_kwargs.load()
        image_pure_tensor = image_tv_tensor.as_subclass(torch.Tensor)
358

359
        output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)
360

361
        # We cannot use `isinstance` here since all tv_tensors are instances of `torch.Tensor` as well
362
363
364
365
        assert type(output) is torch.Tensor

    @make_info_args_kwargs_parametrization(
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
366
        args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
367
368
    )
    def test_pil_output_type(self, info, args_kwargs):
369
        (image_tv_tensor, *other_args), kwargs = args_kwargs.load()
370

371
        if image_tv_tensor.ndim > 3:
372
373
            pytest.skip("Input is batched")

374
        image_pil = F.to_pil_image(image_tv_tensor)
375
376
377
378
379
380
381
382
383

        output = info.dispatcher(image_pil, *other_args, **kwargs)

        assert isinstance(output, PIL.Image.Image)

    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
384
385
    def test_tv_tensor_output_type(self, info, args_kwargs):
        (tv_tensor, *other_args), kwargs = args_kwargs.load()
386

387
        output = info.dispatcher(tv_tensor, *other_args, **kwargs)
388

389
        assert isinstance(output, type(tv_tensor))
390

391
392
        if isinstance(tv_tensor, tv_tensors.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
            assert output.format == tv_tensor.format
393

394
    @pytest.mark.parametrize(
395
        ("dispatcher_info", "tv_tensor_type", "kernel_info"),
396
        [
397
            pytest.param(
398
                dispatcher_info, tv_tensor_type, kernel_info, id=f"{dispatcher_info.id}-{tv_tensor_type.__name__}"
399
            )
400
            for dispatcher_info in DISPATCHER_INFOS
401
            for tv_tensor_type, kernel_info in dispatcher_info.kernel_infos.items()
402
403
        ],
    )
404
    def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, tv_tensor_type, kernel_info):
405
406
407
408
409
410
        dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

        kernel_signature = inspect.signature(kernel_info.kernel)
        kernel_params = list(kernel_signature.parameters.values())[1:]

411
        # We filter out metadata that is implicitly passed to the dispatcher through the input tv_tensor, but has to be
412
413
414
        # explicitly passed to the kernel.
        input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
        explicit_metadata = {
415
            tv_tensors.BoundingBoxes: {"format", "canvas_size"},
416
417
        }
        kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433

        dispatcher_params = iter(dispatcher_params)
        for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
            try:
                # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
                # dispatcher parameters that have no kernel equivalent while keeping the order intact.
                while dispatcher_param.name != kernel_param.name:
                    dispatcher_param = next(dispatcher_params)
            except StopIteration:
                raise AssertionError(
                    f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
                    f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
                ) from None

            assert dispatcher_param == kernel_param

434
435
436
437
438
439
440
441
    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
    def test_unkown_type(self, info):
        unkown_input = object()
        (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")

        with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
            info.dispatcher(unkown_input, *other_args, **kwargs)

442
443
444
445
    @make_info_args_kwargs_parametrization(
        [
            info
            for info in DISPATCHER_INFOS
446
            if tv_tensors.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
447
        ],
448
        args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.BoundingBoxes),
449
    )
450
451
452
    def test_bounding_boxes_format_consistency(self, info, args_kwargs):
        (bounding_boxes, *other_args), kwargs = args_kwargs.load()
        format = bounding_boxes.format
453

454
        output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
455
456
457

        assert output.format == format

458

459
@pytest.mark.parametrize(
460
    ("alias", "target"),
461
    [
462
463
464
465
466
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
467
            (F.to_pil_image, F.to_pil_image),
468
            (F.elastic_transform, F.elastic),
469
            (F.to_grayscale, F.rgb_to_grayscale),
470
        ]
471
472
    ],
)
473
474
def test_alias(alias, target):
    assert alias is target
475
476


477
@pytest.mark.parametrize("device", cpu_and_cuda())
478
479
480
481
482
483
484
485
486
487
488
489
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
    stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

    def assert_samples_from_standard_normal(t):
        p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
        return p_value > 1e-4

    image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
    mean = image.mean(dim=(1, 2)).tolist()
    std = image.std(dim=(1, 2)).tolist()

490
    assert_samples_from_standard_normal(F.normalize_image(image, mean, std))
491
492


493
class TestClampBoundingBoxes:
494
495
496
497
    @pytest.mark.parametrize(
        "metadata",
        [
            dict(),
498
            dict(format=tv_tensors.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
499
            dict(canvas_size=(1, 1)),
500
501
        ],
    )
502
    def test_pure_tensor_insufficient_metadata(self, metadata):
503
        pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
504

Philip Meier's avatar
Philip Meier committed
505
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
506
            F.clamp_bounding_boxes(pure_tensor, **metadata)
507
508
509
510

    @pytest.mark.parametrize(
        "metadata",
        [
511
            dict(format=tv_tensors.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
512
            dict(canvas_size=(1, 1)),
513
            dict(format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
514
515
        ],
    )
516
517
    def test_tv_tensor_explicit_metadata(self, metadata):
        tv_tensor = next(make_multiple_bounding_boxes())
518

Philip Meier's avatar
Philip Meier committed
519
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
520
            F.clamp_bounding_boxes(tv_tensor, **metadata)
521
522


523
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
524
#  `transforms_v2_kernel_infos.py`
525
526


527
528
529
530
531
532
533
534
535
536
537
538
def _parse_padding(padding):
    if isinstance(padding, int):
        return [padding] * 4
    if isinstance(padding, list):
        if len(padding) == 1:
            return padding * 4
        if len(padding) == 2:
            return padding * 2  # [left, up, right, down]

    return padding


539
@pytest.mark.parametrize("device", cpu_and_cuda())
540
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
541
def test_correctness_pad_bounding_boxes(device, padding):
542
    def _compute_expected_bbox(bbox, format, padding_):
543
544
        pad_left, pad_up, _, _ = _parse_padding(padding_)

545
        dtype = bbox.dtype
546
547
        bbox = (
            bbox.clone()
548
549
            if format == tv_tensors.BoundingBoxFormat.XYXY
            else convert_bounding_box_format(bbox, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
550
        )
551
552
553
554

        bbox[0::2] += pad_left
        bbox[1::2] += pad_up

555
        bbox = convert_bounding_box_format(bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format)
556
        if bbox.dtype != dtype:
557
558
            # Temporary cast to original dtype
            # e.g. float32 -> int
559
            bbox = bbox.to(dtype)
560
561
        return bbox

Philip Meier's avatar
Philip Meier committed
562
    def _compute_expected_canvas_size(bbox, padding_):
563
        pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
Philip Meier's avatar
Philip Meier committed
564
        height, width = bbox.canvas_size
565
566
        return height + pad_up + pad_down, width + pad_left + pad_right

567
    for bboxes in make_multiple_bounding_boxes(extra_dims=((4,),)):
568
569
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
570
        bboxes_canvas_size = bboxes.canvas_size
571

Philip Meier's avatar
Philip Meier committed
572
573
        output_boxes, output_canvas_size = F.pad_bounding_boxes(
            bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
574
575
        )

Philip Meier's avatar
Philip Meier committed
576
        torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
577

578
579
580
        expected_bboxes = torch.stack(
            [_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()]
        ).reshape(bboxes.shape)
581

582
        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
583
584


585
@pytest.mark.parametrize("device", cpu_and_cuda())
586
587
588
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
    mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)

589
    out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
590
591
592
593
594
595

    expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
    expected_mask[:, 1:-1, 1:-1] = 1
    torch.testing.assert_close(out_mask, expected_mask)


596
@pytest.mark.parametrize("device", cpu_and_cuda())
597
598
599
600
601
602
603
604
@pytest.mark.parametrize(
    "startpoints, endpoints",
    [
        [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
    ],
)
605
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
606
    def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_):
607
608
609
610
611
612
613
614
615
616
617
618
619
        m1 = np.array(
            [
                [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
                [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
            ]
        )
        m2 = np.array(
            [
                [pcoeffs_[6], pcoeffs_[7], 1.0],
                [pcoeffs_[6], pcoeffs_[7], 1.0],
            ]
        )

620
        bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=tv_tensors.BoundingBoxFormat.XYXY)
621
622
623
624
625
626
627
628
629
630
631
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
        numer = np.matmul(points, m1.T)
        denom = np.matmul(points, m2.T)
        transformed_points = numer / denom
632
633
634
635
636
637
638
639
        out_bbox = np.array(
            [
                np.min(transformed_points[:, 0]),
                np.min(transformed_points[:, 1]),
                np.max(transformed_points[:, 0]),
                np.max(transformed_points[:, 1]),
            ]
        )
640
        out_bbox = torch.from_numpy(out_bbox)
Nicolas Hug's avatar
Nicolas Hug committed
641
        out_bbox = convert_bounding_box_format(
642
            out_bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format_
643
        )
644
        return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
645

Philip Meier's avatar
Philip Meier committed
646
    canvas_size = (32, 38)
647
648
649
650

    pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
    inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

651
    for bboxes in make_multiple_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
652
653
        bboxes = bboxes.to(device)

654
        output_bboxes = F.perspective_bounding_boxes(
655
656
            bboxes.as_subclass(torch.Tensor),
            format=bboxes.format,
Philip Meier's avatar
Philip Meier committed
657
            canvas_size=bboxes.canvas_size,
658
659
            startpoints=None,
            endpoints=None,
660
            coefficients=pcoeffs,
661
662
        )

663
664
665
666
667
668
        expected_bboxes = torch.stack(
            [
                _compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs)
                for b in bboxes.reshape(-1, 4).unbind()
            ]
        ).reshape(bboxes.shape)
669

670
        torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
671
672


673
@pytest.mark.parametrize("device", cpu_and_cuda())
674
675
676
677
@pytest.mark.parametrize(
    "output_size",
    [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
)
678
def test_correctness_center_crop_bounding_boxes(device, output_size):
679
    def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
680
        dtype = bbox.dtype
681
        bbox = convert_bounding_box_format(bbox.float(), format_, tv_tensors.BoundingBoxFormat.XYWH)
682
683
684
685

        if len(output_size_) == 1:
            output_size_.append(output_size_[-1])

Philip Meier's avatar
Philip Meier committed
686
687
        cy = int(round((canvas_size_[0] - output_size_[0]) * 0.5))
        cx = int(round((canvas_size_[1] - output_size_[1]) * 0.5))
688
689
690
691
692
693
        out_bbox = [
            bbox[0].item() - cx,
            bbox[1].item() - cy,
            bbox[2].item(),
            bbox[3].item(),
        ]
694
        out_bbox = torch.tensor(out_bbox)
695
        out_bbox = convert_bounding_box_format(out_bbox, tv_tensors.BoundingBoxFormat.XYWH, format_)
Philip Meier's avatar
Philip Meier committed
696
        out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
697
        return out_bbox.to(dtype=dtype, device=bbox.device)
698

699
    for bboxes in make_multiple_bounding_boxes(extra_dims=((4,),)):
700
701
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
702
        bboxes_canvas_size = bboxes.canvas_size
703

Philip Meier's avatar
Philip Meier committed
704
705
        output_boxes, output_canvas_size = F.center_crop_bounding_boxes(
            bboxes, bboxes_format, bboxes_canvas_size, output_size
706
        )
707

708
709
710
711
712
713
        expected_bboxes = torch.stack(
            [
                _compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size)
                for b in bboxes.reshape(-1, 4).unbind()
            ]
        ).reshape(bboxes.shape)
714
715

        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
Philip Meier's avatar
Philip Meier committed
716
        torch.testing.assert_close(output_canvas_size, output_size)
717
718


719
@pytest.mark.parametrize("device", cpu_and_cuda())
720
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
721
722
def test_correctness_center_crop_mask(device, output_size):
    def _compute_expected_mask(mask, output_size):
723
724
725
726
727
        crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]

        _, image_height, image_width = mask.shape
        if crop_width > image_height or crop_height > image_width:
            padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
728
            mask = F.pad_image(mask, padding, fill=0)
729
730
731
732
733
734
735

        left = round((image_width - crop_width) * 0.5)
        top = round((image_height - crop_height) * 0.5)

        return mask[:, top : top + crop_height, left : left + crop_width]

    mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
736
    actual = F.center_crop_mask(mask, output_size)
737

738
    expected = _compute_expected_mask(mask, output_size)
739
    torch.testing.assert_close(expected, actual)
740
741


742
743
744
745
746
747
748
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
749
750
def test_to_image(inpt):
    output = F.to_image(inpt)
751
    assert isinstance(output, torch.Tensor)
752
    assert output.shape == (3, 32, 32)
753
754
755
756
757
758
759
760
761
762
763
764

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
765
766
def test_to_pil_image(inpt, mode):
    output = F.to_pil_image(inpt, mode=mode)
767
768
769
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()
770
771
772
773


def test_equalize_image_tensor_edge_cases():
    inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
774
    output = F.equalize_image(inpt)
775
776
777
778
    torch.testing.assert_close(inpt, output)

    inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
    inpt[..., 100:, 100:] = 1
779
    output = F.equalize_image(inpt)
780
    assert output.unique().tolist() == [0, 255]
781
782


783
@pytest.mark.parametrize("device", cpu_and_cuda())
784
785
786
787
788
789
790
def test_correctness_uniform_temporal_subsample(device):
    video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
    out_video = F.uniform_temporal_subsample(video, 5)
    assert out_video.unique().tolist() == [0, 2, 4, 6, 9]

    out_video = F.uniform_temporal_subsample(video, 8)
    assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]