test_transforms_v2_consistency.py 49.5 KB
Newer Older
1
import enum
2
3
import importlib.machinery
import importlib.util
4
import inspect
5
import random
6
import re
7
from pathlib import Path
8

9
import numpy as np
10
import PIL.Image
11
import pytest
12
13

import torch
14
import torchvision.transforms.v2 as v2_transforms
15
from common_utils import assert_close, assert_equal, set_rng_seed
16
from torch import nn
17
from torchvision import datapoints, transforms as legacy_transforms
18
from torchvision._utils import sequence_to_str
19

20
from torchvision.transforms import functional as legacy_F
21
from torchvision.transforms.v2 import functional as prototype_F
22
from torchvision.transforms.v2._utils import _get_fill
23
from torchvision.transforms.v2.functional import to_pil_image
Philip Meier's avatar
Philip Meier committed
24
from torchvision.transforms.v2.utils import query_size
25
26
27
28
29
30
31
32
from transforms_v2_legacy_utils import (
    ArgsKwargs,
    make_bounding_boxes,
    make_detection_mask,
    make_image,
    make_images,
    make_segmentation_mask,
)
33

34
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
35
36


Nicolas Hug's avatar
Nicolas Hug committed
37
38
39
40
41
42
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


43
44
45
46
47
48
49
50
51
class NotScriptableArgsKwargs(ArgsKwargs):
    """
    This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
    thus will be tested there, but will be skipped by the JIT tests.
    """

    pass


52
53
class ConsistencyConfig:
    def __init__(
54
55
56
        self,
        prototype_cls,
        legacy_cls,
57
58
        # If no args_kwargs is passed, only the signature will be checked
        args_kwargs=(),
59
60
61
        make_images_kwargs=None,
        supports_pil=True,
        removed_params=(),
62
        closeness_kwargs=None,
63
64
65
    ):
        self.prototype_cls = prototype_cls
        self.legacy_cls = legacy_cls
66
        self.args_kwargs = args_kwargs
67
68
        self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
        self.supports_pil = supports_pil
69
        self.removed_params = removed_params
70
        self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
71
72


73
74
75
76
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

77
78
CONSISTENCY_CONFIGS = [
    ConsistencyConfig(
79
        v2_transforms.Normalize,
80
81
82
83
84
85
86
87
        legacy_transforms.Normalize,
        [
            ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ],
        supports_pil=False,
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
    ),
    ConsistencyConfig(
88
        v2_transforms.Resize,
89
90
        legacy_transforms.Resize,
        [
91
            NotScriptableArgsKwargs(32),
92
            ArgsKwargs([32]),
93
            ArgsKwargs((32, 29)),
94
            ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
95
96
            ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
            ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
97
98
99
100
            NotScriptableArgsKwargs(31, max_size=32),
            ArgsKwargs([31], max_size=32),
            NotScriptableArgsKwargs(30, max_size=100),
            ArgsKwargs([31], max_size=32),
101
102
            ArgsKwargs((29, 32), antialias=False),
            ArgsKwargs((28, 31), antialias=True),
103
        ],
104
105
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        closeness_kwargs=dict(rtol=0, atol=1),
106
    ),
107
108
109
110
111
112
113
114
115
    ConsistencyConfig(
        v2_transforms.Resize,
        legacy_transforms.Resize,
        [
            ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
            ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC, antialias=True),
        ],
        closeness_kwargs=dict(rtol=0, atol=21),
    ),
116
    ConsistencyConfig(
117
        v2_transforms.CenterCrop,
118
119
120
121
122
123
        legacy_transforms.CenterCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
        ],
    ),
