transforms_v2_kernel_infos.py 71.1 KB
Newer Older
1
import decimal
2
3
4
5
import functools
import itertools

import numpy as np
6
import PIL.Image
7
8
import pytest
import torch.testing
9
import torchvision.ops
10
import torchvision.transforms.v2.functional as F
11
from common_utils import (
12
    ArgsKwargs,
13
    combinations_grid,
14
15
    get_num_channels,
    ImageLoader,
16
    InfoBase,
17
    make_bounding_box_loader,
18
    make_bounding_box_loaders,
19
    make_detection_mask_loader,
20
21
    make_image_loader,
    make_image_loaders,
22
    make_image_loaders_for_interpolation,
23
    make_mask_loaders,
24
    make_video_loader,
25
    make_video_loaders,
26
27
    mark_framework_limitation,
    TestMark,
28
)
29
from torch.utils._pytree import tree_map
30
from torchvision import datapoints
31
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
32
33
34
35

__all__ = ["KernelInfo", "KERNEL_INFOS"]


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class KernelInfo(InfoBase):
    def __init__(
        self,
        kernel,
        *,
        # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
        # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
        kernel_name=None,
        # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
        # should not include extensive parameter combinations to keep to overall test count moderate.
        sample_inputs_fn,
        # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
        # take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
        # happen inside the function. It should return a tensor or to be more precise an object that can be compared to
        # a tensor by `assert_close`. If omitted, no reference test will be performed.
        reference_fn=None,
        # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
        # values to be tested. If not specified, `sample_inputs_fn` will be used.
        reference_inputs_fn=None,
55
        # If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
56
        # reference inputs. This is usually used whenever we use a PIL kernel as reference.
57
58
59
60
        # Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
        # structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
        # dtype.
        float32_vs_uint8=False,
61
62
63
        # Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
        # manually. If set, triggers a test that makes sure this happens.
        logs_usage=False,
64
65
66
67
68
69
70
71
72
73
        # See InfoBase
        test_marks=None,
        # See InfoBase
        closeness_kwargs=None,
    ):
        super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
        self.kernel = kernel
        self.sample_inputs_fn = sample_inputs_fn
        self.reference_fn = reference_fn
        self.reference_inputs_fn = reference_inputs_fn
74

75
76
77
        if float32_vs_uint8 and not callable(float32_vs_uint8):
            float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs)  # noqa: E731
        self.float32_vs_uint8 = float32_vs_uint8
78
        self.logs_usage = logs_usage
79
80


81
def pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
82
    return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, mae=mae)
83
84
85
86


def cuda_vs_cpu_pixel_difference(atol=1):
    return {
87
        (("TestKernels", "test_cuda_vs_cpu"), dtype, "cuda"): pixel_difference_closeness_kwargs(atol, dtype=dtype)
88
89
90
91
        for dtype in [torch.uint8, torch.float32]
    }


92
def pil_reference_pixel_difference(atol=1, mae=False):
93
    return {
94
        (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(
95
            atol, mae=mae
96
97
98
99
        )
    }


100
def float32_vs_uint8_pixel_difference(atol=1, mae=False):
101
102
103
104
105
    return {
        (
            ("TestKernels", "test_float32_vs_uint8"),
            torch.float32,
            "cpu",
106
        ): pixel_difference_closeness_kwargs(atol, dtype=torch.float32, mae=mae)
107
    }
108

109

110
def scripted_vs_eager_float64_tolerances(device, atol=1e-6, rtol=1e-6):
111
112
113
114
115
    return {
        (("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
    }


116
117
def pil_reference_wrapper(pil_kernel):
    @functools.wraps(pil_kernel)
118
119
120
121
    def wrapper(input_tensor, *other_args, **kwargs):
        if input_tensor.dtype != torch.uint8:
            raise pytest.UsageError(f"Can only test uint8 tensor images against PIL, but input is {input_tensor.dtype}")
        if input_tensor.ndim > 3:
122
            raise pytest.UsageError(
123
                f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
124
125
            )

126
127
128
129
130
131
132
133
134
135
136
137
138
139
        input_pil = F.to_image_pil(input_tensor)
        output_pil = pil_kernel(input_pil, *other_args, **kwargs)
        if not isinstance(output_pil, PIL.Image.Image):
            return output_pil

        output_tensor = F.to_image_tensor(output_pil)

        # 2D mask shenanigans
        if output_tensor.ndim == 2 and input_tensor.ndim == 3:
            output_tensor = output_tensor.unsqueeze(0)
        elif output_tensor.ndim == 3 and input_tensor.ndim == 2:
            output_tensor = output_tensor.squeeze(0)

        return output_tensor
140
141
142
143

    return wrapper


144
145
146
147
def xfail_jit(reason, *, condition=None):
    return TestMark(("TestKernels", "test_scripted_vs_eager"), pytest.mark.xfail(reason=reason), condition=condition)


148
def xfail_jit_python_scalar_arg(name, *, reason=None):
149
150
    return xfail_jit(
        reason or f"Python scalar int or float for `{name}` is not supported when scripting",
151
152
153
154
        condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
    )


155
156
157
KERNEL_INFOS = []


158
def get_fills(*, num_channels, dtype):
159
160
    yield None

161
162
163
164
    int_value = get_max_value(dtype)
    float_value = int_value / 2
    yield int_value
    yield float_value
165

166
167
168
    for vector_type in [list, tuple]:
        yield vector_type([int_value])
        yield vector_type([float_value])
169

170
171
172
        if num_channels > 1:
            yield vector_type(float_value * c / 10 for c in range(num_channels))
            yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels))
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187


def float32_vs_uint8_fill_adapter(other_args, kwargs):
    fill = kwargs.get("fill")
    if fill is None:
        return other_args, kwargs

    if isinstance(fill, (int, float)):
        fill /= 255
    else:
        fill = type(fill)(fill_ / 255 for fill_ in fill)

    return other_args, dict(kwargs, fill=fill)


188
189
def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix):
    def transform(bbox, affine_matrix_, format_, spatial_size_):
190
191
        # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
        in_dtype = bbox.dtype
192
193
        if not torch.is_floating_point(bbox):
            bbox = bbox.float()
194
        bbox_xyxy = F.convert_format_bounding_box(
195
196
197
198
            bbox.as_subclass(torch.Tensor),
            old_format=format_,
            new_format=datapoints.BoundingBoxFormat.XYXY,
            inplace=True,
199
        )
200
201
202
203
204
205
206
207
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
208
        transformed_points = np.matmul(points, affine_matrix_.T)
209
210
        out_bbox = torch.tensor(
            [
211
212
213
214
                np.min(transformed_points[:, 0]).item(),
                np.min(transformed_points[:, 1]).item(),
                np.max(transformed_points[:, 0]).item(),
                np.max(transformed_points[:, 1]).item(),
215
            ],
216
            dtype=bbox_xyxy.dtype,
217
        )
218
        out_bbox = F.convert_format_bounding_box(
219
            out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
220
        )
221
222
223
224
        # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
        out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_)
        out_bbox = out_bbox.to(dtype=in_dtype)
        return out_bbox
225
226
227
228

    if bounding_box.ndim < 2:
        bounding_box = [bounding_box]

229
    expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box]
230
231
232
233
234
235
236
237
    if len(expected_bboxes) > 1:
        expected_bboxes = torch.stack(expected_bboxes)
    else:
        expected_bboxes = expected_bboxes[0]

    return expected_bboxes


238
def sample_inputs_convert_format_bounding_box():
239
    formats = list(datapoints.BoundingBoxFormat)
240
    for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
241
        yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
242
243


