transforms_v2_kernel_infos.py 55.8 KB
Newer Older
1
2
3
4
import functools
import itertools

import numpy as np
5
import PIL.Image
6
7
import pytest
import torch.testing
8
import torchvision.ops
9
import torchvision.transforms.v2.functional as F
10
from torchvision import tv_tensors
11
12
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
from transforms_v2_legacy_utils import (
13
    ArgsKwargs,
14
    combinations_grid,
15
    DEFAULT_PORTRAIT_SPATIAL_SIZE,
16
17
    get_num_channels,
    ImageLoader,
18
    InfoBase,
19
    make_bounding_box_loader,
20
    make_bounding_box_loaders,
21
    make_detection_mask_loader,
22
23
    make_image_loader,
    make_image_loaders,
24
    make_image_loaders_for_interpolation,
25
    make_mask_loaders,
26
    make_video_loader,
27
    make_video_loaders,
28
29
    mark_framework_limitation,
    TestMark,
30
)
31
32
33
34

__all__ = ["KernelInfo", "KERNEL_INFOS"]


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class KernelInfo(InfoBase):
    def __init__(
        self,
        kernel,
        *,
        # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
        # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
        kernel_name=None,
        # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
        # should not include extensive parameter combinations to keep to overall test count moderate.
        sample_inputs_fn,
        # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
        # take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
        # happen inside the function. It should return a tensor or to be more precise an object that can be compared to
        # a tensor by `assert_close`. If omitted, no reference test will be performed.
        reference_fn=None,
        # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
        # values to be tested. If not specified, `sample_inputs_fn` will be used.
        reference_inputs_fn=None,
54
        # If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
55
        # reference inputs. This is usually used whenever we use a PIL kernel as reference.
56
57
58
59
        # Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
        # structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
        # dtype.
        float32_vs_uint8=False,
60
61
62
        # Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
        # manually. If set, triggers a test that makes sure this happens.
        logs_usage=False,
63
64
65
66
67
68
69
70
71
72
        # See InfoBase
        test_marks=None,
        # See InfoBase
        closeness_kwargs=None,
    ):
        super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
        self.kernel = kernel
        self.sample_inputs_fn = sample_inputs_fn
        self.reference_fn = reference_fn
        self.reference_inputs_fn = reference_inputs_fn
73

74
75
76
        if float32_vs_uint8 and not callable(float32_vs_uint8):
            float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs)  # noqa: E731
        self.float32_vs_uint8 = float32_vs_uint8
77
        self.logs_usage = logs_usage
78
79


80
def pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
81
    return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, mae=mae)
82
83
84
85


def cuda_vs_cpu_pixel_difference(atol=1):
    return {
86
        (("TestKernels", "test_cuda_vs_cpu"), dtype, "cuda"): pixel_difference_closeness_kwargs(atol, dtype=dtype)
87
88
89
90
        for dtype in [torch.uint8, torch.float32]
    }


91
def pil_reference_pixel_difference(atol=1, mae=False):
92
    return {
93
        (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(
94
            atol, mae=mae
95
96
97
98
        )
    }


99
def float32_vs_uint8_pixel_difference(atol=1, mae=False):
100
101
102
103
104
    return {
        (
            ("TestKernels", "test_float32_vs_uint8"),
            torch.float32,
            "cpu",
105
        ): pixel_difference_closeness_kwargs(atol, dtype=torch.float32, mae=mae)
106
    }
107

108

109
def scripted_vs_eager_float64_tolerances(device, atol=1e-6, rtol=1e-6):
110
111
112
113
114
    return {
        (("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
    }


115
116
def pil_reference_wrapper(pil_kernel):
    @functools.wraps(pil_kernel)
117
118
119
120
    def wrapper(input_tensor, *other_args, **kwargs):
        if input_tensor.dtype != torch.uint8:
            raise pytest.UsageError(f"Can only test uint8 tensor images against PIL, but input is {input_tensor.dtype}")
        if input_tensor.ndim > 3:
121
            raise pytest.UsageError(
122
                f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
123
124
            )

125
        input_pil = F.to_pil_image(input_tensor)
126
127
128
129
        output_pil = pil_kernel(input_pil, *other_args, **kwargs)
        if not isinstance(output_pil, PIL.Image.Image):
            return output_pil

130
        output_tensor = F.to_image(output_pil)
131
132
133
134
135
136
137
138

        # 2D mask shenanigans
        if output_tensor.ndim == 2 and input_tensor.ndim == 3:
            output_tensor = output_tensor.unsqueeze(0)
        elif output_tensor.ndim == 3 and input_tensor.ndim == 2:
            output_tensor = output_tensor.squeeze(0)

        return output_tensor
139
140
141
142

    return wrapper


143
144
145
146
def xfail_jit(reason, *, condition=None):
    return TestMark(("TestKernels", "test_scripted_vs_eager"), pytest.mark.xfail(reason=reason), condition=condition)


147
def xfail_jit_python_scalar_arg(name, *, reason=None):
148
149
    return xfail_jit(
        reason or f"Python scalar int or float for `{name}` is not supported when scripting",
150
151
152
153
        condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
    )


154
155
156
KERNEL_INFOS = []


157
def get_fills(*, num_channels, dtype):
158
159
    yield None

160
161
162
163
    int_value = get_max_value(dtype)
    float_value = int_value / 2
    yield int_value
    yield float_value
164

165
166
167
    for vector_type in [list, tuple]:
        yield vector_type([int_value])
        yield vector_type([float_value])
168

169
170
171
        if num_channels > 1:
            yield vector_type(float_value * c / 10 for c in range(num_channels))
            yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels))
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186


def float32_vs_uint8_fill_adapter(other_args, kwargs):
    fill = kwargs.get("fill")
    if fill is None:
        return other_args, kwargs

    if isinstance(fill, (int, float)):
        fill /= 255
    else:
        fill = type(fill)(fill_ / 255 for fill_ in fill)

    return other_args, dict(kwargs, fill=fill)


Philip Meier's avatar
Philip Meier committed
187
188
def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_size, affine_matrix):
    def transform(bbox, affine_matrix_, format_, canvas_size_):
189
190
        # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
        in_dtype = bbox.dtype
191
192
        if not torch.is_floating_point(bbox):
            bbox = bbox.float()
Nicolas Hug's avatar
Nicolas Hug committed
193
        bbox_xyxy = F.convert_bounding_box_format(
194
195
            bbox.as_subclass(torch.Tensor),
            old_format=format_,
196
            new_format=tv_tensors.BoundingBoxFormat.XYXY,
197
            inplace=True,
198
        )
