test_transforms_v2_consistency.py 52.6 KB
Newer Older
1
import enum
2
3
import importlib.machinery
import importlib.util
4
import inspect
5
import random
6
import re
7
from pathlib import Path
8

9
import numpy as np
10
import PIL.Image
11
import pytest
12
13

import torch
14
import torchvision.transforms.v2 as v2_transforms
15
from common_utils import assert_close, assert_equal, set_rng_seed
16
from torch import nn
17
from torchvision import datapoints, transforms as legacy_transforms
18
from torchvision._utils import sequence_to_str
19

20
from torchvision.transforms import functional as legacy_F
21
from torchvision.transforms.v2 import functional as prototype_F
22
from torchvision.transforms.v2._utils import _get_fill
23
from torchvision.transforms.v2.functional import to_pil_image
Philip Meier's avatar
Philip Meier committed
24
from torchvision.transforms.v2.utils import query_size
25
26
27
28
29
30
31
32
from transforms_v2_legacy_utils import (
    ArgsKwargs,
    make_bounding_boxes,
    make_detection_mask,
    make_image,
    make_images,
    make_segmentation_mask,
)
33

34
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
35
36


Nicolas Hug's avatar
Nicolas Hug committed
37
38
39
40
41
42
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


43
44
45
46
47
48
49
50
51
class NotScriptableArgsKwargs(ArgsKwargs):
    """
    This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
    thus will be tested there, but will be skipped by the JIT tests.
    """

    pass


52
53
class ConsistencyConfig:
    def __init__(
54
55
56
        self,
        prototype_cls,
        legacy_cls,
57
58
        # If no args_kwargs is passed, only the signature will be checked
        args_kwargs=(),
59
60
61
        make_images_kwargs=None,
        supports_pil=True,
        removed_params=(),
62
        closeness_kwargs=None,
63
64
65
    ):
        self.prototype_cls = prototype_cls
        self.legacy_cls = legacy_cls
66
        self.args_kwargs = args_kwargs
67
68
        self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
        self.supports_pil = supports_pil
69
        self.removed_params = removed_params
70
        self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
71
72


73
74
75
76
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

77
78
CONSISTENCY_CONFIGS = [
    ConsistencyConfig(
79
        v2_transforms.Normalize,
80
81
82
83
84
85
86
87
        legacy_transforms.Normalize,
        [
            ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ],
        supports_pil=False,
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
    ),
    ConsistencyConfig(
88
        v2_transforms.Resize,
89
90
        legacy_transforms.Resize,
        [
91
            NotScriptableArgsKwargs(32),
92
            ArgsKwargs([32]),
93
            ArgsKwargs((32, 29)),
94
            ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
95
96
            ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
            ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
97
98
99
100
            NotScriptableArgsKwargs(31, max_size=32),
            ArgsKwargs([31], max_size=32),
            NotScriptableArgsKwargs(30, max_size=100),
            ArgsKwargs([31], max_size=32),
101
102
            ArgsKwargs((29, 32), antialias=False),
            ArgsKwargs((28, 31), antialias=True),
103
        ],
104
105
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        closeness_kwargs=dict(rtol=0, atol=1),
106
    ),
107
108
109
110
111
112
113
114
115
    ConsistencyConfig(
        v2_transforms.Resize,
        legacy_transforms.Resize,
        [
            ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
            ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC, antialias=True),
        ],
        closeness_kwargs=dict(rtol=0, atol=21),
    ),
116
    ConsistencyConfig(
117
        v2_transforms.CenterCrop,
118
119
120
121
122
123
        legacy_transforms.CenterCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
        ],
    ),