244
def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
245
    return torchvision.ops.box_convert(
246
247
        bounding_box, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
    ).to(bounding_box.dtype)
248
249
250


def reference_inputs_convert_format_bounding_box():
251
    for args_kwargs in sample_inputs_convert_format_bounding_box():
252
253
        if len(args_kwargs.args[0].shape) == 2:
            yield args_kwargs
254
255
256
257
258
259
260
261


KERNEL_INFOS.append(
    KernelInfo(
        F.convert_format_bounding_box,
        sample_inputs_fn=sample_inputs_convert_format_bounding_box,
        reference_fn=reference_convert_format_bounding_box,
        reference_inputs_fn=reference_inputs_convert_format_bounding_box,
262
        logs_usage=True,
263
264
265
266
    ),
)


267
268
269
270
271
272
def sample_inputs_vertical_flip_image_tensor():
    for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]):
        yield ArgsKwargs(image_loader)


def reference_inputs_vertical_flip_image_tensor():
273
    for image_loader in make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]):
274
275
276
277
278
        yield ArgsKwargs(image_loader)


def sample_inputs_vertical_flip_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders(
279
        formats=[datapoints.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
280
281
    ):
        yield ArgsKwargs(
282
            bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
283
284
285
286
287
288
289
290
        )


def sample_inputs_vertical_flip_mask():
    for image_loader in make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]):
        yield ArgsKwargs(image_loader)


291
292
293
294
295
def sample_inputs_vertical_flip_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


296
297
298
299
300
301
def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
    affine_matrix = np.array(
        [
            [1, 0, 0],
            [0, -1, spatial_size[0]],
        ],
302
        dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
303
304
    )

305
306
307
    expected_bboxes = reference_affine_bounding_box_helper(
        bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
    )
308
309
310
311

    return expected_bboxes


312
313
314
315
316
317
318
319
320
def reference_inputs_vertical_flip_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]):
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
            spatial_size=bounding_box_loader.spatial_size,
        )


321
322
323
324
325
326
327
328
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.vertical_flip_image_tensor,
            kernel_name="vertical_flip_image_tensor",
            sample_inputs_fn=sample_inputs_vertical_flip_image_tensor,
            reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil),
            reference_inputs_fn=reference_inputs_vertical_flip_image_tensor,
329
            float32_vs_uint8=True,
330
331
332
333
        ),
        KernelInfo(
            F.vertical_flip_bounding_box,
            sample_inputs_fn=sample_inputs_vertical_flip_bounding_box,
334
            reference_fn=reference_vertical_flip_bounding_box,
335
            reference_inputs_fn=reference_inputs_vertical_flip_bounding_box,
336
337
338
339
340
        ),
        KernelInfo(
            F.vertical_flip_mask,
            sample_inputs_fn=sample_inputs_vertical_flip_mask,
        ),
341
342
343
344
        KernelInfo(
            F.vertical_flip_video,
            sample_inputs_fn=sample_inputs_vertical_flip_video,
        ),
345
346
347
348
349
350
351
    ]
)

_ROTATE_ANGLES = [-87, 15, 90]


def sample_inputs_rotate_image_tensor():
352
    make_rotate_image_loaders = functools.partial(
353
        make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
354
355
356
357
358
359
360
    )

    for image_loader in make_rotate_image_loaders():
        yield ArgsKwargs(image_loader, angle=15.0, expand=True)

    for image_loader, center in itertools.product(
        make_rotate_image_loaders(), [None, [1.0, 0.5], [1, 2], (1.0, 0.5), (1, 2)]
361
    ):
362
        yield ArgsKwargs(image_loader, angle=15.0, center=center)
363

364
    for image_loader in make_rotate_image_loaders():
365
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
366
367
368
369
370
371
372
            yield ArgsKwargs(image_loader, angle=15.0, fill=fill)

    for image_loader, interpolation in itertools.product(
        make_rotate_image_loaders(),
        [F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR],
    ):
        yield ArgsKwargs(image_loader, angle=15.0, fill=0)
373
374
375


def reference_inputs_rotate_image_tensor():
376
    for image_loader, angle in itertools.product(make_image_loaders_for_interpolation(), _ROTATE_ANGLES):
377
378
379
380
381
382
383
384
        yield ArgsKwargs(image_loader, angle=angle)


def sample_inputs_rotate_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
385
            spatial_size=bounding_box_loader.spatial_size,
386
387
388
389
            angle=_ROTATE_ANGLES[0],
        )


390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def reference_inputs_rotate_bounding_box():
    for bounding_box_loader, angle in itertools.product(
        make_bounding_box_loaders(extra_dims=((), (4,))), _ROTATE_ANGLES
    ):
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
            spatial_size=bounding_box_loader.spatial_size,
            angle=angle,
        )

    # TODO: add samples with expand=True and center


def reference_rotate_bounding_box(bounding_box, *, format, spatial_size, angle, expand=False, center=None):

    if center is None:
        center = [spatial_size[1] * 0.5, spatial_size[0] * 0.5]

    a = np.cos(angle * np.pi / 180.0)
    b = np.sin(angle * np.pi / 180.0)
    cx = center[0]
    cy = center[1]
    affine_matrix = np.array(
        [
            [a, b, cx - cx * a - b * cy],
            [-b, a, cy + cx * b - a * cy],
        ],
        dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
    )

    expected_bboxes = reference_affine_bounding_box_helper(
        bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
    )
    return expected_bboxes, spatial_size


427
def sample_inputs_rotate_mask():
428
429
    for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
        yield ArgsKwargs(mask_loader, angle=15.0)
430
431


432
433
434
435
436
def sample_inputs_rotate_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, angle=15.0)


437
438
439
440
441
442
443
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.rotate_image_tensor,
            sample_inputs_fn=sample_inputs_rotate_image_tensor,
            reference_fn=pil_reference_wrapper(F.rotate_image_pil),
            reference_inputs_fn=reference_inputs_rotate_image_tensor,
444
            float32_vs_uint8=True,
445
            closeness_kwargs=pil_reference_pixel_difference(1, mae=True),
446
            test_marks=[
447
                xfail_jit_python_scalar_arg("fill"),
448
            ],
449
450
451
452
        ),
        KernelInfo(
            F.rotate_bounding_box,
            sample_inputs_fn=sample_inputs_rotate_bounding_box,
453
454
            reference_fn=reference_rotate_bounding_box,
            reference_inputs_fn=reference_inputs_rotate_bounding_box,
455
            closeness_kwargs={
456
457
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-4, rtol=1e-4),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-4, rtol=1e-4),
458
            },
459
460
461
462
463
        ),
        KernelInfo(
            F.rotate_mask,
            sample_inputs_fn=sample_inputs_rotate_mask,
        ),
464
465
466
467
        KernelInfo(
            F.rotate_video,
            sample_inputs_fn=sample_inputs_rotate_video,
        ),
468
469
470
471
472
473
474
    ]
)

_CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20], width=[12, 20])


def sample_inputs_crop_image_tensor():
475
    for image_loader, params in itertools.product(
476
        make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
477
478
479
480
481
482
483
484
        [
            dict(top=4, left=3, height=7, width=8),
            dict(top=-1, left=3, height=7, width=8),
            dict(top=4, left=-1, height=7, width=8),
            dict(top=4, left=3, height=17, width=8),
            dict(top=4, left=3, height=7, width=18),
        ],
    ):
485
486
487
488
        yield ArgsKwargs(image_loader, **params)


def reference_inputs_crop_image_tensor():
489
490
491
    for image_loader, params in itertools.product(
        make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _CROP_PARAMS
    ):
492
493
494
495
496
497
498
        yield ArgsKwargs(image_loader, **params)


