"examples/model_compress/pruning/model_speedup.py" did not exist on "ed121315f4c5df8ed8c00e28ad12f00418fa3bb1"
test_model_speedup.py 13.2 KB
Newer Older
chicm-ms's avatar
chicm-ms committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
5
import sys
chicm-ms's avatar
chicm-ms committed
6
7
import numpy as np
import torch
8
import torchvision.models as models
chicm-ms's avatar
chicm-ms committed
9
10
11
12
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.vgg import vgg16
from torchvision.models.resnet import resnet18
13
import unittest
chicm-ms's avatar
chicm-ms committed
14
15
from unittest import TestCase, main

liuzhe-lz's avatar
liuzhe-lz committed
16
17
from nni.compression.pytorch import ModelSpeedup, apply_compression_results
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
18
19
from nni.algorithms.compression.pytorch.pruning.weight_masker import WeightMasker
from nni.algorithms.compression.pytorch.pruning.one_shot import _StructuredFilterPruner
chicm-ms's avatar
chicm-ms committed
20

chicm-ms's avatar
chicm-ms committed
21
torch.manual_seed(0)
22
23
24
25
26
27
28
29
30
31
32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2
# the relative distance
RELATIVE_THRESHOLD = 0.01
# Because of the precision of floating-point numbers, some errors
# between the original output tensors(without speedup) and the output
# tensors of the speedup model are normal. When the output tensor itself
# is small, such errors may exceed the relative threshold, so we also add
# an absolute threshold to determine whether the final result is correct.
# The error should meet the RELATIVE_THREHOLD or the ABSOLUTE_THRESHOLD.
ABSOLUTE_THRESHOLD = 0.0001
33
34


chicm-ms's avatar
chicm-ms committed
35
36
37
38
class BackboneModel1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 1, 1, 1)
39

chicm-ms's avatar
chicm-ms committed
40
41
42
    def forward(self, x):
        return self.conv1(x)

43

chicm-ms's avatar
chicm-ms committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class BackboneModel2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
        self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
60

chicm-ms's avatar
chicm-ms committed
61
62
63
64
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

65

chicm-ms's avatar
chicm-ms committed
66
67
68
69
70
class BigModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone1 = BackboneModel1()
        self.backbone2 = BackboneModel2()
71
        self.fc3 = nn.Sequential(
chicm-ms's avatar
chicm-ms committed
72
73
74
75
76
            nn.Linear(10, 10),
            nn.BatchNorm1d(10),
            nn.ReLU(inplace=True),
            nn.Linear(10, 2)
        )
77

chicm-ms's avatar
chicm-ms committed
78
79
80
81
82
83
    def forward(self, x):
        x = self.backbone1(x)
        x = self.backbone2(x)
        x = self.fc3(x)
        return x

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

class TransposeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 20, 5)
        self.conv2 = nn.ConvTranspose2d(20, 50, 5, groups=2)
        self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
        self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
        self.fc1 = nn.Linear(8 * 8 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        # x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        # x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


chicm-ms's avatar
chicm-ms committed
107
dummy_input = torch.randn(2, 1, 28, 28)
chicm-ms's avatar
chicm-ms committed
108
SPARSITY = 0.5
chicm-ms's avatar
chicm-ms committed
109
110
MODEL_FILE, MASK_FILE = './11_model.pth', './l1_mask.pth'

111

chicm-ms's avatar
chicm-ms committed
112
113
114
115
116
117
118
def prune_model_l1(model):
    config_list = [{
        'sparsity': SPARSITY,
        'op_types': ['Conv2d']
    }]
    pruner = L1FilterPruner(model, config_list)
    pruner.compress()
chicm-ms's avatar
chicm-ms committed
119
    pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
chicm-ms's avatar
chicm-ms committed
120

121

122
123
124
125
126
127
128
129
130
def generate_random_sparsity(model):
    cfg_list = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            sparsity = np.random.uniform(0.5, 0.99)
            cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
                             'sparsity': sparsity})
    return cfg_list

131

132
133
134
135
def zero_bn_bias(model):
    with torch.no_grad():
        for name, module in model.named_modules():
            if isinstance(module, nn.BatchNorm2d) \
136
137
                    or isinstance(module, nn.BatchNorm3d) \
                    or isinstance(module, nn.BatchNorm1d):
138
139
140
141
142
143
                shape = module.bias.data.size()
                device = module.bias.device
                module.bias.data = torch.zeros(shape).to(device)
                shape = module.running_mean.data.size()
                module.running_mean = torch.zeros(shape).to(device)

