test_transforms_v2_functional.py 43.1 KB
Newer Older
1
import inspect
2
import math
3
import os
4
import re
5
from unittest import mock
6

7
import numpy as np
8
import PIL.Image
9
import pytest
10
import torch
11

12
from common_utils import (
13
    assert_close,
14
    cache,
15
    cpu_and_cuda,
16
17
    DEFAULT_SQUARE_SPATIAL_SIZE,
    make_bounding_boxes,
18
    needs_cuda,
19
    parametrized_error_message,
20
    set_rng_seed,
21
)
22
from torch.utils._pytree import tree_map
23
from torchvision import datapoints
24
from torchvision.transforms.functional import _get_perspective_coeffs
25
26
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
27
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
28
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
29
from torchvision.transforms.v2.utils import is_simple_tensor
30
31
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
32
33


34
35
36
37
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}


38
39
40
41
42
43
@cache
def script(fn):
    try:
        return torch.jit.script(fn)
    except Exception as error:
        raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
44
45


46
47
48
49
50
51
52
53
54
# Scripting a function often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
    f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
)


55
56
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
    args_kwargs = list(args_kwargs_fn(info))
57
58
59
60
    if not args_kwargs:
        raise pytest.UsageError(
            f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
        )
61
62
63
64
65
66
67
68
69
70
71
72
    idx_field_len = len(str(len(args_kwargs)))
    return [
        pytest.param(
            info,
            args_kwargs_,
            marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
            id=f"{info.id}-{idx:0{idx_field_len}}",
        )
        for idx, args_kwargs_ in enumerate(args_kwargs)
    ]


73
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
74
75
76
77
78
79
80
81
    def decorator(test_fn):
        parts = test_fn.__qualname__.split(".")
        if len(parts) == 1:
            test_class_name = None
            test_function_name = parts[0]
        elif len(parts) == 2:
            test_class_name, test_function_name = parts
        else:
82
            raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
83
84
85
86
87
        test_id = (test_class_name, test_function_name)

        argnames = ("info", "args_kwargs")
        argvalues = []
        for info in infos:
88
            argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
89
90
91
92

        return pytest.mark.parametrize(argnames, argvalues)(test_fn)

    return decorator
93
94


Philip Meier's avatar
Philip Meier committed
95
96
97
98
99
100
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


101
102
103
104
105
106
107
@pytest.fixture()
def test_id(request):
    test_class_name = request.cls.__name__ if request.cls is not None else None
    test_function_name = request.node.originalname
    return test_class_name, test_function_name


108
class TestKernels:
109
    sample_inputs = make_info_args_kwargs_parametrization(
110
111
112
        KERNEL_INFOS,
        args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
    )
113
    reference_inputs = make_info_args_kwargs_parametrization(
114
        [info for info in KERNEL_INFOS if info.reference_fn is not None],
115
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
116
    )
117

118
119
120
121
    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.logs_usage],
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    )
122
    @pytest.mark.parametrize("device", cpu_and_cuda())
123
124
125
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

126
127
        (input, *other_args), kwargs = args_kwargs.load(device)
        info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
128
129
130

        spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

131
    @ignore_jit_warning_no_profile
132
    @sample_inputs
133
    @pytest.mark.parametrize("device", cpu_and_cuda())
134
    def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
135
136
        kernel_eager = info.kernel
        kernel_scripted = script(kernel_eager)
137

138
        (input, *other_args), kwargs = args_kwargs.load(device)
139
        input = input.as_subclass(torch.Tensor)
140

141
142
        actual = kernel_scripted(input, *other_args, **kwargs)
        expected = kernel_eager(input, *other_args, **kwargs)
143

144
145
146
147
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
148
            msg=parametrized_error_message(input, other_args, **kwargs),
149
        )
150

151
152
153
154
155
156
    def _unbatch(self, batch, *, data_dims):
        if isinstance(batch, torch.Tensor):
            batched_tensor = batch
            metadata = ()
        else:
            batched_tensor, *metadata = batch
157

158
159
        if batched_tensor.ndim == data_dims:
            return batch
160

161
162
163
164
165
166
        return [
            self._unbatch(unbatched, data_dims=data_dims)
            for unbatched in (
                batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
            )
        ]