def sample_inputs_crop_bounding_box():
    for bounding_box_loader, params in itertools.product(
        make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
    ):
499
        yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
500
501
502


def sample_inputs_crop_mask():
503
504
    for mask_loader in make_mask_loaders(sizes=[(16, 17)], num_categories=["random"], num_objects=["random"]):
        yield ArgsKwargs(mask_loader, top=4, left=3, height=7, width=8)
505
506
507
508
509
510
511


def reference_inputs_crop_mask():
    for mask_loader, params in itertools.product(make_mask_loaders(extra_dims=[()], num_objects=[1]), _CROP_PARAMS):
        yield ArgsKwargs(mask_loader, **params)


512
513
514
515
516
def sample_inputs_crop_video():
    for video_loader in make_video_loaders(sizes=[(16, 17)], num_frames=["random"]):
        yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8)


517
518
519
520
521
522
def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width):
    affine_matrix = np.array(
        [
            [1, 0, -left],
            [0, 1, -top],
        ],
523
        dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
524
525
    )

526
527
528
529
530
    spatial_size = (height, width)
    expected_bboxes = reference_affine_bounding_box_helper(
        bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
    )
    return expected_bboxes, spatial_size
531
532
533
534
535
536
537
538
539


def reference_inputs_crop_bounding_box():
    for bounding_box_loader, params in itertools.product(
        make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
    ):
        yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)


540
541
542
543
544
545
546
547
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.crop_image_tensor,
            kernel_name="crop_image_tensor",
            sample_inputs_fn=sample_inputs_crop_image_tensor,
            reference_fn=pil_reference_wrapper(F.crop_image_pil),
            reference_inputs_fn=reference_inputs_crop_image_tensor,
548
            float32_vs_uint8=True,
549
550
551
552
        ),
        KernelInfo(
            F.crop_bounding_box,
            sample_inputs_fn=sample_inputs_crop_bounding_box,
553
554
            reference_fn=reference_crop_bounding_box,
            reference_inputs_fn=reference_inputs_crop_bounding_box,
555
556
557
558
559
560
        ),
        KernelInfo(
            F.crop_mask,
            sample_inputs_fn=sample_inputs_crop_mask,
            reference_fn=pil_reference_wrapper(F.crop_image_pil),
            reference_inputs_fn=reference_inputs_crop_mask,
561
            float32_vs_uint8=True,
562
        ),
563
564
565
566
        KernelInfo(
            F.crop_video,
            sample_inputs_fn=sample_inputs_crop_video,
        ),
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    ]
)

_RESIZED_CROP_PARAMS = combinations_grid(top=[-8, 9], left=[-8, 9], height=[12], width=[12], size=[(16, 18)])


def sample_inputs_resized_crop_image_tensor():
    for image_loader in make_image_loaders():
        yield ArgsKwargs(image_loader, **_RESIZED_CROP_PARAMS[0])


@pil_reference_wrapper
def reference_resized_crop_image_tensor(*args, **kwargs):
    if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in {
        F.InterpolationMode.BILINEAR,
        F.InterpolationMode.BICUBIC,
    }:
        raise pytest.UsageError("Anti-aliasing is always active in PIL")
    return F.resized_crop_image_pil(*args, **kwargs)


def reference_inputs_resized_crop_image_tensor():
    for image_loader, interpolation, params in itertools.product(
590
        make_image_loaders_for_interpolation(),
591
592
        [
            F.InterpolationMode.NEAREST,
593
            F.InterpolationMode.NEAREST_EXACT,
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
            F.InterpolationMode.BILINEAR,
            F.InterpolationMode.BICUBIC,
        ],
        _RESIZED_CROP_PARAMS,
    ):
        yield ArgsKwargs(
            image_loader,
            interpolation=interpolation,
            antialias=interpolation
            in {
                F.InterpolationMode.BILINEAR,
                F.InterpolationMode.BICUBIC,
            },
            **params,
        )


def sample_inputs_resized_crop_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **_RESIZED_CROP_PARAMS[0])


def sample_inputs_resized_crop_mask():
    for mask_loader in make_mask_loaders():
        yield ArgsKwargs(mask_loader, **_RESIZED_CROP_PARAMS[0])


621
622
623
624
625
def sample_inputs_resized_crop_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, **_RESIZED_CROP_PARAMS[0])


626
627
628
629
630
631
632
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.resized_crop_image_tensor,
            sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
            reference_fn=reference_resized_crop_image_tensor,
            reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
633
            float32_vs_uint8=True,
634
            closeness_kwargs={
635
                **cuda_vs_cpu_pixel_difference(),
636
637
                **pil_reference_pixel_difference(3, mae=True),
                **float32_vs_uint8_pixel_difference(3, mae=True),
638
            },
639
640
641
642
643
644
645
646
647
        ),
        KernelInfo(
            F.resized_crop_bounding_box,
            sample_inputs_fn=sample_inputs_resized_crop_bounding_box,
        ),
        KernelInfo(
            F.resized_crop_mask,
            sample_inputs_fn=sample_inputs_resized_crop_mask,
        ),
648
649
650
        KernelInfo(
            F.resized_crop_video,
            sample_inputs_fn=sample_inputs_resized_crop_video,
651
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
652
        ),
653
654
655
656
657
658
659
660
661
662
    ]
)

_PAD_PARAMS = combinations_grid(
    padding=[[1], [1, 1], [1, 1, 2, 2]],
    padding_mode=["constant", "symmetric", "edge", "reflect"],
)


def sample_inputs_pad_image_tensor():
663
    make_pad_image_loaders = functools.partial(
664
        make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
665
666
667
668
669
670
671
672
673
    )

    for image_loader, padding in itertools.product(
        make_pad_image_loaders(),
        [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]],
    ):
        yield ArgsKwargs(image_loader, padding=padding)

    for image_loader in make_pad_image_loaders():
674
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
675
676
677
678
679
680
681
682
683
684
685
686
687
            yield ArgsKwargs(image_loader, padding=[1], fill=fill)

    for image_loader, padding_mode in itertools.product(
        # We branch for non-constant padding and integer inputs
        make_pad_image_loaders(dtypes=[torch.uint8]),
        ["constant", "symmetric", "edge", "reflect"],
    ):
        yield ArgsKwargs(image_loader, padding=[1], padding_mode=padding_mode)

    # `torch.nn.functional.pad` does not support symmetric padding, and thus we have a custom implementation. Besides
    # negative padding, this is already handled by the inputs above.
    for image_loader in make_pad_image_loaders():
        yield ArgsKwargs(image_loader, padding=[-1], padding_mode="symmetric")
688
689
690


def reference_inputs_pad_image_tensor():
691
692
693
694
695
696
697
    for image_loader, params in itertools.product(
        make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
    ):
        for fill in get_fills(
            num_channels=image_loader.num_channels,
            dtype=image_loader.dtype,
        ):
698
699
700
701
            # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
            if isinstance(fill, (list, tuple)):
                continue

702
703
704
705
            yield ArgsKwargs(image_loader, fill=fill, **params)


def sample_inputs_pad_bounding_box():
706
707
708
    for bounding_box_loader, padding in itertools.product(
        make_bounding_box_loaders(), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
    ):
709
        yield ArgsKwargs(
710
711
            bounding_box_loader,
            format=bounding_box_loader.format,
712
            spatial_size=bounding_box_loader.spatial_size,
713
714
            padding=padding,
            padding_mode="constant",
715
        )
716
717
718


def sample_inputs_pad_mask():
719
720
    for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
        yield ArgsKwargs(mask_loader, padding=[1])
721
722
723


