test_compressor_torch.py 23.6 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import copy
5
from unittest import TestCase, main
6
import numpy as np
7
8
import torch
import torch.nn.functional as F
9
import schema
10
11
import nni.algorithms.compression.pytorch.pruning as torch_pruner
import nni.algorithms.compression.pytorch.quantization as torch_quantizer
12
import math
13

Tang Lang's avatar
Tang Lang committed
14

15
class TorchModel(torch.nn.Module):
16
17
    def __init__(self):
        super().__init__()
18
        self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
Tang Lang's avatar
Tang Lang committed
19
        self.bn1 = torch.nn.BatchNorm2d(5)
20
        self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
Tang Lang's avatar
Tang Lang committed
21
        self.bn2 = torch.nn.BatchNorm2d(10)
22
23
        self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
        self.fc2 = torch.nn.Linear(100, 10)
24
25

    def forward(self, x):
Tang Lang's avatar
Tang Lang committed
26
        x = F.relu(self.bn1(self.conv1(x)))
27
        x = F.max_pool2d(x, 2, 2)
Tang Lang's avatar
Tang Lang committed
28
        x = F.relu(self.bn2(self.conv2(x)))
29
        x = F.max_pool2d(x, 2, 2)
30
        x = x.view(-1, 4 * 4 * 10)
31
32
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
33
34
        return F.log_softmax(x, dim=1)

Tang Lang's avatar
Tang Lang committed
35

36
class CompressorTestCase(TestCase):
37
38
39
40
41
42
    def test_torch_quantizer_modules_detection(self):
        # test if modules can be detected
        model = TorchModel()
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
43
            'op_types': ['Conv2d', 'Linear']
44
45
46
47
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
48
            'op_types': ['ReLU']
49
50
51
        }]

        model.relu = torch.nn.ReLU()
52
53
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
        quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer)
54
55
        quantizer.compress()
        modules_to_compress = quantizer.get_modules_to_compress()
Tang Lang's avatar
Tang Lang committed
56
        modules_to_compress_name = [t[0].name for t in modules_to_compress]
57
58
59
60
61
62
63
        assert "conv1" in modules_to_compress_name
        assert "conv2" in modules_to_compress_name
        assert "fc1" in modules_to_compress_name
        assert "fc2" in modules_to_compress_name
        assert "relu" in modules_to_compress_name
        assert len(modules_to_compress_name) == 5

64
65
    def test_torch_level_pruner(self):
        model = TorchModel()
chicm-ms's avatar
chicm-ms committed
66
        configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
67
        torch_pruner.LevelPruner(model, configure_list).compress()
68

69
70
    def test_torch_naive_quantizer(self):
        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
71
72
73
74
75
        configure_list = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            },
Tang Lang's avatar
Tang Lang committed
76
            'op_types': ['Conv2d', 'Linear']
Cjkkkk's avatar
Cjkkkk committed
77
        }]
78
        torch_quantizer.NaiveQuantizer(model, configure_list).compress()
79

80
81
    def test_torch_fpgm_pruner(self):
        """
chicm-ms's avatar
chicm-ms committed
82
        With filters(kernels) weights defined as above (w), it is obvious that w[4] and w[5] is the Geometric Median
83
84
85
86
        which minimize the total geometric distance by defination of Geometric Median in this paper:
        Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
        https://arxiv.org/pdf/1811.00250.pdf

chicm-ms's avatar
chicm-ms committed
87
        So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
88
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))`
89

chicm-ms's avatar
chicm-ms committed
90
        If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
91
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))`
92
        """
93
        w = np.array([np.ones((5, 5, 5)) * (i+1) for i in range(10)]).astype(np.float32)
94
95

        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
96
        config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
97
        pruner = torch_pruner.FPGMPruner(model, config_list)
98

Cjkkkk's avatar
Cjkkkk committed
99
100
        model.conv2.module.weight.data = torch.tensor(w).float()
        masks = pruner.calc_mask(model.conv2)
101
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))
102

Cjkkkk's avatar
Cjkkkk committed
103
104
105
106
        model.conv2.module.weight.data = torch.tensor(w).float()
        model.conv2.if_calculated = False
        model.conv2.config = config_list[0]
        masks = pruner.calc_mask(model.conv2)
