test_transforms_v2_consistency.py 18.7 KB
Newer Older
1
2
import importlib.machinery
import importlib.util
3
import inspect
4
import random
5
import re
6
from pathlib import Path
7

8
import numpy as np
9
import pytest
10
11

import torch
12
import torchvision.transforms.v2 as v2_transforms
13
from common_utils import assert_close, assert_equal, set_rng_seed
14
from torchvision import transforms as legacy_transforms, tv_tensors
15

16
from torchvision.transforms import functional as legacy_F
17
from torchvision.transforms.v2 import functional as prototype_F
Nicolas Hug's avatar
Nicolas Hug committed
18
from torchvision.transforms.v2._utils import _get_fill, query_size
19
from torchvision.transforms.v2.functional import to_pil_image
20
21
22
23
24
25
26
27
from transforms_v2_legacy_utils import (
    ArgsKwargs,
    make_bounding_boxes,
    make_detection_mask,
    make_image,
    make_images,
    make_segmentation_mask,
)
28

29
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
30
31


Nicolas Hug's avatar
Nicolas Hug committed
32
33
34
35
36
37
@pytest.fixture(autouse=True)
def fix_rng_seed():
    set_rng_seed(0)
    yield


38
39
40
41
42
43
44
45
46
class NotScriptableArgsKwargs(ArgsKwargs):
    """
    This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
    thus will be tested there, but will be skipped by the JIT tests.
    """

    pass


47
48
class ConsistencyConfig:
    def __init__(
49
50
51
        self,
        prototype_cls,
        legacy_cls,
52
53
        # If no args_kwargs is passed, only the signature will be checked
        args_kwargs=(),
54
55
56
        make_images_kwargs=None,
        supports_pil=True,
        removed_params=(),
57
        closeness_kwargs=None,
58
59
60
    ):
        self.prototype_cls = prototype_cls
        self.legacy_cls = legacy_cls
61
        self.args_kwargs = args_kwargs
62
63
        self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
        self.supports_pil = supports_pil
64
        self.removed_params = removed_params
65
        self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
66
67


68
69
70
71
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
LINEAR_TRANSFORMATION_MEAN = torch.rand(36)
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

72
CONSISTENCY_CONFIGS = []
73
74