def reference_inputs_pad_mask():
724
725
726
727
    for mask_loader, fill, params in itertools.product(
        make_mask_loaders(num_objects=[1], extra_dims=[()]), [None, 127], _PAD_PARAMS
    ):
        yield ArgsKwargs(mask_loader, fill=fill, **params)
728
729


730
731
732
733
734
def sample_inputs_pad_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, padding=[1])


735
736
737
738
739
740
741
742
743
def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, padding_mode):

    left, right, top, bottom = _parse_pad_padding(padding)

    affine_matrix = np.array(
        [
            [1, 0, left],
            [0, 1, top],
        ],
744
        dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
745
746
747
748
749
    )

    height = spatial_size[0] + top + bottom
    width = spatial_size[1] + left + right

750
751
752
    expected_bboxes = reference_affine_bounding_box_helper(
        bounding_box, format=format, spatial_size=(height, width), affine_matrix=affine_matrix
    )
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
    return expected_bboxes, (height, width)


def reference_inputs_pad_bounding_box():
    for bounding_box_loader, padding in itertools.product(
        make_bounding_box_loaders(extra_dims=((), (4,))), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
    ):
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
            spatial_size=bounding_box_loader.spatial_size,
            padding=padding,
            padding_mode="constant",
        )


769
770
771
772
773
774
775
776
777
778
def pad_xfail_jit_fill_condition(args_kwargs):
    fill = args_kwargs.kwargs.get("fill")
    if not isinstance(fill, (list, tuple)):
        return False
    elif isinstance(fill, tuple):
        return True
    else:  # isinstance(fill, list):
        return all(isinstance(f, int) for f in fill)


779
780
781
782
783
784
785
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.pad_image_tensor,
            sample_inputs_fn=sample_inputs_pad_image_tensor,
            reference_fn=pil_reference_wrapper(F.pad_image_pil),
            reference_inputs_fn=reference_inputs_pad_image_tensor,
786
787
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
788
            test_marks=[
789
790
791
792
                xfail_jit_python_scalar_arg("padding"),
                xfail_jit(
                    "F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition
                ),
793
            ],
794
795
796
797
        ),
        KernelInfo(
            F.pad_bounding_box,
            sample_inputs_fn=sample_inputs_pad_bounding_box,
798
799
            reference_fn=reference_pad_bounding_box,
            reference_inputs_fn=reference_inputs_pad_bounding_box,
800
            test_marks=[
801
                xfail_jit_python_scalar_arg("padding"),
802
            ],
803
804
805
806
807
808
        ),
        KernelInfo(
            F.pad_mask,
            sample_inputs_fn=sample_inputs_pad_mask,
            reference_fn=pil_reference_wrapper(F.pad_image_pil),
            reference_inputs_fn=reference_inputs_pad_mask,
809
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
810
        ),
811
812
813
814
        KernelInfo(
            F.pad_video,
            sample_inputs_fn=sample_inputs_pad_video,
        ),
815
816
817
818
819
820
821
    ]
)

_PERSPECTIVE_COEFFS = [
    [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
    [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
]
822
823
_STARTPOINTS = [[0, 1], [2, 3], [4, 5], [6, 7]]
_ENDPOINTS = [[9, 8], [7, 6], [5, 4], [3, 2]]
824
825
826


def sample_inputs_perspective_image_tensor():
827
    for image_loader in make_image_loaders(sizes=["random"]):
828
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
829
830
831
832
833
            yield ArgsKwargs(
                image_loader, startpoints=None, endpoints=None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0]
            )

    yield ArgsKwargs(make_image_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
834
835
836


def reference_inputs_perspective_image_tensor():
837
838
839
840
841
842
843
    for image_loader, coefficients, interpolation in itertools.product(
        make_image_loaders_for_interpolation(),
        _PERSPECTIVE_COEFFS,
        [
            F.InterpolationMode.NEAREST,
            F.InterpolationMode.BILINEAR,
        ],
844
845
    ):
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
846
847
848
849
            # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
            if isinstance(fill, (list, tuple)):
                continue

850
851
852
853
854
855
856
857
            yield ArgsKwargs(
                image_loader,
                startpoints=None,
                endpoints=None,
                interpolation=interpolation,
                fill=fill,
                coefficients=coefficients,
            )
858
859
860
861
862


def sample_inputs_perspective_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(
863
864
            bounding_box_loader,
            format=bounding_box_loader.format,
865
            spatial_size=bounding_box_loader.spatial_size,
866
867
868
            startpoints=None,
            endpoints=None,
            coefficients=_PERSPECTIVE_COEFFS[0],
869
870
        )

871
    format = datapoints.BoundingBoxFormat.XYXY
872
    loader = make_bounding_box_loader(format=format)
873
    yield ArgsKwargs(
874
        loader, format=format, spatial_size=loader.spatial_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
875
876
    )

877
878

def sample_inputs_perspective_mask():
879
    for mask_loader in make_mask_loaders(sizes=["random"]):
880
881
882
        yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])

    yield ArgsKwargs(make_detection_mask_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
883
884
885
886
887
888


def reference_inputs_perspective_mask():
    for mask_loader, perspective_coeffs in itertools.product(
        make_mask_loaders(extra_dims=[()], num_objects=[1]), _PERSPECTIVE_COEFFS
    ):
889
        yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=perspective_coeffs)
890
891


892
893
def sample_inputs_perspective_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
894
895
896
        yield ArgsKwargs(video_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])

    yield ArgsKwargs(make_video_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
897
898


899
900
901
902
903
904
905
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.perspective_image_tensor,
            sample_inputs_fn=sample_inputs_perspective_image_tensor,
            reference_fn=pil_reference_wrapper(F.perspective_image_pil),
            reference_inputs_fn=reference_inputs_perspective_image_tensor,
906
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
907
            closeness_kwargs={
908
                **pil_reference_pixel_difference(2, mae=True),
909
910
                **cuda_vs_cpu_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
911
912
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
913
            },
914
            test_marks=[xfail_jit_python_scalar_arg("fill")],
915
916
917
918
        ),
        KernelInfo(
            F.perspective_bounding_box,
            sample_inputs_fn=sample_inputs_perspective_bounding_box,
919
920
921
922
            closeness_kwargs={
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6),
            },
923
924
925
926
927
928
        ),
        KernelInfo(
            F.perspective_mask,
            sample_inputs_fn=sample_inputs_perspective_mask,
            reference_fn=pil_reference_wrapper(F.perspective_image_pil),
            reference_inputs_fn=reference_inputs_perspective_mask,
929
930
931
932
            float32_vs_uint8=True,
            closeness_kwargs={
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=10, rtol=0),
            },
933
934
935
936
        ),
        KernelInfo(
            F.perspective_video,
            sample_inputs_fn=sample_inputs_perspective_video,
937
938
            closeness_kwargs={
                **cuda_vs_cpu_pixel_difference(),
939
940
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
941
            },
942
943
944
945
946
        ),
    ]
)


947
948
def _get_elastic_displacement(spatial_size):
    return torch.rand(1, *spatial_size, 2)
949
950
951


def sample_inputs_elastic_image_tensor():
952
    for image_loader in make_image_loaders(sizes=["random"]):
953
        displacement = _get_elastic_displacement(image_loader.spatial_size)
954
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
955
956
957
958
959
            yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)


def reference_inputs_elastic_image_tensor():
    for image_loader, interpolation in itertools.product(
960
        make_image_loaders_for_interpolation(),
961
962
963
964
965
966
        [
            F.InterpolationMode.NEAREST,
            F.InterpolationMode.BILINEAR,
            F.InterpolationMode.BICUBIC,
        ],
    ):
967
        displacement = _get_elastic_displacement(image_loader.spatial_size)
968
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
969
970
971
972
973
            yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)