124
    ConsistencyConfig(
125
        v2_transforms.FiveCrop,
126
127
128
129
130
131
132
133
        legacy_transforms.FiveCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
    ConsistencyConfig(
134
        v2_transforms.TenCrop,
135
136
137
138
        legacy_transforms.TenCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
139
            ArgsKwargs(18, vertical_flip=True),
140
141
142
143
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
    ConsistencyConfig(
144
        v2_transforms.Pad,
145
146
        legacy_transforms.Pad,
        [
147
            NotScriptableArgsKwargs(3),
148
149
150
            ArgsKwargs([3]),
            ArgsKwargs([2, 3]),
            ArgsKwargs([3, 2, 1, 4]),
151
152
153
154
155
            NotScriptableArgsKwargs(5, fill=1, padding_mode="constant"),
            ArgsKwargs([5], fill=1, padding_mode="constant"),
            NotScriptableArgsKwargs(5, padding_mode="edge"),
            NotScriptableArgsKwargs(5, padding_mode="reflect"),
            NotScriptableArgsKwargs(5, padding_mode="symmetric"),
156
157
        ],
    ),
158
159
    *[
        ConsistencyConfig(
160
            v2_transforms.LinearTransformation,
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            legacy_transforms.LinearTransformation,
            [
                ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
            ],
            # Make sure that the product of the height, width and number of channels matches the number of elements in
            # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
            make_images_kwargs=dict(
                DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
            ),
            supports_pil=False,
        )
        for matrix_dtype, image_dtype in [
            (torch.float32, torch.float32),
            (torch.float64, torch.float64),
            (torch.float32, torch.uint8),
            (torch.float64, torch.float32),
            (torch.float32, torch.float64),
        ]
    ],
180
    ConsistencyConfig(
181
        v2_transforms.Grayscale,
182
183
184
185
186
        legacy_transforms.Grayscale,
        [
            ArgsKwargs(num_output_channels=1),
            ArgsKwargs(num_output_channels=3),
        ],
187
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
188
189
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
190
    ),
191
    ConsistencyConfig(
192
        v2_transforms.ConvertImageDtype,
193
194
195
196
197
198
199
200
201
        legacy_transforms.ConvertImageDtype,
        [
            ArgsKwargs(torch.float16),
            ArgsKwargs(torch.bfloat16),
            ArgsKwargs(torch.float32),
            ArgsKwargs(torch.float64),
            ArgsKwargs(torch.uint8),
        ],
        supports_pil=False,
202
203
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
204
205
    ),
    ConsistencyConfig(
206
        v2_transforms.ToPILImage,
207
        legacy_transforms.ToPILImage,
208
        [NotScriptableArgsKwargs()],
209
210
        make_images_kwargs=dict(
            color_spaces=[
211
212
213
214
                "GRAY",
                "GRAY_ALPHA",
                "RGB",
                "RGBA",
215
216
217
218
219
220
            ],
            extra_dims=[()],
        ),
        supports_pil=False,
    ),
    ConsistencyConfig(
221
        v2_transforms.Lambda,
222
223
        legacy_transforms.Lambda,
        [
224
            NotScriptableArgsKwargs(lambda image: image / 2),
225
226
227
228
229
        ],
        # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
        # images given that the transform does nothing but call it anyway.
        supports_pil=False,
    ),
230
    ConsistencyConfig(
231
        v2_transforms.RandomHorizontalFlip,
232
233
234
235
236
237
238
        legacy_transforms.RandomHorizontalFlip,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
239
        v2_transforms.RandomVerticalFlip,
240
241
242
243
244
245
246
        legacy_transforms.RandomVerticalFlip,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
247
        v2_transforms.RandomEqualize,
248
249
250
251
252
253
254
255
        legacy_transforms.RandomEqualize,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
256
        v2_transforms.RandomInvert,
257
258
259
260
261
262
263
        legacy_transforms.RandomInvert,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
264
        v2_transforms.RandomPosterize,
265
266
267
268
269
270
271
272
273
        legacy_transforms.RandomPosterize,
        [
            ArgsKwargs(p=0, bits=5),
            ArgsKwargs(p=1, bits=1),
            ArgsKwargs(p=1, bits=3),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
274
        v2_transforms.RandomSolarize,
275
276
277
278
279
280
281
        legacy_transforms.RandomSolarize,
        [
            ArgsKwargs(p=0, threshold=0.5),
            ArgsKwargs(p=1, threshold=0.3),
            ArgsKwargs(p=1, threshold=0.99),
        ],
    ),
282
283
    *[
        ConsistencyConfig(
284
            v2_transforms.RandomAutocontrast,
285
286
287
288
289
290
291
292
293
294
            legacy_transforms.RandomAutocontrast,
            [
                ArgsKwargs(p=0),
                ArgsKwargs(p=1),
            ],
            make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
            closeness_kwargs=ckw,
        )
        for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
    ],
295
    ConsistencyConfig(
296
        v2_transforms.RandomAdjustSharpness,
297
298
299
        legacy_transforms.RandomAdjustSharpness,
        [
            ArgsKwargs(p=0, sharpness_factor=0.5),
300
            ArgsKwargs(p=1, sharpness_factor=0.2),
301
302
            ArgsKwargs(p=1, sharpness_factor=0.99),
        ],
303
        closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
304
305
    ),
    ConsistencyConfig(
306
        v2_transforms.RandomGrayscale,
307
308
309
310
311
        legacy_transforms.RandomGrayscale,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
312
313
314
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
315
316
    ),
    ConsistencyConfig(
317
        v2_transforms.RandomResizedCrop,
318
319
320
321
322
        legacy_transforms.RandomResizedCrop,
        [
            ArgsKwargs(16),
            ArgsKwargs(17, scale=(0.3, 0.7)),
            ArgsKwargs(25, ratio=(0.5, 1.5)),
323
            ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
324
            ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
325
326
327
            ArgsKwargs((29, 32), antialias=False),
            ArgsKwargs((28, 31), antialias=True),
        ],
328
329
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        closeness_kwargs=dict(rtol=0, atol=1),
330
    ),
331
332
333
334
335
336
337
338
339
    ConsistencyConfig(
        v2_transforms.RandomResizedCrop,
        legacy_transforms.RandomResizedCrop,
        [
            ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
            ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC, antialias=True),
        ],
        closeness_kwargs=dict(rtol=0, atol=21),
    ),
340
    ConsistencyConfig(
341
        v2_transforms.RandomErasing,
342
343
344
345
346
347
348
349
350
351
352
353
354
        legacy_transforms.RandomErasing,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
            ArgsKwargs(p=1, scale=(0.3, 0.7)),
            ArgsKwargs(p=1, ratio=(0.5, 1.5)),
            ArgsKwargs(p=1, value=1),
            ArgsKwargs(p=1, value=(1, 2, 3)),
            ArgsKwargs(p=1, value="random"),
        ],
        supports_pil=False,
    ),
    ConsistencyConfig(
355
        v2_transforms.ColorJitter,
356
357
358
359
360
361
362
363
364
365
366
        legacy_transforms.ColorJitter,
        [
            ArgsKwargs(),
            ArgsKwargs(brightness=0.1),
            ArgsKwargs(brightness=(0.2, 0.3)),
            ArgsKwargs(contrast=0.4),
            ArgsKwargs(contrast=(0.5, 0.6)),
            ArgsKwargs(saturation=0.7),
            ArgsKwargs(saturation=(0.8, 0.9)),
            ArgsKwargs(hue=0.3),
            ArgsKwargs(hue=(-0.1, 0.2)),
367
            ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
368
        ],
369
        closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
370
    ),
371
372
    *[
        ConsistencyConfig(
373
            v2_transforms.ElasticTransform,
374
375
376
377
378
379
380
            legacy_transforms.ElasticTransform,
            [
                ArgsKwargs(),
                ArgsKwargs(alpha=20.0),
                ArgsKwargs(alpha=(15.3, 27.2)),
                ArgsKwargs(sigma=3.0),
                ArgsKwargs(sigma=(2.5, 3.9)),
381
382
                ArgsKwargs(interpolation=v2_transforms.InterpolationMode.NEAREST),
                ArgsKwargs(interpolation=v2_transforms.InterpolationMode.BICUBIC),
383
384
                ArgsKwargs(interpolation=PIL.Image.NEAREST),
                ArgsKwargs(interpolation=PIL.Image.BICUBIC),
385
386
387
388
389
390
391
392
393
394
                ArgsKwargs(fill=1),
            ],
            # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
            make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)], dtypes=[dt]),
            # We updated gaussian blur kernel generation with a faster and numerically more stable version
            # This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance
            closeness_kwargs=ckw,
        )
        for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})]
    ],
