test_transforms_v2_functional.py 21.3 KB
Newer Older
1
import inspect
2
import re
3

4
import numpy as np
5
import PIL.Image
6
import pytest
7
import torch
8

9
from common_utils import assert_close, cache, cpu_and_cuda, needs_cuda, set_rng_seed
10
from torch.utils._pytree import tree_map
11
from torchvision import tv_tensors
12
from torchvision.transforms.v2 import functional as F
Nicolas Hug's avatar
Nicolas Hug committed
13
from torchvision.transforms.v2._utils import is_pure_tensor
14
15
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
16
17
18
19
20
from transforms_v2_legacy_utils import (
    DEFAULT_SQUARE_SPATIAL_SIZE,
    make_multiple_bounding_boxes,
    parametrized_error_message,
)
21
22


23
24
25
26
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
DISPATCHER_INFOS_MAP = {info.dispatcher: info for info in DISPATCHER_INFOS}


27
28
29
30
31
32
@cache
def script(fn):
    try:
        return torch.jit.script(fn)
    except Exception as error:
        raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
33
34


35
36
37
38
39
40
41
42
43
# Scripting a function often triggers a warning like
# `UserWarning: operator() profile_node %$INT1 : int[] = prim::profile_ivalue($INT2) does not have profile information`
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
ignore_jit_warning_no_profile = pytest.mark.filterwarnings(
    f"ignore:{re.escape('operator() profile_node %')}:UserWarning"
)


44
45
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
    args_kwargs = list(args_kwargs_fn(info))
46
47
48
49
    if not args_kwargs:
        raise pytest.UsageError(
            f"Couldn't collect a single `ArgsKwargs` for `{info.id}`{f' in {test_id}' if test_id else ''}"
        )
50
51
52
53
54
55
56
57
58
59
60
61
    idx_field_len = len(str(len(args_kwargs)))
    return [
        pytest.param(
            info,
            args_kwargs_,
            marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
            id=f"{info.id}-{idx:0{idx_field_len}}",
        )
        for idx, args_kwargs_ in enumerate(args_kwargs)
    ]


62
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn):
63
64
65
66
67
68
69
70
    def decorator(test_fn):
        parts = test_fn.__qualname__.split(".")
        if len(parts) == 1:
            test_class_name = None
            test_function_name = parts[0]
        elif len(parts) == 2:
            test_class_name, test_function_name = parts
        else:
71
            raise pytest.UsageError("Unable to parse the test class name and test function name from test function")
72
73
74
75
76
        test_id = (test_class_name, test_function_name)

        argnames = ("info", "args_kwargs")
        argvalues = []
        for info in infos:
77
            argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
78
79
80
81

        return pytest.mark.parametrize(argnames, argvalues)(test_fn)

    return decorator
82
83


Philip Meier's avatar
Philip Meier committed
84
85
86
87
88
89
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


90
91
92
93
94
95
96
@pytest.fixture()
def test_id(request):
    test_class_name = request.cls.__name__ if request.cls is not None else None
    test_function_name = request.node.originalname
    return test_class_name, test_function_name


97
class TestKernels:
98
    sample_inputs = make_info_args_kwargs_parametrization(
99
100
101
        KERNEL_INFOS,
        args_kwargs_fn=lambda kernel_info: kernel_info.sample_inputs_fn(),
    )
102
    reference_inputs = make_info_args_kwargs_parametrization(
103
        [info for info in KERNEL_INFOS if info.reference_fn is not None],
104
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
105
    )
106

107
108
109
110
    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.logs_usage],
        args_kwargs_fn=lambda info: info.sample_inputs_fn(),
    )
111
    @pytest.mark.parametrize("device", cpu_and_cuda())
112
113
114
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

115
116
        (input, *other_args), kwargs = args_kwargs.load(device)
        info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
117
118
119

        spy.assert_any_call(f"{info.kernel.__module__}.{info.id}")

120
    @ignore_jit_warning_no_profile
121
    @sample_inputs
122
    @pytest.mark.parametrize("device", cpu_and_cuda())
123
    def test_scripted_vs_eager(self, test_id, info, args_kwargs, device):
124
125
        kernel_eager = info.kernel
        kernel_scripted = script(kernel_eager)