199
200
201
202
203
204
205
206
        points = np.array(
            [
                [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
                [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
                [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
            ]
        )
207
        transformed_points = np.matmul(points, affine_matrix_.T)
208
209
        out_bbox = torch.tensor(
            [
210
211
212
213
                np.min(transformed_points[:, 0]).item(),
                np.min(transformed_points[:, 1]).item(),
                np.max(transformed_points[:, 0]).item(),
                np.max(transformed_points[:, 1]).item(),
214
            ],
215
            dtype=bbox_xyxy.dtype,
216
        )
Nicolas Hug's avatar
Nicolas Hug committed
217
        out_bbox = F.convert_bounding_box_format(
218
            out_bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
219
        )
220
        # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
Philip Meier's avatar
Philip Meier committed
221
        out_bbox = F.clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_)
222
223
        out_bbox = out_bbox.to(dtype=in_dtype)
        return out_bbox
224

225
226
227
    return torch.stack(
        [transform(b, affine_matrix, format, canvas_size) for b in bounding_boxes.reshape(-1, 4).unbind()]
    ).reshape(bounding_boxes.shape)
228
229


Nicolas Hug's avatar
Nicolas Hug committed
230
def sample_inputs_convert_bounding_box_format():
231
    formats = list(tv_tensors.BoundingBoxFormat)
232
233
    for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
        yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)
234
235


Nicolas Hug's avatar
Nicolas Hug committed
236
def reference_convert_bounding_box_format(bounding_boxes, old_format, new_format):
237
    return torchvision.ops.box_convert(
238
239
        bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
    ).to(bounding_boxes.dtype)
240
241


Nicolas Hug's avatar
Nicolas Hug committed
242
243
def reference_inputs_convert_bounding_box_format():
    for args_kwargs in sample_inputs_convert_bounding_box_format():
244
245
        if len(args_kwargs.args[0].shape) == 2:
            yield args_kwargs
246
247
248
249


KERNEL_INFOS.append(
    KernelInfo(
Nicolas Hug's avatar
Nicolas Hug committed
250
251
252
253
        F.convert_bounding_box_format,
        sample_inputs_fn=sample_inputs_convert_bounding_box_format,
        reference_fn=reference_convert_bounding_box_format,
        reference_inputs_fn=reference_inputs_convert_bounding_box_format,
254
        logs_usage=True,
255
256
257
        closeness_kwargs={
            (("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
        },
258
259
260
261
    ),
)


262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
_RESIZED_CROP_PARAMS = combinations_grid(top=[-8, 9], left=[-8, 9], height=[12], width=[12], size=[(16, 18)])


def sample_inputs_resized_crop_image_tensor():
    for image_loader in make_image_loaders():
        yield ArgsKwargs(image_loader, **_RESIZED_CROP_PARAMS[0])


@pil_reference_wrapper
def reference_resized_crop_image_tensor(*args, **kwargs):
    if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in {
        F.InterpolationMode.BILINEAR,
        F.InterpolationMode.BICUBIC,
    }:
        raise pytest.UsageError("Anti-aliasing is always active in PIL")
277
    return F._resized_crop_image_pil(*args, **kwargs)
278
279
280
281


def reference_inputs_resized_crop_image_tensor():
    for image_loader, interpolation, params in itertools.product(
282
        make_image_loaders_for_interpolation(),
283
284
        [
            F.InterpolationMode.NEAREST,
285
            F.InterpolationMode.NEAREST_EXACT,
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            F.InterpolationMode.BILINEAR,
            F.InterpolationMode.BICUBIC,
        ],
        _RESIZED_CROP_PARAMS,
    ):
        yield ArgsKwargs(
            image_loader,
            interpolation=interpolation,
            antialias=interpolation
            in {
                F.InterpolationMode.BILINEAR,
                F.InterpolationMode.BICUBIC,
            },
            **params,
        )


303
304
305
def sample_inputs_resized_crop_bounding_boxes():
    for bounding_boxes_loader in make_bounding_box_loaders():
        yield ArgsKwargs(bounding_boxes_loader, format=bounding_boxes_loader.format, **_RESIZED_CROP_PARAMS[0])
306
307
308
309
310
311
312


def sample_inputs_resized_crop_mask():
    for mask_loader in make_mask_loaders():
        yield ArgsKwargs(mask_loader, **_RESIZED_CROP_PARAMS[0])


313
def sample_inputs_resized_crop_video():
314
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
315
316
317
        yield ArgsKwargs(video_loader, **_RESIZED_CROP_PARAMS[0])


318
319
320
KERNEL_INFOS.extend(
    [
        KernelInfo(
321
            F.resized_crop_image,
322
323
324
            sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
            reference_fn=reference_resized_crop_image_tensor,
            reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
325
            float32_vs_uint8=True,
326
            closeness_kwargs={
327
                **cuda_vs_cpu_pixel_difference(),
328
329
                **pil_reference_pixel_difference(3, mae=True),
                **float32_vs_uint8_pixel_difference(3, mae=True),
330
            },
331
332
        ),
        KernelInfo(
333
334
            F.resized_crop_bounding_boxes,
            sample_inputs_fn=sample_inputs_resized_crop_bounding_boxes,
335
336
337
338
339
        ),
        KernelInfo(
            F.resized_crop_mask,
            sample_inputs_fn=sample_inputs_resized_crop_mask,
        ),
340
341
342
        KernelInfo(
            F.resized_crop_video,
            sample_inputs_fn=sample_inputs_resized_crop_video,
343
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
344
        ),
345
346
347
348
349
350
351
352
353
354
    ]
)

_PAD_PARAMS = combinations_grid(
    padding=[[1], [1, 1], [1, 1, 2, 2]],
    padding_mode=["constant", "symmetric", "edge", "reflect"],
)


def sample_inputs_pad_image_tensor():
355
    make_pad_image_loaders = functools.partial(
356
        make_image_loaders, sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[torch.float32]
357
358
359
360
361
362
363
364
365
    )

    for image_loader, padding in itertools.product(
        make_pad_image_loaders(),
        [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]],
    ):
        yield ArgsKwargs(image_loader, padding=padding)

    for image_loader in make_pad_image_loaders():
366
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
367
368
369
370
371
372
373
374
375
376
377
378
379
            yield ArgsKwargs(image_loader, padding=[1], fill=fill)

    for image_loader, padding_mode in itertools.product(
        # We branch for non-constant padding and integer inputs
        make_pad_image_loaders(dtypes=[torch.uint8]),
        ["constant", "symmetric", "edge", "reflect"],
    ):
        yield ArgsKwargs(image_loader, padding=[1], padding_mode=padding_mode)

    # `torch.nn.functional.pad` does not support symmetric padding, and thus we have a custom implementation. Besides
    # negative padding, this is already handled by the inputs above.
    for image_loader in make_pad_image_loaders():
        yield ArgsKwargs(image_loader, padding=[-1], padding_mode="symmetric")
380
381
382


def reference_inputs_pad_image_tensor():
383
384
385
386
387
388
389
    for image_loader, params in itertools.product(
        make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]), _PAD_PARAMS
    ):
        for fill in get_fills(
            num_channels=image_loader.num_channels,
            dtype=image_loader.dtype,
        ):
