test_transforms_v2_consistency.py 23.9 KB
Newer Older
1
2
import importlib.machinery
import importlib.util
3
import inspect
4
import random
5
import re
6
from pathlib import Path
7

8
import numpy as np
9
import pytest
10
11

import torch
12
import torchvision.transforms.v2 as v2_transforms
13
from common_utils import assert_close, assert_equal, set_rng_seed
14
from torch import nn
15
from torchvision import transforms as legacy_transforms, tv_tensors
16
from torchvision._utils import sequence_to_str
17

18
from torchvision.transforms import functional as legacy_F
19
from torchvision.transforms.v2 import functional as prototype_F
Nicolas Hug's avatar
Nicolas Hug committed
20
from torchvision.transforms.v2._utils import _get_fill, query_size
21
from torchvision.transforms.v2.functional import to_pil_image
22
23
24
25
26
27
28
29
from transforms_v2_legacy_utils import (
    ArgsKwargs,
    make_bounding_boxes,
    make_detection_mask,
    make_image,
    make_images,
    make_segmentation_mask,
)
30

31
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
32
33


Nicolas Hug's avatar
Nicolas Hug committed
34
35
36
37
38
39
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


40
41
42
43
44
45
46
47
48
class NotScriptableArgsKwargs(ArgsKwargs):
    """
    This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
    thus will be tested there, but will be skipped by the JIT tests.
    """

    pass


49
50
class ConsistencyConfig:
    def __init__(
51
52
53
        self,
        prototype_cls,
        legacy_cls,
54
55
        # If no args_kwargs is passed, only the signature will be checked
        args_kwargs=(),
56
57
58
        make_images_kwargs=None,
        supports_pil=True,
        removed_params=(),
59
        closeness_kwargs=None,
60
61
62
    ):
        self.prototype_cls = prototype_cls
        self.legacy_cls = legacy_cls
63
        self.args_kwargs = args_kwargs
64
65
        self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
        self.supports_pil = supports_pil
66
        self.removed_params = removed_params
67
        self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
68
69


70
71
72
73
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

74
CONSISTENCY_CONFIGS = [
75
    ConsistencyConfig(
76
        v2_transforms.Compose,
77
78
79
        legacy_transforms.Compose,
    ),
    ConsistencyConfig(
80
        v2_transforms.RandomApply,
81
82
83
        legacy_transforms.RandomApply,
    ),
    ConsistencyConfig(
84
        v2_transforms.RandomChoice,
85
86
87
        legacy_transforms.RandomChoice,
    ),
    ConsistencyConfig(
88
        v2_transforms.RandomOrder,
89
90
        legacy_transforms.RandomOrder,
    ),
91
92
93
]


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
    legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
    prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

    for param in config.removed_params:
        legacy_params.pop(param, None)

    missing = legacy_params.keys() - prototype_params.keys()
    if missing:
        raise AssertionError(
            f"The prototype transform does not support the parameters "
            f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
            f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
            f"the `ConsistencyConfig`."
        )

    extra = prototype_params.keys() - legacy_params.keys()
112
113
114
115
116
117
    extra_without_default = {
        param
        for param in extra
        if prototype_params[param].default is inspect.Parameter.empty
        and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
    }
118
119
    if extra_without_default:
        raise AssertionError(
120
121
122
            f"The prototype transform requires the parameters "
            f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
            f"not. Please add a default value."
123
124
        )

125
126
127
128
129
130
    legacy_signature = list(legacy_params.keys())
    # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
    # to the same number of parameters as the legacy one
    prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

    assert prototype_signature == legacy_signature
131
132


