test_backbone_utils.py 13.3 KB
Newer Older
1
import random
2
from itertools import chain
3
from typing import Mapping, Sequence
4

5
import pytest
6
import torch
7
8
9
from common_utils import set_rng_seed
from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter
10
from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
11
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
12
13
14
15


def get_available_models():
    # TODO add a registration mechanism to torchvision.models
16
17
18
19
20
    return [
        k
        for k, v in models.__dict__.items()
        if callable(v) and k[0].lower() == k[0] and k[0] != "_" and k != "get_weight"
    ]
21

22

23
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
24
def test_resnet_fpn_backbone(backbone_name):
25
    x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
26
    model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None)
27
28
    assert isinstance(model, BackboneWithFPN)
    y = model(x)
29
    assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
30

31
    with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
32
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6)
33
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
34
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3])
35
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
36
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5])
37
38
39
40
41


@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
def test_mobilenet_backbone(backbone_name):
    with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
42
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1)
43
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
44
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2])
45
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
46
47
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6])
    model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True)
48
    assert isinstance(model_fpn, BackboneWithFPN)
49
    model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False)
50
    assert isinstance(model, torch.nn.Sequential)
51

52
53
54
55
56
57

# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
    return int(x)


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Needed by TestFXFeatureExtraction. Checking that node naming conventions
# are respected. Particularly the index postfix of repeated node names
class TestSubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.submodule = TestSubModule()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.submodule(x)
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


test_module_nodes = [
89
90
91
92
93
94
95
96
97
98
    "x",
    "submodule.add",
    "submodule.add_1",
    "submodule.relu",
    "submodule.relu_1",
    "add",
    "add_1",
    "relu",
    "relu_1",
]
99
100


101
class TestFxFeatureExtraction:
102
    inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
103
    model_defaults = {"num_classes": 1}
104
    leaf_modules = []
105
106
107
108
109
110

    def _create_feature_extractor(self, *args, **kwargs):
        """
        Apply leaf modules
        """
        tracer_kwargs = {}
111
112
        if "tracer_kwargs" not in kwargs:
            tracer_kwargs = {"leaf_modules": self.leaf_modules}
113
        else:
114
115
            tracer_kwargs = kwargs.pop("tracer_kwargs")
        return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
116
117
118

    def _get_return_nodes(self, model):
        set_rng_seed(0)
119
120
121
122
123
124
125
126
127
128
        exclude_nodes_filter = [
            "getitem",
            "floordiv",
            "size",
            "chunk",
            "_assert",
            "eq",
            "dim",
            "getattr",
        ]
129
        train_nodes, eval_nodes = get_graph_node_names(
130
131
            model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
        )
132
133
        # Get rid of any nodes that don't return tensors as they cause issues
        # when testing backward pass.
134
135
        train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
        eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
136
137
        return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)

138
    @pytest.mark.parametrize("model_name", get_available_models())
139
140
141
142
143
144
    def test_build_fx_feature_extractor(self, model_name):
        set_rng_seed(0)
        model = models.__dict__[model_name](**self.model_defaults).eval()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        # Check that it works with both a list and dict for return nodes
        self._create_feature_extractor(
145
146
            model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
        )
147
        self._create_feature_extractor(
148
149
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
150
        # Check must specify return nodes
151
        with pytest.raises(ValueError):
152
153
154
            self._create_feature_extractor(model)
        # Check return_nodes and train_return_nodes / eval_return nodes
        # mutual exclusivity
155
        with pytest.raises(ValueError):
156
            self._create_feature_extractor(
157
158
                model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
            )
159
        # Check train_return_nodes / eval_return nodes must both be specified
160
        with pytest.raises(ValueError):
161
            self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
162
163
164
        # Check invalid node name raises ValueError
        with pytest.raises(ValueError):
            # First just double check that this node really doesn't exist
165
166
            if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
                self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
167
168
169
            else:  # otherwise skip this check
                raise ValueError

170
171
172
173
174
    def test_node_name_conventions(self):
        model = TestModule()
        train_nodes, _ = get_graph_node_names(model)
        assert all(a == b for a, b in zip(train_nodes, test_module_nodes))

175
    @pytest.mark.parametrize("model_name", get_available_models())
176
177
178
179
    def test_forward_backward(self, model_name):
        model = models.__dict__[model_name](**self.model_defaults).train()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
180
181
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
182
        out = model(self.inp)
183
184
185
        out_agg = 0
        for node_out in out.values():
            if isinstance(node_out, Sequence):
186
                out_agg += sum(o.float().mean() for o in node_out if o is not None)
187
            elif isinstance(node_out, Mapping):
188
                out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
189
190
            else:
                # Assume that the only other alternative at this point is a Tensor
191
                out_agg += node_out.float().mean()
192
        out_agg.backward()
193
194
195

    def test_feature_extraction_methods_equivalence(self):
        model = models.resnet18(**self.model_defaults).eval()
196
197
198
        return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}

        ilg_model = IntermediateLayerGetter(model, return_layers).eval()
