test_extended_models.py 17.7 KB
Newer Older
1
import copy
Philip Meier's avatar
Philip Meier committed
2
import os
3
import pickle
4
5

import pytest
6
import test_models as TM
7
import torch
Nicolas Hug's avatar
Nicolas Hug committed
8
from common_extended_utils import get_file_size_mb, get_ops
9
from torchvision import models
10
from torchvision.models import get_model_weights, Weights, WeightsEnum
11
from torchvision.models._utils import handle_legacy_interface
12
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
13

14
15
16
17
run_if_test_with_extended = pytest.mark.skipif(
    os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
    reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.",
)
18
19


20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
@pytest.mark.parametrize(
    "name, model_class",
    [
        ("resnet50", models.ResNet),
        ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet),
        ("raft_large", models.optical_flow.RAFT),
        ("quantized_resnet50", models.quantization.QuantizableResNet),
        ("lraspp_mobilenet_v3_large", models.segmentation.LRASPP),
        ("mvit_v1_b", models.video.MViT),
    ],
)
def test_get_model(name, model_class):
    assert isinstance(models.get_model(name), model_class)


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@pytest.mark.parametrize(
    "name, model_fn",
    [
        ("resnet50", models.resnet50),
        ("retinanet_resnet50_fpn_v2", models.detection.retinanet_resnet50_fpn_v2),
        ("raft_large", models.optical_flow.raft_large),
        ("quantized_resnet50", models.quantization.resnet50),
        ("lraspp_mobilenet_v3_large", models.segmentation.lraspp_mobilenet_v3_large),
        ("mvit_v1_b", models.video.mvit_v1_b),
    ],
)
def test_get_model_builder(name, model_fn):
    assert models.get_model_builder(name) == model_fn


50
51
52
53
54
55
56
57
58
59
60
61
62
@pytest.mark.parametrize(
    "name, weight",
    [
        ("resnet50", models.ResNet50_Weights),
        ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet_ResNet50_FPN_V2_Weights),
        ("raft_large", models.optical_flow.Raft_Large_Weights),
        ("quantized_resnet50", models.quantization.ResNet50_QuantizedWeights),
        ("lraspp_mobilenet_v3_large", models.segmentation.LRASPP_MobileNet_V3_Large_Weights),
        ("mvit_v1_b", models.video.MViT_V1_B_Weights),
    ],
)
def test_get_model_weights(name, weight):
    assert models.get_model_weights(name) == weight
63
64


65
66
67
68
69
70
71
72
73
74
75
76
77
@pytest.mark.parametrize("copy_fn", [copy.copy, copy.deepcopy])
@pytest.mark.parametrize(
    "name",
    [
        "resnet50",
        "retinanet_resnet50_fpn_v2",
        "raft_large",
        "quantized_resnet50",
        "lraspp_mobilenet_v3_large",
        "mvit_v1_b",
    ],
)
def test_weights_copyable(copy_fn, name):
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    for weights in list(models.get_model_weights(name)):
        # It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
        # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
        # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
        # support for the identity operation in the future.
        assert copy_fn(weights) is weights


@pytest.mark.parametrize(
    "name",
    [
        "resnet50",
        "retinanet_resnet50_fpn_v2",
        "raft_large",
        "quantized_resnet50",
        "lraspp_mobilenet_v3_large",
        "mvit_v1_b",
    ],
)
def test_weights_deserializable(name):
    for weights in list(models.get_model_weights(name)):
        # It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
        # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
        # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
        # support for the identity operation in the future.
        assert pickle.loads(pickle.dumps(weights)) is weights
104
105


106
107
108
109
110
111
112
113
def get_models_from_module(module):
    return [
        v.__name__
        for k, v in module.__dict__.items()
        if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
    ]


114
115
116
117
118
119
120
121
122
@pytest.mark.parametrize(
    "module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
)
def test_list_models(module):
    a = set(get_models_from_module(module))
    b = set(x.replace("quantized_", "") for x in models.list_models(module))

    assert len(b) > 0
    assert a == b