395
    ConsistencyConfig(
396
        v2_transforms.GaussianBlur,
397
398
399
400
401
402
403
        legacy_transforms.GaussianBlur,
        [
            ArgsKwargs(kernel_size=3),
            ArgsKwargs(kernel_size=(1, 5)),
            ArgsKwargs(kernel_size=3, sigma=0.7),
            ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)),
        ],
404
        closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
405
406
    ),
    ConsistencyConfig(
407
        v2_transforms.RandomAffine,
408
409
410
411
412
413
414
415
416
417
        legacy_transforms.RandomAffine,
        [
            ArgsKwargs(degrees=30.0),
            ArgsKwargs(degrees=(-20.0, 10.0)),
            ArgsKwargs(degrees=0.0, translate=(0.4, 0.6)),
            ArgsKwargs(degrees=0.0, scale=(0.3, 0.8)),
            ArgsKwargs(degrees=0.0, shear=13),
            ArgsKwargs(degrees=0.0, shear=(8, 17)),
            ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
            ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
418
            ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.NEAREST),
419
            ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
420
421
422
423
            ArgsKwargs(degrees=30.0, fill=1),
            ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
            ArgsKwargs(degrees=30.0, center=(0, 0)),
        ],
424
        removed_params=["fillcolor", "resample"],
425
426
    ),
    ConsistencyConfig(
427
        v2_transforms.RandomCrop,
428
429
430
431
        legacy_transforms.RandomCrop,
        [
            ArgsKwargs(12),
            ArgsKwargs((15, 17)),
432
433
            NotScriptableArgsKwargs(11, padding=1),
            ArgsKwargs(11, padding=[1]),
434
435
436
437
            ArgsKwargs((8, 13), padding=(2, 3)),
            ArgsKwargs((14, 9), padding=(0, 2, 1, 0)),
            ArgsKwargs(36, pad_if_needed=True),
            ArgsKwargs((7, 8), fill=1),
438
            NotScriptableArgsKwargs(5, fill=(1, 2, 3)),
439
            ArgsKwargs(12),
440
            NotScriptableArgsKwargs(15, padding=2, padding_mode="edge"),
441
442
443
444
445
446
            ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"),
            ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]),
    ),
    ConsistencyConfig(
447
        v2_transforms.RandomPerspective,
448
449
450
451
452
        legacy_transforms.RandomPerspective,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
            ArgsKwargs(p=1, distortion_scale=0.3),
453
            ArgsKwargs(p=1, distortion_scale=0.2, interpolation=v2_transforms.InterpolationMode.NEAREST),
454
            ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
455
456
457
            ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
            ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
        ],
