transforms_v2_kernel_infos.py 33.2 KB
Newer Older
1
2
3
import functools
import itertools

4
import PIL.Image
5
6
import pytest
import torch.testing
7
import torchvision.transforms.v2.functional as F
8
from torchvision.transforms._functional_tensor import _max_value as get_max_value
9
from transforms_v2_legacy_utils import (
10
    ArgsKwargs,
11
    DEFAULT_PORTRAIT_SPATIAL_SIZE,
12
13
    get_num_channels,
    ImageLoader,
14
    InfoBase,
15
16
17
    make_bounding_box_loaders,
    make_image_loader,
    make_image_loaders,
18
    make_image_loaders_for_interpolation,
19
    make_mask_loaders,
20
    make_video_loaders,
21
22
    mark_framework_limitation,
    TestMark,
23
)
24
25
26
27

__all__ = ["KernelInfo", "KERNEL_INFOS"]


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class KernelInfo(InfoBase):
    def __init__(
        self,
        kernel,
        *,
        # Defaults to `kernel.__name__`. Should be set if the function is exposed under a different name
        # TODO: This can probably be removed after roll-out since we shouldn't have any aliasing then
        kernel_name=None,
        # Most common tests use these inputs to check the kernel. As such it should cover all valid code paths, but
        # should not include extensive parameter combinations to keep to overall test count moderate.
        sample_inputs_fn,
        # This function should mirror the kernel. It should have the same signature as the `kernel` and as such also
        # take tensors as inputs. Any conversion into another object type, e.g. PIL images or numpy arrays, should
        # happen inside the function. It should return a tensor or to be more precise an object that can be compared to
        # a tensor by `assert_close`. If omitted, no reference test will be performed.
        reference_fn=None,
        # These inputs are only used for the reference tests and thus can be comprehensive with regard to the parameter
        # values to be tested. If not specified, `sample_inputs_fn` will be used.
        reference_inputs_fn=None,
47
        # If true-ish, triggers a test that checks the kernel for consistency between uint8 and float32 inputs with the
48
        # reference inputs. This is usually used whenever we use a PIL kernel as reference.
49
50
51
52
        # Can be a callable in which case it will be called with `other_args, kwargs`. It should return the same
        # structure, but with adapted parameters. This is useful in case a parameter value is closely tied to the input
        # dtype.
        float32_vs_uint8=False,
53
54
55
        # Some kernels don't have dispatchers that would handle logging the usage. Thus, the kernel has to do it
        # manually. If set, triggers a test that makes sure this happens.
        logs_usage=False,
56
57
58
59
60
61
62
63
64
65
        # See InfoBase
        test_marks=None,
        # See InfoBase
        closeness_kwargs=None,
    ):
        super().__init__(id=kernel_name or kernel.__name__, test_marks=test_marks, closeness_kwargs=closeness_kwargs)
        self.kernel = kernel
        self.sample_inputs_fn = sample_inputs_fn
        self.reference_fn = reference_fn
        self.reference_inputs_fn = reference_inputs_fn
66

67
68
69
        if float32_vs_uint8 and not callable(float32_vs_uint8):
            float32_vs_uint8 = lambda other_args, kwargs: (other_args, kwargs)  # noqa: E731
        self.float32_vs_uint8 = float32_vs_uint8
70
        self.logs_usage = logs_usage
71
72


73
def pixel_difference_closeness_kwargs(uint8_atol, *, dtype=torch.uint8, mae=False):
74
    return dict(atol=uint8_atol / 255 * get_max_value(dtype), rtol=0, mae=mae)
75
76
77
78


def cuda_vs_cpu_pixel_difference(atol=1):
    return {
79
        (("TestKernels", "test_cuda_vs_cpu"), dtype, "cuda"): pixel_difference_closeness_kwargs(atol, dtype=dtype)
80
81
82
83
        for dtype in [torch.uint8, torch.float32]
    }