def sample_inputs_elastic_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
974
        displacement = _get_elastic_displacement(bounding_box_loader.spatial_size)
975
976
977
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
978
            spatial_size=bounding_box_loader.spatial_size,
979
980
981
982
983
            displacement=displacement,
        )


def sample_inputs_elastic_mask():
984
    for mask_loader in make_mask_loaders(sizes=["random"]):
985
986
987
988
        displacement = _get_elastic_displacement(mask_loader.shape[-2:])
        yield ArgsKwargs(mask_loader, displacement=displacement)


989
990
991
992
993
994
def sample_inputs_elastic_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        displacement = _get_elastic_displacement(video_loader.shape[-2:])
        yield ArgsKwargs(video_loader, displacement=displacement)


995
996
997
998
999
1000
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.elastic_image_tensor,
            sample_inputs_fn=sample_inputs_elastic_image_tensor,
            reference_inputs_fn=reference_inputs_elastic_image_tensor,
1001
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
1002
            closeness_kwargs={
1003
                **float32_vs_uint8_pixel_difference(6, mae=True),
1004
1005
                **cuda_vs_cpu_pixel_difference(),
            },
1006
            test_marks=[xfail_jit_python_scalar_arg("fill")],
1007
1008
1009
1010
1011
1012
1013
1014
        ),
        KernelInfo(
            F.elastic_bounding_box,
            sample_inputs_fn=sample_inputs_elastic_bounding_box,
        ),
        KernelInfo(
            F.elastic_mask,
            sample_inputs_fn=sample_inputs_elastic_mask,
1015
1016
1017
1018
        ),
        KernelInfo(
            F.elastic_video,
            sample_inputs_fn=sample_inputs_elastic_video,
1019
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
1020
1021
1022
1023
1024
        ),
    ]
)


1025
_CENTER_CROP_SPATIAL_SIZES = [(16, 16), (7, 33), (31, 9)]
1026
_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
1027
1028
1029
1030


def sample_inputs_center_crop_image_tensor():
    for image_loader, output_size in itertools.product(
1031
        make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
1032
1033
1034
1035
1036
1037
        [
            # valid `output_size` types for which cropping is applied to both dimensions
            *[5, (4,), (2, 3), [6], [3, 2]],
            # `output_size`'s for which at least one dimension needs to be padded
            *[[4, 18], [17, 5], [17, 18]],
        ],
1038
1039
1040
1041
1042
1043
    ):
        yield ArgsKwargs(image_loader, output_size=output_size)


def reference_inputs_center_crop_image_tensor():
    for image_loader, output_size in itertools.product(
1044
1045
        make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], dtypes=[torch.uint8]),
        _CENTER_CROP_OUTPUT_SIZES,
1046
1047
1048
1049
1050
1051
1052
1053
1054
    ):
        yield ArgsKwargs(image_loader, output_size=output_size)


def sample_inputs_center_crop_bounding_box():
    for bounding_box_loader, output_size in itertools.product(make_bounding_box_loaders(), _CENTER_CROP_OUTPUT_SIZES):
        yield ArgsKwargs(
            bounding_box_loader,
            format=bounding_box_loader.format,
1055
            spatial_size=bounding_box_loader.spatial_size,
1056
1057
1058
1059
1060
            output_size=output_size,
        )


def sample_inputs_center_crop_mask():
1061
1062
1063
    for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
        height, width = mask_loader.shape[-2:]
        yield ArgsKwargs(mask_loader, output_size=(height // 2, width // 2))
1064
1065
1066
1067


def reference_inputs_center_crop_mask():
    for mask_loader, output_size in itertools.product(
1068
        make_mask_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES
1069
1070
1071
1072
    ):
        yield ArgsKwargs(mask_loader, output_size=output_size)


1073
1074
1075
1076
1077
1078
def sample_inputs_center_crop_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        height, width = video_loader.shape[-2:]
        yield ArgsKwargs(video_loader, output_size=(height // 2, width // 2))


1079
1080
1081
1082
1083
1084
1085
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.center_crop_image_tensor,
            sample_inputs_fn=sample_inputs_center_crop_image_tensor,
            reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
            reference_inputs_fn=reference_inputs_center_crop_image_tensor,
1086
            float32_vs_uint8=True,
1087
            test_marks=[
1088
                xfail_jit_python_scalar_arg("output_size"),
1089
            ],
1090
1091
1092
1093
        ),
        KernelInfo(
            F.center_crop_bounding_box,
            sample_inputs_fn=sample_inputs_center_crop_bounding_box,
1094
            test_marks=[
1095
                xfail_jit_python_scalar_arg("output_size"),
1096
            ],
1097
1098
1099
1100
1101
1102
        ),
        KernelInfo(
            F.center_crop_mask,
            sample_inputs_fn=sample_inputs_center_crop_mask,
            reference_fn=pil_reference_wrapper(F.center_crop_image_pil),
            reference_inputs_fn=reference_inputs_center_crop_mask,
1103
            float32_vs_uint8=True,
1104
            test_marks=[
1105
                xfail_jit_python_scalar_arg("output_size"),
1106
            ],
1107
        ),
1108
1109
1110
1111
        KernelInfo(
            F.center_crop_video,
            sample_inputs_fn=sample_inputs_center_crop_video,
        ),
1112
1113
1114
1115
1116
    ]
)


def sample_inputs_gaussian_blur_image_tensor():
1117
    make_gaussian_blur_image_loaders = functools.partial(make_image_loaders, sizes=[(7, 33)], color_spaces=["RGB"])
1118
1119
1120
1121
1122
1123

    for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
        yield ArgsKwargs(image_loader, kernel_size=kernel_size)

    for image_loader, sigma in itertools.product(
        make_gaussian_blur_image_loaders(), [None, (3.0, 3.0), [2.0, 2.0], 4.0, [1.5], (3.14,)]
1124
    ):
1125
        yield ArgsKwargs(image_loader, kernel_size=5, sigma=sigma)
1126
1127


1128
def sample_inputs_gaussian_blur_video():
1129
    for video_loader in make_video_loaders(sizes=[(7, 33)], num_frames=[5]):
1130
1131
1132
1133
1134
1135
1136
1137
        yield ArgsKwargs(video_loader, kernel_size=[3, 3])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.gaussian_blur_image_tensor,
            sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
1138
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
1139
1140
1141
1142
1143
1144
1145
1146
            test_marks=[
                xfail_jit_python_scalar_arg("kernel_size"),
                xfail_jit_python_scalar_arg("sigma"),
            ],
        ),
        KernelInfo(
            F.gaussian_blur_video,
            sample_inputs_fn=sample_inputs_gaussian_blur_video,
1147
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
1148
1149
        ),
    ]
1150
1151
1152
1153
)


def sample_inputs_equalize_image_tensor():
1154
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1155
1156
1157
1158
        yield ArgsKwargs(image_loader)


def reference_inputs_equalize_image_tensor():
1159
1160
1161
    # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
    # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
    # the information gain is low if we already provide something really close to the expected value.
1162
    def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
1163
1164
1165
1166
1167
1168
1169
        if dtype.is_floating_point:
            low = low_factor
            high = high_factor
        else:
            max_value = torch.iinfo(dtype).max
            low = int(low_factor * max_value)
            high = int(high_factor * max_value)
1170
1171
1172
        return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
            memory_format=memory_format, copy=True
        )
1173

1174
    def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
1175
1176
1177
        image = torch.distributions.Beta(alpha, beta).sample(shape)
        if not dtype.is_floating_point:
            image.mul_(torch.iinfo(dtype).max).round_()
1178
        return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)
1179

1180
    spatial_size = (256, 256)
