test_transforms_v2_consistency.py 49.5 KB
Newer Older
1
import enum
2
3
import importlib.machinery
import importlib.util
4
import inspect
5
import random
6
import re
7
from pathlib import Path
8

9
import numpy as np
10
import PIL.Image
11
import pytest
12
13

import torch
14
import torchvision.transforms.v2 as v2_transforms
15
from common_utils import (
16
    ArgsKwargs,
17
    assert_close,
18
19
20
21
22
    assert_equal,
    make_bounding_box,
    make_detection_mask,
    make_image,
    make_images,
23
    make_segmentation_mask,
Nicolas Hug's avatar
Nicolas Hug committed
24
    set_rng_seed,
25
)
26
from torch import nn
27
from torchvision import datapoints, transforms as legacy_transforms
28
from torchvision._utils import sequence_to_str
29

30
from torchvision.transforms import functional as legacy_F
31
from torchvision.transforms.v2 import functional as prototype_F
32
from torchvision.transforms.v2._utils import _get_fill
33
from torchvision.transforms.v2.functional import to_image_pil
Philip Meier's avatar
Philip Meier committed
34
from torchvision.transforms.v2.utils import query_size
35

36
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
37
38


Nicolas Hug's avatar
Nicolas Hug committed
39
40
41
42
43
44
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


45
46
47
48
49
50
51
52
53
class NotScriptableArgsKwargs(ArgsKwargs):
    """
    This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
    thus will be tested there, but will be skipped by the JIT tests.
    """

    pass


54
55
class ConsistencyConfig:
    def __init__(
56
57
58
        self,
        prototype_cls,
        legacy_cls,
59
60
        # If no args_kwargs is passed, only the signature will be checked
        args_kwargs=(),
61
62
63
        make_images_kwargs=None,
        supports_pil=True,
        removed_params=(),
64
        closeness_kwargs=None,
65
66
67
    ):
        self.prototype_cls = prototype_cls
        self.legacy_cls = legacy_cls
68
        self.args_kwargs = args_kwargs
69
70
        self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
        self.supports_pil = supports_pil
71
        self.removed_params = removed_params
72
        self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
73
74


75
76
77
78
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

79
80
CONSISTENCY_CONFIGS = [
    ConsistencyConfig(
81
        v2_transforms.Normalize,
82
83
84
85
86
87
88
89
        legacy_transforms.Normalize,
        [
            ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ],
        supports_pil=False,
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
    ),
    ConsistencyConfig(
90
        v2_transforms.Resize,
91
92
        legacy_transforms.Resize,
        [
93
            NotScriptableArgsKwargs(32),
94
            ArgsKwargs([32]),
95
            ArgsKwargs((32, 29)),
96
            ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
97
98
            ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
            ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
99
100
101
102
            NotScriptableArgsKwargs(31, max_size=32),
            ArgsKwargs([31], max_size=32),
            NotScriptableArgsKwargs(30, max_size=100),
            ArgsKwargs([31], max_size=32),
103
104
            ArgsKwargs((29, 32), antialias=False),
            ArgsKwargs((28, 31), antialias=True),
105
        ],
106
107
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        closeness_kwargs=dict(rtol=0, atol=1),
108
    ),
109
110
111
112
113
114
115
116
117
    ConsistencyConfig(
        v2_transforms.Resize,
        legacy_transforms.Resize,
        [
            ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
            ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC, antialias=True),
        ],
        closeness_kwargs=dict(rtol=0, atol=21),
    ),
118
    ConsistencyConfig(
119
        v2_transforms.CenterCrop,
120
121
122
123
124
125
        legacy_transforms.CenterCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
        ],
    ),
