test_transforms_v2_functional.py 24.3 KB
Newer Older
1
import inspect
2
import re
3

4
import numpy as np
5
import PIL.Image
6
import pytest
7
import torch
8

9
from common_utils import assert_close, cache, cpu_and_cuda, needs_cuda, set_rng_seed
10
from torch.utils._pytree import tree_map
11
from torchvision import tv_tensors
12
from torchvision.transforms.functional import _get_perspective_coeffs
13
from torchvision.transforms.v2 import functional as F
Nicolas Hug's avatar
Nicolas Hug committed
14
from torchvision.transforms.v2._utils import is_pure_tensor
Nicolas Hug's avatar
Nicolas Hug committed
15
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format
16
17
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
18
19
20
21
22
from transforms_v2_legacy_utils import (
    DEFAULT_SQUARE_SPATIAL_SIZE,
    make_multiple_bounding_boxes,
    parametrized_error_message,
)
23
24


25
26
27
28
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}


29
30
31
32
33
34
@cache
def script(fn):
    try:
        return torch.jit.script(fn)
    except Exception as error:
        raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
35
36


37
38
39
40
41
42
43
44
45
# Scripting a function often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
    f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
)


46
47
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
    args_kwargs = list(args_kwargs_fn(info))
48
49
50
51
    if not args_kwargs:
        raise pytest.UsageError(
            f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
        )
52
53
54
55
56
57
58
59
60
61
62
63
    idx_field_len = len(str(len(args_kwargs)))
    return [
        pytest.param(
            info,
            args_kwargs_,
            marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
            id=f"{info.id}-{idx:0{idx_field_len}}",
        )
        for idx, args_kwargs_ in enumerate(args_kwargs)
    ]


64
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
65
66
67
68
69
70
71
72
    def decorator(test_fn):
        parts = test_fn.__qualname__.split(".")
        if len(parts) == 1:
            test_class_name = None
            test_function_name = parts[0]
        elif len(parts) == 2:
            test_class_name, test_function_name = parts
        else:
73
            raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
74
75
76
77
78
        test_id = (test_class_name, test_function_name)

        argnames = ("info", "args_kwargs")
        argvalues = []
        for info in infos:
79
            argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
80
81
82
83

        return pytest.mark.parametrize(argnames, argvalues)(test_fn)

    return decorator
84
85


Philip Meier's avatar
Philip Meier committed
86
87
88
89
90
91
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


92
93
94
95
96
97
98
@pytest.fixture()
def test_id(request):
    test_class_name = request.cls.__name__ if request.cls is not None else None
    test_function_name = request.node.originalname
    return test_class_name, test_function_name


99
class TestKernels:
100
    sample_inputs = make_info_args_kwargs_parametrization(
101
102
103
        KERNEL_INFOS,
        args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
    )
104
    reference_inputs = make_info_args_kwargs_parametrization(
105
        [info for info in KERNEL_INFOS if info.reference_fn is not None],
106
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
107
    )
108

109
110
111
112
    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.logs_usage],
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    )
113
    @pytest.mark.parametrize("device", cpu_and_cuda())
114
115
116
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

117
118
        (input, *other_args), kwargs = args_kwargs.load(device)
        info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
119
120
121

        spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

122
    @ignore_jit_warning_no_profile
123
    @sample_inputs
124
    @pytest.mark.parametrize("device", cpu_and_cuda())
125
    def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
126
127
        kernel_eager = info.kernel
        kernel_scripted = script(kernel_eager)
128

129
        (input, *other_args), kwargs = args_kwargs.load(device)
130
        input = input.as_subclass(torch.Tensor)
131

132
133
        actual = kernel_scripted(input, *other_args, **kwargs)
        expected = kernel_eager(input, *other_args, **kwargs)
134

135
136
137
138
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
139
            msg=parametrized_error_message(input, other_args, **kwargs),
140
        )
141

142
143
144
145
146
147
    def _unbatch(self, batch, *, data_dims):
        if isinstance(batch, torch.Tensor):
            batched_tensor = batch
            metadata = ()
        else:
            batched_tensor, *metadata = batch
148

149
150
        if batched_tensor.ndim == data_dims:
            return batch
151

152
153
154
155
156
157
        return [
            self._unbatch(unbatched, data_dims=data_dims)
            for unbatched in (
                batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
            )
        ]
158
159

    @sample_inputs
160
    @pytest.mark.parametrize("device", cpu_and_cuda())
161
    def test_batched_vs_single(self, test_id, info, args_kwargs, device):
162
163
        (batched_input, *other_args), kwargs = args_kwargs.load(device)

164
        tv_tensor_type = tv_tensors.Image if is_pure_tensor(batched_input) else type(batched_input)
