test_transforms_v2_consistency.py 24.9 KB
Newer Older
1
2
import importlib.machinery
import importlib.util
3
import inspect
4
import random
5
import re
6
from pathlib import Path
7

8
import numpy as np
9
import pytest
10
11

import torch
12
import torchvision.transforms.v2 as v2_transforms
13
from common_utils import assert_close, assert_equal, set_rng_seed
14
from torch import nn
15
from torchvision import transforms as legacy_transforms, tv_tensors
16
from torchvision._utils import sequence_to_str
17

18
from torchvision.transforms import functional as legacy_F
19
from torchvision.transforms.v2 import functional as prototype_F
Nicolas Hug's avatar
Nicolas Hug committed
20
from torchvision.transforms.v2._utils import _get_fill, query_size
21
from torchvision.transforms.v2.functional import to_pil_image
22
23
24
25
26
27
28
29
from transforms_v2_legacy_utils import (
    ArgsKwargs,
    make_bounding_boxes,
    make_detection_mask,
    make_image,
    make_images,
    make_segmentation_mask,
)
30

31
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
32
33


Nicolas Hug's avatar
Nicolas Hug committed
34
35
36
37
38
39
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


40
41
42
43
44
45
46
47
48
class NotScriptableArgsKwargs(ArgsKwargs):
    """
    This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
    thus will be tested there, but will be skipped by the JIT tests.
    """

    pass


49
50
class ConsistencyConfig:
    def __init__(
51
52
53
        self,
        prototype_cls,
        legacy_cls,
54
55
        # If no args_kwargs is passed, only the signature will be checked
        args_kwargs=(),
56
57
58
        make_images_kwargs=None,
        supports_pil=True,
        removed_params=(),
59
        closeness_kwargs=None,
60
61
62
    ):
        self.prototype_cls = prototype_cls
        self.legacy_cls = legacy_cls
63
        self.args_kwargs = args_kwargs
64
65
        self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
        self.supports_pil = supports_pil
66
        self.removed_params = removed_params
67
        self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
68
69


70
71
72
73
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

74
CONSISTENCY_CONFIGS = [
75
    ConsistencyConfig(
76
        v2_transforms.ToPILImage,
77
        legacy_transforms.ToPILImage,
78
        [NotScriptableArgsKwargs()],
79
80
        make_images_kwargs=dict(
            color_spaces=[
81
82
83
84
                "GRAY",
                "GRAY_ALPHA",
                "RGB",
                "RGBA",
85
86
87
88
89
90
            ],
            extra_dims=[()],
        ),
        supports_pil=False,
    ),
    ConsistencyConfig(
91
        v2_transforms.Lambda,
92
93
        legacy_transforms.Lambda,
        [
94
            NotScriptableArgsKwargs(lambda image: image / 2),
95
96
97
98
99
        ],
        # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
        # images given that the transform does nothing but call it anyway.
        supports_pil=False,
    ),
100
    ConsistencyConfig(
101
        v2_transforms.PILToTensor,
102
103
104
        legacy_transforms.PILToTensor,
    ),
    ConsistencyConfig(
105
        v2_transforms.ToTensor,
106
107
108
        legacy_transforms.ToTensor,
    ),
    ConsistencyConfig(
109
        v2_transforms.Compose,
110
111
112
        legacy_transforms.Compose,
    ),
    ConsistencyConfig(
113
        v2_transforms.RandomApply,
114
115
116
        legacy_transforms.RandomApply,
    ),
    ConsistencyConfig(
117
        v2_transforms.RandomChoice,
118
119
120
        legacy_transforms.RandomChoice,
    ),
    ConsistencyConfig(
121
        v2_transforms.RandomOrder,
122
123
        legacy_transforms.RandomOrder,
    ),
124
125
126
]


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
    legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
    prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

    for param in config.removed_params:
        legacy_params.pop(param, None)

    missing = legacy_params.keys() - prototype_params.keys()
    if missing:
        raise AssertionError(
            f"The prototype transform does not support the parameters "
            f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
            f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
            f"the `ConsistencyConfig`."
        )

    extra = prototype_params.keys() - legacy_params.keys()