84
def pil_reference_pixel_difference(atol=1, mae=False):
85
    return {
86
        (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(
87
            atol, mae=mae
88
89
90
91
        )
    }


92
def float32_vs_uint8_pixel_difference(atol=1, mae=False):
93
94
95
96
97
    return {
        (
            ("TestKernels", "test_float32_vs_uint8"),
            torch.float32,
            "cpu",
98
        ): pixel_difference_closeness_kwargs(atol, dtype=torch.float32, mae=mae)
99
    }
100

101

102
def scripted_vs_eager_float64_tolerances(device, atol=1e-6, rtol=1e-6):
103
104
105
106
107
    return {
        (("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
    }


108
109
def pil_reference_wrapper(pil_kernel):
    @functools.wraps(pil_kernel)
110
111
112
113
    def wrapper(input_tensor, *other_args, **kwargs):
        if input_tensor.dtype != torch.uint8:
            raise pytest.UsageError(f"Can only test uint8 tensor images against PIL, but input is {input_tensor.dtype}")
        if input_tensor.ndim > 3:
114
            raise pytest.UsageError(
115
                f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
116
117
            )

118
        input_pil = F.to_pil_image(input_tensor)
119
120
121
122
        output_pil = pil_kernel(input_pil, *other_args, **kwargs)
        if not isinstance(output_pil, PIL.Image.Image):
            return output_pil

123
        output_tensor = F.to_image(output_pil)
124
125
126
127
128
129
130
131

        # 2D mask shenanigans
        if output_tensor.ndim == 2 and input_tensor.ndim == 3:
            output_tensor = output_tensor.unsqueeze(0)
        elif output_tensor.ndim == 3 and input_tensor.ndim == 2:
            output_tensor = output_tensor.squeeze(0)

        return output_tensor
132
133
134
135

    return wrapper


136
137
138
139
def xfail_jit(reason, *, condition=None):
    return TestMark(("TestKernels", "test_scripted_vs_eager"), pytest.mark.xfail(reason=reason), condition=condition)


140
def xfail_jit_python_scalar_arg(name, *, reason=None):
141
142
    return xfail_jit(
        reason or f"Python scalar int or float for `{name}` is not supported when scripting",
143
144
145
146
        condition=lambda args_kwargs: isinstance(args_kwargs.kwargs.get(name), (int, float)),
    )


147
148
149
KERNEL_INFOS = []


150
def get_fills(*, num_channels, dtype):
151
152
    yield None

153
154
155
156
    int_value = get_max_value(dtype)
    float_value = int_value / 2
    yield int_value
    yield float_value
157

158
159
160
    for vector_type in [list, tuple]:
        yield vector_type([int_value])
        yield vector_type([float_value])
161

162
163
164
        if num_channels > 1:
            yield vector_type(float_value * c / 10 for c in range(num_channels))
            yield vector_type(int_value if c % 2 == 0 else 0 for c in range(num_channels))
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179


def float32_vs_uint8_fill_adapter(other_args, kwargs):
    fill = kwargs.get("fill")
    if fill is None:
        return other_args, kwargs

    if isinstance(fill, (int, float)):
        fill /= 255
    else:
        fill = type(fill)(fill_ / 255 for fill_ in fill)

    return other_args, dict(kwargs, fill=fill)


Philip Meier's avatar
Philip Meier committed
180
181
def _get_elastic_displacement(canvas_size):
    return torch.rand(1, *canvas_size, 2)
182
183
184


def sample_inputs_elastic_image_tensor():
185
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
Philip Meier's avatar
Philip Meier committed
186
        displacement = _get_elastic_displacement(image_loader.canvas_size)
187
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
188
189
190
191
192
            yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)


def reference_inputs_elastic_image_tensor():
    for image_loader, interpolation in itertools.product(
193
        make_image_loaders_for_interpolation(),
194
195
196
197
198
199
        [
            F.InterpolationMode.NEAREST,
            F.InterpolationMode.BILINEAR,
            F.InterpolationMode.BICUBIC,
        ],
    ):
Philip Meier's avatar
Philip Meier committed
200
        displacement = _get_elastic_displacement(image_loader.canvas_size)
201
        for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
202
203
204
            yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)


205
206
def sample_inputs_elastic_bounding_boxes():
    for bounding_boxes_loader in make_bounding_box_loaders():
Philip Meier's avatar
Philip Meier committed
207
        displacement = _get_elastic_displacement(bounding_boxes_loader.canvas_size)
208
        yield ArgsKwargs(
209
210
            bounding_boxes_loader,
            format=bounding_boxes_loader.format,
Philip Meier's avatar
Philip Meier committed
211
            canvas_size=bounding_boxes_loader.canvas_size,
212
213
214
215
216
            displacement=displacement,
        )


def sample_inputs_elastic_mask():
217
    for mask_loader in make_mask_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE]):