126

127
        (input, *other_args), kwargs = args_kwargs.load(device)
128
        input = input.as_subclass(torch.Tensor)
129

130
131
        actual = kernel_scripted(input, *other_args, **kwargs)
        expected = kernel_eager(input, *other_args, **kwargs)
132

133
134
135
136
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
137
            msg=parametrized_error_message(input, other_args, **kwargs),
138
        )
139

140
141
142
143
144
145
    def _unbatch(self, batch, *, data_dims):
        if isinstance(batch, torch.Tensor):
            batched_tensor = batch
            metadata = ()
        else:
            batched_tensor, *metadata = batch
146

147
148
        if batched_tensor.ndim == data_dims:
            return batch
149

150
151
152
153
154
155
        return [
            self._unbatch(unbatched, data_dims=data_dims)
            for unbatched in (
                batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
            )
        ]
156
157

    @sample_inputs
158
    @pytest.mark.parametrize("device", cpu_and_cuda())
159
    def test_batched_vs_single(self, test_id, info, args_kwargs, device):
160
161
        (batched_input, *other_args), kwargs = args_kwargs.load(device)

162
        tv_tensor_type = tv_tensors.Image if is_pure_tensor(batched_input) else type(batched_input)
163
164
165
        # This dictionary contains the number of rightmost dimensions that contain the actual data.
        # Everything to the left is considered a batch dimension.
        data_dims = {
166
167
            tv_tensors.Image: 3,
            tv_tensors.BoundingBoxes: 1,
168
169
170
171
            # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
            # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
            # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
            # common ground.
172
173
174
            tv_tensors.Mask: 2,
            tv_tensors.Video: 4,
        }.get(tv_tensor_type)
175
176
        if data_dims is None:
            raise pytest.UsageError(
177
                f"The number of data dimensions cannot be determined for input of type {tv_tensor_type.__name__}."
178
179
180
181
182
183
            ) from None
        elif batched_input.ndim <= data_dims:
            pytest.skip("Input is not batched.")
        elif not all(batched_input.shape[:-data_dims]):
            pytest.skip("Input has a degenerate batch shape.")

184
        batched_input = batched_input.as_subclass(torch.Tensor)
185
186
        batched_output = info.kernel(batched_input, *other_args, **kwargs)
        actual = self._unbatch(batched_output, data_dims=data_dims)
187

188
189
        single_inputs = self._unbatch(batched_input, data_dims=data_dims)
        expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
190

191
192
193
194
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=batched_input.dtype, device=batched_input.device),
195
            msg=parametrized_error_message(batched_input, *other_args, **kwargs),
196
        )
197
198

    @sample_inputs
199
    @pytest.mark.parametrize("device", cpu_and_cuda())
