test_compressor_torch.py 17.7 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
from unittest import TestCase, main
5
import numpy as np
6
7
import torch
import torch.nn.functional as F
8
import schema
9
10
import nni.algorithms.compression.pytorch.pruning as torch_pruner
import nni.algorithms.compression.pytorch.quantization as torch_quantizer
11
import math
12

Tang Lang's avatar
Tang Lang committed
13

14
class TorchModel(torch.nn.Module):
15
16
    def __init__(self):
        super().__init__()
17
        self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
Tang Lang's avatar
Tang Lang committed
18
        self.bn1 = torch.nn.BatchNorm2d(5)
19
        self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
Tang Lang's avatar
Tang Lang committed
20
        self.bn2 = torch.nn.BatchNorm2d(10)
21
22
        self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
        self.fc2 = torch.nn.Linear(100, 10)
23
24

    def forward(self, x):
Tang Lang's avatar
Tang Lang committed
25
        x = F.relu(self.bn1(self.conv1(x)))
26
        x = F.max_pool2d(x, 2, 2)
Tang Lang's avatar
Tang Lang committed
27
        x = F.relu(self.bn2(self.conv2(x)))
28
        x = F.max_pool2d(x, 2, 2)
29
        x = x.view(-1, 4 * 4 * 10)
30
31
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
32
33
        return F.log_softmax(x, dim=1)

Tang Lang's avatar
Tang Lang committed
34

35
class CompressorTestCase(TestCase):
36
37
38
39
40
41
    def test_torch_quantizer_modules_detection(self):
        # test if modules can be detected
        model = TorchModel()
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
42
            'op_types': ['Conv2d', 'Linear']
43
44
45
46
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
47
            'op_types': ['ReLU']
48
49
50
        }]

        model.relu = torch.nn.ReLU()
51
        quantizer = torch_quantizer.QAT_Quantizer(model, config_list)
52
53
        quantizer.compress()
        modules_to_compress = quantizer.get_modules_to_compress()
Tang Lang's avatar
Tang Lang committed
54
        modules_to_compress_name = [t[0].name for t in modules_to_compress]
55
56
57
58
59
60
61
        assert "conv1" in modules_to_compress_name
        assert "conv2" in modules_to_compress_name
        assert "fc1" in modules_to_compress_name
        assert "fc2" in modules_to_compress_name
        assert "relu" in modules_to_compress_name
        assert len(modules_to_compress_name) == 5

62
63
    def test_torch_level_pruner(self):
        model = TorchModel()
chicm-ms's avatar
chicm-ms committed
64
        configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
65
        torch_pruner.LevelPruner(model, configure_list).compress()
66

67
68
    def test_torch_naive_quantizer(self):
        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
69
70
71
72
73
        configure_list = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            },
Tang Lang's avatar
Tang Lang committed
74
            'op_types': ['Conv2d', 'Linear']
Cjkkkk's avatar
Cjkkkk committed
75
        }]
76
        torch_quantizer.NaiveQuantizer(model, configure_list).compress()
77

78
79
    def test_torch_fpgm_pruner(self):
        """
chicm-ms's avatar
chicm-ms committed
80
        With filters(kernels) weights defined as above (w), it is obvious that w[4] and w[5] is the Geometric Median
81
82
83
84
        which minimize the total geometric distance by defination of Geometric Median in this paper:
        Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
        https://arxiv.org/pdf/1811.00250.pdf

chicm-ms's avatar
chicm-ms committed
85
        So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
86
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))`
87

chicm-ms's avatar
chicm-ms committed
88
        If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
89
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))`
90
        """
91
        w = np.array([np.ones((5, 5, 5)) * (i+1) for i in range(10)]).astype(np.float32)
92
93

        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
94
        config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
95
        pruner = torch_pruner.FPGMPruner(model, config_list)
96

Cjkkkk's avatar
Cjkkkk committed
97
98
        model.conv2.module.weight.data = torch.tensor(w).float()
        masks = pruner.calc_mask(model.conv2)
99
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))
100

Cjkkkk's avatar
Cjkkkk committed
101
102
103
104
        model.conv2.module.weight.data = torch.tensor(w).float()
        model.conv2.if_calculated = False
        model.conv2.config = config_list[0]
        masks = pruner.calc_mask(model.conv2)
