test_transforms_v2_consistency.py 33.7 KB
Newer Older
1
2
import importlib.machinery
import importlib.util
3
import inspect
4
import random
5
import re
6
from pathlib import Path
7

8
import numpy as np
9
import PIL.Image
10
import pytest
11
12

import torch
13
import torchvision.transforms.v2 as v2_transforms
14
from common_utils import assert_close, assert_equal, set_rng_seed
15
from torch import nn
16
from torchvision import transforms as legacy_transforms, tv_tensors
17
from torchvision._utils import sequence_to_str
18

19
from torchvision.transforms import functional as legacy_F
20
from torchvision.transforms.v2 import functional as prototype_F
Nicolas Hug's avatar
Nicolas Hug committed
21
from torchvision.transforms.v2._utils import _get_fill, query_size
22
from torchvision.transforms.v2.functional import to_pil_image
23
24
25
26
27
28
29
30
from transforms_v2_legacy_utils import (
    ArgsKwargs,
    make_bounding_boxes,
    make_detection_mask,
    make_image,
    make_images,
    make_segmentation_mask,
)
31

32
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
33
34


Nicolas Hug's avatar
Nicolas Hug committed
35
36
37
38
39
40
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


41
42
43
44
45
46
47
48
49
class NotScriptableArgsKwargs(ArgsKwargs):
    """
    This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
    thus will be tested there, but will be skipped by the JIT tests.
    """

    pass


50
51
class ConsistencyConfig:
    def __init__(
52
53
54
        self,
        prototype_cls,
        legacy_cls,
55
56
        # If no args_kwargs is passed, only the signature will be checked
        args_kwargs=(),
57
58
59
        make_images_kwargs=None,
        supports_pil=True,
        removed_params=(),
60
        closeness_kwargs=None,
61
62
63
    ):
        self.prototype_cls = prototype_cls
        self.legacy_cls = legacy_cls
64
        self.args_kwargs = args_kwargs
65
66
        self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
        self.supports_pil = supports_pil
67
        self.removed_params = removed_params
68
        self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
69
70


71
72
73
74
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

75
76
CONSISTENCY_CONFIGS = [
    ConsistencyConfig(
77
        v2_transforms.Normalize,
78
79
80
81
82
83
84
85
        legacy_transforms.Normalize,
        [
            ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ],
        supports_pil=False,
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
    ),
    ConsistencyConfig(
86
        v2_transforms.CenterCrop,
87
88
89
90
91
92
        legacy_transforms.CenterCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
        ],
    ),
