test_backbone_utils.py 13.1 KB
Newer Older
1
import random
2
from itertools import chain
3
from typing import Mapping, Sequence
4

5
import pytest
6
import torch
7
8
9
from common_utils import set_rng_seed
from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter
10
from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
11
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
12
13


14
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
15
def test_resnet_fpn_backbone(backbone_name):
16
    x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
17
    model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None)
18
19
    assert isinstance(model, BackboneWithFPN)
    y = model(x)
20
    assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
21

22
    with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
23
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6)
24
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
25
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3])
26
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
27
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5])
28
29
30
31
32


@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
def test_mobilenet_backbone(backbone_name):
    with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
33
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1)
34
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
35
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2])
36
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
37
38
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6])
    model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True)
39
    assert isinstance(model_fpn, BackboneWithFPN)
40
    model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False)
41
    assert isinstance(model, torch.nn.Sequential)
42

43
44
45
46
47
48

# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
    return int(x)


49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Needed by TestFXFeatureExtraction. Checking that node naming conventions
# are respected. Particularly the index postfix of repeated node names
class TestSubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.submodule = TestSubModule()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.submodule(x)
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


test_module_nodes = [
80
81
82
83
84
85
86
87
88
89
    "x",
    "submodule.add",
    "submodule.add_1",
    "submodule.relu",
    "submodule.relu_1",
    "add",
    "add_1",
    "relu",
    "relu_1",
]
90
91


92
class TestFxFeatureExtraction:
93
    inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
94
    model_defaults = {"num_classes": 1}
95
    leaf_modules = []
96
97
98
99
100
101

    def _create_feature_extractor(self, *args, **kwargs):
        """
        Apply leaf modules
        """
        tracer_kwargs = {}
102
103
        if "tracer_kwargs" not in kwargs:
            tracer_kwargs = {"leaf_modules": self.leaf_modules}
104
        else:
105
106
            tracer_kwargs = kwargs.pop("tracer_kwargs")
        return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
107
108
109

    def _get_return_nodes(self, model):
        set_rng_seed(0)
110
111
112
113
114
115
116
117
118
119
        exclude_nodes_filter = [
            "getitem",
            "floordiv",
            "size",
            "chunk",
            "_assert",
            "eq",
            "dim",
            "getattr",
        ]
120
        train_nodes, eval_nodes = get_graph_node_names(
121
122
            model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
        )
123
124
        # Get rid of any nodes that don't return tensors as they cause issues
        # when testing backward pass.
125
126
        train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
        eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
127
128
        return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)

129
    @pytest.mark.parametrize("model_name", models.list_models(models))
130
131
    def test_build_fx_feature_extractor(self, model_name):
        set_rng_seed(0)
132
        model = models.get_model(model_name, **self.model_defaults).eval()
133
134
135
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        # Check that it works with both a list and dict for return nodes
        self._create_feature_extractor(
136
137
            model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
        )
138
        self._create_feature_extractor(
139
140
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
141
        # Check must specify return nodes
142
        with pytest.raises(ValueError):
143
144
145
            self._create_feature_extractor(model)
        # Check return_nodes and train_return_nodes / eval_return nodes
        # mutual exclusivity
146
        with pytest.raises(ValueError):
147
            self._create_feature_extractor(
148
149
                model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
            )
150
        # Check train_return_nodes / eval_return nodes must both be specified
151
        with pytest.raises(ValueError):
152
            self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
153
154
155
        # Check invalid node name raises ValueError
        with pytest.raises(ValueError):
            # First just double check that this node really doesn't exist
156
157
            if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
                self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
158
159
160
            else:  # otherwise skip this check
                raise ValueError

161
162
163
164
165
    def test_node_name_conventions(self):
        model = TestModule()
        train_nodes, _ = get_graph_node_names(model)
        assert all(a == b for a, b in zip(train_nodes, test_module_nodes))

166
    @pytest.mark.parametrize("model_name", models.list_models(models))
167
    def test_forward_backward(self, model_name):
168
        model = models.get_model(model_name, **self.model_defaults).train()
169
170
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
171
172
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
173
        out = model(self.inp)
174
175
176
        out_agg = 0
        for node_out in out.values():
            if isinstance(node_out, Sequence):
177
                out_agg += sum(o.float().mean() for o in node_out if o is not None)
178
            elif isinstance(node_out, Mapping):
179
                out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
180
181
            else:
                # Assume that the only other alternative at this point is a Tensor
182
                out_agg += node_out.float().mean()
183
        out_agg.backward()
184
185
186

    def test_feature_extraction_methods_equivalence(self):
        model = models.resnet18(**self.model_defaults).eval()
187
188
189
        return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}

        ilg_model = IntermediateLayerGetter(model, return_layers).eval()