218
219
220
221
        displacement = _get_elastic_displacement(mask_loader.shape[-2:])
        yield ArgsKwargs(mask_loader, displacement=displacement)


222
def sample_inputs_elastic_video():
223
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
224
225
226
227
        displacement = _get_elastic_displacement(video_loader.shape[-2:])
        yield ArgsKwargs(video_loader, displacement=displacement)


228
229
230
KERNEL_INFOS.extend(
    [
        KernelInfo(
231
            F.elastic_image,
232
233
            sample_inputs_fn=sample_inputs_elastic_image_tensor,
            reference_inputs_fn=reference_inputs_elastic_image_tensor,
234
            float32_vs_uint8=float32_vs_uint8_fill_adapter,
235
            closeness_kwargs={
236
                **float32_vs_uint8_pixel_difference(6, mae=True),
237
238
                **cuda_vs_cpu_pixel_difference(),
            },
239
            test_marks=[xfail_jit_python_scalar_arg("fill")],
240
241
        ),
        KernelInfo(
242
243
            F.elastic_bounding_boxes,
            sample_inputs_fn=sample_inputs_elastic_bounding_boxes,
244
245
246
247
        ),
        KernelInfo(
            F.elastic_mask,
            sample_inputs_fn=sample_inputs_elastic_mask,
248
249
250
251
        ),
        KernelInfo(
            F.elastic_video,
            sample_inputs_fn=sample_inputs_elastic_video,
252
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
253
254
255
256
257
258
        ),
    ]
)


def sample_inputs_equalize_image_tensor():
259
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
260
261
262
263
        yield ArgsKwargs(image_loader)


def reference_inputs_equalize_image_tensor():
264
265
266
    # We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
    # Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
    # the information gain is low if we already provide something really close to the expected value.
267
    def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor, memory_format):
268
269
270
271
272
273
274
        if dtype.is_floating_point:
            low = low_factor
            high = high_factor
        else:
            max_value = torch.iinfo(dtype).max
            low = int(low_factor * max_value)
            high = int(high_factor * max_value)
275
276
277
        return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high).to(
            memory_format=memory_format, copy=True
        )
278

279
    def make_beta_distributed_image(shape, dtype, device, *, alpha, beta, memory_format):
280
281
282
        image = torch.distributions.Beta(alpha, beta).sample(shape)
        if not dtype.is_floating_point:
            image.mul_(torch.iinfo(dtype).max).round_()
283
        return image.to(dtype=dtype, device=device, memory_format=memory_format, copy=True)
284

Philip Meier's avatar
Philip Meier committed
285
    canvas_size = (256, 256)
286
    for dtype, color_space, fn in itertools.product(
287
        [torch.uint8],
288
        ["GRAY", "RGB"],
289
        [
290
291
            lambda shape, dtype, device, memory_format: torch.zeros(shape, dtype=dtype, device=device).to(
                memory_format=memory_format, copy=True
292
            ),
293
294
295
            lambda shape, dtype, device, memory_format: torch.full(
                shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
            ).to(memory_format=memory_format, copy=True),
296
            *[
297
298
299
300
301
                functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
                for low_factor, high_factor in [
                    (0.0, 0.25),
                    (0.25, 0.75),
                    (0.75, 1.0),
302
303
304
                ]
            ],
            *[
305
                functools.partial(make_beta_distributed_image, alpha=alpha, beta=beta)
306
307
308
309
310
311
312
313
                for alpha, beta in [
                    (0.5, 0.5),
                    (2, 2),
                    (2, 5),
                    (5, 2),
                ]
            ],
        ],
314
    ):
