"vscode:/vscode.git/clone" did not exist on "f652494df16ef9fa0fac998ddf63961aee0849d4"
transforms_v2_kernel_infos.py 68.4 KB
Newer Older
1
import decimal
2
3
4
5
import functools
import itertools

import numpy as np
6
import PIL.Image
7
8
import pytest
import torch.testing
9
import torchvision.ops
10
import torchvision.transforms.v2.functional as F
11
from common_utils import (
12
    ArgsKwargs,
13
    combinations_grid,
14
15
    get_num_channels,
    ImageLoader,
16
    InfoBase,
17
    make_bounding_box_loader,
18
    make_bounding_box_loaders,
19
    make_detection_mask_loader,
20
21
    make_image_loader,
    make_image_loaders,
22
    make_image_loaders_for_interpolation,
23
    make_mask_loaders,
24
    make_video_loader,
25
    make_video_loaders,
26
27
    mark_framework_limitation,
    TestMark,
28
)
29
from torch.utils._pytree import tree_map
30
from torchvision import datapoints
31
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
32
33
34
35

__all__ = ["KernelInfo", "KERNEL_INFOS"]


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class KernelInfo(InfoBase):
    def __init__(
        self,
        kernel,
        *,
        # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
        # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
        kernel_name=None,
        # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
        # should not include extensive parameter combinations to keep to overall test count moderate.
        sample_inputs_fn,
        # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
        # take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
        # happen inside the function. It should return a tensor or to be more precise an object that can be compared to
        # a tensor by `assert_close`. If omitted, no reference test will be performed.
        reference_fn=None,
        # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
        # values to be tested. If not specified, `sample_inputs_fn` will be used.
        reference_inputs_fn=None,
55
        # If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
56
        # reference inputs. This is usually used whenever we use a PIL kernel as reference.
57
58
59
60
        # Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
        # structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
        # dtype.
        float32_vs_uint8=False,
61
62
63
        # Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
        # manually. If set, triggers a test that makes sure this happens.
        logs_usage=False,
64
65
66
67
68
69
70
71
72
73
        # See InfoBase
        test_marks=None,
        # See InfoBase
        closeness_kwargs=None,
    ):
        super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
        self.kernel = kernel
        self.sample_inputs_fn = sample_inputs_fn
        self.reference_fn = reference_fn
        self.reference_inputs_fn = reference_inputs_fn
74

75
76
77
        if float32_vs_uint8 and not callable(float32_vs_uint8):
            float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs)  # noqa: E731
        self.float32_vs_uint8 = float32_vs_uint8
78
        self.logs_usage = logs_usage
79
80


81
def pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
82
    return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, mae=mae)
83
84
85
86


def cuda_vs_cpu_pixel_difference(atol=1):
    return {
87
        (("TestKernels", "test_cuda_vs_cpu"), dtype, "cuda"): pixel_difference_closeness_kwargs(atol, dtype=dtype)
88
89
90
91
        for dtype in [torch.uint8, torch.float32]
    }


92
def pil_reference_pixel_difference(atol=1, mae=False):
93
    return {
94
        (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(
95
            atol, mae=mae
96
97
98
99
        )
    }


100
def float32_vs_uint8_pixel_difference(atol=1, mae=False):
101
102
103
104
105
    return {
        (
            ("TestKernels", "test_float32_vs_uint8"),
            torch.float32,
            "cpu",
106
        ): pixel_difference_closeness_kwargs(atol, dtype=torch.float32, mae=mae)
107
    }
108

109

110
def scripted_vs_eager_float64_tolerances(device, atol=1e-6, rtol=1e-6):
111
112
113
114
115
    return {
        (("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
    }


116
117
def pil_reference_wrapper(pil_kernel):
    @functools.wraps(pil_kernel)
118
119
120
121
    def wrapper(input_tensor, *other_args, **kwargs):
        if input_tensor.dtype != torch.uint8:
            raise pytest.UsageError(f"Can only test uint8 tensor images against PIL, but input is {input_tensor.dtype}")
        if input_tensor.ndim > 3:
122
            raise pytest.UsageError(
123
                f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
124
125
            )

126
127
128
129
130
131
132
133
134
135
136
137
138
139
        input_pil = F.to_image_pil(input_tensor)
        output_pil = pil_kernel(input_pil, *other_args, **kwargs)
        if not isinstance(output_pil, PIL.Image.Image):
            return output_pil

        output_tensor = F.to_image_tensor(output_pil)

        # 2D mask shenanigans
        if output_tensor.ndim == 2 and input_tensor.ndim == 3:
            output_tensor = output_tensor.unsqueeze(0)
        elif output_tensor.ndim == 3 and input_tensor.ndim == 2:
            output_tensor = output_tensor.squeeze(0)

        return output_tensor
140
141
142
143

    return wrapper


144
145
146
147
def xfail_jit(reason, *, condition=None):
    return TestMark(("TestKernels", "test_scripted_vs_eager"), pytest.mark.xfail(reason=reason), condition=condition)


148
def xfail_jit_python_scalar_arg(name, *, reason=None):
149
150
    return xfail_jit(
        reason or f"Python scalar int or float for `{name}` is not supported when scripting",
151
152
153
154
        condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
    )


155
156
157
KERNEL_INFOS = []


158
def get_fills(*, num_channels, dtype):
159
160
    yield None

161
162
163
164
    int_value = get_max_value(dtype)
    float_value = int_value / 2
    yield int_value
    yield float_value
165

166
167
168
    for vector_type in [list, tuple]:
        yield vector_type([int_value])
        yield vector_type([float_value])
169

170
171
172
        if num_channels > 1:
            yield vector_type(float_value * c / 10 for c in range(num_channels))
            yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels))
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187


def float32_vs_uint8_fill_adapter(other_args, kwargs):
    fill = kwargs.get("fill")
    if fill is None:
        return other_args, kwargs

    if isinstance(fill, (int, float)):
        fill /= 255
    else:
        fill = type(fill)(fill_ / 255 for fill_ in fill)

    return other_args, dict(kwargs, fill=fill)


188
189
def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix):
    def transform(bbox, affine_matrix_, format_, spatial_size_):
190
191
        # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
        in_dtype = bbox.dtype
192
193
        if not torch.is_floating_point(bbox):
            bbox = bbox.float()
194
        bbox_xyxy = F.convert_format_bounding_box(
195
196
197
198
            bbox.as_subclass(torch.Tensor),
            old_format=format_,
            new_format=datapoints.BoundingBoxFormat.XYXY,
            inplace=True,
199
        )
200
201
202
203
204
205
206
207
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
208
        transformed_points = np.matmul(points, affine_matrix_.T)
209
210
        out_bbox = torch.tensor(
            [
211
212
213
214
                np.min(transformed_points[:, 0]).item(),
                np.min(transformed_points[:, 1]).item(),
                np.max(transformed_points[:, 0]).item(),
                np.max(transformed_points[:, 1]).item(),
215
            ],
216
            dtype=bbox_xyxy.dtype,
217
        )
218
        out_bbox = F.convert_format_bounding_box(
219
            out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
220
        )
221
222
223
224
        # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
        out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_)
        out_bbox = out_bbox.to(dtype=in_dtype)
        return out_bbox
225
226
227
228

    if bounding_box.ndim < 2:
        bounding_box = [bounding_box]

