"docs/vscode:/vscode.git/clone" did not exist on "c85b80c2b64d0f420aaca59679e5f38f71a8a53e"
test_compressor.py 16.7 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
from unittest import TestCase, main
5
import numpy as np
6
import tensorflow as tf
7
8
import torch
import torch.nn.functional as F
9
import schema
10
import nni.compression.torch as torch_compressor
11
import math
12

13
14
15
if tf.__version__ >= '2.0':
    import nni.compression.tensorflow as tf_compressor

Tang Lang's avatar
Tang Lang committed
16

17
def get_tf_model():
18
    model = tf.keras.models.Sequential([
19
        tf.keras.layers.Conv2D(filters=5, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
20
        tf.keras.layers.MaxPooling2D(pool_size=2),
21
        tf.keras.layers.Conv2D(filters=10, kernel_size=3, activation='relu', padding="SAME"),
22
23
24
25
26
27
28
        tf.keras.layers.MaxPooling2D(pool_size=2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=128, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(units=10, activation='softmax'),
    ])
    model.compile(loss="sparse_categorical_crossentropy",
Tang Lang's avatar
Tang Lang committed
29
30
                  optimizer=tf.keras.optimizers.SGD(lr=1e-3),
                  metrics=["accuracy"])
31
    return model
32

Tang Lang's avatar
Tang Lang committed
33

34
class TorchModel(torch.nn.Module):
35
36
    def __init__(self):
        super().__init__()
37
        self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
Tang Lang's avatar
Tang Lang committed
38
        self.bn1 = torch.nn.BatchNorm2d(5)
39
        self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
Tang Lang's avatar
Tang Lang committed
40
        self.bn2 = torch.nn.BatchNorm2d(10)
41
42
        self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
        self.fc2 = torch.nn.Linear(100, 10)
43
44

    def forward(self, x):
Tang Lang's avatar
Tang Lang committed
45
        x = F.relu(self.bn1(self.conv1(x)))
46
        x = F.max_pool2d(x, 2, 2)
Tang Lang's avatar
Tang Lang committed
47
        x = F.relu(self.bn2(self.conv2(x)))
48
        x = F.max_pool2d(x, 2, 2)
49
        x = x.view(-1, 4 * 4 * 10)
50
51
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
52
53
        return F.log_softmax(x, dim=1)

Tang Lang's avatar
Tang Lang committed
54

55
def tf2(func):
56
    def test_tf2_func(*args):
57
        if tf.__version__ >= '2.0':
58
            func(*args)
Tang Lang's avatar
Tang Lang committed
59

60
    return test_tf2_func
61
62

class CompressorTestCase(TestCase):
63
64
65
66
67
68
    def test_torch_quantizer_modules_detection(self):
        # test if modules can be detected
        model = TorchModel()
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
69
            'op_types': ['Conv2d', 'Linear']
70
71
72
73
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
74
            'op_types': ['ReLU']
75
76
77
78
79
80
        }]

        model.relu = torch.nn.ReLU()
        quantizer = torch_compressor.QAT_Quantizer(model, config_list)
        quantizer.compress()
        modules_to_compress = quantizer.get_modules_to_compress()
Tang Lang's avatar
Tang Lang committed
81
        modules_to_compress_name = [t[0].name for t in modules_to_compress]
82
83
84
85
86
87
88
        assert "conv1" in modules_to_compress_name
        assert "conv2" in modules_to_compress_name
        assert "fc1" in modules_to_compress_name
        assert "fc2" in modules_to_compress_name
        assert "relu" in modules_to_compress_name
        assert len(modules_to_compress_name) == 5

89
90
    def test_torch_level_pruner(self):
        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
91
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
chicm-ms's avatar
chicm-ms committed
92
        configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
Cjkkkk's avatar
Cjkkkk committed
93
        torch_compressor.LevelPruner(model, configure_list, optimizer).compress()
94

95
96
97
98
    @tf2
    def test_tf_level_pruner(self):
        configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
        tf_compressor.LevelPruner(get_tf_model(), configure_list).compress()
99

100
101
    def test_torch_naive_quantizer(self):
        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
102
103
104
105
106
        configure_list = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            },
Tang Lang's avatar
Tang Lang committed
107
            'op_types': ['Conv2d', 'Linear']
Cjkkkk's avatar
Cjkkkk committed
108
109
        }]
        torch_compressor.NaiveQuantizer(model, configure_list).compress()
