transforms_v2_kernel_infos.py 64.2 KB
Newer Older
1
import decimal
2
3
4
5
import functools
import itertools

import numpy as np
6
import PIL.Image
7
8
import pytest
import torch.testing
9
import torchvision.ops
10
import torchvision.transforms.v2.functional as F
11
from common_utils import (
12
    ArgsKwargs,
13
    combinations_grid,
14
15
    get_num_channels,
    ImageLoader,
16
    InfoBase,
17
    make_bounding_box_loader,
18
    make_bounding_box_loaders,
19
    make_detection_mask_loader,
20
21
    make_image_loader,
    make_image_loaders,
22
    make_image_loaders_for_interpolation,
23
    make_mask_loaders,
24
    make_video_loader,
25
    make_video_loaders,
26
27
    mark_framework_limitation,
    TestMark,
28
)
29
from torch.utils._pytree import tree_map
30
from torchvision import datapoints
31
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
32
33
34
35

__all__ = ["KernelInfo", "KERNEL_INFOS"]


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class KernelInfo(InfoBase):
    def __init__(
        self,
        kernel,
        *,
        # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
        # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
        kernel_name=None,
        # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
        # should not include extensive parameter combinations to keep to overall test count moderate.
        sample_inputs_fn,
        # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
        # take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
        # happen inside the function. It should return a tensor or to be more precise an object that can be compared to
        # a tensor by `assert_close`. If omitted, no reference test will be performed.
        reference_fn=None,
        # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
        # values to be tested. If not specified, `sample_inputs_fn` will be used.
        reference_inputs_fn=None,
55
        # If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
56
        # reference inputs. This is usually used whenever we use a PIL kernel as reference.
57
58
59
60
        # Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
        # structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
        # dtype.
        float32_vs_uint8=False,
61
62
63
        # Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
        # manually. If set, triggers a test that makes sure this happens.
        logs_usage=False,
64
65
66
67
68
69
70
71
72
73
        # See InfoBase
        test_marks=None,
        # See InfoBase
        closeness_kwargs=None,
    ):
        super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
        self.kernel = kernel
        self.sample_inputs_fn = sample_inputs_fn
        self.reference_fn = reference_fn
        self.reference_inputs_fn = reference_inputs_fn
74

75
76
77
        if float32_vs_uint8 and not callable(float32_vs_uint8):
            float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs)  # noqa: E731
        self.float32_vs_uint8 = float32_vs_uint8
78
        self.logs_usage = logs_usage
79
80


81
def pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
82
    return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, mae=mae)
83
84
85
86


def cuda_vs_cpu_pixel_difference(atol=1):
    return {
87
        (("TestKernels", "test_cuda_vs_cpu"), dtype, "cuda"): pixel_difference_closeness_kwargs(atol, dtype=dtype)
88
89
90
91
        for dtype in [torch.uint8, torch.float32]
    }


92
def pil_reference_pixel_difference(atol=1, mae=False):
93
    return {
94
        (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(
95
            atol, mae=mae
96
97
98
99
        )
    }


100
def float32_vs_uint8_pixel_difference(atol=1, mae=False):
101
102
103
104
105
    return {
        (
            ("TestKernels", "test_float32_vs_uint8"),
            torch.float32,
            "cpu",
106
        ): pixel_difference_closeness_kwargs(atol, dtype=torch.float32, mae=mae)
107
    }
108

109

110
def scripted_vs_eager_float64_tolerances(device, atol=1e-6, rtol=1e-6):
111
112
113
114
115
    return {
        (("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
    }


116
117
def pil_reference_wrapper(pil_kernel):
    @functools.wraps(pil_kernel)
118
119
120
121
    def wrapper(input_tensor, *other_args, **kwargs):
        if input_tensor.dtype != torch.uint8:
            raise pytest.UsageError(f"Can only test uint8 tensor images against PIL, but input is {input_tensor.dtype}")
        if input_tensor.ndim > 3:
122
            raise pytest.UsageError(
123
                f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
124
125
            )

126
127
128
129
130
131
132
133
134
135
136
137
138
139
        input_pil = F.to_image_pil(input_tensor)
        output_pil = pil_kernel(input_pil, *other_args, **kwargs)
        if not isinstance(output_pil, PIL.Image.Image):
            return output_pil

        output_tensor = F.to_image_tensor(output_pil)

        # 2D mask shenanigans
        if output_tensor.ndim == 2 and input_tensor.ndim == 3:
            output_tensor = output_tensor.unsqueeze(0)
        elif output_tensor.ndim == 3 and input_tensor.ndim == 2:
            output_tensor = output_tensor.squeeze(0)

        return output_tensor
140
141
142
143

    return wrapper


144
145
146
147
def xfail_jit(reason, *, condition=None):
    return TestMark(("TestKernels", "test_scripted_vs_eager"), pytest.mark.xfail(reason=reason), condition=condition)


148
def xfail_jit_python_scalar_arg(name, *, reason=None):
149
150
    return xfail_jit(
        reason or f"Python scalar int or float for `{name}` is not supported when scripting",
151
152
153
154
        condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
    )


155
156
157
KERNEL_INFOS = []


158
def get_fills(*, num_channels, dtype):
159
160
    yield None

161
162
163
164
    int_value = get_max_value(dtype)
    float_value = int_value / 2
    yield int_value
    yield float_value
165

166
167
168
    for vector_type in [list, tuple]:
        yield vector_type([int_value])
        yield vector_type([float_value])
169

170
171
172
        if num_channels > 1:
            yield vector_type(float_value * c / 10 for c in range(num_channels))
            yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels))
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187


def float32_vs_uint8_fill_adapter(other_args, kwargs):
    fill = kwargs.get("fill")
    if fill is None:
        return other_args, kwargs

    if isinstance(fill, (int, float)):
        fill /= 255
    else:
        fill = type(fill)(fill_ / 255 for fill_ in fill)

    return other_args, dict(kwargs, fill=fill)


188
189
def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix):
    def transform(bbox, affine_matrix_, format_, spatial_size_):
190
191
        # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
        in_dtype = bbox.dtype
192
193
        if not torch.is_floating_point(bbox):
            bbox = bbox.float()
194
        bbox_xyxy = F.convert_format_bounding_box(
195
196
197
198
            bbox.as_subclass(torch.Tensor),
            old_format=format_,
            new_format=datapoints.BoundingBoxFormat.XYXY,
            inplace=True,
199
        )
200
201
202
203
204
205
206
207
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
208
        transformed_points = np.matmul(points, affine_matrix_.T)
209
210
        out_bbox = torch.tensor(
            [
211
212
213
214
                np.min(transformed_points[:, 0]).item(),
                np.min(transformed_points[:, 1]).item(),
                np.max(transformed_points[:, 0]).item(),
                np.max(transformed_points[:, 1]).item(),
215
            ],
216
            dtype=bbox_xyxy.dtype,
217
        )