133
134
135
def check_call_consistency(
    prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
136
137
    if images is None:
        images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
138

139
140
    closeness_kwargs = closeness_kwargs or dict()

141
142
    for image in images:
        image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
143
144
145

        image_tensor = torch.Tensor(image)
        try:
146
            torch.manual_seed(0)
147
            output_legacy_tensor = legacy_transform(image_tensor)
148
149
        except Exception as exc:
            raise pytest.UsageError(
150
                f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
151
                f"error above. This means that you need to specify the parameters passed to `make_images` through the "
152
153
154
155
                "`make_images_kwargs` of the `ConsistencyConfig`."
            ) from exc

        try:
156
            torch.manual_seed(0)
157
            output_prototype_tensor = prototype_transform(image_tensor)
158
159
        except Exception as exc:
            raise AssertionError(
160
                f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
161
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
162
                f"`is_pure_tensor` path in `_transform`."
163
164
            ) from exc

165
        assert_close(
166
167
168
            output_prototype_tensor,
            output_legacy_tensor,
            msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
169
            **closeness_kwargs,
170
171
172
        )

        try:
173
            torch.manual_seed(0)
174
            output_prototype_image = prototype_transform(image)
175
176
        except Exception as exc:
            raise AssertionError(
177
                f"Transforming a image tv_tensor with shape {image_repr} failed in the prototype transform with "
178
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
179
                f"`tv_tensors.Image` path in `_transform`."
180
181
            ) from exc

182
        assert_close(
183
            output_prototype_image,
184
            output_prototype_tensor,
185
            msg=lambda msg: f"Output for tv_tensor and tensor images is not equal: \n\n{msg}",
186
            **closeness_kwargs,
187
188
        )

189
        if image.ndim == 3 and supports_pil:
190
            image_pil = to_pil_image(image)
191

192
            try:
193
                torch.manual_seed(0)
194
                output_legacy_pil = legacy_transform(image_pil)
195
196
            except Exception as exc:
                raise pytest.UsageError(
197
                    f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
198
199
200
201
202
                    f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
                    "`ConsistencyConfig`. "
                ) from exc

            try:
203
                torch.manual_seed(0)
204
                output_prototype_pil = prototype_transform(image_pil)
205
206
            except Exception as exc:
                raise AssertionError(
207
                    f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
208
209
210
211
                    f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                    f"`PIL.Image.Image` path in `_transform`."
                ) from exc

212
            assert_close(
213
214
                output_prototype_pil,
                output_legacy_pil,
215
                msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
216
                **closeness_kwargs,
217
            )
218
219


220
@pytest.mark.parametrize(
221
222
    ("config", "args_kwargs"),
    [
223
224
225
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
226
        for config in CONSISTENCY_CONFIGS
227
        for idx, args_kwargs in enumerate(config.args_kwargs)
228
    ],
229
)
230
@pytest.mark.filterwarnings("ignore")
231
def test_call_consistency(config, args_kwargs):
232
233
234
    args, kwargs = args_kwargs

    try:
235
        legacy_transform = config.legacy_cls(*args, **kwargs)
236
237
238
239
240
241
242
    except Exception as exc:
        raise pytest.UsageError(
            f"Initializing the legacy transform failed with the error above. "
            f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
        ) from exc

    try:
243
        prototype_transform = config.prototype_cls(*args, **kwargs)
244
245
246
247
248
249
    except Exception as exc:
        raise AssertionError(
            "Initializing the prototype transform failed with the error above. "
            "This means there is a consistency bug in the constructor."
        ) from exc

250
251
252
253
254
    check_call_consistency(
        prototype_transform,
        legacy_transform,
        images=make_images(**config.make_images_kwargs),
        supports_pil=config.supports_pil,
255
        closeness_kwargs=config.closeness_kwargs,
256
257
258
    )


259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
@pytest.mark.parametrize(
    ("config", "args_kwargs"),
    [
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
        for config in CONSISTENCY_CONFIGS
        for idx, args_kwargs in enumerate(config.args_kwargs)
        if not isinstance(args_kwargs, NotScriptableArgsKwargs)
    ],
)
def test_jit_consistency(config, args_kwargs):
    args, kwargs = args_kwargs

    prototype_transform_eager = config.prototype_cls(*args, **kwargs)
    legacy_transform_eager = config.legacy_cls(*args, **kwargs)

    legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
    prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

    for image in make_images(**config.make_images_kwargs):
        image = image.as_subclass(torch.Tensor)

        torch.manual_seed(0)
        output_legacy_scripted = legacy_transform_scripted(image)

        torch.manual_seed(0)
        output_prototype_scripted = prototype_transform_scripted(image)

        assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


291
292
293
294
295
296
297
298
299
300
class TestContainerTransforms:
    """
    Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
    consistency automatically tests the wrapped transforms consistency.

    Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
    that were already tested for consistency above.
    """

    def test_compose(self):
301
        prototype_transform = v2_transforms.Compose(
302
            [
303
304
                v2_transforms.Resize(256),
                v2_transforms.CenterCrop(224),
305
306
307
308
309
310
311
312
313
            ]
        )
        legacy_transform = legacy_transforms.Compose(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ]
        )

314
315
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
316
317

    @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
318
319
    @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
    def test_random_apply(self, p, sequence_type):
320
        prototype_transform = v2_transforms.RandomApply(
321
322
            sequence_type(
                [
323
324
                    v2_transforms.Resize(256),
                    v2_transforms.CenterCrop(224),
325
326
                ]
            ),
327
328
329
            p=p,
        )
        legacy_transform = legacy_transforms.RandomApply(
330
331
332
333
334
335
            sequence_type(
                [
                    legacy_transforms.Resize(256),
                    legacy_transforms.CenterCrop(224),
                ]
            ),
336
337
338
            p=p,
        )

339
340
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
341

342
343
344
345
346
        if sequence_type is nn.ModuleList:
            # quick and dirty test that it is jit-scriptable
            scripted = torch.jit.script(prototype_transform)
            scripted(torch.rand(1, 3, 300, 300))

347
    # We can't test other values for `p` since the random parameter generation is different
348
349
    @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
    def test_random_choice(self, probabilities):