123
124


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
@pytest.mark.parametrize(
    "include_filters",
    [
        None,
        [],
        (),
        "",
        "*resnet*",
        ["*alexnet*"],
        "*not-existing-model-for-test?",
        ["*resnet*", "*alexnet*"],
        ["*resnet*", "*alexnet*", "*not-existing-model-for-test?"],
        ("*resnet*", "*alexnet*"),
        set(["*resnet*", "*alexnet*"]),
    ],
)
@pytest.mark.parametrize(
    "exclude_filters",
    [
        None,
        [],
        (),
        "",
        "*resnet*",
        ["*alexnet*"],
        ["*not-existing-model-for-test?"],
        ["resnet34", "*not-existing-model-for-test?"],
        ["resnet34", "*resnet1*"],
        ("resnet34", "*resnet1*"),
        set(["resnet34", "*resnet1*"]),
    ],
)
def test_list_models_filters(include_filters, exclude_filters):
    actual = set(models.list_models(models, include=include_filters, exclude=exclude_filters))
    classification_models = set(get_models_from_module(models))

    if isinstance(include_filters, str):
        include_filters = [include_filters]
    if isinstance(exclude_filters, str):
        exclude_filters = [exclude_filters]

    if include_filters:
        expected = set()
        for include_f in include_filters:
            include_f = include_f.strip("*?")
            expected = expected | set(x for x in classification_models if include_f in x)
    else:
        expected = classification_models

    if exclude_filters:
        for exclude_f in exclude_filters:
            exclude_f = exclude_f.strip("*?")
            if exclude_f != "":
                a_exclude = set(x for x in classification_models if exclude_f in x)
                expected = expected - a_exclude

    assert expected == actual