218
        out_bbox = F.convert_format_bounding_box(
219
            out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
220
        )
221
222
223
224
        # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
        out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_)
        out_bbox = out_bbox.to(dtype=in_dtype)
        return out_bbox
225
226
227
228

    if bounding_box.ndim < 2:
        bounding_box = [bounding_box]

229
    expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box]
230
231
232
233
234
235
236
237
    if len(expected_bboxes) > 1:
        expected_bboxes = torch.stack(expected_bboxes)
    else:
        expected_bboxes = expected_bboxes[0]

    return expected_bboxes


238
def sample_inputs_convert_format_bounding_box():
239
    formats = list(datapoints.BoundingBoxFormat)
240
    for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
241
        yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
242
243


244
def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
245
    return torchvision.ops.box_convert(
246
247
        bounding_box, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
    ).to(bounding_box.dtype)
248
249
250


def reference_inputs_convert_format_bounding_box():
251
    for args_kwargs in sample_inputs_convert_format_bounding_box():
252
253
        if len(args_kwargs.args[0].shape) == 2:
            yield args_kwargs
254
255
256
257
258
259
260
261


KERNEL_INFOS.append(
    KernelInfo(
        F.convert_format_bounding_box,
        sample_inputs_fn=sample_inputs_convert_format_bounding_box,
        reference_fn=reference_convert_format_bounding_box,
        reference_inputs_fn=reference_inputs_convert_format_bounding_box,
262
        logs_usage=True,
263
264
265
266
    ),
)


267
268
269
270
_CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20], width=[12, 20])


def sample_inputs_crop_image_tensor():
271
    for image_loader, params in itertools.product(
272
        make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
273
274
275
276
277
278
279
280
        [
            dict(top=4, left=3, height=7, width=8),
            dict(top=-1, left=3, height=7, width=8),
            dict(top=4, left=-1, height=7, width=8),
            dict(top=4, left=3, height=17, width=8),
            dict(top=4, left=3, height=7, width=18),
        ],
    ):
281
282
283
284
        yield ArgsKwargs(image_loader, **params)


def reference_inputs_crop_image_tensor():
285
286
287
    for image_loader, params in itertools.product(
        make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _CROP_PARAMS
    ):
288
289
290
291
292
293
294
        yield ArgsKwargs(image_loader, **params)


def sample_inputs_crop_bounding_box():
    for bounding_box_loader, params in itertools.product(
        make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
    ):
295
        yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
296
297
298


def sample_inputs_crop_mask():
299
300
    for mask_loader in make_mask_loaders(sizes=[(16, 17)], num_categories=["random"], num_objects=["random"]):
        yield ArgsKwargs(mask_loader, top=4, left=3, height=7, width=8)
301
302
303
304
305
306
307


def reference_inputs_crop_mask():
    for mask_loader, params in itertools.product(make_mask_loaders(extra_dims=[()], num_objects=[1]), _CROP_PARAMS):
        yield ArgsKwargs(mask_loader, **params)


308
309
310
311
312
def sample_inputs_crop_video():
    for video_loader in make_video_loaders(sizes=[(16, 17)], num_frames=["random"]):
        yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8)


313
314
315
316
317
318
def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width):
    affine_matrix = np.array(
        [
            [1, 0, -left],
            [0, 1, -top],
        ],
319
        dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
320
321
    )

322
323
324
325
326
    spatial_size = (height, width)
    expected_bboxes = reference_affine_bounding_box_helper(
        bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
    )
    return expected_bboxes, spatial_size
327
328
329
330
331
332
333
334
335


def reference_inputs_crop_bounding_box():
    for bounding_box_loader, params in itertools.product(
        make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
    ):
        yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)


336
337
338
339
340
341
342
343
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.crop_image_tensor,
            kernel_name="crop_image_tensor",
            sample_inputs_fn=sample_inputs_crop_image_tensor,
            reference_fn=pil_reference_wrapper(F.crop_image_pil),
            reference_inputs_fn=reference_inputs_crop_image_tensor,
344
            float32_vs_uint8=True,
345
346
347
348
        ),
        KernelInfo(
            F.crop_bounding_box,
            sample_inputs_fn=sample_inputs_crop_bounding_box,
349
350
            reference_fn=reference_crop_bounding_box,
            reference_inputs_fn=reference_inputs_crop_bounding_box,
351
352
353
354
355
356
        ),
        KernelInfo(
            F.crop_mask,
            sample_inputs_fn=sample_inputs_crop_mask,
            reference_fn=pil_reference_wrapper(F.crop_image_pil),
            reference_inputs_fn=reference_inputs_crop_mask,
357
            float32_vs_uint8=True,
358
        ),
359
360
361
362
        KernelInfo(
            F.crop_video,
            sample_inputs_fn=sample_inputs_crop_video,
        ),
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    ]
)

_RESIZED_CROP_PARAMS = combinations_grid(top=[-8, 9], left=[-8, 9], height=[12], width=[12], size=[(16, 18)])


def sample_inputs_resized_crop_image_tensor():
    for image_loader in make_image_loaders():
        yield ArgsKwargs(image_loader, **_RESIZED_CROP_PARAMS[0])


@pil_reference_wrapper
def reference_resized_crop_image_tensor(*args, **kwargs):
    if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in {
        F.InterpolationMode.BILINEAR,
        F.InterpolationMode.BICUBIC,
    }:
        raise pytest.UsageError("Anti-aliasing is always active in PIL")
    return F.resized_crop_image_pil(*args, **kwargs)


def reference_inputs_resized_crop_image_tensor():
    for image_loader, interpolation, params in itertools.product(
386
        make_image_loaders_for_interpolation(),
387
388
        [
            F.InterpolationMode.NEAREST,
389
            F.InterpolationMode.NEAREST_EXACT,
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
            F.InterpolationMode.BILINEAR,
            F.InterpolationMode.BICUBIC,
        ],
        _RESIZED_CROP_PARAMS,
    ):
        yield ArgsKwargs(
            image_loader,
            interpolation=interpolation,
            antialias=interpolation
            in {
                F.InterpolationMode.BILINEAR,
                F.InterpolationMode.BICUBIC,
            },
            **params,
        )


def sample_inputs_resized_crop_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **_RESIZED_CROP_PARAMS[0])


def sample_inputs_resized_crop_mask():
    for mask_loader in make_mask_loaders():
        yield ArgsKwargs(mask_loader, **_RESIZED_CROP_PARAMS[0])


417
418
419
420
421
def sample_inputs_resized_crop_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, **_RESIZED_CROP_PARAMS[0])


422
423
424
425
426
427
428
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.resized_crop_image_tensor,
            sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
            reference_fn=reference_resized_crop_image_tensor,
            reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
