"graphbolt/CMakeLists.txt" did not exist on "7a2f4943736f6b6ccd93b069d2b945c5b4ec9d9f"
test_backbone_utils.py 10.6 KB
Newer Older
1
import random
2
3
from itertools import chain

4
import pytest
5
import torch
6
7
8
from common_utils import set_rng_seed
from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter
9
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
10
11
12
13
14
15
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names


def get_available_models():
    # TODO add a registration mechanism to torchvision.models
16
    return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
17

18

19
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
20
def test_resnet_fpn_backbone(backbone_name):
21
    x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
22
    y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
23
    assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
24
25
26
27
28
29
30


# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
    return int(x)


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Needed by TestFXFeatureExtraction. Checking that node naming conventions
# are respected. Particularly the index postfix of repeated node names
class TestSubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.submodule = TestSubModule()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.submodule(x)
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


test_module_nodes = [
62
63
64
65
66
67
68
69
70
71
    "x",
    "submodule.add",
    "submodule.add_1",
    "submodule.relu",
    "submodule.relu_1",
    "add",
    "add_1",
    "relu",
    "relu_1",
]
72
73


74
class TestFxFeatureExtraction:
75
76
    inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
    model_defaults = {"num_classes": 1, "pretrained": False}
77
    leaf_modules = []
78
79
80
81
82
83

    def _create_feature_extractor(self, *args, **kwargs):
        """
        Apply leaf modules
        """
        tracer_kwargs = {}
84
85
        if "tracer_kwargs" not in kwargs:
            tracer_kwargs = {"leaf_modules": self.leaf_modules}
86
        else:
87
88
            tracer_kwargs = kwargs.pop("tracer_kwargs")
        return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
89
90
91

    def _get_return_nodes(self, model):
        set_rng_seed(0)
92
        exclude_nodes_filter = ["getitem", "floordiv", "size", "chunk"]
93
        train_nodes, eval_nodes = get_graph_node_names(
94
95
            model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
        )
96
97
        # Get rid of any nodes that don't return tensors as they cause issues
        # when testing backward pass.
98
99
        train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
        eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
100
101
        return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)

102
    @pytest.mark.parametrize("model_name", get_available_models())
103
104
105
106
107
108
    def test_build_fx_feature_extractor(self, model_name):
        set_rng_seed(0)
        model = models.__dict__[model_name](**self.model_defaults).eval()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        # Check that it works with both a list and dict for return nodes
        self._create_feature_extractor(
109
110
            model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
        )
111
        self._create_feature_extractor(
112
113
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
114
115
116
117
118
119
120
        # Check must specify return nodes
        with pytest.raises(AssertionError):
            self._create_feature_extractor(model)
        # Check return_nodes and train_return_nodes / eval_return nodes
        # mutual exclusivity
        with pytest.raises(AssertionError):
            self._create_feature_extractor(
121
122
                model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
            )
123
124
        # Check train_return_nodes / eval_return nodes must both be specified
        with pytest.raises(AssertionError):
125
            self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
126
127
128
        # Check invalid node name raises ValueError
        with pytest.raises(ValueError):
            # First just double check that this node really doesn't exist
129
130
            if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
                self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
131
132
133
            else:  # otherwise skip this check
                raise ValueError

134
135
136
137
138
    def test_node_name_conventions(self):
        model = TestModule()
        train_nodes, _ = get_graph_node_names(model)
        assert all(a == b for a, b in zip(train_nodes, test_module_nodes))

139
    @pytest.mark.parametrize("model_name", get_available_models())
140
141
142
143
    def test_forward_backward(self, model_name):
        model = models.__dict__[model_name](**self.model_defaults).train()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
144
145
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
146
147
148
149
150
        out = model(self.inp)
        sum([o.mean() for o in out.values()]).backward()

    def test_feature_extraction_methods_equivalence(self):
        model = models.resnet18(**self.model_defaults).eval()
151
152
153
        return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}

        ilg_model = IntermediateLayerGetter(model, return_layers).eval()