Philip Meier's avatar
Philip Meier committed
315
        image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *canvas_size), dtype=dtype)
316
317
318
        yield ArgsKwargs(image_loader)


319
def sample_inputs_equalize_video():
320
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
321
322
323
324
325
326
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
327
            F.equalize_image,
328
329
            kernel_name="equalize_image_tensor",
            sample_inputs_fn=sample_inputs_equalize_image_tensor,
330
            reference_fn=pil_reference_wrapper(F._equalize_image_pil),
331
            float32_vs_uint8=True,
332
333
334
335
336
337
338
            reference_inputs_fn=reference_inputs_equalize_image_tensor,
        ),
        KernelInfo(
            F.equalize_video,
            sample_inputs_fn=sample_inputs_equalize_video,
        ),
    ]
339
340
341
342
)


def sample_inputs_invert_image_tensor():
343
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
344
345
346
347
        yield ArgsKwargs(image_loader)


def reference_inputs_invert_image_tensor():
348
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
349
350
351
        yield ArgsKwargs(image_loader)


352
def sample_inputs_invert_video():
353
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
354
355
356
357
358
359
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
360
            F.invert_image,
361
362
            kernel_name="invert_image_tensor",
            sample_inputs_fn=sample_inputs_invert_image_tensor,
363
            reference_fn=pil_reference_wrapper(F._invert_image_pil),
364
            reference_inputs_fn=reference_inputs_invert_image_tensor,
365
            float32_vs_uint8=True,
366
367
368
369
370
371
        ),
        KernelInfo(
            F.invert_video,
            sample_inputs_fn=sample_inputs_invert_video,
        ),
    ]
372
373
374
375
376
377
378
)


_POSTERIZE_BITS = [1, 4, 8]


def sample_inputs_posterize_image_tensor():
379
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
380
381
382
383
384
        yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])


def reference_inputs_posterize_image_tensor():
    for image_loader, bits in itertools.product(
385
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
386
387
388
389
390
        _POSTERIZE_BITS,
    ):
        yield ArgsKwargs(image_loader, bits=bits)


391
def sample_inputs_posterize_video():
392
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
393
394
395
396
397
398
        yield ArgsKwargs(video_loader, bits=_POSTERIZE_BITS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
399
            F.posterize_image,
400
401
            kernel_name="posterize_image_tensor",
            sample_inputs_fn=sample_inputs_posterize_image_tensor,
402
            reference_fn=pil_reference_wrapper(F._posterize_image_pil),
403
            reference_inputs_fn=reference_inputs_posterize_image_tensor,
404
405
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
406
407
408
409
410
411
        ),
        KernelInfo(
            F.posterize_video,
            sample_inputs_fn=sample_inputs_posterize_video,
        ),
    ]
412
413
414
415
416
417
418
419
420
421
)


def _get_solarize_thresholds(dtype):
    for factor in [0.1, 0.5]:
        max_value = get_max_value(dtype)
        yield (float if dtype.is_floating_point else int)(max_value * factor)


def sample_inputs_solarize_image_tensor():
422
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
423
424
425
426
        yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype)))


def reference_inputs_solarize_image_tensor():
427
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
428
429
430
431
        for threshold in _get_solarize_thresholds(image_loader.dtype):
            yield ArgsKwargs(image_loader, threshold=threshold)


432
433
434
435
def uint8_to_float32_threshold_adapter(other_args, kwargs):
    return other_args, dict(threshold=kwargs["threshold"] / 255)


436
def sample_inputs_solarize_video():
437
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
438
439
440
441
442
443
        yield ArgsKwargs(video_loader, threshold=next(_get_solarize_thresholds(video_loader.dtype)))