390
391
392
393
            # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
            if isinstance(fill, (list, tuple)):
                continue

394
395
396
            yield ArgsKwargs(image_loader, fill=fill, **params)


397
398
def sample_inputs_pad_bounding_boxes():
    for bounding_boxes_loader, padding in itertools.product(
399
400
        make_bounding_box_loaders(), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
    ):
401
        yield ArgsKwargs(
402
403
            bounding_boxes_loader,
            format=bounding_boxes_loader.format,
Philip Meier's avatar
Philip Meier committed
404
            canvas_size=bounding_boxes_loader.canvas_size,
405
406
            padding=padding,
            padding_mode="constant",
407
        )
408
409
410


def sample_inputs_pad_mask():
411
    for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_categories=[10], num_objects=[5]):
412
        yield ArgsKwargs(mask_loader, padding=[1])
413
414
415


def reference_inputs_pad_mask():
416
417
418
419
    for mask_loader, fill, params in itertools.product(
        make_mask_loaders(num_objects=[1], extra_dims=[()]), [None, 127], _PAD_PARAMS
    ):
        yield ArgsKwargs(mask_loader, fill=fill, **params)
420
421


422
def sample_inputs_pad_video():
423
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
424
425
426
        yield ArgsKwargs(video_loader, padding=[1])


Philip Meier's avatar
Philip Meier committed
427
def reference_pad_bounding_boxes(bounding_boxes, *, format, canvas_size, padding, padding_mode):
428
429
430
431
432
433
434
435

    left, right, top, bottom = _parse_pad_padding(padding)

    affine_matrix = np.array(
        [
            [1, 0, left],
            [0, 1, top],
        ],
436
        dtype="float64" if bounding_boxes.dtype == torch.float64 else "float32",
437
438
    )

Philip Meier's avatar
Philip Meier committed
439
440
    height = canvas_size[0] + top + bottom
    width = canvas_size[1] + left + right
441

442
    expected_bboxes = reference_affine_bounding_boxes_helper(
Philip Meier's avatar
Philip Meier committed
443
        bounding_boxes, format=format, canvas_size=(height, width), affine_matrix=affine_matrix
444
    )
445
446
447
    return expected_bboxes, (height, width)


448
449
def reference_inputs_pad_bounding_boxes():
    for bounding_boxes_loader, padding in itertools.product(
450
451
452
        make_bounding_box_loaders(extra_dims=((), (4,))), [1, (1,), (1, 2), (1, 2, 3, 4), [1], [1, 2], [1, 2, 3, 4]]
    ):
        yield ArgsKwargs(
453
454
            bounding_boxes_loader,
            format=bounding_boxes_loader.format,
Philip Meier's avatar
Philip Meier committed
455
            canvas_size=bounding_boxes_loader.canvas_size,
456
457
458
459
460
            padding=padding,
            padding_mode="constant",
        )


461
462
463
464
465
466
467
468
469
470
def pad_xfail_jit_fill_condition(args_kwargs):
    fill = args_kwargs.kwargs.get("fill")
    if not isinstance(fill, (list, tuple)):
        return False
    elif isinstance(fill, tuple):
        return True
    else:  # isinstance(fill, list):
        return all(isinstance(f, int) for f in fill)


471
472
473
KERNEL_INFOS.extend(
    [
        KernelInfo(
474
            F.pad_image,
475
            sample_inputs_fn=sample_inputs_pad_image_tensor,
476
            reference_fn=pil_reference_wrapper(F._pad_image_pil),
477
            reference_inputs_fn=reference_inputs_pad_image_tensor,
478
479
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
480
            test_marks=[
481
482
483
484
                xfail_jit_python_scalar_arg("padding"),
                xfail_jit(
                    "F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition
                ),
485
            ],
486
487
        ),
        KernelInfo(
488
489
490
491
            F.pad_bounding_boxes,
            sample_inputs_fn=sample_inputs_pad_bounding_boxes,
            reference_fn=reference_pad_bounding_boxes,
            reference_inputs_fn=reference_inputs_pad_bounding_boxes,
492
            test_marks=[
493
                xfail_jit_python_scalar_arg("padding"),
494
            ],
495
496
497
498
        ),
        KernelInfo(
            F.pad_mask,
            sample_inputs_fn=sample_inputs_pad_mask,
499
            reference_fn=pil_reference_wrapper(F._pad_image_pil),
500
            reference_inputs_fn=reference_inputs_pad_mask,
501
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
502
        ),
503
504
505
506
        KernelInfo(
            F.pad_video,
            sample_inputs_fn=sample_inputs_pad_video,
        ),
507
508
509
510
511
512
513
    ]
)

_PERSPECTIVE_COEFFS = [
    [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
    [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
]
514
515
_STARTPOINTS = [[0, 1], [2, 3], [4, 5], [6, 7]]
_ENDPOINTS = [[9, 8], [7, 6], [5, 4], [3, 2]]
516
517
518


def sample_inputs_perspective_image_tensor():
519
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
520
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
521
522
523
524
525
            yield ArgsKwargs(
                image_loader, startpoints=None, endpoints=None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0]
            )

    yield ArgsKwargs(make_image_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
526
527
528


def reference_inputs_perspective_image_tensor():
529
530
531
532
533
534
535
    for image_loader, coefficients, interpolation in itertools.product(
        make_image_loaders_for_interpolation(),
        _PERSPECTIVE_COEFFS,
        [
            F.InterpolationMode.NEAREST,
            F.InterpolationMode.BILINEAR,
        ],
536
537
    ):
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
538
539
540
541
            # FIXME: PIL kernel doesn't support sequences of length 1 if the number of channels is larger. Shouldn't it?
            if isinstance(fill, (list, tuple)):
                continue

542
543
544
545
546
547
548
549
            yield ArgsKwargs(
                image_loader,
                startpoints=None,
                endpoints=None,
                interpolation=interpolation,
                fill=fill,
                coefficients=coefficients,
            )
550
551


552
553
def sample_inputs_perspective_bounding_boxes():
    for bounding_boxes_loader in make_bounding_box_loaders():
554
        yield ArgsKwargs(
555
556
            bounding_boxes_loader,
            format=bounding_boxes_loader.format,
Philip Meier's avatar
Philip Meier committed
557
            canvas_size=bounding_boxes_loader.canvas_size,
558
559
560
            startpoints=None,
            endpoints=None,
            coefficients=_PERSPECTIVE_COEFFS[0],
561
562
        )

563
    format = tv_tensors.BoundingBoxFormat.XYXY
564
    loader = make_bounding_box_loader(format=format)
565
    yield ArgsKwargs(
Philip Meier's avatar
Philip Meier committed
566
        loader, format=format, canvas_size=loader.canvas_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
567
568
    )

569
570

def sample_inputs_perspective_mask():
571
    for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
572
573
574
        yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])

    yield ArgsKwargs(make_detection_mask_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
575
576
577
578
579
580


def reference_inputs_perspective_mask():
    for mask_loader, perspective_coeffs in itertools.product(
        make_mask_loaders(extra_dims=[()], num_objects=[1]), _PERSPECTIVE_COEFFS
    ):
581
        yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=perspective_coeffs)