429
            float32_vs_uint8=True,
430
            closeness_kwargs={
431
                **cuda_vs_cpu_pixel_difference(),
432
433
                **pil_reference_pixel_difference(3, mae=True),
                **float32_vs_uint8_pixel_difference(3, mae=True),
434
            },
435
436
437
438
439
440
441
442
443
        ),
        KernelInfo(
            F.resized_crop_bounding_box,
            sample_inputs_fn=sample_inputs_resized_crop_bounding_box,
        ),
        KernelInfo(
            F.resized_crop_mask,
            sample_inputs_fn=sample_inputs_resized_crop_mask,
        ),
444
445
446
        KernelInfo(
            F.resized_crop_video,
            sample_inputs_fn=sample_inputs_resized_crop_video,
447
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
448
        ),
449
450
451
452
453
454
455
456
457
458
    ]
)

_PAD_PARAMS = combinations_grid(
    padding=[[1], [1, 1], [1, 1, 2, 2]],
    padding_mode=["constant", "symmetric", "edge", "reflect"],
)


def sample_inputs_pad_image_tensor():
459
    make_pad_image_loaders = functools.partial(
460
        make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
461
462
463
464
465
466
467
468
469
    )

    for image_loader, padding in itertools.product(
        make_pad_image_loaders(),
        [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]],
    ):
        yield ArgsKwargs(image_loader, padding=padding)

    for image_loader in make_pad_image_loaders():
470
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
471
472
473
474
475
476
477
478
479
480
481
482
483
            yield ArgsKwargs(image_loader, padding=[1], fill=fill)

    for image_loader, padding_mode in itertools.product(
        # We branch for non-constant padding and integer inputs
        make_pad_image_loaders(dtypes=[torch.uint8]),
        ["constant", "symmetric", "edge", "reflect"],
    ):
        yield ArgsKwargs(image_loader, padding=[1], padding_mode=padding_mode)

    # `torch.nn.functional.pad` does not support symmetric padding, and thus we have a custom implementation. Besides
    # negative padding, this is already handled by the inputs above.
    for image_loader in make_pad_image_loaders():
        yield ArgsKwargs(image_loader, padding=[-1], padding_mode="symmetric")
484
485
486


def reference_inputs_pad_image_tensor():
487
488
489
490
491
492
493
    for image_loader, params in itertools.product(
        make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
    ):
        for fill in get_fills(
            num_channels=image_loader.num_channels,
            dtype=image_loader.dtype,
        ):
494
495
496
497
            # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
            if isinstance(fill, (list, tuple)):
                continue

498
499
500
501
            yield ArgsKwargs(image_loader, fill=fill, **params)


def sample_inputs_pad_bounding_box():
502
503
504
    for bounding_box_loader, padding in itertools.product(
        make_bounding_box_loaders(), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
    ):
505
        yield ArgsKwargs(
506
507
            bounding_box_loader,
            format=bounding_box_loader.format,
508
            spatial_size=bounding_box_loader.spatial_size,
509
510
            padding=padding,
            padding_mode="constant",
511
        )
512
513
514


def sample_inputs_pad_mask():
515
516
    for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
        yield ArgsKwargs(mask_loader, padding=[1])
517
518
519


def reference_inputs_pad_mask():
520
521
522
523
    for mask_loader, fill, params in itertools.product(
        make_mask_loaders(num_objects=[1], extra_dims=[()]), [None, 127], _PAD_PARAMS
    ):
        yield ArgsKwargs(mask_loader, fill=fill, **params)
524
525


526
527
528
529
530
def sample_inputs_pad_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, padding=[1])


531
532
533
534
535
536
537
538
539
def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, padding_mode):

    left, right, top, bottom = _parse_pad_padding(padding)

    affine_matrix = np.array(
        [
            [1, 0, left],
            [0, 1, top],
        ],
540
        dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
541
542
543
544
545
    )

    height = spatial_size[0] + top + bottom
    width = spatial_size[1] + left + right

546
547
548
    expected_bboxes = reference_affine_bounding_box_helper(
        bounding_box, format=format, spatial_size=(height, width), affine_matrix=affine_matrix
    )
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    return expected_bboxes, (height, width)


def reference_inputs_pad_bounding_box():
    for bounding_box_loader, padding in itertools.product(
        make_bounding_box_loaders(extra_dims=((), (4,))), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
    ):
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
            spatial_size=bounding_box_loader.spatial_size,
            padding=padding,
            padding_mode="constant",
        )


565
566
567
568
569
570
571
572
573
574
def pad_xfail_jit_fill_condition(args_kwargs):
    fill = args_kwargs.kwargs.get("fill")
    if not isinstance(fill, (list, tuple)):
        return False
    elif isinstance(fill, tuple):
        return True
    else:  # isinstance(fill, list):
        return all(isinstance(f, int) for f in fill)


575
576
577
578
579
580
581
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.pad_image_tensor,
            sample_inputs_fn=sample_inputs_pad_image_tensor,
            reference_fn=pil_reference_wrapper(F.pad_image_pil),
            reference_inputs_fn=reference_inputs_pad_image_tensor,
582
583
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
584
            test_marks=[
585
586
587
588
                xfail_jit_python_scalar_arg("padding"),
                xfail_jit(
                    "F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition
                ),
589
            ],
590
591
592
593
        ),
        KernelInfo(
            F.pad_bounding_box,
            sample_inputs_fn=sample_inputs_pad_bounding_box,
594
595
            reference_fn=reference_pad_bounding_box,
            reference_inputs_fn=reference_inputs_pad_bounding_box,
596
            test_marks=[
597
                xfail_jit_python_scalar_arg("padding"),
598
            ],
599
600
601
602
603
604
        ),
        KernelInfo(
            F.pad_mask,
            sample_inputs_fn=sample_inputs_pad_mask,
            reference_fn=pil_reference_wrapper(F.pad_image_pil),
            reference_inputs_fn=reference_inputs_pad_mask,
605
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
606
        ),
607
608
609
610
        KernelInfo(
            F.pad_video,
            sample_inputs_fn=sample_inputs_pad_video,
        ),
611
612
613
614
615
616
617
    ]
)

