test_models.py 1.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import torch
from torchvision import models
import unittest


def get_available_models():
    # TODO add a registration mechanism to torchvision.models
    return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0]]


class Tester(unittest.TestCase):
    def _test_model(self, name, input_shape):
13
14
        # passing num_class equal to a number other than 1000 helps in making the test more enforcing in nature
        model = models.__dict__[name](num_classes=50)
15
16
17
        model.eval()
        x = torch.rand(input_shape)
        out = model(x)
18
        self.assertEqual(out.shape[-1], 50)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


for model_name in get_available_models():
    # for-loop bodies don't define scopes, so we have to save the variables
    # we want to close over in some way
    def do_test(self, model_name=model_name):
        input_shape = (1, 3, 224, 224)
        if model_name in ['inception_v3']:
            input_shape = (1, 3, 299, 299)
        self._test_model(model_name, input_shape)

    setattr(Tester, "test_" + model_name, do_test)


if __name__ == '__main__':
    unittest.main()