167
168

    @sample_inputs
169
    @pytest.mark.parametrize("device", cpu_and_cuda())
170
    def test_batched_vs_single(self, test_id, info, args_kwargs, device):
171
172
        (batched_input, *other_args), kwargs = args_kwargs.load(device)

173
        datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
174
175
176
        # This dictionary contains the number of rightmost dimensions that contain the actual data.
        # Everything to the left is considered a batch dimension.
        data_dims = {
177
            datapoints.Image: 3,
178
            datapoints.BoundingBoxes: 1,
179
180
181
182
            # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
            # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
            # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
            # common ground.
183
184
            datapoints.Mask: 2,
            datapoints.Video: 4,
185
        }.get(datapoint_type)
186
187
        if data_dims is None:
            raise pytest.UsageError(
188
                f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}."
189
190
191
192
193
194
            ) from None
        elif batched_input.ndim <= data_dims:
            pytest.skip("Input is not batched.")
        elif not all(batched_input.shape[:-data_dims]):
            pytest.skip("Input has a degenerate batch shape.")

195
        batched_input = batched_input.as_subclass(torch.Tensor)
196
197
        batched_output = info.kernel(batched_input, *other_args, **kwargs)
        actual = self._unbatch(batched_output, data_dims=data_dims)
198

199
200
        single_inputs = self._unbatch(batched_input, data_dims=data_dims)
        expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
201

202
203
204
205
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
206
            msg=parametrized_error_message(batched_input, *other_args, **kwargs),
207
        )
208
209

    @sample_inputs
210
    @pytest.mark.parametrize("device", cpu_and_cuda())
211
212
    def test_no_inplace(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
213
        input = input.as_subclass(torch.Tensor)
214
215
216
217
218

        if input.numel() == 0:
            pytest.skip("The input has a degenerate shape.")

        input_version = input._version
219
        info.kernel(input, *other_args, **kwargs)
220

221
        assert input._version == input_version
222
223
224

    @sample_inputs
    @needs_cuda
225
    def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
226
        (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
227
        input_cpu = input_cpu.as_subclass(torch.Tensor)
228
229
230
231
232
        input_cuda = input_cpu.to("cuda")

        output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
        output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

233
234
235
236
237
        assert_close(
            output_cuda,
            output_cpu,
            check_device=False,
            **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
238
            msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
239
        )
240
241

    @sample_inputs
242
    @pytest.mark.parametrize("device", cpu_and_cuda())
243
244
    def test_dtype_and_device_consistency(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
245
        input = input.as_subclass(torch.Tensor)
246
247

        output = info.kernel(input, *other_args, **kwargs)
248
249
250
        # Most kernels just return a tensor, but some also return some additional metadata
        if not isinstance(output, torch.Tensor):
            output, *_ = output
251
252
253
254

        assert output.dtype == input.dtype
        assert output.device == input.device

255
    @reference_inputs
256
257
    def test_against_reference(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
258

259
260
261
        actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
        # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
        # metadata regardless of whether the kernel takes it explicitly or not
262
        expected = info.reference_fn(input, *other_args, **kwargs)
263

264
265
266
267
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
268
            msg=parametrized_error_message(input, *other_args, **kwargs),
269
270
271
272
273
274
275
276
        )

    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.float32_vs_uint8],
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
    )
    def test_float32_vs_uint8(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
277
        input = input.as_subclass(torch.Tensor)
278
279
280
281
282
283
284

        if input.dtype != torch.uint8:
            pytest.skip(f"Input dtype is {input.dtype}.")

        adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

        actual = info.kernel(
285
            F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True),
286
287
288
289
            *adapted_other_args,
            **adapted_kwargs,
        )

290
        expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
291
292
293
294
295

        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
296
            msg=parametrized_error_message(input, *other_args, **kwargs),
297
        )
298
299


300
301
302
303
304
305
306
307
308
309
310
311
@pytest.fixture
def spy_on(mocker):
    def make_spy(fn, *, module=None, name=None):
        # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
        module = module or fn.__module__
        name = name or fn.__name__
        spy = mocker.patch(f"{module}.{name}", wraps=fn)
        return spy

    return make_spy