200
201
    def test_no_inplace(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
202
        input = input.as_subclass(torch.Tensor)
203
204
205
206
207

        if input.numel() == 0:
            pytest.skip("The input has a degenerate shape.")

        input_version = input._version
208
        info.kernel(input, *other_args, **kwargs)
209

210
        assert input._version == input_version
211
212
213

    @sample_inputs
    @needs_cuda
214
    def test_cuda_vs_cpu(self, test_id, info, args_kwargs):
215
        (input_cpu, *other_args), kwargs = args_kwargs.load("cpu")
216
        input_cpu = input_cpu.as_subclass(torch.Tensor)
217
218
219
220
221
        input_cuda = input_cpu.to("cuda")

        output_cpu = info.kernel(input_cpu, *other_args, **kwargs)
        output_cuda = info.kernel(input_cuda, *other_args, **kwargs)

222
223
224
225
226
        assert_close(
            output_cuda,
            output_cpu,
            check_device=False,
            **info.get_closeness_kwargs(test_id, dtype=input_cuda.dtype, device=input_cuda.device),
227
            msg=parametrized_error_message(input_cpu, *other_args, **kwargs),
228
        )
229
230

    @sample_inputs
231
    @pytest.mark.parametrize("device", cpu_and_cuda())
232
233
    def test_dtype_and_device_consistency(self, info, args_kwargs, device):
        (input, *other_args), kwargs = args_kwargs.load(device)
234
        input = input.as_subclass(torch.Tensor)
235
236

        output = info.kernel(input, *other_args, **kwargs)
237
238
239
        # Most kernels just return a tensor, but some also return some additional metadata
        if not isinstance(output, torch.Tensor):
            output, *_ = output
240
241
242
243

        assert output.dtype == input.dtype
        assert output.device == input.device

244
    @reference_inputs
245
246
    def test_against_reference(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
247

248
249
250
        actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
        # We intnetionally don't unwrap the input of the reference function in order for it to have access to all
        # metadata regardless of whether the kernel takes it explicitly or not
251
        expected = info.reference_fn(input, *other_args, **kwargs)
252

253
254
255
256
        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
257
            msg=parametrized_error_message(input, *other_args, **kwargs),
258
259
260
261
262
263
264
265
        )

    @make_info_args_kwargs_parametrization(
        [info for info in KERNEL_INFOS if info.float32_vs_uint8],
        args_kwargs_fn=lambda info: info.reference_inputs_fn(),
    )
    def test_float32_vs_uint8(self, test_id, info, args_kwargs):
        (input, *other_args), kwargs = args_kwargs.load("cpu")
266
        input = input.as_subclass(torch.Tensor)
267
268
269
270
271
272
273

        if input.dtype != torch.uint8:
            pytest.skip(f"Input dtype is {input.dtype}.")

        adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)

        actual = info.kernel(
274
            F.to_dtype_image(input, dtype=torch.float32, scale=True),
275
276
277
278
            *adapted_other_args,
            **adapted_kwargs,
        )

279
        expected = F.to_dtype_image(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
280
281
282
283
284

        assert_close(
            actual,
            expected,
            **info.get_closeness_kwargs(test_id, dtype=torch.float32, device=input.device),
285
            msg=parametrized_error_message(input, *other_args, **kwargs),
286
        )
287
288


289
290
291
292
293
294
295
296
297
298
299
300
@pytest.fixture
def spy_on(mocker):
    def make_spy(fn, *, module=None, name=None):
        # TODO: we can probably get rid of the non-default modules and names if we eliminate aliasing
        module = module or fn.__module__
        name = name or fn.__name__
        spy = mocker.patch(f"{module}.{name}", wraps=fn)
        return spy

    return make_spy


301
class TestDispatchers:
302
    image_sample_inputs = make_info_args_kwargs_parametrization(
303
304
        [info for info in DISPATCHER_INFOS if tv_tensors.Image in info.kernels],
        args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
305
306
    )

307
308
309
310
    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
311
    @pytest.mark.parametrize("device", cpu_and_cuda())
312
313
314
315
316
317
318
319
    def test_logging(self, spy_on, info, args_kwargs, device):
        spy = spy_on(torch._C._log_api_usage_once)

        args, kwargs = args_kwargs.load(device)
        info.dispatcher(*args, **kwargs)

        spy.assert_any_call(f"{info.dispatcher.__module__}.{info.id}")

320
    @ignore_jit_warning_no_profile
321
    @image_sample_inputs
322
    @pytest.mark.parametrize("device", cpu_and_cuda())
323
324
    def test_scripted_smoke(self, info, args_kwargs, device):
        dispatcher = script(info.dispatcher)
325

326
327
        (image_tv_tensor, *other_args), kwargs = args_kwargs.load(device)
        image_pure_tensor = torch.Tensor(image_tv_tensor)
328

329
        dispatcher(image_pure_tensor, *other_args, **kwargs)
330

331
332
    # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
    #  replaces this test for them.
333
    @ignore_jit_warning_no_profile
334
335
336
337
338
339
    @pytest.mark.parametrize(
        "dispatcher",
        [
            F.get_dimensions,
            F.get_image_num_channels,
            F.get_image_size,
340
341
            F.get_num_channels,
            F.get_num_frames,
Philip Meier's avatar
Philip Meier committed
342
            F.get_size,
343
            F.rgb_to_grayscale,
344
            F.uniform_temporal_subsample,
345
346
347
348
349
        ],
        ids=lambda dispatcher: dispatcher.__name__,
    )
    def test_scriptable(self, dispatcher):
        script(dispatcher)
350

351
    @image_sample_inputs
352
    def test_pure_tensor_output_type(self, info, args_kwargs):
353
354
        (image_tv_tensor, *other_args), kwargs = args_kwargs.load()
        image_pure_tensor = image_tv_tensor.as_subclass(torch.Tensor)
355

356
        output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)