KERNEL_INFOS.extend(
    [
        KernelInfo(
444
            F.solarize_image,
445
446
            kernel_name="solarize_image_tensor",
            sample_inputs_fn=sample_inputs_solarize_image_tensor,
447
            reference_fn=pil_reference_wrapper(F._solarize_image_pil),
448
            reference_inputs_fn=reference_inputs_solarize_image_tensor,
449
450
            float32_vs_uint8=uint8_to_float32_threshold_adapter,
            closeness_kwargs=float32_vs_uint8_pixel_difference(),
451
452
453
454
455
456
        ),
        KernelInfo(
            F.solarize_video,
            sample_inputs_fn=sample_inputs_solarize_video,
        ),
    ]
457
458
459
460
)


def sample_inputs_autocontrast_image_tensor():
461
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
462
463
464
465
        yield ArgsKwargs(image_loader)


def reference_inputs_autocontrast_image_tensor():
466
    for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
467
468
469
        yield ArgsKwargs(image_loader)


470
def sample_inputs_autocontrast_video():
471
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
472
473
474
475
476
477
        yield ArgsKwargs(video_loader)


KERNEL_INFOS.extend(
    [
        KernelInfo(
478
            F.autocontrast_image,
479
480
            kernel_name="autocontrast_image_tensor",
            sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
481
            reference_fn=pil_reference_wrapper(F._autocontrast_image_pil),
482
            reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
483
484
485
486
487
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
            },
488
489
490
491
492
493
        ),
        KernelInfo(
            F.autocontrast_video,
            sample_inputs_fn=sample_inputs_autocontrast_video,
        ),
    ]
494
495
496
497
498
499
500
)

_ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_sharpness_image_tensor():
    for image_loader in make_image_loaders(
501
        sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE, (2, 2)],
502
        color_spaces=("GRAY", "RGB"),
503
504
505
506
507
508
    ):
        yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])


def reference_inputs_adjust_sharpness_image_tensor():
    for image_loader, sharpness_factor in itertools.product(
509
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
510
511
512
513
514
        _ADJUST_SHARPNESS_FACTORS,
    ):
        yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)


515
def sample_inputs_adjust_sharpness_video():
516
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
517
518
519
520
521
522
        yield ArgsKwargs(video_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
523
            F.adjust_sharpness_image,
524
525
            kernel_name="adjust_sharpness_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
526
            reference_fn=pil_reference_wrapper(F._adjust_sharpness_image_pil),
527
            reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
528
529
            float32_vs_uint8=True,
            closeness_kwargs=float32_vs_uint8_pixel_difference(2),
530
531
532
533
534
535
        ),
        KernelInfo(
            F.adjust_sharpness_video,
            sample_inputs_fn=sample_inputs_adjust_sharpness_video,
        ),
    ]
536
537
538
)


539
540
541
542
_ADJUST_CONTRAST_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_contrast_image_tensor():
543
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
544
545
546
547
548
        yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])


def reference_inputs_adjust_contrast_image_tensor():
    for image_loader, contrast_factor in itertools.product(
549
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
550
551
552
553
554
        _ADJUST_CONTRAST_FACTORS,
    ):
        yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)


555
def sample_inputs_adjust_contrast_video():
556
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
557
558
559
560
561
562
        yield ArgsKwargs(video_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
563
            F.adjust_contrast_image,
564
565
            kernel_name="adjust_contrast_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
566
            reference_fn=pil_reference_wrapper(F._adjust_contrast_image_pil),
567
            reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
568
569
570
571
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(2),
572
                **cuda_vs_cpu_pixel_difference(),
573
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
574
            },
575
576
577
578
        ),
        KernelInfo(
            F.adjust_contrast_video,
            sample_inputs_fn=sample_inputs_adjust_contrast_video,
579
580
581
582
            closeness_kwargs={
                **cuda_vs_cpu_pixel_difference(),
                (("TestKernels", "test_against_reference"), torch.uint8, "cpu"): pixel_difference_closeness_kwargs(1),
            },
583
584
        ),
    ]
585
586
587
588
589
590
591
592
593
594
)

