"vscode:/vscode.git/clone" did not exist on "b9e2f886cd6e9182f1bf1bf7421c6363956f94c5"
test_backbone_utils.py 10.7 KB
Newer Older
1
import random
2
3
4
from functools import partial
from itertools import chain

5
import pytest
6
import torch
7
import torchvision
8
9
10
from common_utils import set_rng_seed
from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter
11
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
12
13
14
15
16
17
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names


def get_available_models():
    # TODO add a registration mechanism to torchvision.models
18
    return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
19

20

21
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
22
def test_resnet_fpn_backbone(backbone_name):
23
    x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
24
    y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
25
    assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
26
27
28
29
30
31
32


# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
    return int(x)


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Needed by TestFXFeatureExtraction. Checking that node naming conventions
# are respected. Particularly the index postfix of repeated node names
class TestSubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.submodule = TestSubModule()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.submodule(x)
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


test_module_nodes = [
64
65
66
67
68
69
70
71
72
73
    "x",
    "submodule.add",
    "submodule.add_1",
    "submodule.relu",
    "submodule.relu_1",
    "add",
    "add_1",
    "relu",
    "relu_1",
]
74
75


76
class TestFxFeatureExtraction:
77
78
    inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
    model_defaults = {"num_classes": 1, "pretrained": False}
79
    leaf_modules = []
80
81
82
83
84
85

    def _create_feature_extractor(self, *args, **kwargs):
        """
        Apply leaf modules
        """
        tracer_kwargs = {}
86
87
        if "tracer_kwargs" not in kwargs:
            tracer_kwargs = {"leaf_modules": self.leaf_modules}
88
        else:
89
90
            tracer_kwargs = kwargs.pop("tracer_kwargs")
        return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
91
92
93

    def _get_return_nodes(self, model):
        set_rng_seed(0)
94
        exclude_nodes_filter = ["getitem", "floordiv", "size", "chunk"]
95
        train_nodes, eval_nodes = get_graph_node_names(
96
97
            model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
        )
98
99
        # Get rid of any nodes that don't return tensors as they cause issues
        # when testing backward pass.
100
101
        train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
        eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
102
103
        return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)

104
    @pytest.mark.parametrize("model_name", get_available_models())
105
106
107
108
109
110
    def test_build_fx_feature_extractor(self, model_name):
        set_rng_seed(0)
        model = models.__dict__[model_name](**self.model_defaults).eval()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        # Check that it works with both a list and dict for return nodes
        self._create_feature_extractor(
111
112
            model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
        )
113
        self._create_feature_extractor(
114
115
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
116
117
118
119
120
121
122
        # Check must specify return nodes
        with pytest.raises(AssertionError):
            self._create_feature_extractor(model)
        # Check return_nodes and train_return_nodes / eval_return nodes
        # mutual exclusivity
        with pytest.raises(AssertionError):
            self._create_feature_extractor(
123
124
                model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
            )
125
126
        # Check train_return_nodes / eval_return nodes must both be specified
        with pytest.raises(AssertionError):
127
            self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
128
129
130
        # Check invalid node name raises ValueError
        with pytest.raises(ValueError):
            # First just double check that this node really doesn't exist
131
132
            if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
                self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
133
134
135
            else:  # otherwise skip this check
                raise ValueError

136
137
138
139
140
    def test_node_name_conventions(self):
        model = TestModule()
        train_nodes, _ = get_graph_node_names(model)
        assert all(a == b for a, b in zip(train_nodes, test_module_nodes))

141
    @pytest.mark.parametrize("model_name", get_available_models())
142
143
144
145
    def test_forward_backward(self, model_name):
        model = models.__dict__[model_name](**self.model_defaults).train()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
146
147
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
148
149
150
151
152
        out = model(self.inp)
        sum([o.mean() for o in out.values()]).backward()

    def test_feature_extraction_methods_equivalence(self):
        model = models.resnet18(**self.model_defaults).eval()
153
154
155
        return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}

        ilg_model = IntermediateLayerGetter(model, return_layers).eval()