357

358
        # We cannot use `isinstance` here since all tv_tensors are instances of `torch.Tensor` as well
359
360
361
362
        assert type(output) is torch.Tensor

    @make_info_args_kwargs_parametrization(
        [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
363
        args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
364
365
    )
    def test_pil_output_type(self, info, args_kwargs):
366
        (image_tv_tensor, *other_args), kwargs = args_kwargs.load()
367

368
        if image_tv_tensor.ndim > 3:
369
370
            pytest.skip("Input is batched")

371
        image_pil = F.to_pil_image(image_tv_tensor)
372
373
374
375
376
377
378
379
380

        output = info.dispatcher(image_pil, *other_args, **kwargs)

        assert isinstance(output, PIL.Image.Image)

    @make_info_args_kwargs_parametrization(
        DISPATCHER_INFOS,
        args_kwargs_fn=lambda info: info.sample_inputs(),
    )
381
382
    def test_tv_tensor_output_type(self, info, args_kwargs):
        (tv_tensor, *other_args), kwargs = args_kwargs.load()
383

384
        output = info.dispatcher(tv_tensor, *other_args, **kwargs)
385

386
        assert isinstance(output, type(tv_tensor))
387

388
389
        if isinstance(tv_tensor, tv_tensors.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
            assert output.format == tv_tensor.format
390

391
    @pytest.mark.parametrize(
392
        ("dispatcher_info", "tv_tensor_type", "kernel_info"),
393
        [
394
            pytest.param(
395
                dispatcher_info, tv_tensor_type, kernel_info, id=f"{dispatcher_info.id}-{tv_tensor_type.__name__}"
396
            )
397
            for dispatcher_info in DISPATCHER_INFOS
398
            for tv_tensor_type, kernel_info in dispatcher_info.kernel_infos.items()
399
400
        ],
    )
401
    def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, tv_tensor_type, kernel_info):
402
403
404
405
406
407
        dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
        dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

        kernel_signature = inspect.signature(kernel_info.kernel)
        kernel_params = list(kernel_signature.parameters.values())[1:]

408
        # We filter out metadata that is implicitly passed to the dispatcher through the input tv_tensor, but has to be
409
410
411
        # explicitly passed to the kernel.
        input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
        explicit_metadata = {
412
            tv_tensors.BoundingBoxes: {"format", "canvas_size"},
413
414
        }
        kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430

        dispatcher_params = iter(dispatcher_params)
        for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
            try:
                # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
                # dispatcher parameters that have no kernel equivalent while keeping the order intact.
                while dispatcher_param.name != kernel_param.name:
                    dispatcher_param = next(dispatcher_params)
            except StopIteration:
                raise AssertionError(
                    f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
                    f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
                ) from None

            assert dispatcher_param == kernel_param

431
432
433
434
435
436
437
438
    @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
    def test_unkown_type(self, info):
        unkown_input = object()
        (_, *other_args), kwargs = next(iter(info.sample_inputs())).load("cpu")

        with pytest.raises(TypeError, match=re.escape(str(type(unkown_input)))):
            info.dispatcher(unkown_input, *other_args, **kwargs)

439
440
441
442
    @make_info_args_kwargs_parametrization(
        [
            info
            for info in DISPATCHER_INFOS
443
            if tv_tensors.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
444
        ],
445
        args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.BoundingBoxes),
446
    )
447
448
449
    def test_bounding_boxes_format_consistency(self, info, args_kwargs):
        (bounding_boxes, *other_args), kwargs = args_kwargs.load()
        format = bounding_boxes.format
450

451
        output = info.dispatcher(bounding_boxes, *other_args, **kwargs)
452
453
454

        assert output.format == format

455

456
@pytest.mark.parametrize(
457
    ("alias", "target"),
458
    [
459
460
461
462
463
        pytest.param(alias, target, id=alias.__name__)
        for alias, target in [
            (F.hflip, F.horizontal_flip),
            (F.vflip, F.vertical_flip),
            (F.get_image_num_channels, F.get_num_channels),
464
            (F.to_pil_image, F.to_pil_image),
465
            (F.elastic_transform, F.elastic),
466
            (F.to_grayscale, F.rgb_to_grayscale),
467
        ]
468
469
    ],
)
470
471
def test_alias(alias, target):
    assert alias is target