165
166
167
        # This dictionary contains the number of rightmost dimensions that contain the actual data.
        # Everything to the left is considered a batch dimension.
        data_dims = {
168
169
            tv_tensors.Image: 3,
            tv_tensors.BoundingBoxes: 1,
170
171
172
173
            # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
            # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
            # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
            # common ground.
174
175
176
            tv_tensors.Mask: 2,
            tv_tensors.Video: 4,
        }.get(tv_tensor_type)
177
178
        if data_dims is None:
            raise pytest.UsageError(
179
                f"The number of data dimensions cannot be determined for input of type {tv_tensor_type.__name__}."
180
181
182
183
184
185
            ) from None
        elif batched_input.ndim <= data_dims:
            pytest.skip("Input is not batched.")
        elif not all(batched_input.shape[:-data_dims]):
            pytest.skip("Input has a degenerate batch shape.")

186
        batched_input = batched_input.as_subclass(torch.Tensor)
187
188
        batched_output = info.kernel(batched_input, *other_args, **kwargs)
        actual = self._unbatch(batched_output, data_dims=data_dims)
189

190
191
        single_inputs = self._unbatch(batched_input, data_dims=data_dims)
        expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
192

193
194
195
196
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
197
            msg=parametrized_error_message(batched_input, *other_args, **kwargs),
198
        )
199
200

    @sample_inputs
201
    @pytest.mark.parametrize("device", cpu_and_cuda())