75
76
77
def check_call_consistency(
    prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
78
79
    if images is None:
        images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
80

81
82
    closeness_kwargs = closeness_kwargs or dict()

83
84
    for image in images:
        image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
85
86
87

        image_tensor = torch.Tensor(image)
        try:
88
            torch.manual_seed(0)
89
            output_legacy_tensor = legacy_transform(image_tensor)
90
91
        except Exception as exc:
            raise pytest.UsageError(
92
                f"Transforming a tensor image {image_repr} failed in the legacy transform with the "
93
                f"error above. This means that you need to specify the parameters passed to `make_images` through the "
94
95
96
97
                "`make_images_kwargs` of the `ConsistencyConfig`."
            ) from exc

        try:
98
            torch.manual_seed(0)
99
            output_prototype_tensor = prototype_transform(image_tensor)
100
101
        except Exception as exc:
            raise AssertionError(
102
                f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
103
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
104
                f"`is_pure_tensor` path in `_transform`."
105
106
            ) from exc

107
        assert_close(
108
109
110
            output_prototype_tensor,
            output_legacy_tensor,
            msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
111
            **closeness_kwargs,
112
113
114
        )

        try:
115
            torch.manual_seed(0)
116
            output_prototype_image = prototype_transform(image)
117
118
        except Exception as exc:
            raise AssertionError(
119
                f"Transforming a image tv_tensor with shape {image_repr} failed in the prototype transform with "
120
                f"the error above. This means there is a consistency bug either in `_get_params` or in the "
121
                f"`tv_tensors.Image` path in `_transform`."
122
123
            ) from exc

124
        assert_close(
125
            output_prototype_image,
126
            output_prototype_tensor,
127
            msg=lambda msg: f"Output for tv_tensor and tensor images is not equal: \n\n{msg}",
128
            **closeness_kwargs,
129
130
        )

131
        if image.ndim == 3 and supports_pil:
132
            image_pil = to_pil_image(image)
133

134
            try:
135
                torch.manual_seed(0)
136
                output_legacy_pil = legacy_transform(image_pil)
137
138
            except Exception as exc:
                raise pytest.UsageError(
139
                    f"Transforming a PIL image with shape {image_repr} failed in the legacy transform with the "
140
141
142
143
144
                    f"error above. If this transform does not support PIL images, set `supports_pil=False` on the "
                    "`ConsistencyConfig`. "
                ) from exc

            try:
145
                torch.manual_seed(0)
146
                output_prototype_pil = prototype_transform(image_pil)
147
148
            except Exception as exc:
                raise AssertionError(
149
                    f"Transforming a PIL image with shape {image_repr} failed in the prototype transform with "
150
151
152
153
                    f"the error above. This means there is a consistency bug either in `_get_params` or in the "
                    f"`PIL.Image.Image` path in `_transform`."
                ) from exc

154
            assert_close(
155
156
                output_prototype_pil,
                output_legacy_pil,
157
                msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
158
                **closeness_kwargs,
159
            )
160
161


162
@pytest.mark.parametrize(
163
164
    ("config", "args_kwargs"),
    [
165
166
167
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
168
        for config in CONSISTENCY_CONFIGS
169
        for idx, args_kwargs in enumerate(config.args_kwargs)
170
    ],
171
)
172
@pytest.mark.filterwarnings("ignore")
173
def test_call_consistency(config, args_kwargs):
174
175
176
    args, kwargs = args_kwargs

    try:
177
        legacy_transform = config.legacy_cls(*args, **kwargs)
178
179
180
181
182
183
184
    except Exception as exc:
        raise pytest.UsageError(
            f"Initializing the legacy transform failed with the error above. "
            f"Please correct the `ArgsKwargs({args_kwargs})` in the `ConsistencyConfig`."
        ) from exc

    try:
185
        prototype_transform = config.prototype_cls(*args, **kwargs)
186
187
188
189
190
191
    except Exception as exc:
        raise AssertionError(
            "Initializing the prototype transform failed with the error above. "
            "This means there is a consistency bug in the constructor."
        ) from exc

192
193
194
195
196
    check_call_consistency(
        prototype_transform,
        legacy_transform,
        images=make_images(**config.make_images_kwargs),
        supports_pil=config.supports_pil,
197
        closeness_kwargs=config.closeness_kwargs,
198
199
200
    )


201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
@pytest.mark.parametrize(
    ("config", "args_kwargs"),
    [
        pytest.param(
            config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
        )
        for config in CONSISTENCY_CONFIGS
        for idx, args_kwargs in enumerate(config.args_kwargs)
        if not isinstance(args_kwargs, NotScriptableArgsKwargs)
    ],
)
def test_jit_consistency(config, args_kwargs):
    args, kwargs = args_kwargs

    prototype_transform_eager = config.prototype_cls(*args, **kwargs)
    legacy_transform_eager = config.legacy_cls(*args, **kwargs)

    legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
    prototype_transform_scripted = torch.jit.script(prototype_transform_eager)

    for image in make_images(**config.make_images_kwargs):
        image = image.as_subclass(torch.Tensor)

        torch.manual_seed(0)
        output_legacy_scripted = legacy_transform_scripted(image)

        torch.manual_seed(0)
        output_prototype_scripted = prototype_transform_scripted(image)

        assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)


233
234
class TestToTensorTransforms:
    def test_pil_to_tensor(self):
235
        prototype_transform = v2_transforms.PILToTensor()
236
237
        legacy_transform = legacy_transforms.PILToTensor()