124
    ConsistencyConfig(
125
        v2_transforms.FiveCrop,
126
127
128
129
130
131
132
133
        legacy_transforms.FiveCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
    ConsistencyConfig(
134
        v2_transforms.TenCrop,
135
136
137
138
        legacy_transforms.TenCrop,
        [
            ArgsKwargs(18),
            ArgsKwargs((18, 13)),
139
            ArgsKwargs(18, vertical_flip=True),
140
141
142
143
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
    ),
    ConsistencyConfig(
144
        v2_transforms.Pad,
145
146
        legacy_transforms.Pad,
        [
147
            NotScriptableArgsKwargs(3),
148
149
150
            ArgsKwargs([3]),
            ArgsKwargs([2, 3]),
            ArgsKwargs([3, 2, 1, 4]),
151
152
153
154
155
            NotScriptableArgsKwargs(5, fill=1, padding_mode="constant"),
            ArgsKwargs([5], fill=1, padding_mode="constant"),
            NotScriptableArgsKwargs(5, padding_mode="edge"),
            NotScriptableArgsKwargs(5, padding_mode="reflect"),
            NotScriptableArgsKwargs(5, padding_mode="symmetric"),
156
157
        ],
    ),
158
159
    *[
        ConsistencyConfig(
160
            v2_transforms.LinearTransformation,
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            legacy_transforms.LinearTransformation,
            [
                ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
            ],
            # Make sure that the product of the height, width and number of channels matches the number of elements in
            # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
            make_images_kwargs=dict(
                DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"], dtypes=[image_dtype]
            ),
            supports_pil=False,
        )
        for matrix_dtype, image_dtype in [
            (torch.float32, torch.float32),
            (torch.float64, torch.float64),
            (torch.float32, torch.uint8),
            (torch.float64, torch.float32),
            (torch.float32, torch.float64),
        ]
    ],
180
    ConsistencyConfig(
181
        v2_transforms.Grayscale,
182
183
184
185
186
        legacy_transforms.Grayscale,
        [
            ArgsKwargs(num_output_channels=1),
            ArgsKwargs(num_output_channels=3),
        ],
187
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
188
189
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
190
    ),
191
    ConsistencyConfig(
192
        v2_transforms.ConvertImageDtype,
193
194
195
196
197
198
199
200
201
        legacy_transforms.ConvertImageDtype,
        [
            ArgsKwargs(torch.float16),
            ArgsKwargs(torch.bfloat16),
            ArgsKwargs(torch.float32),
            ArgsKwargs(torch.float64),
            ArgsKwargs(torch.uint8),
        ],
        supports_pil=False,
202
203
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
204
205
    ),
    ConsistencyConfig(
206
        v2_transforms.ToPILImage,
207
        legacy_transforms.ToPILImage,
208
        [NotScriptableArgsKwargs()],
209
210
        make_images_kwargs=dict(
            color_spaces=[
211
212
213
214
                "GRAY",
                "GRAY_ALPHA",
                "RGB",
                "RGBA",
215
216
217
218
219
220
            ],
            extra_dims=[()],
        ),
        supports_pil=False,
    ),
    ConsistencyConfig(
221
        v2_transforms.Lambda,
222
223
        legacy_transforms.Lambda,
        [
224
            NotScriptableArgsKwargs(lambda image: image / 2),
225
226
227
228
229
        ],
        # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
        # images given that the transform does nothing but call it anyway.
        supports_pil=False,
    ),
230
    ConsistencyConfig(
231
        v2_transforms.RandomHorizontalFlip,
232
233
234
235
236
237
238
        legacy_transforms.RandomHorizontalFlip,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
239
        v2_transforms.RandomVerticalFlip,
240
241
242
243
244
245
246
        legacy_transforms.RandomVerticalFlip,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
247
        v2_transforms.RandomEqualize,
248
249
250
251
252
253
254
255
        legacy_transforms.RandomEqualize,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
256
        v2_transforms.RandomInvert,
257
258
259
260
261
262
263
        legacy_transforms.RandomInvert,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
    ),
    ConsistencyConfig(
264
        v2_transforms.RandomPosterize,
265
266
267
268
269
270
271
272
273
        legacy_transforms.RandomPosterize,
        [
            ArgsKwargs(p=0, bits=5),
            ArgsKwargs(p=1, bits=1),
            ArgsKwargs(p=1, bits=3),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
    ),
    ConsistencyConfig(
274
        v2_transforms.RandomSolarize,
275
276
277
278
279
280
281
        legacy_transforms.RandomSolarize,
        [
            ArgsKwargs(p=0, threshold=0.5),
            ArgsKwargs(p=1, threshold=0.3),
            ArgsKwargs(p=1, threshold=0.99),
        ],
    ),
282
283
    *[
        ConsistencyConfig(
284
            v2_transforms.RandomAutocontrast,
285
286
287
288
289
290
291
292
293
294
            legacy_transforms.RandomAutocontrast,
            [
                ArgsKwargs(p=0),
                ArgsKwargs(p=1),
            ],
            make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[dt]),
            closeness_kwargs=ckw,
        )
        for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
    ],
295
    ConsistencyConfig(
296
        v2_transforms.RandomAdjustSharpness,
297
298
299
        legacy_transforms.RandomAdjustSharpness,
        [
            ArgsKwargs(p=0, sharpness_factor=0.5),
300
            ArgsKwargs(p=1, sharpness_factor=0.2),
301
302
            ArgsKwargs(p=1, sharpness_factor=0.99),
        ],
303
        closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
304
305
    ),
    ConsistencyConfig(
306
        v2_transforms.RandomGrayscale,
307
308
309
310
311
        legacy_transforms.RandomGrayscale,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
        ],
312
313
314
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
        # Use default tolerances of `torch.testing.assert_close`
        closeness_kwargs=dict(rtol=None, atol=None),
315
316
    ),
    ConsistencyConfig(
317
        v2_transforms.RandomResizedCrop,
318
319
320
321
322
        legacy_transforms.RandomResizedCrop,
        [
            ArgsKwargs(16),
            ArgsKwargs(17, scale=(0.3, 0.7)),
            ArgsKwargs(25, ratio=(0.5, 1.5)),
323
            ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
324
            ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
325
326
327
            ArgsKwargs((29, 32), antialias=False),
            ArgsKwargs((28, 31), antialias=True),
        ],
328
329
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        closeness_kwargs=dict(rtol=0, atol=1),
330
    ),
331
332
333
334
335
336
337
338
339
    ConsistencyConfig(
        v2_transforms.RandomResizedCrop,
        legacy_transforms.RandomResizedCrop,
        [
            ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
            ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC, antialias=True),
        ],
        closeness_kwargs=dict(rtol=0, atol=21),
    ),