472
473


474
@pytest.mark.parametrize("device", cpu_and_cuda())
475
476
477
478
479
480
481
482
483
484
485
486
@pytest.mark.parametrize("num_channels", [1, 3])
def test_normalize_image_tensor_stats(device, num_channels):
    stats = pytest.importorskip("scipy.stats", reason="SciPy is not available")

    def assert_samples_from_standard_normal(t):
        p_value = stats.kstest(t.flatten(), cdf="norm", args=(0, 1)).pvalue
        return p_value > 1e-4

    image = torch.rand(num_channels, DEFAULT_SQUARE_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE)
    mean = image.mean(dim=(1, 2)).tolist()
    std = image.std(dim=(1, 2)).tolist()

487
    assert_samples_from_standard_normal(F.normalize_image(image, mean, std))
488
489


490
class TestClampBoundingBoxes:
491
492
493
494
    @pytest.mark.parametrize(
        "metadata",
        [
            dict(),
495
            dict(format=tv_tensors.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
496
            dict(canvas_size=(1, 1)),
497
498
        ],
    )
499
    def test_pure_tensor_insufficient_metadata(self, metadata):
500
        pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
501

Philip Meier's avatar
Philip Meier committed
502
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
503
            F.clamp_bounding_boxes(pure_tensor, **metadata)
504
505
506
507

    @pytest.mark.parametrize(
        "metadata",
        [
508
            dict(format=tv_tensors.BoundingBoxFormat.XYXY),
Philip Meier's avatar
Philip Meier committed
509
            dict(canvas_size=(1, 1)),
510
            dict(format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
511
512
        ],
    )
513
514
    def test_tv_tensor_explicit_metadata(self, metadata):
        tv_tensor = next(make_multiple_bounding_boxes())
515

Philip Meier's avatar
Philip Meier committed
516
        with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
517
            F.clamp_bounding_boxes(tv_tensor, **metadata)
518
519


520
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
521
#  `transforms_v2_kernel_infos.py`
522
523


524
525
526
527
528
529
530
@pytest.mark.parametrize(
    "inpt",
    [
        127 * np.ones((32, 32, 3), dtype="uint8"),
        PIL.Image.new("RGB", (32, 32), 122),
    ],
)
531
532
def test_to_image(inpt):
    output = F.to_image(inpt)
533
    assert isinstance(output, torch.Tensor)
534
    assert output.shape == (3, 32, 32)
535
536
537
538
539
540
541
542
543
544
545
546

    assert np.asarray(inpt).sum() == output.sum().item()


@pytest.mark.parametrize(
    "inpt",
    [
        torch.randint(0, 256, size=(3, 32, 32), dtype=torch.uint8),
        127 * np.ones((32, 32, 3), dtype="uint8"),
    ],
)
@pytest.mark.parametrize("mode", [None, "RGB"])
547
548
def test_to_pil_image(inpt, mode):
    output = F.to_pil_image(inpt, mode=mode)
549
550
551
    assert isinstance(output, PIL.Image.Image)

    assert np.asarray(inpt).sum() == np.asarray(output).sum()
552
553
554
555


def test_equalize_image_tensor_edge_cases():
    inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
556
    output = F.equalize_image(inpt)
557
558
559
560
    torch.testing.assert_close(inpt, output)

    inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
    inpt[..., 100:, 100:] = 1
561
    output = F.equalize_image(inpt)
562
    assert output.unique().tolist() == [0, 255]
563
564


565
@pytest.mark.parametrize("device", cpu_and_cuda())
566
567
568
569
570
571
572
def test_correctness_uniform_temporal_subsample(device):
    video = torch.arange(10, device=device)[:, None, None, None].expand(-1, 3, 8, 8)
    out_video = F.uniform_temporal_subsample(video, 5)
    assert out_video.unique().tolist() == [0, 2, 4, 6, 9]

    out_video = F.uniform_temporal_subsample(video, 8)
    assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9]