93
    ConsistencyConfig(
94
        v2_transforms.FiveCrop,
95
96
97
98
99
100
101
102
        legacy_transforms.FiveCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
    ConsistencyConfig(
103
        v2_transforms.TenCrop,
104
105
106
107
        legacy_transforms.TenCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
108
            ArgsKwargs(18, vertical_flip=True),
109
110
111
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
112
113
    *[
        ConsistencyConfig(
114
            v2_transforms.LinearTransformation,
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            legacy_transforms.LinearTransformation,
            [
                ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
            ],
            # Make sure that the product of the height, width and number of channels matches the number of elements in
            # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
            make_images_kwargs=dict(
                DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
            ),
            supports_pil=False,
        )
        for matrix_dtype, image_dtype in [
            (torch.float32, torch.float32),
            (torch.float64, torch.float64),
            (torch.float32, torch.uint8),
            (torch.float64, torch.float32),
            (torch.float32, torch.float64),
        ]
    ],
134
    ConsistencyConfig(
135
        v2_transforms.Grayscale,
136
137
138
139
140
        legacy_transforms.Grayscale,
        [
            ArgsKwargs(num_output_channels=1),
            ArgsKwargs(num_output_channels=3),
        ],
141
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
142
143
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
144
    ),
145
    ConsistencyConfig(
146
        v2_transforms.ToPILImage,
147
        legacy_transforms.ToPILImage,
148
        [NotScriptableArgsKwargs()],
149
150
        make_images_kwargs=dict(
            color_spaces=[
151
152
153
154
                "GRAY",
                "GRAY_ALPHA",
                "RGB",
                "RGBA",
155
156
157
158
159
160
            ],
            extra_dims=[()],
        ),
        supports_pil=False,
    ),
    ConsistencyConfig(
161
        v2_transforms.Lambda,
162
163
        legacy_transforms.Lambda,
        [
164
            NotScriptableArgsKwargs(lambda image: image / 2),
165
166
167
168
169
        ],
        # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
        # images given that the transform does nothing but call it anyway.
        supports_pil=False,
    ),
170
    ConsistencyConfig(
171
        v2_transforms.RandomEqualize,
172
173
174
175
176
177
178
179
        legacy_transforms.RandomEqualize,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
180
        v2_transforms.RandomInvert,
181
182
183
184
185
186
187
        legacy_transforms.RandomInvert,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
188
        v2_transforms.RandomPosterize,
189
190
191
192
193
194
195
196
197
        legacy_transforms.RandomPosterize,
        [
            ArgsKwargs(p=0, bits=5),
            ArgsKwargs(p=1, bits=1),
            ArgsKwargs(p=1, bits=3),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
198
        v2_transforms.RandomSolarize,
199
200
201
202
203
204
205
        legacy_transforms.RandomSolarize,
        [
            ArgsKwargs(p=0, threshold=0.5),
            ArgsKwargs(p=1, threshold=0.3),
            ArgsKwargs(p=1, threshold=0.99),
        ],
    ),
206
207
    *[
        ConsistencyConfig(
208
            v2_transforms.RandomAutocontrast,
209
210
211
212
213
214
215
216
217
218
            legacy_transforms.RandomAutocontrast,
            [
                ArgsKwargs(p=0),
                ArgsKwargs(p=1),
            ],
            make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
            closeness_kwargs=ckw,
        )
        for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
    ],
219
    ConsistencyConfig(
220
        v2_transforms.RandomAdjustSharpness,
221
222
223
        legacy_transforms.RandomAdjustSharpness,
        [
            ArgsKwargs(p=0, sharpness_factor=0.5),
224
            ArgsKwargs(p=1, sharpness_factor=0.2),
225
226
            ArgsKwargs(p=1, sharpness_factor=0.99),
        ],
227
        closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
228
229
    ),
    ConsistencyConfig(
230
        v2_transforms.RandomGrayscale,
231
232
233
234
235
        legacy_transforms.RandomGrayscale,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
236
237
238
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
239
240
    ),
    ConsistencyConfig(
241
        v2_transforms.ColorJitter,
242
243
244
245
246
247
248
249
250
251
252
        legacy_transforms.ColorJitter,
        [
            ArgsKwargs(),
            ArgsKwargs(brightness=0.1),
            ArgsKwargs(brightness=(0.2, 0.3)),
            ArgsKwargs(contrast=0.4),
            ArgsKwargs(contrast=(0.5, 0.6)),
            ArgsKwargs(saturation=0.7),
            ArgsKwargs(saturation=(0.8, 0.9)),
            ArgsKwargs(hue=0.3),
            ArgsKwargs(hue=(-0.1, 0.2)),
253
            ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
254
        ],
255
        closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
256
257
    ),
    ConsistencyConfig(
258
        v2_transforms.RandomPerspective,
259
260
261
262
263
        legacy_transforms.RandomPerspective,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
            ArgsKwargs(p=1, distortion_scale=0.3),
264
            ArgsKwargs(p=1, distortion_scale=0.2, interpolation=v2_transforms.InterpolationMode.NEAREST),
265
            ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
266
267
268
            ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
            ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
        ],
269
        closeness_kwargs={"atol": None, "rtol": None},
270
    ),
271
    ConsistencyConfig(
272
        v2_transforms.PILToTensor,
273
274
275
        legacy_transforms.PILToTensor,
    ),
    ConsistencyConfig(
276
        v2_transforms.ToTensor,
277
278
279
        legacy_transforms.ToTensor,
    ),
    ConsistencyConfig(
280
        v2_transforms.Compose,
281
282
283
        legacy_transforms.Compose,
    ),
    ConsistencyConfig(
284
        v2_transforms.RandomApply,
285
286
287
        legacy_transforms.RandomApply,
    ),
    ConsistencyConfig(
288
        v2_transforms.RandomChoice,
289
290
291
        legacy_transforms.RandomChoice,
    ),
    ConsistencyConfig(
292
        v2_transforms.RandomOrder,
293
294
295
        legacy_transforms.RandomOrder,
    ),
    ConsistencyConfig(
296
        v2_transforms.AugMix,
297
298
299
        legacy_transforms.AugMix,
    ),
    ConsistencyConfig(
300
        v2_transforms.AutoAugment,
301
302
303
        legacy_transforms.AutoAugment,
    ),
    ConsistencyConfig(
304
        v2_transforms.RandAugment,
305
306
307
        legacy_transforms.RandAugment,
    ),
    ConsistencyConfig(
308
        v2_transforms.TrivialAugmentWide,
309
310
        legacy_transforms.TrivialAugmentWide,
    ),
311
312
313
]


314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
    legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
    prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

    for param in config.removed_params:
        legacy_params.pop(param, None)

    missing = legacy_params.keys() - prototype_params.keys()
    if missing:
        raise AssertionError(
            f"The prototype transform does not support the parameters "
            f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
            f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
            f"the `ConsistencyConfig`."
        )

    extra = prototype_params.keys() - legacy_params.keys()
332
333
334
335
336
337
    extra_without_default = {
        param
        for param in extra
        if prototype_params[param].default is inspect.Parameter.empty
        and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
    }
338
339
    if extra_without_default:
        raise AssertionError(
340
341
342
            f"The prototype transform requires the parameters "
            f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
            f"not. Please add a default value."
343
344
        )

345
346
347
348
349
350
    legacy_signature = list(legacy_params.keys())
    # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
    # to the same number of parameters as the legacy one
    prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

    assert prototype_signature == legacy_signature
351
352


353
354
355
def check_call_consistency(
    prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
356
357
    if images is None:
        images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
358

359
360
    closeness_kwargs = closeness_kwargs or dict()

361
362
    for image in images:
        image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
363
364
365

        image_tensor = torch.Tensor(image)
        try:
366
            torch.manual_seed(0)
367
            output_legacy_tensor = legacy_transform(image_tensor)
368
369
        except Exception as exc:
            raise pytest.UsageError(
370
                f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
371
                f"error above. This means that you need to specify the parameters passed to `make_images` through the "
372
373
374
375
                "`make_images_kwargs` of the `ConsistencyConfig`."
            ) from exc

        try:
376
            torch.manual_seed(0)
377
            output_prototype_tensor = prototype_transform(image_tensor)
378
379
        except Exception as exc:
            raise AssertionError(
380
                f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
381
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
382
                f"`is_pure_tensor` path in `_transform`."
383
384
            ) from exc

385
        assert_close(
386
387
388
            output_prototype_tensor,
            output_legacy_tensor,
            msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
389
            **closeness_kwargs,
390
391
392
        )

        try:
393
            torch.manual_seed(0)
394
            output_prototype_image = prototype_transform(image)
395
396
        except Exception as exc:
            raise AssertionError(
397
                f"Transforming a image tv_tensor with shape {image_repr} failed in the prototype transform with "
398
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
399
                f"`tv_tensors.Image` path in `_transform`."
400
401
            ) from exc

402
        assert_close(
403
            output_prototype_image,
404
            output_prototype_tensor,
405
            msg=lambda msg: f"Output for tv_tensor and tensor images is not equal: \n\n{msg}",
406
            **closeness_kwargs,
407
408
        )

409
        if image.ndim == 3 and supports_pil:
410
            image_pil = to_pil_image(image)
411

412
            try:
413
                torch.manual_seed(0)
414
                output_legacy_pil = legacy_transform(image_pil)
415
416
            except Exception as exc:
                raise pytest.UsageError(
417
                    f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
418
419
420
421
422
                    f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
                    "`ConsistencyConfig`. "
                ) from exc

            try:
423
                torch.manual_seed(0)
424
                output_prototype_pil = prototype_transform(image_pil)
425
426
            except Exception as exc:
                raise AssertionError(
427
                    f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
428
429
430
431
                    f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                    f"`PIL.Image.Image` path in `_transform`."
                ) from exc

432
            assert_close(
433
434
                output_prototype_pil,
                output_legacy_pil,
435
                msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
436
                **closeness_kwargs,
437
            )
438
439


440
@pytest.mark.parametrize(
441
442
    ("config", "args_kwargs"),
    [
443
444
445
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
446
        for config in CONSISTENCY_CONFIGS
447
        for idx, args_kwargs in enumerate(config.args_kwargs)
448
    ],
449
)
450
@pytest.mark.filterwarnings("ignore")
451
def test_call_consistency(config, args_kwargs):
452
453
454
    args, kwargs = args_kwargs

    try:
455
        legacy_transform = config.legacy_cls(*args, **kwargs)
456
457
458
459
460
461
462
    except Exception as exc:
        raise pytest.UsageError(
            f"Initializing the legacy transform failed with the error above. "
            f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
        ) from exc

    try:
463
        prototype_transform = config.prototype_cls(*args, **kwargs)
464
465
466
467
468
469
    except Exception as exc:
        raise AssertionError(
            "Initializing the prototype transform failed with the error above. "
            "This means there is a consistency bug in the constructor."
        ) from exc

470
471
472
473
474
    check_call_consistency(
        prototype_transform,
        legacy_transform,
        images=make_images(**config.make_images_kwargs),
        supports_pil=config.supports_pil,
475
        closeness_kwargs=config.closeness_kwargs,
476
477
478
    )


479
480
481
482
483
484
485
486
487
get_params_parametrization = pytest.mark.parametrize(
    ("config", "get_params_args_kwargs"),
    [
        pytest.param(
            next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls),
            get_params_args_kwargs,
            id=transform_cls.__name__,
        )
        for transform_cls, get_params_args_kwargs in [
488
489
490
            (v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
            (v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
            (v2_transforms.AutoAugment, ArgsKwargs(5)),
491
492
        ]
    ],
493
)
494
495


496
@get_params_parametrization
497
def test_get_params_alias(config, get_params_args_kwargs):
498
499
    assert config.prototype_cls.get_params is config.legacy_cls.get_params

500
501
502
503
504
    if not config.args_kwargs:
        return
    args, kwargs = config.args_kwargs[0]
    legacy_transform = config.legacy_cls(*args, **kwargs)
    prototype_transform = config.prototype_cls(*args, **kwargs)
505

506
507
508
    assert prototype_transform.get_params is legacy_transform.get_params


509
@get_params_parametrization
510
511
512
513
514
515
516
517
518
def test_get_params_jit(config, get_params_args_kwargs):
    get_params_args, get_params_kwargs = get_params_args_kwargs

    torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs)

    if not config.args_kwargs:
        return
    args, kwargs = config.args_kwargs[0]
    transform = config.prototype_cls(*args, **kwargs)
519

520
    torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs)
521
522


523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
@pytest.mark.parametrize(
    ("config", "args_kwargs"),
    [
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
        for config in CONSISTENCY_CONFIGS
        for idx, args_kwargs in enumerate(config.args_kwargs)
        if not isinstance(args_kwargs, NotScriptableArgsKwargs)
    ],
)
def test_jit_consistency(config, args_kwargs):
    args, kwargs = args_kwargs

    prototype_transform_eager = config.prototype_cls(*args, **kwargs)
    legacy_transform_eager = config.legacy_cls(*args, **kwargs)

    legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
    prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

    for image in make_images(**config.make_images_kwargs):
        image = image.as_subclass(torch.Tensor)

        torch.manual_seed(0)
        output_legacy_scripted = legacy_transform_scripted(image)

        torch.manual_seed(0)
        output_prototype_scripted = prototype_transform_scripted(image)

        assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


555
556
557
558
559
560
561
562
563
564
class TestContainerTransforms:
    """
    Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
    consistency automatically tests the wrapped transforms consistency.

    Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
    that were already tested for consistency above.
    """

    def test_compose(self):
565
        prototype_transform = v2_transforms.Compose(
566
            [
567
568
                v2_transforms.Resize(256),
                v2_transforms.CenterCrop(224),
569
570
571
572
573
574
575
576
577
            ]
        )
        legacy_transform = legacy_transforms.Compose(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ]
        )