340
    ConsistencyConfig(
341
        v2_transforms.RandomErasing,
342
343
344
345
346
347
348
349
350
351
352
353
354
        legacy_transforms.RandomErasing,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
            ArgsKwargs(p=1, scale=(0.3, 0.7)),
            ArgsKwargs(p=1, ratio=(0.5, 1.5)),
            ArgsKwargs(p=1, value=1),
            ArgsKwargs(p=1, value=(1, 2, 3)),
            ArgsKwargs(p=1, value="random"),
        ],
        supports_pil=False,
    ),
    ConsistencyConfig(
355
        v2_transforms.ColorJitter,
356
357
358
359
360
361
362
363
364
365
366
        legacy_transforms.ColorJitter,
        [
            ArgsKwargs(),
            ArgsKwargs(brightness=0.1),
            ArgsKwargs(brightness=(0.2, 0.3)),
            ArgsKwargs(contrast=0.4),
            ArgsKwargs(contrast=(0.5, 0.6)),
            ArgsKwargs(saturation=0.7),
            ArgsKwargs(saturation=(0.8, 0.9)),
            ArgsKwargs(hue=0.3),
            ArgsKwargs(hue=(-0.1, 0.2)),
367
            ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
368
        ],
369
        closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
370
    ),
371
372
    *[
        ConsistencyConfig(
373
            v2_transforms.ElasticTransform,
374
375
376
377
378
379
380
            legacy_transforms.ElasticTransform,
            [
                ArgsKwargs(),
                ArgsKwargs(alpha=20.0),
                ArgsKwargs(alpha=(15.3, 27.2)),
                ArgsKwargs(sigma=3.0),
                ArgsKwargs(sigma=(2.5, 3.9)),
381
382
                ArgsKwargs(interpolation=v2_transforms.InterpolationMode.NEAREST),
                ArgsKwargs(interpolation=v2_transforms.InterpolationMode.BICUBIC),
383
384
                ArgsKwargs(interpolation=PIL.Image.NEAREST),
                ArgsKwargs(interpolation=PIL.Image.BICUBIC),
385
386
387
388
389
390
391
392
393
394
                ArgsKwargs(fill=1),
            ],
            # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
            make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(163, 163), (72, 333), (313, 95)], dtypes=[dt]),
            # We updated gaussian blur kernel generation with a faster and numerically more stable version
            # This brings float32 accumulation visible in elastic transform -> we need to relax consistency tolerance
            closeness_kwargs=ckw,
        )
        for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})]
    ],
395
    ConsistencyConfig(
396
        v2_transforms.GaussianBlur,
397
398
399
400
401
402
403
        legacy_transforms.GaussianBlur,
        [
            ArgsKwargs(kernel_size=3),
            ArgsKwargs(kernel_size=(1, 5)),
            ArgsKwargs(kernel_size=3, sigma=0.7),
            ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)),
        ],
404
        closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
405
406
    ),
    ConsistencyConfig(
407
        v2_transforms.RandomAffine,
408
409
410
411
412
413
414
415
416
417
        legacy_transforms.RandomAffine,
        [
            ArgsKwargs(degrees=30.0),
            ArgsKwargs(degrees=(-20.0, 10.0)),
            ArgsKwargs(degrees=0.0, translate=(0.4, 0.6)),
            ArgsKwargs(degrees=0.0, scale=(0.3, 0.8)),
            ArgsKwargs(degrees=0.0, shear=13),
            ArgsKwargs(degrees=0.0, shear=(8, 17)),
            ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
            ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
418
            ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.NEAREST),
419
            ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
420
421
422
423
            ArgsKwargs(degrees=30.0, fill=1),
            ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
            ArgsKwargs(degrees=30.0, center=(0, 0)),
        ],
424
        removed_params=["fillcolor", "resample"],