110

111
    @tf2
112
113
    def test_tf_naive_quantizer(self):
        tf_compressor.NaiveQuantizer(get_tf_model(), [{'op_types': ['default']}]).compress()
114

115
116
    def test_torch_fpgm_pruner(self):
        """
chicm-ms's avatar
chicm-ms committed
117
        With filters(kernels) weights defined as above (w), it is obvious that w[4] and w[5] is the Geometric Median
118
119
120
121
        which minimize the total geometric distance by defination of Geometric Median in this paper:
        Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
        https://arxiv.org/pdf/1811.00250.pdf

chicm-ms's avatar
chicm-ms committed
122
        So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
123
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))`
124

chicm-ms's avatar
chicm-ms committed
125
        If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
126
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))`
127
        """
128
        w = np.array([np.ones((5, 5, 5)) * (i+1) for i in range(10)]).astype(np.float32)
129
130

        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
131
        config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
chicm-ms's avatar
chicm-ms committed
132
        pruner = torch_compressor.FPGMPruner(model, config_list, torch.optim.SGD(model.parameters(), lr=0.01))
133

Cjkkkk's avatar
Cjkkkk committed
134
135
        model.conv2.module.weight.data = torch.tensor(w).float()
        masks = pruner.calc_mask(model.conv2)
136
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))
137

Cjkkkk's avatar
Cjkkkk committed
138
139
140
141
        model.conv2.module.weight.data = torch.tensor(w).float()
        model.conv2.if_calculated = False
        model.conv2.config = config_list[0]
        masks = pruner.calc_mask(model.conv2)
142
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))
143
144
145

    @tf2
    def test_tf_fpgm_pruner(self):
146
        w = np.array([np.ones((5, 5, 5)) * (i+1) for i in range(10)]).astype(np.float32)
147
        model = get_tf_model()
Cjkkkk's avatar
Cjkkkk committed
148
        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2D']}]
149
150
151
152
153
154
155
156

        pruner = tf_compressor.FPGMPruner(model, config_list)
        weights = model.layers[2].weights
        weights[0] = np.array(w).astype(np.float32).transpose([2, 3, 0, 1]).transpose([0, 1, 3, 2])
        model.layers[2].set_weights([weights[0], weights[1].numpy()])

        layer = tf_compressor.compressor.LayerInfo(model.layers[2])
        masks = pruner.calc_mask(layer, config_list[0]).numpy()
chicm-ms's avatar
chicm-ms committed
157
        masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
158

chicm-ms's avatar
chicm-ms committed
159
        assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
Cjkkkk's avatar
Cjkkkk committed
160
        
Tang Lang's avatar
Tang Lang committed
161
162
163
164
165
166
    def test_torch_l1filter_pruner(self):
        """
        Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
        PRUNING FILTERS FOR EFFICIENT CONVNETS,
        https://arxiv.org/abs/1608.08710

167
168
        So if sparsity is 0.2 for conv1, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`
Tang Lang's avatar
Tang Lang committed
169

170
171
        If sparsity is 0.6 for conv2, the expected masks should mask out filter 0,1,2, this can be verified through:
        `all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))`
Tang Lang's avatar
Tang Lang committed
172
        """
173
174
175
        w1 = np.array([np.ones((1, 5, 5))*i for i in range(5)]).astype(np.float32)
        w2 = np.array([np.ones((5, 5, 5))*i for i in range(10)]).astype(np.float32)

Tang Lang's avatar
Tang Lang committed
176
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
177
178
        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
chicm-ms's avatar
chicm-ms committed
179
        pruner = torch_compressor.L1FilterPruner(model, config_list)
Tang Lang's avatar
Tang Lang committed
180

181
182
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()
Cjkkkk's avatar
Cjkkkk committed
183
184
        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
185
186
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))
Tang Lang's avatar
Tang Lang committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    def test_torch_slim_pruner(self):
        """
        Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
        Learning Efficient Convolutional Networks through Network Slimming,
        https://arxiv.org/pdf/1708.06519.pdf

        So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
        `all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`

        If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
        `all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
        """
        w = np.array([0, 1, 2, 3, 4])
        model = TorchModel()
        config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(-w).float()
chicm-ms's avatar
chicm-ms committed
207
        pruner = torch_compressor.SlimPruner(model, config_list)
Tang Lang's avatar
Tang Lang committed
208