229
    expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box]
230
231
232
233
234
235
236
237
    if len(expected_bboxes) > 1:
        expected_bboxes = torch.stack(expected_bboxes)
    else:
        expected_bboxes = expected_bboxes[0]

    return expected_bboxes


238
def sample_inputs_convert_format_bounding_box():
239
    formats = list(datapoints.BoundingBoxFormat)
240
    for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
241
        yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
242
243


244
def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
245
    return torchvision.ops.box_convert(
246
247
        bounding_box, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
    ).to(bounding_box.dtype)
248
249
250


def reference_inputs_convert_format_bounding_box():
251
    for args_kwargs in sample_inputs_convert_format_bounding_box():
252
253
        if len(args_kwargs.args[0].shape) == 2:
            yield args_kwargs
254
255
256
257
258
259
260
261


KERNEL_INFOS.append(
    KernelInfo(
        F.convert_format_bounding_box,
        sample_inputs_fn=sample_inputs_convert_format_bounding_box,
        reference_fn=reference_convert_format_bounding_box,
        reference_inputs_fn=reference_inputs_convert_format_bounding_box,
262
        logs_usage=True,
263
264
265
266
    ),
)


267
268
269
270
_ROTATE_ANGLES = [-87, 15, 90]


def sample_inputs_rotate_image_tensor():
271
    make_rotate_image_loaders = functools.partial(
272
        make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
273
274
275
276
277
278
279
    )

    for image_loader in make_rotate_image_loaders():
        yield ArgsKwargs(image_loader, angle=15.0, expand=True)

    for image_loader, center in itertools.product(
        make_rotate_image_loaders(), [None, [1.0, 0.5], [1, 2], (1.0, 0.5), (1, 2)]
280
    ):
281
        yield ArgsKwargs(image_loader, angle=15.0, center=center)
282

283
    for image_loader in make_rotate_image_loaders():
284
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
285
286
287
288
289
290
291
            yield ArgsKwargs(image_loader, angle=15.0, fill=fill)

    for image_loader, interpolation in itertools.product(
        make_rotate_image_loaders(),
        [F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR],
    ):
        yield ArgsKwargs(image_loader, angle=15.0, fill=0)
292
293
294


def reference_inputs_rotate_image_tensor():
295
    for image_loader, angle in itertools.product(make_image_loaders_for_interpolation(), _ROTATE_ANGLES):
296
297
298
299
300
301
302
303
        yield ArgsKwargs(image_loader, angle=angle)


def sample_inputs_rotate_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
304
            spatial_size=bounding_box_loader.spatial_size,
305
306
307
308
            angle=_ROTATE_ANGLES[0],
        )


309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def reference_inputs_rotate_bounding_box():
    for bounding_box_loader, angle in itertools.product(
        make_bounding_box_loaders(extra_dims=((), (4,))), _ROTATE_ANGLES
    ):
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
            spatial_size=bounding_box_loader.spatial_size,
            angle=angle,
        )

    # TODO: add samples with expand=True and center


def reference_rotate_bounding_box(bounding_box, *, format, spatial_size, angle, expand=False, center=None):

    if center is None:
        center = [spatial_size[1] * 0.5, spatial_size[0] * 0.5]

    a = np.cos(angle * np.pi / 180.0)
    b = np.sin(angle * np.pi / 180.0)
    cx = center[0]
    cy = center[1]
    affine_matrix = np.array(
        [
            [a, b, cx - cx * a - b * cy],
            [-b, a, cy + cx * b - a * cy],
        ],
        dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
    )

    expected_bboxes = reference_affine_bounding_box_helper(
        bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
    )
    return expected_bboxes, spatial_size


346
def sample_inputs_rotate_mask():
347
348
    for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
        yield ArgsKwargs(mask_loader, angle=15.0)
349
350


351
352
353
354
355
def sample_inputs_rotate_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, angle=15.0)


356
357
358
359
360
361
362
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.rotate_image_tensor,
            sample_inputs_fn=sample_inputs_rotate_image_tensor,
            reference_fn=pil_reference_wrapper(F.rotate_image_pil),
            reference_inputs_fn=reference_inputs_rotate_image_tensor,
363
            float32_vs_uint8=True,
364
            closeness_kwargs=pil_reference_pixel_difference(1, mae=True),
365
            test_marks=[
366
                xfail_jit_python_scalar_arg("fill"),
367
            ],
368
369
370
371
        ),
        KernelInfo(
            F.rotate_bounding_box,
            sample_inputs_fn=sample_inputs_rotate_bounding_box,
372
373
            reference_fn=reference_rotate_bounding_box,
            reference_inputs_fn=reference_inputs_rotate_bounding_box,
374
            closeness_kwargs={
375
376
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-4, rtol=1e-4),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-4, rtol=1e-4),
377
            },
378
379
380
381
382
        ),
        KernelInfo(
            F.rotate_mask,
            sample_inputs_fn=sample_inputs_rotate_mask,
        ),
383
384
385
386
        KernelInfo(
            F.rotate_video,
            sample_inputs_fn=sample_inputs_rotate_video,
        ),
387
388
389
390
391
392
393
    ]
)

_CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20], width=[12, 20])


def sample_inputs_crop_image_tensor():
394
    for image_loader, params in itertools.product(
395
        make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
396
397
398
399
400
401
402
403
        [
            dict(top=4, left=3, height=7, width=8),
            dict(top=-1, left=3, height=7, width=8),
            dict(top=4, left=-1, height=7, width=8),
            dict(top=4, left=3, height=17, width=8),
            dict(top=4, left=3, height=7, width=18),
        ],
    ):
404
405
406
407
        yield ArgsKwargs(image_loader, **params)


def reference_inputs_crop_image_tensor():
408
409
410
    for image_loader, params in itertools.product(
        make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _CROP_PARAMS
    ):
411
412
413
414
415
416
417
        yield ArgsKwargs(image_loader, **params)


def sample_inputs_crop_bounding_box():
    for bounding_box_loader, params in itertools.product(
        make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
    ):
418
        yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
419
420
421


def sample_inputs_crop_mask():
422
423
    for mask_loader in make_mask_loaders(sizes=[(16, 17)], num_categories=["random"], num_objects=["random"]):
        yield ArgsKwargs(mask_loader, top=4, left=3, height=7, width=8)
424
425
426
427
428
429
430


def reference_inputs_crop_mask():
    for mask_loader, params in itertools.product(make_mask_loaders(extra_dims=[()], num_objects=[1]), _CROP_PARAMS):
        yield ArgsKwargs(mask_loader, **params)


431
432
433
434
435
def sample_inputs_crop_video():
    for video_loader in make_video_loaders(sizes=[(16, 17)], num_frames=["random"]):
        yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8)


436
437
438
439
440
441
def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width):
    affine_matrix = np.array(
        [
            [1, 0, -left],
            [0, 1, -top],
        ],
442
        dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
443
444
    )

445
446
447
448
449
    spatial_size = (height, width)
    expected_bboxes = reference_affine_bounding_box_helper(
        bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
    )
    return expected_bboxes, spatial_size
450
451
452
453
454
455
456
457
458


def reference_inputs_crop_bounding_box():
    for bounding_box_loader, params in itertools.product(
        make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
    ):
        yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)