154
155
156
        fx_model = self._create_feature_extractor(model, return_layers)

        # Check that we have same parameters
157
        for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
158
159
160
161
162
163
164
165
166
167
168
            assert n1 == n2
            assert p1.equal(p2)

        # And that ouputs match
        with torch.no_grad():
            ilg_out = ilg_model(self.inp)
            fgn_out = fx_model(self.inp)
        assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
        for k in ilg_out.keys():
            assert ilg_out[k].equal(fgn_out[k])

169
    @pytest.mark.parametrize("model_name", get_available_models())
170
171
172
173
174
    def test_jit_forward_backward(self, model_name):
        set_rng_seed(0)
        model = models.__dict__[model_name](**self.model_defaults).train()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
175
176
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
177
178
179
180
181
182
183
184
        model = torch.jit.script(model)
        fgn_out = model(self.inp)
        sum([o.mean() for o in fgn_out.values()]).backward()

    def test_train_eval(self):
        class TestModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
185
                self.dropout = torch.nn.Dropout(p=1.0)
186
187
188
189
190
191
192
193
194
195
196
197
198

            def forward(self, x):
                x = x.mean()
                x = self.dropout(x)  # dropout
                if self.training:
                    x += 100  # add
                else:
                    x *= 0  # mul
                x -= 0  # sub
                return x

        model = TestModel()

199
200
        train_return_nodes = ["dropout", "add", "sub"]
        eval_return_nodes = ["dropout", "mul", "sub"]
201
202
203
204

        def checks(model, mode):
            with torch.no_grad():
                out = model(torch.ones(10, 10))
205
            if mode == "train":
206
                # Check that dropout is respected
207
                assert out["dropout"].item() == 0
208
                # Check that control flow dependent on training_mode is respected
209
210
211
212
                assert out["sub"].item() == 100
                assert "add" in out
                assert "mul" not in out
            elif mode == "eval":
213
                # Check that dropout is respected
214
                assert out["dropout"].item() == 1
215
                # Check that control flow dependent on training_mode is respected
216
217
218
                assert out["sub"].item() == 0
                assert "mul" in out
                assert "add" not in out
219
220
221
222

        # Starting from train mode
        model.train()
        fx_model = self._create_feature_extractor(
223
224
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
225
226
227
228
        # Check that the models stay in their original training state
        assert model.training
        assert fx_model.training
        # Check outputs
229
        checks(fx_model, "train")
230
231
        # Check outputs after switching to eval mode
        fx_model.eval()
232
        checks(fx_model, "eval")
233
234
235
236

        # Starting from eval mode
        model.eval()
        fx_model = self._create_feature_extractor(
237
238
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
239
240
241
242
        # Check that the models stay in their original training state
        assert not model.training
        assert not fx_model.training
        # Check outputs
243
        checks(fx_model, "eval")
244
245
        # Check outputs after switching to train mode
        fx_model.train()
246
        checks(fx_model, "train")
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

    def test_leaf_module_and_function(self):
        class LeafModule(torch.nn.Module):
            def forward(self, x):
                # This would raise a TypeError if it were not in a leaf module
                int(x.shape[0])
                return torch.nn.functional.relu(x + 4)

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 1, 3)
                self.leaf_module = LeafModule()

            def forward(self, x):
                leaf_function(x.shape[0])
                x = self.conv(x)
                return self.leaf_module(x)

        model = self._create_feature_extractor(
267
268
269
270
            TestModule(),
            return_nodes=["leaf_module"],
            tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
        ).train()
271
272

        # Check that LeafModule is not in the list of nodes
273
274
        assert "relu" not in [str(n) for n in model.graph.nodes]
        assert "leaf_module" in [str(n) for n in model.graph.nodes]
275
276
277
278

        # Check forward
        out = model(self.inp)
        # And backward
279
        out["leaf_module"].mean().backward()