_PERSPECTIVE_COEFFS = [
    [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
    [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
]
618
619
_STARTPOINTS = [[0, 1], [2, 3], [4, 5], [6, 7]]
_ENDPOINTS = [[9, 8], [7, 6], [5, 4], [3, 2]]
620
621
622


def sample_inputs_perspective_image_tensor():
623
    for image_loader in make_image_loaders(sizes=["random"]):
624
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
625
626
627
628
629
            yield ArgsKwargs(
                image_loader, startpoints=None, endpoints=None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0]
            )

    yield ArgsKwargs(make_image_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
630
631
632


def reference_inputs_perspective_image_tensor():
633
634
635
636
637
638
639
    for image_loader, coefficients, interpolation in itertools.product(
        make_image_loaders_for_interpolation(),
        _PERSPECTIVE_COEFFS,
        [
            F.InterpolationMode.NEAREST,
            F.InterpolationMode.BILINEAR,
        ],
640
641
    ):
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
642
643
644
645
            # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
            if isinstance(fill, (list, tuple)):
                continue

646
647
648
649
650
651
652
653
            yield ArgsKwargs(
                image_loader,
                startpoints=None,
                endpoints=None,
                interpolation=interpolation,
                fill=fill,
                coefficients=coefficients,
            )
654
655
656
657
658


def sample_inputs_perspective_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(
659
660
            bounding_box_loader,
            format=bounding_box_loader.format,
661
            spatial_size=bounding_box_loader.spatial_size,
662
663
664
            startpoints=None,
            endpoints=None,
            coefficients=_PERSPECTIVE_COEFFS[0],
665
666
        )

667
    format = datapoints.BoundingBoxFormat.XYXY
668
    loader = make_bounding_box_loader(format=format)
669
    yield ArgsKwargs(
670
        loader, format=format, spatial_size=loader.spatial_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
671
672
    )

673
674

def sample_inputs_perspective_mask():
675
    for mask_loader in make_mask_loaders(sizes=["random"]):
676
677
678
        yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])

    yield ArgsKwargs(make_detection_mask_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
679
680
681
682
683
684


def reference_inputs_perspective_mask():
    for mask_loader, perspective_coeffs in itertools.product(
        make_mask_loaders(extra_dims=[()], num_objects=[1]), _PERSPECTIVE_COEFFS
    ):
685
        yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=perspective_coeffs)
686
687


688
689
def sample_inputs_perspective_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
690
691
692
        yield ArgsKwargs(video_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])

    yield ArgsKwargs(make_video_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
693
694


695
696
697
698
699
700
701
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.perspective_image_tensor,
            sample_inputs_fn=sample_inputs_perspective_image_tensor,
            reference_fn=pil_reference_wrapper(F.perspective_image_pil),
            reference_inputs_fn=reference_inputs_perspective_image_tensor,
702
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
703
            closeness_kwargs={
704
                **pil_reference_pixel_difference(2, mae=True),
705
706
                **cuda_vs_cpu_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
707
708
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
709
            },
710
            test_marks=[xfail_jit_python_scalar_arg("fill")],
711
712
713
714
        ),
        KernelInfo(
            F.perspective_bounding_box,
            sample_inputs_fn=sample_inputs_perspective_bounding_box,
715
716
717
718
            closeness_kwargs={
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6),
            },
719
720
721
722
723
724
        ),
        KernelInfo(
            F.perspective_mask,
            sample_inputs_fn=sample_inputs_perspective_mask,
            reference_fn=pil_reference_wrapper(F.perspective_image_pil),
            reference_inputs_fn=reference_inputs_perspective_mask,
725
726
727
728
            float32_vs_uint8=True,
            closeness_kwargs={
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=10, rtol=0),
            },
729
730
731
732
        ),
        KernelInfo(
            F.perspective_video,
            sample_inputs_fn=sample_inputs_perspective_video,
733
734
            closeness_kwargs={
                **cuda_vs_cpu_pixel_difference(),
735
736
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
737
            },
738
739
740
741
742
        ),
    ]
)


743
744
def _get_elastic_displacement(spatial_size):
    return torch.rand(1, *spatial_size, 2)
745
746
747


def sample_inputs_elastic_image_tensor():
748
    for image_loader in make_image_loaders(sizes=["random"]):
749
        displacement = _get_elastic_displacement(image_loader.spatial_size)
750
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
751
752
753
754
755
            yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)


def reference_inputs_elastic_image_tensor():
    for image_loader, interpolation in itertools.product(
756
        make_image_loaders_for_interpolation(),
757
758
759
760
761
762
        [
            F.InterpolationMode.NEAREST,
            F.InterpolationMode.BILINEAR,
            F.InterpolationMode.BICUBIC,
        ],
    ):
763
        displacement = _get_elastic_displacement(image_loader.spatial_size)
764
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
765
766
767
768
769
            yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)


def sample_inputs_elastic_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
770
        displacement = _get_elastic_displacement(bounding_box_loader.spatial_size)
771
772
773
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
774
            spatial_size=bounding_box_loader.spatial_size,
775
776
777
778
779
            displacement=displacement,
        )


def sample_inputs_elastic_mask():
780
    for mask_loader in make_mask_loaders(sizes=["random"]):
781
782
783
784
        displacement = _get_elastic_displacement(mask_loader.shape[-2:])
        yield ArgsKwargs(mask_loader, displacement=displacement)


785
786
787
788
789
790
def sample_inputs_elastic_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        displacement = _get_elastic_displacement(video_loader.shape[-2:])
        yield ArgsKwargs(video_loader, displacement=displacement)


791
792
793
794
795
796
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.elastic_image_tensor,
            sample_inputs_fn=sample_inputs_elastic_image_tensor,
            reference_inputs_fn=reference_inputs_elastic_image_tensor,
797
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
798
            closeness_kwargs={
799
                **float32_vs_uint8_pixel_difference(6, mae=True),
800
801
                **cuda_vs_cpu_pixel_difference(),
            },
802
            test_marks=[xfail_jit_python_scalar_arg("fill")],
803
804
805
806
807
808
809
810
        ),
        KernelInfo(
            F.elastic_bounding_box,
            sample_inputs_fn=sample_inputs_elastic_bounding_box,
        ),
        KernelInfo(
            F.elastic_mask,
            sample_inputs_fn=sample_inputs_elastic_mask,
811
812
813
814
        ),
        KernelInfo(
            F.elastic_video,
            sample_inputs_fn=sample_inputs_elastic_video,
815
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
816
817
818
819
820
        ),
    ]
)


821
_CENTER_CROP_SPATIAL_SIZES = [(16, 16), (7, 33), (31, 9)]
822
_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
823
824
825
826


def sample_inputs_center_crop_image_tensor():
    for image_loader, output_size in itertools.product(
827
        make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
828
829
830
831
832
833
        [
            # valid `output_size` types for which cropping is applied to both dimensions
            *[5, (4,), (2, 3), [6], [3, 2]],
            # `output_size`'s for which at least one dimension needs to be padded
            *[[4, 18], [17, 5], [17, 18]],
        ],
834
835
836
837
838
839
    ):
        yield ArgsKwargs(image_loader, output_size=output_size)


def reference_inputs_center_crop_image_tensor():
    for image_loader, output_size in itertools.product(
840
841
        make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], dtypes=[torch.uint8]),
        _CENTER_CROP_OUTPUT_SIZES,