126
    ConsistencyConfig(
127
        v2_transforms.FiveCrop,
128
129
130
131
132
133
134
135
        legacy_transforms.FiveCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
    ConsistencyConfig(
136
        v2_transforms.TenCrop,
137
138
139
140
        legacy_transforms.TenCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
141
            ArgsKwargs(18, vertical_flip=True),
142
143
144
145
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
    ConsistencyConfig(
146
        v2_transforms.Pad,
147
148
        legacy_transforms.Pad,
        [
149
            NotScriptableArgsKwargs(3),
150
151
152
            ArgsKwargs([3]),
            ArgsKwargs([2, 3]),
            ArgsKwargs([3, 2, 1, 4]),
153
154
155
156
157
            NotScriptableArgsKwargs(5, fill=1, padding_mode="constant"),
            ArgsKwargs([5], fill=1, padding_mode="constant"),
            NotScriptableArgsKwargs(5, padding_mode="edge"),
            NotScriptableArgsKwargs(5, padding_mode="reflect"),
            NotScriptableArgsKwargs(5, padding_mode="symmetric"),
158
159
        ],
    ),
160
161
    *[
        ConsistencyConfig(
162
            v2_transforms.LinearTransformation,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            legacy_transforms.LinearTransformation,
            [
                ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
            ],
            # Make sure that the product of the height, width and number of channels matches the number of elements in
            # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
            make_images_kwargs=dict(
                DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
            ),
            supports_pil=False,
        )
        for matrix_dtype, image_dtype in [
            (torch.float32, torch.float32),
            (torch.float64, torch.float64),
            (torch.float32, torch.uint8),
            (torch.float64, torch.float32),
            (torch.float32, torch.float64),
        ]
    ],
182
    ConsistencyConfig(
183
        v2_transforms.Grayscale,
184
185
186
187
188
        legacy_transforms.Grayscale,
        [
            ArgsKwargs(num_output_channels=1),
            ArgsKwargs(num_output_channels=3),
        ],
189
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
190
191
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
192
    ),
193
    ConsistencyConfig(
194
        v2_transforms.ConvertImageDtype,
195
196
197
198
199
200
201
202
203
        legacy_transforms.ConvertImageDtype,
        [
            ArgsKwargs(torch.float16),
            ArgsKwargs(torch.bfloat16),
            ArgsKwargs(torch.float32),
            ArgsKwargs(torch.float64),
            ArgsKwargs(torch.uint8),
        ],
        supports_pil=False,
204
205
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
206
207
    ),
    ConsistencyConfig(
208
        v2_transforms.ToPILImage,
209
        legacy_transforms.ToPILImage,
210
        [NotScriptableArgsKwargs()],
211
212
        make_images_kwargs=dict(
            color_spaces=[
213
214
215
216
                "GRAY",
                "GRAY_ALPHA",
                "RGB",
                "RGBA",
217
218
219
220
221
222
            ],
            extra_dims=[()],
        ),
        supports_pil=False,
    ),
    ConsistencyConfig(
223
        v2_transforms.Lambda,
224
225
        legacy_transforms.Lambda,
        [
226
            NotScriptableArgsKwargs(lambda image: image / 2),
227
228
229
230
231
        ],
        # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
        # images given that the transform does nothing but call it anyway.
        supports_pil=False,
    ),
232
    ConsistencyConfig(
233
        v2_transforms.RandomHorizontalFlip,
234
235
236
237
238
239
240
        legacy_transforms.RandomHorizontalFlip,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
241
        v2_transforms.RandomVerticalFlip,
242
243
244
245
246
247
248
        legacy_transforms.RandomVerticalFlip,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
249
        v2_transforms.RandomEqualize,
250
251
252
253
254
255
256
257
        legacy_transforms.RandomEqualize,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
258
        v2_transforms.RandomInvert,
259
260
261
262
263
264
265
        legacy_transforms.RandomInvert,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
266
        v2_transforms.RandomPosterize,
267
268
269
270
271
272
273
274
275
        legacy_transforms.RandomPosterize,
        [
            ArgsKwargs(p=0, bits=5),
            ArgsKwargs(p=1, bits=1),
            ArgsKwargs(p=1, bits=3),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
276
        v2_transforms.RandomSolarize,
277
278
279
280
281
282
283
        legacy_transforms.RandomSolarize,
        [
            ArgsKwargs(p=0, threshold=0.5),
            ArgsKwargs(p=1, threshold=0.3),
            ArgsKwargs(p=1, threshold=0.99),
        ],
    ),
284
285
    *[
        ConsistencyConfig(
286
            v2_transforms.RandomAutocontrast,
287
288
289
290
291
292
293
294
295
296
            legacy_transforms.RandomAutocontrast,
            [
                ArgsKwargs(p=0),
                ArgsKwargs(p=1),
            ],
            make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
            closeness_kwargs=ckw,
        )
        for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
    ],
297
    ConsistencyConfig(
298
        v2_transforms.RandomAdjustSharpness,
299
300
301
        legacy_transforms.RandomAdjustSharpness,
        [
            ArgsKwargs(p=0, sharpness_factor=0.5),
302
            ArgsKwargs(p=1, sharpness_factor=0.2),
303
304
            ArgsKwargs(p=1, sharpness_factor=0.99),
        ],
305
        closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
306
307
    ),
    ConsistencyConfig(
308
        v2_transforms.RandomGrayscale,
309
310
311
312
313
        legacy_transforms.RandomGrayscale,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
314
315
316
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
317
318
    ),
    ConsistencyConfig(
319
        v2_transforms.RandomResizedCrop,
320
321
322
323
324
        legacy_transforms.RandomResizedCrop,
        [
            ArgsKwargs(16),
            ArgsKwargs(17, scale=(0.3, 0.7)),
            ArgsKwargs(25, ratio=(0.5, 1.5)),
325
            ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
326
            ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
327
328
329
            ArgsKwargs((29, 32), antialias=False),
            ArgsKwargs((28, 31), antialias=True),
        ],
330
331
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        closeness_kwargs=dict(rtol=0, atol=1),
332
    ),
333
334
335
336
337
338
339
340
341
    ConsistencyConfig(
        v2_transforms.RandomResizedCrop,
        legacy_transforms.RandomResizedCrop,
        [
            ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
            ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC, antialias=True),
        ],
        closeness_kwargs=dict(rtol=0, atol=21),
    ),
342
    ConsistencyConfig(
343
        v2_transforms.RandomErasing,
344
345
346
347
348
349
350
351
352
353
354
355
356
        legacy_transforms.RandomErasing,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
            ArgsKwargs(p=1, scale=(0.3, 0.7)),
            ArgsKwargs(p=1, ratio=(0.5, 1.5)),
            ArgsKwargs(p=1, value=1),
            ArgsKwargs(p=1, value=(1, 2, 3)),
            ArgsKwargs(p=1, value="random"),
        ],
        supports_pil=False,
    ),
    ConsistencyConfig(
357
        v2_transforms.ColorJitter,
358
359
360
361
362
363
364
365
366
367
368
        legacy_transforms.ColorJitter,
        [
            ArgsKwargs(),
            ArgsKwargs(brightness=0.1),
            ArgsKwargs(brightness=(0.2, 0.3)),
            ArgsKwargs(contrast=0.4),
            ArgsKwargs(contrast=(0.5, 0.6)),
            ArgsKwargs(saturation=0.7),
            ArgsKwargs(saturation=(0.8, 0.9)),
            ArgsKwargs(hue=0.3),
            ArgsKwargs(hue=(-0.1, 0.2)),
369
            ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
370
        ],
371
        closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
372
    ),
373
374
    *[
        ConsistencyConfig(
375
            v2_transforms.ElasticTransform,
376
377
378
379
380
381
382
            legacy_transforms.ElasticTransform,
            [
                ArgsKwargs(),
                ArgsKwargs(alpha=20.0),
                ArgsKwargs(alpha=(15.3, 27.2)),
                ArgsKwargs(sigma=3.0),
                ArgsKwargs(sigma=(2.5, 3.9)),
383
384
                ArgsKwargs(interpolation=v2_transforms.InterpolationMode.NEAREST),
                ArgsKwargs(interpolation=v2_transforms.InterpolationMode.BICUBIC),
385
386
                ArgsKwargs(interpolation=PIL.Image.NEAREST),
                ArgsKwargs(interpolation=PIL.Image.BICUBIC),
387
388
389
390
391
392
393
394
395
396
                ArgsKwargs(fill=1),
            ],
            # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
            make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)], dtypes=[dt]),
            # We updated gaussian blur kernel generation with a faster and numerically more stable version
            # This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance
            closeness_kwargs=ckw,
        )
        for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})]
    ],
