"vscode:/vscode.git/clone" did not exist on "6d6a8bc278eac424214e73544ae010bde3fb99cb"
test_transforms_v2_consistency.py 25.8 KB
Newer Older
1
2
import importlib.machinery
import importlib.util
3
import inspect
4
import random
5
import re
6
from pathlib import Path
7

8
import numpy as np
9
import pytest
10
11

import torch
12
import torchvision.transforms.v2 as v2_transforms
13
from common_utils import assert_close, assert_equal, set_rng_seed
14
from torch import nn
15
from torchvision import transforms as legacy_transforms, tv_tensors
16
from torchvision._utils import sequence_to_str
17

18
from torchvision.transforms import functional as legacy_F
19
from torchvision.transforms.v2 import functional as prototype_F
Nicolas Hug's avatar
Nicolas Hug committed
20
from torchvision.transforms.v2._utils import _get_fill, query_size
21
from torchvision.transforms.v2.functional import to_pil_image
22
23
24
25
26
27
28
29
from transforms_v2_legacy_utils import (
    ArgsKwargs,
    make_bounding_boxes,
    make_detection_mask,
    make_image,
    make_images,
    make_segmentation_mask,
)
30

31
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
32
33


Nicolas Hug's avatar
Nicolas Hug committed
34
35
36
37
38
39
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


40
41
42
43
44
45
46
47
48
class NotScriptableArgsKwargs(ArgsKwargs):
    """
    This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
    thus will be tested there, but will be skipped by the JIT tests.
    """

    pass


49
50
class ConsistencyConfig:
    def __init__(
51
52
53
        self,
        prototype_cls,
        legacy_cls,
54
55
        # If no args_kwargs is passed, only the signature will be checked
        args_kwargs=(),
56
57
58
        make_images_kwargs=None,
        supports_pil=True,
        removed_params=(),
59
        closeness_kwargs=None,
60
61
62
    ):
        self.prototype_cls = prototype_cls
        self.legacy_cls = legacy_cls
63
        self.args_kwargs = args_kwargs
64
65
        self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
        self.supports_pil = supports_pil
66
        self.removed_params = removed_params
67
        self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
68
69


70
71
72
73
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

74
CONSISTENCY_CONFIGS = [
75
76
    *[
        ConsistencyConfig(
77
            v2_transforms.LinearTransformation,
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
            legacy_transforms.LinearTransformation,
            [
                ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
            ],
            # Make sure that the product of the height, width and number of channels matches the number of elements in
            # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
            make_images_kwargs=dict(
                DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
            ),
            supports_pil=False,
        )
        for matrix_dtype, image_dtype in [
            (torch.float32, torch.float32),
            (torch.float64, torch.float64),
            (torch.float32, torch.uint8),
            (torch.float64, torch.float32),
            (torch.float32, torch.float64),
        ]
    ],
97
    ConsistencyConfig(
98
        v2_transforms.ToPILImage,
99
        legacy_transforms.ToPILImage,
100
        [NotScriptableArgsKwargs()],
101
102
        make_images_kwargs=dict(
            color_spaces=[
103
104
105
106
                "GRAY",
                "GRAY_ALPHA",
                "RGB",
                "RGBA",
107
108
109
110
111
112
            ],
            extra_dims=[()],
        ),
        supports_pil=False,
    ),
    ConsistencyConfig(
113
        v2_transforms.Lambda,
114
115
        legacy_transforms.Lambda,
        [
116
            NotScriptableArgsKwargs(lambda image: image / 2),
117
118
119
120
121
        ],
        # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
        # images given that the transform does nothing but call it anyway.
        supports_pil=False,
    ),
122
    ConsistencyConfig(
123
        v2_transforms.PILToTensor,
124
125
126
        legacy_transforms.PILToTensor,
    ),
    ConsistencyConfig(
127
        v2_transforms.ToTensor,
128
129
130
        legacy_transforms.ToTensor,
    ),
    ConsistencyConfig(
131
        v2_transforms.Compose,
132
133
134
        legacy_transforms.Compose,
    ),
    ConsistencyConfig(
135
        v2_transforms.RandomApply,
136
137
138
        legacy_transforms.RandomApply,
    ),
    ConsistencyConfig(
139
        v2_transforms.RandomChoice,
140
141
142
        legacy_transforms.RandomChoice,
    ),
    ConsistencyConfig(
143
        v2_transforms.RandomOrder,
144
145
        legacy_transforms.RandomOrder,
    ),
146
147
148
]


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
    legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
    prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

    for param in config.removed_params:
        legacy_params.pop(param, None)

    missing = legacy_params.keys() - prototype_params.keys()
    if missing:
        raise AssertionError(
            f"The prototype transform does not support the parameters "
            f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
            f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
            f"the `ConsistencyConfig`."
        )

    extra = prototype_params.keys() - legacy_params.keys()