459
460
461
462
463
464
465
466
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.crop_image_tensor,
            kernel_name="crop_image_tensor",
            sample_inputs_fn=sample_inputs_crop_image_tensor,
            reference_fn=pil_reference_wrapper(F.crop_image_pil),
            reference_inputs_fn=reference_inputs_crop_image_tensor,
467
            float32_vs_uint8=True,
468
469
470
471
        ),
        KernelInfo(
            F.crop_bounding_box,
            sample_inputs_fn=sample_inputs_crop_bounding_box,
472
473
            reference_fn=reference_crop_bounding_box,
            reference_inputs_fn=reference_inputs_crop_bounding_box,
474
475
476
477
478
479
        ),
        KernelInfo(
            F.crop_mask,
            sample_inputs_fn=sample_inputs_crop_mask,
            reference_fn=pil_reference_wrapper(F.crop_image_pil),
            reference_inputs_fn=reference_inputs_crop_mask,
480
            float32_vs_uint8=True,
481
        ),
482
483
484
485
        KernelInfo(
            F.crop_video,
            sample_inputs_fn=sample_inputs_crop_video,
        ),
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    ]
)

_RESIZED_CROP_PARAMS = combinations_grid(top=[-8, 9], left=[-8, 9], height=[12], width=[12], size=[(16, 18)])


def sample_inputs_resized_crop_image_tensor():
    for image_loader in make_image_loaders():
        yield ArgsKwargs(image_loader, **_RESIZED_CROP_PARAMS[0])


@pil_reference_wrapper
def reference_resized_crop_image_tensor(*args, **kwargs):
    if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in {
        F.InterpolationMode.BILINEAR,
        F.InterpolationMode.BICUBIC,
    }:
        raise pytest.UsageError("Anti-aliasing is always active in PIL")
    return F.resized_crop_image_pil(*args, **kwargs)


def reference_inputs_resized_crop_image_tensor():
    for image_loader, interpolation, params in itertools.product(
509
        make_image_loaders_for_interpolation(),
510
511
        [
            F.InterpolationMode.NEAREST,
512
            F.InterpolationMode.NEAREST_EXACT,
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
            F.InterpolationMode.BILINEAR,
            F.InterpolationMode.BICUBIC,
        ],
        _RESIZED_CROP_PARAMS,
    ):
        yield ArgsKwargs(
            image_loader,
            interpolation=interpolation,
            antialias=interpolation
            in {
                F.InterpolationMode.BILINEAR,
                F.InterpolationMode.BICUBIC,
            },
            **params,
        )


def sample_inputs_resized_crop_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **_RESIZED_CROP_PARAMS[0])


def sample_inputs_resized_crop_mask():
    for mask_loader in make_mask_loaders():
        yield ArgsKwargs(mask_loader, **_RESIZED_CROP_PARAMS[0])


540
541
542
543
544
def sample_inputs_resized_crop_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, **_RESIZED_CROP_PARAMS[0])


545
546
547
548
549
550
551
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.resized_crop_image_tensor,
            sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
            reference_fn=reference_resized_crop_image_tensor,
            reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
552
            float32_vs_uint8=True,
553
            closeness_kwargs={
554
                **cuda_vs_cpu_pixel_difference(),
555
556
                **pil_reference_pixel_difference(3, mae=True),
                **float32_vs_uint8_pixel_difference(3, mae=True),
557
            },
558
559
560
561
562
563
564
565
566
        ),
        KernelInfo(
            F.resized_crop_bounding_box,
            sample_inputs_fn=sample_inputs_resized_crop_bounding_box,
        ),
        KernelInfo(
            F.resized_crop_mask,
            sample_inputs_fn=sample_inputs_resized_crop_mask,
        ),
567
568
569
        KernelInfo(
            F.resized_crop_video,
            sample_inputs_fn=sample_inputs_resized_crop_video,
570
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
571
        ),
572
573
574
575
576
577
578
579
580
581
    ]
)

_PAD_PARAMS = combinations_grid(
    padding=[[1], [1, 1], [1, 1, 2, 2]],
    padding_mode=["constant", "symmetric", "edge", "reflect"],
)


def sample_inputs_pad_image_tensor():
582
    make_pad_image_loaders = functools.partial(
583
        make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
584
585
586
587
588
589
590
591
592
    )

    for image_loader, padding in itertools.product(
        make_pad_image_loaders(),
        [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]],
    ):
        yield ArgsKwargs(image_loader, padding=padding)

    for image_loader in make_pad_image_loaders():
593
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
594
595
596
597
598
599
600
601
602
603
604
605
606
            yield ArgsKwargs(image_loader, padding=[1], fill=fill)

    for image_loader, padding_mode in itertools.product(
        # We branch for non-constant padding and integer inputs
        make_pad_image_loaders(dtypes=[torch.uint8]),
        ["constant", "symmetric", "edge", "reflect"],
    ):
        yield ArgsKwargs(image_loader, padding=[1], padding_mode=padding_mode)

    # `torch.nn.functional.pad` does not support symmetric padding, and thus we have a custom implementation. Besides
    # negative padding, this is already handled by the inputs above.
    for image_loader in make_pad_image_loaders():
        yield ArgsKwargs(image_loader, padding=[-1], padding_mode="symmetric")
607
608
609


def reference_inputs_pad_image_tensor():
610
611
612
613
614
615
616
    for image_loader, params in itertools.product(
        make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
    ):
        for fill in get_fills(
            num_channels=image_loader.num_channels,
            dtype=image_loader.dtype,
        ):
617
618
619
620
            # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
            if isinstance(fill, (list, tuple)):
                continue

621
622
623
624
            yield ArgsKwargs(image_loader, fill=fill, **params)


def sample_inputs_pad_bounding_box():
625
626
627
    for bounding_box_loader, padding in itertools.product(
        make_bounding_box_loaders(), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
    ):
628
        yield ArgsKwargs(
629
630
            bounding_box_loader,
            format=bounding_box_loader.format,
631
            spatial_size=bounding_box_loader.spatial_size,
632
633
            padding=padding,
            padding_mode="constant",
634
        )
635
636
637


def sample_inputs_pad_mask():
638
639
    for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
        yield ArgsKwargs(mask_loader, padding=[1])
640
641
642


def reference_inputs_pad_mask():
643
644
645
646
    for mask_loader, fill, params in itertools.product(
        make_mask_loaders(num_objects=[1], extra_dims=[()]), [None, 127], _PAD_PARAMS
    ):
        yield ArgsKwargs(mask_loader, fill=fill, **params)
647
648


649
650
651
652
653
def sample_inputs_pad_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, padding=[1])


654
655
656
657
658
659
660
661
662
def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, padding_mode):

    left, right, top, bottom = _parse_pad_padding(padding)

    affine_matrix = np.array(
        [
            [1, 0, left],
            [0, 1, top],
        ],
663
        dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
664
665
666
667
668
    )

    height = spatial_size[0] + top + bottom
    width = spatial_size[1] + left + right

669
670
671
    expected_bboxes = reference_affine_bounding_box_helper(
        bounding_box, format=format, spatial_size=(height, width), affine_matrix=affine_matrix
    )
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
    return expected_bboxes, (height, width)


def reference_inputs_pad_bounding_box():
    for bounding_box_loader, padding in itertools.product(
        make_bounding_box_loaders(extra_dims=((), (4,))), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
    ):
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
            spatial_size=bounding_box_loader.spatial_size,
            padding=padding,
            padding_mode="constant",
        )