199
200
201
        fx_model = self._create_feature_extractor(model, return_layers)

        # Check that we have same parameters
202
        for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
203
204
205
206
207
208
209
210
211
212
213
            assert n1 == n2
            assert p1.equal(p2)

        # And that ouputs match
        with torch.no_grad():
            ilg_out = ilg_model(self.inp)
            fgn_out = fx_model(self.inp)
        assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
        for k in ilg_out.keys():
            assert ilg_out[k].equal(fgn_out[k])

214
    @pytest.mark.parametrize("model_name", get_available_models())
215
216
217
218
219
    def test_jit_forward_backward(self, model_name):
        set_rng_seed(0)
        model = models.__dict__[model_name](**self.model_defaults).train()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
220
221
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
222
223
        model = torch.jit.script(model)
        fgn_out = model(self.inp)
224
225
226
        out_agg = 0
        for node_out in fgn_out.values():
            if isinstance(node_out, Sequence):
227
                out_agg += sum(o.float().mean() for o in node_out if o is not None)
228
            elif isinstance(node_out, Mapping):
229
                out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
230
231
            else:
                # Assume that the only other alternative at this point is a Tensor
232
                out_agg += node_out.float().mean()
233
        out_agg.backward()
234
235
236
237
238

    def test_train_eval(self):
        class TestModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
239
                self.dropout = torch.nn.Dropout(p=1.0)
240
241

            def forward(self, x):
242
                x = x.float().mean()
243
244
245
246
247
248
249
250
251
252
                x = self.dropout(x)  # dropout
                if self.training:
                    x += 100  # add
                else:
                    x *= 0  # mul
                x -= 0  # sub
                return x

        model = TestModel()

253
254
        train_return_nodes = ["dropout", "add", "sub"]
        eval_return_nodes = ["dropout", "mul", "sub"]
255
256
257
258

        def checks(model, mode):
            with torch.no_grad():
                out = model(torch.ones(10, 10))
259
            if mode == "train":
260
                # Check that dropout is respected
261
                assert out["dropout"].item() == 0
262
                # Check that control flow dependent on training_mode is respected
263
264
265
266
                assert out["sub"].item() == 100
                assert "add" in out
                assert "mul" not in out
            elif mode == "eval":
267
                # Check that dropout is respected
268
                assert out["dropout"].item() == 1
269
                # Check that control flow dependent on training_mode is respected
270
271
272
                assert out["sub"].item() == 0
                assert "mul" in out
                assert "add" not in out
273
274
275
276

        # Starting from train mode
        model.train()
        fx_model = self._create_feature_extractor(
277
278
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
279
280
281
282
        # Check that the models stay in their original training state
        assert model.training
        assert fx_model.training
        # Check outputs
283
        checks(fx_model, "train")
284
285
        # Check outputs after switching to eval mode
        fx_model.eval()
286
        checks(fx_model, "eval")
287
288
289
290

        # Starting from eval mode
        model.eval()
        fx_model = self._create_feature_extractor(
291
292
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
293
294
295
296
        # Check that the models stay in their original training state
        assert not model.training
        assert not fx_model.training
        # Check outputs
297
        checks(fx_model, "eval")
298
299
        # Check outputs after switching to train mode
        fx_model.train()
300
        checks(fx_model, "train")
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

    def test_leaf_module_and_function(self):
        class LeafModule(torch.nn.Module):
            def forward(self, x):
                # This would raise a TypeError if it were not in a leaf module
                int(x.shape[0])
                return torch.nn.functional.relu(x + 4)

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 1, 3)
                self.leaf_module = LeafModule()

            def forward(self, x):
                leaf_function(x.shape[0])
                x = self.conv(x)
                return self.leaf_module(x)

        model = self._create_feature_extractor(
321
322
323
324
            TestModule(),
            return_nodes=["leaf_module"],
            tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
        ).train()
325
326

        # Check that LeafModule is not in the list of nodes
327
328
        assert "relu" not in [str(n) for n in model.graph.nodes]
        assert "leaf_module" in [str(n) for n in model.graph.nodes]
329
330
331
332

        # Check forward
        out = model(self.inp)
        # And backward
333
        out["leaf_module"].float().mean().backward()