test_transforms_v2_consistency.py 30.8 KB
Newer Older
1
2
import importlib.machinery
import importlib.util
3
import inspect
4
import random
5
import re
6
from pathlib import Path
7

8
import numpy as np
9
import pytest
10
11

import torch
12
import torchvision.transforms.v2 as v2_transforms
13
from common_utils import assert_close, assert_equal, set_rng_seed
14
from torch import nn
15
from torchvision import transforms as legacy_transforms, tv_tensors
16
from torchvision._utils import sequence_to_str
17

18
from torchvision.transforms import functional as legacy_F
19
from torchvision.transforms.v2 import functional as prototype_F
Nicolas Hug's avatar
Nicolas Hug committed
20
from torchvision.transforms.v2._utils import _get_fill, query_size
21
from torchvision.transforms.v2.functional import to_pil_image
22
23
24
25
26
27
28
29
from transforms_v2_legacy_utils import (
    ArgsKwargs,
    make_bounding_boxes,
    make_detection_mask,
    make_image,
    make_images,
    make_segmentation_mask,
)
30

31
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
32
33


Nicolas Hug's avatar
Nicolas Hug committed
34
35
36
37
38
39
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


40
41
42
43
44
45
46
47
48
class NotScriptableArgsKwargs(ArgsKwargs):
    """
    This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
    thus will be tested there, but will be skipped by the JIT tests.
    """

    pass


49
50
class ConsistencyConfig:
    def __init__(
51
52
53
        self,
        prototype_cls,
        legacy_cls,
54
55
        # If no args_kwargs is passed, only the signature will be checked
        args_kwargs=(),
56
57
58
        make_images_kwargs=None,
        supports_pil=True,
        removed_params=(),
59
        closeness_kwargs=None,
60
61
62
    ):
        self.prototype_cls = prototype_cls
        self.legacy_cls = legacy_cls
63
        self.args_kwargs = args_kwargs
64
65
        self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
        self.supports_pil = supports_pil
66
        self.removed_params = removed_params
67
        self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
68
69


70
71
72
73
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

