test_compressor_torch.py 15.4 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
from unittest import TestCase, main
5
import numpy as np
6
7
import torch
import torch.nn.functional as F
8
import schema
9
10
import nni.algorithms.compression.pytorch.pruning as torch_pruner
import nni.algorithms.compression.pytorch.quantization as torch_quantizer
11
import math
12

Tang Lang's avatar
Tang Lang committed
13

14
class TorchModel(torch.nn.Module):
15
16
    def __init__(self):
        super().__init__()
17
        self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
Tang Lang's avatar
Tang Lang committed
18
        self.bn1 = torch.nn.BatchNorm2d(5)
19
        self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
Tang Lang's avatar
Tang Lang committed
20
        self.bn2 = torch.nn.BatchNorm2d(10)
21
22
        self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
        self.fc2 = torch.nn.Linear(100, 10)
23
24

    def forward(self, x):
Tang Lang's avatar
Tang Lang committed
25
        x = F.relu(self.bn1(self.conv1(x)))
26
        x = F.max_pool2d(x, 2, 2)
Tang Lang's avatar
Tang Lang committed
27
        x = F.relu(self.bn2(self.conv2(x)))
28
        x = F.max_pool2d(x, 2, 2)
29
        x = x.view(-1, 4 * 4 * 10)
30
31
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
32
33
        return F.log_softmax(x, dim=1)

Tang Lang's avatar
Tang Lang committed
34

35
class CompressorTestCase(TestCase):
36
37
38
39
40
41
    def test_torch_quantizer_modules_detection(self):
        # test if modules can be detected
        model = TorchModel()
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
42
            'op_types': ['Conv2d', 'Linear']
43
44
45
46
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
47
            'op_types': ['ReLU']
48
49
50
        }]

        model.relu = torch.nn.ReLU()
51
        quantizer = torch_quantizer.QAT_Quantizer(model, config_list)
52
53
        quantizer.compress()
        modules_to_compress = quantizer.get_modules_to_compress()
Tang Lang's avatar
Tang Lang committed
54
        modules_to_compress_name = [t[0].name for t in modules_to_compress]
55
56
57
58
59
60
61
        assert "conv1" in modules_to_compress_name
        assert "conv2" in modules_to_compress_name
        assert "fc1" in modules_to_compress_name
        assert "fc2" in modules_to_compress_name
        assert "relu" in modules_to_compress_name
        assert len(modules_to_compress_name) == 5

62
63
    def test_torch_level_pruner(self):
        model = TorchModel()
Guoxin's avatar
Guoxin committed
64
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
chicm-ms's avatar
chicm-ms committed
65
        configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
66
        torch_pruner.LevelPruner(model, configure_list, optimizer).compress()
67

68
69
    def test_torch_naive_quantizer(self):
        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
70
71
72
73
74
        configure_list = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            },
Tang Lang's avatar
Tang Lang committed
75
            'op_types': ['Conv2d', 'Linear']
Cjkkkk's avatar
Cjkkkk committed
76
        }]
77
        torch_quantizer.NaiveQuantizer(model, configure_list).compress()
78

79
80
    def test_torch_fpgm_pruner(self):
        """
chicm-ms's avatar
chicm-ms committed
81
        With filters(kernels) weights defined as above (w), it is obvious that w[4] and w[5] is the Geometric Median
82
83
84
85
        which minimize the total geometric distance by defination of Geometric Median in this paper:
        Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
        https://arxiv.org/pdf/1811.00250.pdf

chicm-ms's avatar
chicm-ms committed
86
        So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
87
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))`
88

chicm-ms's avatar
chicm-ms committed
89
        If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
90
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))`
91
        """
92
        w = np.array([np.ones((5, 5, 5)) * (i+1) for i in range(10)]).astype(np.float32)
93
94

        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
95
        config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
96
        pruner = torch_pruner.FPGMPruner(model, config_list, torch.optim.SGD(model.parameters(), lr=0.01))
97

Cjkkkk's avatar
Cjkkkk committed
98
99
        model.conv2.module.weight.data = torch.tensor(w).float()
        masks = pruner.calc_mask(model.conv2)
100
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))
101

Cjkkkk's avatar
Cjkkkk committed
102
103
104
105
        model.conv2.module.weight.data = torch.tensor(w).float()
        model.conv2.if_calculated = False
        model.conv2.config = config_list[0]
        masks = pruner.calc_mask(model.conv2)
106
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))
107

liuzhe-lz's avatar
liuzhe-lz committed
108
       
Tang Lang's avatar
Tang Lang committed
109
110
111
112
113
114
    def test_torch_l1filter_pruner(self):
        """
        Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
        PRUNING FILTERS FOR EFFICIENT CONVNETS,
        https://arxiv.org/abs/1608.08710

115
116
        So if sparsity is 0.2 for conv1, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`
Tang Lang's avatar
Tang Lang committed
117

118
119
        If sparsity is 0.6 for conv2, the expected masks should mask out filter 0,1,2, this can be verified through:
        `all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))`
Tang Lang's avatar
Tang Lang committed
120
        """
121
122
123
        w1 = np.array([np.ones((1, 5, 5))*i for i in range(5)]).astype(np.float32)
        w2 = np.array([np.ones((5, 5, 5))*i for i in range(10)]).astype(np.float32)

Tang Lang's avatar
Tang Lang committed
124
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
125
126
        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
127
        pruner = torch_pruner.L1FilterPruner(model, config_list)
Tang Lang's avatar
Tang Lang committed
128

