test_models.py 15.8 KB
Newer Older
eellison's avatar
eellison committed
1
from common_utils import TestCase, map_nested_tensor_object
2
3
from collections import OrderedDict
from itertools import product
4
import torch
eellison's avatar
eellison committed
5
import numpy as np
6
7
from torchvision import models
import unittest
eellison's avatar
eellison committed
8
import traceback
eellison's avatar
eellison committed
9
import random
10
import inspect
eellison's avatar
eellison committed
11
12


13
14
15
16
17
18
STANDARD_NUM_CLASSES = 50
STANDARD_INPUT_SHAPE = (1, 3, 224, 224)
STANDARD_SEED = 1729


def set_rng_seed(seed=STANDARD_SEED):
eellison's avatar
eellison committed
19
20
21
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
22
23


24
25
26
27
28
29
30
31
def subsample_tensor(tensor, num_samples=20):
    num_elems = tensor.numel()
    if num_elems <= num_samples:
        return tensor

    flat_tensor = tensor.flatten()
    ith_index = num_elems // num_samples
    return flat_tensor[ith_index - 1::ith_index]
32
33
34
35
36


def get_available_segmentation_models():
    # TODO add a registration mechanism to torchvision.models
    return [k for k, v in models.segmentation.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
37
38


39
40
41
42
43
def get_available_detection_models():
    # TODO add a registration mechanism to torchvision.models
    return [k for k, v in models.detection.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


44
45
46
47
48
def get_available_video_models():
    # TODO add a registration mechanism to torchvision.models
    return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


49
50
51
52
# models that are in torch hub, as well as r3d_18. we tried testing all models
# but the test was too slow. not included are detection models, because
# they are not yet supported in JIT.
script_test_models = [
53
54
    "deeplabv3_resnet101",
    "fcn_resnet101",
55
    'r3d_18',
56
]
57
58


eellison's avatar
eellison committed
59
class ModelTester(TestCase):
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    # create random tensor with given shape using synced RNG state
    # caching because these tests take pretty long already (instantiating models and all)
    TEST_INPUTS = {}

    def _get_test_input(self, shape=STANDARD_INPUT_SHAPE):
        # NOTE not thread-safe, but should give same results even if multi-threaded testing gave a race condition
        # giving consistent results is kind of the point of this helper method
        if shape not in self.TEST_INPUTS:
            set_rng_seed(STANDARD_SEED)
            self.TEST_INPUTS[shape] = torch.rand(shape)
        return self.TEST_INPUTS[shape]

    # create a randomly-weighted model w/ synced RNG state
    def _get_test_model(self, callable, **kwargs):
        set_rng_seed(STANDARD_SEED)
        model = callable(**kwargs)
        model.eval()
        return model

80
    def check_script(self, model, name):
81
        if name not in script_test_models:
82
83
            return
        scriptable = True
eellison's avatar
eellison committed
84
        msg = ""
85
86
        try:
            torch.jit.script(model)
eellison's avatar
eellison committed
87
88
        except Exception as e:
            tb = traceback.format_exc()
89
            scriptable = False
eellison's avatar
eellison committed
90
            msg = str(e) + str(tb)
91
        self.assertTrue(scriptable, msg)
92

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    def _check_scriptable(self, model, expected):
        if expected is None:  # we don't check scriptability for all models
            return

        actual = True
        msg = ''
        try:
            torch.jit.script(model)
        except Exception as e:
            tb = traceback.format_exc()
            actual = False
            msg = str(e) + str(tb)
        self.assertEqual(actual, expected, msg)


class ClassificationCoverageTester(TestCase):

    # Find all models exposed by torchvision.models factory methods (with assumptions)
    def get_available_classification_models(self):
        # TODO add a registration mechanism to torchvision.models
        return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]

    # Recursively gather test methods from all classification testers
    def get_test_methods_for_class(self, klass):
        all_methods = inspect.getmembers(klass, predicate=inspect.isfunction)
        test_methods = {method[0] for method in all_methods if method[0].startswith('test_')}
        for child in klass.__subclasses__():
            test_methods = test_methods.union(self.get_test_methods_for_class(child))
        return test_methods

    # Verify that all models exposed by torchvision.models factory methods
    #    have corresponding test methods
    # NOTE This does not include some of the extra tests (such as Resnet
    #    dilation) and says nothing about the correctness of the test, nor
    #    of the model. It just enforces a naming scheme on the tests, and
    #    verifies that all models have a corresponding test name.
    def test_classification_model_coverage(self):
        model_names = self.get_available_classification_models()
        test_names = self.get_test_methods_for_class(ClassificationModelTester)

        for model_name in model_names:
            test_name = 'test_' + model_name
            self.assertTrue(test_name in test_names)


class ClassificationModelTester(ModelTester):
    def _infer_for_test_with(self, model, test_input):
        return model(test_input)

    def _check_classification_output_shape(self, test_output, num_classes):
        self.assertEqual(test_output.shape, (1, num_classes))

    # NOTE Depends on presence of test data fixture. See common_utils.py for
    #    details on creating fixtures.
    def _check_model_correctness(self, model, test_input, num_classes=STANDARD_NUM_CLASSES):
        test_output = self._infer_for_test_with(model, test_input)
        self._check_classification_output_shape(test_output, num_classes)
        self.assertExpected(test_output, rtol=1e-5, atol=1e-5)
        return test_output

    # NOTE override this in a child class
    def _get_input_shape(self):
        return STANDARD_INPUT_SHAPE

    def _test_classification_model(self, model_callable, num_classes=STANDARD_NUM_CLASSES, **kwargs):
        model = self._get_test_model(model_callable, num_classes=num_classes, **kwargs)
        self._check_scriptable(model, True)  # currently, all expected to be scriptable
        test_input = self._get_test_input(shape=self._get_input_shape())
        test_output = self._check_model_correctness(model, test_input)
        return model, test_input, test_output


class AlexnetTester(ClassificationModelTester):
    def test_alexnet(self):
        self._test_classification_model(models.alexnet)


# TODO add test for aux_logits arg to factory method
# TODO add test for transform_input arg to factory method
class InceptionV3Tester(ClassificationModelTester):
    def _get_input_shape(self):
        return (1, 3, 299, 299)

    def test_inception_v3(self):
        self._test_classification_model(models.inception_v3)


class SqueezenetTester(ClassificationModelTester):
    def test_squeezenet1_0(self):
        self._test_classification_model(models.squeezenet1_0)

    def test_squeezenet1_1(self):
        self._test_classification_model(models.squeezenet1_1)


# TODO add test for width_mult arg to factory method
class MobilenetTester(ClassificationModelTester):
    def test_mobilenet_v2(self):
        self._test_classification_model(models.mobilenet_v2)

    def test_mobilenetv2_residual_setting(self):
        self._test_classification_model(models.mobilenet_v2, inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]])