688
689
690
691
692
693
694
695
696
697
def pad_xfail_jit_fill_condition(args_kwargs):
    fill = args_kwargs.kwargs.get("fill")
    if not isinstance(fill, (list, tuple)):
        return False
    elif isinstance(fill, tuple):
        return True
    else:  # isinstance(fill, list):
        return all(isinstance(f, int) for f in fill)


698
699
700
701
702
703
704
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.pad_image_tensor,
            sample_inputs_fn=sample_inputs_pad_image_tensor,
            reference_fn=pil_reference_wrapper(F.pad_image_pil),
            reference_inputs_fn=reference_inputs_pad_image_tensor,
705
706
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
707
            test_marks=[
708
709
710
711
                xfail_jit_python_scalar_arg("padding"),
                xfail_jit(
                    "F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition
                ),
712
            ],
713
714
715
716
        ),
        KernelInfo(
            F.pad_bounding_box,
            sample_inputs_fn=sample_inputs_pad_bounding_box,
717
718
            reference_fn=reference_pad_bounding_box,
            reference_inputs_fn=reference_inputs_pad_bounding_box,
719
            test_marks=[
720
                xfail_jit_python_scalar_arg("padding"),
721
            ],
722
723
724
725
726
727
        ),
        KernelInfo(
            F.pad_mask,
            sample_inputs_fn=sample_inputs_pad_mask,
            reference_fn=pil_reference_wrapper(F.pad_image_pil),
            reference_inputs_fn=reference_inputs_pad_mask,
728
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
729
        ),
730
731
732
733
        KernelInfo(
            F.pad_video,
            sample_inputs_fn=sample_inputs_pad_video,
        ),
734
735
736
737
738
739
740
    ]
)

_PERSPECTIVE_COEFFS = [
    [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
    [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
]
741
742
_STARTPOINTS = [[0, 1], [2, 3], [4, 5], [6, 7]]
_ENDPOINTS = [[9, 8], [7, 6], [5, 4], [3, 2]]
743
744
745


def sample_inputs_perspective_image_tensor():
746
    for image_loader in make_image_loaders(sizes=["random"]):
747
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
748
749
750
751
752
            yield ArgsKwargs(
                image_loader, startpoints=None, endpoints=None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0]
            )

    yield ArgsKwargs(make_image_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
753
754
755


def reference_inputs_perspective_image_tensor():
756
757
758
759
760
761
762
    for image_loader, coefficients, interpolation in itertools.product(
        make_image_loaders_for_interpolation(),
        _PERSPECTIVE_COEFFS,
        [
            F.InterpolationMode.NEAREST,
            F.InterpolationMode.BILINEAR,
        ],
763
764
    ):
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
765
766
767
768
            # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
            if isinstance(fill, (list, tuple)):
                continue

769
770
771
772
773
774
775
776
            yield ArgsKwargs(
                image_loader,
                startpoints=None,
                endpoints=None,
                interpolation=interpolation,
                fill=fill,
                coefficients=coefficients,
            )
777
778
779
780
781


def sample_inputs_perspective_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(
782
783
            bounding_box_loader,
            format=bounding_box_loader.format,
784
            spatial_size=bounding_box_loader.spatial_size,
785
786
787
            startpoints=None,
            endpoints=None,
            coefficients=_PERSPECTIVE_COEFFS[0],
788
789
        )

790
    format = datapoints.BoundingBoxFormat.XYXY
791
    loader = make_bounding_box_loader(format=format)
792
    yield ArgsKwargs(
793
        loader, format=format, spatial_size=loader.spatial_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
794
795
    )

796
797

def sample_inputs_perspective_mask():
798
    for mask_loader in make_mask_loaders(sizes=["random"]):
799
800
801
        yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])

    yield ArgsKwargs(make_detection_mask_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
802
803
804
805
806
807


def reference_inputs_perspective_mask():
    for mask_loader, perspective_coeffs in itertools.product(
        make_mask_loaders(extra_dims=[()], num_objects=[1]), _PERSPECTIVE_COEFFS
    ):
808
        yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=perspective_coeffs)
809
810


811
812
def sample_inputs_perspective_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
813
814
815
        yield ArgsKwargs(video_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])

    yield ArgsKwargs(make_video_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
816
817


818
819
820
821
822
823
824
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.perspective_image_tensor,
            sample_inputs_fn=sample_inputs_perspective_image_tensor,
            reference_fn=pil_reference_wrapper(F.perspective_image_pil),
            reference_inputs_fn=reference_inputs_perspective_image_tensor,
825
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
826
            closeness_kwargs={
827
                **pil_reference_pixel_difference(2, mae=True),
828
829
                **cuda_vs_cpu_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
830
831
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
832
            },
833
            test_marks=[xfail_jit_python_scalar_arg("fill")],
834
835
836
837
        ),
        KernelInfo(
            F.perspective_bounding_box,
            sample_inputs_fn=sample_inputs_perspective_bounding_box,
838
839
840
841
            closeness_kwargs={
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6),
            },
842
843
844
845
846
847
        ),
        KernelInfo(
            F.perspective_mask,
            sample_inputs_fn=sample_inputs_perspective_mask,
            reference_fn=pil_reference_wrapper(F.perspective_image_pil),
            reference_inputs_fn=reference_inputs_perspective_mask,
848
849
850
851
            float32_vs_uint8=True,
            closeness_kwargs={
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=10, rtol=0),
            },
852
853
854
855
        ),
        KernelInfo(
            F.perspective_video,
            sample_inputs_fn=sample_inputs_perspective_video,
856
857
            closeness_kwargs={
                **cuda_vs_cpu_pixel_difference(),
858
859
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
860
            },
861
862
863
864
865
        ),
    ]
)


866
867
def _get_elastic_displacement(spatial_size):
    return torch.rand(1, *spatial_size, 2)
868
869
870


def sample_inputs_elastic_image_tensor():
871
    for image_loader in make_image_loaders(sizes=["random"]):
872
        displacement = _get_elastic_displacement(image_loader.spatial_size)
873
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
874
875
876
877
878
            yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)


def reference_inputs_elastic_image_tensor():
    for image_loader, interpolation in itertools.product(
879
        make_image_loaders_for_interpolation(),
880
881
882
883
884
885
        [
            F.InterpolationMode.NEAREST,
            F.InterpolationMode.BILINEAR,
            F.InterpolationMode.BICUBIC,
        ],
    ):
886
        displacement = _get_elastic_displacement(image_loader.spatial_size)
887
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
888
889
890
891
892
            yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)


def sample_inputs_elastic_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
893
        displacement = _get_elastic_displacement(bounding_box_loader.spatial_size)
894
895
896
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
897
            spatial_size=bounding_box_loader.spatial_size,
898
899
900
901
902
            displacement=displacement,
        )


def sample_inputs_elastic_mask():
903
    for mask_loader in make_mask_loaders(sizes=["random"]):
904
905
906
907
        displacement = _get_elastic_displacement(mask_loader.shape[-2:])
        yield ArgsKwargs(mask_loader, displacement=displacement)


908
909
910
911
912
913
def sample_inputs_elastic_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        displacement = _get_elastic_displacement(video_loader.shape[-2:])
        yield ArgsKwargs(video_loader, displacement=displacement)


