test_models.py 7.19 KB
Newer Older
1
2
from collections import OrderedDict
from itertools import product
3
4
5
import torch
from torchvision import models
import unittest
eellison's avatar
eellison committed
6
import traceback
7
8


9
def get_available_classification_models():
10
    # TODO add a registration mechanism to torchvision.models
11
12
13
14
15
16
    return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


def get_available_segmentation_models():
    # TODO add a registration mechanism to torchvision.models
    return [k for k, v in models.segmentation.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
17
18


19
20
21
22
23
def get_available_detection_models():
    # TODO add a registration mechanism to torchvision.models
    return [k for k, v in models.detection.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


24
25
26
27
28
def get_available_video_models():
    # TODO add a registration mechanism to torchvision.models
    return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


29
30
31
32
# model_name, expected to script without error
torchub_models = {
    "deeplabv3_resnet101": False,
    "mobilenet_v2": True,
eellison's avatar
eellison committed
33
    "resnext50_32x4d": True,
34
35
    "fcn_resnet101": False,
    "googlenet": False,
eellison's avatar
eellison committed
36
    "densenet121": True,
37
    "resnet18": True,
38
    "alexnet": True,
39
    "shufflenet_v2_x1_0": True,
40
41
42
43
44
45
    "squeezenet1_0": True,
    "vgg11": True,
    "inception_v3": False,
}


46
class Tester(unittest.TestCase):
47
48
49
50
    def check_script(self, model, name):
        if name not in torchub_models:
            return
        scriptable = True
eellison's avatar
eellison committed
51
        msg = ""
52
53
        try:
            torch.jit.script(model)
eellison's avatar
eellison committed
54
55
        except Exception as e:
            tb = traceback.format_exc()
56
            scriptable = False
eellison's avatar
eellison committed
57
58
            msg = str(e) + str(tb)
        self.assertEqual(torchub_models[name], scriptable, msg)
59

60
    def _test_classification_model(self, name, input_shape):
61
62
        # passing num_class equal to a number other than 1000 helps in making the test
        # more enforcing in nature
63
        model = models.__dict__[name](num_classes=50)
64
        self.check_script(model, name)
65
66
67
        model.eval()
        x = torch.rand(input_shape)
        out = model(x)
68
        self.assertEqual(out.shape[-1], 50)
69

70
71
72
73
    def _test_segmentation_model(self, name):
        # passing num_class equal to a number other than 1000 helps in making the test
        # more enforcing in nature
        model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False)
74
        self.check_script(model, name)
75
76
77
78
79
80
        model.eval()
        input_shape = (1, 3, 300, 300)
        x = torch.rand(input_shape)
        out = model(x)
        self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))

81
82
    def _test_detection_model(self, name):
        model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
83
        self.check_script(model, name)
84
85
86
        model.eval()
        input_shape = (3, 300, 300)
        x = torch.rand(input_shape)
87
88
89
        model_input = [x]
        out = model(model_input)
        self.assertIs(model_input[0], x)
90
91
92
93
94
        self.assertEqual(len(out), 1)
        self.assertTrue("boxes" in out[0])
        self.assertTrue("scores" in out[0])
        self.assertTrue("labels" in out[0])

95
96
97
    def _test_video_model(self, name):
        # the default input shape is
        # bs * num_channels * clip_len * h *w
98
        input_shape = (1, 3, 4, 112, 112)
99
100
        # test both basicblock and Bottleneck
        model = models.video.__dict__[name](num_classes=50)
101
        self.check_script(model, name)
102
103
104
105
        x = torch.rand(input_shape)
        out = model(x)
        self.assertEqual(out.shape[-1], 50)

106
107
108
109
110
111
112
113
114
    def _make_sliced_model(self, model, stop_layer):
        layers = OrderedDict()
        for name, layer in model.named_children():
            layers[name] = layer
            if name == stop_layer:
                break
        new_model = torch.nn.Sequential(layers)
        return new_model

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    def test_memory_efficient_densenet(self):
        input_shape = (1, 3, 300, 300)
        x = torch.rand(input_shape)

        for name in ['densenet121', 'densenet169', 'densenet201', 'densenet161']:
            model1 = models.__dict__[name](num_classes=50, memory_efficient=True)
            params = model1.state_dict()
            model1.eval()
            out1 = model1(x)
            out1.sum().backward()

            model2 = models.__dict__[name](num_classes=50, memory_efficient=False)
            model2.load_state_dict(params)
            model2.eval()
            out2 = model2(x)

            max_diff = (out1 - out2).abs().max()

            self.assertTrue(max_diff < 1e-5)

135
136
137
138
139
140
141
142
143
144
145
    def test_resnet_dilation(self):
        # TODO improve tests to also check that each layer has the right dimensionality
        for i in product([False, True], [False, True], [False, True]):
            model = models.__dict__["resnet50"](replace_stride_with_dilation=i)
            model = self._make_sliced_model(model, stop_layer="layer4")
            model.eval()
            x = torch.rand(1, 3, 224, 224)
            out = model(x)
            f = 2 ** sum(i)
            self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f))

146
147
148
149
150
151
152
    def test_mobilenetv2_residual_setting(self):
        model = models.__dict__["mobilenet_v2"](inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]])
        model.eval()
        x = torch.rand(1, 3, 224, 224)
        out = model(x)
        self.assertEqual(out.shape[-1], 1000)

153
154
155
156
157
158
159
160
161
162
163
164
165
166
    def test_fasterrcnn_double(self):
        model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
        model.double()
        model.eval()
        input_shape = (3, 300, 300)
        x = torch.rand(input_shape, dtype=torch.float64)
        model_input = [x]
        out = model(model_input)
        self.assertIs(model_input[0], x)
        self.assertEqual(len(out), 1)
        self.assertTrue("boxes" in out[0])
        self.assertTrue("scores" in out[0])
        self.assertTrue("labels" in out[0])

167

168
for model_name in get_available_classification_models():
169
170
171
172
173
174
    # for-loop bodies don't define scopes, so we have to save the variables
    # we want to close over in some way
    def do_test(self, model_name=model_name):
        input_shape = (1, 3, 224, 224)
        if model_name in ['inception_v3']:
            input_shape = (1, 3, 299, 299)
175
176
177
178
179
180
181
182
183
184
        self._test_classification_model(model_name, input_shape)

    setattr(Tester, "test_" + model_name, do_test)


for model_name in get_available_segmentation_models():
    # for-loop bodies don't define scopes, so we have to save the variables
    # we want to close over in some way
    def do_test(self, model_name=model_name):
        self._test_segmentation_model(model_name)
185
186
187
188

    setattr(Tester, "test_" + model_name, do_test)


189
190
191
192
193
194
195
196
for model_name in get_available_detection_models():
    # for-loop bodies don't define scopes, so we have to save the variables
    # we want to close over in some way
    def do_test(self, model_name=model_name):
        self._test_detection_model(model_name)

    setattr(Tester, "test_" + model_name, do_test)

197

198
199
200
201
202
203
for model_name in get_available_video_models():

    def do_test(self, model_name=model_name):
        self._test_video_model(model_name)

    setattr(Tester, "test_" + model_name, do_test)
204

205
206
if __name__ == '__main__':
    unittest.main()