578
579
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
580
581

    @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
582
583
    @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
    def test_random_apply(self, p, sequence_type):
584
        prototype_transform = v2_transforms.RandomApply(
585
586
            sequence_type(
                [
587
588
                    v2_transforms.Resize(256),
                    v2_transforms.CenterCrop(224),
589
590
                ]
            ),
591
592
593
            p=p,
        )
        legacy_transform = legacy_transforms.RandomApply(
594
595
596
597
598
599
            sequence_type(
                [
                    legacy_transforms.Resize(256),
                    legacy_transforms.CenterCrop(224),
                ]
            ),
600
601
602
            p=p,
        )

603
604
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
605

606
607
608
609
610
        if sequence_type is nn.ModuleList:
            # quick and dirty test that it is jit-scriptable
            scripted = torch.jit.script(prototype_transform)
            scripted(torch.rand(1, 3, 300, 300))

611
    # We can't test other values for `p` since the random parameter generation is different
612
613
    @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
    def test_random_choice(self, probabilities):
614
        prototype_transform = v2_transforms.RandomChoice(
615
            [
616
                v2_transforms.Resize(256),
617
618
                legacy_transforms.CenterCrop(224),
            ],
619
            p=probabilities,
620
621
622
623
624
625
        )
        legacy_transform = legacy_transforms.RandomChoice(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ],