842
843
844
845
846
847
848
849
850
    ):
        yield ArgsKwargs(image_loader, output_size=output_size)


def sample_inputs_center_crop_bounding_box():
    for bounding_box_loader, output_size in itertools.product(make_bounding_box_loaders(), _CENTER_CROP_OUTPUT_SIZES):
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
851
            spatial_size=bounding_box_loader.spatial_size,
852
853
854
855
856
            output_size=output_size,
        )


def sample_inputs_center_crop_mask():
857
858
859
    for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
        height, width = mask_loader.shape[-2:]
        yield ArgsKwargs(mask_loader, output_size=(height // 2, width // 2))
860
861
862
863


def reference_inputs_center_crop_mask():
    for mask_loader, output_size in itertools.product(
864
        make_mask_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES
865
866
867
868
    ):
        yield ArgsKwargs(mask_loader, output_size=output_size)


869
870
871
872
873
874
def sample_inputs_center_crop_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        height, width = video_loader.shape[-2:]
        yield ArgsKwargs(video_loader, output_size=(height // 2, width // 2))


875
876
877
878
879
880
881
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.center_crop_image_tensor,
            sample_inputs_fn=sample_inputs_center_crop_image_tensor,
            reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
            reference_inputs_fn=reference_inputs_center_crop_image_tensor,
882
            float32_vs_uint8=True,
883
            test_marks=[
884
                xfail_jit_python_scalar_arg("output_size"),
885
            ],
886
887
888
889
        ),
        KernelInfo(
            F.center_crop_bounding_box,
            sample_inputs_fn=sample_inputs_center_crop_bounding_box,
890
            test_marks=[
891
                xfail_jit_python_scalar_arg("output_size"),
892
            ],
893
894
895
896
897
898
        ),
        KernelInfo(
            F.center_crop_mask,
            sample_inputs_fn=sample_inputs_center_crop_mask,
            reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
            reference_inputs_fn=reference_inputs_center_crop_mask,
899
            float32_vs_uint8=True,
900
            test_marks=[
901
                xfail_jit_python_scalar_arg("output_size"),
902
            ],
903
        ),
904
905
906
907
        KernelInfo(
            F.center_crop_video,
            sample_inputs_fn=sample_inputs_center_crop_video,
        ),
908
909
910
911
912
    ]
)


def sample_inputs_gaussian_blur_image_tensor():
913
    make_gaussian_blur_image_loaders = functools.partial(make_image_loaders, sizes=[(7, 33)], color_spaces=["RGB"])
914
915
916
917
918
919

    for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
        yield ArgsKwargs(image_loader, kernel_size=kernel_size)

    for image_loader, sigma in itertools.product(
        make_gaussian_blur_image_loaders(), [None, (3.0, 3.0), [2.0, 2.0], 4.0, [1.5], (3.14,)]
920
    ):
921
        yield ArgsKwargs(image_loader, kernel_size=5, sigma=sigma)
922
923


924
def sample_inputs_gaussian_blur_video():
925
    for video_loader in make_video_loaders(sizes=[(7, 33)], num_frames=[5]):
926
927
928
929
930
931
932
933
        yield ArgsKwargs(video_loader, kernel_size=[3, 3])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.gaussian_blur_image_tensor,
            sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
934
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
935
936
937
938
939
940
941
942
            test_marks=[
                xfail_jit_python_scalar_arg("kernel_size"),
                xfail_jit_python_scalar_arg("sigma"),
            ],
        ),
        KernelInfo(
            F.gaussian_blur_video,
            sample_inputs_fn=sample_inputs_gaussian_blur_video,
943
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
944
945
        ),
    ]
946
947
948
949
)


def sample_inputs_equalize_image_tensor():
950
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
951
952
953
954
        yield ArgsKwargs(image_loader)


def reference_inputs_equalize_image_tensor():
955
956
957
    # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
    # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
    # the information gain is low if we already provide something really close to the expected value.
958
    def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
959
960
961
962
963
964
965
        if dtype.is_floating_point:
            low = low_factor
            high = high_factor
        else:
            max_value = torch.iinfo(dtype).max
            low = int(low_factor * max_value)
            high = int(high_factor * max_value)
966
967
968
        return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
            memory_format=memory_format, copy=True
        )
969

970
    def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
971
972
973
        image = torch.distributions.Beta(alpha, beta).sample(shape)
        if not dtype.is_floating_point:
            image.mul_(torch.iinfo(dtype).max).round_()
974
        return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)
975

976
    spatial_size = (256, 256)
977
    for dtype, color_space, fn in itertools.product(
978
        [torch.uint8],
979
        ["GRAY", "RGB"],
980
        [
981
982
            lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
                memory_format=memory_format, copy=True
983
            ),
984
985
986
            lambda shape, dtype, device, memory_format: torch.full(
                shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
            ).to(memory_format=memory_format, copy=True),
987
            *[
988
989
990
991
992
                functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
                for low_factor, high_factor in [
                    (0.0, 0.25),
                    (0.25, 0.75),
                    (0.75, 1.0),
993
994
995
                ]
            ],
            *[
996
                functools.partial(make_beta_distributed_image, alpha=alpha, beta=beta)
997
998
999
1000
1001
1002
1003
1004
                for alpha, beta in [
                    (0.5, 0.5),
                    (2, 2),
                    (2, 5),
                    (5, 2),
                ]
            ],
        ],
1005
    ):
1006
        image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype)
1007
1008
1009
        yield ArgsKwargs(image_loader)


1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
def sample_inputs_equalize_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.equalize_image_tensor,
            kernel_name="equalize_image_tensor",
            sample_inputs_fn=sample_inputs_equalize_image_tensor,
            reference_fn=pil_reference_wrapper(F.equalize_image_pil),
1022
            float32_vs_uint8=True,
1023
1024
1025
1026
1027
1028
1029
            reference_inputs_fn=reference_inputs_equalize_image_tensor,
        ),
        KernelInfo(
            F.equalize_video,
            sample_inputs_fn=sample_inputs_equalize_video,
        ),
    ]
1030
1031
1032
1033
)


def sample_inputs_invert_image_tensor():
1034
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1035
1036
1037
1038
        yield ArgsKwargs(image_loader)


def reference_inputs_invert_image_tensor():
1039
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1040
1041
1042
        yield ArgsKwargs(image_loader)


1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
def sample_inputs_invert_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.invert_image_tensor,
            kernel_name="invert_image_tensor",
            sample_inputs_fn=sample_inputs_invert_image_tensor,
            reference_fn=pil_reference_wrapper(F.invert_image_pil),
            reference_inputs_fn=reference_inputs_invert_image_tensor,
1056
            float32_vs_uint8=True,
1057
1058
1059
1060
1061
1062
        ),
        KernelInfo(
            F.invert_video,
            sample_inputs_fn=sample_inputs_invert_video,
        ),
    ]
