test_backbone_utils.py 12.8 KB
Newer Older
1
import random
2
from itertools import chain
3
from typing import Mapping, Sequence
4

5
import pytest
6
import torch
7
8
9
from common_utils import set_rng_seed
from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter
10
11
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
12
13
14
15


def get_available_models():
    # TODO add a registration mechanism to torchvision.models
16
    return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
17

18

19
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
20
def test_resnet_fpn_backbone(backbone_name):
21
    x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
22
    y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
23
    assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
        resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False, trainable_layers=6)
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
        resnet_fpn_backbone(backbone_name, False, returned_layers=[0, 1, 2, 3])
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
        resnet_fpn_backbone(backbone_name, False, returned_layers=[2, 3, 4, 5])


@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
def test_mobilenet_backbone(backbone_name):
    with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
        mobilenet_backbone(backbone_name=backbone_name, pretrained=False, fpn=False, trainable_layers=-1)
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
        mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2])
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
        mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6])

42
43
44
45
46
47

# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
    return int(x)


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Needed by TestFXFeatureExtraction. Checking that node naming conventions
# are respected. Particularly the index postfix of repeated node names
class TestSubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.submodule = TestSubModule()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.submodule(x)
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


test_module_nodes = [
79
80
81
82
83
84
85
86
87
88
    "x",
    "submodule.add",
    "submodule.add_1",
    "submodule.relu",
    "submodule.relu_1",
    "add",
    "add_1",
    "relu",
    "relu_1",
]
89
90


91
class TestFxFeatureExtraction:
92
93
    inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
    model_defaults = {"num_classes": 1, "pretrained": False}
94
    leaf_modules = []
95
96
97
98
99
100

    def _create_feature_extractor(self, *args, **kwargs):
        """
        Apply leaf modules
        """
        tracer_kwargs = {}
101
102
        if "tracer_kwargs" not in kwargs:
            tracer_kwargs = {"leaf_modules": self.leaf_modules}
103
        else:
104
105
            tracer_kwargs = kwargs.pop("tracer_kwargs")
        return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
106
107
108

    def _get_return_nodes(self, model):
        set_rng_seed(0)
109
110
111
112
113
114
115
116
117
118
        exclude_nodes_filter = [
            "getitem",
            "floordiv",
            "size",
            "chunk",
            "_assert",
            "eq",
            "dim",
            "getattr",
        ]
119
        train_nodes, eval_nodes = get_graph_node_names(
120
121
            model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
        )
122
123
        # Get rid of any nodes that don't return tensors as they cause issues
        # when testing backward pass.
124
125
        train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
        eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
126
127
        return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)

128
    @pytest.mark.parametrize("model_name", get_available_models())
129
130
131
132
133
134
    def test_build_fx_feature_extractor(self, model_name):
        set_rng_seed(0)
        model = models.__dict__[model_name](**self.model_defaults).eval()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        # Check that it works with both a list and dict for return nodes
        self._create_feature_extractor(
135
136
            model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
        )
137
        self._create_feature_extractor(
138
139
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
140
141
142
143
144
145
146
        # Check must specify return nodes
        with pytest.raises(AssertionError):
            self._create_feature_extractor(model)
        # Check return_nodes and train_return_nodes / eval_return nodes
        # mutual exclusivity
        with pytest.raises(AssertionError):
            self._create_feature_extractor(
147
148
                model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
            )
149
150
        # Check train_return_nodes / eval_return nodes must both be specified
        with pytest.raises(AssertionError):
151
            self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
152
153
154
        # Check invalid node name raises ValueError
        with pytest.raises(ValueError):
            # First just double check that this node really doesn't exist
155
156
            if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
                self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
157
158
159
            else:  # otherwise skip this check
                raise ValueError

160
161
162
163
164
    def test_node_name_conventions(self):
        model = TestModule()
        train_nodes, _ = get_graph_node_names(model)
        assert all(a == b for a, b in zip(train_nodes, test_module_nodes))

165
    @pytest.mark.parametrize("model_name", get_available_models())
166
167
168
169
    def test_forward_backward(self, model_name):
        model = models.__dict__[model_name](**self.model_defaults).train()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
170
171
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
172
        out = model(self.inp)
173
174
175
176
177
178
179
180
181
182
        out_agg = 0
        for node_out in out.values():
            if isinstance(node_out, Sequence):
                out_agg += sum(o.mean() for o in node_out if o is not None)
            elif isinstance(node_out, Mapping):
                out_agg += sum(o.mean() for o in node_out.values() if o is not None)
            else:
                # Assume that the only other alternative at this point is a Tensor
                out_agg += node_out.mean()
        out_agg.backward()
183
184
185

    def test_feature_extraction_methods_equivalence(self):
        model = models.resnet18(**self.model_defaults).eval()
186
187
188
        return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}

        ilg_model = IntermediateLayerGetter(model, return_layers).eval()