74
75
CONSISTENCY_CONFIGS = [
    ConsistencyConfig(
76
        v2_transforms.Normalize,
77
78
79
80
81
82
83
        legacy_transforms.Normalize,
        [
            ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ],
        supports_pil=False,
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
    ),
84
    ConsistencyConfig(
85
        v2_transforms.FiveCrop,
86
87
88
89
90
91
92
93
        legacy_transforms.FiveCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
    ConsistencyConfig(
94
        v2_transforms.TenCrop,
95
96
97
98
        legacy_transforms.TenCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
99
            ArgsKwargs(18, vertical_flip=True),
100
101
102
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
103
104
    *[
        ConsistencyConfig(
105
            v2_transforms.LinearTransformation,
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            legacy_transforms.LinearTransformation,
            [
                ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
            ],
            # Make sure that the product of the height, width and number of channels matches the number of elements in
            # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
            make_images_kwargs=dict(
                DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
            ),
            supports_pil=False,
        )
        for matrix_dtype, image_dtype in [
            (torch.float32, torch.float32),
            (torch.float64, torch.float64),
            (torch.float32, torch.uint8),
            (torch.float64, torch.float32),
            (torch.float32, torch.float64),
        ]
    ],
125
    ConsistencyConfig(
126
        v2_transforms.Grayscale,
127
128
129
130
131
        legacy_transforms.Grayscale,
        [
            ArgsKwargs(num_output_channels=1),
            ArgsKwargs(num_output_channels=3),
        ],
132
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
133
134
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
135
    ),
136
    ConsistencyConfig(
137
        v2_transforms.ToPILImage,
138
        legacy_transforms.ToPILImage,
139
        [NotScriptableArgsKwargs()],
140
141
        make_images_kwargs=dict(
            color_spaces=[
142
143
144
145
                "GRAY",
                "GRAY_ALPHA",
                "RGB",
                "RGBA",
146
147
148
149
150
151
            ],
            extra_dims=[()],
        ),
        supports_pil=False,
    ),
    ConsistencyConfig(
152
        v2_transforms.Lambda,
153
154
        legacy_transforms.Lambda,
        [
155
            NotScriptableArgsKwargs(lambda image: image / 2),
156
157
158
159
160
        ],
        # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
        # images given that the transform does nothing but call it anyway.
        supports_pil=False,
    ),
161
    ConsistencyConfig(
162
        v2_transforms.RandomEqualize,
163
164
165
166
167
168
169
170
        legacy_transforms.RandomEqualize,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
171
        v2_transforms.RandomInvert,
172
173
174
175
176
177
178
        legacy_transforms.RandomInvert,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
179
        v2_transforms.RandomPosterize,
180
181
182
183
184
185
186
187
188
        legacy_transforms.RandomPosterize,
        [
            ArgsKwargs(p=0, bits=5),
            ArgsKwargs(p=1, bits=1),
            ArgsKwargs(p=1, bits=3),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
189
        v2_transforms.RandomSolarize,
190
191
192
193
194
195
196
        legacy_transforms.RandomSolarize,
        [
            ArgsKwargs(p=0, threshold=0.5),
            ArgsKwargs(p=1, threshold=0.3),
            ArgsKwargs(p=1, threshold=0.99),
        ],
    ),
197
198
    *[
        ConsistencyConfig(
199
            v2_transforms.RandomAutocontrast,
200
201
202
203
204
205
206
207
208
209
            legacy_transforms.RandomAutocontrast,
            [
                ArgsKwargs(p=0),
                ArgsKwargs(p=1),
            ],
            make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
            closeness_kwargs=ckw,
        )
        for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
    ],
210
    ConsistencyConfig(
211
        v2_transforms.RandomAdjustSharpness,
212
213
214
        legacy_transforms.RandomAdjustSharpness,
        [
            ArgsKwargs(p=0, sharpness_factor=0.5),
215
            ArgsKwargs(p=1, sharpness_factor=0.2),
216
217
            ArgsKwargs(p=1, sharpness_factor=0.99),
        ],
218
        closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
219
220
    ),
    ConsistencyConfig(
221
        v2_transforms.RandomGrayscale,
222
223
224
225
226
        legacy_transforms.RandomGrayscale,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
227
228
229
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
230
    ),
231
    ConsistencyConfig(
232
        v2_transforms.PILToTensor,
233
234
235
        legacy_transforms.PILToTensor,
    ),
    ConsistencyConfig(
236
        v2_transforms.ToTensor,
237
238
239
        legacy_transforms.ToTensor,
    ),
    ConsistencyConfig(
240
        v2_transforms.Compose,
241
242
243
        legacy_transforms.Compose,
    ),
    ConsistencyConfig(
244
        v2_transforms.RandomApply,
245
246
247
        legacy_transforms.RandomApply,
    ),
    ConsistencyConfig(
248
        v2_transforms.RandomChoice,
249
250
251
        legacy_transforms.RandomChoice,
    ),
    ConsistencyConfig(
252
        v2_transforms.RandomOrder,
253
254
255
        legacy_transforms.RandomOrder,
    ),
    ConsistencyConfig(
256
        v2_transforms.AugMix,
257
258
259
        legacy_transforms.AugMix,
    ),
    ConsistencyConfig(
260
        v2_transforms.AutoAugment,
261
262
263
        legacy_transforms.AutoAugment,
    ),
    ConsistencyConfig(
264
        v2_transforms.RandAugment,
265
266
267
        legacy_transforms.RandAugment,
    ),
    ConsistencyConfig(
268
        v2_transforms.TrivialAugmentWide,
269
270
        legacy_transforms.TrivialAugmentWide,
    ),
271
272
273
]


274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
    legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
    prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

    for param in config.removed_params:
        legacy_params.pop(param, None)

    missing = legacy_params.keys() - prototype_params.keys()
    if missing:
        raise AssertionError(
            f"The prototype transform does not support the parameters "
            f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
            f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
            f"the `ConsistencyConfig`."
        )

    extra = prototype_params.keys() - legacy_params.keys()
292
293
294
295
296
297
    extra_without_default = {
        param
        for param in extra
        if prototype_params[param].default is inspect.Parameter.empty
        and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
    }
298
299
    if extra_without_default:
        raise AssertionError(
300
301
302
            f"The prototype transform requires the parameters "
            f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
            f"not. Please add a default value."
303
304
        )

305
306
307
308
309
310
    legacy_signature = list(legacy_params.keys())
    # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
    # to the same number of parameters as the legacy one
    prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

    assert prototype_signature == legacy_signature
311
312


313
314
315
def check_call_consistency(
    prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
316
317
    if images is None:
        images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
318

319
320
    closeness_kwargs = closeness_kwargs or dict()

321
322
    for image in images:
        image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
323
324
325

        image_tensor = torch.Tensor(image)
        try:
326
            torch.manual_seed(0)
327
            output_legacy_tensor = legacy_transform(image_tensor)
328
329
        except Exception as exc:
            raise pytest.UsageError(
330
                f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
331
                f"error above. This means that you need to specify the parameters passed to `make_images` through the "
332
333
334
335
                "`make_images_kwargs` of the `ConsistencyConfig`."
            ) from exc

        try:
336
            torch.manual_seed(0)
337
            output_prototype_tensor = prototype_transform(image_tensor)
338
339
        except Exception as exc:
            raise AssertionError(
340
                f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
341
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
342
                f"`is_pure_tensor` path in `_transform`."
343
344
            ) from exc

345
        assert_close(
346
347
348
            output_prototype_tensor,
            output_legacy_tensor,
            msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
349
            **closeness_kwargs,
350
351
352
        )

        try:
353
            torch.manual_seed(0)
354
            output_prototype_image = prototype_transform(image)
355
356
        except Exception as exc:
            raise AssertionError(
357
                f"Transforming a image tv_tensor with shape {image_repr} failed in the prototype transform with "
358
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
359
                f"`tv_tensors.Image` path in `_transform`."
360
361
            ) from exc

362
        assert_close(
363
            output_prototype_image,
364
            output_prototype_tensor,
365
            msg=lambda msg: f"Output for tv_tensor and tensor images is not equal: \n\n{msg}",
366
            **closeness_kwargs,
367
368
        )

369
        if image.ndim == 3 and supports_pil:
370
            image_pil = to_pil_image(image)
371

372
            try:
373
                torch.manual_seed(0)
374
                output_legacy_pil = legacy_transform(image_pil)
375
376
            except Exception as exc:
                raise pytest.UsageError(
377
                    f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
378
379
380
381
382
                    f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
                    "`ConsistencyConfig`. "
                ) from exc

            try:
383
                torch.manual_seed(0)
384
                output_prototype_pil = prototype_transform(image_pil)
385
386
            except Exception as exc:
                raise AssertionError(
387
                    f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
388
389
390
391
                    f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                    f"`PIL.Image.Image` path in `_transform`."
                ) from exc

392
            assert_close(
393
394
                output_prototype_pil,
                output_legacy_pil,
395
                msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
396
                **closeness_kwargs,
397
            )
398
399


400
@pytest.mark.parametrize(
401
402
    ("config", "args_kwargs"),
    [
403
404
405
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
406
        for config in CONSISTENCY_CONFIGS
407
        for idx, args_kwargs in enumerate(config.args_kwargs)
408
    ],
409
)
410
@pytest.mark.filterwarnings("ignore")
411
def test_call_consistency(config, args_kwargs):
412
413
414
    args, kwargs = args_kwargs

    try:
415
        legacy_transform = config.legacy_cls(*args, **kwargs)
416
417
418
419
420
421
422
    except Exception as exc:
        raise pytest.UsageError(
            f"Initializing the legacy transform failed with the error above. "
            f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
        ) from exc

    try:
423
        prototype_transform = config.prototype_cls(*args, **kwargs)
424
425
426
427
428
429
    except Exception as exc:
        raise AssertionError(
            "Initializing the prototype transform failed with the error above. "
            "This means there is a consistency bug in the constructor."
        ) from exc

430
431
432
433
434
    check_call_consistency(
        prototype_transform,
        legacy_transform,
        images=make_images(**config.make_images_kwargs),
        supports_pil=config.supports_pil,
435
        closeness_kwargs=config.closeness_kwargs,
436
437
438
    )


439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
@pytest.mark.parametrize(
    ("config", "args_kwargs"),
    [
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
        for config in CONSISTENCY_CONFIGS
        for idx, args_kwargs in enumerate(config.args_kwargs)
        if not isinstance(args_kwargs, NotScriptableArgsKwargs)
    ],
)
def test_jit_consistency(config, args_kwargs):
    args, kwargs = args_kwargs

    prototype_transform_eager = config.prototype_cls(*args, **kwargs)
    legacy_transform_eager = config.legacy_cls(*args, **kwargs)

    legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
    prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

    for image in make_images(**config.make_images_kwargs):
        image = image.as_subclass(torch.Tensor)

        torch.manual_seed(0)
        output_legacy_scripted = legacy_transform_scripted(image)

        torch.manual_seed(0)
        output_prototype_scripted = prototype_transform_scripted(image)

        assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


471
472
473
474
475
476
477
478
479
480
class TestContainerTransforms:
    """
    Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
    consistency automatically tests the wrapped transforms consistency.

    Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
    that were already tested for consistency above.
    """

    def test_compose(self):
481
        prototype_transform = v2_transforms.Compose(
482
            [
483
484
                v2_transforms.Resize(256),
                v2_transforms.CenterCrop(224),
485
486
487
488
489
490
491
492
493
            ]
        )
        legacy_transform = legacy_transforms.Compose(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ]
        )

494
495
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
496
497

    @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
498
499
    @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
    def test_random_apply(self, p, sequence_type):
500
        prototype_transform = v2_transforms.RandomApply(
501
502
            sequence_type(
                [
503
504
                    v2_transforms.Resize(256),
                    v2_transforms.CenterCrop(224),
505
506
                ]
            ),
507
508
509
            p=p,
        )
        legacy_transform = legacy_transforms.RandomApply(
510
511
512
513
514
515
            sequence_type(
                [
                    legacy_transforms.Resize(256),
                    legacy_transforms.CenterCrop(224),
                ]
            ),
516
517
518
            p=p,
        )

519
520
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
521

522
523
524
525
526
        if sequence_type is nn.ModuleList:
            # quick and dirty test that it is jit-scriptable
            scripted = torch.jit.script(prototype_transform)
            scripted(torch.rand(1, 3, 300, 300))

527
    # We can't test other values for `p` since the random parameter generation is different
528
529
    @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
    def test_random_choice(self, probabilities):
530
        prototype_transform = v2_transforms.RandomChoice(
531
            [
532
                v2_transforms.Resize(256),
533
534
                legacy_transforms.CenterCrop(224),
            ],
535
            p=probabilities,
536
537
538
539
540
541
        )
        legacy_transform = legacy_transforms.RandomChoice(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ],
542
            p=probabilities,
543
544
        )

545
546
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
547
548


549
550
class TestToTensorTransforms:
    def test_pil_to_tensor(self):
551
        prototype_transform = v2_transforms.PILToTensor()
552
553
        legacy_transform = legacy_transforms.PILToTensor()

554
        for image in make_images(extra_dims=[()]):
555
            image_pil = to_pil_image(image)
556
557
558
559

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

    def test_to_tensor(self):
560
        with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
561
            prototype_transform = v2_transforms.ToTensor()
562
563
        legacy_transform = legacy_transforms.ToTensor()

564
        for image in make_images(extra_dims=[()]):
565
            image_pil = to_pil_image(image)
566
567
568
569
            image_numpy = np.array(image_pil)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
            assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
570
571


572
def import_transforms_from_references(reference):
573
574
575
576
577
578
579
580
581
582
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
583
584
585


det_transforms = import_transforms_from_references("detection")
586
587
588


class TestRefDetTransforms:
589
    def make_tv_tensors(self, with_mask=True):
590
591
592
        size = (600, 800)
        num_objects = 22

593
594
595
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

596
        pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
597
        target = {
598
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
599
600
601
602
603
604
605
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

606
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
607
        target = {
608
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
609
610
611
612
613
614
615
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

616
        tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
617
        target = {
618
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
619
620
621
622
623
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

624
        yield (tv_tensor_image, target)
625
626
627
628

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
629
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
630
631
632
633
634
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
635
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
636
637
638
639
                    ]
                ),
                {"with_mask": False},
            ),
640
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
641
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
642
643
644
645
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
646
                v2_transforms.RandomShortestSize(
647
648
649
650
651
652
653
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
654
        for dp in self.make_tv_tensors(**data_kwargs):
655
656
657
658
659
660
661
662
663

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
664
665
666
667
668
669
670
671
672


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
673
class PadIfSmaller(v2_transforms.Transform):
674
675
676
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
677
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
678
679

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
680
        height, width = query_size(sample)
681
682
683
684
685
686
687
688
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

689
        fill = _get_fill(self.fill, type(inpt))
690
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
691
692
693


class TestRefSegTransforms:
694
    def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
695
        size = (256, 460)
696
697
698
699
        num_categories = 21

        conv_fns = []
        if supports_pil:
700
            conv_fns.append(to_pil_image)
701
702
703
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
704
705
            tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
            tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
706

707
            dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
708
            dp_ref = (
709
710
                to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
                to_pil_image(tv_tensor_mask),
711
712
713
714
715
716
717
718
719
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
720
        for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
721
722

            self.set_seed()
723
            actual = actual_image, actual_mask = t(dp)
724
725

            self.set_seed()
726
727
728
729
730
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
731

732
            assert_equal(actual, expected)
733
734
735
736
737
738

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
739
                v2_transforms.RandomHorizontalFlip(p=1.0),
740
741
742
743
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
744
                v2_transforms.RandomHorizontalFlip(p=0.0),
745
746
747
748
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
749
                v2_transforms.Compose(
750
                    [
751
                        PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
752
                        v2_transforms.RandomCrop(size=480),
753
754
755
756
757
758
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
759
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
760
761
762
763
764
765
766
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)

767
768
769
770
771
772
773
774
775
776
777
778

@pytest.mark.parametrize(
    ("legacy_dispatcher", "name_only_params"),
    [
        (legacy_F.get_dimensions, {}),
        (legacy_F.get_image_size, {}),
        (legacy_F.get_image_num_channels, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.pil_to_tensor, {}),
        (legacy_F.convert_image_dtype, {}),
        (legacy_F.to_pil_image, {}),
        (legacy_F.normalize, {}),
779
        (legacy_F.resize, {"interpolation"}),
780
781
782
        (legacy_F.pad, {"padding", "fill"}),
        (legacy_F.crop, {}),
        (legacy_F.center_crop, {}),
783
        (legacy_F.resized_crop, {"interpolation"}),
784
        (legacy_F.hflip, {}),
785
        (legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
786
787
788
789
790
791
792
793
        (legacy_F.vflip, {}),
        (legacy_F.five_crop, {}),
        (legacy_F.ten_crop, {}),
        (legacy_F.adjust_brightness, {}),
        (legacy_F.adjust_contrast, {}),
        (legacy_F.adjust_saturation, {}),
        (legacy_F.adjust_hue, {}),
        (legacy_F.adjust_gamma, {}),
794
795
        (legacy_F.rotate, {"center", "fill", "interpolation"}),
        (legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
796
797
798
799
800
801
802
803
804
805
806
        (legacy_F.to_grayscale, {}),
        (legacy_F.rgb_to_grayscale, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.erase, {}),
        (legacy_F.gaussian_blur, {}),
        (legacy_F.invert, {}),
        (legacy_F.posterize, {}),
        (legacy_F.solarize, {}),
        (legacy_F.adjust_sharpness, {}),
        (legacy_F.autocontrast, {}),
        (legacy_F.equalize, {}),
807
        (legacy_F.elastic_transform, {"fill", "interpolation"}),
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    ],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
    legacy_signature = inspect.signature(legacy_dispatcher)
    legacy_params = list(legacy_signature.parameters.values())[1:]

    try:
        prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
    except AttributeError:
        raise AssertionError(
            f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
        ) from None

    prototype_signature = inspect.signature(prototype_dispatcher)
    prototype_params = list(prototype_signature.parameters.values())[1:]

    # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
    # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
    # regular check below.
    prototype_params, new_prototype_params = (
        prototype_params[: len(legacy_params)],
        prototype_params[len(legacy_params) :],
    )
    for param in new_prototype_params:
        assert param.default is not param.empty

    # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
    # annotations. In these cases we simply drop the annotation and default argument from the comparison
    for prototype_param, legacy_param in zip(prototype_params, legacy_params):
        if legacy_param.name in name_only_params:
            prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
            legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
        elif legacy_param.annotation is inspect.Parameter.empty:
            prototype_param._annotation = inspect.Parameter.empty

    assert prototype_params == legacy_params