167
168
169
170
171
172
    extra_without_default = {
        param
        for param in extra
        if prototype_params[param].default is inspect.Parameter.empty
        and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
    }
173
174
    if extra_without_default:
        raise AssertionError(
175
176
177
            f"The prototype transform requires the parameters "
            f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
            f"not. Please add a default value."
178
179
        )

180
181
182
183
184
185
    legacy_signature = list(legacy_params.keys())
    # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
    # to the same number of parameters as the legacy one
    prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

    assert prototype_signature == legacy_signature
186
187


188
189
190
def check_call_consistency(
    prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
191
192
    if images is None:
        images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
193

194
195
    closeness_kwargs = closeness_kwargs or dict()

196
197
    for image in images:
        image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
198
199
200

        image_tensor = torch.Tensor(image)
        try:
201
            torch.manual_seed(0)
202
            output_legacy_tensor = legacy_transform(image_tensor)
203
204
        except Exception as exc:
            raise pytest.UsageError(
205
                f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
206
                f"error above. This means that you need to specify the parameters passed to `make_images` through the "
207
208
209
210
                "`make_images_kwargs` of the `ConsistencyConfig`."
            ) from exc

        try:
211
            torch.manual_seed(0)
212
            output_prototype_tensor = prototype_transform(image_tensor)
213
214
        except Exception as exc:
            raise AssertionError(
215
                f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
216
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
217
                f"`is_pure_tensor` path in `_transform`."
218
219
            ) from exc

220
        assert_close(
221
222
223
            output_prototype_tensor,
            output_legacy_tensor,
            msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
224
            **closeness_kwargs,
225
226
227
        )

        try:
228
            torch.manual_seed(0)
229
            output_prototype_image = prototype_transform(image)
230
231
        except Exception as exc:
            raise AssertionError(
232
                f"Transforming a image tv_tensor with shape {image_repr} failed in the prototype transform with "
233
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
234
                f"`tv_tensors.Image` path in `_transform`."
235
236
            ) from exc

237
        assert_close(
238
            output_prototype_image,
239
            output_prototype_tensor,
240
            msg=lambda msg: f"Output for tv_tensor and tensor images is not equal: \n\n{msg}",
241
            **closeness_kwargs,
242
243
        )

244
        if image.ndim == 3 and supports_pil:
245
            image_pil = to_pil_image(image)
246

247
            try:
248
                torch.manual_seed(0)
249
                output_legacy_pil = legacy_transform(image_pil)
250
251
            except Exception as exc:
                raise pytest.UsageError(
252
                    f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
253
254
255
256
257
                    f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
                    "`ConsistencyConfig`. "
                ) from exc

            try:
258
                torch.manual_seed(0)
259
                output_prototype_pil = prototype_transform(image_pil)
260
261
            except Exception as exc:
                raise AssertionError(
262
                    f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
263
264
265
266
                    f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                    f"`PIL.Image.Image` path in `_transform`."
                ) from exc

267
            assert_close(
268
269
                output_prototype_pil,
                output_legacy_pil,
270
                msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
271
                **closeness_kwargs,
272
            )
273
274


275
@pytest.mark.parametrize(
276
277
    ("config", "args_kwargs"),
    [
278
279
280
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
281
        for config in CONSISTENCY_CONFIGS
282
        for idx, args_kwargs in enumerate(config.args_kwargs)
283
    ],
284
)
285
@pytest.mark.filterwarnings("ignore")
286
def test_call_consistency(config, args_kwargs):
287
288
289
    args, kwargs = args_kwargs

    try:
290
        legacy_transform = config.legacy_cls(*args, **kwargs)
291
292
293
294
295
296
297
    except Exception as exc:
        raise pytest.UsageError(
            f"Initializing the legacy transform failed with the error above. "
            f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
        ) from exc

    try:
298
        prototype_transform = config.prototype_cls(*args, **kwargs)
299
300
301
302
303
304
    except Exception as exc:
        raise AssertionError(
            "Initializing the prototype transform failed with the error above. "
            "This means there is a consistency bug in the constructor."
        ) from exc

305
306
307
308
309
    check_call_consistency(
        prototype_transform,
        legacy_transform,
        images=make_images(**config.make_images_kwargs),
        supports_pil=config.supports_pil,
310
        closeness_kwargs=config.closeness_kwargs,
311
312
313
    )


314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
@pytest.mark.parametrize(
    ("config", "args_kwargs"),
    [
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
        for config in CONSISTENCY_CONFIGS
        for idx, args_kwargs in enumerate(config.args_kwargs)
        if not isinstance(args_kwargs, NotScriptableArgsKwargs)
    ],
)
def test_jit_consistency(config, args_kwargs):
    args, kwargs = args_kwargs

    prototype_transform_eager = config.prototype_cls(*args, **kwargs)
    legacy_transform_eager = config.legacy_cls(*args, **kwargs)

    legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
    prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

    for image in make_images(**config.make_images_kwargs):
        image = image.as_subclass(torch.Tensor)

        torch.manual_seed(0)
        output_legacy_scripted = legacy_transform_scripted(image)

        torch.manual_seed(0)
        output_prototype_scripted = prototype_transform_scripted(image)

        assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


346
347
348
349
350
351
352
353
354
355
class TestContainerTransforms:
    """
    Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
    consistency automatically tests the wrapped transforms consistency.

    Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
    that were already tested for consistency above.
    """

    def test_compose(self):
356
        prototype_transform = v2_transforms.Compose(
357
            [
358
359
                v2_transforms.Resize(256),
                v2_transforms.CenterCrop(224),
360
361
362
363
364
365
366
367
368
            ]
        )
        legacy_transform = legacy_transforms.Compose(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ]
        )