107
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))
108

liuzhe-lz's avatar
liuzhe-lz committed
109
       
Tang Lang's avatar
Tang Lang committed
110
111
112
113
114
115
    def test_torch_l1filter_pruner(self):
        """
        Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
        PRUNING FILTERS FOR EFFICIENT CONVNETS,
        https://arxiv.org/abs/1608.08710

116
117
        So if sparsity is 0.2 for conv1, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`
Tang Lang's avatar
Tang Lang committed
118

119
120
        If sparsity is 0.6 for conv2, the expected masks should mask out filter 0,1,2, this can be verified through:
        `all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))`
Tang Lang's avatar
Tang Lang committed
121
        """
122
123
124
        w1 = np.array([np.ones((1, 5, 5))*i for i in range(5)]).astype(np.float32)
        w2 = np.array([np.ones((5, 5, 5))*i for i in range(10)]).astype(np.float32)

Tang Lang's avatar
Tang Lang committed
125
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
126
127
        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
128
        pruner = torch_pruner.L1FilterPruner(model, config_list)
Tang Lang's avatar
Tang Lang committed
129

130
131
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()
Cjkkkk's avatar
Cjkkkk committed
132
133
        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
134
135
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))
Tang Lang's avatar
Tang Lang committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

    def test_torch_slim_pruner(self):
        """
        Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
        Learning Efficient Convolutional Networks through Network Slimming,
        https://arxiv.org/pdf/1708.06519.pdf

        So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
        `all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`

        If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
        `all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
        """
        w = np.array([0, 1, 2, 3, 4])
        model = TorchModel()
        config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(-w).float()
156
        pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
Tang Lang's avatar
Tang Lang committed
157

Cjkkkk's avatar
Cjkkkk committed
158
159
160
161
162
163
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
164

Cjkkkk's avatar
Cjkkkk committed
165
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
166
167
168
        config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(w).float()
169
        pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
Tang Lang's avatar
Tang Lang committed
170

Cjkkkk's avatar
Cjkkkk committed
171
172
173
174
175
176
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    def test_torch_taylorFOweight_pruner(self):
        """
        Filters with the minimum importance approxiamtion based on the first order 
        taylor expansion on the weights (w*grad)**2 are pruned in this paper:
        Importance Estimation for Neural Network Pruning,
        http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf

        So if sparsity of conv1 is 0.2, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`

        If sparsity of conv2 is 0.6, the expected masks should mask out filter 4,5,6,7,8,9 this can be verified through:
        `all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))`
        """

        w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
                      np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
        w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])

        grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
                      np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])

        grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])

        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]

        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
J-shang's avatar
J-shang committed
206
        pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1)