145
146
147
148
149
150
    extra_without_default = {
        param
        for param in extra
        if prototype_params[param].default is inspect.Parameter.empty
        and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
    }
151
152
    if extra_without_default:
        raise AssertionError(
153
154
155
            f"The prototype transform requires the parameters "
            f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
            f"not. Please add a default value."
156
157
        )

158
159
160
161
162
163
    legacy_signature = list(legacy_params.keys())
    # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
    # to the same number of parameters as the legacy one
    prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

    assert prototype_signature == legacy_signature
164
165


166
167
168
def check_call_consistency(
    prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
169
170
    if images is None:
        images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
171

172
173
    closeness_kwargs = closeness_kwargs or dict()

174
175
    for image in images:
        image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
176
177
178

        image_tensor = torch.Tensor(image)
        try:
179
            torch.manual_seed(0)
180
            output_legacy_tensor = legacy_transform(image_tensor)
181
182
        except Exception as exc:
            raise pytest.UsageError(
183
                f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
184
                f"error above. This means that you need to specify the parameters passed to `make_images` through the "
185
186
187
188
                "`make_images_kwargs` of the `ConsistencyConfig`."
            ) from exc

        try:
189
            torch.manual_seed(0)
190
            output_prototype_tensor = prototype_transform(image_tensor)
191
192
        except Exception as exc:
            raise AssertionError(
193
                f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
194
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
195
                f"`is_pure_tensor` path in `_transform`."
196
197
            ) from exc

198
        assert_close(
199
200
201
            output_prototype_tensor,
            output_legacy_tensor,
            msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
202
            **closeness_kwargs,
203
204
205
        )

        try:
206
            torch.manual_seed(0)
207
            output_prototype_image = prototype_transform(image)
208
209
        except Exception as exc:
            raise AssertionError(
210
                f"Transforming a image tv_tensor with shape {image_repr} failed in the prototype transform with "
211
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
212
                f"`tv_tensors.Image` path in `_transform`."
213
214
            ) from exc

215
        assert_close(
216
            output_prototype_image,
217
            output_prototype_tensor,
218
            msg=lambda msg: f"Output for tv_tensor and tensor images is not equal: \n\n{msg}",
219
            **closeness_kwargs,
220
221
        )

222
        if image.ndim == 3 and supports_pil:
223
            image_pil = to_pil_image(image)
224

225
            try:
226
                torch.manual_seed(0)
227
                output_legacy_pil = legacy_transform(image_pil)
228
229
            except Exception as exc:
                raise pytest.UsageError(
230
                    f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
231
232
233
234
235
                    f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
                    "`ConsistencyConfig`. "
                ) from exc

            try:
236
                torch.manual_seed(0)
237
                output_prototype_pil = prototype_transform(image_pil)
238
239
            except Exception as exc:
                raise AssertionError(
240
                    f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
241
242
243
244
                    f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                    f"`PIL.Image.Image` path in `_transform`."
                ) from exc

245
            assert_close(
246
247
                output_prototype_pil,
                output_legacy_pil,
248
                msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
249
                **closeness_kwargs,
250
            )
251
252


253
@pytest.mark.parametrize(
254
255
    ("config", "args_kwargs"),
    [
256
257
258
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
259
        for config in CONSISTENCY_CONFIGS
260
        for idx, args_kwargs in enumerate(config.args_kwargs)
261
    ],
262
)
263
@pytest.mark.filterwarnings("ignore")
264
def test_call_consistency(config, args_kwargs):
265
266
267
    args, kwargs = args_kwargs

    try:
268
        legacy_transform = config.legacy_cls(*args, **kwargs)
269
270
271
272
273
274
275
    except Exception as exc:
        raise pytest.UsageError(
            f"Initializing the legacy transform failed with the error above. "
            f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
        ) from exc

    try:
276
        prototype_transform = config.prototype_cls(*args, **kwargs)
277
278
279
280
281
282
    except Exception as exc:
        raise AssertionError(
            "Initializing the prototype transform failed with the error above. "
            "This means there is a consistency bug in the constructor."
        ) from exc

283
284
285
286
287
    check_call_consistency(
        prototype_transform,
        legacy_transform,
        images=make_images(**config.make_images_kwargs),
        supports_pil=config.supports_pil,
288
        closeness_kwargs=config.closeness_kwargs,
289
290
291
    )


292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
@pytest.mark.parametrize(
    ("config", "args_kwargs"),
    [
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
        for config in CONSISTENCY_CONFIGS
        for idx, args_kwargs in enumerate(config.args_kwargs)
        if not isinstance(args_kwargs, NotScriptableArgsKwargs)
    ],
)
def test_jit_consistency(config, args_kwargs):
    args, kwargs = args_kwargs

    prototype_transform_eager = config.prototype_cls(*args, **kwargs)
    legacy_transform_eager = config.legacy_cls(*args, **kwargs)

    legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
    prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

    for image in make_images(**config.make_images_kwargs):
        image = image.as_subclass(torch.Tensor)

        torch.manual_seed(0)
        output_legacy_scripted = legacy_transform_scripted(image)

        torch.manual_seed(0)
        output_prototype_scripted = prototype_transform_scripted(image)

        assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


324
325
326
327
328
329
330
331
332
333
class TestContainerTransforms:
    """
    Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
    consistency automatically tests the wrapped transforms consistency.

    Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
    that were already tested for consistency above.
    """

    def test_compose(self):
334
        prototype_transform = v2_transforms.Compose(
335
            [
336
337
                v2_transforms.Resize(256),
                v2_transforms.CenterCrop(224),
338
339
340
341
342
343
344
345
346
            ]
        )
        legacy_transform = legacy_transforms.Compose(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ]
        )