582
583


584
def sample_inputs_perspective_video():
585
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
586
587
588
        yield ArgsKwargs(video_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])

    yield ArgsKwargs(make_video_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
589
590


591
592
593
KERNEL_INFOS.extend(
    [
        KernelInfo(
594
            F.perspective_image,
595
            sample_inputs_fn=sample_inputs_perspective_image_tensor,
596
            reference_fn=pil_reference_wrapper(F._perspective_image_pil),
597
            reference_inputs_fn=reference_inputs_perspective_image_tensor,
598
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
599
            closeness_kwargs={
600
                **pil_reference_pixel_difference(2, mae=True),
601
602
                **cuda_vs_cpu_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
603
604
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
605
            },
606
            test_marks=[xfail_jit_python_scalar_arg("fill")],
607
608
        ),
        KernelInfo(
609
610
            F.perspective_bounding_boxes,
            sample_inputs_fn=sample_inputs_perspective_bounding_boxes,
611
612
613
614
            closeness_kwargs={
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6),
            },
615
616
617
618
        ),
        KernelInfo(
            F.perspective_mask,
            sample_inputs_fn=sample_inputs_perspective_mask,
619
            reference_fn=pil_reference_wrapper(F._perspective_image_pil),
620
            reference_inputs_fn=reference_inputs_perspective_mask,
621
622
623
624
            float32_vs_uint8=True,
            closeness_kwargs={
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): dict(atol=10, rtol=0),
            },
625
626
627
628
        ),
        KernelInfo(
            F.perspective_video,
            sample_inputs_fn=sample_inputs_perspective_video,
629
630
            closeness_kwargs={
                **cuda_vs_cpu_pixel_difference(),
631
632
                **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
                **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
633
            },
634
635
636
637
638
        ),
    ]
)


Philip Meier's avatar
Philip Meier committed
639
640
def _get_elastic_displacement(canvas_size):
    return torch.rand(1, *canvas_size, 2)
641
642
643


def sample_inputs_elastic_image_tensor():
644
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
Philip Meier's avatar
Philip Meier committed
645
        displacement = _get_elastic_displacement(image_loader.canvas_size)
646
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
647
648
649
650
651
            yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)


def reference_inputs_elastic_image_tensor():
    for image_loader, interpolation in itertools.product(
652
        make_image_loaders_for_interpolation(),
653
654
655
656
657
658
        [
            F.InterpolationMode.NEAREST,
            F.InterpolationMode.BILINEAR,
            F.InterpolationMode.BICUBIC,
        ],
    ):
Philip Meier's avatar
Philip Meier committed
659
        displacement = _get_elastic_displacement(image_loader.canvas_size)
660
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
661
662
663
            yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)


664
665
def sample_inputs_elastic_bounding_boxes():
    for bounding_boxes_loader in make_bounding_box_loaders():
Philip Meier's avatar
Philip Meier committed
666
        displacement = _get_elastic_displacement(bounding_boxes_loader.canvas_size)
667
        yield ArgsKwargs(
668
669
            bounding_boxes_loader,
            format=bounding_boxes_loader.format,
Philip Meier's avatar
Philip Meier committed
670
            canvas_size=bounding_boxes_loader.canvas_size,
671
672
673
674
675
            displacement=displacement,
        )


def sample_inputs_elastic_mask():
676
    for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
677
678
679
680
        displacement = _get_elastic_displacement(mask_loader.shape[-2:])
        yield ArgsKwargs(mask_loader, displacement=displacement)


681
def sample_inputs_elastic_video():
682
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
683
684
685
686
        displacement = _get_elastic_displacement(video_loader.shape[-2:])
        yield ArgsKwargs(video_loader, displacement=displacement)


687
688
689
KERNEL_INFOS.extend(
    [
        KernelInfo(
690
            F.elastic_image,
691
692
            sample_inputs_fn=sample_inputs_elastic_image_tensor,
            reference_inputs_fn=reference_inputs_elastic_image_tensor,
693
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
694
            closeness_kwargs={
695
                **float32_vs_uint8_pixel_difference(6, mae=True),
696
697
                **cuda_vs_cpu_pixel_difference(),
            },
698
            test_marks=[xfail_jit_python_scalar_arg("fill")],
699
700
        ),
        KernelInfo(
701
702
            F.elastic_bounding_boxes,
            sample_inputs_fn=sample_inputs_elastic_bounding_boxes,
703
704
705
706
        ),
        KernelInfo(
            F.elastic_mask,
            sample_inputs_fn=sample_inputs_elastic_mask,
707
708
709
710
        ),
        KernelInfo(
            F.elastic_video,
            sample_inputs_fn=sample_inputs_elastic_video,
711
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
712
713
714
715
716
        ),
    ]
)


717
_CENTER_CROP_SPATIAL_SIZES = [(16, 16), (7, 33), (31, 9)]
718
_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
719
720
721
722


def sample_inputs_center_crop_image_tensor():
    for image_loader, output_size in itertools.product(
723
        make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
724
725
726
727
728
729
        [
            # valid `output_size` types for which cropping is applied to both dimensions
            *[5, (4,), (2, 3), [6], [3, 2]],
            # `output_size`'s for which at least one dimension needs to be padded
            *[[4, 18], [17, 5], [17, 18]],
        ],
730
731
732
733
734
735
    ):
        yield ArgsKwargs(image_loader, output_size=output_size)