129
130
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()
Cjkkkk's avatar
Cjkkkk committed
131
132
        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
133
134
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))
Tang Lang's avatar
Tang Lang committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    def test_torch_slim_pruner(self):
        """
        Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
        Learning Efficient Convolutional Networks through Network Slimming,
        https://arxiv.org/pdf/1708.06519.pdf

        So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
        `all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`

        If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
        `all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
        """
        w = np.array([0, 1, 2, 3, 4])
        model = TorchModel()
        config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(-w).float()
155
        pruner = torch_pruner.SlimPruner(model, config_list)
Tang Lang's avatar
Tang Lang committed
156

Cjkkkk's avatar
Cjkkkk committed
157
158
159
160
161
162
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
163

Cjkkkk's avatar
Cjkkkk committed
164
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
165
166
167
        config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(w).float()
168
        pruner = torch_pruner.SlimPruner(model, config_list)
Tang Lang's avatar
Tang Lang committed
169

Cjkkkk's avatar
Cjkkkk committed
170
171
172
173
174
175
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
176

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def test_torch_taylorFOweight_pruner(self):
        """
        Filters with the minimum importance approxiamtion based on the first order 
        taylor expansion on the weights (w*grad)**2 are pruned in this paper:
        Importance Estimation for Neural Network Pruning,
        http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf

        So if sparsity of conv1 is 0.2, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`

        If sparsity of conv2 is 0.6, the expected masks should mask out filter 4,5,6,7,8,9 this can be verified through:
        `all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))`
        """

        w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
                      np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
        w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])

        grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
                      np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])

        grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])

        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]

        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
205
        pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, statistics_batch_num=1)
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        
        x = torch.rand((1, 1, 28, 28), requires_grad=True)
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()

        y = model(x)
        y.backward(torch.ones_like(y))

        model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
        model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
        optimizer.step()

        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))

223
224
225
226
227
    def test_torch_QAT_quantizer(self):
        model = TorchModel()
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
228
            'op_types': ['Conv2d', 'Linear']
229
230
231
232
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
233
            'op_types': ['ReLU']
234
235
        }]
        model.relu = torch.nn.ReLU()
236
        quantizer = torch_quantizer.QAT_Quantizer(model, config_list)
237
        quantizer.compress()
238

239
240
241
242
        # test quantize
        # range not including 0
        eps = 1e-7
        weight = torch.tensor([[1, 2], [3, 5]]).float()
243
244
        model.conv2.module.old_weight.data = weight
        quantizer.quantize_weight(model.conv2)
Cjkkkk's avatar
Cjkkkk committed
245
246
        assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
        assert model.conv2.module.zero_point == 0
Tang Lang's avatar
Tang Lang committed
247
        # range including 0
248
        weight = torch.tensor([[-1, 2], [3, 5]]).float()
249
250
        model.conv2.module.old_weight.data = weight
        quantizer.quantize_weight(model.conv2)
Cjkkkk's avatar
Cjkkkk committed
251
252
        assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
        assert model.conv2.module.zero_point in (42, 43)
253
254
255
256
257
258
259
260
261
262
        # test value of weight and bias after quantization
        weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
        weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
        bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
        bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
        model.conv2.module.old_weight.data = weight
        model.conv2.module.bias.data = bias
        quantizer.quantize_weight(model.conv2)
        assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
        assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7))
263
264

        # test ema
265
        eps = 1e-7
266
267
        x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
        out = model.relu(x)
chicm-ms's avatar
chicm-ms committed
268
269
        assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
        assert math.isclose(model.relu.module.tracked_max_biased, 0.002, abs_tol=eps)
270

Cjkkkk's avatar
Cjkkkk committed
271
        quantizer.step_with_optimizer()
272
273
        x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
        out = model.relu(x)
chicm-ms's avatar
chicm-ms committed
274
275
        assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps)
        assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps)
276

277
278
    def test_torch_pruner_validation(self):
        # test bad configuraiton
279
        pruner_classes = [torch_pruner.__dict__[x] for x in \
Guoxin's avatar
Guoxin committed
280
            ['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGPPruner',\
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
            'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]

        bad_configs = [
            [
                {'sparsity': '0.2'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2},
                {'sparsity': 1.6 }
            ],
            [
                {'sparsity': 0.2, 'op_types': 'default'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2 },
                {'sparsity': 0.6, 'op_names': 'abc' }
            ]
        ]
        model = TorchModel()
Guoxin's avatar
Guoxin committed
302
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
303
304
305
        for pruner_class in pruner_classes:
            for config_list in bad_configs:
                try:
Guoxin's avatar
Guoxin committed
306
                    pruner_class(model, config_list, optimizer)
307
308
309
310
311
312
313
314
315
316
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', pruner_class, config_list)
                    raise

    def test_torch_quantizer_validation(self):
        # test bad configuraiton
317
        quantizer_classes = [torch_quantizer.__dict__[x] for x in \
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
            ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']]

        bad_configs = [
            [
                {'bad_key': 'abc'}
            ],
            [
                {'quant_types': 'abc'}
            ],
            [
                {'quant_bits': 34}
            ],
            [
                {'op_types': 'default'}
            ],
            [
                {'quant_bits': {'abc': 123}}
            ]
        ]
        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        for quantizer_class in quantizer_classes:
            for config_list in bad_configs:
                try:
                    quantizer_class(model, config_list, optimizer)
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', quantizer_class, config_list)
                    raise
Tang Lang's avatar
Tang Lang committed
350

351
352
if __name__ == '__main__':
    main()