347
348
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
349
350

    @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
351
352
    @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
    def test_random_apply(self, p, sequence_type):
353
        prototype_transform = v2_transforms.RandomApply(
354
355
            sequence_type(
                [
356
357
                    v2_transforms.Resize(256),
                    v2_transforms.CenterCrop(224),
358
359
                ]
            ),
360
361
362
            p=p,
        )
        legacy_transform = legacy_transforms.RandomApply(
363
364
365
366
367
368
            sequence_type(
                [
                    legacy_transforms.Resize(256),
                    legacy_transforms.CenterCrop(224),
                ]
            ),
369
370
371
            p=p,
        )

372
373
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
374

375
376
377
378
379
        if sequence_type is nn.ModuleList:
            # quick and dirty test that it is jit-scriptable
            scripted = torch.jit.script(prototype_transform)
            scripted(torch.rand(1, 3, 300, 300))

380
    # We can't test other values for `p` since the random parameter generation is different
381
382
    @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
    def test_random_choice(self, probabilities):
383
        prototype_transform = v2_transforms.RandomChoice(
384
            [
385
                v2_transforms.Resize(256),
386
387
                legacy_transforms.CenterCrop(224),
            ],
388
            p=probabilities,
389
390
391
392
393
394
        )
        legacy_transform = legacy_transforms.RandomChoice(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ],
395
            p=probabilities,
396
397
        )

398
399
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
400
401


402
403
class TestToTensorTransforms:
    def test_pil_to_tensor(self):
404
        prototype_transform = v2_transforms.PILToTensor()
405
406
        legacy_transform = legacy_transforms.PILToTensor()

407
        for image in make_images(extra_dims=[()]):
408
            image_pil = to_pil_image(image)
409
410
411
412

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

    def test_to_tensor(self):
413
        with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
414
            prototype_transform = v2_transforms.ToTensor()
415
416
        legacy_transform = legacy_transforms.ToTensor()

417
        for image in make_images(extra_dims=[()]):
418
            image_pil = to_pil_image(image)
419
420
421
422
            image_numpy = np.array(image_pil)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
            assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
423
424