458
        closeness_kwargs={"atol": None, "rtol": None},
459
460
    ),
    ConsistencyConfig(
461
        v2_transforms.RandomRotation,
462
463
464
465
        legacy_transforms.RandomRotation,
        [
            ArgsKwargs(degrees=30.0),
            ArgsKwargs(degrees=(-20.0, 10.0)),
466
            ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.BILINEAR),
467
            ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
468
469
470
471
472
            ArgsKwargs(degrees=30.0, expand=True),
            ArgsKwargs(degrees=30.0, center=(0, 0)),
            ArgsKwargs(degrees=30.0, fill=1),
            ArgsKwargs(degrees=30.0, fill=(1, 2, 3)),
        ],
473
        removed_params=["resample"],
474
    ),
475
    ConsistencyConfig(
476
        v2_transforms.PILToTensor,
477
478
479
        legacy_transforms.PILToTensor,
    ),
    ConsistencyConfig(
480
        v2_transforms.ToTensor,
481
482
483
        legacy_transforms.ToTensor,
    ),
    ConsistencyConfig(
484
        v2_transforms.Compose,
485
486
487
        legacy_transforms.Compose,
    ),
    ConsistencyConfig(
488
        v2_transforms.RandomApply,
489
490
491
        legacy_transforms.RandomApply,
    ),
    ConsistencyConfig(
492
        v2_transforms.RandomChoice,
493
494
495
        legacy_transforms.RandomChoice,
    ),
    ConsistencyConfig(
496
        v2_transforms.RandomOrder,
497
498
499
        legacy_transforms.RandomOrder,
    ),
    ConsistencyConfig(
500
        v2_transforms.AugMix,
501
502
503
        legacy_transforms.AugMix,
    ),
    ConsistencyConfig(
504
        v2_transforms.AutoAugment,
505
506
507
        legacy_transforms.AutoAugment,
    ),
    ConsistencyConfig(
508
        v2_transforms.RandAugment,
509
510
511
        legacy_transforms.RandAugment,
    ),
    ConsistencyConfig(
512
        v2_transforms.TrivialAugmentWide,
513
514
        legacy_transforms.TrivialAugmentWide,
    ),
515
516
517
]


518
519
def test_automatic_coverage():
    available = {
520
521
        name
        for name, obj in legacy_transforms.__dict__.items()
522
        if not name.startswith("_") and isinstance(obj, type) and not issubclass(obj, enum.Enum)
523
524
    }

525
    checked = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}
526

527
    missing = available - checked
528
529
530
531
532
533
534
    if missing:
        raise AssertionError(
            f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} "
            f"are not checked for consistency although a legacy counterpart exists."
        )


535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
    legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
    prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

    for param in config.removed_params:
        legacy_params.pop(param, None)

    missing = legacy_params.keys() - prototype_params.keys()
    if missing:
        raise AssertionError(
            f"The prototype transform does not support the parameters "
            f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
            f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
            f"the `ConsistencyConfig`."
        )

    extra = prototype_params.keys() - legacy_params.keys()
553
554
555
556
557
558
    extra_without_default = {
        param
        for param in extra
        if prototype_params[param].default is inspect.Parameter.empty
        and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
    }
559
560
    if extra_without_default:
        raise AssertionError(
561
562
563
            f"The prototype transform requires the parameters "
            f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
            f"not. Please add a default value."
564
565
        )

566
567
568
569
570
571
    legacy_signature = list(legacy_params.keys())
    # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
    # to the same number of parameters as the legacy one
    prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

    assert prototype_signature == legacy_signature
572
573