425
426
    ),
    ConsistencyConfig(
427
        v2_transforms.RandomCrop,
428
429
430
431
        legacy_transforms.RandomCrop,
        [
            ArgsKwargs(12),
            ArgsKwargs((15, 17)),
432
433
            NotScriptableArgsKwargs(11, padding=1),
            ArgsKwargs(11, padding=[1]),
434
435
436
437
            ArgsKwargs((8, 13), padding=(2, 3)),
            ArgsKwargs((14, 9), padding=(0, 2, 1, 0)),
            ArgsKwargs(36, pad_if_needed=True),
            ArgsKwargs((7, 8), fill=1),
438
            NotScriptableArgsKwargs(5, fill=(1, 2, 3)),
439
            ArgsKwargs(12),
440
            NotScriptableArgsKwargs(15, padding=2, padding_mode="edge"),
441
442
443
444
445
446
            ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"),
            ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"),
        ],
        make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]),
    ),
    ConsistencyConfig(
447
        v2_transforms.RandomPerspective,
448
449
450
451
452
        legacy_transforms.RandomPerspective,
        [
            ArgsKwargs(p=0),
            ArgsKwargs(p=1),
            ArgsKwargs(p=1, distortion_scale=0.3),
453
            ArgsKwargs(p=1, distortion_scale=0.2, interpolation=v2_transforms.InterpolationMode.NEAREST),
454
            ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
455
456
457
            ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
            ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
        ],
458
        closeness_kwargs={"atol": None, "rtol": None},
459
460
    ),
    ConsistencyConfig(
461
        v2_transforms.RandomRotation,
462
463
464
465
        legacy_transforms.RandomRotation,
        [
            ArgsKwargs(degrees=30.0),
            ArgsKwargs(degrees=(-20.0, 10.0)),
466
            ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.BILINEAR),
467
            ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
468
469
470
471
472
            ArgsKwargs(degrees=30.0, expand=True),
            ArgsKwargs(degrees=30.0, center=(0, 0)),
            ArgsKwargs(degrees=30.0, fill=1),
            ArgsKwargs(degrees=30.0, fill=(1, 2, 3)),
        ],
473
        removed_params=["resample"],
474
    ),
475
    ConsistencyConfig(
476
        v2_transforms.PILToTensor,
477
478
479
        legacy_transforms.PILToTensor,
    ),
    ConsistencyConfig(
480
        v2_transforms.ToTensor,
481
482
483
        legacy_transforms.ToTensor,
    ),
    ConsistencyConfig(
484
        v2_transforms.Compose,
485
486
487
        legacy_transforms.Compose,
    ),
    ConsistencyConfig(
488
        v2_transforms.RandomApply,
489
490
491
        legacy_transforms.RandomApply,
    ),
    ConsistencyConfig(
492
        v2_transforms.RandomChoice,
493
494
495
        legacy_transforms.RandomChoice,
    ),
    ConsistencyConfig(
496
        v2_transforms.RandomOrder,
497
498
499
        legacy_transforms.RandomOrder,
    ),
    ConsistencyConfig(
500
        v2_transforms.AugMix,
501
502
503
        legacy_transforms.AugMix,
    ),
    ConsistencyConfig(
504
        v2_transforms.AutoAugment,
505
506
507
        legacy_transforms.AutoAugment,
    ),
    ConsistencyConfig(
508
        v2_transforms.RandAugment,
509
510
511
        legacy_transforms.RandAugment,
    ),
    ConsistencyConfig(
512
        v2_transforms.TrivialAugmentWide,
513
514
        legacy_transforms.TrivialAugmentWide,
    ),
515
516
517
]


518
519
def test_automatic_coverage():
    available = {
520
521
        name
        for name, obj in legacy_transforms.__dict__.items()
522
        if not name.startswith("_") and isinstance(obj, type) and not issubclass(obj, enum.Enum)
523
524
    }

525
    checked = {config.legacy_cls.__name__ for config in CONSISTENCY_CONFIGS}
526

527
    missing = available - checked
528
529
530
531
532
533
534
    if missing:
        raise AssertionError(
            f"The prototype transformations {sequence_to_str(sorted(missing), separate_last='and ')} "
            f"are not checked for consistency although a legacy counterpart exists."
        )


535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
@pytest.mark.parametrize("config", CONSISTENCY_CONFIGS, ids=lambda config: config.legacy_cls.__name__)
def test_signature_consistency(config):
    legacy_params = dict(inspect.signature(config.legacy_cls).parameters)
    prototype_params = dict(inspect.signature(config.prototype_cls).parameters)

    for param in config.removed_params:
        legacy_params.pop(param, None)

    missing = legacy_params.keys() - prototype_params.keys()
    if missing:
        raise AssertionError(
            f"The prototype transform does not support the parameters "
            f"{sequence_to_str(sorted(missing), separate_last='and ')}, but the legacy transform does. "
            f"If that is intentional, e.g. pending deprecation, please add the parameters to the `removed_params` on "
            f"the `ConsistencyConfig`."
        )

    extra = prototype_params.keys() - legacy_params.keys()
553
554
555
556
557
558
    extra_without_default = {
        param
        for param in extra
        if prototype_params[param].default is inspect.Parameter.empty
        and prototype_params[param].kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
    }
559
560
    if extra_without_default:
        raise AssertionError(
561
562
563
            f"The prototype transform requires the parameters "
            f"{sequence_to_str(sorted(extra_without_default), separate_last='and ')}, but the legacy transform does "
            f"not. Please add a default value."
564
565
        )