397
    ConsistencyConfig(
398
        v2_transforms.GaussianBlur,
399
400
401
402
403
404
405
        legacy_transforms.GaussianBlur,
        [
            ArgsKwargs(kernel_size=3),
            ArgsKwargs(kernel_size=(1, 5)),
            ArgsKwargs(kernel_size=3, sigma=0.7),
            ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)),
        ],
406
        closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
407
408
    ),
    ConsistencyConfig(
409
        v2_transforms.RandomAffine,
410
411
412
413
414
415
416
417
418
419
        legacy_transforms.RandomAffine,
        [
            ArgsKwargs(degrees=30.0),
            ArgsKwargs(degrees=(-20.0, 10.0)),
            ArgsKwargs(degrees=0.0, translate=(0.4, 0.6)),
            ArgsKwargs(degrees=0.0, scale=(0.3, 0.8)),
            ArgsKwargs(degrees=0.0, shear=13),
            ArgsKwargs(degrees=0.0, shear=(8, 17)),
            ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
            ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
420
            ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.NEAREST),
421
            ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
422
423
424
425
            ArgsKwargs(degrees=30.0, fill=1),
            ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
            ArgsKwargs(degrees=30.0, center=(0, 0)),
        ],
426
        removed_params=["fillcolor", "resample"],
427
428
    ),
    ConsistencyConfig(
429
        v2_transforms.RandomCrop,
430
431
432
433
        legacy_transforms.RandomCrop,
        [
            ArgsKwargs(12),
            ArgsKwargs((15, 17)),
434
435
            NotScriptableArgsKwargs(11, padding=1),
            ArgsKwargs(11, padding=[1]),
436
437
438
439
            ArgsKwargs((8, 13), padding=(2, 3)),
            ArgsKwargs((14, 9), padding=(0, 2, 1, 0)),
            ArgsKwargs(36, pad_if_needed=True),
            ArgsKwargs((7, 8), fill=1),
440
            NotScriptableArgsKwargs(5, fill=(1, 2, 3)),
441
            ArgsKwargs(12),
442
            NotScriptableArgsKwargs(15, padding=2, padding_mode="edge"),
443
444
445
446
447
448
            ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"),
            ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]),
    ),
    ConsistencyConfig(
449
        v2_transforms.RandomPerspective,
450
451
452
453
454
        legacy_transforms.RandomPerspective,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
            ArgsKwargs(p=1, distortion_scale=0.3),
455
            ArgsKwargs(p=1, distortion_scale=0.2, interpolation=v2_transforms.InterpolationMode.NEAREST),
456
            ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
457
458
459
            ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
            ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
        ],