312
class TestDispatchers:
313
    image_sample_inputs = make_info_args_kwargs_parametrization(
314
315
        [info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
316
317
    )

318
319
320
321
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
322
    @pytest.mark.parametrize("device", cpu_and_cuda())
323
324
325
326
327
328
329
330
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

        args, kwargs = args_kwargs.load(device)
        info.dispatcher(*args, **kwargs)

        spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

331
    @ignore_jit_warning_no_profile
332
    @image_sample_inputs
333
    @pytest.mark.parametrize("device", cpu_and_cuda())
334
335
    def test_scripted_smoke(self, info, args_kwargs, device):
        dispatcher = script(info.dispatcher)
336

337
338
        (image_datapoint, *other_args), kwargs = args_kwargs.load(device)
        image_simple_tensor = torch.Tensor(image_datapoint)
339

340
        dispatcher(image_simple_tensor, *other_args, **kwargs)
341

342
343
    # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
    #  replaces this test for them.
344
    @ignore_jit_warning_no_profile
345
346
347
348
349
350
    @pytest.mark.parametrize(
        "dispatcher",
        [
            F.get_dimensions,
            F.get_image_num_channels,
            F.get_image_size,
351
352
            F.get_num_channels,
            F.get_num_frames,
Philip Meier's avatar
Philip Meier committed
353
            F.get_size,
354
            F.rgb_to_grayscale,
355
            F.uniform_temporal_subsample,
356
357
358
359
360
        ],
        ids=lambda dispatcher: dispatcher.__name__,
    )
    def test_scriptable(self, dispatcher):
        script(dispatcher)
361

362
    @image_sample_inputs
363
    def test_dispatch_simple_tensor(self, info, args_kwargs, spy_on):
364
365
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
        image_simple_tensor = torch.Tensor(image_datapoint)
366

367
        kernel_info = info.kernel_infos[datapoints.Image]
368
        spy = spy_on(kernel_info.kernel, module=info.dispatcher.__module__, name=kernel_info.id)
369
370
371
372
373

        info.dispatcher(image_simple_tensor, *other_args, **kwargs)

        spy.assert_called_once()

374
375
376
377
378
379
380
381
382
383
    @image_sample_inputs
    def test_simple_tensor_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
        image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)

        output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)

        # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
        assert type(output) is torch.Tensor

384
    @make_info_args_kwargs_parametrization(
385
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
386
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
387
388
    )
    def test_dispatch_pil(self, info, args_kwargs, spy_on):
389
        (image_datapoint, *other_args), kwargs = args_kwargs.load()
390

391
        if image_datapoint.ndim > 3:
392
393
            pytest.skip("Input is batched")

394
        image_pil = F.to_image_pil(image_datapoint)
395
396

        pil_kernel_info = info.pil_kernel_info
397
        spy = spy_on(pil_kernel_info.kernel, module=info.dispatcher.__module__, name=pil_kernel_info.id)
398
399
400
401
402

        info.dispatcher(image_pil, *other_args, **kwargs)

        spy.assert_called_once()

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    @make_info_args_kwargs_parametrization(
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
    )
    def test_pil_output_type(self, info, args_kwargs):
        (image_datapoint, *other_args), kwargs = args_kwargs.load()

        if image_datapoint.ndim > 3:
            pytest.skip("Input is batched")

        image_pil = F.to_image_pil(image_datapoint)

        output = info.dispatcher(image_pil, *other_args, **kwargs)

        assert isinstance(output, PIL.Image.Image)

419
    @make_info_args_kwargs_parametrization(
420
421
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
422
    )
423
424
    def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
        (datapoint, *other_args), kwargs = args_kwargs.load()
425

426
427
428
        input_type = type(datapoint)

        wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type]
429

430
431
432
433
434
435
436
437
        # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
        # proper kernel was wrapped
        if hasattr(wrapped_kernel, "__wrapped__"):
            assert wrapped_kernel.__wrapped__ is info.kernels[input_type]

        spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
        with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}):
            info.dispatcher(datapoint, *other_args, **kwargs)
438
439
440

        spy.assert_called_once()

441
442
443
444
445
446
447
448
449
450
451
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
    def test_datapoint_output_type(self, info, args_kwargs):
        (datapoint, *other_args), kwargs = args_kwargs.load()

        output = info.dispatcher(datapoint, *other_args, **kwargs)

        assert isinstance(output, type(datapoint))