1063
1064
1065
1066
1067
1068
1069
)


_POSTERIZE_BITS = [1, 4, 8]


def sample_inputs_posterize_image_tensor():
1070
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1071
1072
1073
1074
1075
        yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])


def reference_inputs_posterize_image_tensor():
    for image_loader, bits in itertools.product(
1076
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1077
1078
1079
1080
1081
        _POSTERIZE_BITS,
    ):
        yield ArgsKwargs(image_loader, bits=bits)


1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
def sample_inputs_posterize_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.posterize_image_tensor,
            kernel_name="posterize_image_tensor",
            sample_inputs_fn=sample_inputs_posterize_image_tensor,
            reference_fn=pil_reference_wrapper(F.posterize_image_pil),
            reference_inputs_fn=reference_inputs_posterize_image_tensor,
1095
1096
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
1097
1098
1099
1100
1101
1102
        ),
        KernelInfo(
            F.posterize_video,
            sample_inputs_fn=sample_inputs_posterize_video,
        ),
    ]
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
)


def _get_solarize_thresholds(dtype):
    for factor in [0.1, 0.5]:
        max_value = get_max_value(dtype)
        yield (float if dtype.is_floating_point else int)(max_value * factor)


def sample_inputs_solarize_image_tensor():
1113
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1114
1115
1116
1117
        yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype)))


def reference_inputs_solarize_image_tensor():
1118
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1119
1120
1121
1122
        for threshold in _get_solarize_thresholds(image_loader.dtype):
            yield ArgsKwargs(image_loader, threshold=threshold)


1123
1124
1125
1126
def uint8_to_float32_threshold_adapter(other_args, kwargs):
    return other_args, dict(threshold=kwargs["threshold"] / 255)


1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
def sample_inputs_solarize_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype)))


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.solarize_image_tensor,
            kernel_name="solarize_image_tensor",
            sample_inputs_fn=sample_inputs_solarize_image_tensor,
            reference_fn=pil_reference_wrapper(F.solarize_image_pil),
            reference_inputs_fn=reference_inputs_solarize_image_tensor,
1140
1141
            float32_vs_uint8=uint8_to_float32_threshold_adapter,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
1142
1143
1144
1145
1146
1147
        ),
        KernelInfo(
            F.solarize_video,
            sample_inputs_fn=sample_inputs_solarize_video,
        ),
    ]
1148
1149
1150
1151
)


def sample_inputs_autocontrast_image_tensor():
1152
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1153
1154
1155
1156
        yield ArgsKwargs(image_loader)


def reference_inputs_autocontrast_image_tensor():
1157
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1158
1159
1160
        yield ArgsKwargs(image_loader)


1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
def sample_inputs_autocontrast_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.autocontrast_image_tensor,
            kernel_name="autocontrast_image_tensor",
            sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
            reference_fn=pil_reference_wrapper(F.autocontrast_image_pil),
            reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
1174
1175
1176
1177
1178
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
            },
1179
1180
1181
1182
1183
1184
        ),
        KernelInfo(
            F.autocontrast_video,
            sample_inputs_fn=sample_inputs_autocontrast_video,
        ),
    ]
1185
1186
1187
1188
1189
1190
1191
1192
)

_ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_sharpness_image_tensor():
    for image_loader in make_image_loaders(
        sizes=["random", (2, 2)],
1193
        color_spaces=("GRAY", "RGB"),
1194
1195
1196
1197
1198
1199
    ):
        yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])


def reference_inputs_adjust_sharpness_image_tensor():
    for image_loader, sharpness_factor in itertools.product(
1200
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1201
1202
1203
1204
1205
        _ADJUST_SHARPNESS_FACTORS,
    ):
        yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)


1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
def sample_inputs_adjust_sharpness_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_sharpness_image_tensor,
            kernel_name="adjust_sharpness_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil),
            reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
1219
1220
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(2),
1221
1222
1223
1224
1225
1226
        ),
        KernelInfo(
            F.adjust_sharpness_video,
            sample_inputs_fn=sample_inputs_adjust_sharpness_video,
        ),
    ]
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
)


def sample_inputs_erase_image_tensor():
    for image_loader in make_image_loaders(sizes=["random"]):
        # FIXME: make the parameters more diverse
        h, w = 6, 7
        v = torch.rand(image_loader.num_channels, h, w)
        yield ArgsKwargs(image_loader, i=1, j=2, h=h, w=w, v=v)


1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
def sample_inputs_erase_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        # FIXME: make the parameters more diverse
        h, w = 6, 7
        v = torch.rand(video_loader.num_channels, h, w)
        yield ArgsKwargs(video_loader, i=1, j=2, h=h, w=w, v=v)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.erase_image_tensor,
            kernel_name="erase_image_tensor",
            sample_inputs_fn=sample_inputs_erase_image_tensor,
        ),
        KernelInfo(
            F.erase_video,
            sample_inputs_fn=sample_inputs_erase_video,
        ),
    ]
1258
)
1259
1260
1261
1262
1263

_ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_brightness_image_tensor():
1264
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1265
1266
1267
1268
1269
        yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])


def reference_inputs_adjust_brightness_image_tensor():
    for image_loader, brightness_factor in itertools.product(
1270
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1271
1272
1273
1274
1275
        _ADJUST_BRIGHTNESS_FACTORS,
    ):
        yield ArgsKwargs(image_loader, brightness_factor=brightness_factor)


1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
def sample_inputs_adjust_brightness_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_brightness_image_tensor,
            kernel_name="adjust_brightness_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
            reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
1289
1290
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
1291
1292
1293
1294
1295
1296
        ),
        KernelInfo(
            F.adjust_brightness_video,
            sample_inputs_fn=sample_inputs_adjust_brightness_video,
        ),
    ]
1297
1298
1299
1300
1301
1302
1303
)


_ADJUST_CONTRAST_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_contrast_image_tensor():
1304
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1305
1306
1307
1308
1309
        yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])


def reference_inputs_adjust_contrast_image_tensor():
    for image_loader, contrast_factor in itertools.product(
1310
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1311
1312
1313
1314
1315
        _ADJUST_CONTRAST_FACTORS,
    ):
        yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)


1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
def sample_inputs_adjust_contrast_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_contrast_image_tensor,
            kernel_name="adjust_contrast_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil),
            reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
1329
1330
1331
1332
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(2),
1333
                **cuda_vs_cpu_pixel_difference(),
1334
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
1335
            },
1336
1337
1338
1339
        ),
        KernelInfo(
            F.adjust_contrast_video,
            sample_inputs_fn=sample_inputs_adjust_contrast_video,
1340
1341
1342
1343
            closeness_kwargs={
                **cuda_vs_cpu_pixel_difference(),
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
            },
1344
1345
        ),
    ]
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
)

