test_quantized_models.py 2.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torchvision
from common_utils import TestCase, map_nested_tensor_object
from collections import OrderedDict
from itertools import product
import torch
import numpy as np
from torchvision import models
import unittest
import traceback
import random


def set_rng_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)


def get_available_quantizable_models():
    # TODO add a registration mechanism to torchvision.models
    return [k for k, v in models.quantization.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


# list of models that are not scriptable
scriptable_quantizable_models_blacklist = []


@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines and
                     'qnnpack' in torch.backends.quantized.supported_engines,
                     "This Pytorch Build has not been built with fbgemm and qnnpack")
class ModelTester(TestCase):
    def check_quantized_model(self, model, input_shape):
        x = torch.rand(input_shape)
        model(x)
        return

    def check_script(self, model, name):
        if name in scriptable_quantizable_models_blacklist:
            return
        scriptable = True
        msg = ""
        try:
            torch.jit.script(model)
        except Exception as e:
            tb = traceback.format_exc()
            scriptable = False
            msg = str(e) + str(tb)
        self.assertTrue(scriptable, msg)

    def _test_classification_model(self, name, input_shape):
        # First check if quantize=True provides models that can run with input data

53
        model = torchvision.models.quantization.__dict__[name](pretrained=False, quantize=True)
54
55
        self.check_quantized_model(model, input_shape)

56
        for eval_mode in [True, False]:
57
            model = torchvision.models.quantization.__dict__[name](pretrained=False, quantize=False)
58
            if eval_mode:
59
60
61
62
63
64
65
                model.eval()
                model.qconfig = torch.quantization.default_qconfig
            else:
                model.train()
                model.qconfig = torch.quantization.default_qat_qconfig

            model.fuse_model()
66
            if eval_mode:
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
                torch.quantization.prepare(model, inplace=True)
            else:
                torch.quantization.prepare_qat(model, inplace=True)
                model.eval()

            torch.quantization.convert(model, inplace=True)

        self.check_script(model, name)


for model_name in get_available_quantizable_models():
    # for-loop bodies don't define scopes, so we have to save the variables
    # we want to close over in some way
    def do_test(self, model_name=model_name):
        input_shape = (1, 3, 224, 224)
        if model_name in ['inception_v3']:
            input_shape = (1, 3, 299, 299)
        self._test_classification_model(model_name, input_shape)

    setattr(ModelTester, "test_" + model_name, do_test)


if __name__ == '__main__':
    unittest.main()