350
        prototype_transform = v2_transforms.RandomChoice(
351
            [
352
                v2_transforms.Resize(256),
353
354
                legacy_transforms.CenterCrop(224),
            ],
355
            p=probabilities,
356
357
358
359
360
361
        )
        legacy_transform = legacy_transforms.RandomChoice(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ],
362
            p=probabilities,
363
364
        )

365
366
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
367
368


369
370
class TestToTensorTransforms:
    def test_pil_to_tensor(self):
371
        prototype_transform = v2_transforms.PILToTensor()
372
373
        legacy_transform = legacy_transforms.PILToTensor()

374
        for image in make_images(extra_dims=[()]):
375
            image_pil = to_pil_image(image)
376
377
378
379

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

    def test_to_tensor(self):
380
        with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
381
            prototype_transform = v2_transforms.ToTensor()
382
383
        legacy_transform = legacy_transforms.ToTensor()

384
        for image in make_images(extra_dims=[()]):
385
            image_pil = to_pil_image(image)
386
387
388
389
            image_numpy = np.array(image_pil)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
            assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
390
391


392
def import_transforms_from_references(reference):
393
394
395
396
397
398
399
400
401
402
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
403
404
405


det_transforms = import_transforms_from_references("detection")
406
407
408


class TestRefDetTransforms:
409
    def make_tv_tensors(self, with_mask=True):
410
411
412
        size = (600, 800)
        num_objects = 22

413
414
415
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

416
        pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
417
        target = {
418
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
419
420
421
422
423
424
425
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

426
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
427
        target = {
428
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
429
430
431
432
433
434
435
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

436
        tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
437
        target = {
438
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
439
440
441
442
443
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

444
        yield (tv_tensor_image, target)
445
446
447
448

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
449
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
450
451
452
453
454
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
455
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
456
457
458
459
                    ]
                ),
                {"with_mask": False},
            ),
460
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
461
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
462
463
464
465
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
466
                v2_transforms.RandomShortestSize(
467
468
469
470
471
472
473
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
474
        for dp in self.make_tv_tensors(**data_kwargs):
475
476
477
478
479
480
481
482
483

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
484
485
486
487
488
489
490
491
492


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
493
class PadIfSmaller(v2_transforms.Transform):
494
495
496
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
497
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
498
499

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
500
        height, width = query_size(sample)
501
502
503
504
505
506
507
508
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

509
        fill = _get_fill(self.fill, type(inpt))
510
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
511
512
513


class TestRefSegTransforms:
514
    def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
515
        size = (256, 460)
516
517
518
519
        num_categories = 21

        conv_fns = []
        if supports_pil:
520
            conv_fns.append(to_pil_image)
521
522
523
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
524
525
            tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
            tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
526

527
            dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
528
            dp_ref = (
529
530
                to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
                to_pil_image(tv_tensor_mask),
531
532
533
534
535
536
537
538
539
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
540
        for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
541
542

            self.set_seed()
543
            actual = actual_image, actual_mask = t(dp)
544
545

            self.set_seed()
546
547
548
549
550
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
551

552
            assert_equal(actual, expected)
553
554
555
556
557
558

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
559
                v2_transforms.RandomHorizontalFlip(p=1.0),
560
561
562
563
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
564
                v2_transforms.RandomHorizontalFlip(p=0.0),
565
566
567
568
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
569
                v2_transforms.Compose(
570
                    [
571
                        PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
572
                        v2_transforms.RandomCrop(size=480),
573
574
575
576
577
578
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
579
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
580
581
582
583
584
585
586
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)

587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636

@pytest.mark.parametrize(
    ("legacy_dispatcher", "name_only_params"),
    [
        (legacy_F.get_dimensions, {}),
        (legacy_F.get_image_size, {}),
        (legacy_F.get_image_num_channels, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.pil_to_tensor, {}),
        (legacy_F.convert_image_dtype, {}),
        (legacy_F.to_pil_image, {}),
        (legacy_F.to_grayscale, {}),
        (legacy_F.rgb_to_grayscale, {}),
        (legacy_F.to_tensor, {}),
    ],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
    legacy_signature = inspect.signature(legacy_dispatcher)
    legacy_params = list(legacy_signature.parameters.values())[1:]

    try:
        prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
    except AttributeError:
        raise AssertionError(
            f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
        ) from None

    prototype_signature = inspect.signature(prototype_dispatcher)
    prototype_params = list(prototype_signature.parameters.values())[1:]

    # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
    # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
    # regular check below.
    prototype_params, new_prototype_params = (
        prototype_params[: len(legacy_params)],
        prototype_params[len(legacy_params) :],
    )
    for param in new_prototype_params:
        assert param.default is not param.empty

    # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
    # annotations. In these cases we simply drop the annotation and default argument from the comparison
    for prototype_param, legacy_param in zip(prototype_params, legacy_params):
        if legacy_param.name in name_only_params:
            prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
            legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
        elif legacy_param.annotation is inspect.Parameter.empty:
            prototype_param._annotation = inspect.Parameter.empty

    assert prototype_params == legacy_params