202
203
    def test_no_inplace(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
204
        input = input.as_subclass(torch.Tensor)
205
206
207
208
209

        if input.numel() == 0:
            pytest.skip("The input has a degenerate shape.")

        input_version = input._version
210
        info.kernel(input, *other_args, **kwargs)
211

212
        assert input._version == input_version
213
214
215

    @sample_inputs
    @needs_cuda
216
    def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
217
        (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
218
        input_cpu = input_cpu.as_subclass(torch.Tensor)
219
220
221
222
223
        input_cuda = input_cpu.to("cuda")

        output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
        output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

224
225
226
227
228
        assert_close(
            output_cuda,
            output_cpu,
            check_device=False,
            **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
229
            msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
230
        )
231
232

    @sample_inputs
233
    @pytest.mark.parametrize("device", cpu_and_cuda())
234
235
    def test_dtype_and_device_consistency(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
236
        input = input.as_subclass(torch.Tensor)
237
238

        output = info.kernel(input, *other_args, **kwargs)
239
240
241
        # Most kernels just return a tensor, but some also return some additional metadata
        if not isinstance(output, torch.Tensor):
            output, *_ = output
242
243
244
245

        assert output.dtype == input.dtype
        assert output.device == input.device

246
    @reference_inputs
247
248
    def test_against_reference(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
249

250
251
252
        actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
        # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
        # metadata regardless of whether the kernel takes it explicitly or not
253
        expected = info.reference_fn(input, *other_args, **kwargs)
254

255
256
257
258
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
259
            msg=parametrized_error_message(input, *other_args, **kwargs),
260
261
262
263
264
265
266
267
        )

    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.float32_vs_uint8],
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
    )
    def test_float32_vs_uint8(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
268
        input = input.as_subclass(torch.Tensor)
269
270
271
272
273
274
275

        if input.dtype != torch.uint8:
            pytest.skip(f"Input dtype is {input.dtype}.")

        adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

        actual = info.kernel(
276
            F.to_dtype_image(input, dtype=torch.float32, scale=True),
277
278
279
280
            *adapted_other_args,
            **adapted_kwargs,
        )

281
        expected = F.to_dtype_image(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
282
283
284
285
286

        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
287
            msg=parametrized_error_message(input, *other_args, **kwargs),
288
        )
289
290


291
292
293
294
295
296
297
298
299
300
301
302
@pytest.fixture
def spy_on(mocker):
    def make_spy(fn, *, module=None, name=None):
        # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
        module = module or fn.__module__
        name = name or fn.__name__
        spy = mocker.patch(f"{module}.{name}", wraps=fn)
        return spy

    return make_spy


303
class TestDispatchers:
304
    image_sample_inputs = make_info_args_kwargs_parametrization(
305
306
        [info for info in DISPATCHER_INFOS if tv_tensors.Image in info.kernels],
        args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
307
308
    )

309
310
311
312
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
313
    @pytest.mark.parametrize("device", cpu_and_cuda())
314
315
316
317
318
319
320
321
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

        args, kwargs = args_kwargs.load(device)
        info.dispatcher(*args, **kwargs)

        spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

322
    @ignore_jit_warning_no_profile
323
    @image_sample_inputs
324
    @pytest.mark.parametrize("device", cpu_and_cuda())
325
326
    def test_scripted_smoke(self, info, args_kwargs, device):
        dispatcher = script(info.dispatcher)
327

328
329
        (image_tv_tensor, *other_args), kwargs = args_kwargs.load(device)
        image_pure_tensor = torch.Tensor(image_tv_tensor)
330

331
        dispatcher(image_pure_tensor, *other_args, **kwargs)
332

333
334
    # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
    #  replaces this test for them.
335
    @ignore_jit_warning_no_profile
336
337
338
339
340
341
    @pytest.mark.parametrize(
        "dispatcher",
        [
            F.get_dimensions,
            F.get_image_num_channels,
            F.get_image_size,
342
343
            F.get_num_channels,
            F.get_num_frames,
Philip Meier's avatar
Philip Meier committed
344
            F.get_size,
345
            F.rgb_to_grayscale,
346
            F.uniform_temporal_subsample,
347
348
349
350
351
        ],
        ids=lambda dispatcher: dispatcher.__name__,
    )
    def test_scriptable(self, dispatcher):
        script(dispatcher)
352

353
    @image_sample_inputs
354
    def test_pure_tensor_output_type(self, info, args_kwargs):
355
356
        (image_tv_tensor, *other_args), kwargs = args_kwargs.load()
        image_pure_tensor = image_tv_tensor.as_subclass(torch.Tensor)
357

358
        output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)
359

360
        # We cannot use `isinstance` here since all tv_tensors are instances of `torch.Tensor` as well
361
362
363
364
        assert type(output) is torch.Tensor

    @make_info_args_kwargs_parametrization(
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
365
        args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
366
367
    )
    def test_pil_output_type(self, info, args_kwargs):
368
        (image_tv_tensor, *other_args), kwargs = args_kwargs.load()
369

370
        if image_tv_tensor.ndim > 3:
371
372
            pytest.skip("Input is batched")

373
        image_pil = F.to_pil_image(image_tv_tensor)
374
375
376
377
378
379
380
381
382

        output = info.dispatcher(image_pil, *other_args, **kwargs)

        assert isinstance(output, PIL.Image.Image)

    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
383
384
    def test_tv_tensor_output_type(self, info, args_kwargs):
        (tv_tensor, *other_args), kwargs = args_kwargs.load()
385

386
        output = info.dispatcher(tv_tensor, *other_args, **kwargs)
387

388
        assert isinstance(output, type(tv_tensor))
389

390
391
        if isinstance(tv_tensor, tv_tensors.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
            assert output.format == tv_tensor.format
392

393
    @pytest.mark.parametrize(
394
        ("dispatcher_info", "tv_tensor_type", "kernel_info"),
395
        [
396
            pytest.param(
397
                dispatcher_info, tv_tensor_type, kernel_info, id=f"{dispatcher_info.id}-{tv_tensor_type.__name__}"
398
            )
399
            for dispatcher_info in DISPATCHER_INFOS
400
            for tv_tensor_type, kernel_info in dispatcher_info.kernel_infos.items()
401
402
        ],
    )
403
    def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, tv_tensor_type, kernel_info):
404
405
406
407
408
409
        dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

        kernel_signature = inspect.signature(kernel_info.kernel)
        kernel_params = list(kernel_signature.parameters.values())[1:]

410
        # We filter out metadata that is implicitly passed to the dispatcher through the input tv_tensor, but has to be
411
412
413
        # explicitly passed to the kernel.
        input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
        explicit_metadata = {
414
            tv_tensors.BoundingBoxes: {"format", "canvas_size"},
415
416
        }
        kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432

        dispatcher_params = iter(dispatcher_params)
        for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
            try:
                # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
                # dispatcher parameters that have no kernel equivalent while keeping the order intact.
                while dispatcher_param.name != kernel_param.name:
                    dispatcher_param = next(dispatcher_params)
            except StopIteration:
                raise AssertionError(
                    f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
                    f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
                ) from None

            assert dispatcher_param == kernel_param

433
434
435
436
437
438
439
440
    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
    def test_unkown_type(self, info):
        unkown_input = object()
        (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")

        with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
            info.dispatcher(unkown_input, *other_args, **kwargs)

441
442
443
444
    @make_info_args_kwargs_parametrization(
        [
            info
            for info in DISPATCHER_INFOS
445
            if tv_tensors.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
446
        ],
447
        args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.BoundingBoxes),
448
    )
449
450
451
    def test_bounding_boxes_format_consistency(self, info, args_kwargs):
        (bounding_boxes, *other_args), kwargs = args_kwargs.load()
        format = bounding_boxes.format
452

453
        output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
454
455
456

        assert output.format == format

457

458
@pytest.mark.parametrize(
459
    ("alias", "target"),
460
    [
461
462
463
464
465
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
466
            (F.to_pil_image, F.to_pil_image),
467
            (F.elastic_transform, F.elastic),
468
            (F.to_grayscale, F.rgb_to_grayscale),
469
        ]
470
471
    ],
)
472
473
def test_alias(alias, target):
    assert alias is target
474
475


