test_compressor_torch.py 21.7 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import copy
5
from unittest import TestCase, main
6
import numpy as np
7
8
import torch
import torch.nn.functional as F
9
import schema
10
11
import nni.algorithms.compression.pytorch.pruning as torch_pruner
import nni.algorithms.compression.pytorch.quantization as torch_quantizer
12
import math
13

Tang Lang's avatar
Tang Lang committed
14

15
class TorchModel(torch.nn.Module):
16
17
    def __init__(self):
        super().__init__()
18
        self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
Tang Lang's avatar
Tang Lang committed
19
        self.bn1 = torch.nn.BatchNorm2d(5)
20
        self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
Tang Lang's avatar
Tang Lang committed
21
        self.bn2 = torch.nn.BatchNorm2d(10)
22
23
        self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
        self.fc2 = torch.nn.Linear(100, 10)
24
25

    def forward(self, x):
Tang Lang's avatar
Tang Lang committed
26
        x = F.relu(self.bn1(self.conv1(x)))
27
        x = F.max_pool2d(x, 2, 2)
Tang Lang's avatar
Tang Lang committed
28
        x = F.relu(self.bn2(self.conv2(x)))
29
        x = F.max_pool2d(x, 2, 2)
30
        x = x.view(-1, 4 * 4 * 10)
31
32
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
33
34
        return F.log_softmax(x, dim=1)

Tang Lang's avatar
Tang Lang committed
35

36
class CompressorTestCase(TestCase):
37
38
39
40
41
42
    def test_torch_quantizer_modules_detection(self):
        # test if modules can be detected
        model = TorchModel()
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
43
            'op_types': ['Conv2d', 'Linear']
44
45
46
47
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
48
            'op_types': ['ReLU']
49
50
51
        }]

        model.relu = torch.nn.ReLU()
52
        quantizer = torch_quantizer.QAT_Quantizer(model, config_list)
53
54
        quantizer.compress()
        modules_to_compress = quantizer.get_modules_to_compress()
Tang Lang's avatar
Tang Lang committed
55
        modules_to_compress_name = [t[0].name for t in modules_to_compress]
56
57
58
59
60
61
62
        assert "conv1" in modules_to_compress_name
        assert "conv2" in modules_to_compress_name
        assert "fc1" in modules_to_compress_name
        assert "fc2" in modules_to_compress_name
        assert "relu" in modules_to_compress_name
        assert len(modules_to_compress_name) == 5

63
64
    def test_torch_level_pruner(self):
        model = TorchModel()
chicm-ms's avatar
chicm-ms committed
65
        configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
66
        torch_pruner.LevelPruner(model, configure_list).compress()
67

68
69
    def test_torch_naive_quantizer(self):
        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
70
71
72
73
74
        configure_list = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            },
Tang Lang's avatar
Tang Lang committed
75
            'op_types': ['Conv2d', 'Linear']
Cjkkkk's avatar
Cjkkkk committed
76
        }]
77
        torch_quantizer.NaiveQuantizer(model, configure_list).compress()
78

79
80
    def test_torch_fpgm_pruner(self):
        """
chicm-ms's avatar
chicm-ms committed
81
        With filters(kernels) weights defined as above (w), it is obvious that w[4] and w[5] is the Geometric Median
82
83
84
85
        which minimize the total geometric distance by defination of Geometric Median in this paper:
        Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
        https://arxiv.org/pdf/1811.00250.pdf

chicm-ms's avatar
chicm-ms committed
86
        So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
87
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))`
88

chicm-ms's avatar
chicm-ms committed
89
        If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
90
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))`
91
        """
92
        w = np.array([np.ones((5, 5, 5)) * (i+1) for i in range(10)]).astype(np.float32)
93
94

        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
95
        config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
96
        pruner = torch_pruner.FPGMPruner(model, config_list)
97

Cjkkkk's avatar
Cjkkkk committed
98
99
        model.conv2.module.weight.data = torch.tensor(w).float()
        masks = pruner.calc_mask(model.conv2)
100
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))
101

Cjkkkk's avatar
Cjkkkk committed
102
103
104
105
        model.conv2.module.weight.data = torch.tensor(w).float()
        model.conv2.if_calculated = False
        model.conv2.config = config_list[0]
        masks = pruner.calc_mask(model.conv2)
106
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))
107

liuzhe-lz's avatar
liuzhe-lz committed
108
       
Tang Lang's avatar
Tang Lang committed
109
110
111
112
113
114
    def test_torch_l1filter_pruner(self):
        """
        Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
        PRUNING FILTERS FOR EFFICIENT CONVNETS,
        https://arxiv.org/abs/1608.08710

115
116
        So if sparsity is 0.2 for conv1, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`
Tang Lang's avatar
Tang Lang committed
117

118
119
        If sparsity is 0.6 for conv2, the expected masks should mask out filter 0,1,2, this can be verified through:
        `all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))`
Tang Lang's avatar
Tang Lang committed
120
        """