1181
    for dtype, color_space, fn in itertools.product(
1182
        [torch.uint8],
1183
        ["GRAY", "RGB"],
1184
        [
1185
1186
            lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
                memory_format=memory_format, copy=True
1187
            ),
1188
1189
1190
            lambda shape, dtype, device, memory_format: torch.full(
                shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
            ).to(memory_format=memory_format, copy=True),
1191
            *[
1192
1193
1194
1195
1196
                functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
                for low_factor, high_factor in [
                    (0.0, 0.25),
                    (0.25, 0.75),
                    (0.75, 1.0),
1197
1198
1199
                ]
            ],
            *[
1200
                functools.partial(make_beta_distributed_image, alpha=alpha, beta=beta)
1201
1202
1203
1204
1205
1206
1207
1208
                for alpha, beta in [
                    (0.5, 0.5),
                    (2, 2),
                    (2, 5),
                    (5, 2),
                ]
            ],
        ],
1209
    ):
1210
        image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype)
1211
1212
1213
        yield ArgsKwargs(image_loader)


1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
def sample_inputs_equalize_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.equalize_image_tensor,
            kernel_name="equalize_image_tensor",
            sample_inputs_fn=sample_inputs_equalize_image_tensor,
            reference_fn=pil_reference_wrapper(F.equalize_image_pil),
1226
            float32_vs_uint8=True,
1227
1228
1229
1230
1231
1232
1233
            reference_inputs_fn=reference_inputs_equalize_image_tensor,
        ),
        KernelInfo(
            F.equalize_video,
            sample_inputs_fn=sample_inputs_equalize_video,
        ),
    ]
1234
1235
1236
1237
)


def sample_inputs_invert_image_tensor():
1238
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1239
1240
1241
1242
        yield ArgsKwargs(image_loader)


def reference_inputs_invert_image_tensor():
1243
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1244
1245
1246
        yield ArgsKwargs(image_loader)


1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
def sample_inputs_invert_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.invert_image_tensor,
            kernel_name="invert_image_tensor",
            sample_inputs_fn=sample_inputs_invert_image_tensor,
            reference_fn=pil_reference_wrapper(F.invert_image_pil),
            reference_inputs_fn=reference_inputs_invert_image_tensor,
1260
            float32_vs_uint8=True,
1261
1262
1263
1264
1265
1266
        ),
        KernelInfo(
            F.invert_video,
            sample_inputs_fn=sample_inputs_invert_video,
        ),
    ]
1267
1268
1269
1270
1271
1272
1273
)


_POSTERIZE_BITS = [1, 4, 8]


def sample_inputs_posterize_image_tensor():
1274
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1275
1276
1277
1278
1279
        yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])


def reference_inputs_posterize_image_tensor():
    for image_loader, bits in itertools.product(
1280
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1281
1282
1283
1284
1285
        _POSTERIZE_BITS,
    ):
        yield ArgsKwargs(image_loader, bits=bits)


1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
def sample_inputs_posterize_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.posterize_image_tensor,
            kernel_name="posterize_image_tensor",
            sample_inputs_fn=sample_inputs_posterize_image_tensor,
            reference_fn=pil_reference_wrapper(F.posterize_image_pil),
            reference_inputs_fn=reference_inputs_posterize_image_tensor,
1299
1300
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
1301
1302
1303
1304
1305
1306
        ),
        KernelInfo(
            F.posterize_video,
            sample_inputs_fn=sample_inputs_posterize_video,
        ),
    ]
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
)


def _get_solarize_thresholds(dtype):
    for factor in [0.1, 0.5]:
        max_value = get_max_value(dtype)
        yield (float if dtype.is_floating_point else int)(max_value * factor)


def sample_inputs_solarize_image_tensor():
1317
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1318
1319
1320
1321
        yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype)))


def reference_inputs_solarize_image_tensor():
1322
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1323
1324
1325
1326
        for threshold in _get_solarize_thresholds(image_loader.dtype):
            yield ArgsKwargs(image_loader, threshold=threshold)


1327
1328
1329
1330
def uint8_to_float32_threshold_adapter(other_args, kwargs):
    return other_args, dict(threshold=kwargs["threshold"] / 255)


1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
def sample_inputs_solarize_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype)))


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.solarize_image_tensor,
            kernel_name="solarize_image_tensor",
            sample_inputs_fn=sample_inputs_solarize_image_tensor,
            reference_fn=pil_reference_wrapper(F.solarize_image_pil),
            reference_inputs_fn=reference_inputs_solarize_image_tensor,
1344
1345
            float32_vs_uint8=uint8_to_float32_threshold_adapter,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
1346
1347
1348
1349
1350
1351
        ),
        KernelInfo(
            F.solarize_video,
            sample_inputs_fn=sample_inputs_solarize_video,
        ),
    ]
1352
1353
1354
1355
)


def sample_inputs_autocontrast_image_tensor():
1356
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1357
1358
1359
1360
        yield ArgsKwargs(image_loader)


def reference_inputs_autocontrast_image_tensor():
1361
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1362
1363
1364
        yield ArgsKwargs(image_loader)


1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
def sample_inputs_autocontrast_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.autocontrast_image_tensor,
            kernel_name="autocontrast_image_tensor",
            sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
            reference_fn=pil_reference_wrapper(F.autocontrast_image_pil),
            reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
1378
1379
1380
1381
1382
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
            },
1383
1384
1385
1386
1387
1388
        ),
        KernelInfo(
            F.autocontrast_video,
            sample_inputs_fn=sample_inputs_autocontrast_video,
        ),
    ]
1389
1390
1391
1392
1393
1394
1395
1396
)

_ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_sharpness_image_tensor():
    for image_loader in make_image_loaders(
        sizes=["random", (2, 2)],
1397
        color_spaces=("GRAY", "RGB"),
1398
1399
1400
1401
1402
1403
    ):
        yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])


def reference_inputs_adjust_sharpness_image_tensor():
    for image_loader, sharpness_factor in itertools.product(
1404
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1405
1406
1407
1408
1409
        _ADJUST_SHARPNESS_FACTORS,
    ):
        yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)


1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
def sample_inputs_adjust_sharpness_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_sharpness_image_tensor,
            kernel_name="adjust_sharpness_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil),
            reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
1423
1424
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(2),
1425
1426
1427
1428
1429
1430
        ),
        KernelInfo(
            F.adjust_sharpness_video,
            sample_inputs_fn=sample_inputs_adjust_sharpness_video,
        ),
    ]
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
)


def sample_inputs_erase_image_tensor():
    for image_loader in make_image_loaders(sizes=["random"]):
        # FIXME: make the parameters more diverse
        h, w = 6, 7
        v = torch.rand(image_loader.num_channels, h, w)
        yield ArgsKwargs(image_loader, i=1, j=2, h=h, w=w, v=v)


1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
def sample_inputs_erase_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        # FIXME: make the parameters more diverse
        h, w = 6, 7
        v = torch.rand(video_loader.num_channels, h, w)
        yield ArgsKwargs(video_loader, i=1, j=2, h=h, w=w, v=v)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.erase_image_tensor,
            kernel_name="erase_image_tensor",
            sample_inputs_fn=sample_inputs_erase_image_tensor,
        ),
        KernelInfo(
            F.erase_video,
            sample_inputs_fn=sample_inputs_erase_video,
        ),
    ]
1462
)
1463
1464
1465
1466
1467

_ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_brightness_image_tensor():
1468
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1469
1470
1471
1472
1473
        yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])


def reference_inputs_adjust_brightness_image_tensor():
    for image_loader, brightness_factor in itertools.product(
1474
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1475
1476
1477
1478
1479
        _ADJUST_BRIGHTNESS_FACTORS,
    ):
        yield ArgsKwargs(image_loader, brightness_factor=brightness_factor)