425
def import_transforms_from_references(reference):
426
427
428
429
430
431
432
433
434
435
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
436
437
438


det_transforms = import_transforms_from_references("detection")
439
440
441


class TestRefDetTransforms:
442
    def make_tv_tensors(self, with_mask=True):
443
444
445
        size = (600, 800)
        num_objects = 22

446
447
448
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

449
        pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
450
        target = {
451
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
452
453
454
455
456
457
458
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

459
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
460
        target = {
461
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
462
463
464
465
466
467
468
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

469
        tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
470
        target = {
471
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
472
473
474
475
476
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

477
        yield (tv_tensor_image, target)
478
479
480
481

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
482
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
483
484
485
486
487
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
488
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
489
490
491
492
                    ]
                ),
                {"with_mask": False},
            ),
493
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
494
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
495
496
497
498
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
499
                v2_transforms.RandomShortestSize(
500
501
502
503
504
505
506
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
507
        for dp in self.make_tv_tensors(**data_kwargs):
508
509
510
511
512
513
514
515
516

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
517
518
519
520
521
522
523
524
525


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
526
class PadIfSmaller(v2_transforms.Transform):
527
528
529
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
530
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
531
532

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
533
        height, width = query_size(sample)
534
535
536
537
538
539
540
541
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

542
        fill = _get_fill(self.fill, type(inpt))
543
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
544
545
546


class TestRefSegTransforms:
547
    def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
548
        size = (256, 460)
549
550
551
552
        num_categories = 21

        conv_fns = []
        if supports_pil:
553
            conv_fns.append(to_pil_image)
554
555
556
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
557
558
            tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
            tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
559

560
            dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
561
            dp_ref = (
562
563
                to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
                to_pil_image(tv_tensor_mask),
564
565
566
567
568
569
570
571
572
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
573
        for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
574
575

            self.set_seed()
576
            actual = actual_image, actual_mask = t(dp)
577
578

            self.set_seed()
579
580
581
582
583
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
584

585
            assert_equal(actual, expected)
586
587
588
589
590
591

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
592
                v2_transforms.RandomHorizontalFlip(p=1.0),
593
594
595
596
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
597
                v2_transforms.RandomHorizontalFlip(p=0.0),
598
599
600
601
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
602
                v2_transforms.Compose(
603
                    [
604
                        PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
605
                        v2_transforms.RandomCrop(size=480),
606
607
608
609
610
611
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
612
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
613
614
615
616
617
618
619
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)

620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669

@pytest.mark.parametrize(
    ("legacy_dispatcher", "name_only_params"),
    [
        (legacy_F.get_dimensions, {}),
        (legacy_F.get_image_size, {}),
        (legacy_F.get_image_num_channels, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.pil_to_tensor, {}),
        (legacy_F.convert_image_dtype, {}),
        (legacy_F.to_pil_image, {}),
        (legacy_F.to_grayscale, {}),
        (legacy_F.rgb_to_grayscale, {}),
        (legacy_F.to_tensor, {}),
    ],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
    legacy_signature = inspect.signature(legacy_dispatcher)
    legacy_params = list(legacy_signature.parameters.values())[1:]

    try:
        prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
    except AttributeError:
        raise AssertionError(
            f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
        ) from None

    prototype_signature = inspect.signature(prototype_dispatcher)
    prototype_params = list(prototype_signature.parameters.values())[1:]

    # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
    # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
    # regular check below.
    prototype_params, new_prototype_params = (
        prototype_params[: len(legacy_params)],
        prototype_params[len(legacy_params) :],
    )
    for param in new_prototype_params:
        assert param.default is not param.empty

    # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
    # annotations. In these cases we simply drop the annotation and default argument from the comparison
    for prototype_param, legacy_param in zip(prototype_params, legacy_params):
        if legacy_param.name in name_only_params:
            prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
            legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
        elif legacy_param.annotation is inspect.Parameter.empty:
            prototype_param._annotation = inspect.Parameter.empty

    assert prototype_params == legacy_params