238
        for image in make_images(extra_dims=[()]):
239
            image_pil = to_pil_image(image)
240
241
242
243

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))

    def test_to_tensor(self):
244
        with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
245
            prototype_transform = v2_transforms.ToTensor()
246
247
        legacy_transform = legacy_transforms.ToTensor()

248
        for image in make_images(extra_dims=[()]):
249
            image_pil = to_pil_image(image)
250
251
252
253
            image_numpy = np.array(image_pil)

            assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
            assert_equal(prototype_transform(image_numpy), legacy_transform(image_numpy))
254
255


256
def import_transforms_from_references(reference):
257
258
259
260
261
262
263
264
265
266
    HERE = Path(__file__).parent
    PROJECT_ROOT = HERE.parent

    loader = importlib.machinery.SourceFileLoader(
        "transforms", str(PROJECT_ROOT / "references" / reference / "transforms.py")
    )
    spec = importlib.util.spec_from_loader("transforms", loader)
    module = importlib.util.module_from_spec(spec)
    loader.exec_module(module)
    return module
267
268
269


det_transforms = import_transforms_from_references("detection")
270
271
272


class TestRefDetTransforms:
273
    def make_tv_tensors(self, with_mask=True):
274
275
276
        size = (600, 800)
        num_objects = 22

277
278
279
        def make_label(extra_dims, categories):
            return torch.randint(categories, extra_dims, dtype=torch.int64)

280
        pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