Cjkkkk's avatar
Cjkkkk committed
209
210
211
212
213
214
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
215

Cjkkkk's avatar
Cjkkkk committed
216
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
217
218
219
        config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(w).float()
chicm-ms's avatar
chicm-ms committed
220
        pruner = torch_compressor.SlimPruner(model, config_list)
Tang Lang's avatar
Tang Lang committed
221

Cjkkkk's avatar
Cjkkkk committed
222
223
224
225
226
227
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
228

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    def test_torch_taylorFOweight_pruner(self):
        """
        Filters with the minimum importance approxiamtion based on the first order 
        taylor expansion on the weights (w*grad)**2 are pruned in this paper:
        Importance Estimation for Neural Network Pruning,
        http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf

        So if sparsity of conv1 is 0.2, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`

        If sparsity of conv2 is 0.6, the expected masks should mask out filter 4,5,6,7,8,9 this can be verified through:
        `all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))`
        """

        w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
                      np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
        w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])

        grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
                      np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])

        grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])

        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]

        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
        pruner = torch_compressor.TaylorFOWeightFilterPruner(model, config_list, optimizer, statistics_batch_num=1)
        
        x = torch.rand((1, 1, 28, 28), requires_grad=True)
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()

        y = model(x)
        y.backward(torch.ones_like(y))

        model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
        model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
        optimizer.step()

        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))

275
276
277
278
279
    def test_torch_QAT_quantizer(self):
        model = TorchModel()
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
280
            'op_types': ['Conv2d', 'Linear']
281
282
283
284
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
285
            'op_types': ['ReLU']
286
287
288
289
290
291
292
293
        }]
        model.relu = torch.nn.ReLU()
        quantizer = torch_compressor.QAT_Quantizer(model, config_list)
        quantizer.compress()
        # test quantize
        # range not including 0
        eps = 1e-7
        weight = torch.tensor([[1, 2], [3, 5]]).float()
Cjkkkk's avatar
Cjkkkk committed
294
295
296
        quantize_weight = quantizer.quantize_weight(weight, model.conv2)
        assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
        assert model.conv2.module.zero_point == 0
Tang Lang's avatar
Tang Lang committed
297
        # range including 0
298
        weight = torch.tensor([[-1, 2], [3, 5]]).float()
Cjkkkk's avatar
Cjkkkk committed
299
300
301
        quantize_weight = quantizer.quantize_weight(weight, model.conv2)
        assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
        assert model.conv2.module.zero_point in (42, 43)
302
303
304
305

        # test ema
        x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
        out = model.relu(x)
chicm-ms's avatar
chicm-ms committed
306
307
        assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
        assert math.isclose(model.relu.module.tracked_max_biased, 0.002, abs_tol=eps)
308

Cjkkkk's avatar
Cjkkkk committed
309
        quantizer.step_with_optimizer()
310
311
        x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
        out = model.relu(x)
chicm-ms's avatar
chicm-ms committed
312
313
        assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps)
        assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps)
314

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    def test_torch_pruner_validation(self):
        # test bad configuraiton
        pruner_classes = [torch_compressor.__dict__[x] for x in \
            ['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGP_Pruner', \
            'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]

        bad_configs = [
            [
                {'sparsity': '0.2'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2},
                {'sparsity': 1.6 }
            ],
            [
                {'sparsity': 0.2, 'op_types': 'default'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2 },
                {'sparsity': 0.6, 'op_names': 'abc' }
            ]
        ]
        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        for pruner_class in pruner_classes:
            for config_list in bad_configs:
                try:
                    pruner_class(model, config_list, optimizer)
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', pruner_class, config_list)
                    raise

    def test_torch_quantizer_validation(self):
        # test bad configuraiton
        quantizer_classes = [torch_compressor.__dict__[x] for x in \
            ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']]

        bad_configs = [
            [
                {'bad_key': 'abc'}
            ],
            [
                {'quant_types': 'abc'}
            ],
            [
                {'quant_bits': 34}
            ],
            [
                {'op_types': 'default'}
            ],
            [
                {'quant_bits': {'abc': 123}}
            ]
        ]
        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        for quantizer_class in quantizer_classes:
            for config_list in bad_configs:
                try:
                    quantizer_class(model, config_list, optimizer)
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', quantizer_class, config_list)
                    raise
Tang Lang's avatar
Tang Lang committed
388

389
390
if __name__ == '__main__':
    main()