144

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class L1ChannelMasker(WeightMasker):
    def __init__(self, model, pruner):
        self.model = model
        self.pruner = pruner

    def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
        msg = 'module type {} is not supported!'.format(wrapper.type)
        #assert wrapper.type == 'Conv2d', msg
        weight = wrapper.module.weight.data
        bias = None
        if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
            bias = wrapper.module.bias.data

        if wrapper.weight_mask is None:
            mask_weight = torch.ones(weight.size()).type_as(weight).detach()
        else:
            mask_weight = wrapper.weight_mask.clone()
        if bias is not None:
            if wrapper.bias_mask is None:
                mask_bias = torch.ones(bias.size()).type_as(bias).detach()
            else:
                mask_bias = wrapper.bias_mask.clone()
        else:
            mask_bias = None
        base_mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}

        num_total = weight.size(1)
        num_prune = int(num_total * sparsity)

        if num_total < 2 or num_prune < 1:
            return base_mask
        w_abs = weight.abs()
        if wrapper.type == 'Conv2d':
            w_abs_structured = w_abs.sum((0, 2, 3))
179
180
181
182
            threshold = torch.topk(
                w_abs_structured, num_prune, largest=False)[0].max()
            mask_weight = torch.gt(w_abs_structured, threshold)[
                None, :, None, None].expand_as(weight).type_as(weight)
183
184
185
186
187
            return {'weight_mask': mask_weight.detach()}
        else:
            # Linear
            assert wrapper.type == 'Linear'
            w_abs_structured = w_abs.sum((0))
188
189
190
191
            threshold = torch.topk(
                w_abs_structured, num_prune, largest=False)[0].max()
            mask_weight = torch.gt(w_abs_structured, threshold)[
                None, :].expand_as(weight).type_as(weight)
192
193
            return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}

194

195
196
197
198
class L1ChannelPruner(_StructuredFilterPruner):
    def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
        super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer,
                         dependency_aware=dependency_aware, dummy_input=dummy_input)
199

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def validate_config(self, model, config_list):
        pass


def channel_prune(model):
    config_list = [{
        'sparsity': SPARSITY,
        'op_types': ['Conv2d', 'Linear']
    }, {
        'op_names': ['conv1'],
        'exclude': True
    }]

    pruner = L1ChannelPruner(model, config_list)
    masker = L1ChannelMasker(model, pruner)
    pruner.masker = masker
    pruner.compress()
    pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)

219

chicm-ms's avatar
chicm-ms committed
220
221
222
223
224
class SpeedupTestCase(TestCase):
    def test_speedup_vgg16(self):
        prune_model_l1(vgg16())
        model = vgg16()
        model.train()
chicm-ms's avatar
chicm-ms committed
225
        ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), MASK_FILE)
chicm-ms's avatar
chicm-ms committed
226
227
228
229
        ms.speedup_model()

        orig_model = vgg16()
        assert model.training
230
231
232
233
        assert model.features[2].out_channels == int(
            orig_model.features[2].out_channels * SPARSITY)
        assert model.classifier[0].in_features == int(
            orig_model.classifier[0].in_features * SPARSITY)
chicm-ms's avatar
chicm-ms committed
234
235
236
237

    def test_speedup_bigmodel(self):
        prune_model_l1(BigModel())
        model = BigModel()
chicm-ms's avatar
chicm-ms committed
238
239
240
241
        apply_compression_results(model, MASK_FILE, 'cpu')
        model.eval()
        mask_out = model(dummy_input)

chicm-ms's avatar
chicm-ms committed
242
        model.train()
chicm-ms's avatar
chicm-ms committed
243
        ms = ModelSpeedup(model, dummy_input, MASK_FILE)
chicm-ms's avatar
chicm-ms committed
244
        ms.speedup_model()
chicm-ms's avatar
chicm-ms committed
245
246
247
248
249
        assert model.training

        model.eval()
        speedup_out = model(dummy_input)
        if not torch.allclose(mask_out, speedup_out, atol=1e-07):
250
251
            print('input:', dummy_input.size(),
                  torch.abs(dummy_input).sum((2, 3)))
chicm-ms's avatar
chicm-ms committed
252
253
254
            print('mask_out:', mask_out)
            print('speedup_out:', speedup_out)
            raise RuntimeError('model speedup inference result is incorrect!')
chicm-ms's avatar
chicm-ms committed
255
256

        orig_model = BigModel()
chicm-ms's avatar
chicm-ms committed
257