914
915
916
917
918
919
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.elastic_image_tensor,
            sample_inputs_fn=sample_inputs_elastic_image_tensor,
            reference_inputs_fn=reference_inputs_elastic_image_tensor,
920
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
921
            closeness_kwargs={
922
                **float32_vs_uint8_pixel_difference(6, mae=True),
923
924
                **cuda_vs_cpu_pixel_difference(),
            },
925
            test_marks=[xfail_jit_python_scalar_arg("fill")],
926
927
928
929
930
931
932
933
        ),
        KernelInfo(
            F.elastic_bounding_box,
            sample_inputs_fn=sample_inputs_elastic_bounding_box,
        ),
        KernelInfo(
            F.elastic_mask,
            sample_inputs_fn=sample_inputs_elastic_mask,
934
935
936
937
        ),
        KernelInfo(
            F.elastic_video,
            sample_inputs_fn=sample_inputs_elastic_video,
938
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
939
940
941
942
943
        ),
    ]
)


944
_CENTER_CROP_SPATIAL_SIZES = [(16, 16), (7, 33), (31, 9)]
945
_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
946
947
948
949


def sample_inputs_center_crop_image_tensor():
    for image_loader, output_size in itertools.product(
950
        make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
951
952
953
954
955
956
        [
            # valid `output_size` types for which cropping is applied to both dimensions
            *[5, (4,), (2, 3), [6], [3, 2]],
            # `output_size`'s for which at least one dimension needs to be padded
            *[[4, 18], [17, 5], [17, 18]],
        ],
957
958
959
960
961
962
    ):
        yield ArgsKwargs(image_loader, output_size=output_size)


def reference_inputs_center_crop_image_tensor():
    for image_loader, output_size in itertools.product(
963
964
        make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], dtypes=[torch.uint8]),
        _CENTER_CROP_OUTPUT_SIZES,
965
966
967
968
969
970
971
972
973
    ):
        yield ArgsKwargs(image_loader, output_size=output_size)


def sample_inputs_center_crop_bounding_box():
    for bounding_box_loader, output_size in itertools.product(make_bounding_box_loaders(), _CENTER_CROP_OUTPUT_SIZES):
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
974
            spatial_size=bounding_box_loader.spatial_size,
975
976
977
978
979
            output_size=output_size,
        )


def sample_inputs_center_crop_mask():
980
981
982
    for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
        height, width = mask_loader.shape[-2:]
        yield ArgsKwargs(mask_loader, output_size=(height // 2, width // 2))
983
984
985
986


def reference_inputs_center_crop_mask():
    for mask_loader, output_size in itertools.product(
987
        make_mask_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES
988
989
990
991
    ):
        yield ArgsKwargs(mask_loader, output_size=output_size)


992
993
994
995
996
997
def sample_inputs_center_crop_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        height, width = video_loader.shape[-2:]
        yield ArgsKwargs(video_loader, output_size=(height // 2, width // 2))


998
999
1000
1001
1002
1003
1004
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.center_crop_image_tensor,
            sample_inputs_fn=sample_inputs_center_crop_image_tensor,
            reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
            reference_inputs_fn=reference_inputs_center_crop_image_tensor,
1005
            float32_vs_uint8=True,
1006
            test_marks=[
1007
                xfail_jit_python_scalar_arg("output_size"),
1008
            ],
1009
1010
1011
1012
        ),
        KernelInfo(
            F.center_crop_bounding_box,
            sample_inputs_fn=sample_inputs_center_crop_bounding_box,
1013
            test_marks=[
1014
                xfail_jit_python_scalar_arg("output_size"),
1015
            ],
1016
1017
1018
1019
1020
1021
        ),
        KernelInfo(
            F.center_crop_mask,
            sample_inputs_fn=sample_inputs_center_crop_mask,
            reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
            reference_inputs_fn=reference_inputs_center_crop_mask,
1022
            float32_vs_uint8=True,
1023
            test_marks=[
1024
                xfail_jit_python_scalar_arg("output_size"),
1025
            ],
1026
        ),
1027
1028
1029
1030
        KernelInfo(
            F.center_crop_video,
            sample_inputs_fn=sample_inputs_center_crop_video,
        ),
1031
1032
1033
1034
1035
    ]
)


def sample_inputs_gaussian_blur_image_tensor():
1036
    make_gaussian_blur_image_loaders = functools.partial(make_image_loaders, sizes=[(7, 33)], color_spaces=["RGB"])
1037
1038
1039
1040
1041
1042

    for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
        yield ArgsKwargs(image_loader, kernel_size=kernel_size)

    for image_loader, sigma in itertools.product(
        make_gaussian_blur_image_loaders(), [None, (3.0, 3.0), [2.0, 2.0], 4.0, [1.5], (3.14,)]
1043
    ):
1044
        yield ArgsKwargs(image_loader, kernel_size=5, sigma=sigma)
1045
1046


1047
def sample_inputs_gaussian_blur_video():
1048
    for video_loader in make_video_loaders(sizes=[(7, 33)], num_frames=[5]):
1049
1050
1051
1052
1053
1054
1055
1056
        yield ArgsKwargs(video_loader, kernel_size=[3, 3])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.gaussian_blur_image_tensor,
            sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
1057
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
1058
1059
1060
1061
1062
1063
1064
1065
            test_marks=[
                xfail_jit_python_scalar_arg("kernel_size"),
                xfail_jit_python_scalar_arg("sigma"),
            ],
        ),
        KernelInfo(
            F.gaussian_blur_video,
            sample_inputs_fn=sample_inputs_gaussian_blur_video,
1066
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
1067
1068
        ),
    ]
1069
1070
1071
1072
)


def sample_inputs_equalize_image_tensor():
1073
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1074
1075
1076
1077
        yield ArgsKwargs(image_loader)


def reference_inputs_equalize_image_tensor():
1078
1079
1080
    # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
    # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
    # the information gain is low if we already provide something really close to the expected value.
1081
    def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
1082
1083
1084
1085
1086
1087
1088
        if dtype.is_floating_point:
            low = low_factor
            high = high_factor
        else:
            max_value = torch.iinfo(dtype).max
            low = int(low_factor * max_value)
            high = int(high_factor * max_value)
1089
1090
1091
        return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
            memory_format=memory_format, copy=True
        )
1092

1093
    def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
1094
1095
1096
        image = torch.distributions.Beta(alpha, beta).sample(shape)
        if not dtype.is_floating_point:
            image.mul_(torch.iinfo(dtype).max).round_()
1097
        return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)
1098

1099
    spatial_size = (256, 256)
1100
    for dtype, color_space, fn in itertools.product(
1101
        [torch.uint8],
1102
        ["GRAY", "RGB"],
1103
        [
1104
1105
            lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
                memory_format=memory_format, copy=True
1106
            ),
1107
1108
1109
            lambda shape, dtype, device, memory_format: torch.full(
                shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
            ).to(memory_format=memory_format, copy=True),
1110
            *[
1111
1112
1113
1114
1115
                functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
                for low_factor, high_factor in [
                    (0.0, 0.25),
                    (0.25, 0.75),
                    (0.75, 1.0),
1116
1117
1118
                ]
            ],
            *[
1119
                functools.partial(make_beta_distributed_image, alpha=alpha, beta=beta)
1120
1121
1122
1123
1124
1125
1126
1127
                for alpha, beta in [
                    (0.5, 0.5),
                    (2, 2),
                    (2, 5),
                    (5, 2),
                ]
            ],
        ],
