test_model_speedup.py 6.8 KB
Newer Older
chicm-ms's avatar
chicm-ms committed
1
2
3
4
5
6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import numpy as np
import torch
7
import torchvision.models as models
chicm-ms's avatar
chicm-ms committed
8
9
10
11
12
13
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.vgg import vgg16
from torchvision.models.resnet import resnet18
from unittest import TestCase, main

14
from nni.compression.torch import L1FilterPruner, apply_compression_results, ModelSpeedup
chicm-ms's avatar
chicm-ms committed
15

chicm-ms's avatar
chicm-ms committed
16
torch.manual_seed(0)
17
18
19
20
21
22
23
24
25
26
27
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2
# the relative distance
RELATIVE_THRESHOLD = 0.01
# Because of the precision of floating-point numbers, some errors
# between the original output tensors(without speedup) and the output
# tensors of the speedup model are normal. When the output tensor itself
# is small, such errors may exceed the relative threshold, so we also add
# an absolute threshold to determine whether the final result is correct.
# The error should meet the RELATIVE_THREHOLD or the ABSOLUTE_THRESHOLD.
ABSOLUTE_THRESHOLD = 0.0001
chicm-ms's avatar
chicm-ms committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class BackboneModel1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 1, 1, 1)
    def forward(self, x):
        return self.conv1(x)

class BackboneModel2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
        self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class BigModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone1 = BackboneModel1()
        self.backbone2 = BackboneModel2()
        self.fc3 =  nn.Sequential(
            nn.Linear(10, 10),
            nn.BatchNorm1d(10),
            nn.ReLU(inplace=True),
            nn.Linear(10, 2)
        )
    def forward(self, x):
        x = self.backbone1(x)
        x = self.backbone2(x)
        x = self.fc3(x)
        return x

chicm-ms's avatar
chicm-ms committed
73
dummy_input = torch.randn(2, 1, 28, 28)
chicm-ms's avatar
chicm-ms committed
74
SPARSITY = 0.5
chicm-ms's avatar
chicm-ms committed
75
76
MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth'

chicm-ms's avatar
chicm-ms committed
77
78
79
80
81
82
83
def prune_model_l1(model):
    config_list = [{
        'sparsity': SPARSITY,
        'op_types': ['Conv2d']
    }]
    pruner = L1FilterPruner(model, config_list)
    pruner.compress()
chicm-ms's avatar
chicm-ms committed
84
    pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
chicm-ms's avatar
chicm-ms committed
85

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
def generate_random_sparsity(model):
    cfg_list = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            sparsity = np.random.uniform(0.5, 0.99)
            cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
                             'sparsity': sparsity})
    return cfg_list

def zero_bn_bias(model):
    with torch.no_grad():
        for name, module in model.named_modules():
            if isinstance(module, nn.BatchNorm2d) \
            or isinstance(module, nn.BatchNorm3d) \
            or isinstance(module, nn.BatchNorm1d):
                shape = module.bias.data.size()
                device = module.bias.device
                module.bias.data = torch.zeros(shape).to(device)
                shape = module.running_mean.data.size()
                module.running_mean = torch.zeros(shape).to(device)

chicm-ms's avatar
chicm-ms committed
107
108
109
110
111
class SpeedupTestCase(TestCase):
    def test_speedup_vgg16(self):
        prune_model_l1(vgg16())
        model = vgg16()
        model.train()
chicm-ms's avatar
chicm-ms committed
112
        ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), MASK_FILE)
chicm-ms's avatar
chicm-ms committed
113
114
115
116
117
118
119
120
121
122
        ms.speedup_model()

        orig_model = vgg16()
        assert model.training
        assert model.features[2].out_channels == int(orig_model.features[2].out_channels * SPARSITY)
        assert model.classifier[0].in_features == int(orig_model.classifier[0].in_features * SPARSITY)

    def test_speedup_bigmodel(self):
        prune_model_l1(BigModel())
        model = BigModel()
chicm-ms's avatar
chicm-ms committed
123
124
125
126
        apply_compression_results(model, MASK_FILE, 'cpu')
        model.eval()
        mask_out = model(dummy_input)

chicm-ms's avatar
chicm-ms committed
127
        model.train()
chicm-ms's avatar
chicm-ms committed
128
        ms = ModelSpeedup(model, dummy_input, MASK_FILE)
chicm-ms's avatar
chicm-ms committed
129
        ms.speedup_model()
chicm-ms's avatar
chicm-ms committed
130
131
132
133
134
135
136
137
138
        assert model.training

        model.eval()
        speedup_out = model(dummy_input)
        if not torch.allclose(mask_out, speedup_out, atol=1e-07):
            print('input:', dummy_input.size(), torch.abs(dummy_input).sum((2,3)))
            print('mask_out:', mask_out)
            print('speedup_out:', speedup_out)
            raise RuntimeError('model speedup inference result is incorrect!')
chicm-ms's avatar
chicm-ms committed
139
140

        orig_model = BigModel()
chicm-ms's avatar
chicm-ms committed
141

chicm-ms's avatar
chicm-ms committed
142
143
144
145
146
        assert model.backbone2.conv1.out_channels == int(orig_model.backbone2.conv1.out_channels * SPARSITY)
        assert model.backbone2.conv2.in_channels == int(orig_model.backbone2.conv2.in_channels * SPARSITY)
        assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY)
        assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)

147
    def test_speedup_integration(self):
Ningxin Zheng's avatar
Ningxin Zheng committed
148
        for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3']:
149
150
            Model = getattr(models, model_name)
            net = Model(pretrained=True, progress=False).to(device)
Ningxin Zheng's avatar
Ningxin Zheng committed
151
            speedup_model = Model().to(device)
152
            net.eval() # this line is necessary
Ningxin Zheng's avatar
Ningxin Zheng committed
153
            speedup_model.eval()
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            # random generate the prune config for the pruner
            cfgs = generate_random_sparsity(net)
            pruner = L1FilterPruner(net, cfgs)
            pruner.compress()
            pruner.export_model(MODEL_FILE, MASK_FILE)
            pruner._unwrap_model()
            state_dict = torch.load(MODEL_FILE)
            speedup_model.load_state_dict(state_dict)
            zero_bn_bias(net)
            zero_bn_bias(speedup_model)

            data = torch.ones(BATCH_SIZE, 3, 224, 224).to(device)
            ms = ModelSpeedup(speedup_model, data, MASK_FILE)
            ms.speedup_model()
            ori_out = net(data)
            speeded_out = speedup_model(data)
            ori_sum = torch.sum(ori_out).item()
            speeded_sum = torch.sum(speeded_out).item()
            print('Sum of the output of %s (before speedup):'%model_name, ori_sum)
            print('Sum of the output of %s (after speedup):'%model_name, speeded_sum)
            assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
                   (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)

chicm-ms's avatar
chicm-ms committed
177
    def tearDown(self):
chicm-ms's avatar
chicm-ms committed
178
179
        os.remove(MODEL_FILE)
        os.remove(MASK_FILE)
chicm-ms's avatar
chicm-ms committed
180
181
182

if __name__ == '__main__':
    main()