105
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))
106

liuzhe-lz's avatar
liuzhe-lz committed
107
       
Tang Lang's avatar
Tang Lang committed
108
109
110
111
112
113
    def test_torch_l1filter_pruner(self):
        """
        Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
        PRUNING FILTERS FOR EFFICIENT CONVNETS,
        https://arxiv.org/abs/1608.08710

114
115
        So if sparsity is 0.2 for conv1, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`
Tang Lang's avatar
Tang Lang committed
116

117
118
        If sparsity is 0.6 for conv2, the expected masks should mask out filter 0,1,2, this can be verified through:
        `all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))`
Tang Lang's avatar
Tang Lang committed
119
        """
120
121
122
        w1 = np.array([np.ones((1, 5, 5))*i for i in range(5)]).astype(np.float32)
        w2 = np.array([np.ones((5, 5, 5))*i for i in range(10)]).astype(np.float32)

Tang Lang's avatar
Tang Lang committed
123
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
124
125
        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
126
        pruner = torch_pruner.L1FilterPruner(model, config_list)
Tang Lang's avatar
Tang Lang committed
127

128
129
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()
Cjkkkk's avatar
Cjkkkk committed
130
131
        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
132
133
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))
Tang Lang's avatar
Tang Lang committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

    def test_torch_slim_pruner(self):
        """
        Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
        Learning Efficient Convolutional Networks through Network Slimming,
        https://arxiv.org/pdf/1708.06519.pdf

        So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
        `all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`

        If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
        `all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
        """
        w = np.array([0, 1, 2, 3, 4])
        model = TorchModel()
        config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(-w).float()
154
        pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
Tang Lang's avatar
Tang Lang committed
155

Cjkkkk's avatar
Cjkkkk committed
156
157
158
159
160
161
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
162

Cjkkkk's avatar
Cjkkkk committed
163
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
164
165
166
        config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(w).float()
167
        pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
Tang Lang's avatar
Tang Lang committed
168

Cjkkkk's avatar
Cjkkkk committed
169
170
171
172
173
174
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def test_torch_taylorFOweight_pruner(self):
        """
        Filters with the minimum importance approxiamtion based on the first order 
        taylor expansion on the weights (w*grad)**2 are pruned in this paper:
        Importance Estimation for Neural Network Pruning,
        http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf

        So if sparsity of conv1 is 0.2, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`

        If sparsity of conv2 is 0.6, the expected masks should mask out filter 4,5,6,7,8,9 this can be verified through:
        `all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))`
        """

        w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
                      np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
        w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])

        grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
                      np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])

        grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])

        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]

        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
J-shang's avatar
J-shang committed
204
        pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1)
205

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        x = torch.rand((1, 1, 28, 28), requires_grad=True)
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()

        y = model(x)
        y.backward(torch.ones_like(y))

        model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
        model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
        optimizer.step()

        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))

222
223
224
225
226
    def test_torch_QAT_quantizer(self):
        model = TorchModel()
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
227
            'op_types': ['Conv2d', 'Linear']
228
229
230
231
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
232
            'op_types': ['ReLU']
233
234
        }]
        model.relu = torch.nn.ReLU()
235
        quantizer = torch_quantizer.QAT_Quantizer(model, config_list)
236
        quantizer.compress()
237

238
239
240
        # test quantize
        # range not including 0
        eps = 1e-7
241
        input = torch.tensor([[0, 4], [2, 1]]).float()
242
        weight = torch.tensor([[1, 2], [3, 5]]).float()
243
        model.conv2.module.old_weight.data = weight
244
        quantizer.quantize_weight(model.conv2, input_tensor=input)
Cjkkkk's avatar
Cjkkkk committed
245
246
        assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
        assert model.conv2.module.zero_point == 0
Tang Lang's avatar
Tang Lang committed
247
        # range including 0
248
        weight = torch.tensor([[-1, 2], [3, 5]]).float()
249
        model.conv2.module.old_weight.data = weight
250
        quantizer.quantize_weight(model.conv2, input_tensor=input)
Cjkkkk's avatar
Cjkkkk committed
251
252
        assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
        assert model.conv2.module.zero_point in (42, 43)