1128
    ):
1129
        image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype)
1130
1131
1132
        yield ArgsKwargs(image_loader)


1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
def sample_inputs_equalize_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.equalize_image_tensor,
            kernel_name="equalize_image_tensor",
            sample_inputs_fn=sample_inputs_equalize_image_tensor,
            reference_fn=pil_reference_wrapper(F.equalize_image_pil),
1145
            float32_vs_uint8=True,
1146
1147
1148
1149
1150
1151
1152
            reference_inputs_fn=reference_inputs_equalize_image_tensor,
        ),
        KernelInfo(
            F.equalize_video,
            sample_inputs_fn=sample_inputs_equalize_video,
        ),
    ]
1153
1154
1155
1156
)


def sample_inputs_invert_image_tensor():
1157
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1158
1159
1160
1161
        yield ArgsKwargs(image_loader)


def reference_inputs_invert_image_tensor():
1162
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1163
1164
1165
        yield ArgsKwargs(image_loader)


1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
def sample_inputs_invert_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.invert_image_tensor,
            kernel_name="invert_image_tensor",
            sample_inputs_fn=sample_inputs_invert_image_tensor,
            reference_fn=pil_reference_wrapper(F.invert_image_pil),
            reference_inputs_fn=reference_inputs_invert_image_tensor,
1179
            float32_vs_uint8=True,
1180
1181
1182
1183
1184
1185
        ),
        KernelInfo(
            F.invert_video,
            sample_inputs_fn=sample_inputs_invert_video,
        ),
    ]
1186
1187
1188
1189
1190
1191
1192
)


_POSTERIZE_BITS = [1, 4, 8]


def sample_inputs_posterize_image_tensor():
1193
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1194
1195
1196
1197
1198
        yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])


def reference_inputs_posterize_image_tensor():
    for image_loader, bits in itertools.product(
1199
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1200
1201
1202
1203
1204
        _POSTERIZE_BITS,
    ):
        yield ArgsKwargs(image_loader, bits=bits)


1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
def sample_inputs_posterize_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.posterize_image_tensor,
            kernel_name="posterize_image_tensor",
            sample_inputs_fn=sample_inputs_posterize_image_tensor,
            reference_fn=pil_reference_wrapper(F.posterize_image_pil),
            reference_inputs_fn=reference_inputs_posterize_image_tensor,
1218
1219
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
1220
1221
1222
1223
1224
1225
        ),
        KernelInfo(
            F.posterize_video,
            sample_inputs_fn=sample_inputs_posterize_video,
        ),
    ]
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
)


def _get_solarize_thresholds(dtype):
    for factor in [0.1, 0.5]:
        max_value = get_max_value(dtype)
        yield (float if dtype.is_floating_point else int)(max_value * factor)


def sample_inputs_solarize_image_tensor():
1236
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1237
1238
1239
1240
        yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype)))


def reference_inputs_solarize_image_tensor():
1241
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1242
1243
1244
1245
        for threshold in _get_solarize_thresholds(image_loader.dtype):
            yield ArgsKwargs(image_loader, threshold=threshold)


1246
1247
1248
1249
def uint8_to_float32_threshold_adapter(other_args, kwargs):
    return other_args, dict(threshold=kwargs["threshold"] / 255)


1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
def sample_inputs_solarize_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype)))


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.solarize_image_tensor,
            kernel_name="solarize_image_tensor",
            sample_inputs_fn=sample_inputs_solarize_image_tensor,
            reference_fn=pil_reference_wrapper(F.solarize_image_pil),
            reference_inputs_fn=reference_inputs_solarize_image_tensor,
1263
1264
            float32_vs_uint8=uint8_to_float32_threshold_adapter,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
1265
1266
1267
1268
1269
1270
        ),
        KernelInfo(
            F.solarize_video,
            sample_inputs_fn=sample_inputs_solarize_video,
        ),
    ]
1271
1272
1273
1274
)


def sample_inputs_autocontrast_image_tensor():
1275
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1276
1277
1278
1279
        yield ArgsKwargs(image_loader)


def reference_inputs_autocontrast_image_tensor():
1280
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1281
1282
1283
        yield ArgsKwargs(image_loader)


1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
def sample_inputs_autocontrast_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.autocontrast_image_tensor,
            kernel_name="autocontrast_image_tensor",
            sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
            reference_fn=pil_reference_wrapper(F.autocontrast_image_pil),
            reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
1297
1298
1299
1300
1301
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
            },
1302
1303
1304
1305
1306
1307
        ),
        KernelInfo(
            F.autocontrast_video,
            sample_inputs_fn=sample_inputs_autocontrast_video,
        ),
    ]
1308
1309
1310
1311
1312
1313
1314
1315
)

_ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_sharpness_image_tensor():
    for image_loader in make_image_loaders(
        sizes=["random", (2, 2)],
1316
        color_spaces=("GRAY", "RGB"),
1317
1318
1319
1320
1321
1322
    ):
        yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])


def reference_inputs_adjust_sharpness_image_tensor():
    for image_loader, sharpness_factor in itertools.product(
1323
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1324
1325
1326
1327
1328
        _ADJUST_SHARPNESS_FACTORS,
    ):
        yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)


1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
def sample_inputs_adjust_sharpness_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_sharpness_image_tensor,
            kernel_name="adjust_sharpness_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil),
            reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
1342
1343
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(2),
1344
1345
1346
1347
1348
1349
        ),
        KernelInfo(
            F.adjust_sharpness_video,
            sample_inputs_fn=sample_inputs_adjust_sharpness_video,
        ),
    ]
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
)


def sample_inputs_erase_image_tensor():
    for image_loader in make_image_loaders(sizes=["random"]):
        # FIXME: make the parameters more diverse
        h, w = 6, 7
        v = torch.rand(image_loader.num_channels, h, w)
        yield ArgsKwargs(image_loader, i=1, j=2, h=h, w=w, v=v)


1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
def sample_inputs_erase_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        # FIXME: make the parameters more diverse
        h, w = 6, 7
        v = torch.rand(video_loader.num_channels, h, w)
        yield ArgsKwargs(video_loader, i=1, j=2, h=h, w=w, v=v)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.erase_image_tensor,
            kernel_name="erase_image_tensor",
            sample_inputs_fn=sample_inputs_erase_image_tensor,
        ),
        KernelInfo(
            F.erase_video,
            sample_inputs_fn=sample_inputs_erase_video,
        ),
    ]
1381
)
1382
1383
1384
1385
1386

_ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_brightness_image_tensor():
1387
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1388
1389
1390
1391
1392
        yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])


def reference_inputs_adjust_brightness_image_tensor():
    for image_loader, brightness_factor in itertools.product(
1393
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1394
1395
1396
1397
1398
        _ADJUST_BRIGHTNESS_FACTORS,
    ):
        yield ArgsKwargs(image_loader, brightness_factor=brightness_factor)


1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
def sample_inputs_adjust_brightness_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_brightness_image_tensor,
            kernel_name="adjust_brightness_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
            reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
1412
1413
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
1414
1415
1416
1417
1418
1419
        ),
        KernelInfo(
            F.adjust_brightness_video,
            sample_inputs_fn=sample_inputs_adjust_brightness_video,
        ),
    ]