258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        assert model.backbone2.conv1.out_channels == int(
            orig_model.backbone2.conv1.out_channels * SPARSITY)
        assert model.backbone2.conv2.in_channels == int(
            orig_model.backbone2.conv2.in_channels * SPARSITY)
        assert model.backbone2.conv2.out_channels == int(
            orig_model.backbone2.conv2.out_channels * SPARSITY)
        assert model.backbone2.fc1.in_features == int(
            orig_model.backbone2.fc1.in_features * SPARSITY)

    def test_convtranspose_model(self):
        ori_model = TransposeModel()
        dummy_input = torch.rand(1, 3, 8, 8)
        config_list = [{'sparsity': 0.5, 'op_types': ['Conv2d']}]
        pruner = L1FilterPruner(ori_model, config_list)
        pruner.compress()
        ori_model(dummy_input)
        pruner.export_model(MODEL_FILE, MASK_FILE)
        pruner._unwrap_model()
        new_model = TransposeModel()
        state_dict = torch.load(MODEL_FILE)
        new_model.load_state_dict(state_dict)
        ms = ModelSpeedup(new_model, dummy_input, MASK_FILE)
        ms.speedup_model()
        zero_bn_bias(ori_model)
        zero_bn_bias(new_model)
        ori_out = ori_model(dummy_input)
        new_out = new_model(dummy_input)
        ori_sum = torch.sum(ori_out)
        speeded_sum = torch.sum(new_out)
        print('Tanspose Speedup Test: ori_sum={} speedup_sum={}'.format(ori_sum, speeded_sum))
        assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
                (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
chicm-ms's avatar
chicm-ms committed
290

liuzhe-lz's avatar
liuzhe-lz committed
291
292
    # FIXME: This test case might fail randomly, no idea why
    # Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
293

294
    def test_speedup_integration(self):
295
296
297
298
        for model_name in ['resnet18', 'squeezenet1_1', 
                           'mobilenet_v2', 'densenet121',
                           # 'inception_v3' inception is too large and may fail the pipeline
                           'densenet169', 'resnet50']:
299
300
301
302
303
304
305
306
307
308
            kwargs = {
                'pretrained': True
            }
            if model_name == 'resnet50':
                # testing multiple groups
                kwargs = {
                    'pretrained': False,
                    'groups': 4
                }

309
            Model = getattr(models, model_name)
310
311
            net = Model(**kwargs).to(device)
            speedup_model = Model(**kwargs).to(device)
312
            net.eval()  # this line is necessary
Ningxin Zheng's avatar
Ningxin Zheng committed
313
            speedup_model.eval()
314
315
316
317
318
319
320
321
322
323
324
            # random generate the prune config for the pruner
            cfgs = generate_random_sparsity(net)
            pruner = L1FilterPruner(net, cfgs)
            pruner.compress()
            pruner.export_model(MODEL_FILE, MASK_FILE)
            pruner._unwrap_model()
            state_dict = torch.load(MODEL_FILE)
            speedup_model.load_state_dict(state_dict)
            zero_bn_bias(net)
            zero_bn_bias(speedup_model)

liuzhe-lz's avatar
liuzhe-lz committed
325
            data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
326
327
            ms = ModelSpeedup(speedup_model, data, MASK_FILE)
            ms.speedup_model()
328
329
330

            speedup_model.eval()

331
332
333
334
            ori_out = net(data)
            speeded_out = speedup_model(data)
            ori_sum = torch.sum(ori_out).item()
            speeded_sum = torch.sum(speeded_out).item()
335
336
337
338
            print('Sum of the output of %s (before speedup):' %
                  model_name, ori_sum)
            print('Sum of the output of %s (after speedup):' %
                  model_name, speeded_sum)
339
340
341
            assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
                   (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    def test_channel_prune(self):
        orig_net = resnet18(num_classes=10).to(device)
        channel_prune(orig_net)
        state_dict = torch.load(MODEL_FILE)

        orig_net = resnet18(num_classes=10).to(device)
        orig_net.load_state_dict(state_dict)
        apply_compression_results(orig_net, MASK_FILE)
        orig_net.eval()

        net = resnet18(num_classes=10).to(device)

        net.load_state_dict(state_dict)
        net.eval()

liuzhe-lz's avatar
liuzhe-lz committed
357
        data = torch.randn(BATCH_SIZE, 3, 128, 128).to(device)
358
359
360
361
362
363
364
365
366
367
368
369
370
        ms = ModelSpeedup(net, data, MASK_FILE)
        ms.speedup_model()
        ms.bound_model(data)

        net.eval()

        ori_sum = orig_net(data).abs().sum().item()
        speeded_sum = net(data).abs().sum().item()

        print(ori_sum, speeded_sum)
        assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
            (abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)

chicm-ms's avatar
chicm-ms committed
371
    def tearDown(self):
chicm-ms's avatar
chicm-ms committed
372
373
        os.remove(MODEL_FILE)
        os.remove(MASK_FILE)
chicm-ms's avatar
chicm-ms committed
374

375

chicm-ms's avatar
chicm-ms committed
376
377
if __name__ == '__main__':
    main()