184
@pytest.mark.parametrize(
185
    "name, weight",
186
    [
187
188
        ("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
        ("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
189
        (
190
191
            "ResNet50_QuantizedWeights.DEFAULT",
            models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
192
        ),
193
        (
194
195
            "ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
            models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
196
        ),
197
198
    ],
)
199
200
def test_get_weight(name, weight):
    assert models.get_weight(name) == weight
201
202


203
204
@pytest.mark.parametrize(
    "model_fn",
205
206
207
208
209
210
    TM.list_model_fns(models)
    + TM.list_model_fns(models.detection)
    + TM.list_model_fns(models.quantization)
    + TM.list_model_fns(models.segmentation)
    + TM.list_model_fns(models.video)
    + TM.list_model_fns(models.optical_flow),
211
212
)
def test_naming_conventions(model_fn):
213
    weights_enum = get_model_weights(model_fn)
214
    assert weights_enum is not None
215
    assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
216
217


218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
detection_models_input_dims = {
    "fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320),
    "fasterrcnn_mobilenet_v3_large_fpn": (800, 800),
    "fasterrcnn_resnet50_fpn": (800, 800),
    "fasterrcnn_resnet50_fpn_v2": (800, 800),
    "fcos_resnet50_fpn": (800, 800),
    "keypointrcnn_resnet50_fpn": (1333, 1333),
    "maskrcnn_resnet50_fpn": (800, 800),
    "maskrcnn_resnet50_fpn_v2": (800, 800),
    "retinanet_resnet50_fpn": (800, 800),
    "retinanet_resnet50_fpn_v2": (800, 800),
    "ssd300_vgg16": (300, 300),
    "ssdlite320_mobilenet_v3_large": (320, 320),
}


234
235
@pytest.mark.parametrize(
    "model_fn",
236
237
238
239
240
241
    TM.list_model_fns(models)
    + TM.list_model_fns(models.detection)
    + TM.list_model_fns(models.quantization)
    + TM.list_model_fns(models.segmentation)
    + TM.list_model_fns(models.video)
    + TM.list_model_fns(models.optical_flow),
242
)
243
@run_if_test_with_extended
244
def test_schema_meta_validation(model_fn):
245
246
247
    if model_fn.__name__ == "maskrcnn_resnet50_fpn_v2":
        pytest.skip(reason="FIXME https://github.com/pytorch/vision/issues/7349")

248
249
250
251
252
253
    # list of all possible supported high-level fields for weights meta-data
    permitted_fields = {
        "backend",
        "categories",
        "keypoint_names",
        "license",
254
        "_metrics",
255
        "min_size",
256
        "min_temporal_size",
257
258
259
        "num_params",
        "recipe",
        "unquantized",
260
        "_docs",
261
        "_ops",
Nicolas Hug's avatar
Nicolas Hug committed
262
        "_file_size",
263
264
    }
    # mandatory fields for each computer vision task
265
    classification_fields = {"categories", ("_metrics", "ImageNet-1K", "acc@1"), ("_metrics", "ImageNet-1K", "acc@5")}
266
    defaults = {
Nicolas Hug's avatar
Nicolas Hug committed
267
        "all": {"_metrics", "min_size", "num_params", "recipe", "_docs", "_file_size", "_ops"},
268
        "models": classification_fields,
269
        "detection": {"categories", ("_metrics", "COCO-val2017", "box_map")},
270
        "quantization": classification_fields | {"backend", "unquantized"},
271
272
273
274
275
276
        "segmentation": {
            "categories",
            ("_metrics", "COCO-val2017-VOC-labels", "miou"),
            ("_metrics", "COCO-val2017-VOC-labels", "pixel_acc"),
        },
        "video": {"categories", ("_metrics", "Kinetics-400", "acc@1"), ("_metrics", "Kinetics-400", "acc@5")},
277
        "optical_flow": set(),
278
    }
279
    model_name = model_fn.__name__
280
    module_name = model_fn.__module__.split(".")[-2]
281
    expected_fields = defaults["all"] | defaults[module_name]
282

283
    weights_enum = get_model_weights(model_fn)
284
285
    if len(weights_enum) == 0:
        pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
286
287

    problematic_weights = {}
288
    incorrect_meta = []
289
    bad_names = []
290
    for w in weights_enum:
291
292
293
294
295
296
297
        actual_fields = set(w.meta.keys())
        actual_fields |= set(
            ("_metrics", dataset, metric_key)
            for dataset in w.meta.get("_metrics", {}).keys()
            for metric_key in w.meta.get("_metrics", {}).get(dataset, {}).keys()
        )
        missing_fields = expected_fields - actual_fields
298
299
300
        unsupported_fields = set(w.meta.keys()) - permitted_fields
        if missing_fields or unsupported_fields:
            problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}
301
302

        if w == weights_enum.DEFAULT or any(w.meta[k] != weights_enum.DEFAULT.meta[k] for k in ["num_params", "_ops"]):
303
            if module_name == "quantization":
304
                # parameters() count doesn't work well with quantization, so we check against the non-quantized
305
                unquantized_w = w.meta.get("unquantized")
306
307
308
309
310
311
312
313
314
                if unquantized_w is not None:
                    if w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
                        incorrect_meta.append((w, "num_params"))

                    # the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs
                    # instead
                    if w.meta["_ops"] != unquantized_w.meta.get("_ops"):
                        incorrect_meta.append((w, "_ops"))

315
            else:
316
317
318
319
320
321
322
323
324
325
326
327
                # loading the model and using it for parameter and ops verification
                model = model_fn(weights=w)

                if w.meta.get("num_params") != sum(p.numel() for p in model.parameters()):
                    incorrect_meta.append((w, "num_params"))

                kwargs = {}
                if model_name in detection_models_input_dims:
                    # detection models have non default height and width
                    height, width = detection_models_input_dims[model_name]
                    kwargs = {"height": height, "width": width}

328
329
330
331
332
                if not model_fn.__name__.startswith("vit"):
                    # FIXME: https://github.com/pytorch/vision/issues/7871
                    calculated_ops = get_ops(model=model, weight=w, **kwargs)
                    if calculated_ops != w.meta["_ops"]:
                        incorrect_meta.append((w, "_ops"))
333

334
335
        if not w.name.isupper():
            bad_names.append(w)
336

Nicolas Hug's avatar
Nicolas Hug committed
337
338
        if get_file_size_mb(w) != w.meta.get("_file_size"):
            incorrect_meta.append((w, "_file_size"))
339

340
    assert not problematic_weights
341
    assert not incorrect_meta
342
    assert not bad_names
343
344


345
@pytest.mark.parametrize(
346
    "model_fn",
347
348
349
350
351
352
    TM.list_model_fns(models)
    + TM.list_model_fns(models.detection)
    + TM.list_model_fns(models.quantization)
    + TM.list_model_fns(models.segmentation)
    + TM.list_model_fns(models.video)
    + TM.list_model_fns(models.optical_flow),
353
)
354
355
356
@run_if_test_with_extended
def test_transforms_jit(model_fn):
    model_name = model_fn.__name__
357
    weights_enum = get_model_weights(model_fn)
358
359
360
    if len(weights_enum) == 0:
        pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")

361
    defaults = {
362
363
364
        "models": {
            "input_shape": (1, 3, 224, 224),
        },
365
366
367
        "detection": {
            "input_shape": (3, 300, 300),
        },
368
369
370
        "quantization": {
            "input_shape": (1, 3, 224, 224),
        },
371
372
373
374
        "segmentation": {
            "input_shape": (1, 3, 520, 520),
        },
        "video": {
375
            "input_shape": (1, 3, 4, 112, 112),
376
        },
377
378
379
        "optical_flow": {
            "input_shape": (1, 3, 128, 128),
        },
380
    }
381
    module_name = model_fn.__module__.split(".")[-2]
382

383
384
385
    kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})}
    input_shape = kwargs.pop("input_shape")
    x = torch.rand(input_shape)