460
        closeness_kwargs={"atol": None, "rtol": None},
461
462
    ),
    ConsistencyConfig(
463
        v2_transforms.RandomRotation,
464
465
466
467
        legacy_transforms.RandomRotation,
        [
            ArgsKwargs(degrees=30.0),
            ArgsKwargs(degrees=(-20.0, 10.0)),
468
            ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.BILINEAR),
469
            ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
470
471
472
473
474
            ArgsKwargs(degrees=30.0, expand=True),
            ArgsKwargs(degrees=30.0, center=(0, 0)),
            ArgsKwargs(degrees=30.0, fill=1),
            ArgsKwargs(degrees=30.0, fill=(1, 2, 3)),
        ],
475
        removed_params=["resample"],
476
    ),
477
    ConsistencyConfig(
478
        v2_transforms.PILToTensor,
479
480
481
        legacy_transforms.PILToTensor,
    ),
    ConsistencyConfig(
482
        v2_transforms.ToTensor,
483
484
485
        legacy_transforms.ToTensor,
    ),
    ConsistencyConfig(
486
        v2_transforms.Compose,
487
488
489
        legacy_transforms.Compose,
    ),
    ConsistencyConfig(
490
        v2_transforms.RandomApply,
491
492
493
        legacy_transforms.RandomApply,
    ),
    ConsistencyConfig(
494
        v2_transforms.RandomChoice,
495
496
497
        legacy_transforms.RandomChoice,
    ),
    ConsistencyConfig(
498
        v2_transforms.RandomOrder,
499
500
501
        legacy_transforms.RandomOrder,
    ),
    ConsistencyConfig(
502
        v2_transforms.AugMix,
503
504
505
        legacy_transforms.AugMix,
    ),
    ConsistencyConfig(
506
        v2_transforms.AutoAugment,
507
508
509
        legacy_transforms.AutoAugment,
    ),
    ConsistencyConfig(
510
        v2_transforms.RandAugment,
511
512
513
        legacy_transforms.RandAugment,
    ),
    ConsistencyConfig(
514
        v2_transforms.TrivialAugmentWide,
515
516
        legacy_transforms.TrivialAugmentWide,
    ),
517
518
519
]


520
521
def test_automatic_coverage():
    available = {
522
523
        name
        for name, obj in legacy_transforms.__dict__.items()
524
        if not name.startswith("_") and isinstance(obj, type) and not issubclass(obj, enum.Enum)
525
526
    }

527
    checked = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}
528

529
    missing = available - checked
530
531
532
533
534
535
536
    if missing:
        raise AssertionError(
            f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} "
            f"are not checked for consistency although a legacy counterpart exists."
        )


537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
    legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
    prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

    for param in config.removed_params:
        legacy_params.pop(param, None)

    missing = legacy_params.keys() - prototype_params.keys()
    if missing:
        raise AssertionError(
            f"The prototype transform does not support the parameters "
            f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
            f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
            f"the `ConsistencyConfig`."
        )

    extra = prototype_params.keys() - legacy_params.keys()
555
556
557
558
559
560
    extra_without_default = {
        param
        for param in extra
        if prototype_params[param].default is inspect.Parameter.empty
        and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
    }
561
562
    if extra_without_default:
        raise AssertionError(
563
564
565
            f"The prototype transform requires the parameters "
            f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
            f"not. Please add a default value."
566
567
        )

568
569
570
571
572
573
    legacy_signature = list(legacy_params.keys())
    # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
    # to the same number of parameters as the legacy one
    prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

    assert prototype_signature == legacy_signature
574
575