_ADJUST_GAMMA_GAMMAS_GAINS = [
    (0.5, 2.0),
    (0.0, 1.0),
]


def sample_inputs_adjust_gamma_image_tensor():
    gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
595
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
596
597
598
599
600
        yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)


def reference_inputs_adjust_gamma_image_tensor():
    for image_loader, (gamma, gain) in itertools.product(
601
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
602
603
604
605
606
        _ADJUST_GAMMA_GAMMAS_GAINS,
    ):
        yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)


607
608
def sample_inputs_adjust_gamma_video():
    gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
609
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
610
611
612
613
614
615
        yield ArgsKwargs(video_loader, gamma=gamma, gain=gain)


KERNEL_INFOS.extend(
    [
        KernelInfo(
616
            F.adjust_gamma_image,
617
618
            kernel_name="adjust_gamma_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
619
            reference_fn=pil_reference_wrapper(F._adjust_gamma_image_pil),
620
            reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
621
622
623
624
625
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(),
            },
626
627
628
629
630
631
        ),
        KernelInfo(
            F.adjust_gamma_video,
            sample_inputs_fn=sample_inputs_adjust_gamma_video,
        ),
    ]
632
633
634
635
636
637
638
)


_ADJUST_HUE_FACTORS = [-0.1, 0.5]


def sample_inputs_adjust_hue_image_tensor():
639
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
640
641
642
643
644
        yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])


def reference_inputs_adjust_hue_image_tensor():
    for image_loader, hue_factor in itertools.product(
645
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
646
647
648
649
650
        _ADJUST_HUE_FACTORS,
    ):
        yield ArgsKwargs(image_loader, hue_factor=hue_factor)


651
def sample_inputs_adjust_hue_video():
652
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
653
654
655
656
657
658
        yield ArgsKwargs(video_loader, hue_factor=_ADJUST_HUE_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
659
            F.adjust_hue_image,
660
661
            kernel_name="adjust_hue_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
662
            reference_fn=pil_reference_wrapper(F._adjust_hue_image_pil),
663
            reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
664
665
            float32_vs_uint8=True,
            closeness_kwargs={
666
                **pil_reference_pixel_difference(2, mae=True),
667
668
                **float32_vs_uint8_pixel_difference(),
            },
669
670
671
672
673
674
        ),
        KernelInfo(
            F.adjust_hue_video,
            sample_inputs_fn=sample_inputs_adjust_hue_video,
        ),
    ]
675
676
677
678
679
680
)

_ADJUST_SATURATION_FACTORS = [0.1, 0.5]


def sample_inputs_adjust_saturation_image_tensor():
681
    for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
682
683
684
685
686
        yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])


def reference_inputs_adjust_saturation_image_tensor():
    for image_loader, saturation_factor in itertools.product(
687
        make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
688
689
690
691
692
        _ADJUST_SATURATION_FACTORS,
    ):
        yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)


693
def sample_inputs_adjust_saturation_video():
694
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
695
696
697
698
699
700
        yield ArgsKwargs(video_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])


KERNEL_INFOS.extend(
    [
        KernelInfo(
701
            F.adjust_saturation_image,
702
703
            kernel_name="adjust_saturation_image_tensor",
            sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
704
            reference_fn=pil_reference_wrapper(F._adjust_saturation_image_pil),
705
            reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
706
707
708
709
            float32_vs_uint8=True,
            closeness_kwargs={
                **pil_reference_pixel_difference(),
                **float32_vs_uint8_pixel_difference(2),
710
                **cuda_vs_cpu_pixel_difference(),
711
            },
712
713
714
715
        ),
        KernelInfo(
            F.adjust_saturation_video,
            sample_inputs_fn=sample_inputs_adjust_saturation_video,
716
            closeness_kwargs=cuda_vs_cpu_pixel_difference(),
717
718
        ),
    ]
719
720
721
)


722
723
def sample_inputs_clamp_bounding_boxes():
    for bounding_boxes_loader in make_bounding_box_loaders():