574
575
576
def check_call_consistency(
    prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
577
578
    if images is None:
        images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
579

580
581
    closeness_kwargs = closeness_kwargs or dict()

582
583
    for image in images:
        image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
584
585
586

        image_tensor = torch.Tensor(image)
        try:
587
            torch.manual_seed(0)
588
            output_legacy_tensor = legacy_transform(image_tensor)
589
590
        except Exception as exc:
            raise pytest.UsageError(
591
                f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
592
                f"error above. This means that you need to specify the parameters passed to `make_images` through the "
593
594
595
596
                "`make_images_kwargs` of the `ConsistencyConfig`."
            ) from exc

        try:
597
            torch.manual_seed(0)
598
            output_prototype_tensor = prototype_transform(image_tensor)
599
600
        except Exception as exc:
            raise AssertionError(
601
                f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
602
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
603
                f"`is_pure_tensor` path in `_transform`."
604
605
            ) from exc

606
        assert_close(
607
608
609
            output_prototype_tensor,
            output_legacy_tensor,
            msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
610
            **closeness_kwargs,
611
612
613
        )

        try:
614
            torch.manual_seed(0)
615
            output_prototype_image = prototype_transform(image)
616
617
        except Exception as exc:
            raise AssertionError(
618
                f"Transforming a image datapoint with shape {image_repr} failed in the prototype transform with "
619
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
620
                f"`datapoints.Image` path in `_transform`."
621
622
            ) from exc

623
        assert_close(
624
            output_prototype_image,
625
            output_prototype_tensor,
626
            msg=lambda msg: f"Output for datapoint and tensor images is not equal: \n\n{msg}",
627
            **closeness_kwargs,
628
629
        )

630
        if image.ndim == 3 and supports_pil:
631
            image_pil = to_pil_image(image)
632

633
            try:
634
                torch.manual_seed(0)
635
                output_legacy_pil = legacy_transform(image_pil)
636
637
            except Exception as exc:
                raise pytest.UsageError(
638
                    f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
639
640
641
642
643
                    f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
                    "`ConsistencyConfig`. "
                ) from exc

            try:
644
                torch.manual_seed(0)
645
                output_prototype_pil = prototype_transform(image_pil)
646
647
            except Exception as exc:
                raise AssertionError(
648
                    f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
649
650
651
652
                    f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                    f"`PIL.Image.Image` path in `_transform`."
                ) from exc

653
            assert_close(
654
655
                output_prototype_pil,
                output_legacy_pil,
656
                msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
657
                **closeness_kwargs,
658
            )
659
660


661
@pytest.mark.parametrize(
662
663
    ("config", "args_kwargs"),
    [
664
665
666
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
667
        for config in CONSISTENCY_CONFIGS
668
        for idx, args_kwargs in enumerate(config.args_kwargs)
669
    ],
670
)
671
@pytest.mark.filterwarnings("ignore")
672
def test_call_consistency(config, args_kwargs):
673
674
675
    args, kwargs = args_kwargs

    try:
676
        legacy_transform = config.legacy_cls(*args, **kwargs)
677
678
679
680
681
682
683
    except Exception as exc:
        raise pytest.UsageError(
            f"Initializing the legacy transform failed with the error above. "
            f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
        ) from exc

    try:
684
        prototype_transform = config.prototype_cls(*args, **kwargs)
685
686
687
688
689
690
    except Exception as exc:
        raise AssertionError(
            "Initializing the prototype transform failed with the error above. "
            "This means there is a consistency bug in the constructor."
        ) from exc

691
692
693
694
695
    check_call_consistency(
        prototype_transform,
        legacy_transform,
        images=make_images(**config.make_images_kwargs),
        supports_pil=config.supports_pil,
696
        closeness_kwargs=config.closeness_kwargs,
697
698
699
    )


700
701
702
703
704
705
706
707
708
get_params_parametrization = pytest.mark.parametrize(
    ("config", "get_params_args_kwargs"),
    [
        pytest.param(
            next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls),
            get_params_args_kwargs,
            id=transform_cls.__name__,
        )
        for transform_cls, get_params_args_kwargs in [
709
710
711
712
713
            (v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
            (v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
            (v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
            (v2_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
            (v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
714
            (
715
                v2_transforms.RandomAffine,
716
717
                ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
            ),
718
719
720
721
            (v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
            (v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
            (v2_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
            (v2_transforms.AutoAugment, ArgsKwargs(5)),
722
723
        ]
    ],
724
)
725
726


727
@get_params_parametrization
728
def test_get_params_alias(config, get_params_args_kwargs):
729
730
    assert config.prototype_cls.get_params is config.legacy_cls.get_params

731
732
733
734
735
    if not config.args_kwargs:
        return
    args, kwargs = config.args_kwargs[0]
    legacy_transform = config.legacy_cls(*args, **kwargs)
    prototype_transform = config.prototype_cls(*args, **kwargs)
736

737
738
739
    assert prototype_transform.get_params is legacy_transform.get_params


740
@get_params_parametrization
741
742
743
744
745
746
747
748
749
def test_get_params_jit(config, get_params_args_kwargs):
    get_params_args, get_params_kwargs = get_params_args_kwargs

    torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs)

    if not config.args_kwargs:
        return
    args, kwargs = config.args_kwargs[0]
    transform = config.prototype_cls(*args, **kwargs)
750

751
    torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs)
752
753


754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
@pytest.mark.parametrize(
    ("config", "args_kwargs"),
    [
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
        for config in CONSISTENCY_CONFIGS
        for idx, args_kwargs in enumerate(config.args_kwargs)
        if not isinstance(args_kwargs, NotScriptableArgsKwargs)
    ],
)
def test_jit_consistency(config, args_kwargs):
    args, kwargs = args_kwargs

    prototype_transform_eager = config.prototype_cls(*args, **kwargs)
    legacy_transform_eager = config.legacy_cls(*args, **kwargs)

    legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
    prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

    for image in make_images(**config.make_images_kwargs):
        image = image.as_subclass(torch.Tensor)

        torch.manual_seed(0)
        output_legacy_scripted = legacy_transform_scripted(image)

        torch.manual_seed(0)
        output_prototype_scripted = prototype_transform_scripted(image)

        assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


786
787
788
789
790
791
792
793
794
795
class TestContainerTransforms:
    """
    Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
    consistency automatically tests the wrapped transforms consistency.

    Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
    that were already tested for consistency above.
    """

    def test_compose(self):
796
        prototype_transform = v2_transforms.Compose(
797
            [
798
799
                v2_transforms.Resize(256),
                v2_transforms.CenterCrop(224),
800
801
802
803
804
805
806
807
808
            ]
        )
        legacy_transform = legacy_transforms.Compose(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ]
        )

809
810
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
811
812

    @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
813
814
    @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
    def test_random_apply(self, p, sequence_type):
815
        prototype_transform = v2_transforms.RandomApply(
816
817
            sequence_type(
                [
818
819
                    v2_transforms.Resize(256),
                    v2_transforms.CenterCrop(224),
820
821
                ]
            ),
822
823
824
            p=p,
        )
        legacy_transform = legacy_transforms.RandomApply(
825
826
827
828
829
830
            sequence_type(
                [
                    legacy_transforms.Resize(256),
                    legacy_transforms.CenterCrop(224),
                ]
            ),
831
832
833
            p=p,
        )

834
835
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
836

837
838
839
840
841
        if sequence_type is nn.ModuleList:
            # quick and dirty test that it is jit-scriptable
            scripted = torch.jit.script(prototype_transform)
            scripted(torch.rand(1, 3, 300, 300))

842
    # We can't test other values for `p` since the random parameter generation is different
843
844
    @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
    def test_random_choice(self, probabilities):
845
        prototype_transform = v2_transforms.RandomChoice(
846
            [
847
                v2_transforms.Resize(256),
848
849
                legacy_transforms.CenterCrop(224),
            ],
850
            p=probabilities,
851
852
853
854
855
856
        )
        legacy_transform = legacy_transforms.RandomChoice(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ],
857
            p=probabilities,
858
859
        )

860
861
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
862
863


864
865
class TestToTensorTransforms:
    def test_pil_to_tensor(self):
866
        prototype_transform = v2_transforms.PILToTensor()
867
868
        legacy_transform = legacy_transforms.PILToTensor()

869
        for image in make_images(extra_dims=[()]):
870
            image_pil = to_pil_image(image)
871
872
873
874

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

    def test_to_tensor(self):
875
        with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
876
            prototype_transform = v2_transforms.ToTensor()
877
878
        legacy_transform = legacy_transforms.ToTensor()

879
        for image in make_images(extra_dims=[()]):
880
            image_pil = to_pil_image(image)
881
882
883
884
            image_numpy = np.array(image_pil)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
            assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
885
886
887
888
889
890
891
892


class TestAATransforms:
    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
893
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
894
895
896
897
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
898
        [
899
900
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
901
902
            PIL.Image.NEAREST,
        ],
903
904
905
    )
    def test_randaug(self, inpt, interpolation, mocker):
        t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
906
        t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927

        le = len(t._AUGMENTATION_SPACE)
        keys = list(t._AUGMENTATION_SPACE.keys())
        randint_values = []
        for i in range(le):
            # Stable API, op_index random call
            randint_values.append(i)
            # Stable API, if signed there is another random call
            if t._AUGMENTATION_SPACE[keys[i]][1]:
                randint_values.append(0)
            # New API, _get_random_item
            randint_values.append(i)
        randint_values = iter(randint_values)

        mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
        mocker.patch("torch.rand", return_value=1.0)

        for i in range(le):
            expected_output = t_ref(inpt)
            output = t(inpt)

928
            assert_close(expected_output, output, atol=1, rtol=0.1)
929

930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
    @pytest.mark.parametrize(
        "interpolation",
        [
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
        ],
    )
    def test_randaug_jit(self, interpolation):
        inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
        t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
        t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)

        tt_ref = torch.jit.script(t_ref)
        tt = torch.jit.script(t)

        torch.manual_seed(12)
        expected_output = tt_ref(inpt)

        torch.manual_seed(12)
        scripted_output = tt(inpt)

        assert_equal(scripted_output, expected_output)

953
954
955
956
957
    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
958
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
959
960
961
962
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
963
        [
964
965
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
966
967
            PIL.Image.NEAREST,
        ],
968
969
970
    )
    def test_trivial_aug(self, inpt, interpolation, mocker):
        t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
971
        t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002

        le = len(t._AUGMENTATION_SPACE)
        keys = list(t._AUGMENTATION_SPACE.keys())
        randint_values = []
        for i in range(le):
            # Stable API, op_index random call
            randint_values.append(i)
            key = keys[i]
            # Stable API, random magnitude
            aug_op = t._AUGMENTATION_SPACE[key]
            magnitudes = aug_op[0](2, 0, 0)
            if magnitudes is not None:
                randint_values.append(5)
            # Stable API, if signed there is another random call
            if aug_op[1]:
                randint_values.append(0)
            # New API, _get_random_item
            randint_values.append(i)
            # New API, random magnitude
            if magnitudes is not None:
                randint_values.append(5)

        randint_values = iter(randint_values)

        mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
        mocker.patch("torch.rand", return_value=1.0)

        for _ in range(le):
            expected_output = t_ref(inpt)
            output = t(inpt)

1003
            assert_close(expected_output, output, atol=1, rtol=0.1)
1004

1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
    @pytest.mark.parametrize(
        "interpolation",
        [
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
        ],
    )
    def test_trivial_aug_jit(self, interpolation):
        inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
        t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
        t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)

        tt_ref = torch.jit.script(t_ref)
        tt = torch.jit.script(t)

        torch.manual_seed(12)
        expected_output = tt_ref(inpt)

        torch.manual_seed(12)
        scripted_output = tt(inpt)

        assert_equal(scripted_output, expected_output)