281
        target = {
282
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
283
284
285
286
287
288
289
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (pil_image, target)

290
        tensor_image = torch.Tensor(make_image(size=size, color_space="RGB", dtype=torch.float32))
291
        target = {
292
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
293
294
295
296
297
298
299
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

        yield (tensor_image, target)

300
        tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
301
        target = {
302
            "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
303
304
305
306
307
            "labels": make_label(extra_dims=(num_objects,), categories=80),
        }
        if with_mask:
            target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

308
        yield (tv_tensor_image, target)
309
310
311
312

    @pytest.mark.parametrize(
        "t_ref, t, data_kwargs",
        [
313
            (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
314
315
316
317
318
            (
                det_transforms.RandomIoUCrop(),
                v2_transforms.Compose(
                    [
                        v2_transforms.RandomIoUCrop(),
319
                        v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
320
321
322
323
                    ]
                ),
                {"with_mask": False},
            ),
324
            (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
325
            (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024), antialias=True), {}),
326
327
328
329
            (
                det_transforms.RandomShortestSize(
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
330
                v2_transforms.RandomShortestSize(
331
332
333
334
335
336
337
                    min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
                ),
                {},
            ),
        ],
    )
    def test_transform(self, t_ref, t, data_kwargs):
338
        for dp in self.make_tv_tensors(**data_kwargs):
339
340
341
342
343
344
345
346
347

            # We should use prototype transform first as reference transform performs inplace target update
            torch.manual_seed(12)
            output = t(dp)

            torch.manual_seed(12)
            expected_output = t_ref(*dp)

            assert_equal(expected_output, output)
348
349
350
351
352
353
354
355
356


seg_transforms = import_transforms_from_references("segmentation")


# We need this transform for two reasons:
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
#    counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
357
class PadIfSmaller(v2_transforms.Transform):
358
359
360
    def __init__(self, size, fill=0):
        super().__init__()
        self.size = size
361
        self.fill = v2_transforms._geometry._setup_fill_arg(fill)
362
363

    def _get_params(self, sample):
Philip Meier's avatar
Philip Meier committed
364
        height, width = query_size(sample)
365
366
367
368
369
370
371
372
        padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
        needs_padding = any(padding)
        return dict(padding=padding, needs_padding=needs_padding)

    def _transform(self, inpt, params):
        if not params["needs_padding"]:
            return inpt

373
        fill = _get_fill(self.fill, type(inpt))
374
        return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
375
376
377


class TestRefSegTransforms:
378
    def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
379
        size = (256, 460)
380
381
382
383
        num_categories = 21

        conv_fns = []
        if supports_pil:
384
            conv_fns.append(to_pil_image)
385
386
387
        conv_fns.extend([torch.Tensor, lambda x: x])

        for conv_fn in conv_fns:
388
389
            tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
            tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
390

391
            dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
392
            dp_ref = (
393
394
                to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
                to_pil_image(tv_tensor_mask),
395
396
397
398
399
400
401
402
403
            )

            yield dp, dp_ref

    def set_seed(self, seed=12):
        torch.manual_seed(seed)
        random.seed(seed)

    def check(self, t, t_ref, data_kwargs=None):
404
        for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
405
406

            self.set_seed()
407
            actual = actual_image, actual_mask = t(dp)
408
409

            self.set_seed()
410
411
412
413
414
            expected_image, expected_mask = t_ref(*dp_ref)
            if isinstance(actual_image, torch.Tensor) and not isinstance(expected_image, torch.Tensor):
                expected_image = legacy_F.pil_to_tensor(expected_image)
            expected_mask = legacy_F.pil_to_tensor(expected_mask).squeeze(0)
            expected = (expected_image, expected_mask)
415

416
            assert_equal(actual, expected)
417
418
419
420
421
422

    @pytest.mark.parametrize(
        ("t_ref", "t", "data_kwargs"),
        [
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
423
                v2_transforms.RandomHorizontalFlip(p=1.0),
424
425
426
427
                dict(),
            ),
            (
                seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
428
                v2_transforms.RandomHorizontalFlip(p=0.0),
429
430
431
432
                dict(),
            ),
            (
                seg_transforms.RandomCrop(size=480),
433
                v2_transforms.Compose(
434
                    [
435
                        PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
436
                        v2_transforms.RandomCrop(size=480),
437
438
439
440
441
442
                    ]
                ),
                dict(),
            ),
            (
                seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
443
                v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
444
445
446
447
448
449
450
                dict(supports_pil=False, image_dtype=torch.float),
            ),
        ],
    )
    def test_common(self, t_ref, t, data_kwargs):
        self.check(t, t_ref, data_kwargs)

451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500

@pytest.mark.parametrize(
    ("legacy_dispatcher", "name_only_params"),
    [
        (legacy_F.get_dimensions, {}),
        (legacy_F.get_image_size, {}),
        (legacy_F.get_image_num_channels, {}),
        (legacy_F.to_tensor, {}),
        (legacy_F.pil_to_tensor, {}),
        (legacy_F.convert_image_dtype, {}),
        (legacy_F.to_pil_image, {}),
        (legacy_F.to_grayscale, {}),
        (legacy_F.rgb_to_grayscale, {}),
        (legacy_F.to_tensor, {}),
    ],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
    legacy_signature = inspect.signature(legacy_dispatcher)
    legacy_params = list(legacy_signature.parameters.values())[1:]

    try:
        prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
    except AttributeError:
        raise AssertionError(
            f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
        ) from None

    prototype_signature = inspect.signature(prototype_dispatcher)
    prototype_params = list(prototype_signature.parameters.values())[1:]

    # Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
    # need to check if parameters were added in the middle rather than at the end, since that will be caught by the
    # regular check below.
    prototype_params, new_prototype_params = (
        prototype_params[: len(legacy_params)],
        prototype_params[len(legacy_params) :],
    )
    for param in new_prototype_params:
        assert param.default is not param.empty

    # Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
    # annotations. In these cases we simply drop the annotation and default argument from the comparison
    for prototype_param, legacy_param in zip(prototype_params, legacy_params):
        if legacy_param.name in name_only_params:
            prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
            legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
        elif legacy_param.annotation is inspect.Parameter.empty:
            prototype_param._annotation = inspect.Parameter.empty

    assert prototype_params == legacy_params