386
    if module_name == "optical_flow":
387
        args = (x, x)
388
    else:
389
390
        if module_name == "video":
            x = x.permute(0, 2, 1, 3, 4)
391
        args = (x,)
392

393
394
395
396
397
398
399
    problematic_weights = []
    for w in weights_enum:
        transforms = w.transforms()
        try:
            TM._check_jit_scriptable(transforms, args)
        except Exception:
            problematic_weights.append(w)
400

401
    assert not problematic_weights
Philip Meier's avatar
Philip Meier committed
402
403
404
405
406


# With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface:
407
    class ModelWeights(WeightsEnum):
Philip Meier's avatar
Philip Meier committed
408
409
410
411
412
413
414
        Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())

    @pytest.mark.parametrize(
        "kwargs",
        [
            pytest.param(dict(), id="empty"),
            pytest.param(dict(weights=None), id="None"),
415
            pytest.param(dict(weights=ModelWeights.Sentinel), id="Weights"),
Philip Meier's avatar
Philip Meier committed
416
417
418
        ],
    )
    def test_no_warn(self, kwargs):
419
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
420
421
422
423
424
425
426
        def builder(*, weights=None):
            pass

        builder(**kwargs)

    @pytest.mark.parametrize("pretrained", (True, False))
    def test_pretrained_pos(self, pretrained):
427
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
428
429
430
431
432
433
434
435
        def builder(*, weights=None):
            pass

        with pytest.warns(UserWarning, match="positional"):
            builder(pretrained)

    @pytest.mark.parametrize("pretrained", (True, False))
    def test_pretrained_kw(self, pretrained):
436
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
437
438
439
440
441
442
443
444
445
        def builder(*, weights=None):
            pass

        with pytest.warns(UserWarning, match="deprecated"):
            builder(pretrained)

    @pytest.mark.parametrize("pretrained", (True, False))
    @pytest.mark.parametrize("positional", (True, False))
    def test_equivalent_behavior_weights(self, pretrained, positional):
446
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
447
448
449
450
        def builder(*, weights=None):
            pass

        args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
451
        with pytest.warns(UserWarning, match=f"weights={self.ModelWeights.Sentinel if pretrained else None}"):
Philip Meier's avatar
Philip Meier committed
452
453
454
455
456
457
458
459
            builder(*args, **kwargs)

    def test_multi_params(self):
        weights_params = ("weights", "weights_other")
        pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]

        @handle_legacy_interface(
            **{
460
                weights_param: (pretrained_param, self.ModelWeights.Sentinel)
Philip Meier's avatar
Philip Meier committed
461
462
463
464
465
466
467
468
469
470
471
472
473
474
                for weights_param, pretrained_param in zip(weights_params, pretrained_params)
            }
        )
        def builder(*, weights=None, weights_other=None):
            pass

        for pretrained_param in pretrained_params:
            with pytest.warns(UserWarning, match="deprecated"):
                builder(**{pretrained_param: True})

    def test_default_callable(self):
        @handle_legacy_interface(
            weights=(
                "pretrained",
475
                lambda kwargs: self.ModelWeights.Sentinel if kwargs["flag"] else None,
Philip Meier's avatar
Philip Meier committed
476
477
478
479
480
481
482
483
484
485
            )
        )
        def builder(*, weights=None, flag):
            pass

        with pytest.warns(UserWarning, match="deprecated"):
            builder(pretrained=True, flag=True)

        with pytest.raises(ValueError, match="weights"):
            builder(pretrained=True, flag=False)
486
487
488
489
490
491
492
493

    @pytest.mark.parametrize(
        "model_fn",
        [fn for fn in TM.list_model_fns(models) if fn.__name__ not in {"vit_h_14", "regnet_y_128gf"}]
        + TM.list_model_fns(models.detection)
        + TM.list_model_fns(models.quantization)
        + TM.list_model_fns(models.segmentation)
        + TM.list_model_fns(models.video)
494
495
496
497
498
        + TM.list_model_fns(models.optical_flow)
        + [
            lambda pretrained: resnet_fpn_backbone(backbone_name="resnet50", pretrained=pretrained),
            lambda pretrained: mobilenet_backbone(backbone_name="mobilenet_v2", fpn=False, pretrained=pretrained),
        ],
499
500
501
502
503
    )
    @run_if_test_with_extended
    def test_pretrained_deprecation(self, model_fn):
        with pytest.warns(UserWarning, match="deprecated"):
            model_fn(pretrained=True)