1028
1029
1030
1031
1032
    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
1033
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
1034
1035
1036
1037
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
1038
        [
1039
1040
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
1041
1042
            PIL.Image.NEAREST,
        ],
1043
1044
1045
1046
    )
    def test_augmix(self, inpt, interpolation, mocker):
        t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
        t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
1047
        t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
        t._sample_dirichlet = lambda t: t.softmax(dim=-1)

        le = len(t._AUGMENTATION_SPACE)
        keys = list(t._AUGMENTATION_SPACE.keys())
        randint_values = []
        for i in range(le):
            # Stable API, op_index random call
            randint_values.append(i)
            key = keys[i]
            # Stable API, random magnitude
            aug_op = t._AUGMENTATION_SPACE[key]
            magnitudes = aug_op[0](2, 0, 0)
            if magnitudes is not None:
                randint_values.append(5)
            # Stable API, if signed there is another random call
            if aug_op[1]:
                randint_values.append(0)
            # New API, _get_random_item
            randint_values.append(i)
            # New API, random magnitude
            if magnitudes is not None:
                randint_values.append(5)

        randint_values = iter(randint_values)

        mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
        mocker.patch("torch.rand", return_value=1.0)

        expected_output = t_ref(inpt)
        output = t(inpt)

        assert_equal(expected_output, output)