253
254
255
256
257
258
259
        # test value of weight and bias after quantization
        weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
        weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
        bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
        bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
        model.conv2.module.old_weight.data = weight
        model.conv2.module.bias.data = bias
260
        quantizer.quantize_weight(model.conv2, input_tensor=input)
261
262
        assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
        assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7))
263
264

        # test ema
265
        eps = 1e-7
266
267
        x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
        out = model.relu(x)
268
269
        assert math.isclose(model.relu.module.tracked_min_activation, 0, abs_tol=eps)
        assert math.isclose(model.relu.module.tracked_max_activation, 0.002, abs_tol=eps)
270

Cjkkkk's avatar
Cjkkkk committed
271
        quantizer.step_with_optimizer()
272
273
        x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
        out = model.relu(x)
274
275
        assert math.isclose(model.relu.module.tracked_min_activation, 0.002, abs_tol=eps)
        assert math.isclose(model.relu.module.tracked_max_activation, 0.00998, abs_tol=eps)
276

lin bin's avatar
lin bin committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    def test_torch_quantizer_export(self):
        config_list_qat = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
            'op_types': ['Conv2d', 'Linear']
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
            'op_types': ['ReLU']
        }]
        config_list_dorefa = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            }, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
            'op_types':['Conv2d', 'Linear']
        }]
        config_list_bnn = [{
            'quant_types': ['weight'],
            'quant_bits': 1,
            'op_types': ['Conv2d', 'Linear']
        }, {
            'quant_types': ['output'],
            'quant_bits': 1,
            'op_types': ['ReLU']
        }]
        config_set = [config_list_qat, config_list_dorefa, config_list_bnn]
        quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer]

        for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
            model = TorchModel()
            model.relu = torch.nn.ReLU()
            quantizer = quantize_algorithm(model, config)
            quantizer.compress()

            x = torch.rand((1, 1, 28, 28), requires_grad=True)
            y = model(x)
            y.backward(torch.ones_like(y))

            model_path = "test_model.pth"
            calibration_path = "test_calibration.pth"
            onnx_path = "test_model.onnx"
            input_shape = (1, 1, 28, 28)
            device = torch.device("cpu")

            calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
            assert calibration_config is not None

326
327
    def test_torch_pruner_validation(self):
        # test bad configuraiton
328
        pruner_classes = [torch_pruner.__dict__[x] for x in \
Guoxin's avatar
Guoxin committed
329
            ['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGPPruner',\
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]

        bad_configs = [
            [
                {'sparsity': '0.2'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2},
                {'sparsity': 1.6 }
            ],
            [
                {'sparsity': 0.2, 'op_types': 'default'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2 },
347
                {'sparsity': 0.6, 'op_names': 'abc'}
348
349
350
            ]
        ]
        model = TorchModel()
Guoxin's avatar
Guoxin committed
351
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
352
353
354
        for pruner_class in pruner_classes:
            for config_list in bad_configs:
                try:
355
356
357
358
359
360
361
                    kwargs = {}
                    if pruner_class in (torch_pruner.SlimPruner, torch_pruner.AGPPruner, torch_pruner.ActivationMeanRankFilterPruner, torch_pruner.ActivationAPoZRankFilterPruner):
                        kwargs = {'optimizer': None, 'trainer': None, 'criterion': None}

                    print('kwargs', kwargs)
                    pruner_class(model, config_list, **kwargs)      

362
363
364
365
366
367
368
369
370
371
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', pruner_class, config_list)
                    raise

    def test_torch_quantizer_validation(self):
        # test bad configuraiton
372
        quantizer_classes = [torch_quantizer.__dict__[x] for x in \
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
            ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']]

        bad_configs = [
            [
                {'bad_key': 'abc'}
            ],
            [
                {'quant_types': 'abc'}
            ],
            [
                {'quant_bits': 34}
            ],
            [
                {'op_types': 'default'}
            ],
            [
                {'quant_bits': {'abc': 123}}
            ]
        ]
        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        for quantizer_class in quantizer_classes:
            for config_list in bad_configs:
                try:
                    quantizer_class(model, config_list, optimizer)
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', quantizer_class, config_list)
                    raise
Tang Lang's avatar
Tang Lang committed
405

406
407
if __name__ == '__main__':
    main()