207

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        x = torch.rand((1, 1, 28, 28), requires_grad=True)
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()

        y = model(x)
        y.backward(torch.ones_like(y))

        model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
        model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
        optimizer.step()

        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    def test_torch_taylorFOweight_pruner_global_sort(self):
        """
        After enabling global_sort, taylorFOweight pruner will calculate contributions and rank topk from all
        of the conv operators. Then it will prune low contribution filters depends on the global information.

        So if sparsity of conv operator is 0.4, the expected masks should mask out filter 0 and filter 1 together, 
        this can be verified through:
        `all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))`
        `all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))`
        """

        w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
                      np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
        w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])

        grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
                      np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])

        grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])

        config_list = [{'sparsity': 0.4, 'op_types': ['Conv2d']}]

        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
        pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1, global_sort=True)

        x = torch.rand((1, 1, 28, 28), requires_grad=True)
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()

        y = model(x)
        y.backward(torch.ones_like(y))

        model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
        model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
        optimizer.step()

        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
        print(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy())
        print(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy())
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    def test_torch_observer_quantizer(self):
        model = TorchModel()
        # test invalid config
        # only support 8bit for now
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 5,
            'op_types': ['Conv2d', 'Linear']
        }]
        with self.assertRaises(schema.SchemaError):
            torch_quantizer.ObserverQuantizer(model, config_list)

        # weight will not change for now
        model = TorchModel().eval()
        origin_parameters = copy.deepcopy(dict(model.named_parameters()))

        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
            'op_types': ['Conv2d', 'Linear']
        }]
        quantizer = torch_quantizer.ObserverQuantizer(model, config_list)
        input = torch.randn(1, 1, 28, 28)
        model(input)
        quantizer.compress()
        buffers = dict(model.named_buffers())
        scales = {k: v for k, v in buffers.items() if 'scale' in k}
        model_path = "test_model.pth"
        calibration_path = "test_calibration.pth"
        calibration_config = quantizer.export_model(model_path, calibration_path)
        new_parameters = dict(model.named_parameters())
        for layer_name, v in calibration_config.items():
            scale_name = layer_name + '.module.weight_scale'
            weight_name = layer_name + '.weight'
            s = float(scales[scale_name])
            self.assertTrue(torch.allclose(origin_parameters[weight_name], new_parameters[weight_name], atol=0.5 * s))

        self.assertTrue(calibration_config is not None)
        self.assertTrue(len(calibration_config) == 4)

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    def test_torch_quantizer_weight_type(self):
        quantizer_list = [
            torch_quantizer.QAT_Quantizer,
            torch_quantizer.LsqQuantizer,
            torch_quantizer.ObserverQuantizer,
            torch_quantizer.NaiveQuantizer,
            torch_quantizer.DoReFaQuantizer]
        for quantizer_type in quantizer_list:
            model = TorchModel().eval()
            config_list = [{
                'quant_types': ['weight'],
                'quant_bits': 8,
                'op_types': ['Conv2d', 'Linear']
            }]

            optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
            dummy = torch.randn(1, 1, 28, 28)
            if quantizer_type == torch_quantizer.QAT_Quantizer:
                quantizer_type(model, config_list, optimizer, dummy_input=dummy)
            else:
                quantizer_type(model, config_list, optimizer)

            self.assertFalse(isinstance(model.conv1.module.weight, torch.nn.Parameter))
            self.assertFalse(isinstance(model.conv2.module.weight, torch.nn.Parameter))
            self.assertFalse(isinstance(model.fc1.module.weight, torch.nn.Parameter))
            self.assertFalse(isinstance(model.fc2.module.weight, torch.nn.Parameter))

335
336
337
    def test_torch_QAT_quantizer(self):
        model = TorchModel()
        config_list = [{
338
            'quant_types': ['weight', 'input'],
339
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
340
            'op_types': ['Conv2d', 'Linear']
341
342
343
344
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
345
            'op_types': ['ReLU']
346
347
        }]
        model.relu = torch.nn.ReLU()
348
349
350

        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
        quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer)
351
        quantizer.compress()
352

353
354
355
        # test quantize
        # range not including 0
        eps = 1e-7
356
        input = torch.tensor([[1, 4], [2, 1]])
357
        weight = torch.tensor([[1, 2], [3, 5]]).float()
358
        model.conv2.module.weight.data = weight
359
        quantizer.quantize_weight(model.conv2, input_tensor=input)
Cjkkkk's avatar
Cjkkkk committed
360
361
        assert math.isclose(model.conv2.module.scale, 5 / 255, abs_tol=eps)
        assert model.conv2.module.zero_point == 0
362
363
364
        quantizer.quantize_input(input, model.conv2)
        self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.04 / 255])))
        self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.])))
Tang Lang's avatar
Tang Lang committed
365
        # range including 0
366
        weight = torch.tensor([[-1, 2], [3, 5]]).float()
367
        model.conv2.module.weight = weight
368
        quantizer.quantize_weight(model.conv2, input_tensor=input)
Cjkkkk's avatar
Cjkkkk committed
369
370
        assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
        assert model.conv2.module.zero_point in (42, 43)
371
372
373
        quantizer.quantize_input(input, model.conv2)
        self.assertTrue(torch.allclose(model.conv2.module.scale, torch.tensor([0.0796 / 255])))
        self.assertTrue(torch.equal(model.conv2.module.zero_point, torch.tensor([0.])))
374
375
376
377
378
        # test value of weight and bias after quantization
        weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
        weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
        bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
        bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
379
        model.conv2.module.weight = weight
380
        model.conv2.module.bias.data = bias
