test_transforms_v2_consistency.py 30 KB
Newer Older
1
2
import importlib.machinery
import importlib.util
3
import inspect
4
import random
5
import re
6
from pathlib import Path
7

8
import numpy as np
9
import pytest
10
11

import torch
12
import torchvision.transforms.v2 as v2_transforms
13
from common_utils import assert_close, assert_equal, set_rng_seed
14
from torch import nn
15
from torchvision import transforms as legacy_transforms, tv_tensors
16
from torchvision._utils import sequence_to_str
17

18
from torchvision.transforms import functional as legacy_F
19
from torchvision.transforms.v2 import functional as prototype_F
Nicolas Hug's avatar
Nicolas Hug committed
20
from torchvision.transforms.v2._utils import _get_fill, query_size
21
from torchvision.transforms.v2.functional import to_pil_image
22
23
24
25
26
27
28
29
from transforms_v2_legacy_utils import (
    ArgsKwargs,
    make_bounding_boxes,
    make_detection_mask,
    make_image,
    make_images,
    make_segmentation_mask,
)
30

31
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
32
33


Nicolas Hug's avatar
Nicolas Hug committed
34
35
36
37
38
39
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


40
41
42
43
44
45
46
47
48
class NotScriptableArgsKwargs(ArgsKwargs):
    """
    This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
    thus will be tested there, but will be skipped by the JIT tests.
    """

    pass


49
50
class ConsistencyConfig:
    def __init__(
51
52
53
        self,
        prototype_cls,
        legacy_cls,
54
55
        # If no args_kwargs is passed, only the signature will be checked
        args_kwargs=(),
56
57
58
        make_images_kwargs=None,
        supports_pil=True,
        removed_params=(),
59
        closeness_kwargs=None,
60
61
62
    ):
        self.prototype_cls = prototype_cls
        self.legacy_cls = legacy_cls
63
        self.args_kwargs = args_kwargs
64
65
        self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
        self.supports_pil = supports_pil
66
        self.removed_params = removed_params
67
        self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
68
69


70
71
72
73
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