1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
def sample_inputs_adjust_brightness_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_brightness_image_tensor,
            kernel_name="adjust_brightness_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
            reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
1493
1494
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
1495
1496
1497
1498
1499
1500
        ),
        KernelInfo(
            F.adjust_brightness_video,
            sample_inputs_fn=sample_inputs_adjust_brightness_video,
        ),
    ]
1501
1502
1503
1504
1505
1506
1507
)


_ADJUST_CONTRAST_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_contrast_image_tensor():
1508
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1509
1510
1511
1512
1513
        yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])


def reference_inputs_adjust_contrast_image_tensor():
    for image_loader, contrast_factor in itertools.product(
1514
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1515
1516
1517
1518
1519
        _ADJUST_CONTRAST_FACTORS,
    ):
        yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)


1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
def sample_inputs_adjust_contrast_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_contrast_image_tensor,
            kernel_name="adjust_contrast_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil),
            reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
1533
1534
1535
1536
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(2),
1537
                **cuda_vs_cpu_pixel_difference(),
1538
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
1539
            },
1540
1541
1542
1543
        ),
        KernelInfo(
            F.adjust_contrast_video,
            sample_inputs_fn=sample_inputs_adjust_contrast_video,
1544
1545
1546
1547
            closeness_kwargs={
                **cuda_vs_cpu_pixel_difference(),
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
            },
1548
1549
        ),
    ]
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
)

_ADJUST_GAMMA_GAMMAS_GAINS = [
    (0.5, 2.0),
    (0.0, 1.0),
]


def sample_inputs_adjust_gamma_image_tensor():
    gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
1560
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1561
1562
1563
1564
1565
        yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)


def reference_inputs_adjust_gamma_image_tensor():
    for image_loader, (gamma, gain) in itertools.product(
1566
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1567
1568
1569
1570
1571
        _ADJUST_GAMMA_GAMMAS_GAINS,
    ):
        yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)


1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
def sample_inputs_adjust_gamma_video():
    gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, gamma=gamma, gain=gain)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_gamma_image_tensor,
            kernel_name="adjust_gamma_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil),
            reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
1586
1587
1588
1589
1590
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
            },
1591
1592
1593
1594
1595
1596
        ),
        KernelInfo(
            F.adjust_gamma_video,
            sample_inputs_fn=sample_inputs_adjust_gamma_video,
        ),
    ]
1597
1598
1599
1600
1601
1602
1603
)


_ADJUST_HUE_FACTORS = [-0.1, 0.5]


def sample_inputs_adjust_hue_image_tensor():
1604
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1605
1606
1607
1608
1609
        yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])


def reference_inputs_adjust_hue_image_tensor():
    for image_loader, hue_factor in itertools.product(
1610
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1611
1612
1613
1614
1615
        _ADJUST_HUE_FACTORS,
    ):
        yield ArgsKwargs(image_loader, hue_factor=hue_factor)


1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
def sample_inputs_adjust_hue_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_hue_image_tensor,
            kernel_name="adjust_hue_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil),
            reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
1629
1630
            float32_vs_uint8=True,
            closeness_kwargs={
1631
                **pil_reference_pixel_difference(2, mae=True),
1632
1633
                **float32_vs_uint8_pixel_difference(),
            },
1634
1635
1636
1637
1638
1639
        ),
        KernelInfo(
            F.adjust_hue_video,
            sample_inputs_fn=sample_inputs_adjust_hue_video,
        ),
    ]
1640
1641
1642
1643
1644
1645
)

_ADJUST_SATURATION_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_saturation_image_tensor():
1646
    for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
1647
1648
1649
1650
1651
        yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])


def reference_inputs_adjust_saturation_image_tensor():
    for image_loader, saturation_factor in itertools.product(
1652
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1653
1654
1655
1656
1657
        _ADJUST_SATURATION_FACTORS,
    ):
        yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)


1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
def sample_inputs_adjust_saturation_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.adjust_saturation_image_tensor,
            kernel_name="adjust_saturation_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
            reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil),
            reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
1671
1672
1673
1674
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(2),
1675
                **cuda_vs_cpu_pixel_difference(),
1676
            },
1677
1678
1679
1680
        ),
        KernelInfo(
            F.adjust_saturation_video,
            sample_inputs_fn=sample_inputs_adjust_saturation_video,
1681
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
1682
1683
        ),
    ]
1684
1685
1686
1687
1688
1689
)


def sample_inputs_clamp_bounding_box():
    for bounding_box_loader in make_bounding_box_loaders():
        yield ArgsKwargs(
1690
            bounding_box_loader,
1691
1692
            format=bounding_box_loader.format,
            spatial_size=bounding_box_loader.spatial_size,
1693
1694
1695
1696
1697
1698
1699
        )


KERNEL_INFOS.append(
    KernelInfo(
        F.clamp_bounding_box,
        sample_inputs_fn=sample_inputs_clamp_bounding_box,
1700
        logs_usage=True,
1701
1702
1703
1704
1705
1706
    )
)

_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]


1707
def _get_five_ten_crop_spatial_size(size):
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
    if isinstance(size, int):
        crop_height = crop_width = size
    elif len(size) == 1:
        crop_height = crop_width = size[0]
    else:
        crop_height, crop_width = size
    return 2 * crop_height, 2 * crop_width


def sample_inputs_five_crop_image_tensor():
    for size in _FIVE_TEN_CROP_SIZES:
1719
        for image_loader in make_image_loaders(
1720
            sizes=[_get_five_ten_crop_spatial_size(size)],
1721
            color_spaces=["RGB"],
1722
            dtypes=[torch.float32],
1723
        ):
1724
1725
1726
1727
1728
            yield ArgsKwargs(image_loader, size=size)


def reference_inputs_five_crop_image_tensor():
    for size in _FIVE_TEN_CROP_SIZES:
1729
1730
1731
        for image_loader in make_image_loaders(
            sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8]
        ):
1732
1733
1734
            yield ArgsKwargs(image_loader, size=size)


1735
1736
1737
1738
1739
1740
def sample_inputs_five_crop_video():
    size = _FIVE_TEN_CROP_SIZES[0]
    for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_spatial_size(size)]):
        yield ArgsKwargs(video_loader, size=size)


1741
1742
def sample_inputs_ten_crop_image_tensor():
    for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
1743
        for image_loader in make_image_loaders(
1744
            sizes=[_get_five_ten_crop_spatial_size(size)],
1745
            color_spaces=["RGB"],
1746
            dtypes=[torch.float32],
1747
        ):
1748
1749
1750
1751
1752
            yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


def reference_inputs_ten_crop_image_tensor():
    for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
1753
1754
1755
        for image_loader in make_image_loaders(
            sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()], dtypes=[torch.uint8]
        ):
1756
1757
1758
            yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


1759
1760
1761
1762
1763
1764
def sample_inputs_ten_crop_video():
    size = _FIVE_TEN_CROP_SIZES[0]
    for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_spatial_size(size)]):
        yield ArgsKwargs(video_loader, size=size)


1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
def multi_crop_pil_reference_wrapper(pil_kernel):
    def wrapper(input_tensor, *other_args, **kwargs):
        output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
        return type(output)(
            F.convert_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype)
            for output_pil in output
        )

    return wrapper


1776
1777
1778
1779
1780
_common_five_ten_crop_marks = [
    xfail_jit_python_scalar_arg("size"),
    mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
]

1781
1782
1783
1784
1785
KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.five_crop_image_tensor,
            sample_inputs_fn=sample_inputs_five_crop_image_tensor,