576
577
578
def check_call_consistency(
    prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
579
580
    if images is None:
        images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
581

582
583
    closeness_kwargs = closeness_kwargs or dict()

584
585
    for image in images:
        image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
586
587
588

        image_tensor = torch.Tensor(image)
        try:
589
            torch.manual_seed(0)
590
            output_legacy_tensor = legacy_transform(image_tensor)
591
592
        except Exception as exc:
            raise pytest.UsageError(
593
                f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
594
                f"error above. This means that you need to specify the parameters passed to `make_images` through the "
595
596
597
598
                "`make_images_kwargs` of the `ConsistencyConfig`."
            ) from exc

        try:
599
            torch.manual_seed(0)
600
            output_prototype_tensor = prototype_transform(image_tensor)
601
602
        except Exception as exc:
            raise AssertionError(
603
                f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
604
605
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                f"`is_simple_tensor` path in `_transform`."
606
607
            ) from exc

608
        assert_close(
609
610
611
            output_prototype_tensor,
            output_legacy_tensor,
            msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
612
            **closeness_kwargs,
613
614
615
        )

        try:
616
            torch.manual_seed(0)
617
            output_prototype_image = prototype_transform(image)
618
619
        except Exception as exc:
            raise AssertionError(
620
                f"Transforming a image datapoint with shape {image_repr} failed in the prototype transform with "
621
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
622
                f"`datapoints.Image` path in `_transform`."
623
624
            ) from exc

625
        assert_close(
626
            output_prototype_image,
627
            output_prototype_tensor,
628
            msg=lambda msg: f"Output for datapoint and tensor images is not equal: \n\n{msg}",
629
            **closeness_kwargs,
630
631
        )

632
633
634
        if image.ndim == 3 and supports_pil:
            image_pil = to_image_pil(image)

635
            try:
636
                torch.manual_seed(0)
637
                output_legacy_pil = legacy_transform(image_pil)
638
639
            except Exception as exc:
                raise pytest.UsageError(
640
                    f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
641
642
643
644
645
                    f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
                    "`ConsistencyConfig`. "
                ) from exc

            try:
646
                torch.manual_seed(0)
647
                output_prototype_pil = prototype_transform(image_pil)
648
649
            except Exception as exc:
                raise AssertionError(
650
                    f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
651
652
653
654
                    f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                    f"`PIL.Image.Image` path in `_transform`."
                ) from exc

655
            assert_close(
656
657
                output_prototype_pil,
                output_legacy_pil,
658
                msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
659
                **closeness_kwargs,
660
            )
661
662


663
@pytest.mark.parametrize(
664
665
    ("config", "args_kwargs"),
    [
666
667
668
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
669
        for config in CONSISTENCY_CONFIGS
670
        for idx, args_kwargs in enumerate(config.args_kwargs)
671
    ],
672
)
673
@pytest.mark.filterwarnings("ignore")
674
def test_call_consistency(config, args_kwargs):
675
676
677
    args, kwargs = args_kwargs

    try:
678
        legacy_transform = config.legacy_cls(*args, **kwargs)
679
680
681
682
683
684
685
    except Exception as exc:
        raise pytest.UsageError(
            f"Initializing the legacy transform failed with the error above. "
            f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
        ) from exc

    try:
686
        prototype_transform = config.prototype_cls(*args, **kwargs)
687
688
689
690
691
692
    except Exception as exc:
        raise AssertionError(
            "Initializing the prototype transform failed with the error above. "
            "This means there is a consistency bug in the constructor."
        ) from exc

693
694
695
696
697
    check_call_consistency(
        prototype_transform,
        legacy_transform,
        images=make_images(**config.make_images_kwargs),
        supports_pil=config.supports_pil,
698
        closeness_kwargs=config.closeness_kwargs,
699
700
701
    )


702
703
704
705
706
707
708
709
710
get_params_parametrization = pytest.mark.parametrize(
    ("config", "get_params_args_kwargs"),
    [
        pytest.param(
            next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls),
            get_params_args_kwargs,
            id=transform_cls.__name__,
        )
        for transform_cls, get_params_args_kwargs in [
711
712
713
714
715
            (v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
            (v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
            (v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
            (v2_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
            (v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
716
            (
717
                v2_transforms.RandomAffine,
718
719
                ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
            ),
720
721
722
723
            (v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
            (v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
            (v2_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
            (v2_transforms.AutoAugment, ArgsKwargs(5)),
724
725
        ]
    ],
726
)
727
728


729
@get_params_parametrization
730
def test_get_params_alias(config, get_params_args_kwargs):
731
732
    assert config.prototype_cls.get_params is config.legacy_cls.get_params

733
734
735
736
737
    if not config.args_kwargs:
        return
    args, kwargs = config.args_kwargs[0]
    legacy_transform = config.legacy_cls(*args, **kwargs)
    prototype_transform = config.prototype_cls(*args, **kwargs)
738

739
740
741
    assert prototype_transform.get_params is legacy_transform.get_params


742
@get_params_parametrization
743
744
745
746
747
748
749
750
751
def test_get_params_jit(config, get_params_args_kwargs):
    get_params_args, get_params_kwargs = get_params_args_kwargs

    torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs)

    if not config.args_kwargs:
        return
    args, kwargs = config.args_kwargs[0]
    transform = config.prototype_cls(*args, **kwargs)
752

753
    torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs)
754
755


756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
@pytest.mark.parametrize(
    ("config", "args_kwargs"),
    [
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
        for config in CONSISTENCY_CONFIGS
        for idx, args_kwargs in enumerate(config.args_kwargs)
        if not isinstance(args_kwargs, NotScriptableArgsKwargs)
    ],
)
def test_jit_consistency(config, args_kwargs):
    args, kwargs = args_kwargs

    prototype_transform_eager = config.prototype_cls(*args, **kwargs)
    legacy_transform_eager = config.legacy_cls(*args, **kwargs)

    legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
    prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

    for image in make_images(**config.make_images_kwargs):
        image = image.as_subclass(torch.Tensor)

        torch.manual_seed(0)
        output_legacy_scripted = legacy_transform_scripted(image)

        torch.manual_seed(0)
        output_prototype_scripted = prototype_transform_scripted(image)

        assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


788
789
790
791
792
793
794
795
796
797
class TestContainerTransforms:
    """
    Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
    consistency automatically tests the wrapped transforms consistency.

    Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
    that were already tested for consistency above.
    """

    def test_compose(self):
798
        prototype_transform = v2_transforms.Compose(
799
            [
800
801
                v2_transforms.Resize(256),
                v2_transforms.CenterCrop(224),
802
803
804
805
806
807
808
809
810
            ]
        )
        legacy_transform = legacy_transforms.Compose(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ]
        )

811
812
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
813
814

    @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
815
816
    @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
    def test_random_apply(self, p, sequence_type):
817
        prototype_transform = v2_transforms.RandomApply(
818
819
            sequence_type(
                [
820
821
                    v2_transforms.Resize(256),
                    v2_transforms.CenterCrop(224),
822
823
                ]
            ),
824
825
826
            p=p,
        )
        legacy_transform = legacy_transforms.RandomApply(
827
828
829
830
831
832
            sequence_type(
                [
                    legacy_transforms.Resize(256),
                    legacy_transforms.CenterCrop(224),
                ]
            ),
833
834
835
            p=p,
        )

836
837
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
838

839
840
841
842
843
        if sequence_type is nn.ModuleList:
            # quick and dirty test that it is jit-scriptable
            scripted = torch.jit.script(prototype_transform)
            scripted(torch.rand(1, 3, 300, 300))

844
    # We can't test other values for `p` since the random parameter generation is different
845
846
    @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
    def test_random_choice(self, probabilities):
847
        prototype_transform = v2_transforms.RandomChoice(
848
            [
849
                v2_transforms.Resize(256),
850
851
                legacy_transforms.CenterCrop(224),
            ],
852
            p=probabilities,
853
854
855
856
857
858
        )
        legacy_transform = legacy_transforms.RandomChoice(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ],
859
            p=probabilities,
860
861
        )

862
863
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
864
865


866
867
class TestToTensorTransforms:
    def test_pil_to_tensor(self):
868
        prototype_transform = v2_transforms.PILToTensor()
869
870
        legacy_transform = legacy_transforms.PILToTensor()

871
872
873
874
875
876
        for image in make_images(extra_dims=[()]):
            image_pil = to_image_pil(image)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

    def test_to_tensor(self):
877
        with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
878
            prototype_transform = v2_transforms.ToTensor()
879
880
        legacy_transform = legacy_transforms.ToTensor()

881
882
883
884
885
886
        for image in make_images(extra_dims=[()]):
            image_pil = to_image_pil(image)
            image_numpy = np.array(image_pil)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
            assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
887
888
889
890
891
892
893
894


class TestAATransforms:
    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
895
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
896
897
898
899
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
900
        [
901
902
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
903
904
            PIL.Image.NEAREST,
        ],
905
906
907
    )
    def test_randaug(self, inpt, interpolation, mocker):
        t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
908
        t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929

        le = len(t._AUGMENTATION_SPACE)
        keys = list(t._AUGMENTATION_SPACE.keys())
        randint_values = []
        for i in range(le):
            # Stable API, op_index random call
            randint_values.append(i)
            # Stable API, if signed there is another random call
            if t._AUGMENTATION_SPACE[keys[i]][1]:
                randint_values.append(0)
            # New API, _get_random_item
            randint_values.append(i)
        randint_values = iter(randint_values)

        mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
        mocker.patch("torch.rand", return_value=1.0)

        for i in range(le):
            expected_output = t_ref(inpt)
            output = t(inpt)

930
            assert_close(expected_output, output, atol=1, rtol=0.1)
931
932
933
934
935
936

    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
937
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
938
939
940
941
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
942
        [
943
944
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
945
946
            PIL.Image.NEAREST,
        ],
947
948
949
    )
    def test_trivial_aug(self, inpt, interpolation, mocker):
        t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
950
        t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981

        le = len(t._AUGMENTATION_SPACE)
        keys = list(t._AUGMENTATION_SPACE.keys())
        randint_values = []
        for i in range(le):
            # Stable API, op_index random call
            randint_values.append(i)
            key = keys[i]
            # Stable API, random magnitude
            aug_op = t._AUGMENTATION_SPACE[key]
            magnitudes = aug_op[0](2, 0, 0)
            if magnitudes is not None:
                randint_values.append(5)
            # Stable API, if signed there is another random call
            if aug_op[1]:
                randint_values.append(0)
            # New API, _get_random_item
            randint_values.append(i)
            # New API, random magnitude
            if magnitudes is not None:
                randint_values.append(5)

        randint_values = iter(randint_values)

        mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
        mocker.patch("torch.rand", return_value=1.0)

        for _ in range(le):
            expected_output = t_ref(inpt)
            output = t(inpt)

982
            assert_close(expected_output, output, atol=1, rtol=0.1)
983
984
985
986
987
988

    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
989
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
990
991
992
993
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
994
        [
995
996
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
997
998
            PIL.Image.NEAREST,
        ],
999
1000
1001
1002
    )
    def test_augmix(self, inpt, interpolation, mocker):
        t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
        t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
1003
        t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
        t._sample_dirichlet = lambda t: t.softmax(dim=-1)

        le = len(t._AUGMENTATION_SPACE)
        keys = list(t._AUGMENTATION_SPACE.keys())
        randint_values = []
        for i in range(le):
            # Stable API, op_index random call
            randint_values.append(i)
            key = keys[i]
            # Stable API, random magnitude
            aug_op = t._AUGMENTATION_SPACE[key]
            magnitudes = aug_op[0](2, 0, 0)
            if magnitudes is not None:
                randint_values.append(5)
            # Stable API, if signed there is another random call
            if aug_op[1]:
                randint_values.append(0)
            # New API, _get_random_item
            randint_values.append(i)
            # New API, random magnitude
            if magnitudes is not None:
                randint_values.append(5)

        randint_values = iter(randint_values)

        mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
        mocker.patch("torch.rand", return_value=1.0)

        expected_output = t_ref(inpt)
        output = t(inpt)

        assert_equal(expected_output, output)

    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
1042
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
1043
1044
1045
1046
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
1047
        [
1048
1049
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
1050
1051
            PIL.Image.NEAREST,
        ],
1052
1053
1054
1055
    )
    def test_aa(self, inpt, interpolation):
        aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
        t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