566
567
568
569
570
571
    legacy_signature = list(legacy_params.keys())
    # Since we made sure that we don't have any extra parameters without default above, we clamp the prototype signature
    # to the same number of parameters as the legacy one
    prototype_signature = list(prototype_params.keys())[: len(legacy_signature)]

    assert prototype_signature == legacy_signature
572
573


574
575
576
def check_call_consistency(
    prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
577
578
    if images is None:
        images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
579

580
581
    closeness_kwargs = closeness_kwargs or dict()

582
583
    for image in images:
        image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
584
585
586

        image_tensor = torch.Tensor(image)
        try:
587
            torch.manual_seed(0)
588
            output_legacy_tensor = legacy_transform(image_tensor)
589
590
        except Exception as exc:
            raise pytest.UsageError(
591
                f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
592
                f"error above. This means that you need to specify the parameters passed to `make_images` through the "
593
594
595
596
                "`make_images_kwargs` of the `ConsistencyConfig`."
            ) from exc

        try:
597
            torch.manual_seed(0)
598
            output_prototype_tensor = prototype_transform(image_tensor)
599
600
        except Exception as exc:
            raise AssertionError(
601
                f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
602
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
603
                f"`is_pure_tensor` path in `_transform`."
604
605
            ) from exc

606
        assert_close(
607
608
609
            output_prototype_tensor,
            output_legacy_tensor,
            msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
610
            **closeness_kwargs,
611
612
613
        )

        try:
614
            torch.manual_seed(0)
615
            output_prototype_image = prototype_transform(image)
616
617
        except Exception as exc:
            raise AssertionError(
618
                f"Transforming a image datapoint with shape {image_repr} failed in the prototype transform with "
619
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
620
                f"`datapoints.Image` path in `_transform`."
621
622
            ) from exc

623
        assert_close(
624
            output_prototype_image,
625
            output_prototype_tensor,
626
            msg=lambda msg: f"Output for datapoint and tensor images is not equal: \n\n{msg}",
627
            **closeness_kwargs,
628
629
        )

630
        if image.ndim == 3 and supports_pil:
631
            image_pil = to_pil_image(image)
632

633
            try:
634
                torch.manual_seed(0)
635
                output_legacy_pil = legacy_transform(image_pil)
636
637
            except Exception as exc:
                raise pytest.UsageError(
638
                    f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
639
640
641
642
643
                    f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
                    "`ConsistencyConfig`. "
                ) from exc

            try:
644
                torch.manual_seed(0)
645
                output_prototype_pil = prototype_transform(image_pil)
646
647
            except Exception as exc:
                raise AssertionError(
648
                    f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
649
650
651
652
                    f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                    f"`PIL.Image.Image` path in `_transform`."
                ) from exc

653
            assert_close(
654
655
                output_prototype_pil,
                output_legacy_pil,
656
                msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
657
                **closeness_kwargs,
658
            )
659
660


661
@pytest.mark.parametrize(
662
663
    ("config", "args_kwargs"),
    [
664
665
666
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
667
        for config in CONSISTENCY_CONFIGS
668
        for idx, args_kwargs in enumerate(config.args_kwargs)
669
    ],
670
)
671
@pytest.mark.filterwarnings("ignore")
672
def test_call_consistency(config, args_kwargs):
673
674
675
    args, kwargs = args_kwargs

    try:
676
        legacy_transform = config.legacy_cls(*args, **kwargs)
677
678
679
680
681
682
683
    except Exception as exc:
        raise pytest.UsageError(
            f"Initializing the legacy transform failed with the error above. "
            f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
        ) from exc

    try:
684
        prototype_transform = config.prototype_cls(*args, **kwargs)
685
686
687
688
689
690
    except Exception as exc:
        raise AssertionError(
            "Initializing the prototype transform failed with the error above. "
            "This means there is a consistency bug in the constructor."
        ) from exc

691
692
693
694
695
    check_call_consistency(
        prototype_transform,
        legacy_transform,
        images=make_images(**config.make_images_kwargs),
        supports_pil=config.supports_pil,
696
        closeness_kwargs=config.closeness_kwargs,
697
698
699
    )


700
701
702
703
704
705
706
707
708
get_params_parametrization = pytest.mark.parametrize(
    ("config", "get_params_args_kwargs"),
    [
        pytest.param(
            next(config for config in CONSISTENCY_CONFIGS if config.prototype_cls is transform_cls),
            get_params_args_kwargs,
            id=transform_cls.__name__,
        )
        for transform_cls, get_params_args_kwargs in [
709
710
711
712
713
            (v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
            (v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
            (v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
            (v2_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
            (v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
714
            (
715
                v2_transforms.RandomAffine,
716
717
                ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
            ),
718
719
720
721
            (v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
            (v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
            (v2_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
            (v2_transforms.AutoAugment, ArgsKwargs(5)),
722
723
        ]
    ],
724
)
725
726


727
@get_params_parametrization
728
def test_get_params_alias(config, get_params_args_kwargs):
729
730
    assert config.prototype_cls.get_params is config.legacy_cls.get_params

731
732
733
734
735
    if not config.args_kwargs:
        return
    args, kwargs = config.args_kwargs[0]
    legacy_transform = config.legacy_cls(*args, **kwargs)
    prototype_transform = config.prototype_cls(*args, **kwargs)
736

737
738
739
    assert prototype_transform.get_params is legacy_transform.get_params


740
@get_params_parametrization
741
742
743
744
745
746
747
748
749
def test_get_params_jit(config, get_params_args_kwargs):
    get_params_args, get_params_kwargs = get_params_args_kwargs

    torch.jit.script(config.prototype_cls.get_params)(*get_params_args, **get_params_kwargs)

    if not config.args_kwargs:
        return
    args, kwargs = config.args_kwargs[0]
    transform = config.prototype_cls(*args, **kwargs)
750

751
    torch.jit.script(transform.get_params)(*get_params_args, **get_params_kwargs)
752
753


754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
@pytest.mark.parametrize(
    ("config", "args_kwargs"),
    [
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
        for config in CONSISTENCY_CONFIGS
        for idx, args_kwargs in enumerate(config.args_kwargs)
        if not isinstance(args_kwargs, NotScriptableArgsKwargs)
    ],
)
def test_jit_consistency(config, args_kwargs):
    args, kwargs = args_kwargs

    prototype_transform_eager = config.prototype_cls(*args, **kwargs)
    legacy_transform_eager = config.legacy_cls(*args, **kwargs)

    legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
    prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

    for image in make_images(**config.make_images_kwargs):
        image = image.as_subclass(torch.Tensor)

        torch.manual_seed(0)
        output_legacy_scripted = legacy_transform_scripted(image)

        torch.manual_seed(0)
        output_prototype_scripted = prototype_transform_scripted(image)

        assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


786
787
788
789
790
791
792
793
794
795
class TestContainerTransforms:
    """
    Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
    consistency automatically tests the wrapped transforms consistency.

    Instead of complicated mocking or creating custom transforms just for these tests, here we use deterministic ones
    that were already tested for consistency above.
    """

    def test_compose(self):
796
        prototype_transform = v2_transforms.Compose(
797
            [
798
799
                v2_transforms.Resize(256),
                v2_transforms.CenterCrop(224),
800
801
802
803
804
805
806
807
808
            ]
        )
        legacy_transform = legacy_transforms.Compose(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ]
        )

809
810
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
811
812

    @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
813
814
    @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
    def test_random_apply(self, p, sequence_type):
815
        prototype_transform = v2_transforms.RandomApply(
816
817
            sequence_type(
                [
818
819
                    v2_transforms.Resize(256),
                    v2_transforms.CenterCrop(224),
820
821
                ]
            ),
822
823
824
            p=p,
        )
        legacy_transform = legacy_transforms.RandomApply(
825
826
827
828
829
830
            sequence_type(
                [
                    legacy_transforms.Resize(256),
                    legacy_transforms.CenterCrop(224),
                ]
            ),
831
832
833
            p=p,
        )

834
835
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
836

837
838
839
840
841
        if sequence_type is nn.ModuleList:
            # quick and dirty test that it is jit-scriptable
            scripted = torch.jit.script(prototype_transform)
            scripted(torch.rand(1, 3, 300, 300))

842
    # We can't test other values for `p` since the random parameter generation is different
843
844
    @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
    def test_random_choice(self, probabilities):
845
        prototype_transform = v2_transforms.RandomChoice(
846
            [
847
                v2_transforms.Resize(256),
848
849
                legacy_transforms.CenterCrop(224),
            ],
850
            p=probabilities,
851
852
853
854
855
856
        )
        legacy_transform = legacy_transforms.RandomChoice(
            [
                legacy_transforms.Resize(256),
                legacy_transforms.CenterCrop(224),
            ],
857
            p=probabilities,
858
859
        )

860
861
        # atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
        check_call_consistency(prototype_transform, legacy_transform, closeness_kwargs=dict(rtol=0, atol=1))
862
863


864
865
class TestToTensorTransforms:
    def test_pil_to_tensor(self):
866
        prototype_transform = v2_transforms.PILToTensor()
867
868
        legacy_transform = legacy_transforms.PILToTensor()

869
        for image in make_images(extra_dims=[()]):
870
            image_pil = to_pil_image(image)
871
872
873
874

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

    def test_to_tensor(self):
875
        with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
876
            prototype_transform = v2_transforms.ToTensor()
877
878
        legacy_transform = legacy_transforms.ToTensor()

879
        for image in make_images(extra_dims=[()]):
880
            image_pil = to_pil_image(image)
881
882
883
884
            image_numpy = np.array(image_pil)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
            assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
885
886
887
888
889
890
891
892


class TestAATransforms:
    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
893
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
894
895
896
897
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
898
        [
899
900
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
901
902
            PIL.Image.NEAREST,
        ],
903
904
905
    )
    def test_randaug(self, inpt, interpolation, mocker):
        t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
906
        t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927

        le = len(t._AUGMENTATION_SPACE)
        keys = list(t._AUGMENTATION_SPACE.keys())
        randint_values = []
        for i in range(le):
            # Stable API, op_index random call
            randint_values.append(i)
            # Stable API, if signed there is another random call
            if t._AUGMENTATION_SPACE[keys[i]][1]:
                randint_values.append(0)
            # New API, _get_random_item
            randint_values.append(i)
        randint_values = iter(randint_values)

        mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
        mocker.patch("torch.rand", return_value=1.0)

        for i in range(le):
            expected_output = t_ref(inpt)
            output = t(inpt)

928
            assert_close(expected_output, output, atol=1, rtol=0.1)
929
930
931
932
933
934

    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
935
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
936
937
938
939
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
940
        [
941
942
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
943
944
            PIL.Image.NEAREST,
        ],
945
946
947
    )
    def test_trivial_aug(self, inpt, interpolation, mocker):
        t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
948
        t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979

        le = len(t._AUGMENTATION_SPACE)
        keys = list(t._AUGMENTATION_SPACE.keys())
        randint_values = []
        for i in range(le):
            # Stable API, op_index random call
            randint_values.append(i)
            key = keys[i]
            # Stable API, random magnitude
            aug_op = t._AUGMENTATION_SPACE[key]
            magnitudes = aug_op[0](2, 0, 0)
            if magnitudes is not None:
                randint_values.append(5)
            # Stable API, if signed there is another random call
            if aug_op[1]:
                randint_values.append(0)
            # New API, _get_random_item
            randint_values.append(i)
            # New API, random magnitude
            if magnitudes is not None:
                randint_values.append(5)

        randint_values = iter(randint_values)

        mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
        mocker.patch("torch.rand", return_value=1.0)

        for _ in range(le):
            expected_output = t_ref(inpt)
            output = t(inpt)

980
            assert_close(expected_output, output, atol=1, rtol=0.1)
981
982
983
984
985
986

    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
987
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
988
989
990
991
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
992
        [
993
994
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
995
996
            PIL.Image.NEAREST,
        ],
997
998
999
1000
    )
    def test_augmix(self, inpt, interpolation, mocker):
        t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
        t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
1001
        t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
        t._sample_dirichlet = lambda t: t.softmax(dim=-1)

        le = len(t._AUGMENTATION_SPACE)
        keys = list(t._AUGMENTATION_SPACE.keys())
        randint_values = []
        for i in range(le):
            # Stable API, op_index random call
            randint_values.append(i)
            key = keys[i]
            # Stable API, random magnitude
            aug_op = t._AUGMENTATION_SPACE[key]
            magnitudes = aug_op[0](2, 0, 0)
            if magnitudes is not None:
                randint_values.append(5)
            # Stable API, if signed there is another random call
            if aug_op[1]:
                randint_values.append(0)
            # New API, _get_random_item
            randint_values.append(i)
            # New API, random magnitude
            if magnitudes is not None:
                randint_values.append(5)

        randint_values = iter(randint_values)

        mocker.patch("torch.randint", side_effect=lambda *arg, **kwargs: torch.tensor(next(randint_values)))
        mocker.patch("torch.rand", return_value=1.0)

        expected_output = t_ref(inpt)
        output = t(inpt)

        assert_equal(expected_output, output)

    @pytest.mark.parametrize(
        "inpt",
        [
            torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
            PIL.Image.new("RGB", (256, 256), 123),
1040
            datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
1041
1042
1043
1044
        ],
    )
    @pytest.mark.parametrize(
        "interpolation",
1045
        [
1046
1047
            v2_transforms.InterpolationMode.NEAREST,
            v2_transforms.InterpolationMode.BILINEAR,
1048
1049
            PIL.Image.NEAREST,
        ],
1050
1051
1052
1053
    )
    def test_aa(self, inpt, interpolation):
        aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
        t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
1054
        t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
1055
1056
1057
1058
1059
1060
1061
1062

        torch.manual_seed(12)
        expected_output = t_ref(inpt)

        torch.manual_seed(12)
        output = t(inpt)

        assert_equal(expected_output, output)
1063
1064


1065
def import_transforms_from_references(reference):
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
1076
1077
1078


det_transforms = import_transforms_from_references("detection")
1079
1080
1081
1082
1083
1084
1085


class TestRefDetTransforms:
    def make_datapoints(self, with_mask=True):
        size = (600, 800)
        num_objects = 22

1086
1087
1088
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

1089
        pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
1090
        target = {
1091
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
1092
1093
1094
1095
1096
1097
1098
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

1099
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
1100
        target = {
1101
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
1102
1103
1104
1105
1106
1107
1108
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

1109
        datapoint_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
1110
        target = {
1111
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
1112
1113
1114
1115
1116
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

1117
        yield (datapoint_image, target)
1118
1119
1120
1121

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
1122
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
1123
1124
1125
1126
1127
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
1128
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
1129
1130
1131
1132
                    ]
                ),
                {"with_mask": False},
            ),
1133
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
1134
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
1135
1136
1137
1138
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
1139
                v2_transforms.RandomShortestSize(
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
        for dp in self.make_datapoints(**data_kwargs):

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
1157
1158
1159
1160
1161
1162
1163
1164
1165


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
1166
class PadIfSmaller(v2_transforms.Transform):
1167
1168
1169
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
1170
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
1171
1172

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
1173
        height, width = query_size(sample)
1174
1175
1176
1177
1178
1179
1180
1181
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

1182
        fill = _get_fill(self.fill, type(inpt))
1183
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
1184
1185
1186
1187


class TestRefSegTransforms:
    def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
1188
        size = (256, 460)
1189
1190
1191
1192
        num_categories = 21

        conv_fns = []
        if supports_pil:
1193
            conv_fns.append(to_pil_image)
1194
1195
1196
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
1197
            datapoint_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
1198
            datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
1199

1200
            dp = (conv_fn(datapoint_image), datapoint_mask)
1201
            dp_ref = (
1202
1203
                to_pil_image(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor),
                to_pil_image(datapoint_mask),
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
        for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):

            self.set_seed()
1216
            actual = actual_image, actual_mask = t(dp)
1217
1218

            self.set_seed()
1219
1220
1221
1222
1223
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
1224

1225
            assert_equal(actual, expected)
1226
1227
1228
1229
1230
1231

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
1232
                v2_transforms.RandomHorizontalFlip(p=1.0),
1233
1234
1235
1236
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
1237
                v2_transforms.RandomHorizontalFlip(p=0.0),
1238
1239
1240
1241
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
1242
                v2_transforms.Compose(
1243
                    [
1244
                        PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}),
1245
                        v2_transforms.RandomCrop(size=480),
1246
1247
1248
1249
1250
1251
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
1252
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
1253
1254
1255
1256
1257
1258
1259
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)

1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271

@pytest.mark.parametrize(
    ("legacy_dispatcher", "name_only_params"),
    [
        (legacy_F.get_dimensions, {}),
        (legacy_F.get_image_size, {}),
        (legacy_F.get_image_num_channels, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.pil_to_tensor, {}),
        (legacy_F.convert_image_dtype, {}),
        (legacy_F.to_pil_image, {}),
        (legacy_F.normalize, {}),
1272
        (legacy_F.resize, {"interpolation"}),
1273
1274
1275
        (legacy_F.pad, {"padding", "fill"}),
        (legacy_F.crop, {}),
        (legacy_F.center_crop, {}),
1276
        (legacy_F.resized_crop, {"interpolation"}),
1277
        (legacy_F.hflip, {}),
1278
        (legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
1279
1280
1281
1282
1283
1284
1285
1286
        (legacy_F.vflip, {}),
        (legacy_F.five_crop, {}),
        (legacy_F.ten_crop, {}),
        (legacy_F.adjust_brightness, {}),
        (legacy_F.adjust_contrast, {}),
        (legacy_F.adjust_saturation, {}),
        (legacy_F.adjust_hue, {}),
        (legacy_F.adjust_gamma, {}),
1287
1288
        (legacy_F.rotate, {"center", "fill", "interpolation"}),
        (legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
        (legacy_F.to_grayscale, {}),
        (legacy_F.rgb_to_grayscale, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.erase, {}),
        (legacy_F.gaussian_blur, {}),
        (legacy_F.invert, {}),
        (legacy_F.posterize, {}),
        (legacy_F.solarize, {}),
        (legacy_F.adjust_sharpness, {}),
        (legacy_F.autocontrast, {}),
        (legacy_F.equalize, {}),
1300
        (legacy_F.elastic_transform, {"fill", "interpolation"}),
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
    ],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
    legacy_signature = inspect.signature(legacy_dispatcher)
    legacy_params = list(legacy_signature.parameters.values())[1:]

    try:
        prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
    except AttributeError:
        raise AssertionError(
            f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
        ) from None

    prototype_signature = inspect.signature(prototype_dispatcher)
    prototype_params = list(prototype_signature.parameters.values())[1:]

    # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
    # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
    # regular check below.
    prototype_params, new_prototype_params = (
        prototype_params[: len(legacy_params)],
        prototype_params[len(legacy_params) :],
    )
    for param in new_prototype_params:
        assert param.default is not param.empty

    # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
    # annotations. In these cases we simply drop the annotation and default argument from the comparison
    for prototype_param, legacy_param in zip(prototype_params, legacy_params):
        if legacy_param.name in name_only_params:
            prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
            legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
        elif legacy_param.annotation is inspect.Parameter.empty:
            prototype_param._annotation = inspect.Parameter.empty

    assert prototype_params == legacy_params