1786
            reference_fn=multi_crop_pil_reference_wrapper(F.five_crop_image_pil),
1787
            reference_inputs_fn=reference_inputs_five_crop_image_tensor,
1788
            test_marks=_common_five_ten_crop_marks,
1789
        ),
1790
1791
1792
1793
1794
        KernelInfo(
            F.five_crop_video,
            sample_inputs_fn=sample_inputs_five_crop_video,
            test_marks=_common_five_ten_crop_marks,
        ),
1795
1796
1797
        KernelInfo(
            F.ten_crop_image_tensor,
            sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
1798
            reference_fn=multi_crop_pil_reference_wrapper(F.ten_crop_image_pil),
1799
            reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
1800
            test_marks=_common_five_ten_crop_marks,
1801
        ),
1802
1803
1804
1805
1806
        KernelInfo(
            F.ten_crop_video,
            sample_inputs_fn=sample_inputs_ten_crop_video,
            test_marks=_common_five_ten_crop_marks,
        ),
1807
1808
1809
1810
1811
1812
    ]
)

_NORMALIZE_MEANS_STDS = [
    ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
1813
    (0.5, 2.0),
1814
1815
1816
1817
1818
]


def sample_inputs_normalize_image_tensor():
    for image_loader, (mean, std) in itertools.product(
1819
        make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]),
1820
1821
1822
1823
1824
        _NORMALIZE_MEANS_STDS,
    ):
        yield ArgsKwargs(image_loader, mean=mean, std=std)


1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
def reference_normalize_image_tensor(image, mean, std, inplace=False):
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)

    sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub
    return sub(image, mean).div_(std)


def reference_inputs_normalize_image_tensor():
    yield ArgsKwargs(
1835
        make_image_loader(size=(32, 32), color_space="RGB", extra_dims=[1]),
1836
1837
1838
1839
1840
        mean=[0.5, 0.5, 0.5],
        std=[1.0, 1.0, 1.0],
    )


1841
1842
1843
def sample_inputs_normalize_video():
    mean, std = _NORMALIZE_MEANS_STDS[0]
    for video_loader in make_video_loaders(
1844
        sizes=["random"], color_spaces=["RGB"], num_frames=["random"], dtypes=[torch.float32]
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
    ):
        yield ArgsKwargs(video_loader, mean=mean, std=std)


KERNEL_INFOS.extend(
    [
        KernelInfo(
            F.normalize_image_tensor,
            kernel_name="normalize_image_tensor",
            sample_inputs_fn=sample_inputs_normalize_image_tensor,
1855
1856
            reference_fn=reference_normalize_image_tensor,
            reference_inputs_fn=reference_inputs_normalize_image_tensor,
1857
1858
1859
1860
            test_marks=[
                xfail_jit_python_scalar_arg("mean"),
                xfail_jit_python_scalar_arg("std"),
            ],
1861
1862
1863
1864
1865
1866
        ),
        KernelInfo(
            F.normalize_video,
            sample_inputs_fn=sample_inputs_normalize_video,
        ),
    ]
1867
)
1868
1869


1870
def sample_inputs_convert_dtype_image_tensor():
1871
1872
1873
1874
1875
1876
1877
    for input_dtype, output_dtype in itertools.product(
        [torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
    ):
        if input_dtype.is_floating_point and output_dtype == torch.int64:
            # conversion cannot be performed safely
            continue

1878
        for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[input_dtype]):
1879
1880
1881
            yield ArgsKwargs(image_loader, dtype=output_dtype)


1882
def reference_convert_dtype_image_tensor(image, dtype=torch.float):
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
    input_dtype = image.dtype
    output_dtype = dtype

    if output_dtype == input_dtype:
        return image

    def fn(value):
        if input_dtype.is_floating_point:
            if output_dtype.is_floating_point:
                return value
            else:
                return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
        else:
            input_max_value = torch.iinfo(input_dtype).max

            if output_dtype.is_floating_point:
                return float(decimal.Decimal(value) / input_max_value)
            else:
                output_max_value = torch.iinfo(output_dtype).max

                if input_max_value > output_max_value:
                    factor = (input_max_value + 1) // (output_max_value + 1)
                    return value // factor
                else:
                    factor = (output_max_value + 1) // (input_max_value + 1)
                    return value * factor

    return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)


1913
def reference_inputs_convert_dtype_image_tensor():
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
    for input_dtype, output_dtype in itertools.product(
        [
            torch.uint8,
            torch.int16,
            torch.int32,
            torch.int64,
            torch.float16,
            torch.float32,
            torch.float64,
            torch.bfloat16,
        ],
        repeat=2,
    ):
        if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or (
            input_dtype == torch.float64 and output_dtype == torch.int64
        ):
            continue

        if input_dtype.is_floating_point:
            data = [0.0, 0.5, 1.0]
        else:
            max_value = torch.iinfo(input_dtype).max
            data = [0, max_value // 2, max_value]
        image = torch.tensor(data, dtype=input_dtype)

        yield ArgsKwargs(image, dtype=output_dtype)


1942
1943
1944
1945
1946
def sample_inputs_convert_dtype_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
        yield ArgsKwargs(video_loader)


1947
1948
1949
1950
1951
skip_dtype_consistency = TestMark(
    ("TestKernels", "test_dtype_and_device_consistency"),
    pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
    condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32),
)
1952

1953
1954
1955
KERNEL_INFOS.extend(
    [
        KernelInfo(
1956
1957
1958
1959
            F.convert_dtype_image_tensor,
            sample_inputs_fn=sample_inputs_convert_dtype_image_tensor,
            reference_fn=reference_convert_dtype_image_tensor,
            reference_inputs_fn=reference_inputs_convert_dtype_image_tensor,
1960
            test_marks=[
1961
                skip_dtype_consistency,
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
                TestMark(
                    ("TestKernels", "test_against_reference"),
                    pytest.mark.xfail(reason="Conversion overflows"),
                    condition=lambda args_kwargs: (
                        args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
                        and not args_kwargs.kwargs["dtype"].is_floating_point
                    )
                    or (
                        args_kwargs.args[0].dtype in {torch.int32, torch.int64}
                        and args_kwargs.kwargs["dtype"] == torch.float16
                    ),
                ),
            ],
        ),
1976
1977
1978
        KernelInfo(
            F.convert_dtype_video,
            sample_inputs_fn=sample_inputs_convert_dtype_video,
1979
1980
1981
            test_marks=[
                skip_dtype_consistency,
            ],
1982
        ),
1983
1984
    ]
)
1985
1986
1987
1988


def sample_inputs_uniform_temporal_subsample_video():
    for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]):
1989
        yield ArgsKwargs(video_loader, num_samples=2)
1990
1991


1992
def reference_uniform_temporal_subsample_video(x, num_samples):
1993
1994
    # Copy-pasted from
    # https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
1995
    t = x.shape[-4]
1996
1997
1998
1999
    assert num_samples > 0 and t > 0
    # Sample by nearest neighbor interpolation if num_samples > t.
    indices = torch.linspace(0, t - 1, num_samples)
    indices = torch.clamp(indices, 0, t - 1).long()
2000
    return torch.index_select(x, -4, indices)
2001
2002
2003


def reference_inputs_uniform_temporal_subsample_video():
2004
    for video_loader in make_video_loaders(sizes=["random"], color_spaces=["RGB"], num_frames=[10]):
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
        for num_samples in range(1, video_loader.shape[-4] + 1):
            yield ArgsKwargs(video_loader, num_samples)


KERNEL_INFOS.append(
    KernelInfo(
        F.uniform_temporal_subsample_video,
        sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
        reference_fn=reference_uniform_temporal_subsample_video,
        reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
    )
)