_ADJUST_GAMMA_GAMMAS_GAINS = [
    (0.5, 2.0),
    (0.0, 1.0),
]


def sample_inputs_adjust_gamma_image_tensor():
    gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
1356
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1357
1358
1359
1360
1361
        yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)


def reference_inputs_adjust_gamma_image_tensor():
    for image_loader, (gamma, gain) in itertools.product(
1362
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1363
1364
1365
1366
1367
        _ADJUST_GAMMA_GAMMAS_GAINS,
    ):
        yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)


1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
def sample_inputs_adjust_gamma_video():
    gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, gamma=gamma, gain=gain)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_gamma_image_tensor,
            kernel_name="adjust_gamma_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
            reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
1382
1383
1384
1385
1386
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
            },
1387
1388
1389
1390
1391
1392
        ),
        KernelInfo(
            F.adjust_gamma_video,
            sample_inputs_fn=sample_inputs_adjust_gamma_video,
        ),
    ]
1393
1394
1395
1396
1397
1398
1399
)


_ADJUST_HUE_FACTORS = [-0.1, 0.5]


def sample_inputs_adjust_hue_image_tensor():
1400
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1401
1402
1403
1404
1405
        yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])


def reference_inputs_adjust_hue_image_tensor():
    for image_loader, hue_factor in itertools.product(
1406
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1407
1408
1409
1410
1411
        _ADJUST_HUE_FACTORS,
    ):
        yield ArgsKwargs(image_loader, hue_factor=hue_factor)


1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
def sample_inputs_adjust_hue_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_hue_image_tensor,
            kernel_name="adjust_hue_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil),
            reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
1425
1426
            float32_vs_uint8=True,
            closeness_kwargs={
1427
                **pil_reference_pixel_difference(2, mae=True),
1428
1429
                **float32_vs_uint8_pixel_difference(),
            },
1430
1431
1432
1433
1434
1435
        ),
        KernelInfo(
            F.adjust_hue_video,
            sample_inputs_fn=sample_inputs_adjust_hue_video,
        ),
    ]
1436
1437
1438
1439
1440
1441
)

_ADJUST_SATURATION_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_saturation_image_tensor():
1442
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1443
1444
1445
1446
1447
        yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])


def reference_inputs_adjust_saturation_image_tensor():
    for image_loader, saturation_factor in itertools.product(
1448
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1449
1450
1451
1452
1453
        _ADJUST_SATURATION_FACTORS,
    ):
        yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)


1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
def sample_inputs_adjust_saturation_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_saturation_image_tensor,
            kernel_name="adjust_saturation_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil),
            reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
1467
1468
1469
1470
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(2),
1471
                **cuda_vs_cpu_pixel_difference(),
1472
            },
1473
1474
1475
1476
        ),
        KernelInfo(
            F.adjust_saturation_video,
            sample_inputs_fn=sample_inputs_adjust_saturation_video,
1477
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
1478
1479
        ),
    ]
1480
1481
1482
1483
1484
1485
)


def sample_inputs_clamp_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(
1486
            bounding_box_loader,
1487
1488
            format=bounding_box_loader.format,
            spatial_size=bounding_box_loader.spatial_size,
1489
1490
1491
1492
1493
1494
1495
        )


KERNEL_INFOS.append(
    KernelInfo(
        F.clamp_bounding_box,
        sample_inputs_fn=sample_inputs_clamp_bounding_box,
1496
        logs_usage=True,
1497
1498
1499
1500
1501
1502
    )
)

_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]


1503
def _get_five_ten_crop_spatial_size(size):
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
    if isinstance(size, int):
        crop_height = crop_width = size
    elif len(size) == 1:
        crop_height = crop_width = size[0]
    else:
        crop_height, crop_width = size
    return 2 * crop_height, 2 * crop_width


def sample_inputs_five_crop_image_tensor():
    for size in _FIVE_TEN_CROP_SIZES:
1515
        for image_loader in make_image_loaders(
1516
            sizes=[_get_five_ten_crop_spatial_size(size)],
1517
            color_spaces=["RGB"],
1518
            dtypes=[torch.float32],
1519
        ):
1520
1521
1522
1523
1524
            yield ArgsKwargs(image_loader, size=size)


def reference_inputs_five_crop_image_tensor():
    for size in _FIVE_TEN_CROP_SIZES:
1525
1526
1527
        for image_loader in make_image_loaders(
            sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8]
        ):
1528
1529
1530
            yield ArgsKwargs(image_loader, size=size)


1531
1532
1533
1534
1535
1536
def sample_inputs_five_crop_video():
    size = _FIVE_TEN_CROP_SIZES[0]
    for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_spatial_size(size)]):
        yield ArgsKwargs(video_loader, size=size)


1537
1538
def sample_inputs_ten_crop_image_tensor():
    for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
1539
        for image_loader in make_image_loaders(
1540
            sizes=[_get_five_ten_crop_spatial_size(size)],
1541
            color_spaces=["RGB"],
1542
            dtypes=[torch.float32],
1543
        ):
1544
1545
1546
1547
1548
            yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


def reference_inputs_ten_crop_image_tensor():
    for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
1549
1550
1551
        for image_loader in make_image_loaders(
            sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8]
        ):
1552
1553
1554
            yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


1555
1556
1557
1558
1559
1560
def sample_inputs_ten_crop_video():
    size = _FIVE_TEN_CROP_SIZES[0]
    for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_spatial_size(size)]):
        yield ArgsKwargs(video_loader, size=size)


1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
def multi_crop_pil_reference_wrapper(pil_kernel):
    def wrapper(input_tensor, *other_args, **kwargs):
        output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
        return type(output)(
            F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype)
            for output_pil in output
        )

    return wrapper


1572
1573
1574
1575
1576
_common_five_ten_crop_marks = [
    xfail_jit_python_scalar_arg("size"),
    mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
]

1577
1578
1579
1580
1581
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.five_crop_image_tensor,
            sample_inputs_fn=sample_inputs_five_crop_image_tensor,
1582
            reference_fn=multi_crop_pil_reference_wrapper(F.five_crop_image_pil),
1583
            reference_inputs_fn=reference_inputs_five_crop_image_tensor,
1584
            test_marks=_common_five_ten_crop_marks,
1585
        ),
1586
1587
1588
1589
1590
        KernelInfo(
            F.five_crop_video,
            sample_inputs_fn=sample_inputs_five_crop_video,
            test_marks=_common_five_ten_crop_marks,
        ),
1591
1592
1593
        KernelInfo(
            F.ten_crop_image_tensor,
            sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
1594
            reference_fn=multi_crop_pil_reference_wrapper(F.ten_crop_image_pil),
1595
            reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
1596
            test_marks=_common_five_ten_crop_marks,
1597
        ),