def reference_inputs_center_crop_image_tensor():
    for image_loader, output_size in itertools.product(
736
737
        make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], dtypes=[torch.uint8]),
        _CENTER_CROP_OUTPUT_SIZES,
738
739
740
741
    ):
        yield ArgsKwargs(image_loader, output_size=output_size)


742
743
def sample_inputs_center_crop_bounding_boxes():
    for bounding_boxes_loader, output_size in itertools.product(make_bounding_box_loaders(), _CENTER_CROP_OUTPUT_SIZES):
744
        yield ArgsKwargs(
745
746
            bounding_boxes_loader,
            format=bounding_boxes_loader.format,
Philip Meier's avatar
Philip Meier committed
747
            canvas_size=bounding_boxes_loader.canvas_size,
748
749
750
751
752
            output_size=output_size,
        )


def sample_inputs_center_crop_mask():
753
    for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_categories=[10], num_objects=[5]):
754
755
        height, width = mask_loader.shape[-2:]
        yield ArgsKwargs(mask_loader, output_size=(height // 2, width // 2))
756
757
758
759


def reference_inputs_center_crop_mask():
    for mask_loader, output_size in itertools.product(
760
        make_mask_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES
761
762
763
764
    ):
        yield ArgsKwargs(mask_loader, output_size=output_size)


765
def sample_inputs_center_crop_video():
766
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
767
768
769
770
        height, width = video_loader.shape[-2:]
        yield ArgsKwargs(video_loader, output_size=(height // 2, width // 2))


771
772
773
KERNEL_INFOS.extend(
    [
        KernelInfo(
774
            F.center_crop_image,
775
            sample_inputs_fn=sample_inputs_center_crop_image_tensor,
776
            reference_fn=pil_reference_wrapper(F._center_crop_image_pil),
777
            reference_inputs_fn=reference_inputs_center_crop_image_tensor,
778
            float32_vs_uint8=True,
779
            test_marks=[
780
                xfail_jit_python_scalar_arg("output_size"),
781
            ],
782
783
        ),
        KernelInfo(
784
785
            F.center_crop_bounding_boxes,
            sample_inputs_fn=sample_inputs_center_crop_bounding_boxes,
786
            test_marks=[
787
                xfail_jit_python_scalar_arg("output_size"),
788
            ],
789
790
791
792
        ),
        KernelInfo(
            F.center_crop_mask,
            sample_inputs_fn=sample_inputs_center_crop_mask,
793
            reference_fn=pil_reference_wrapper(F._center_crop_image_pil),
794
            reference_inputs_fn=reference_inputs_center_crop_mask,
795
            float32_vs_uint8=True,
796
            test_marks=[
797
                xfail_jit_python_scalar_arg("output_size"),
798
            ],
799
        ),
800
801
802
803
        KernelInfo(
            F.center_crop_video,
            sample_inputs_fn=sample_inputs_center_crop_video,
        ),
804
805
806
807
808
    ]
)


def sample_inputs_gaussian_blur_image_tensor():
809
    make_gaussian_blur_image_loaders = functools.partial(make_image_loaders, sizes=[(7, 33)], color_spaces=["RGB"])
810
811
812
813
814
815

    for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
        yield ArgsKwargs(image_loader, kernel_size=kernel_size)

    for image_loader, sigma in itertools.product(
        make_gaussian_blur_image_loaders(), [None, (3.0, 3.0), [2.0, 2.0], 4.0, [1.5], (3.14,)]
816
    ):
817
        yield ArgsKwargs(image_loader, kernel_size=5, sigma=sigma)
818
819


820
def sample_inputs_gaussian_blur_video():
821
    for video_loader in make_video_loaders(sizes=[(7, 33)], num_frames=[5]):
822
823
824
825
826
827
        yield ArgsKwargs(video_loader, kernel_size=[3, 3])


KERNEL_INFOS.extend(
    [
        KernelInfo(
828
            F.gaussian_blur_image,
829
            sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
830
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
831
832
833
834
835
836
837
838
            test_marks=[
                xfail_jit_python_scalar_arg("kernel_size"),
                xfail_jit_python_scalar_arg("sigma"),
            ],
        ),
        KernelInfo(
            F.gaussian_blur_video,
            sample_inputs_fn=sample_inputs_gaussian_blur_video,
839
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
840
841
        ),
    ]
842
843
844
845
)


def sample_inputs_equalize_image_tensor():
846
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
847
848
849
850
        yield ArgsKwargs(image_loader)


def reference_inputs_equalize_image_tensor():
851
852
853
    # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
    # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
    # the information gain is low if we already provide something really close to the expected value.
854
    def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
855
856
857
858
859
860
861
        if dtype.is_floating_point:
            low = low_factor
            high = high_factor
        else:
            max_value = torch.iinfo(dtype).max
            low = int(low_factor * max_value)
            high = int(high_factor * max_value)
862
863
864
        return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
            memory_format=memory_format, copy=True
        )
865

866
    def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
867
868
869
        image = torch.distributions.Beta(alpha, beta).sample(shape)
        if not dtype.is_floating_point:
            image.mul_(torch.iinfo(dtype).max).round_()
870
        return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)
871

Philip Meier's avatar
Philip Meier committed
872
    canvas_size = (256, 256)
873
    for dtype, color_space, fn in itertools.product(
874
        [torch.uint8],
875
        ["GRAY", "RGB"],
876
        [
877
878
            lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
                memory_format=memory_format, copy=True
879
            ),
880
881
882
            lambda shape, dtype, device, memory_format: torch.full(
                shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
            ).to(memory_format=memory_format, copy=True),
883
            *[
884
885
886
887
888
                functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
                for low_factor, high_factor in [
                    (0.0, 0.25),
                    (0.25, 0.75),
                    (0.75, 1.0),
889
890
891
                ]
            ],
            *[
892
                functools.partial(make_beta_distributed_image, alpha=alpha, beta=beta)
893
894
895
896
897
898
899
900
                for alpha, beta in [
                    (0.5, 0.5),
                    (2, 2),
                    (2, 5),
                    (5, 2),
                ]
            ],
        ],
901
    ):
Philip Meier's avatar
Philip Meier committed
902
        image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *canvas_size), dtype=dtype)
903
904
905
        yield ArgsKwargs(image_loader)


906
def sample_inputs_equalize_video():
907
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
908
909
910
911
912
913
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
914
            F.equalize_image,
915
916
            kernel_name="equalize_image_tensor",
            sample_inputs_fn=sample_inputs_equalize_image_tensor,
917
            reference_fn=pil_reference_wrapper(F._equalize_image_pil),
918
            float32_vs_uint8=True,