1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
    @pytest.mark.parametrize(
        "interpolation",
        [
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
        ],
    )
    def test_augmix_jit(self, interpolation):
        inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)

        t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
        t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)

        tt_ref = torch.jit.script(t_ref)
        tt = torch.jit.script(t)

        torch.manual_seed(12)
        expected_output = tt_ref(inpt)

        torch.manual_seed(12)
        scripted_output = tt(inpt)

        assert_equal(scripted_output, expected_output)

1105
1106
1107
1108
1109
    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
1110
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
1111
1112
1113
1114
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
1115
        [
1116
1117
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
1118
1119
            PIL.Image.NEAREST,
        ],
1120
1121
1122
1123
    )
    def test_aa(self, inpt, interpolation):
        aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
        t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
1124
        t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
1125
1126
1127
1128
1129
1130
1131
1132

        torch.manual_seed(12)
        expected_output = t_ref(inpt)

        torch.manual_seed(12)
        output = t(inpt)

        assert_equal(expected_output, output)
1133

1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
    @pytest.mark.parametrize(
        "interpolation",
        [
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
        ],
    )
    def test_aa_jit(self, interpolation):
        inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
        aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
        t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
        t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)

        tt_ref = torch.jit.script(t_ref)
        tt = torch.jit.script(t)

        torch.manual_seed(12)
        expected_output = tt_ref(inpt)

        torch.manual_seed(12)
        scripted_output = tt(inpt)

        assert_equal(scripted_output, expected_output)

1158

1159
def import_transforms_from_references(reference):
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
1170
1171
1172


det_transforms = import_transforms_from_references("detection")
1173
1174
1175
1176
1177
1178
1179


class TestRefDetTransforms:
    def make_datapoints(self, with_mask=True):
        size = (600, 800)
        num_objects = 22

1180
1181
1182
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

1183
        pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
1184
        target = {
1185
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
1186
1187
1188
1189
1190
1191
1192
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

1193
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
1194
        target = {
1195
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
1196
1197
1198
1199
1200
1201
1202
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

1203
        datapoint_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
1204
        target = {
1205
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
1206
1207
1208
1209
1210
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

1211
        yield (datapoint_image, target)
1212
1213
1214
1215

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
1216
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
1217
1218
1219
1220
1221
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
1222
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
1223
1224
1225
1226
                    ]
                ),
                {"with_mask": False},
            ),