1056
        t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
1057
1058
1059
1060
1061
1062
1063
1064

        torch.manual_seed(12)
        expected_output = t_ref(inpt)

        torch.manual_seed(12)
        output = t(inpt)

        assert_equal(expected_output, output)
1065
1066


1067
def import_transforms_from_references(reference):
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
1078
1079
1080


det_transforms = import_transforms_from_references("detection")
1081
1082
1083
1084
1085
1086
1087


class TestRefDetTransforms:
    def make_datapoints(self, with_mask=True):
        size = (600, 800)
        num_objects = 22

1088
1089
1090
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

1091
        pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
1092
        target = {
Philip Meier's avatar
Philip Meier committed
1093
            "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
1094
1095
1096
1097
1098
1099
1100
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

1101
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
1102
        target = {
Philip Meier's avatar
Philip Meier committed
1103
            "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
1104
1105
1106
1107
1108
1109
1110
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

1111
        datapoint_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
1112
        target = {
Philip Meier's avatar
Philip Meier committed
1113
            "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
1114
1115
1116
1117
1118
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

1119
        yield (datapoint_image, target)
1120
1121
1122
1123

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
1124
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
1125
1126
1127
1128
1129
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
1130
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
1131
1132
1133
1134
                    ]
                ),
                {"with_mask": False},
            ),