369
370
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
371
372

    @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
373
374
    @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
    def test_random_apply(self, p, sequence_type):
375
        prototype_transform = v2_transforms.RandomApply(
376
377
            sequence_type(
                [
378
379
                    v2_transforms.Resize(256),
                    v2_transforms.CenterCrop(224),
380
381
                ]
            ),
382
383
384
            p=p,
        )
        legacy_transform = legacy_transforms.RandomApply(
385
386
387
388
389
390
            sequence_type(
                [
                    legacy_transforms.Resize(256),
                    legacy_transforms.CenterCrop(224),
                ]
            ),
391
392
393
            p=p,
        )

394
395
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
396

397
398
399
400
401
        if sequence_type is nn.ModuleList:
            # quick and dirty test that it is jit-scriptable
            scripted = torch.jit.script(prototype_transform)
            scripted(torch.rand(1, 3, 300, 300))

402
    # We can't test other values for `p` since the random parameter generation is different
403
404
    @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
    def test_random_choice(self, probabilities):
405
        prototype_transform = v2_transforms.RandomChoice(
406
            [
407
                v2_transforms.Resize(256),
408
409
                legacy_transforms.CenterCrop(224),
            ],
410
            p=probabilities,
411
412
413
414
415
416
        )
        legacy_transform = legacy_transforms.RandomChoice(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ],
417
            p=probabilities,
418
419
        )

420
421
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
422
423


424
425
class TestToTensorTransforms:
    def test_pil_to_tensor(self):
426
        prototype_transform = v2_transforms.PILToTensor()
427
428
        legacy_transform = legacy_transforms.PILToTensor()

429
        for image in make_images(extra_dims=[()]):
430
            image_pil = to_pil_image(image)
431
432
433
434

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

    def test_to_tensor(self):
435
        with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
436
            prototype_transform = v2_transforms.ToTensor()
437
438
        legacy_transform = legacy_transforms.ToTensor()