626
            p=probabilities,
627
628
        )

629
630
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
631
632


633
634
class TestToTensorTransforms:
    def test_pil_to_tensor(self):
635
        prototype_transform = v2_transforms.PILToTensor()
636
637
        legacy_transform = legacy_transforms.PILToTensor()

638
        for image in make_images(extra_dims=[()]):
639
            image_pil = to_pil_image(image)
640
641
642
643

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

    def test_to_tensor(self):
644
        with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
645
            prototype_transform = v2_transforms.ToTensor()
646
647
        legacy_transform = legacy_transforms.ToTensor()

648
        for image in make_images(extra_dims=[()]):
649
            image_pil = to_pil_image(image)
650
651
652
653
            image_numpy = np.array(image_pil)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
            assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
654
655


656
def import_transforms_from_references(reference):
657
658
659
660
661
662
663
664
665
666
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
667
668
669


det_transforms = import_transforms_from_references("detection")
670
671
672


class TestRefDetTransforms:
673
    def make_tv_tensors(self, with_mask=True):
674
675
676
        size = (600, 800)
        num_objects = 22

677
678
679
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

680
        pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
681
        target = {
682
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
683
684
685
686
687
688
689
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

690
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
691
        target = {
692
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
693
694
695
696
697
698
699
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

700
        tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
701
        target = {
702
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
703
704
705
706
707
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

708
        yield (tv_tensor_image, target)
709
710
711
712

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
713
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
714
715
716
717
718
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
719
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
720
721
722
723
                    ]
                ),
                {"with_mask": False},
            ),