1598
1599
1600
1601
1602
        KernelInfo(
            F.ten_crop_video,
            sample_inputs_fn=sample_inputs_ten_crop_video,
            test_marks=_common_five_ten_crop_marks,
        ),
1603
1604
1605
1606
1607
1608
    ]
)

_NORMALIZE_MEANS_STDS = [
    ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
1609
    (0.5, 2.0),
1610
1611
1612
1613
1614
]


def sample_inputs_normalize_image_tensor():
    for image_loader, (mean, std) in itertools.product(
1615
        make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]),
1616
1617
1618
1619
1620
        _NORMALIZE_MEANS_STDS,
    ):
        yield ArgsKwargs(image_loader, mean=mean, std=std)


1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
def reference_normalize_image_tensor(image, mean, std, inplace=False):
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)

    sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub
    return sub(image, mean).div_(std)


def reference_inputs_normalize_image_tensor():
    yield ArgsKwargs(
1631
        make_image_loader(size=(32, 32), color_space="RGB", extra_dims=[1]),
1632
1633
1634
1635
1636
        mean=[0.5, 0.5, 0.5],
        std=[1.0, 1.0, 1.0],
    )


1637
1638
1639
def sample_inputs_normalize_video():
    mean, std = _NORMALIZE_MEANS_STDS[0]
    for video_loader in make_video_loaders(
1640
        sizes=["random"], color_spaces=["RGB"], num_frames=["random"], dtypes=[torch.float32]
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
    ):
        yield ArgsKwargs(video_loader, mean=mean, std=std)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.normalize_image_tensor,
            kernel_name="normalize_image_tensor",
            sample_inputs_fn=sample_inputs_normalize_image_tensor,
1651
1652
            reference_fn=reference_normalize_image_tensor,
            reference_inputs_fn=reference_inputs_normalize_image_tensor,
1653
1654
1655
1656
            test_marks=[
                xfail_jit_python_scalar_arg("mean"),
                xfail_jit_python_scalar_arg("std"),
            ],
1657
1658
1659
1660
1661
1662
        ),
        KernelInfo(
            F.normalize_video,
            sample_inputs_fn=sample_inputs_normalize_video,
        ),
    ]
1663
)
1664
1665


1666
def sample_inputs_convert_dtype_image_tensor():
1667
1668
1669
1670
1671
1672
1673
    for input_dtype, output_dtype in itertools.product(
        [torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
    ):
        if input_dtype.is_floating_point and output_dtype == torch.int64:
            # conversion cannot be performed safely
            continue

1674
        for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[input_dtype]):
1675
1676
1677
            yield ArgsKwargs(image_loader, dtype=output_dtype)


1678
def reference_convert_dtype_image_tensor(image, dtype=torch.float):
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
    input_dtype = image.dtype
    output_dtype = dtype

    if output_dtype == input_dtype:
        return image

    def fn(value):
        if input_dtype.is_floating_point:
            if output_dtype.is_floating_point:
                return value
            else:
                return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
        else:
            input_max_value = torch.iinfo(input_dtype).max

            if output_dtype.is_floating_point:
                return float(decimal.Decimal(value) / input_max_value)
            else:
                output_max_value = torch.iinfo(output_dtype).max

                if input_max_value > output_max_value:
                    factor = (input_max_value + 1) // (output_max_value + 1)
                    return value // factor
                else:
                    factor = (output_max_value + 1) // (input_max_value + 1)
                    return value * factor

    return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)


1709
def reference_inputs_convert_dtype_image_tensor():
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
    for input_dtype, output_dtype in itertools.product(
        [
            torch.uint8,
            torch.int16,
            torch.int32,
            torch.int64,
            torch.float16,
            torch.float32,
            torch.float64,
            torch.bfloat16,
        ],
        repeat=2,
    ):
        if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or (
            input_dtype == torch.float64 and output_dtype == torch.int64
        ):
            continue

        if input_dtype.is_floating_point:
            data = [0.0, 0.5, 1.0]
        else:
            max_value = torch.iinfo(input_dtype).max
            data = [0, max_value // 2, max_value]
        image = torch.tensor(data, dtype=input_dtype)

        yield ArgsKwargs(image, dtype=output_dtype)


1738
1739
1740
1741
1742
def sample_inputs_convert_dtype_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


1743
1744
1745
1746
1747
skip_dtype_consistency = TestMark(
    ("TestKernels", "test_dtype_and_device_consistency"),
    pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
    condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32),
)
1748

1749
1750
1751
KERNEL_INFOS.extend(
    [
        KernelInfo(
1752
1753
1754
1755
            F.convert_dtype_image_tensor,
            sample_inputs_fn=sample_inputs_convert_dtype_image_tensor,
            reference_fn=reference_convert_dtype_image_tensor,
            reference_inputs_fn=reference_inputs_convert_dtype_image_tensor,
1756
            test_marks=[
1757
                skip_dtype_consistency,
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
                TestMark(
                    ("TestKernels", "test_against_reference"),
                    pytest.mark.xfail(reason="Conversion overflows"),
                    condition=lambda args_kwargs: (
                        args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
                        and not args_kwargs.kwargs["dtype"].is_floating_point
                    )
                    or (
                        args_kwargs.args[0].dtype in {torch.int32, torch.int64}
                        and args_kwargs.kwargs["dtype"] == torch.float16
                    ),
                ),
            ],
        ),
1772
1773
1774
        KernelInfo(
            F.convert_dtype_video,
            sample_inputs_fn=sample_inputs_convert_dtype_video,
1775
1776
1777
            test_marks=[
                skip_dtype_consistency,
            ],
1778
        ),
1779
1780
    ]
)
1781
1782
1783
1784


def sample_inputs_uniform_temporal_subsample_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]):
1785
        yield ArgsKwargs(video_loader, num_samples=2)
1786
1787


1788
def reference_uniform_temporal_subsample_video(x, num_samples):
1789
1790
    # Copy-pasted from
    # https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
1791
    t = x.shape[-4]
1792
1793
1794
1795
    assert num_samples > 0 and t > 0
    # Sample by nearest neighbor interpolation if num_samples > t.
    indices = torch.linspace(0, t - 1, num_samples)
    indices = torch.clamp(indices, 0, t - 1).long()
1796
    return torch.index_select(x, -4, indices)
1797
1798
1799


def reference_inputs_uniform_temporal_subsample_video():
1800
    for video_loader in make_video_loaders(sizes=["random"], color_spaces=["RGB"], num_frames=[10]):
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
        for num_samples in range(1, video_loader.shape[-4] + 1):
            yield ArgsKwargs(video_loader, num_samples)


KERNEL_INFOS.append(
    KernelInfo(
        F.uniform_temporal_subsample_video,
        sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
        reference_fn=reference_uniform_temporal_subsample_video,
        reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
    )
)