1227
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
1228
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
1229
1230
1231
1232
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
1233
                v2_transforms.RandomShortestSize(
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
        for dp in self.make_datapoints(**data_kwargs):

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
1251
1252
1253
1254
1255
1256
1257
1258
1259


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
1260
class PadIfSmaller(v2_transforms.Transform):
1261
1262
1263
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
1264
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
1265
1266

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
1267
        height, width = query_size(sample)
1268
1269
1270
1271
1272
1273
1274
1275
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

1276
        fill = _get_fill(self.fill, type(inpt))
1277
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
1278
1279
1280
1281


class TestRefSegTransforms:
    def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
1282
        size = (256, 460)
1283
1284
1285
1286
        num_categories = 21

        conv_fns = []
        if supports_pil:
1287
            conv_fns.append(to_pil_image)
1288
1289
1290
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
1291
            datapoint_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
1292
            datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
1293

1294
            dp = (conv_fn(datapoint_image), datapoint_mask)
1295
            dp_ref = (
1296
1297
                to_pil_image(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor),
                to_pil_image(datapoint_mask),
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
        for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):

            self.set_seed()
1310
            actual = actual_image, actual_mask = t(dp)
1311
1312

            self.set_seed()
1313
1314
1315
1316
1317
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
1318

1319
            assert_equal(actual, expected)
1320
1321
1322
1323
1324
1325

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
1326
                v2_transforms.RandomHorizontalFlip(p=1.0),
1327
1328
1329
1330
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
1331
                v2_transforms.RandomHorizontalFlip(p=0.0),
1332
1333
1334
1335
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
1336
                v2_transforms.Compose(
1337
                    [
1338
                        PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}),
1339
                        v2_transforms.RandomCrop(size=480),
1340
1341
1342
1343
1344
1345
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
1346
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
1347
1348
1349
1350
1351
1352
1353
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)

1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365

@pytest.mark.parametrize(
    ("legacy_dispatcher", "name_only_params"),
    [
        (legacy_F.get_dimensions, {}),
        (legacy_F.get_image_size, {}),
        (legacy_F.get_image_num_channels, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.pil_to_tensor, {}),
        (legacy_F.convert_image_dtype, {}),
        (legacy_F.to_pil_image, {}),
        (legacy_F.normalize, {}),
1366
        (legacy_F.resize, {"interpolation"}),
1367
1368
1369
        (legacy_F.pad, {"padding", "fill"}),
        (legacy_F.crop, {}),
        (legacy_F.center_crop, {}),
1370
        (legacy_F.resized_crop, {"interpolation"}),
1371
        (legacy_F.hflip, {}),
1372
        (legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
1373
1374
1375
1376
1377
1378
1379
1380
        (legacy_F.vflip, {}),
        (legacy_F.five_crop, {}),
        (legacy_F.ten_crop, {}),
        (legacy_F.adjust_brightness, {}),
        (legacy_F.adjust_contrast, {}),
        (legacy_F.adjust_saturation, {}),
        (legacy_F.adjust_hue, {}),
        (legacy_F.adjust_gamma, {}),
1381
1382
        (legacy_F.rotate, {"center", "fill", "interpolation"}),
        (legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
        (legacy_F.to_grayscale, {}),
        (legacy_F.rgb_to_grayscale, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.erase, {}),
        (legacy_F.gaussian_blur, {}),
        (legacy_F.invert, {}),
        (legacy_F.posterize, {}),
        (legacy_F.solarize, {}),
        (legacy_F.adjust_sharpness, {}),
        (legacy_F.autocontrast, {}),
        (legacy_F.equalize, {}),
1394
        (legacy_F.elastic_transform, {"fill", "interpolation"}),
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
    ],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
    legacy_signature = inspect.signature(legacy_dispatcher)
    legacy_params = list(legacy_signature.parameters.values())[1:]

    try:
        prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
    except AttributeError:
        raise AssertionError(
            f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
        ) from None

    prototype_signature = inspect.signature(prototype_dispatcher)
    prototype_params = list(prototype_signature.parameters.values())[1:]

    # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
    # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
    # regular check below.
    prototype_params, new_prototype_params = (
        prototype_params[: len(legacy_params)],
        prototype_params[len(legacy_params) :],
    )
    for param in new_prototype_params:
        assert param.default is not param.empty

    # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
    # annotations. In these cases we simply drop the annotation and default argument from the comparison
    for prototype_param, legacy_param in zip(prototype_params, legacy_params):
        if legacy_param.name in name_only_params:
            prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
            legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
        elif legacy_param.annotation is inspect.Parameter.empty:
            prototype_param._annotation = inspect.Parameter.empty

    assert prototype_params == legacy_params