439
        for image in make_images(extra_dims=[()]):
440
            image_pil = to_pil_image(image)
441
442
443
444
            image_numpy = np.array(image_pil)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
            assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
445
446


447
def import_transforms_from_references(reference):
448
449
450
451
452
453
454
455
456
457
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
458
459
460


det_transforms = import_transforms_from_references("detection")
461
462
463


class TestRefDetTransforms:
464
    def make_tv_tensors(self, with_mask=True):
465
466
467
        size = (600, 800)
        num_objects = 22

468
469
470
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

471
        pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
472
        target = {
473
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
474
475
476
477
478
479
480
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

481
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
482
        target = {
483
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
484
485
486
487
488
489
490
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

491
        tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
492
        target = {
493
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
494
495
496
497
498
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

499
        yield (tv_tensor_image, target)
500
501
502
503

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
504
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
505
506
507
508
509
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
510
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
511
512
513
514
                    ]
                ),
                {"with_mask": False},
            ),
515
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
516
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
517
518
519
520
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
521
                v2_transforms.RandomShortestSize(
522
523
524
525
526
527
528
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
529
        for dp in self.make_tv_tensors(**data_kwargs):
530
531
532
533
534
535
536
537
538

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
539
540
541
542
543
544
545
546
547


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
548
class PadIfSmaller(v2_transforms.Transform):
549
550
551
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
552
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
553
554

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
555
        height, width = query_size(sample)
556
557
558
559
560
561
562
563
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

564
        fill = _get_fill(self.fill, type(inpt))
565
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
566
567
568


class TestRefSegTransforms:
569
    def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
570
        size = (256, 460)
571
572
573
574
        num_categories = 21

        conv_fns = []
        if supports_pil:
575
            conv_fns.append(to_pil_image)
576
577
578
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
579
580
            tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
            tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
581

582
            dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
583
            dp_ref = (
584
585
                to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
                to_pil_image(tv_tensor_mask),
586
587
588
589
590
591
592
593
594
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
595
        for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
596
597

            self.set_seed()
598
            actual = actual_image, actual_mask = t(dp)
599
600

            self.set_seed()
601
602
603
604
605
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
606

607
            assert_equal(actual, expected)
608
609
610
611
612
613

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
614
                v2_transforms.RandomHorizontalFlip(p=1.0),
615
616
617
618
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
619
                v2_transforms.RandomHorizontalFlip(p=0.0),
620
621
622
623
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
624
                v2_transforms.Compose(
625
                    [
626
                        PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
627
                        v2_transforms.RandomCrop(size=480),
628
629
630
631
632
633
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
634
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
635
636
637
638
639
640
641
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)

642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691

@pytest.mark.parametrize(
    ("legacy_dispatcher", "name_only_params"),
    [
        (legacy_F.get_dimensions, {}),
        (legacy_F.get_image_size, {}),
        (legacy_F.get_image_num_channels, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.pil_to_tensor, {}),
        (legacy_F.convert_image_dtype, {}),
        (legacy_F.to_pil_image, {}),
        (legacy_F.to_grayscale, {}),
        (legacy_F.rgb_to_grayscale, {}),
        (legacy_F.to_tensor, {}),
    ],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
    legacy_signature = inspect.signature(legacy_dispatcher)
    legacy_params = list(legacy_signature.parameters.values())[1:]

    try:
        prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
    except AttributeError:
        raise AssertionError(
            f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
        ) from None

    prototype_signature = inspect.signature(prototype_dispatcher)
    prototype_params = list(prototype_signature.parameters.values())[1:]

    # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
    # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
    # regular check below.
    prototype_params, new_prototype_params = (
        prototype_params[: len(legacy_params)],
        prototype_params[len(legacy_params) :],
    )
    for param in new_prototype_params:
        assert param.default is not param.empty

    # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
    # annotations. In these cases we simply drop the annotation and default argument from the comparison
    for prototype_param, legacy_param in zip(prototype_params, legacy_params):
        if legacy_param.name in name_only_params:
            prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
            legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
        elif legacy_param.annotation is inspect.Parameter.empty:
            prototype_param._annotation = inspect.Parameter.empty

    assert prototype_params == legacy_params