1420
1421
1422
1423
1424
1425
1426
)


_ADJUST_CONTRAST_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_contrast_image_tensor():
1427
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1428
1429
1430
1431
1432
        yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])


def reference_inputs_adjust_contrast_image_tensor():
    for image_loader, contrast_factor in itertools.product(
1433
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1434
1435
1436
1437
1438
        _ADJUST_CONTRAST_FACTORS,
    ):
        yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)


1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
def sample_inputs_adjust_contrast_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_contrast_image_tensor,
            kernel_name="adjust_contrast_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil),
            reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
1452
1453
1454
1455
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(2),
1456
                **cuda_vs_cpu_pixel_difference(),
1457
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
1458
            },
1459
1460
1461
1462
        ),
        KernelInfo(
            F.adjust_contrast_video,
            sample_inputs_fn=sample_inputs_adjust_contrast_video,
1463
1464
1465
1466
            closeness_kwargs={
                **cuda_vs_cpu_pixel_difference(),
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
            },
1467
1468
        ),
    ]
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
)

_ADJUST_GAMMA_GAMMAS_GAINS = [
    (0.5, 2.0),
    (0.0, 1.0),
]


def sample_inputs_adjust_gamma_image_tensor():
    gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
1479
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1480
1481
1482
1483
1484
        yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)


def reference_inputs_adjust_gamma_image_tensor():
    for image_loader, (gamma, gain) in itertools.product(
1485
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1486
1487
1488
1489
1490
        _ADJUST_GAMMA_GAMMAS_GAINS,
    ):
        yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)


1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
def sample_inputs_adjust_gamma_video():
    gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, gamma=gamma, gain=gain)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_gamma_image_tensor,
            kernel_name="adjust_gamma_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
            reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
1505
1506
1507
1508
1509
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
            },
1510
1511
1512
1513
1514
1515
        ),
        KernelInfo(
            F.adjust_gamma_video,
            sample_inputs_fn=sample_inputs_adjust_gamma_video,
        ),
    ]
1516
1517
1518
1519
1520
1521
1522
)


_ADJUST_HUE_FACTORS = [-0.1, 0.5]


def sample_inputs_adjust_hue_image_tensor():
1523
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1524
1525
1526
1527
1528
        yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])


def reference_inputs_adjust_hue_image_tensor():
    for image_loader, hue_factor in itertools.product(
1529
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1530
1531
1532
1533
1534
        _ADJUST_HUE_FACTORS,
    ):
        yield ArgsKwargs(image_loader, hue_factor=hue_factor)


1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
def sample_inputs_adjust_hue_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_hue_image_tensor,
            kernel_name="adjust_hue_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil),
            reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
1548
1549
            float32_vs_uint8=True,
            closeness_kwargs={
1550
                **pil_reference_pixel_difference(2, mae=True),
1551
1552
                **float32_vs_uint8_pixel_difference(),
            },
1553
1554
1555
1556
1557
1558
        ),
        KernelInfo(
            F.adjust_hue_video,
            sample_inputs_fn=sample_inputs_adjust_hue_video,
        ),
    ]
1559
1560
1561
1562
1563
1564
)

_ADJUST_SATURATION_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_saturation_image_tensor():
1565
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1566
1567
1568
1569
1570
        yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])


def reference_inputs_adjust_saturation_image_tensor():
    for image_loader, saturation_factor in itertools.product(
1571
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1572
1573
1574
1575
1576
        _ADJUST_SATURATION_FACTORS,
    ):
        yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)


1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
def sample_inputs_adjust_saturation_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_saturation_image_tensor,
            kernel_name="adjust_saturation_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil),
            reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
1590
1591
1592
1593
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(2),
1594
                **cuda_vs_cpu_pixel_difference(),
1595
            },
1596
1597
1598
1599
        ),
        KernelInfo(
            F.adjust_saturation_video,
            sample_inputs_fn=sample_inputs_adjust_saturation_video,
1600
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
1601
1602
        ),
    ]
1603
1604
1605
1606
1607
1608
)


def sample_inputs_clamp_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(
1609
            bounding_box_loader,
1610
1611
            format=bounding_box_loader.format,
            spatial_size=bounding_box_loader.spatial_size,
1612
1613
1614
1615
1616
1617
1618
        )


KERNEL_INFOS.append(
    KernelInfo(
        F.clamp_bounding_box,
        sample_inputs_fn=sample_inputs_clamp_bounding_box,
1619
        logs_usage=True,
1620
1621
1622
1623
1624
1625
    )
)

_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]


1626
def _get_five_ten_crop_spatial_size(size):
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
    if isinstance(size, int):
        crop_height = crop_width = size
    elif len(size) == 1:
        crop_height = crop_width = size[0]
    else:
        crop_height, crop_width = size
    return 2 * crop_height, 2 * crop_width


def sample_inputs_five_crop_image_tensor():
    for size in _FIVE_TEN_CROP_SIZES:
1638
        for image_loader in make_image_loaders(
1639
            sizes=[_get_five_ten_crop_spatial_size(size)],
1640
            color_spaces=["RGB"],
1641
            dtypes=[torch.float32],
1642
        ):
1643
1644
1645
1646
1647
            yield ArgsKwargs(image_loader, size=size)


def reference_inputs_five_crop_image_tensor():
    for size in _FIVE_TEN_CROP_SIZES:
1648
1649
1650
        for image_loader in make_image_loaders(
            sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8]
        ):
1651
1652
1653
            yield ArgsKwargs(image_loader, size=size)


1654
1655
1656
1657
1658
1659
def sample_inputs_five_crop_video():
    size = _FIVE_TEN_CROP_SIZES[0]
    for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_spatial_size(size)]):
        yield ArgsKwargs(video_loader, size=size)


1660
1661
def sample_inputs_ten_crop_image_tensor():
    for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
1662
        for image_loader in make_image_loaders(
1663
            sizes=[_get_five_ten_crop_spatial_size(size)],
1664
            color_spaces=["RGB"],
1665
            dtypes=[torch.float32],
1666
        ):
1667
1668
1669
1670
1671
            yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


def reference_inputs_ten_crop_image_tensor():
    for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
1672
1673
1674
        for image_loader in make_image_loaders(
            sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8]
        ):
1675
1676
1677
            yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


1678
1679
1680
1681
1682
1683
def sample_inputs_ten_crop_video():
    size = _FIVE_TEN_CROP_SIZES[0]
    for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_spatial_size(size)]):
        yield ArgsKwargs(video_loader, size=size)


1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
def multi_crop_pil_reference_wrapper(pil_kernel):
    def wrapper(input_tensor, *other_args, **kwargs):
        output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
        return type(output)(
            F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype)
            for output_pil in output
        )

    return wrapper


1695
1696
1697
1698
1699
_common_five_ten_crop_marks = [
    xfail_jit_python_scalar_arg("size"),
    mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
]

1700
1701
1702
1703
1704
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.five_crop_image_tensor,
            sample_inputs_fn=sample_inputs_five_crop_image_tensor,
1705
            reference_fn=multi_crop_pil_reference_wrapper(F.five_crop_image_pil),
1706
            reference_inputs_fn=reference_inputs_five_crop_image_tensor,
1707
            test_marks=_common_five_ten_crop_marks,