189
190
191
        fx_model = self._create_feature_extractor(model, return_layers)

        # Check that we have same parameters
192
        for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
193
194
195
196
197
198
199
200
201
202
203
            assert n1 == n2
            assert p1.equal(p2)

        # And that ouputs match
        with torch.no_grad():
            ilg_out = ilg_model(self.inp)
            fgn_out = fx_model(self.inp)
        assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
        for k in ilg_out.keys():
            assert ilg_out[k].equal(fgn_out[k])

204
    @pytest.mark.parametrize("model_name", get_available_models())
205
206
207
208
209
    def test_jit_forward_backward(self, model_name):
        set_rng_seed(0)
        model = models.__dict__[model_name](**self.model_defaults).train()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
210
211
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
212
213
        model = torch.jit.script(model)
        fgn_out = model(self.inp)
214
215
216
217
218
219
220
221
222
223
        out_agg = 0
        for node_out in fgn_out.values():
            if isinstance(node_out, Sequence):
                out_agg += sum(o.mean() for o in node_out if o is not None)
            elif isinstance(node_out, Mapping):
                out_agg += sum(o.mean() for o in node_out.values() if o is not None)
            else:
                # Assume that the only other alternative at this point is a Tensor
                out_agg += node_out.mean()
        out_agg.backward()
224
225
226
227
228

    def test_train_eval(self):
        class TestModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
229
                self.dropout = torch.nn.Dropout(p=1.0)
230
231
232
233
234
235
236
237
238
239
240
241
242

            def forward(self, x):
                x = x.mean()
                x = self.dropout(x)  # dropout
                if self.training:
                    x += 100  # add
                else:
                    x *= 0  # mul
                x -= 0  # sub
                return x

        model = TestModel()

243
244
        train_return_nodes = ["dropout", "add", "sub"]
        eval_return_nodes = ["dropout", "mul", "sub"]
245
246
247
248

        def checks(model, mode):
            with torch.no_grad():
                out = model(torch.ones(10, 10))
249
            if mode == "train":
250
                # Check that dropout is respected
251
                assert out["dropout"].item() == 0
252
                # Check that control flow dependent on training_mode is respected
253
254
255
256
                assert out["sub"].item() == 100
                assert "add" in out
                assert "mul" not in out
            elif mode == "eval":
257
                # Check that dropout is respected
258
                assert out["dropout"].item() == 1
259
                # Check that control flow dependent on training_mode is respected
260
261
262
                assert out["sub"].item() == 0
                assert "mul" in out
                assert "add" not in out
263
264
265
266

        # Starting from train mode
        model.train()
        fx_model = self._create_feature_extractor(
267
268
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
269
270
271
272
        # Check that the models stay in their original training state
        assert model.training
        assert fx_model.training
        # Check outputs
273
        checks(fx_model, "train")
274
275
        # Check outputs after switching to eval mode
        fx_model.eval()
276
        checks(fx_model, "eval")
277
278
279
280

        # Starting from eval mode
        model.eval()
        fx_model = self._create_feature_extractor(
281
282
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
283
284
285
286
        # Check that the models stay in their original training state
        assert not model.training
        assert not fx_model.training
        # Check outputs
287
        checks(fx_model, "eval")
288
289
        # Check outputs after switching to train mode
        fx_model.train()
290
        checks(fx_model, "train")
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

    def test_leaf_module_and_function(self):
        class LeafModule(torch.nn.Module):
            def forward(self, x):
                # This would raise a TypeError if it were not in a leaf module
                int(x.shape[0])
                return torch.nn.functional.relu(x + 4)

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 1, 3)
                self.leaf_module = LeafModule()

            def forward(self, x):
                leaf_function(x.shape[0])
                x = self.conv(x)
                return self.leaf_module(x)

        model = self._create_feature_extractor(
311
312
313
314
            TestModule(),
            return_nodes=["leaf_module"],
            tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
        ).train()
315
316

        # Check that LeafModule is not in the list of nodes
317
318
        assert "relu" not in [str(n) for n in model.graph.nodes]
        assert "leaf_module" in [str(n) for n in model.graph.nodes]
319
320
321
322

        # Check forward
        out = model(self.inp)
        # And backward
323
        out["leaf_module"].mean().backward()