724
        yield ArgsKwargs(
725
726
            bounding_boxes_loader,
            format=bounding_boxes_loader.format,
Philip Meier's avatar
Philip Meier committed
727
            canvas_size=bounding_boxes_loader.canvas_size,
728
729
730
731
732
        )


KERNEL_INFOS.append(
    KernelInfo(
733
734
        F.clamp_bounding_boxes,
        sample_inputs_fn=sample_inputs_clamp_bounding_boxes,
735
        logs_usage=True,
736
737
738
739
740
741
    )
)

_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]


Philip Meier's avatar
Philip Meier committed
742
def _get_five_ten_crop_canvas_size(size):
743
744
745
746
747
748
749
750
751
752
753
    if isinstance(size, int):
        crop_height = crop_width = size
    elif len(size) == 1:
        crop_height = crop_width = size[0]
    else:
        crop_height, crop_width = size
    return 2 * crop_height, 2 * crop_width


def sample_inputs_five_crop_image_tensor():
    for size in _FIVE_TEN_CROP_SIZES:
754
        for image_loader in make_image_loaders(
Philip Meier's avatar
Philip Meier committed
755
            sizes=[_get_five_ten_crop_canvas_size(size)],
756
            color_spaces=["RGB"],
757
            dtypes=[torch.float32],
758
        ):
759
760
761
762
763
            yield ArgsKwargs(image_loader, size=size)


def reference_inputs_five_crop_image_tensor():
    for size in _FIVE_TEN_CROP_SIZES:
764
        for image_loader in make_image_loaders(
Philip Meier's avatar
Philip Meier committed
765
            sizes=[_get_five_ten_crop_canvas_size(size)], extra_dims=[()], dtypes=[torch.uint8]
766
        ):
767
768
769
            yield ArgsKwargs(image_loader, size=size)


770
771
def sample_inputs_five_crop_video():
    size = _FIVE_TEN_CROP_SIZES[0]
Philip Meier's avatar
Philip Meier committed
772
    for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_canvas_size(size)]):
773
774
775
        yield ArgsKwargs(video_loader, size=size)


776
777
def sample_inputs_ten_crop_image_tensor():
    for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
778
        for image_loader in make_image_loaders(
Philip Meier's avatar
Philip Meier committed
779
            sizes=[_get_five_ten_crop_canvas_size(size)],
780
            color_spaces=["RGB"],
781
            dtypes=[torch.float32],
782
        ):
783
784
785
786
787
            yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


def reference_inputs_ten_crop_image_tensor():
    for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
788
        for image_loader in make_image_loaders(
Philip Meier's avatar
Philip Meier committed
789
            sizes=[_get_five_ten_crop_canvas_size(size)], extra_dims=[()], dtypes=[torch.uint8]
790
        ):
791
792
793
            yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)


794
795
def sample_inputs_ten_crop_video():
    size = _FIVE_TEN_CROP_SIZES[0]
Philip Meier's avatar
Philip Meier committed
796
    for video_loader in make_video_loaders(sizes=[_get_five_ten_crop_canvas_size(size)]):
797
798
799
        yield ArgsKwargs(video_loader, size=size)


800
801
802
803
def multi_crop_pil_reference_wrapper(pil_kernel):
    def wrapper(input_tensor, *other_args, **kwargs):
        output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
        return type(output)(
804
            F.to_dtype_image(F.to_image(output_pil), dtype=input_tensor.dtype, scale=True) for output_pil in output
805
806
807
808
809
        )

    return wrapper


810
811
812
813
814
_common_five_ten_crop_marks = [
    xfail_jit_python_scalar_arg("size"),
    mark_framework_limitation(("TestKernels", "test_batched_vs_single"), "Custom batching needed."),
]