1135
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
1136
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
1137
1138
1139
1140
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
1141
                v2_transforms.RandomShortestSize(
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
        for dp in self.make_datapoints(**data_kwargs):

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
1159
1160
1161
1162
1163
1164
1165
1166
1167


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
1168
class PadIfSmaller(v2_transforms.Transform):
1169
1170
1171
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
1172
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
1173
1174

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
1175
        height, width = query_size(sample)
1176
1177
1178
1179
1180
1181
1182
1183
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

1184
        fill = _get_fill(self.fill, type(inpt))
1185
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
1186
1187
1188
1189


class TestRefSegTransforms:
    def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
1190
        size = (256, 460)
1191
1192
1193
1194
1195
1196
1197
1198
        num_categories = 21

        conv_fns = []
        if supports_pil:
            conv_fns.append(to_image_pil)
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
1199
            datapoint_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
1200
            datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
1201

1202
            dp = (conv_fn(datapoint_image), datapoint_mask)
1203
            dp_ref = (
1204
1205
                to_image_pil(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor),
                to_image_pil(datapoint_mask),
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
        for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):

            self.set_seed()
1218
            actual = actual_image, actual_mask = t(dp)
1219
1220

            self.set_seed()
1221
1222
1223
1224
1225
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
1226

1227
            assert_equal(actual, expected)
1228
1229
1230
1231
1232
1233

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
1234
                v2_transforms.RandomHorizontalFlip(p=1.0),
1235
1236
1237
1238
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
1239
                v2_transforms.RandomHorizontalFlip(p=0.0),
1240
1241
1242
1243
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
1244
                v2_transforms.Compose(
1245
                    [
1246
                        PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}),
1247
                        v2_transforms.RandomCrop(size=480),
1248
1249
1250
1251
1252
1253
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
1254
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
1255
1256
1257
1258
1259
1260
1261
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)

