test_extended_models.py 9.81 KB
Newer Older
1
import importlib
Philip Meier's avatar
Philip Meier committed
2
import os
3
4

import pytest
5
import test_models as TM
6
import torch
7
8
9
from torchvision import models
from torchvision.models._api import WeightsEnum, Weights
from torchvision.models._utils import handle_legacy_interface
10

11

12
13
14
15
run_if_test_with_extended = pytest.mark.skipif(
    os.getenv("PYTORCH_TEST_WITH_EXTENDED", "0") != "1",
    reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.",
)
16
17


18
19
20
21
22
23
def _get_parent_module(model_fn):
    parent_module_name = ".".join(model_fn.__module__.split(".")[:-1])
    module = importlib.import_module(parent_module_name)
    return module


24
25
26
27
28
29
30
31
32
33
34
35
36
def _get_model_weights(model_fn):
    module = _get_parent_module(model_fn)
    weights_name = "_QuantizedWeights" if module.__name__.split(".")[-1] == "quantization" else "_Weights"
    try:
        return next(
            v
            for k, v in module.__dict__.items()
            if k.endswith(weights_name) and k.replace(weights_name, "").lower() == model_fn.__name__
        )
    except StopIteration:
        return None


37
@pytest.mark.parametrize(
38
    "name, weight",
39
    [
40
41
        ("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
        ("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
42
        (
43
44
            "ResNet50_QuantizedWeights.DEFAULT",
            models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
45
        ),
46
        (
47
48
            "ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
            models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
49
        ),
50
51
    ],
)
52
53
def test_get_weight(name, weight):
    assert models.get_weight(name) == weight
54
55


56
57
58
59
60
61
@pytest.mark.parametrize(
    "model_fn",
    TM.get_models_from_module(models)
    + TM.get_models_from_module(models.detection)
    + TM.get_models_from_module(models.quantization)
    + TM.get_models_from_module(models.segmentation)
62
63
    + TM.get_models_from_module(models.video)
    + TM.get_models_from_module(models.optical_flow),
64
65
)
def test_naming_conventions(model_fn):
66
67
    weights_enum = _get_model_weights(model_fn)
    assert weights_enum is not None
68
    assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
69
70


71
72
73
74
75
76
@pytest.mark.parametrize(
    "model_fn",
    TM.get_models_from_module(models)
    + TM.get_models_from_module(models.detection)
    + TM.get_models_from_module(models.quantization)
    + TM.get_models_from_module(models.segmentation)
77
78
    + TM.get_models_from_module(models.video)
    + TM.get_models_from_module(models.optical_flow),
79
)
80
@run_if_test_with_extended
81
def test_schema_meta_validation(model_fn):
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    # list of all possible supported high-level fields for weights meta-data
    permitted_fields = {
        "backend",
        "categories",
        "keypoint_names",
        "license",
        "metrics",
        "min_size",
        "num_params",
        "recipe",
        "unquantized",
    }
    # mandatory fields for each computer vision task
    classification_fields = {"categories", ("metrics", "acc@1"), ("metrics", "acc@5")}
96
    defaults = {
97
        "all": {"metrics", "min_size", "num_params", "recipe"},
98
        "models": classification_fields,
99
100
101
        "detection": {"categories", ("metrics", "box_map")},
        "quantization": classification_fields | {"backend", "unquantized"},
        "segmentation": {"categories", ("metrics", "miou"), ("metrics", "pixel_acc")},
102
        "video": classification_fields,
103
        "optical_flow": set(),
104
    }
105
    model_name = model_fn.__name__
106
    module_name = model_fn.__module__.split(".")[-2]
107
    fields = defaults["all"] | defaults[module_name]
108
109

    weights_enum = _get_model_weights(model_fn)
110
111
    if len(weights_enum) == 0:
        pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")
112
113

    problematic_weights = {}
114
    incorrect_params = []
115
    bad_names = []
116
    for w in weights_enum:
117
118
119
120
        missing_fields = fields - (set(w.meta.keys()) | set(("metrics", x) for x in w.meta.get("metrics", {}).keys()))
        unsupported_fields = set(w.meta.keys()) - permitted_fields
        if missing_fields or unsupported_fields:
            problematic_weights[w] = {"missing": missing_fields, "unsupported": unsupported_fields}
121
        if w == weights_enum.DEFAULT:
122
            if module_name == "quantization":
123
                # parameters() count doesn't work well with quantization, so we check against the non-quantized
124
125
126
127
128
129
130
                unquantized_w = w.meta.get("unquantized")
                if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
                    incorrect_params.append(w)
            else:
                if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
                    incorrect_params.append(w)
        else:
131
            if w.meta.get("num_params") != weights_enum.DEFAULT.meta.get("num_params"):
132
133
                if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
                    incorrect_params.append(w)
134
135
        if not w.name.isupper():
            bad_names.append(w)
136
137

    assert not problematic_weights
138
    assert not incorrect_params
139
    assert not bad_names
140
141


142
@pytest.mark.parametrize(
143
144
145
146
147
    "model_fn",
    TM.get_models_from_module(models)
    + TM.get_models_from_module(models.detection)
    + TM.get_models_from_module(models.quantization)
    + TM.get_models_from_module(models.segmentation)
148
149
    + TM.get_models_from_module(models.video)
    + TM.get_models_from_module(models.optical_flow),
150
)
151
152
153
154
155
156
157
@run_if_test_with_extended
def test_transforms_jit(model_fn):
    model_name = model_fn.__name__
    weights_enum = _get_model_weights(model_fn)
    if len(weights_enum) == 0:
        pytest.skip(f"Model '{model_name}' doesn't have any pre-trained weights.")

158
    defaults = {
159
160
161
        "models": {
            "input_shape": (1, 3, 224, 224),
        },
162
163
164
        "detection": {
            "input_shape": (3, 300, 300),
        },
165
166
167
        "quantization": {
            "input_shape": (1, 3, 224, 224),
        },
168
169
170
171
        "segmentation": {
            "input_shape": (1, 3, 520, 520),
        },
        "video": {
172
            "input_shape": (1, 4, 112, 112, 3),
173
        },
174
175
176
        "optical_flow": {
            "input_shape": (1, 3, 128, 128),
        },
177
    }
178
    module_name = model_fn.__module__.split(".")[-2]
179

180
181
182
    kwargs = {**defaults[module_name], **TM._model_params.get(model_name, {})}
    input_shape = kwargs.pop("input_shape")
    x = torch.rand(input_shape)
183
    if module_name == "optical_flow":
184
        args = (x, x)
185
    else:
186
        args = (x,)
187

188
189
190
191
192
193
194
    problematic_weights = []
    for w in weights_enum:
        transforms = w.transforms()
        try:
            TM._check_jit_scriptable(transforms, args)
        except Exception:
            problematic_weights.append(w)
195

196
    assert not problematic_weights
Philip Meier's avatar
Philip Meier committed
197
198
199
200
201


# With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface:
202
    class ModelWeights(WeightsEnum):
Philip Meier's avatar
Philip Meier committed
203
204
205
206
207
208
209
        Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())

    @pytest.mark.parametrize(
        "kwargs",
        [
            pytest.param(dict(), id="empty"),
            pytest.param(dict(weights=None), id="None"),
210
            pytest.param(dict(weights=ModelWeights.Sentinel), id="Weights"),
Philip Meier's avatar
Philip Meier committed
211
212
213
        ],
    )
    def test_no_warn(self, kwargs):
214
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
215
216
217
218
219
220
221
        def builder(*, weights=None):
            pass

        builder(**kwargs)

    @pytest.mark.parametrize("pretrained", (True, False))
    def test_pretrained_pos(self, pretrained):
222
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
223
224
225
226
227
228
229
230
        def builder(*, weights=None):
            pass

        with pytest.warns(UserWarning, match="positional"):
            builder(pretrained)

    @pytest.mark.parametrize("pretrained", (True, False))
    def test_pretrained_kw(self, pretrained):
231
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
232
233
234
235
236
237
238
239
240
        def builder(*, weights=None):
            pass

        with pytest.warns(UserWarning, match="deprecated"):
            builder(pretrained)

    @pytest.mark.parametrize("pretrained", (True, False))
    @pytest.mark.parametrize("positional", (True, False))
    def test_equivalent_behavior_weights(self, pretrained, positional):
241
        @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
Philip Meier's avatar
Philip Meier committed
242
243
244
245
        def builder(*, weights=None):
            pass

        args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
246
        with pytest.warns(UserWarning, match=f"weights={self.ModelWeights.Sentinel if pretrained else None}"):
Philip Meier's avatar
Philip Meier committed
247
248
249
250
251
252
253
254
            builder(*args, **kwargs)

    def test_multi_params(self):
        weights_params = ("weights", "weights_other")
        pretrained_params = [param.replace("weights", "pretrained") for param in weights_params]

        @handle_legacy_interface(
            **{
255
                weights_param: (pretrained_param, self.ModelWeights.Sentinel)
Philip Meier's avatar
Philip Meier committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                for weights_param, pretrained_param in zip(weights_params, pretrained_params)
            }
        )
        def builder(*, weights=None, weights_other=None):
            pass

        for pretrained_param in pretrained_params:
            with pytest.warns(UserWarning, match="deprecated"):
                builder(**{pretrained_param: True})

    def test_default_callable(self):
        @handle_legacy_interface(
            weights=(
                "pretrained",
270
                lambda kwargs: self.ModelWeights.Sentinel if kwargs["flag"] else None,
Philip Meier's avatar
Philip Meier committed
271
272
273
274
275
276
277
278
279
280
            )
        )
        def builder(*, weights=None, flag):
            pass

        with pytest.warns(UserWarning, match="deprecated"):
            builder(pretrained=True, flag=True)

        with pytest.raises(ValueError, match="weights"):
            builder(pretrained=True, flag=False)