# TODO add test for aux_logits arg to factory method
# TODO add test for transform_input arg to factory method
class GooglenetTester(ClassificationModelTester):
    def test_googlenet(self):
        self._test_classification_model(models.googlenet)


class VGGNetTester(ClassificationModelTester):
    def test_vgg11(self):
        self._test_classification_model(models.vgg11)

    def test_vgg11_bn(self):
        self._test_classification_model(models.vgg11_bn)

    def test_vgg13(self):
        self._test_classification_model(models.vgg13)

    def test_vgg13_bn(self):
        self._test_classification_model(models.vgg13_bn)

    def test_vgg16(self):
        self._test_classification_model(models.vgg16)

    def test_vgg16_bn(self):
        self._test_classification_model(models.vgg16_bn)

    def test_vgg19(self):
        self._test_classification_model(models.vgg19)

    def test_vgg19_bn(self):
        self._test_classification_model(models.vgg19_bn)


# TODO add test for dropout arg to factory method
class MNASNetTester(ClassificationModelTester):
    def test_mnasnet0_5(self):
        self._test_classification_model(models.mnasnet0_5)

    def test_mnasnet0_75(self):
        self._test_classification_model(models.mnasnet0_75)

    def test_mnasnet1_0(self):
        self._test_classification_model(models.mnasnet1_0)

    def test_mnasnet1_3(self):
        self._test_classification_model(models.mnasnet1_3)


# TODO add test for bn_size arg to factory method
# TODO add test for drop_rate arg to factory method
class DensenetTester(ClassificationModelTester):
    def _test_densenet_plus_mem_eff(self, model_callable):
        model, test_input, test_output = self._test_classification_model(model_callable)

        # above, we perform the standard correctness test against the test fixture, and capture key test params
        # below, we check that memory efficient/time inefficient DenseNet implementation behaves like the "standard" one
        me_model = self._get_test_model(model_callable, num_classes=STANDARD_NUM_CLASSES, memory_efficient=True)
        me_model.load_state_dict(model.state_dict())  # xfer weights over
        me_output = self._infer_for_test_with(me_model, test_input)
        test_output.squeeze(0)
        me_output.squeeze(0)
        # NOTE testing against same memory fixtures as the non-mem-efficient version
        self.assertExpected(test_output, rtol=1e-5, atol=1e-5)

    def test_densenet121(self):
        self._test_densenet_plus_mem_eff(models.densenet121)

    def test_densenet161(self):
        self._test_densenet_plus_mem_eff(models.densenet161)

    def test_densenet169(self):
        self._test_densenet_plus_mem_eff(models.densenet169)

    def test_densenet201(self):
        self._test_densenet_plus_mem_eff(models.densenet201)


class ShufflenetTester(ClassificationModelTester):
    def test_shufflenet_v2_x0_5(self):
        self._test_classification_model(models.shufflenet_v2_x0_5)

    def test_shufflenet_v2_x1_0(self):
        self._test_classification_model(models.shufflenet_v2_x1_0)

    def test_shufflenet_v2_x1_5(self):
        self._test_classification_model(models.shufflenet_v2_x1_5)

    def test_shufflenet_v2_x2_0(self):
        self._test_classification_model(models.shufflenet_v2_x2_0)