74
75
CONSISTENCY_CONFIGS = [
    ConsistencyConfig(
76
        v2_transforms.Normalize,
77
78
79
80
81
82
83
        legacy_transforms.Normalize,
        [
            ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ],
        supports_pil=False,
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
    ),
84
    ConsistencyConfig(
85
        v2_transforms.FiveCrop,
86
87
88
89
90
91
92
93
        legacy_transforms.FiveCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
    ConsistencyConfig(
94
        v2_transforms.TenCrop,
95
96
97
98
        legacy_transforms.TenCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
99
            ArgsKwargs(18, vertical_flip=True),
100
101
102
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
103
104
    *[
        ConsistencyConfig(
105
            v2_transforms.LinearTransformation,
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            legacy_transforms.LinearTransformation,
            [
                ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
            ],
            # Make sure that the product of the height, width and number of channels matches the number of elements in
            # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
            make_images_kwargs=dict(
                DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
            ),
            supports_pil=False,
        )
        for matrix_dtype, image_dtype in [
            (torch.float32, torch.float32),
            (torch.float64, torch.float64),
            (torch.float32, torch.uint8),
            (torch.float64, torch.float32),
            (torch.float32, torch.float64),
        ]
    ],
125
    ConsistencyConfig(
126
        v2_transforms.ToPILImage,
127
        legacy_transforms.ToPILImage,
128
        [NotScriptableArgsKwargs()],
129
130
        make_images_kwargs=dict(
            color_spaces=[
131
132
133
134
                "GRAY",
                "GRAY_ALPHA",
                "RGB",
                "RGBA",
135
136
137
138
139
140
            ],
            extra_dims=[()],
        ),
        supports_pil=False,
    ),
    ConsistencyConfig(
141
        v2_transforms.Lambda,
142
143
        legacy_transforms.Lambda,
        [
144
            NotScriptableArgsKwargs(lambda image: image / 2),
145
146
147
148
149
        ],
        # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
        # images given that the transform does nothing but call it anyway.
        supports_pil=False,
    ),
150
    ConsistencyConfig(
151
        v2_transforms.RandomEqualize,
152
153
154
155
156
157
158
159
        legacy_transforms.RandomEqualize,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
160
        v2_transforms.RandomInvert,
161
162
163
164
165
166
167
        legacy_transforms.RandomInvert,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
168
        v2_transforms.RandomPosterize,
169
170
171
172
173
174
175
176
177
        legacy_transforms.RandomPosterize,
        [
            ArgsKwargs(p=0, bits=5),
            ArgsKwargs(p=1, bits=1),
            ArgsKwargs(p=1, bits=3),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
178
        v2_transforms.RandomSolarize,
179
180
181
182
183
184
185
        legacy_transforms.RandomSolarize,
        [
            ArgsKwargs(p=0, threshold=0.5),
            ArgsKwargs(p=1, threshold=0.3),
            ArgsKwargs(p=1, threshold=0.99),
        ],
    ),
186
187
    *[
        ConsistencyConfig(
188
            v2_transforms.RandomAutocontrast,
189
190
191
192
193
194
195
196
197
198
            legacy_transforms.RandomAutocontrast,
            [
                ArgsKwargs(p=0),
                ArgsKwargs(p=1),
            ],
            make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
            closeness_kwargs=ckw,
        )
        for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
    ],
199
    ConsistencyConfig(
200
        v2_transforms.RandomAdjustSharpness,
201
202
203
        legacy_transforms.RandomAdjustSharpness,
        [
            ArgsKwargs(p=0, sharpness_factor=0.5),
204
            ArgsKwargs(p=1, sharpness_factor=0.2),
205
206
            ArgsKwargs(p=1, sharpness_factor=0.99),
        ],
207
        closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
208
    ),
209
    ConsistencyConfig(
210
        v2_transforms.PILToTensor,
211
212
213
        legacy_transforms.PILToTensor,
    ),
    ConsistencyConfig(
214
        v2_transforms.ToTensor,
215
216
217
        legacy_transforms.ToTensor,
    ),
    ConsistencyConfig(
218
        v2_transforms.Compose,
219
220
221
        legacy_transforms.Compose,
    ),
    ConsistencyConfig(
222
        v2_transforms.RandomApply,
223
224
225
        legacy_transforms.RandomApply,
    ),
    ConsistencyConfig(
226
        v2_transforms.RandomChoice,
227
228
229
        legacy_transforms.RandomChoice,
    ),
    ConsistencyConfig(
230
        v2_transforms.RandomOrder,
231
232
233
        legacy_transforms.RandomOrder,
    ),
    ConsistencyConfig(
234
        v2_transforms.AugMix,
235
236
237
        legacy_transforms.AugMix,
    ),
    ConsistencyConfig(
238
        v2_transforms.AutoAugment,
239
240
241
        legacy_transforms.AutoAugment,
    ),
    ConsistencyConfig(
242
        v2_transforms.RandAugment,
243
244
245
        legacy_transforms.RandAugment,
    ),
    ConsistencyConfig(
246
        v2_transforms.TrivialAugmentWide,
247
248
        legacy_transforms.TrivialAugmentWide,
    ),
249
250
251
]


252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
    legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
    prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

    for param in config.removed_params:
        legacy_params.pop(param, None)

    missing = legacy_params.keys() - prototype_params.keys()
    if missing:
        raise AssertionError(
            f"The prototype transform does not support the parameters "
            f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
            f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
            f"the `ConsistencyConfig`."
        )

    extra = prototype_params.keys() - legacy_params.keys()
270
271
272
273
274
275
    extra_without_default = {
        param
        for param in extra
        if prototype_params[param].default is inspect.Parameter.empty
        and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
    }
276
277
    if extra_without_default:
        raise AssertionError(
278
279
280
            f"The prototype transform requires the parameters "
            f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
            f"not. Please add a default value."
281
282
        )

283
284
285
286
287
288
    legacy_signature = list(legacy_params.keys())
    # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
    # to the same number of parameters as the legacy one
    prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

    assert prototype_signature == legacy_signature
289
290


291
292
293
def check_call_consistency(
    prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
294
295
    if images is None:
        images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
296

297
298
    closeness_kwargs = closeness_kwargs or dict()

299
300
    for image in images:
        image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
301
302
303

        image_tensor = torch.Tensor(image)
        try:
304
            torch.manual_seed(0)
305
            output_legacy_tensor = legacy_transform(image_tensor)
306
307
        except Exception as exc:
            raise pytest.UsageError(
308
                f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
309
                f"error above. This means that you need to specify the parameters passed to `make_images` through the "
310
311
312
313
                "`make_images_kwargs` of the `ConsistencyConfig`."
            ) from exc

        try:
314
            torch.manual_seed(0)
315
            output_prototype_tensor = prototype_transform(image_tensor)
316
317
        except Exception as exc:
            raise AssertionError(
318
                f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
319
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
320
                f"`is_pure_tensor` path in `_transform`."
321
322
            ) from exc

323
        assert_close(
324
325
326
            output_prototype_tensor,
            output_legacy_tensor,
            msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
327
            **closeness_kwargs,
328
329
330
        )

        try:
331
            torch.manual_seed(0)
332
            output_prototype_image = prototype_transform(image)
333
334
        except Exception as exc:
            raise AssertionError(
335
                f"Transforming a image tv_tensor with shape {image_repr} failed in the prototype transform with "
336
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
337
                f"`tv_tensors.Image` path in `_transform`."
338
339
            ) from exc

340
        assert_close(
341
            output_prototype_image,
342
            output_prototype_tensor,
343
            msg=lambda msg: f"Output for tv_tensor and tensor images is not equal: \n\n{msg}",
344
            **closeness_kwargs,
345
346
        )

347
        if image.ndim == 3 and supports_pil:
348
            image_pil = to_pil_image(image)
349

350
            try:
351
                torch.manual_seed(0)
352
                output_legacy_pil = legacy_transform(image_pil)
353
354
            except Exception as exc:
                raise pytest.UsageError(
355
                    f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
356
357
358
359
360
                    f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
                    "`ConsistencyConfig`. "
                ) from exc

            try:
361
                torch.manual_seed(0)
362
                output_prototype_pil = prototype_transform(image_pil)
363
364
            except Exception as exc:
                raise AssertionError(
365
                    f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
366
367
368
369
                    f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                    f"`PIL.Image.Image` path in `_transform`."
                ) from exc

370
            assert_close(
371
372
                output_prototype_pil,
                output_legacy_pil,
373
                msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
374
                **closeness_kwargs,
375
            )
376
377


378
@pytest.mark.parametrize(
379
380
    ("config", "args_kwargs"),
    [
381
382
383
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
384
        for config in CONSISTENCY_CONFIGS
385
        for idx, args_kwargs in enumerate(config.args_kwargs)
386
    ],
387
)
388
@pytest.mark.filterwarnings("ignore")
389
def test_call_consistency(config, args_kwargs):
390
391
392
    args, kwargs = args_kwargs

    try:
393
        legacy_transform = config.legacy_cls(*args, **kwargs)
394
395
396
397
398
399
400
    except Exception as exc:
        raise pytest.UsageError(
            f"Initializing the legacy transform failed with the error above. "
            f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
        ) from exc

    try:
401
        prototype_transform = config.prototype_cls(*args, **kwargs)
402
403
404
405
406
407
    except Exception as exc:
        raise AssertionError(
            "Initializing the prototype transform failed with the error above. "
            "This means there is a consistency bug in the constructor."
        ) from exc

408
409
410
411
412
    check_call_consistency(
        prototype_transform,
        legacy_transform,
        images=make_images(**config.make_images_kwargs),
        supports_pil=config.supports_pil,
413
        closeness_kwargs=config.closeness_kwargs,
414
415
416
    )


417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
@pytest.mark.parametrize(
    ("config", "args_kwargs"),
    [
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
        for config in CONSISTENCY_CONFIGS
        for idx, args_kwargs in enumerate(config.args_kwargs)
        if not isinstance(args_kwargs, NotScriptableArgsKwargs)
    ],
)
def test_jit_consistency(config, args_kwargs):
    args, kwargs = args_kwargs

    prototype_transform_eager = config.prototype_cls(*args, **kwargs)
    legacy_transform_eager = config.legacy_cls(*args, **kwargs)

    legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
    prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

    for image in make_images(**config.make_images_kwargs):
        image = image.as_subclass(torch.Tensor)

        torch.manual_seed(0)
        output_legacy_scripted = legacy_transform_scripted(image)

        torch.manual_seed(0)
        output_prototype_scripted = prototype_transform_scripted(image)

        assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


449
450
451
452
453
454
455
456
457
458
class TestContainerTransforms:
    """
    Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
    consistency automatically tests the wrapped transforms consistency.

    Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
    that were already tested for consistency above.
    """

    def test_compose(self):
459
        prototype_transform = v2_transforms.Compose(
460
            [
461
462
                v2_transforms.Resize(256),
                v2_transforms.CenterCrop(224),
463
464
465
466
467
468
469
470
471
            ]
        )
        legacy_transform = legacy_transforms.Compose(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ]
        )

472
473
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
474
475

    @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
476
477
    @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
    def test_random_apply(self, p, sequence_type):
478
        prototype_transform = v2_transforms.RandomApply(
479
480
            sequence_type(
                [
481
482
                    v2_transforms.Resize(256),
                    v2_transforms.CenterCrop(224),
483
484
                ]
            ),
485
486
487
            p=p,
        )
        legacy_transform = legacy_transforms.RandomApply(
488
489
490
491
492
493
            sequence_type(
                [
                    legacy_transforms.Resize(256),
                    legacy_transforms.CenterCrop(224),
                ]
            ),
494
495
496
            p=p,
        )

497
498
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
499

500
501
502
503
504
        if sequence_type is nn.ModuleList:
            # quick and dirty test that it is jit-scriptable
            scripted = torch.jit.script(prototype_transform)
            scripted(torch.rand(1, 3, 300, 300))

505
    # We can't test other values for `p` since the random parameter generation is different
506
507
    @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
    def test_random_choice(self, probabilities):
508
        prototype_transform = v2_transforms.RandomChoice(
509
            [
510
                v2_transforms.Resize(256),
511
512
                legacy_transforms.CenterCrop(224),
            ],
513
            p=probabilities,
514
515
516
517
518
519
        )
        legacy_transform = legacy_transforms.RandomChoice(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ],
520
            p=probabilities,
521
522
        )

523
524
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
525
526


527
528
class TestToTensorTransforms:
    def test_pil_to_tensor(self):
529
        prototype_transform = v2_transforms.PILToTensor()
530
531
        legacy_transform = legacy_transforms.PILToTensor()

532
        for image in make_images(extra_dims=[()]):
533
            image_pil = to_pil_image(image)
534
535
536
537

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

    def test_to_tensor(self):
538
        with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
539
            prototype_transform = v2_transforms.ToTensor()
540
541
        legacy_transform = legacy_transforms.ToTensor()

542
        for image in make_images(extra_dims=[()]):
543
            image_pil = to_pil_image(image)
544
545
546
547
            image_numpy = np.array(image_pil)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
            assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
548
549


550
def import_transforms_from_references(reference):
551
552
553
554
555
556
557
558
559
560
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
561
562
563


det_transforms = import_transforms_from_references("detection")
564
565
566


class TestRefDetTransforms:
567
    def make_tv_tensors(self, with_mask=True):
568
569
570
        size = (600, 800)
        num_objects = 22

571
572
573
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

574
        pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
575
        target = {
576
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
577
578
579
580
581
582
583
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

584
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
585
        target = {
586
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
587
588
589
590
591
592
593
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

594
        tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
595
        target = {
596
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
597
598
599
600
601
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

602
        yield (tv_tensor_image, target)
603
604
605
606

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
607
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
608
609
610
611
612
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
613
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
614
615
616
617
                    ]
                ),
                {"with_mask": False},
            ),
618
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
619
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
620
621
622
623
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
624
                v2_transforms.RandomShortestSize(
625
626
627
628
629
630
631
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
632
        for dp in self.make_tv_tensors(**data_kwargs):
633
634
635
636
637
638
639
640
641

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
642
643
644
645
646
647
648
649
650


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
651
class PadIfSmaller(v2_transforms.Transform):
652
653
654
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
655
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
656
657

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
658
        height, width = query_size(sample)
659
660
661
662
663
664
665
666
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

667
        fill = _get_fill(self.fill, type(inpt))
668
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
669
670
671


class TestRefSegTransforms:
672
    def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
673
        size = (256, 460)
674
675
676
677
        num_categories = 21

        conv_fns = []
        if supports_pil:
678
            conv_fns.append(to_pil_image)
679
680
681
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
682
683
            tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
            tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
684

685
            dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
686
            dp_ref = (
687
688
                to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
                to_pil_image(tv_tensor_mask),
689
690
691
692
693
694
695
696
697
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
698
        for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
699
700

            self.set_seed()
701
            actual = actual_image, actual_mask = t(dp)
702
703

            self.set_seed()
704
705
706
707
708
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
709

710
            assert_equal(actual, expected)
711
712
713
714
715
716

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
717
                v2_transforms.RandomHorizontalFlip(p=1.0),
718
719
720
721
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
722
                v2_transforms.RandomHorizontalFlip(p=0.0),
723
724
725
726
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
727
                v2_transforms.Compose(
728
                    [
729
                        PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
730
                        v2_transforms.RandomCrop(size=480),
731
732
733
734
735
736
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
737
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
738
739
740
741
742
743
744
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)

745
746
747
748
749
750
751
752
753
754
755
756

@pytest.mark.parametrize(
    ("legacy_dispatcher", "name_only_params"),
    [
        (legacy_F.get_dimensions, {}),
        (legacy_F.get_image_size, {}),
        (legacy_F.get_image_num_channels, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.pil_to_tensor, {}),
        (legacy_F.convert_image_dtype, {}),
        (legacy_F.to_pil_image, {}),
        (legacy_F.normalize, {}),
757
        (legacy_F.resize, {"interpolation"}),
758
759
760
        (legacy_F.pad, {"padding", "fill"}),
        (legacy_F.crop, {}),
        (legacy_F.center_crop, {}),
761
        (legacy_F.resized_crop, {"interpolation"}),
762
        (legacy_F.hflip, {}),
763
        (legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
764
765
766
767
768
769
770
771
        (legacy_F.vflip, {}),
        (legacy_F.five_crop, {}),
        (legacy_F.ten_crop, {}),
        (legacy_F.adjust_brightness, {}),
        (legacy_F.adjust_contrast, {}),
        (legacy_F.adjust_saturation, {}),
        (legacy_F.adjust_hue, {}),
        (legacy_F.adjust_gamma, {}),
772
773
        (legacy_F.rotate, {"center", "fill", "interpolation"}),
        (legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
774
775
776
777
778
779
780
781
782
783
784
        (legacy_F.to_grayscale, {}),
        (legacy_F.rgb_to_grayscale, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.erase, {}),
        (legacy_F.gaussian_blur, {}),
        (legacy_F.invert, {}),
        (legacy_F.posterize, {}),
        (legacy_F.solarize, {}),
        (legacy_F.adjust_sharpness, {}),
        (legacy_F.autocontrast, {}),
        (legacy_F.equalize, {}),
785
        (legacy_F.elastic_transform, {"fill", "interpolation"}),
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
    ],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
    legacy_signature = inspect.signature(legacy_dispatcher)
    legacy_params = list(legacy_signature.parameters.values())[1:]

    try:
        prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
    except AttributeError:
        raise AssertionError(
            f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
        ) from None

    prototype_signature = inspect.signature(prototype_dispatcher)
    prototype_params = list(prototype_signature.parameters.values())[1:]

    # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
    # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
    # regular check below.
    prototype_params, new_prototype_params = (
        prototype_params[: len(legacy_params)],
        prototype_params[len(legacy_params) :],
    )
    for param in new_prototype_params:
        assert param.default is not param.empty

    # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
    # annotations. In these cases we simply drop the annotation and default argument from the comparison
    for prototype_param, legacy_param in zip(prototype_params, legacy_params):
        if legacy_param.name in name_only_params:
            prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
            legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
        elif legacy_param.annotation is inspect.Parameter.empty:
            prototype_param._annotation = inspect.Parameter.empty

    assert prototype_params == legacy_params