381
        quantizer.quantize_weight(model.conv2, input_tensor=input)
382
383
        assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
        assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7))
384
385

        # test ema
386
        eps = 1e-7
387
388
        x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
        out = model.relu(x)
389
390
        assert math.isclose(model.relu.module.tracked_min_output, 0, abs_tol=eps)
        assert math.isclose(model.relu.module.tracked_max_output, 0.002, abs_tol=eps)
391

Cjkkkk's avatar
Cjkkkk committed
392
        quantizer.step_with_optimizer()
393
394
        x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
        out = model.relu(x)
395
396
        assert math.isclose(model.relu.module.tracked_min_output, 0.002, abs_tol=eps)
        assert math.isclose(model.relu.module.tracked_max_output, 0.00998, abs_tol=eps)
397

lin bin's avatar
lin bin committed
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    def test_torch_quantizer_export(self):
        config_list_qat = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
            'op_types': ['Conv2d', 'Linear']
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
            'op_types': ['ReLU']
        }]
        config_list_dorefa = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            }, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
            'op_types':['Conv2d', 'Linear']
        }]
        config_list_bnn = [{
            'quant_types': ['weight'],
            'quant_bits': 1,
            'op_types': ['Conv2d', 'Linear']
        }, {
            'quant_types': ['output'],
            'quant_bits': 1,
            'op_types': ['ReLU']
        }]
        config_set = [config_list_qat, config_list_dorefa, config_list_bnn]
        quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer]

        for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
            model = TorchModel()
            model.relu = torch.nn.ReLU()
431
432
            optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
            quantizer = quantize_algorithm(model, config, optimizer)
lin bin's avatar
lin bin committed
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
            quantizer.compress()

            x = torch.rand((1, 1, 28, 28), requires_grad=True)
            y = model(x)
            y.backward(torch.ones_like(y))

            model_path = "test_model.pth"
            calibration_path = "test_calibration.pth"
            onnx_path = "test_model.onnx"
            input_shape = (1, 1, 28, 28)
            device = torch.device("cpu")

            calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
            assert calibration_config is not None

448
449
    def test_torch_pruner_validation(self):
        # test bad configuraiton
450
        pruner_classes = [torch_pruner.__dict__[x] for x in \
Guoxin's avatar
Guoxin committed
451
            ['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGPPruner',\
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
            'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]

        bad_configs = [
            [
                {'sparsity': '0.2'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2},
                {'sparsity': 1.6 }
            ],
            [
                {'sparsity': 0.2, 'op_types': 'default'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2 },
469
                {'sparsity': 0.6, 'op_names': 'abc'}
470
471
472
            ]
        ]
        model = TorchModel()
Guoxin's avatar
Guoxin committed
473
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
474
475
476
        for pruner_class in pruner_classes:
            for config_list in bad_configs:
                try:
477
478
479
480
481
482
483
                    kwargs = {}
                    if pruner_class in (torch_pruner.SlimPruner, torch_pruner.AGPPruner, torch_pruner.ActivationMeanRankFilterPruner, torch_pruner.ActivationAPoZRankFilterPruner):
                        kwargs = {'optimizer': None, 'trainer': None, 'criterion': None}

                    print('kwargs', kwargs)
                    pruner_class(model, config_list, **kwargs)      

484
485
486
487
488
489
490
491
492
493
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', pruner_class, config_list)
                    raise

    def test_torch_quantizer_validation(self):
        # test bad configuraiton
494
        quantizer_classes = [torch_quantizer.__dict__[x] for x in \
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
            ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']]

        bad_configs = [
            [
                {'bad_key': 'abc'}
            ],
            [
                {'quant_types': 'abc'}
            ],
            [
                {'quant_bits': 34}
            ],
            [
                {'op_types': 'default'}
            ],
            [
                {'quant_bits': {'abc': 123}}
            ]
        ]
        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        for quantizer_class in quantizer_classes:
            for config_list in bad_configs:
                try:
                    quantizer_class(model, config_list, optimizer)
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', quantizer_class, config_list)
                    raise
Tang Lang's avatar
Tang Lang committed
527

528
529
if __name__ == '__main__':
    main()