190
191
192
        fx_model = self._create_feature_extractor(model, return_layers)

        # Check that we have same parameters
193
        for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
194
195
196
            assert n1 == n2
            assert p1.equal(p2)

197
        # And that outputs match
198
199
200
201
202
203
204
        with torch.no_grad():
            ilg_out = ilg_model(self.inp)
            fgn_out = fx_model(self.inp)
        assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
        for k in ilg_out.keys():
            assert ilg_out[k].equal(fgn_out[k])

205
    @pytest.mark.parametrize("model_name", models.list_models(models))
206
207
    def test_jit_forward_backward(self, model_name):
        set_rng_seed(0)
208
        model = models.get_model(model_name, **self.model_defaults).train()
209
210
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
211
212
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
213
214
        model = torch.jit.script(model)
        fgn_out = model(self.inp)
215
216
217
        out_agg = 0
        for node_out in fgn_out.values():
            if isinstance(node_out, Sequence):
218
                out_agg += sum(o.float().mean() for o in node_out if o is not None)
219
            elif isinstance(node_out, Mapping):
220
                out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
221
222
            else:
                # Assume that the only other alternative at this point is a Tensor
223
                out_agg += node_out.float().mean()
224
        out_agg.backward()
225
226
227
228
229

    def test_train_eval(self):
        class TestModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
230
                self.dropout = torch.nn.Dropout(p=1.0)
231
232

            def forward(self, x):
233
                x = x.float().mean()
234
235
236
237
238
239
240
241
242
243
                x = self.dropout(x)  # dropout
                if self.training:
                    x += 100  # add
                else:
                    x *= 0  # mul
                x -= 0  # sub
                return x

        model = TestModel()

244
245
        train_return_nodes = ["dropout", "add", "sub"]
        eval_return_nodes = ["dropout", "mul", "sub"]
246
247
248
249

        def checks(model, mode):
            with torch.no_grad():
                out = model(torch.ones(10, 10))
250
            if mode == "train":
251
                # Check that dropout is respected
252
                assert out["dropout"].item() == 0
253
                # Check that control flow dependent on training_mode is respected
254
255
256
257
                assert out["sub"].item() == 100
                assert "add" in out
                assert "mul" not in out
            elif mode == "eval":
258
                # Check that dropout is respected
259
                assert out["dropout"].item() == 1
260
                # Check that control flow dependent on training_mode is respected
261
262
263
                assert out["sub"].item() == 0
                assert "mul" in out
                assert "add" not in out
264
265
266
267

        # Starting from train mode
        model.train()
        fx_model = self._create_feature_extractor(
268
269
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
270
271
272
273
        # Check that the models stay in their original training state
        assert model.training
        assert fx_model.training
        # Check outputs
274
        checks(fx_model, "train")
275
276
        # Check outputs after switching to eval mode
        fx_model.eval()
277
        checks(fx_model, "eval")
278
279
280
281

        # Starting from eval mode
        model.eval()
        fx_model = self._create_feature_extractor(
282
283
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
284
285
286
287
        # Check that the models stay in their original training state
        assert not model.training
        assert not fx_model.training
        # Check outputs
288
        checks(fx_model, "eval")
289
290
        # Check outputs after switching to train mode
        fx_model.train()
291
        checks(fx_model, "train")
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311

    def test_leaf_module_and_function(self):
        class LeafModule(torch.nn.Module):
            def forward(self, x):
                # This would raise a TypeError if it were not in a leaf module
                int(x.shape[0])
                return torch.nn.functional.relu(x + 4)

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 1, 3)
                self.leaf_module = LeafModule()

            def forward(self, x):
                leaf_function(x.shape[0])
                x = self.conv(x)
                return self.leaf_module(x)

        model = self._create_feature_extractor(
312
313
314
315
            TestModule(),
            return_nodes=["leaf_module"],
            tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
        ).train()
316
317

        # Check that LeafModule is not in the list of nodes
318
319
        assert "relu" not in [str(n) for n in model.graph.nodes]
        assert "leaf_module" in [str(n) for n in model.graph.nodes]
320
321
322
323

        # Check forward
        out = model(self.inp)
        # And backward
324
        out["leaf_module"].float().mean().backward()