test_extended_models.py 11.3 KB
Newer Older
Philip Meier's avatar
Philip Meier committed
1
import os
2
3

import pytest
4
import test_models as TM
5
import torch
6
from torchvision import models
7
from torchvision.models._api import get_model_weights, Weights, WeightsEnum
8
from torchvision.models._utils import handle_legacy_interface
9

10

11
12
13
14
run_if_test_with_extended = pytest.mark.skipif(
    os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
    reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.",
)
15
16


17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@pytest.mark.parametrize(
    "name, model_class",
    [
        ("resnet50", models.ResNet),
        ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet),
        ("raft_large", models.optical_flow.RAFT),
        ("quantized_resnet50", models.quantization.QuantizableResNet),
        ("lraspp_mobilenet_v3_large", models.segmentation.LRASPP),
        ("mvit_v1_b", models.video.MViT),
    ],
)
def test_get_model(name, model_class):
    assert isinstance(models.get_model(name), model_class)


@pytest.mark.parametrize(
    "name, weight",
    [
        ("resnet50", models.ResNet50_Weights),
        ("retinanet_resnet50_fpn_v2", models.detection.RetinaNet_ResNet50_FPN_V2_Weights),
        ("raft_large", models.optical_flow.Raft_Large_Weights),
        ("quantized_resnet50", models.quantization.ResNet50_QuantizedWeights),
        ("lraspp_mobilenet_v3_large", models.segmentation.LRASPP_MobileNet_V3_Large_Weights),
        ("mvit_v1_b", models.video.MViT_V1_B_Weights),
    ],
)
def test_get_model_weights(name, weight):
    assert models.get_model_weights(name) == weight
45
46


47
48
49
50
51
52
53
@pytest.mark.parametrize(
    "module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
)
def test_list_models(module):
    def get_models_from_module(module):
        return [
            v.__name__
54
            for k, v in module.__dict__.items()
55
56
57
58
59
60
61
62
            if callable(v) and k[0].islower() and k[0] != "_" and k not in models._api.__all__
        ]

    a = set(get_models_from_module(module))
    b = set(x.replace("quantized_", "") for x in models.list_models(module))

    assert len(b) > 0
    assert a == b
63
64


65
@pytest.mark.parametrize(
66
    "name, weight",
67
    [
68
69
        ("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
        ("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
70
        (
71
72
            "ResNet50_QuantizedWeights.DEFAULT",
            models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
73
        ),
74
        (
75
76
            "ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
            models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
77
        ),
78
79
    ],
)
80
81
def test_get_weight(name, weight):
    assert models.get_weight(name) == weight
82
83


84
85
@pytest.mark.parametrize(
    "model_fn",
86
87
88
89
90
91
    TM.list_model_fns(models)
    + TM.list_model_fns(models.detection)
    + TM.list_model_fns(models.quantization)
    + TM.list_model_fns(models.segmentation)
    + TM.list_model_fns(models.video)
    + TM.list_model_fns(models.optical_flow),
92
93
)
def test_naming_conventions(model_fn):
94
    weights_enum = get_model_weights(model_fn)
95
    assert weights_enum is not None
96
    assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
97
98


99
100
@pytest.mark.parametrize(
    "model_fn",
101
102
103
104
105
106
    TM.list_model_fns(models)
    + TM.list_model_fns(models.detection)
    + TM.list_model_fns(models.quantization)
    + TM.list_model_fns(models.segmentation)
    + TM.list_model_fns(models.video)
    + TM.list_model_fns(models.optical_flow),
107
)
108
@run_if_test_with_extended
109
def test_schema_meta_validation(model_fn):
110
111
112
113
114
115
    # list of all possible supported high-level fields for weights meta-data
    permitted_fields = {
        "backend",
        "categories",
        "keypoint_names",
        "license",
116
        "_metrics",
117
        "min_size",
118
        "min_temporal_size",
119
120
121
        "num_params",
        "recipe",
        "unquantized",
122
        "_docs",
123
124
    }
    # mandatory fields for each computer vision task
125
    classification_fields = {"categories", ("_metrics", "ImageNet-1K", "acc@1"), ("_metrics", "ImageNet-1K", "acc@5")}
126
    defaults = {
127
        "all": {"_metrics", "min_size", "num_params", "recipe", "_docs"},
128
        "models": classification_fields,
129
        "detection": {"categories", ("_metrics", "COCO-val2017", "box_map")},
130
        "quantization": classification_fields | {"backend", "unquantized"},
131
132
133
134
135
136
        "segmentation": {
            "categories",
            ("_metrics", "COCO-val2017-VOC-labels", "miou"),
            ("_metrics", "COCO-val2017-VOC-labels", "pixel_acc"),
        },
        "video": {"categories", ("_metrics", "Kinetics-400", "acc@1"), ("_metrics", "Kinetics-400", "acc@5")},
137
        "optical_flow": set(),
138
    }
139
    model_name = model_fn.__name__
140
    module_name = model_fn.__module__.split(".")[-2]
141
    expected_fields = defaults["all"] | defaults[module_name]
142

143
    weights_enum = get_model_weights(model_fn)
144
145
    if len(weights_enum) == 0:
        pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
146
147

    problematic_weights = {}
148
    incorrect_params = []
149
    bad_names = []
150
    for w in weights_enum:
151
152
153
154
155
156
157
        actual_fields = set(w.meta.keys())
        actual_fields |= set(
            ("_metrics", dataset, metric_key)
            for dataset in w.meta.get("_metrics", {}).keys()
            for metric_key in w.meta.get("_metrics", {}).get(dataset, {}).keys()
        )
        missing_fields = expected_fields - actual_fields
158
159
160
        unsupported_fields = set(w.meta.keys()) - permitted_fields
        if missing_fields or unsupported_fields:
            problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}
161
        if w == weights_enum.DEFAULT:
162
            if module_name == "quantization":
163
                # parameters() count doesn't work well with quantization, so we check against the non-quantized
164
165
166
167
168
169
170
                unquantized_w = w.meta.get("unquantized")
                if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
                    incorrect_params.append(w)
            else:
                if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
                    incorrect_params.append(w)
        else:
171
            if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
172
173
                if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
                    incorrect_params.append(w)
174
175
        if not w.name.isupper():
            bad_names.append(w)
176
177

    assert not problematic_weights
178
    assert not incorrect_params
179
    assert not bad_names
180
181


182
@pytest.mark.parametrize(
183
    "model_fn",
184
185
186
187
188
189
    TM.list_model_fns(models)
    + TM.list_model_fns(models.detection)
    + TM.list_model_fns(models.quantization)
    + TM.list_model_fns(models.segmentation)
    + TM.list_model_fns(models.video)
    + TM.list_model_fns(models.optical_flow),
190
)
191
192
193
@run_if_test_with_extended
def test_transforms_jit(model_fn):
    model_name = model_fn.__name__
194
    weights_enum = get_model_weights(model_fn)
195
196
197
    if len(weights_enum) == 0:
        pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")

198
    defaults = {
199
200
201
        "models": {
            "input_shape": (1, 3, 224, 224),
        },
202
203
204
        "detection": {
            "input_shape": (3, 300, 300),
        },
205
206
207
        "quantization": {
            "input_shape": (1, 3, 224, 224),
        },
208
209
210
211
        "segmentation": {
            "input_shape": (1, 3, 520, 520),
        },
        "video": {
212
            "input_shape": (1, 3, 4, 112, 112),
213
        },
214
215
216
        "optical_flow": {
            "input_shape": (1, 3, 128, 128),
        },
217
    }
218
    module_name = model_fn.__module__.split(".")[-2]
219

220
221
222
    kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})}
    input_shape = kwargs.pop("input_shape")
    x = torch.rand(input_shape)