815
816
817
KERNEL_INFOS.extend(
    [
        KernelInfo(
818
            F.five_crop_image,
819
            sample_inputs_fn=sample_inputs_five_crop_image_tensor,
820
            reference_fn=multi_crop_pil_reference_wrapper(F._five_crop_image_pil),
821
            reference_inputs_fn=reference_inputs_five_crop_image_tensor,
822
            test_marks=_common_five_ten_crop_marks,
823
        ),
824
825
826
827
828
        KernelInfo(
            F.five_crop_video,
            sample_inputs_fn=sample_inputs_five_crop_video,
            test_marks=_common_five_ten_crop_marks,
        ),
829
        KernelInfo(
830
            F.ten_crop_image,
831
            sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
832
            reference_fn=multi_crop_pil_reference_wrapper(F._ten_crop_image_pil),
833
            reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
834
            test_marks=_common_five_ten_crop_marks,
835
        ),
836
837
838
839
840
        KernelInfo(
            F.ten_crop_video,
            sample_inputs_fn=sample_inputs_ten_crop_video,
            test_marks=_common_five_ten_crop_marks,
        ),
841
842
843
844
845
846
    ]
)

_NORMALIZE_MEANS_STDS = [
    ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
847
    (0.5, 2.0),
848
849
850
851
852
]


def sample_inputs_normalize_image_tensor():
    for image_loader, (mean, std) in itertools.product(
853
        make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], dtypes=[torch.float32]),
854
855
856
857
858
        _NORMALIZE_MEANS_STDS,
    ):
        yield ArgsKwargs(image_loader, mean=mean, std=std)


859
860
861
862
863
864
865
866
867
868
def reference_normalize_image_tensor(image, mean, std, inplace=False):
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)

    sub = torch.Tensor.sub_ if inplace else torch.Tensor.sub
    return sub(image, mean).div_(std)


def reference_inputs_normalize_image_tensor():
    yield ArgsKwargs(
869
        make_image_loader(size=(32, 32), color_space="RGB", extra_dims=[1]),
870
871
872
873
874
        mean=[0.5, 0.5, 0.5],
        std=[1.0, 1.0, 1.0],
    )


875
876
877
def sample_inputs_normalize_video():
    mean, std = _NORMALIZE_MEANS_STDS[0]
    for video_loader in make_video_loaders(
878
        sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], num_frames=[3], dtypes=[torch.float32]
879
880
881
882
883
884
885
    ):
        yield ArgsKwargs(video_loader, mean=mean, std=std)


KERNEL_INFOS.extend(
    [
        KernelInfo(
886
            F.normalize_image,
887
888
            kernel_name="normalize_image_tensor",
            sample_inputs_fn=sample_inputs_normalize_image_tensor,
889
890
            reference_fn=reference_normalize_image_tensor,
            reference_inputs_fn=reference_inputs_normalize_image_tensor,
891
892
893
894
            test_marks=[
                xfail_jit_python_scalar_arg("mean"),
                xfail_jit_python_scalar_arg("std"),
            ],
895
896
897
898
899
900
        ),
        KernelInfo(
            F.normalize_video,
            sample_inputs_fn=sample_inputs_normalize_video,
        ),
    ]
901
)
902
903


904
def sample_inputs_uniform_temporal_subsample_video():
905
    for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[4]):
906
        yield ArgsKwargs(video_loader, num_samples=2)
907
908


909
def reference_uniform_temporal_subsample_video(x, num_samples):
910
911
    # Copy-pasted from
    # https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
912
    t = x.shape[-4]
913
914
915
916
    assert num_samples > 0 and t > 0
    # Sample by nearest neighbor interpolation if num_samples > t.
    indices = torch.linspace(0, t - 1, num_samples)
    indices = torch.clamp(indices, 0, t - 1).long()
917
    return torch.index_select(x, -4, indices)
918
919
920


def reference_inputs_uniform_temporal_subsample_video():
921
922
923
    for video_loader in make_video_loaders(
        sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=["RGB"], num_frames=[10]
    ):
924
925
926
927
928
929
930
931
932
933
934
935
        for num_samples in range(1, video_loader.shape[-4] + 1):
            yield ArgsKwargs(video_loader, num_samples)


KERNEL_INFOS.append(
    KernelInfo(
        F.uniform_temporal_subsample_video,
        sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
        reference_fn=reference_uniform_temporal_subsample_video,
        reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
    )
)