919
920
921
922
923
924
925
            reference_inputs_fn=reference_inputs_equalize_image_tensor,
        ),
        KernelInfo(
            F.equalize_video,
            sample_inputs_fn=sample_inputs_equalize_video,
        ),
    ]
926
927
928
929
)


def sample_inputs_invert_image_tensor():
930
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
931
932
933
934
        yield ArgsKwargs(image_loader)


def reference_inputs_invert_image_tensor():
935
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
936
937
938
        yield ArgsKwargs(image_loader)


939
def sample_inputs_invert_video():
940
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
941
942
943
944
945
946
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
947
            F.invert_image,
948
949
            kernel_name="invert_image_tensor",
            sample_inputs_fn=sample_inputs_invert_image_tensor,
950
            reference_fn=pil_reference_wrapper(F._invert_image_pil),
951
            reference_inputs_fn=reference_inputs_invert_image_tensor,
952
            float32_vs_uint8=True,
953
954
955
956
957
958
        ),
        KernelInfo(
            F.invert_video,
            sample_inputs_fn=sample_inputs_invert_video,
        ),
    ]
959
960
961
962
963
964
965
)


_POSTERIZE_BITS = [1, 4, 8]


def sample_inputs_posterize_image_tensor():
966
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
967
968
969
970
971
        yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])


def reference_inputs_posterize_image_tensor():
    for image_loader, bits in itertools.product(
972
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
973
974
975
976
977
        _POSTERIZE_BITS,
    ):
        yield ArgsKwargs(image_loader, bits=bits)


978
def sample_inputs_posterize_video():
979
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
980
981
982
983
984
985
        yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
986
            F.posterize_image,
987
988
            kernel_name="posterize_image_tensor",
            sample_inputs_fn=sample_inputs_posterize_image_tensor,
989
            reference_fn=pil_reference_wrapper(F._posterize_image_pil),
990
            reference_inputs_fn=reference_inputs_posterize_image_tensor,
991
992
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
993
994
995
996
997
998
        ),
        KernelInfo(
            F.posterize_video,
            sample_inputs_fn=sample_inputs_posterize_video,
        ),
    ]
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
)


def _get_solarize_thresholds(dtype):
    for factor in [0.1, 0.5]:
        max_value = get_max_value(dtype)
        yield (float if dtype.is_floating_point else int)(max_value * factor)


def sample_inputs_solarize_image_tensor():
1009
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
1010
1011
1012
1013
        yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype)))


def reference_inputs_solarize_image_tensor():
1014
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1015
1016
1017
1018
        for threshold in _get_solarize_thresholds(image_loader.dtype):
            yield ArgsKwargs(image_loader, threshold=threshold)


1019
1020
1021
1022
def uint8_to_float32_threshold_adapter(other_args, kwargs):
    return other_args, dict(threshold=kwargs["threshold"] / 255)


1023
def sample_inputs_solarize_video():
1024
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
1025
1026
1027
1028
1029
1030
        yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype)))


KERNEL_INFOS.extend(
    [
        KernelInfo(
1031
            F.solarize_image,
1032
1033
            kernel_name="solarize_image_tensor",
            sample_inputs_fn=sample_inputs_solarize_image_tensor,
1034
            reference_fn=pil_reference_wrapper(F._solarize_image_pil),
1035
            reference_inputs_fn=reference_inputs_solarize_image_tensor,
1036
1037
            float32_vs_uint8=uint8_to_float32_threshold_adapter,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
1038
1039
1040
1041
1042
1043
        ),
        KernelInfo(
            F.solarize_video,
            sample_inputs_fn=sample_inputs_solarize_video,
        ),
    ]
1044
1045
1046
1047
)


def sample_inputs_autocontrast_image_tensor():
1048
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
1049
1050
1051
1052
        yield ArgsKwargs(image_loader)


def reference_inputs_autocontrast_image_tensor():
1053
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
1054
1055
1056
        yield ArgsKwargs(image_loader)


1057
def sample_inputs_autocontrast_video():
1058
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
1059
1060
1061
1062
1063
1064
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
1065
            F.autocontrast_image,
1066
1067
            kernel_name="autocontrast_image_tensor",
            sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
1068
            reference_fn=pil_reference_wrapper(F._autocontrast_image_pil),
1069
            reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
1070
1071
1072
1073
1074
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
            },
1075
1076
1077
1078
1079
1080
        ),
        KernelInfo(
            F.autocontrast_video,
            sample_inputs_fn=sample_inputs_autocontrast_video,
        ),
    ]
1081
1082
1083
1084
1085
1086
1087
)

_ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_sharpness_image_tensor():
    for image_loader in make_image_loaders(
1088
        sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE, (2, 2)],
1089
        color_spaces=("GRAY", "RGB"),
1090
1091
1092
1093
1094
1095
    ):
        yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])


def reference_inputs_adjust_sharpness_image_tensor():
    for image_loader, sharpness_factor in itertools.product(
1096
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1097
1098
1099
1100
1101
        _ADJUST_SHARPNESS_FACTORS,
    ):
        yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)


1102
def sample_inputs_adjust_sharpness_video():
1103
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
1104
1105
1106
1107
1108
1109
        yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
1110
            F.adjust_sharpness_image,
1111
1112
            kernel_name="adjust_sharpness_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
1113
            reference_fn=pil_reference_wrapper(F._adjust_sharpness_image_pil),
1114
            reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
1115
1116
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(2),
1117
1118
1119
1120
1121
1122
        ),
        KernelInfo(
            F.adjust_sharpness_video,
            sample_inputs_fn=sample_inputs_adjust_sharpness_video,
        ),
    ]
1123
1124
1125
1126
)


def sample_inputs_erase_image_tensor():
1127
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
1128
1129
1130
1131
1132
1133
        # FIXME: make the parameters more diverse
        h, w = 6, 7
        v = torch.rand(image_loader.num_channels, h, w)
        yield ArgsKwargs(image_loader, i=1, j=2, h=h, w=w, v=v)


1134
def sample_inputs_erase_video():
1135
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
1136
1137
1138
1139
1140
1141
1142
1143
1144
        # FIXME: make the parameters more diverse
        h, w = 6, 7
        v = torch.rand(video_loader.num_channels, h, w)
        yield ArgsKwargs(video_loader, i=1, j=2, h=h, w=w, v=v)


KERNEL_INFOS.extend(
    [
        KernelInfo(
1145
            F.erase_image,
1146
1147
1148
1149
1150
1151
1152
1153
            kernel_name="erase_image_tensor",
            sample_inputs_fn=sample_inputs_erase_image_tensor,
        ),
        KernelInfo(
            F.erase_video,
            sample_inputs_fn=sample_inputs_erase_video,
        ),
    ]
1154
)
1155
1156
1157
1158
1159