223
    if module_name == "optical_flow":
224
        args = (x, x)
225
    else:
226
227
        if module_name == "video":
            x = x.permute(0, 2, 1, 3, 4)
228
        args = (x,)
229

230
231
232
233
234
235
236
    problematic_weights = []
    for w in weights_enum:
        transforms = w.transforms()
        try:
            TM._check_jit_scriptable(transforms, args)
        except Exception:
            problematic_weights.append(w)
237

238
    assert not problematic_weights
Philip Meier's avatar
Philip Meier committed
239
240
241
242
243


# With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface:
244
    class ModelWeights(WeightsEnum):
Philip Meier's avatar
Philip Meier committed
245
246
247
248
249
250
251
        Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())

    @pytest.mark.parametrize(
        "kwargs",
        [
            pytest.param(dict(), id="empty"),
            pytest.param(dict(weights=None), id="None"),
252
            pytest.param(dict(weights=ModelWeights.Sentinel), id="Weights"),
Philip Meier's avatar
Philip Meier committed
253
254
255
        ],
    )
    def test_no_warn(self, kwargs):
256
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
257
258
259
260
261
262
263
        def builder(*, weights=None):
            pass

        builder(**kwargs)

    @pytest.mark.parametrize("pretrained", (True, False))
    def test_pretrained_pos(self, pretrained):
264
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
265
266
267
268
269
270
271
272
        def builder(*, weights=None):
            pass

        with pytest.warns(UserWarning, match="positional"):
            builder(pretrained)

    @pytest.mark.parametrize("pretrained", (True, False))
    def test_pretrained_kw(self, pretrained):
273
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
274
275
276
277
278
279
280
281
282
        def builder(*, weights=None):
            pass

        with pytest.warns(UserWarning, match="deprecated"):
            builder(pretrained)

    @pytest.mark.parametrize("pretrained", (True, False))
    @pytest.mark.parametrize("positional", (True, False))
    def test_equivalent_behavior_weights(self, pretrained, positional):
283
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
284
285
286
287
        def builder(*, weights=None):
            pass

        args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
288
        with pytest.warns(UserWarning, match=f"weights={self.ModelWeights.Sentinel if pretrained else None}"):
Philip Meier's avatar
Philip Meier committed
289
290
291
292
293
294
295
296
            builder(*args, **kwargs)

    def test_multi_params(self):
        weights_params = ("weights", "weights_other")
        pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]

        @handle_legacy_interface(
            **{
297
                weights_param: (pretrained_param, self.ModelWeights.Sentinel)
Philip Meier's avatar
Philip Meier committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
                for weights_param, pretrained_param in zip(weights_params, pretrained_params)
            }
        )
        def builder(*, weights=None, weights_other=None):
            pass

        for pretrained_param in pretrained_params:
            with pytest.warns(UserWarning, match="deprecated"):
                builder(**{pretrained_param: True})

    def test_default_callable(self):
        @handle_legacy_interface(
            weights=(
                "pretrained",
312
                lambda kwargs: self.ModelWeights.Sentinel if kwargs["flag"] else None,
Philip Meier's avatar
Philip Meier committed
313
314
315
316
317
318
319
320
321
322
            )
        )
        def builder(*, weights=None, flag):
            pass

        with pytest.warns(UserWarning, match="deprecated"):
            builder(pretrained=True, flag=True)

        with pytest.raises(ValueError, match="weights"):
            builder(pretrained=True, flag=False)