1708
        ),
1709
1710
1711
1712
1713
        KernelInfo(
            F.five_crop_video,
            sample_inputs_fn=sample_inputs_five_crop_video,
            test_marks=_common_five_ten_crop_marks,
        ),
1714
1715
1716
        KernelInfo(
            F.ten_crop_image_tensor,
            sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
1717
            reference_fn=multi_crop_pil_reference_wrapper(F.ten_crop_image_pil),
1718
            reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
1719
            test_marks=_common_five_ten_crop_marks,
1720
        ),
1721
1722
1723
1724
1725
        KernelInfo(
            F.ten_crop_video,
            sample_inputs_fn=sample_inputs_ten_crop_video,
            test_marks=_common_five_ten_crop_marks,
        ),
1726
1727
1728
1729
1730
1731
    ]
)

_NORMALIZE_MEANS_STDS = [
    ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
1732
    (0.5, 2.0),
1733
1734
1735
1736
1737
]


def sample_inputs_normalize_image_tensor():
    for image_loader, (mean, std) in itertools.product(
1738
        make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]),
1739
1740
1741
1742
1743
        _NORMALIZE_MEANS_STDS,
    ):
        yield ArgsKwargs(image_loader, mean=mean, std=std)


1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
def reference_normalize_image_tensor(image, mean, std, inplace=False):
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)

    sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub
    return sub(image, mean).div_(std)


def reference_inputs_normalize_image_tensor():
    yield ArgsKwargs(
1754
        make_image_loader(size=(32, 32), color_space="RGB", extra_dims=[1]),
1755
1756
1757
1758
1759
        mean=[0.5, 0.5, 0.5],
        std=[1.0, 1.0, 1.0],
    )


1760
1761
1762
def sample_inputs_normalize_video():
    mean, std = _NORMALIZE_MEANS_STDS[0]
    for video_loader in make_video_loaders(
1763
        sizes=["random"], color_spaces=["RGB"], num_frames=["random"], dtypes=[torch.float32]
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
    ):
        yield ArgsKwargs(video_loader, mean=mean, std=std)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.normalize_image_tensor,
            kernel_name="normalize_image_tensor",
            sample_inputs_fn=sample_inputs_normalize_image_tensor,
1774
1775
            reference_fn=reference_normalize_image_tensor,
            reference_inputs_fn=reference_inputs_normalize_image_tensor,
1776
1777
1778
1779
            test_marks=[
                xfail_jit_python_scalar_arg("mean"),
                xfail_jit_python_scalar_arg("std"),
            ],
1780
1781
1782
1783
1784
1785
        ),
        KernelInfo(
            F.normalize_video,
            sample_inputs_fn=sample_inputs_normalize_video,
        ),
    ]
1786
)
1787
1788


1789
def sample_inputs_convert_dtype_image_tensor():
1790
1791
1792
1793
1794
1795
1796
    for input_dtype, output_dtype in itertools.product(
        [torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
    ):
        if input_dtype.is_floating_point and output_dtype == torch.int64:
            # conversion cannot be performed safely
            continue

1797
        for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[input_dtype]):
1798
1799
1800
            yield ArgsKwargs(image_loader, dtype=output_dtype)


1801
def reference_convert_dtype_image_tensor(image, dtype=torch.float):
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
    input_dtype = image.dtype
    output_dtype = dtype

    if output_dtype == input_dtype:
        return image

    def fn(value):
        if input_dtype.is_floating_point:
            if output_dtype.is_floating_point:
                return value
            else:
                return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
        else:
            input_max_value = torch.iinfo(input_dtype).max

            if output_dtype.is_floating_point:
                return float(decimal.Decimal(value) / input_max_value)
            else:
                output_max_value = torch.iinfo(output_dtype).max

                if input_max_value > output_max_value:
                    factor = (input_max_value + 1) // (output_max_value + 1)
                    return value // factor
                else:
                    factor = (output_max_value + 1) // (input_max_value + 1)
                    return value * factor

    return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)


1832
def reference_inputs_convert_dtype_image_tensor():
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
    for input_dtype, output_dtype in itertools.product(
        [
            torch.uint8,
            torch.int16,
            torch.int32,
            torch.int64,
            torch.float16,
            torch.float32,
            torch.float64,
            torch.bfloat16,
        ],
        repeat=2,
    ):
        if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or (
            input_dtype == torch.float64 and output_dtype == torch.int64
        ):
            continue

        if input_dtype.is_floating_point:
            data = [0.0, 0.5, 1.0]
        else:
            max_value = torch.iinfo(input_dtype).max
            data = [0, max_value // 2, max_value]
        image = torch.tensor(data, dtype=input_dtype)

        yield ArgsKwargs(image, dtype=output_dtype)


1861
1862
1863
1864
1865
def sample_inputs_convert_dtype_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


1866
1867
1868
1869
1870
skip_dtype_consistency = TestMark(
    ("TestKernels", "test_dtype_and_device_consistency"),
    pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
    condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32),
)
1871

1872
1873
1874
KERNEL_INFOS.extend(
    [
        KernelInfo(
1875
1876
1877
1878
            F.convert_dtype_image_tensor,
            sample_inputs_fn=sample_inputs_convert_dtype_image_tensor,
            reference_fn=reference_convert_dtype_image_tensor,
            reference_inputs_fn=reference_inputs_convert_dtype_image_tensor,
1879
            test_marks=[
1880
                skip_dtype_consistency,
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
                TestMark(
                    ("TestKernels", "test_against_reference"),
                    pytest.mark.xfail(reason="Conversion overflows"),
                    condition=lambda args_kwargs: (
                        args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
                        and not args_kwargs.kwargs["dtype"].is_floating_point
                    )
                    or (
                        args_kwargs.args[0].dtype in {torch.int32, torch.int64}
                        and args_kwargs.kwargs["dtype"] == torch.float16
                    ),
                ),
            ],
        ),
1895
1896
1897
        KernelInfo(
            F.convert_dtype_video,
            sample_inputs_fn=sample_inputs_convert_dtype_video,
1898
1899
1900
            test_marks=[
                skip_dtype_consistency,
            ],
1901
        ),
1902
1903
    ]
)
1904
1905
1906
1907


def sample_inputs_uniform_temporal_subsample_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]):
1908
        yield ArgsKwargs(video_loader, num_samples=2)
1909
1910


1911
def reference_uniform_temporal_subsample_video(x, num_samples):
1912
1913
    # Copy-pasted from
    # https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
1914
    t = x.shape[-4]
1915
1916
1917
1918
    assert num_samples > 0 and t > 0
    # Sample by nearest neighbor interpolation if num_samples > t.
    indices = torch.linspace(0, t - 1, num_samples)
    indices = torch.clamp(indices, 0, t - 1).long()
1919
    return torch.index_select(x, -4, indices)
1920
1921
1922


def reference_inputs_uniform_temporal_subsample_video():
1923
    for video_loader in make_video_loaders(sizes=["random"], color_spaces=["RGB"], num_frames=[10]):
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
        for num_samples in range(1, video_loader.shape[-4] + 1):
            yield ArgsKwargs(video_loader, num_samples)


KERNEL_INFOS.append(
    KernelInfo(
        F.uniform_temporal_subsample_video,
        sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
        reference_fn=reference_uniform_temporal_subsample_video,
        reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
    )
)