121
122
123
        w1 = np.array([np.ones((1, 5, 5))*i for i in range(5)]).astype(np.float32)
        w2 = np.array([np.ones((5, 5, 5))*i for i in range(10)]).astype(np.float32)

Tang Lang's avatar
Tang Lang committed
124
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
125
126
        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
127
        pruner = torch_pruner.L1FilterPruner(model, config_list)
Tang Lang's avatar
Tang Lang committed
128

129
130
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()
Cjkkkk's avatar
Cjkkkk committed
131
132
        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
133
134
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))
Tang Lang's avatar
Tang Lang committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    def test_torch_slim_pruner(self):
        """
        Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
        Learning Efficient Convolutional Networks through Network Slimming,
        https://arxiv.org/pdf/1708.06519.pdf

        So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
        `all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`

        If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
        `all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
        """
        w = np.array([0, 1, 2, 3, 4])
        model = TorchModel()
        config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(-w).float()
155
        pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
Tang Lang's avatar
Tang Lang committed
156

Cjkkkk's avatar
Cjkkkk committed
157
158
159
160
161
162
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
163

Cjkkkk's avatar
Cjkkkk committed
164
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
165
166
167
        config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(w).float()
168
        pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
Tang Lang's avatar
Tang Lang committed
169

Cjkkkk's avatar
Cjkkkk committed
170
171
172
173
174
175
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
176

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def test_torch_taylorFOweight_pruner(self):
        """
        Filters with the minimum importance approxiamtion based on the first order 
        taylor expansion on the weights (w*grad)**2 are pruned in this paper:
        Importance Estimation for Neural Network Pruning,
        http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf

        So if sparsity of conv1 is 0.2, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`

        If sparsity of conv2 is 0.6, the expected masks should mask out filter 4,5,6,7,8,9 this can be verified through:
        `all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))`
        """

        w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
                      np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
        w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])

        grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
                      np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])

        grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])

        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]

        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
J-shang's avatar
J-shang committed
205
        pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1)
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        x = torch.rand((1, 1, 28, 28), requires_grad=True)
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()

        y = model(x)
        y.backward(torch.ones_like(y))

        model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
        model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
        optimizer.step()

        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    def test_torch_taylorFOweight_pruner_global_sort(self):
        """
        After enabling global_sort, taylorFOweight pruner will calculate contributions and rank topk from all
        of the conv operators. Then it will prune low contribution filters depends on the global information.

        So if sparsity of conv operator is 0.4, the expected masks should mask out filter 0 and filter 1 together, 
        this can be verified through:
        `all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))`
        `all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))`
        """

        w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
                      np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
        w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])

        grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
                      np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])

        grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])

        config_list = [{'sparsity': 0.4, 'op_types': ['Conv2d']}]

        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
        pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1, global_sort=True)

        x = torch.rand((1, 1, 28, 28), requires_grad=True)
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()

        y = model(x)
        y.backward(torch.ones_like(y))

        model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
        model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
        optimizer.step()

        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
        print(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy())
        print(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy())
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    def test_torch_observer_quantizer(self):
        model = TorchModel()
        # test invalid config
        # only support 8bit for now
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 5,
            'op_types': ['Conv2d', 'Linear']
        }]
        with self.assertRaises(schema.SchemaError):
            torch_quantizer.ObserverQuantizer(model, config_list)

        # weight will not change for now
        model = TorchModel().eval()
        origin_parameters = copy.deepcopy(dict(model.named_parameters()))

        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
            'op_types': ['Conv2d', 'Linear']
        }]
        quantizer = torch_quantizer.ObserverQuantizer(model, config_list)
        input = torch.randn(1, 1, 28, 28)
        model(input)
        quantizer.compress()
        buffers = dict(model.named_buffers())
        scales = {k: v for k, v in buffers.items() if 'scale' in k}
        model_path = "test_model.pth"
        calibration_path = "test_calibration.pth"
        calibration_config = quantizer.export_model(model_path, calibration_path)
        new_parameters = dict(model.named_parameters())
        for layer_name, v in calibration_config.items():
            scale_name = layer_name + '.module.weight_scale'
            weight_name = layer_name + '.weight'
            s = float(scales[scale_name])
            self.assertTrue(torch.allclose(origin_parameters[weight_name], new_parameters[weight_name], atol=0.5 * s))

        self.assertTrue(calibration_config is not None)
        self.assertTrue(len(calibration_config) == 4)

307
308
309
310
311
    def test_torch_QAT_quantizer(self):
        model = TorchModel()
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
312
            'op_types': ['Conv2d', 'Linear']
313
314
315
316
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
317
            'op_types': ['ReLU']
318
319
        }]
        model.relu = torch.nn.ReLU()
320
        quantizer = torch_quantizer.QAT_Quantizer(model, config_list)
321
        quantizer.compress()
322

323
324
325
        # test quantize
        # range not including 0
        eps = 1e-7
326
        input = torch.tensor([[0, 4], [2, 1]]).float()