476
@pytest.mark.parametrize("device", cpu_and_cuda())
477
478
479
480
481
482
483
484
485
486
487
488
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
    stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

    def assert_samples_from_standard_normal(t):
        p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
        return p_value > 1e-4

    image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
    mean = image.mean(dim=(1, 2)).tolist()
    std = image.std(dim=(1, 2)).tolist()

489
    assert_samples_from_standard_normal(F.normalize_image(image, mean, std))
490
491


492
class TestClampBoundingBoxes:
493
494
495
496
    @pytest.mark.parametrize(
        "metadata",
        [
            dict(),
497
            dict(format=tv_tensors.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
498
            dict(canvas_size=(1, 1)),
499
500
        ],
    )
501
    def test_pure_tensor_insufficient_metadata(self, metadata):
502
        pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
503

Philip Meier's avatar
Philip Meier committed
504
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
505
            F.clamp_bounding_boxes(pure_tensor, **metadata)
506
507
508
509

    @pytest.mark.parametrize(
        "metadata",
        [
510
            dict(format=tv_tensors.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
511
            dict(canvas_size=(1, 1)),
512
            dict(format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
513
514
        ],
    )
515
516
    def test_tv_tensor_explicit_metadata(self, metadata):
        tv_tensor = next(make_multiple_bounding_boxes())
517

Philip Meier's avatar
Philip Meier committed
518
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
519
            F.clamp_bounding_boxes(tv_tensor, **metadata)
520
521


522
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
523
#  `transforms_v2_kernel_infos.py`
524
525


526
@pytest.mark.parametrize("device", cpu_and_cuda())
527
528
529
530
531
532
533
534
@pytest.mark.parametrize(
    "startpoints, endpoints",
    [
        [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
    ],
)
535
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
536
    def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_):
537
538
539
540
541
542
543
544
545
546
547
548
549
        m1 = np.array(
            [
                [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
                [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
            ]
        )
        m2 = np.array(
            [
                [pcoeffs_[6], pcoeffs_[7], 1.0],
                [pcoeffs_[6], pcoeffs_[7], 1.0],
            ]
        )

550
        bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=tv_tensors.BoundingBoxFormat.XYXY)
551
552
553
554
555
556
557
558
559
560
561
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
        numer = np.matmul(points, m1.T)
        denom = np.matmul(points, m2.T)
        transformed_points = numer / denom
562
563
564
565
566
567
568
569
        out_bbox = np.array(
            [
                np.min(transformed_points[:, 0]),
                np.min(transformed_points[:, 1]),
                np.max(transformed_points[:, 0]),
                np.max(transformed_points[:, 1]),
            ]
        )
570
        out_bbox = torch.from_numpy(out_bbox)
Nicolas Hug's avatar
Nicolas Hug committed
571
        out_bbox = convert_bounding_box_format(
572
            out_bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format_
573
        )
574
        return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
575

Philip Meier's avatar
Philip Meier committed
576
    canvas_size = (32, 38)
577
578
579
580

    pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
    inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

581
    for bboxes in make_multiple_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
582
583
        bboxes = bboxes.to(device)

584
        output_bboxes = F.perspective_bounding_boxes(
585
586
            bboxes.as_subclass(torch.Tensor),
            format=bboxes.format,
Philip Meier's avatar
Philip Meier committed
587
            canvas_size=bboxes.canvas_size,
588
589
            startpoints=None,
            endpoints=None,
590
            coefficients=pcoeffs,
591
592
        )

593
594
595
596
597
598
        expected_bboxes = torch.stack(
            [
                _compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs)
                for b in bboxes.reshape(-1, 4).unbind()
            ]
        ).reshape(bboxes.shape)
599

600
        torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
601
602


603
604
605
606
607
608
609
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
610
611
def test_to_image(inpt):
    output = F.to_image(inpt)
612
    assert isinstance(output, torch.Tensor)
613
    assert output.shape == (3, 32, 32)
614
615
616
617
618
619
620
621
622
623
624
625

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
626
627
def test_to_pil_image(inpt, mode):
    output = F.to_pil_image(inpt, mode=mode)
628
629
630
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()
631
632
633
634


def test_equalize_image_tensor_edge_cases():
    inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
635
    output = F.equalize_image(inpt)
636
637
638
639
    torch.testing.assert_close(inpt, output)

    inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
    inpt[..., 100:, 100:] = 1
640
    output = F.equalize_image(inpt)
641
    assert output.unique().tolist() == [0, 255]
642
643


644
@pytest.mark.parametrize("device", cpu_and_cuda())
645
646
647
648
649
650
651
def test_correctness_uniform_temporal_subsample(device):
    video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
    out_video = F.uniform_temporal_subsample(video, 5)
    assert out_video.unique().tolist() == [0, 2, 4, 6, 9]

    out_video = F.uniform_temporal_subsample(video, 8)
    assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]