_ADJUST_CONTRAST_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_contrast_image_tensor():
1160
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
1161
1162
1163
1164
1165
        yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])


def reference_inputs_adjust_contrast_image_tensor():
    for image_loader, contrast_factor in itertools.product(
1166
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1167
1168
1169
1170
1171
        _ADJUST_CONTRAST_FACTORS,
    ):
        yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)


1172
def sample_inputs_adjust_contrast_video():
1173
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
1174
1175
1176
1177
1178
1179
        yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
1180
            F.adjust_contrast_image,
1181
1182
            kernel_name="adjust_contrast_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
1183
            reference_fn=pil_reference_wrapper(F._adjust_contrast_image_pil),
1184
            reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
1185
1186
1187
1188
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(2),
1189
                **cuda_vs_cpu_pixel_difference(),
1190
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
1191
            },
1192
1193
1194
1195
        ),
        KernelInfo(
            F.adjust_contrast_video,
            sample_inputs_fn=sample_inputs_adjust_contrast_video,
1196
1197
1198
1199
            closeness_kwargs={
                **cuda_vs_cpu_pixel_difference(),
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
            },
1200
1201
        ),
    ]
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
)

_ADJUST_GAMMA_GAMMAS_GAINS = [
    (0.5, 2.0),
    (0.0, 1.0),
]


def sample_inputs_adjust_gamma_image_tensor():
    gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
1212
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
1213
1214
1215
1216
1217
        yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)


def reference_inputs_adjust_gamma_image_tensor():
    for image_loader, (gamma, gain) in itertools.product(
1218
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1219
1220
1221
1222
1223
        _ADJUST_GAMMA_GAMMAS_GAINS,
    ):
        yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)


1224
1225
def sample_inputs_adjust_gamma_video():
    gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
1226
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
1227
1228
1229
1230
1231
1232
        yield ArgsKwargs(video_loader, gamma=gamma, gain=gain)


KERNEL_INFOS.extend(
    [
        KernelInfo(
1233
            F.adjust_gamma_image,
1234
1235
            kernel_name="adjust_gamma_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
1236
            reference_fn=pil_reference_wrapper(F._adjust_gamma_image_pil),
1237
            reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
1238
1239
1240
1241
1242
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
            },
1243
1244
1245
1246
1247
1248
        ),
        KernelInfo(
            F.adjust_gamma_video,
            sample_inputs_fn=sample_inputs_adjust_gamma_video,
        ),
    ]
1249
1250
1251
1252
1253
1254
1255
)


_ADJUST_HUE_FACTORS = [-0.1, 0.5]


def sample_inputs_adjust_hue_image_tensor():
1256
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
1257
1258
1259
1260
1261
        yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])


def reference_inputs_adjust_hue_image_tensor():
    for image_loader, hue_factor in itertools.product(
1262
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1263
1264
1265
1266
1267
        _ADJUST_HUE_FACTORS,
    ):
        yield ArgsKwargs(image_loader, hue_factor=hue_factor)


1268
def sample_inputs_adjust_hue_video():
1269
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
1270
1271
1272
1273
1274
1275
        yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
1276
            F.adjust_hue_image,
1277
1278
            kernel_name="adjust_hue_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
1279
            reference_fn=pil_reference_wrapper(F._adjust_hue_image_pil),
1280
            reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
1281
1282
            float32_vs_uint8=True,
            closeness_kwargs={
1283
                **pil_reference_pixel_difference(2, mae=True),
1284
1285
                **float32_vs_uint8_pixel_difference(),
            },
1286
1287
1288
1289
1290
1291
        ),
        KernelInfo(
            F.adjust_hue_video,
            sample_inputs_fn=sample_inputs_adjust_hue_video,
        ),
    ]
1292
1293
1294
1295
1296
1297
)

_ADJUST_SATURATION_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_saturation_image_tensor():
1298
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
1299
1300
1301
1302
1303
        yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])


def reference_inputs_adjust_saturation_image_tensor():
    for image_loader, saturation_factor in itertools.product(
1304
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
1305
1306
1307
1308
1309
        _ADJUST_SATURATION_FACTORS,
    ):
        yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)


1310
def sample_inputs_adjust_saturation_video():
1311
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
1312
1313
1314
1315
1316
1317
        yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
1318
            F.adjust_saturation_image,
1319
1320
            kernel_name="adjust_saturation_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
1321
            reference_fn=pil_reference_wrapper(F._adjust_saturation_image_pil),
1322
            reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
1323
1324
1325
1326
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(2),
1327
                **cuda_vs_cpu_pixel_difference(),
1328
            },
1329
1330
1331
1332
        ),
        KernelInfo(
            F.adjust_saturation_video,
            sample_inputs_fn=sample_inputs_adjust_saturation_video,
1333
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
1334
1335
        ),
    ]
1336
1337
1338
)


1339
1340
def sample_inputs_clamp_bounding_boxes():
    for bounding_boxes_loader in make_bounding_box_loaders():
1341
        yield ArgsKwargs(
1342
1343
            bounding_boxes_loader,
            format=bounding_boxes_loader.format,
Philip Meier's avatar
Philip Meier committed
1344
            canvas_size=bounding_boxes_loader.canvas_size,
1345
1346
1347
1348
1349
        )


KERNEL_INFOS.append(
    KernelInfo(
1350
1351
        F.clamp_bounding_boxes,
        sample_inputs_fn=sample_inputs_clamp_bounding_boxes,
1352
        logs_usage=True,
1353
1354
1355
1356
1357
1358
    )
)

_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]


Philip Meier's avatar
Philip Meier committed
1359
def _get_five_ten_crop_canvas_size(size):
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
    if isinstance(size, int):
        crop_height = crop_width = size
    elif len(size) == 1:
        crop_height = crop_width = size[0]
    else:
        crop_height, crop_width = size
    return 2 * crop_height, 2 * crop_width


def sample_inputs_five_crop_image_tensor():
    for size in _FIVE_TEN_CROP_SIZES:
1371
        for image_loader in make_image_loaders(
Philip Meier's avatar
Philip Meier committed
1372
            sizes=[_get_five_ten_crop_canvas_size(size)],
1373
            color_spaces=["RGB"],
1374
            dtypes=[torch.float32],
1375
        ):
1376
1377
1378
1379
1380
            yield ArgsKwargs(image_loader, size=size)


def reference_inputs_five_crop_image_tensor():
    for size in _FIVE_TEN_CROP_SIZES:
1381
        for image_loader in make_image_loaders(
Philip Meier's avatar
Philip Meier committed
1382
            sizes=[_get_five_ten_crop_canvas_size(size)], extra_dims=[()], dtypes=[torch.uint8]
1383
        ):
1384
1385
1386
            yield ArgsKwargs(image_loader, size=size)


1387
1388
def sample_inputs_five_crop_video():
    size = _FIVE_TEN_CROP_SIZES[0]
Philip Meier's avatar
Philip Meier committed
1389
    for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_canvas_size(size)]):
1390
1391
1392
        yield ArgsKwargs(video_loader, size=size)


1393
1394
def sample_inputs_ten_crop_image_tensor():
    for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
1395
        for image_loader in make_image_loaders(
Philip Meier's avatar
Philip Meier committed
1396
            sizes=[_get_five_ten_crop_canvas_size(size)],
1397
            color_spaces=["RGB"],
1398
            dtypes=[torch.float32],
1399
        ):
1400
1401
1402
1403
1404
            yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


def reference_inputs_ten_crop_image_tensor():
    for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
1405
        for image_loader in make_image_loaders(
Philip Meier's avatar
Philip Meier committed
1406
            sizes=[_get_five_ten_crop_canvas_size(size)], extra_dims=[()], dtypes=[torch.uint8]
1407
        ):
1408
1409
1410
            yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


1411
1412
def sample_inputs_ten_crop_video():
    size = _FIVE_TEN_CROP_SIZES[0]
Philip Meier's avatar
Philip Meier committed
1413
    for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_canvas_size(size)]):
1414
1415
1416
        yield ArgsKwargs(video_loader, size=size)


1417
1418
1419
1420
def multi_crop_pil_reference_wrapper(pil_kernel):
    def wrapper(input_tensor, *other_args, **kwargs):
        output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
        return type(output)(
1421
            F.to_dtype_image(F.to_image(output_pil), dtype=input_tensor.dtype, scale=True) for output_pil in output
1422
1423
1424
1425
1426
        )

    return wrapper


1427
1428
1429
1430
1431
_common_five_ten_crop_marks = [
    xfail_jit_python_scalar_arg("size"),
    mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
]

1432
1433
1434
KERNEL_INFOS.extend(
    [
        KernelInfo(
1435
            F.five_crop_image,
1436
            sample_inputs_fn=sample_inputs_five_crop_image_tensor,
1437
            reference_fn=multi_crop_pil_reference_wrapper(F._five_crop_image_pil),
1438
            reference_inputs_fn=reference_inputs_five_crop_image_tensor,
1439
            test_marks=_common_five_ten_crop_marks,
1440
        ),
1441
1442
1443
1444
1445
        KernelInfo(
            F.five_crop_video,
            sample_inputs_fn=sample_inputs_five_crop_video,
            test_marks=_common_five_ten_crop_marks,
        ),
1446
        KernelInfo(
1447
            F.ten_crop_image,
1448
            sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
1449
            reference_fn=multi_crop_pil_reference_wrapper(F._ten_crop_image_pil),
1450
            reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
1451
            test_marks=_common_five_ten_crop_marks,
1452
        ),
1453
1454
1455
1456
1457
        KernelInfo(
            F.ten_crop_video,
            sample_inputs_fn=sample_inputs_ten_crop_video,
            test_marks=_common_five_ten_crop_marks,
        ),
1458
1459
1460
1461
1462
1463
    ]
)

_NORMALIZE_MEANS_STDS = [
    ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
1464
    (0.5, 2.0),
1465
1466
1467
1468
1469
]


def sample_inputs_normalize_image_tensor():
    for image_loader, (mean, std) in itertools.product(
1470
        make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[torch.float32]),
1471
1472
1473
1474
1475
        _NORMALIZE_MEANS_STDS,
    ):
        yield ArgsKwargs(image_loader, mean=mean, std=std)


1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
def reference_normalize_image_tensor(image, mean, std, inplace=False):
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)

    sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub
    return sub(image, mean).div_(std)


def reference_inputs_normalize_image_tensor():
    yield ArgsKwargs(
1486
        make_image_loader(size=(32, 32), color_space="RGB", extra_dims=[1]),
1487
1488
1489
1490
1491
        mean=[0.5, 0.5, 0.5],
        std=[1.0, 1.0, 1.0],
    )


1492
1493
1494
def sample_inputs_normalize_video():
    mean, std = _NORMALIZE_MEANS_STDS[0]
    for video_loader in make_video_loaders(
1495
        sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], num_frames=[3], dtypes=[torch.float32]
1496
1497
1498
1499
1500
1501
1502
    ):
        yield ArgsKwargs(video_loader, mean=mean, std=std)


KERNEL_INFOS.extend(
    [
        KernelInfo(
1503
            F.normalize_image,
1504
1505
            kernel_name="normalize_image_tensor",
            sample_inputs_fn=sample_inputs_normalize_image_tensor,
1506
1507
            reference_fn=reference_normalize_image_tensor,
            reference_inputs_fn=reference_inputs_normalize_image_tensor,
1508
1509
1510
1511
            test_marks=[
                xfail_jit_python_scalar_arg("mean"),
                xfail_jit_python_scalar_arg("std"),
            ],
1512
1513
1514
1515
1516
1517
        ),
        KernelInfo(
            F.normalize_video,
            sample_inputs_fn=sample_inputs_normalize_video,
        ),
    ]
1518
)
1519
1520


1521
def sample_inputs_uniform_temporal_subsample_video():
1522
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[4]):
1523
        yield ArgsKwargs(video_loader, num_samples=2)
1524
1525


1526
def reference_uniform_temporal_subsample_video(x, num_samples):
1527
1528
    # Copy-pasted from
    # https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
1529
    t = x.shape[-4]
1530
1531
1532
1533
    assert num_samples > 0 and t > 0
    # Sample by nearest neighbor interpolation if num_samples > t.
    indices = torch.linspace(0, t - 1, num_samples)
    indices = torch.clamp(indices, 0, t - 1).long()
1534
    return torch.index_select(x, -4, indices)
1535
1536
1537


def reference_inputs_uniform_temporal_subsample_video():
1538
1539
1540
    for video_loader in make_video_loaders(
        sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], num_frames=[10]
    ):
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
        for num_samples in range(1, video_loader.shape[-4] + 1):
            yield ArgsKwargs(video_loader, num_samples)


KERNEL_INFOS.append(
    KernelInfo(
        F.uniform_temporal_subsample_video,
        sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
        reference_fn=reference_uniform_temporal_subsample_video,
        reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
    )
)