452
    @pytest.mark.parametrize(
453
        ("dispatcher_info", "datapoint_type", "kernel_info"),
454
        [
455
456
457
            pytest.param(
                dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}"
            )
458
            for dispatcher_info in DISPATCHER_INFOS
459
            for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items()
460
461
        ],
    )
462
    def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info):
463
464
465
466
467
468
        dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

        kernel_signature = inspect.signature(kernel_info.kernel)
        kernel_params = list(kernel_signature.parameters.values())[1:]

469
        # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
470
471
472
473
474
475
        # explicitly passed to the kernel.
        input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
        explicit_metadata = {
            datapoints.BoundingBoxes: {"format", "canvas_size"},
        }
        kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491

        dispatcher_params = iter(dispatcher_params)
        for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
            try:
                # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
                # dispatcher parameters that have no kernel equivalent while keeping the order intact.
                while dispatcher_param.name != kernel_param.name:
                    dispatcher_param = next(dispatcher_params)
            except StopIteration:
                raise AssertionError(
                    f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
                    f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
                ) from None

            assert dispatcher_param == kernel_param

492
493
494
495
496
497
498
499
    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
    def test_unkown_type(self, info):
        unkown_input = object()
        (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")

        with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
            info.dispatcher(unkown_input, *other_args, **kwargs)

500
501
502
503
    @make_info_args_kwargs_parametrization(
        [
            info
            for info in DISPATCHER_INFOS
504
            if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_format_bounding_boxes
505
        ],
506
        args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
507
    )
508
509
510
    def test_bounding_boxes_format_consistency(self, info, args_kwargs):
        (bounding_boxes, *other_args), kwargs = args_kwargs.load()
        format = bounding_boxes.format
511

512
        output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
513
514
515

        assert output.format == format

516

517
@pytest.mark.parametrize(
518
    ("alias", "target"),
519
    [
520
521
522
523
524
525
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
            (F.to_pil_image, F.to_image_pil),
526
            (F.elastic_transform, F.elastic),
527
            (F.to_grayscale, F.rgb_to_grayscale),
528
        ]
529
530
    ],
)
531
532
def test_alias(alias, target):
    assert alias is target
533
534


535
@pytest.mark.parametrize("device", cpu_and_cuda())
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
    stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

    def assert_samples_from_standard_normal(t):
        p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
        return p_value > 1e-4

    image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
    mean = image.mean(dim=(1, 2)).tolist()
    std = image.std(dim=(1, 2)).tolist()

    assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))


551
class TestClampBoundingBoxes:
552
553
554
555
556
    @pytest.mark.parametrize(
        "metadata",
        [
            dict(),
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
557
            dict(canvas_size=(1, 1)),
558
559
560
561
562
        ],
    )
    def test_simple_tensor_insufficient_metadata(self, metadata):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

Philip Meier's avatar
Philip Meier committed
563
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
564
            F.clamp_bounding_boxes(simple_tensor, **metadata)