156
157
158
        fx_model = self._create_feature_extractor(model, return_layers)

        # Check that we have same parameters
159
        for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
160
161
162
163
164
165
166
167
168
169
170
            assert n1 == n2
            assert p1.equal(p2)

        # And that ouputs match
        with torch.no_grad():
            ilg_out = ilg_model(self.inp)
            fgn_out = fx_model(self.inp)
        assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
        for k in ilg_out.keys():
            assert ilg_out[k].equal(fgn_out[k])

171
    @pytest.mark.parametrize("model_name", get_available_models())
172
173
174
175
176
    def test_jit_forward_backward(self, model_name):
        set_rng_seed(0)
        model = models.__dict__[model_name](**self.model_defaults).train()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
177
178
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
179
180
181
182
183
184
185
186
        model = torch.jit.script(model)
        fgn_out = model(self.inp)
        sum([o.mean() for o in fgn_out.values()]).backward()

    def test_train_eval(self):
        class TestModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
187
                self.dropout = torch.nn.Dropout(p=1.0)
188
189
190
191
192
193
194
195
196
197
198
199
200

            def forward(self, x):
                x = x.mean()
                x = self.dropout(x)  # dropout
                if self.training:
                    x += 100  # add
                else:
                    x *= 0  # mul
                x -= 0  # sub
                return x

        model = TestModel()

201
202
        train_return_nodes = ["dropout", "add", "sub"]
        eval_return_nodes = ["dropout", "mul", "sub"]
203
204
205
206

        def checks(model, mode):
            with torch.no_grad():
                out = model(torch.ones(10, 10))
207
            if mode == "train":
208
                # Check that dropout is respected
209
                assert out["dropout"].item() == 0
210
                # Check that control flow dependent on training_mode is respected
211
212
213
214
                assert out["sub"].item() == 100
                assert "add" in out
                assert "mul" not in out
            elif mode == "eval":
215
                # Check that dropout is respected
216
                assert out["dropout"].item() == 1
217
                # Check that control flow dependent on training_mode is respected
218
219
220
                assert out["sub"].item() == 0
                assert "mul" in out
                assert "add" not in out
221
222
223
224

        # Starting from train mode
        model.train()
        fx_model = self._create_feature_extractor(
225
226
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
227
228
229
230
        # Check that the models stay in their original training state
        assert model.training
        assert fx_model.training
        # Check outputs
231
        checks(fx_model, "train")
232
233
        # Check outputs after switching to eval mode
        fx_model.eval()
234
        checks(fx_model, "eval")
235
236
237
238

        # Starting from eval mode
        model.eval()
        fx_model = self._create_feature_extractor(
239
240
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
241
242
243
244
        # Check that the models stay in their original training state
        assert not model.training
        assert not fx_model.training
        # Check outputs
245
        checks(fx_model, "eval")
246
247
        # Check outputs after switching to train mode
        fx_model.train()
248
        checks(fx_model, "train")
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

    def test_leaf_module_and_function(self):
        class LeafModule(torch.nn.Module):
            def forward(self, x):
                # This would raise a TypeError if it were not in a leaf module
                int(x.shape[0])
                return torch.nn.functional.relu(x + 4)

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 1, 3)
                self.leaf_module = LeafModule()

            def forward(self, x):
                leaf_function(x.shape[0])
                x = self.conv(x)
                return self.leaf_module(x)

        model = self._create_feature_extractor(
269
270
271
272
            TestModule(),
            return_nodes=["leaf_module"],
            tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
        ).train()
273
274

        # Check that LeafModule is not in the list of nodes
275
276
        assert "relu" not in [str(n) for n in model.graph.nodes]
        assert "leaf_module" in [str(n) for n in model.graph.nodes]
277
278
279
280

        # Check forward
        out = model(self.inp)
        # And backward
281
        out["leaf_module"].mean().backward()