327
        weight = torch.tensor([[1, 2], [3, 5]]).float()
328
        model.conv2.module.old_weight.data = weight
329
        quantizer.quantize_weight(model.conv2, input_tensor=input)
Cjkkkk's avatar
Cjkkkk committed
330
331
        assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
        assert model.conv2.module.zero_point == 0
Tang Lang's avatar
Tang Lang committed
332
        # range including 0
333
        weight = torch.tensor([[-1, 2], [3, 5]]).float()
334
        model.conv2.module.old_weight.data = weight
335
        quantizer.quantize_weight(model.conv2, input_tensor=input)
Cjkkkk's avatar
Cjkkkk committed
336
337
        assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
        assert model.conv2.module.zero_point in (42, 43)
338
339
340
341
342
343
344
        # test value of weight and bias after quantization
        weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
        weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
        bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
        bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
        model.conv2.module.old_weight.data = weight
        model.conv2.module.bias.data = bias
345
        quantizer.quantize_weight(model.conv2, input_tensor=input)
346
347
        assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
        assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7))
348
349

        # test ema
350
        eps = 1e-7
351
352
        x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
        out = model.relu(x)
353
354
        assert math.isclose(model.relu.module.tracked_min_activation, 0, abs_tol=eps)
        assert math.isclose(model.relu.module.tracked_max_activation, 0.002, abs_tol=eps)
355

Cjkkkk's avatar
Cjkkkk committed
356
        quantizer.step_with_optimizer()
357
358
        x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
        out = model.relu(x)
359
360
        assert math.isclose(model.relu.module.tracked_min_activation, 0.002, abs_tol=eps)
        assert math.isclose(model.relu.module.tracked_max_activation, 0.00998, abs_tol=eps)
361

lin bin's avatar
lin bin committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    def test_torch_quantizer_export(self):
        config_list_qat = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
            'op_types': ['Conv2d', 'Linear']
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
            'op_types': ['ReLU']
        }]
        config_list_dorefa = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            }, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
            'op_types':['Conv2d', 'Linear']
        }]
        config_list_bnn = [{
            'quant_types': ['weight'],
            'quant_bits': 1,
            'op_types': ['Conv2d', 'Linear']
        }, {
            'quant_types': ['output'],
            'quant_bits': 1,
            'op_types': ['ReLU']
        }]
        config_set = [config_list_qat, config_list_dorefa, config_list_bnn]
        quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer]

        for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
            model = TorchModel()
            model.relu = torch.nn.ReLU()
            quantizer = quantize_algorithm(model, config)
            quantizer.compress()

            x = torch.rand((1, 1, 28, 28), requires_grad=True)
            y = model(x)
            y.backward(torch.ones_like(y))

            model_path = "test_model.pth"
            calibration_path = "test_calibration.pth"
            onnx_path = "test_model.onnx"
            input_shape = (1, 1, 28, 28)
            device = torch.device("cpu")

            calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
            assert calibration_config is not None

411
412
    def test_torch_pruner_validation(self):
        # test bad configuraiton
413
        pruner_classes = [torch_pruner.__dict__[x] for x in \
Guoxin's avatar
Guoxin committed
414
            ['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGPPruner',\
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
            'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]

        bad_configs = [
            [
                {'sparsity': '0.2'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2},
                {'sparsity': 1.6 }
            ],
            [
                {'sparsity': 0.2, 'op_types': 'default'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2 },
432
                {'sparsity': 0.6, 'op_names': 'abc'}
433
434
435
            ]
        ]
        model = TorchModel()
Guoxin's avatar
Guoxin committed
436
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
437
438
439
        for pruner_class in pruner_classes:
            for config_list in bad_configs:
                try:
440
441
442
443
444
445
446
                    kwargs = {}
                    if pruner_class in (torch_pruner.SlimPruner, torch_pruner.AGPPruner, torch_pruner.ActivationMeanRankFilterPruner, torch_pruner.ActivationAPoZRankFilterPruner):
                        kwargs = {'optimizer': None, 'trainer': None, 'criterion': None}

                    print('kwargs', kwargs)
                    pruner_class(model, config_list, **kwargs)      

447
448
449
450
451
452
453
454
455
456
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', pruner_class, config_list)
                    raise

    def test_torch_quantizer_validation(self):
        # test bad configuraiton
457
        quantizer_classes = [torch_quantizer.__dict__[x] for x in \
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
            ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']]

        bad_configs = [
            [
                {'bad_key': 'abc'}
            ],
            [
                {'quant_types': 'abc'}
            ],
            [
                {'quant_bits': 34}
            ],
            [
                {'op_types': 'default'}
            ],
            [
                {'quant_bits': {'abc': 123}}
            ]
        ]
        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        for quantizer_class in quantizer_classes:
            for config_list in bad_configs:
                try:
                    quantizer_class(model, config_list, optimizer)
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', quantizer_class, config_list)
                    raise
Tang Lang's avatar
Tang Lang committed
490

491
492
if __name__ == '__main__':
    main()