565
566
567
568
569

    @pytest.mark.parametrize(
        "metadata",
        [
            dict(format=datapoints.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
570
571
            dict(canvas_size=(1, 1)),
            dict(format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
572
573
574
575
576
        ],
    )
    def test_datapoint_explicit_metadata(self, metadata):
        datapoint = next(make_bounding_boxes())

Philip Meier's avatar
Philip Meier committed
577
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
578
            F.clamp_bounding_boxes(datapoint, **metadata)
579
580


581
class TestConvertFormatBoundingBoxes:
582
583
584
585
586
587
588
589
590
    @pytest.mark.parametrize(
        ("inpt", "old_format"),
        [
            (next(make_bounding_boxes()), None),
            (next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
        ],
    )
    def test_missing_new_format(self, inpt, old_format):
        with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
591
            F.convert_format_bounding_boxes(inpt, old_format)
592
593
594
595
596

    def test_simple_tensor_insufficient_metadata(self):
        simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)

        with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
597
            F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
598
599
600
601
602

    def test_datapoint_explicit_metadata(self):
        datapoint = next(make_bounding_boxes())

        with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
603
            F.convert_format_bounding_boxes(
604
605
606
607
                datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
            )


608
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
609
#  `transforms_v2_kernel_infos.py`
610
611


612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
    rot = math.radians(angle_)
    cx, cy = center_
    tx, ty = translate_
    sx, sy = [math.radians(sh_) for sh_ in shear_]

    c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
    t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
    c_matrix_inv = np.linalg.inv(c_matrix)
    rs_matrix = np.array(
        [
            [scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
            [scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
            [0, 0, 1],
        ]
    )
    shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
    shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
    rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
    true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
    return true_matrix


635
@pytest.mark.parametrize("device", cpu_and_cuda())
636
637
@pytest.mark.parametrize(
    "format",
638
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
639
)
640
641
642
643
644
645
646
@pytest.mark.parametrize(
    "top, left, height, width, expected_bboxes",
    [
        [8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
        [-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
    ],
)
647
def test_correctness_crop_bounding_boxes(device, format, top, left, height, width, expected_bboxes):
648
649
650
651
652
653
654
655
656
657
658
659
660

    # Expected bboxes computed using Albumentations:
    # import numpy as np
    # from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox
    # expected_bboxes = []
    # for in_box in in_boxes:
    #     n_in_box = normalize_bbox(in_box, *size)
    #     n_out_box = crop_bbox_by_coords(
    #         n_in_box, (left, top, left + width, top + height), height, width, *size
    #     )
    #     out_box = denormalize_bbox(n_out_box, height, width)
    #     expected_bboxes.append(out_box)

661
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
662
    canvas_size = (64, 76)
663
664
665
666
667
    in_boxes = [
        [10.0, 15.0, 25.0, 35.0],
        [50.0, 5.0, 70.0, 22.0],
        [45.0, 46.0, 56.0, 62.0],
    ]
668
    in_boxes = torch.tensor(in_boxes, device=device)
669
    if format != datapoints.BoundingBoxFormat.XYXY:
670
        in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
671

672
    expected_bboxes = clamp_bounding_boxes(
Philip Meier's avatar
Philip Meier committed
673
        datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
674
675
    ).tolist()

Philip Meier's avatar
Philip Meier committed
676
    output_boxes, output_canvas_size = F.crop_bounding_boxes(
677
        in_boxes,
678
        format,
679
680
        top,
        left,
Philip Meier's avatar
Philip Meier committed
681
682
        canvas_size[0],
        canvas_size[1],
683
684
    )

685
    if format != datapoints.BoundingBoxFormat.XYXY:
686
        output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
687

688
    torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
Philip Meier's avatar
Philip Meier committed
689
    torch.testing.assert_close(output_canvas_size, canvas_size)
690
691


692
@pytest.mark.parametrize("device", cpu_and_cuda())
693
694
695
696
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
    mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    mask[:, 0, :] = 1

697
    out_mask = F.vertical_flip_mask(mask)
698
699
700
701

    expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
    expected_mask[:, -1, :] = 1
    torch.testing.assert_close(out_mask, expected_mask)
702
703


704
@pytest.mark.parametrize("device", cpu_and_cuda())
705
706
@pytest.mark.parametrize(
    "format",
707
    [datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
708
709
710
711
712
713
714
715
)
@pytest.mark.parametrize(
    "top, left, height, width, size",
    [
        [0, 0, 30, 30, (60, 60)],
        [-5, 5, 35, 45, (32, 34)],
    ],
)
716
def test_correctness_resized_crop_bounding_boxes(device, format, top, left, height, width, size):
717
    def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
718
719
720
721
722
723
724
        # bbox should be xyxy
        bbox[0] = (bbox[0] - left_) * size_[1] / width_
        bbox[1] = (bbox[1] - top_) * size_[0] / height_
        bbox[2] = (bbox[2] - left_) * size_[1] / width_
        bbox[3] = (bbox[3] - top_) * size_[0] / height_
        return bbox

725
    format = datapoints.BoundingBoxFormat.XYXY
Philip Meier's avatar
Philip Meier committed
726
    canvas_size = (100, 100)
727
728
729
730
731
732
    in_boxes = [
        [10.0, 10.0, 20.0, 20.0],
        [5.0, 10.0, 15.0, 20.0],
    ]
    expected_bboxes = []
    for in_box in in_boxes:
733
        expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
734
735
    expected_bboxes = torch.tensor(expected_bboxes, device=device)

736
    in_boxes = datapoints.BoundingBoxes(
Philip Meier's avatar
Philip Meier committed
737
        in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
738
    )
739
    if format != datapoints.BoundingBoxFormat.XYXY:
740
        in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
741

Philip Meier's avatar
Philip Meier committed
742
    output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
743

744
    if format != datapoints.BoundingBoxFormat.XYXY:
745
        output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
746
747

    torch.testing.assert_close(output_boxes, expected_bboxes)
Philip Meier's avatar
Philip Meier committed
748
    torch.testing.assert_close(output_canvas_size, size)
749
750


751
752
753
754
755
756
757
758
759
760
761
762
def _parse_padding(padding):
    if isinstance(padding, int):
        return [padding] * 4
    if isinstance(padding, list):
        if len(padding) == 1:
            return padding * 4
        if len(padding) == 2:
            return padding * 2  # [left, up, right, down]

    return padding


763
@pytest.mark.parametrize("device", cpu_and_cuda())
764
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
765
def test_correctness_pad_bounding_boxes(device, padding):
766
767
768
    def _compute_expected_bbox(bbox, padding_):
        pad_left, pad_up, _, _ = _parse_padding(padding_)

769
770
        dtype = bbox.dtype
        format = bbox.format
771
772
        bbox = (
            bbox.clone()
773
            if format == datapoints.BoundingBoxFormat.XYXY
774
            else convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
775
        )
776
777
778
779

        bbox[0::2] += pad_left
        bbox[1::2] += pad_up

780
        bbox = convert_format_bounding_boxes(bbox, new_format=format)
781
        if bbox.dtype != dtype:
782
783
            # Temporary cast to original dtype
            # e.g. float32 -> int
784
            bbox = bbox.to(dtype)
785
786
        return bbox

Philip Meier's avatar
Philip Meier committed
787
    def _compute_expected_canvas_size(bbox, padding_):
788
        pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
Philip Meier's avatar
Philip Meier committed
789
        height, width = bbox.canvas_size
790
791
        return height + pad_up + pad_down, width + pad_left + pad_right

792
793
794
    for bboxes in make_bounding_boxes():
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
795
        bboxes_canvas_size = bboxes.canvas_size
796

Philip Meier's avatar
Philip Meier committed
797
798
        output_boxes, output_canvas_size = F.pad_bounding_boxes(
            bboxes, format=bboxes_format, canvas_size=bboxes_canvas_size, padding=padding
799
800
        )

Philip Meier's avatar
Philip Meier committed
801
        torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
802

803
        if bboxes.ndim < 2 or bboxes.shape[0] == 0:
804
805
806
807
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
Philip Meier's avatar
Philip Meier committed
808
            bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
809
810
811
812
813
814
            expected_bboxes.append(_compute_expected_bbox(bbox, padding))

        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
815
        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
816
817


818
@pytest.mark.parametrize("device", cpu_and_cuda())
819
820
821
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
    mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)

822
    out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
823
824
825
826
827
828

    expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
    expected_mask[:, 1:-1, 1:-1] = 1
    torch.testing.assert_close(out_mask, expected_mask)


829
@pytest.mark.parametrize("device", cpu_and_cuda())
830
831
832
833
834
835
836
837
@pytest.mark.parametrize(
    "startpoints, endpoints",
    [
        [[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
        [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
    ],
)
838
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
839
840
841
842
843
844
845
846
847
848
849
850
851
852
    def _compute_expected_bbox(bbox, pcoeffs_):
        m1 = np.array(
            [
                [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
                [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
            ]
        )
        m2 = np.array(
            [
                [pcoeffs_[6], pcoeffs_[7], 1.0],
                [pcoeffs_[6], pcoeffs_[7], 1.0],
            ]
        )

853
        bbox_xyxy = convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY)
854
855
856
857
858
859
860
861
862
863
864
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
        numer = np.matmul(points, m1.T)
        denom = np.matmul(points, m2.T)
        transformed_points = numer / denom
865
866
867
868
869
870
871
872
        out_bbox = np.array(
            [
                np.min(transformed_points[:, 0]),
                np.min(transformed_points[:, 1]),
                np.max(transformed_points[:, 0]),
                np.max(transformed_points[:, 1]),
            ]
        )
873
        out_bbox = datapoints.BoundingBoxes(
874
            out_bbox,
875
            format=datapoints.BoundingBoxFormat.XYXY,
Philip Meier's avatar
Philip Meier committed
876
            canvas_size=bbox.canvas_size,
877
            dtype=bbox.dtype,
878
879
            device=bbox.device,
        )
880
        return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format))
881

Philip Meier's avatar
Philip Meier committed
882
    canvas_size = (32, 38)
883
884
885
886

    pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
    inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)

Philip Meier's avatar
Philip Meier committed
887
    for bboxes in make_bounding_boxes(canvas_size=canvas_size, extra_dims=((4,),)):
888
889
        bboxes = bboxes.to(device)

890
        output_bboxes = F.perspective_bounding_boxes(
891
892
            bboxes.as_subclass(torch.Tensor),
            format=bboxes.format,
Philip Meier's avatar
Philip Meier committed
893
            canvas_size=bboxes.canvas_size,
894
895
            startpoints=None,
            endpoints=None,
896
            coefficients=pcoeffs,
897
898
899
900
901
902
903
        )

        if bboxes.ndim < 2:
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
Philip Meier's avatar
Philip Meier committed
904
            bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, canvas_size=bboxes.canvas_size)
905
906
907
908
909
            expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
910
        torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
911
912


913
@pytest.mark.parametrize("device", cpu_and_cuda())
914
915
916
917
@pytest.mark.parametrize(
    "output_size",
    [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
)
918
def test_correctness_center_crop_bounding_boxes(device, output_size):
919
920
    def _compute_expected_bbox(bbox, output_size_):
        format_ = bbox.format
Philip Meier's avatar
Philip Meier committed
921
        canvas_size_ = bbox.canvas_size
922
        dtype = bbox.dtype
923
        bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
924
925
926
927

        if len(output_size_) == 1:
            output_size_.append(output_size_[-1])

Philip Meier's avatar
Philip Meier committed
928
929
        cy = int(round((canvas_size_[0] - output_size_[0]) * 0.5))
        cx = int(round((canvas_size_[1] - output_size_[1]) * 0.5))
930
931
932
933
934
935
        out_bbox = [
            bbox[0].item() - cx,
            bbox[1].item() - cy,
            bbox[2].item(),
            bbox[3].item(),
        ]
936
        out_bbox = torch.tensor(out_bbox)
937
        out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
Philip Meier's avatar
Philip Meier committed
938
        out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
939
        return out_bbox.to(dtype=dtype, device=bbox.device)
940

941
    for bboxes in make_bounding_boxes(extra_dims=((4,),)):
942
943
        bboxes = bboxes.to(device)
        bboxes_format = bboxes.format
Philip Meier's avatar
Philip Meier committed
944
        bboxes_canvas_size = bboxes.canvas_size
945

Philip Meier's avatar
Philip Meier committed
946
947
        output_boxes, output_canvas_size = F.center_crop_bounding_boxes(
            bboxes, bboxes_format, bboxes_canvas_size, output_size
948
        )
949
950
951
952
953
954

        if bboxes.ndim < 2:
            bboxes = [bboxes]

        expected_bboxes = []
        for bbox in bboxes:
Philip Meier's avatar
Philip Meier committed
955
            bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
956
957
958
959
960
961
            expected_bboxes.append(_compute_expected_bbox(bbox, output_size))

        if len(expected_bboxes) > 1:
            expected_bboxes = torch.stack(expected_bboxes)
        else:
            expected_bboxes = expected_bboxes[0]
962
963

        torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
Philip Meier's avatar
Philip Meier committed
964
        torch.testing.assert_close(output_canvas_size, output_size)
965
966


967
@pytest.mark.parametrize("device", cpu_and_cuda())
968
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
969
970
def test_correctness_center_crop_mask(device, output_size):
    def _compute_expected_mask(mask, output_size):
971
972
973
974
975
976
977
978
979
980
981
982
983
        crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]

        _, image_height, image_width = mask.shape
        if crop_width > image_height or crop_height > image_width:
            padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
            mask = F.pad_image_tensor(mask, padding, fill=0)

        left = round((image_width - crop_width) * 0.5)
        top = round((image_height - crop_height) * 0.5)

        return mask[:, top : top + crop_height, left : left + crop_width]

    mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
984
    actual = F.center_crop_mask(mask, output_size)
985

986
    expected = _compute_expected_mask(mask, output_size)
987
    torch.testing.assert_close(expected, actual)
988
989
990


# Copied from test/test_functional_tensor.py
991
@pytest.mark.parametrize("device", cpu_and_cuda())
Philip Meier's avatar
Philip Meier committed
992
@pytest.mark.parametrize("canvas_size", ("small", "large"))
993
994
995
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
Philip Meier's avatar
Philip Meier committed
996
def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
    fn = F.gaussian_blur_image_tensor

    # true_cv2_results = {
    #     # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
    #     "3_3_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
    #     "3_3_0.5": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
    #     "3_5_0.8": ...
    #     # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
    #     "3_5_0.5": ...
    #     # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
    #     # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
    #     "23_23_1.7": ...
    # }
    p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
    true_cv2_results = torch.load(p)

Philip Meier's avatar
Philip Meier committed
1016
    if canvas_size == "small":
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
        tensor = (
            torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
        )
    else:
        tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)

    if dt == torch.float16 and device == "cpu":
        # skip float16 on CPU case
        return

    if dt is not None:
        tensor = tensor.to(dtype=dt)

    _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
    _sigma = sigma[0] if sigma is not None else None
    shape = tensor.shape
    gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
    if gt_key not in true_cv2_results:
        return

    true_out = (
        torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
    )

1041
    image = datapoints.Image(tensor)
1042
1043
1044
1045
1046

    out = fn(image, kernel_size=ksize, sigma=sigma)
    torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


1047
1048
1049
1050
1051
1052
1053
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
1054
1055
def test_to_image_tensor(inpt):
    output = F.to_image_tensor(inpt)
1056
    assert isinstance(output, torch.Tensor)
1057
    assert output.shape == (3, 32, 32)
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
def test_to_image_pil(inpt, mode):
    output = F.to_image_pil(inpt, mode=mode)
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085


def test_equalize_image_tensor_edge_cases():
    inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
    output = F.equalize_image_tensor(inpt)
    torch.testing.assert_close(inpt, output)

    inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
    inpt[..., 100:, 100:] = 1
    output = F.equalize_image_tensor(inpt)
    assert output.unique().tolist() == [0, 255]
1086
1087


1088
@pytest.mark.parametrize("device", cpu_and_cuda())
1089
1090
1091
1092
1093
1094
1095
def test_correctness_uniform_temporal_subsample(device):
    video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
    out_video = F.uniform_temporal_subsample(video, 5)
    assert out_video.unique().tolist() == [0, 2, 4, 6, 9]

    out_video = F.uniform_temporal_subsample(video, 8)
    assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125


# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization(
    [info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor],
    args_kwargs_fn=lambda info: info.reference_inputs_fn(),
)
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
    (input, *other_args), kwargs = args_kwargs.load("cpu")

    output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)

    error_msg_fn = parametrized_error_message(input, *other_args, **kwargs)
    assert input.ndim == 3, error_msg_fn
    input_stride = input.stride()
    output_stride = output.stride()
    # Here we check output memory format according to the input:
    # if input_stride is (..., 1) then input is most likely channels first and thus
    #   output strides should match channels first strides (H * W, H, 1)
    # if input_stride is (1, ...) then input is most likely channels last and thus
    #   output strides should match channels last strides (1, W * C, C)
    if input_stride[-1] == 1:
        expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1)
        assert expected_stride == output_stride, error_msg_fn("")
    elif input_stride[0] == 1:
        expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0])
        assert expected_stride == output_stride, error_msg_fn("")
    else:
        assert False, error_msg_fn("")
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135


def test_resize_float16_no_rounding():
    # Make sure Resize() doesn't round float16 images
    # Non-regression test for https://github.com/pytorch/vision/issues/7667

    img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16)
    out = F.resize(img, size=(10, 10))
    assert out.dtype == torch.float16
    assert (out.round() - out).sum() > 0