1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273

@pytest.mark.parametrize(
    ("legacy_dispatcher", "name_only_params"),
    [
        (legacy_F.get_dimensions, {}),
        (legacy_F.get_image_size, {}),
        (legacy_F.get_image_num_channels, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.pil_to_tensor, {}),
        (legacy_F.convert_image_dtype, {}),
        (legacy_F.to_pil_image, {}),
        (legacy_F.normalize, {}),
1274
        (legacy_F.resize, {"interpolation"}),
1275
1276
1277
        (legacy_F.pad, {"padding", "fill"}),
        (legacy_F.crop, {}),
        (legacy_F.center_crop, {}),
1278
        (legacy_F.resized_crop, {"interpolation"}),
1279
        (legacy_F.hflip, {}),
1280
        (legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
1281
1282
1283
1284
1285
1286
1287
1288
        (legacy_F.vflip, {}),
        (legacy_F.five_crop, {}),
        (legacy_F.ten_crop, {}),
        (legacy_F.adjust_brightness, {}),
        (legacy_F.adjust_contrast, {}),
        (legacy_F.adjust_saturation, {}),
        (legacy_F.adjust_hue, {}),
        (legacy_F.adjust_gamma, {}),
1289
1290
        (legacy_F.rotate, {"center", "fill", "interpolation"}),
        (legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
        (legacy_F.to_grayscale, {}),
        (legacy_F.rgb_to_grayscale, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.erase, {}),
        (legacy_F.gaussian_blur, {}),
        (legacy_F.invert, {}),
        (legacy_F.posterize, {}),
        (legacy_F.solarize, {}),
        (legacy_F.adjust_sharpness, {}),
        (legacy_F.autocontrast, {}),
        (legacy_F.equalize, {}),
1302
        (legacy_F.elastic_transform, {"fill", "interpolation"}),
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
    ],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
    legacy_signature = inspect.signature(legacy_dispatcher)
    legacy_params = list(legacy_signature.parameters.values())[1:]

    try:
        prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
    except AttributeError:
        raise AssertionError(
            f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
        ) from None

    prototype_signature = inspect.signature(prototype_dispatcher)
    prototype_params = list(prototype_signature.parameters.values())[1:]

    # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
    # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
    # regular check below.
    prototype_params, new_prototype_params = (
        prototype_params[: len(legacy_params)],
        prototype_params[len(legacy_params) :],
    )
    for param in new_prototype_params:
        assert param.default is not param.empty

    # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
    # annotations. In these cases we simply drop the annotation and default argument from the comparison
    for prototype_param, legacy_param in zip(prototype_params, legacy_params):
        if legacy_param.name in name_only_params:
            prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
            legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
        elif legacy_param.annotation is inspect.Parameter.empty:
            prototype_param._annotation = inspect.Parameter.empty

    assert prototype_params == legacy_params