# TODO add test for zero_init_residual arg to factory method
# TODO add test for norm_layer arg to factory method
class ResnetTester(ClassificationModelTester):
    def _get_scriptability_value(self):
        return True

    def test_resnet18(self):
        self._test_classification_model(models.resnet18)

    def test_resnet34(self):
        self._test_classification_model(models.resnet34)

    def test_resnet50(self):
        self._test_classification_model(models.resnet50)

    def test_resnet101(self):
        self._test_classification_model(models.resnet101)

    def test_resnet152(self):
        self._test_classification_model(models.resnet152)

    def test_resnext50_32x4d(self):
        self._test_classification_model(models.resnext50_32x4d)
311

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    def test_resnext101_32x8d(self):
        self._test_classification_model(models.resnext101_32x8d)

    def test_wide_resnet50_2(self):
        self._test_classification_model(models.wide_resnet50_2)

    def test_wide_resnet101_2(self):
        self._test_classification_model(models.wide_resnet101_2)

    def _make_sliced_model(self, model, stop_layer):
        layers = OrderedDict()
        for name, layer in model.named_children():
            layers[name] = layer
            if name == stop_layer:
                break
        new_model = torch.nn.Sequential(layers)
        return new_model

    def test_resnet_dilation(self):
        # TODO improve tests to also check that each layer has the right dimensionality
        for i in product([False, True], [False, True], [False, True]):
            model = models.__dict__["resnet50"](replace_stride_with_dilation=i)
            model = self._make_sliced_model(model, stop_layer="layer4")
            model.eval()
            x = self._get_test_input(STANDARD_INPUT_SHAPE)
            out = model(x)
            f = 2 ** sum(i)
            self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f))


class SegmentationModelTester(ModelTester):
343
344
345
346
    def _test_segmentation_model(self, name):
        # passing num_class equal to a number other than 1000 helps in making the test
        # more enforcing in nature
        model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False)
347
        self.check_script(model, name)
348
349
350
351
352
353
        model.eval()
        input_shape = (1, 3, 300, 300)
        x = torch.rand(input_shape)
        out = model(x)
        self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))

354
355

class DetectionModelTester(ModelTester):
356
    def _test_detection_model(self, name):
eellison's avatar
eellison committed
357
        set_rng_seed(0)
358
        model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
359
        self.check_script(model, name)
360
361
362
        model.eval()
        input_shape = (3, 300, 300)
        x = torch.rand(input_shape)
363
364
365
        model_input = [x]
        out = model(model_input)
        self.assertIs(model_input[0], x)
366
        self.assertEqual(len(out), 1)
eellison's avatar
eellison committed
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383

        def compute_mean_std(tensor):
            # can't compute mean of integral tensor
            tensor = tensor.to(torch.double)
            mean = torch.mean(tensor)
            std = torch.std(tensor)
            return {"mean": mean, "std": std}

        # maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
        # compare results with mean and std
        if name == "maskrcnn_resnet50_fpn":
            test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std)
            # mean values are small, use large rtol
            self.assertExpected(test_value, rtol=.01, atol=.01)
        else:
            self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor))

384
385
386
387
        self.assertTrue("boxes" in out[0])
        self.assertTrue("scores" in out[0])
        self.assertTrue("labels" in out[0])

388
389
390
391
392
393
394
395
396
397
398
399
400
401
    def test_fasterrcnn_double(self):
        model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
        model.double()
        model.eval()
        input_shape = (3, 300, 300)
        x = torch.rand(input_shape, dtype=torch.float64)
        model_input = [x]
        out = model(model_input)
        self.assertIs(model_input[0], x)
        self.assertEqual(len(out), 1)
        self.assertTrue("boxes" in out[0])
        self.assertTrue("scores" in out[0])
        self.assertTrue("labels" in out[0])

402

403
404
405
406
407
408
409
410
411
412
413
class VideoModelTester(ModelTester):
    def _test_video_model(self, name):
        # the default input shape is
        # bs * num_channels * clip_len * h *w
        input_shape = (1, 3, 4, 112, 112)
        # test both basicblock and Bottleneck
        model = models.video.__dict__[name](num_classes=50)
        self.check_script(model, name)
        x = torch.rand(input_shape)
        out = model(x)
        self.assertEqual(out.shape[-1], 50)
414
415
416
417
418
419
420


for model_name in get_available_segmentation_models():
    # for-loop bodies don't define scopes, so we have to save the variables
    # we want to close over in some way
    def do_test(self, model_name=model_name):
        self._test_segmentation_model(model_name)
421

422
    setattr(SegmentationModelTester, "test_" + model_name, do_test)
423
424


425
426
427
428
429
430
for model_name in get_available_detection_models():
    # for-loop bodies don't define scopes, so we have to save the variables
    # we want to close over in some way
    def do_test(self, model_name=model_name):
        self._test_detection_model(model_name)

431
    setattr(DetectionModelTester, "test_" + model_name, do_test)
432

433

434
435
436
437
438
for model_name in get_available_video_models():

    def do_test(self, model_name=model_name):
        self._test_video_model(model_name)

439
    setattr(VideoModelTester, "test_" + model_name, do_test)
440

441
442
if __name__ == '__main__':
    unittest.main()