724
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
725
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
726
727
728
729
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
730
                v2_transforms.RandomShortestSize(
731
732
733
734
735
736
737
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
738
        for dp in self.make_tv_tensors(**data_kwargs):
739
740
741
742
743
744
745
746
747

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
748
749
750
751
752
753
754
755
756


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
757
class PadIfSmaller(v2_transforms.Transform):
758
759
760
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
761
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
762
763

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
764
        height, width = query_size(sample)
765
766
767
768
769
770
771
772
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

773
        fill = _get_fill(self.fill, type(inpt))
774
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
775
776
777


class TestRefSegTransforms:
778
    def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
779
        size = (256, 460)
780
781
782
783
        num_categories = 21

        conv_fns = []
        if supports_pil:
784
            conv_fns.append(to_pil_image)
785
786
787
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
788
789
            tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
            tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
790

791
            dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
792
            dp_ref = (
793
794
                to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
                to_pil_image(tv_tensor_mask),
795
796
797
798
799
800
801
802
803
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
804
        for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
805
806

            self.set_seed()
807
            actual = actual_image, actual_mask = t(dp)
808
809

            self.set_seed()
810
811
812
813
814
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
815

816
            assert_equal(actual, expected)
817
818
819
820
821
822

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
823
                v2_transforms.RandomHorizontalFlip(p=1.0),
824
825
826
827
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
828
                v2_transforms.RandomHorizontalFlip(p=0.0),
829
830
831
832
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
833
                v2_transforms.Compose(
834
                    [
835
                        PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
836
                        v2_transforms.RandomCrop(size=480),
837
838
839
840
841
842
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
843
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
844
845
846
847
848
849
850
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)

851
852
853
854
855
856
857
858
859
860
861
862

@pytest.mark.parametrize(
    ("legacy_dispatcher", "name_only_params"),
    [
        (legacy_F.get_dimensions, {}),
        (legacy_F.get_image_size, {}),
        (legacy_F.get_image_num_channels, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.pil_to_tensor, {}),
        (legacy_F.convert_image_dtype, {}),
        (legacy_F.to_pil_image, {}),
        (legacy_F.normalize, {}),
863
        (legacy_F.resize, {"interpolation"}),
864
865
866
        (legacy_F.pad, {"padding", "fill"}),
        (legacy_F.crop, {}),
        (legacy_F.center_crop, {}),
867
        (legacy_F.resized_crop, {"interpolation"}),
868
        (legacy_F.hflip, {}),
869
        (legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
870
871
872
873
874
875
876
877
        (legacy_F.vflip, {}),
        (legacy_F.five_crop, {}),
        (legacy_F.ten_crop, {}),
        (legacy_F.adjust_brightness, {}),
        (legacy_F.adjust_contrast, {}),
        (legacy_F.adjust_saturation, {}),
        (legacy_F.adjust_hue, {}),
        (legacy_F.adjust_gamma, {}),
878
879
        (legacy_F.rotate, {"center", "fill", "interpolation"}),
        (legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
880
881
882
883
884
885
886
887
888
889
890
        (legacy_F.to_grayscale, {}),
        (legacy_F.rgb_to_grayscale, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.erase, {}),
        (legacy_F.gaussian_blur, {}),
        (legacy_F.invert, {}),
        (legacy_F.posterize, {}),
        (legacy_F.solarize, {}),
        (legacy_F.adjust_sharpness, {}),
        (legacy_F.autocontrast, {}),
        (legacy_F.equalize, {}),
891
        (legacy_F.elastic_transform, {"fill", "interpolation"}),
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
    ],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
    legacy_signature = inspect.signature(legacy_dispatcher)
    legacy_params = list(legacy_signature.parameters.values())[1:]

    try:
        prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
    except AttributeError:
        raise AssertionError(
            f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
        ) from None

    prototype_signature = inspect.signature(prototype_dispatcher)
    prototype_params = list(prototype_signature.parameters.values())[1:]

    # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
    # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
    # regular check below.
    prototype_params, new_prototype_params = (
        prototype_params[: len(legacy_params)],
        prototype_params[len(legacy_params) :],
    )
    for param in new_prototype_params:
        assert param.default is not param.empty

    # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
    # annotations. In these cases we simply drop the annotation and default argument from the comparison
    for prototype_param, legacy_param in zip(prototype_params, legacy_params):
        if legacy_param.name in name_only_params:
            prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
            legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
        elif legacy_param.annotation is inspect.Parameter.empty:
            prototype_param._annotation = inspect.Parameter.empty

    assert prototype_params == legacy_params