test_compressor_torch.py 32.8 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import copy
5
from unittest import TestCase, main
6
import numpy as np
7
8
import torch
import torch.nn.functional as F
9
import schema
10
11
import nni.algorithms.compression.pytorch.pruning as torch_pruner
import nni.algorithms.compression.pytorch.quantization as torch_quantizer
12
from nni.compression.pytorch.quantization.utils import calculate_qmin_qmax, get_quant_shape
13
import math
14

Tang Lang's avatar
Tang Lang committed
15

16
class TorchModel(torch.nn.Module):
17
18
    def __init__(self):
        super().__init__()
19
        self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
Tang Lang's avatar
Tang Lang committed
20
        self.bn1 = torch.nn.BatchNorm2d(5)
21
        self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
Tang Lang's avatar
Tang Lang committed
22
        self.bn2 = torch.nn.BatchNorm2d(10)
23
24
        self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
        self.fc2 = torch.nn.Linear(100, 10)
25
26

    def forward(self, x):
Tang Lang's avatar
Tang Lang committed
27
        x = F.relu(self.bn1(self.conv1(x)))
28
        x = F.max_pool2d(x, 2, 2)
Tang Lang's avatar
Tang Lang committed
29
        x = F.relu(self.bn2(self.conv2(x)))
30
        x = F.max_pool2d(x, 2, 2)
31
        x = x.view(-1, 4 * 4 * 10)
32
33
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
34
35
        return F.log_softmax(x, dim=1)

Tang Lang's avatar
Tang Lang committed
36

37
class CompressorTestCase(TestCase):
38
39
40
41
42
43
    def test_torch_quantizer_modules_detection(self):
        # test if modules can be detected
        model = TorchModel()
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
44
            'op_types': ['Conv2d', 'Linear']
45
46
47
48
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
49
            'op_types': ['ReLU']
50
51
52
        }]

        model.relu = torch.nn.ReLU()
53
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
54
55
        dummy = torch.randn(1, 1, 28, 28)
        quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
56
57
        quantizer.compress()
        modules_to_compress = quantizer.get_modules_to_compress()
Tang Lang's avatar
Tang Lang committed
58
        modules_to_compress_name = [t[0].name for t in modules_to_compress]
59
60
61
62
63
64
65
        assert "conv1" in modules_to_compress_name
        assert "conv2" in modules_to_compress_name
        assert "fc1" in modules_to_compress_name
        assert "fc2" in modules_to_compress_name
        assert "relu" in modules_to_compress_name
        assert len(modules_to_compress_name) == 5

66
67
    def test_torch_level_pruner(self):
        model = TorchModel()
chicm-ms's avatar
chicm-ms committed
68
        configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
69
        torch_pruner.LevelPruner(model, configure_list).compress()
70

71
72
    def test_torch_naive_quantizer(self):
        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
73
74
75
76
77
        configure_list = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            },
Tang Lang's avatar
Tang Lang committed
78
            'op_types': ['Conv2d', 'Linear']
Cjkkkk's avatar
Cjkkkk committed
79
        }]
80
        torch_quantizer.NaiveQuantizer(model, configure_list).compress()
81

82
83
    def test_torch_fpgm_pruner(self):
        """
chicm-ms's avatar
chicm-ms committed
84
        With filters(kernels) weights defined as above (w), it is obvious that w[4] and w[5] is the Geometric Median
85
86
87
88
        which minimize the total geometric distance by defination of Geometric Median in this paper:
        Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration,
        https://arxiv.org/pdf/1811.00250.pdf

chicm-ms's avatar
chicm-ms committed
89
        So if sparsity is 0.2, the expected masks should mask out w[4] and w[5], this can be verified through:
90
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))`
91

chicm-ms's avatar
chicm-ms committed
92
        If sparsity is 0.6, the expected masks should mask out w[2] - w[7], this can be verified through:
93
        `all(torch.sum(masks, (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))`
94
        """
95
        w = np.array([np.ones((5, 5, 5)) * (i+1) for i in range(10)]).astype(np.float32)
96
97

        model = TorchModel()
Cjkkkk's avatar
Cjkkkk committed
98
        config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
99
        pruner = torch_pruner.FPGMPruner(model, config_list)
100

Cjkkkk's avatar
Cjkkkk committed
101
102
        model.conv2.module.weight.data = torch.tensor(w).float()
        masks = pruner.calc_mask(model.conv2)
103
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 125., 125., 125., 125.]))
104

Cjkkkk's avatar
Cjkkkk committed
105
106
107
108
        model.conv2.module.weight.data = torch.tensor(w).float()
        model.conv2.if_calculated = False
        model.conv2.config = config_list[0]
        masks = pruner.calc_mask(model.conv2)
109
        assert all(torch.sum(masks['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 0., 0., 0., 0., 0., 0., 125., 125.]))
110

liuzhe-lz's avatar
liuzhe-lz committed
111
       
Tang Lang's avatar
Tang Lang committed
112
113
114
115
116
117
    def test_torch_l1filter_pruner(self):
        """
        Filters with the minimum sum of the weights' L1 norm are pruned in this paper:
        PRUNING FILTERS FOR EFFICIENT CONVNETS,
        https://arxiv.org/abs/1608.08710

118
119
        So if sparsity is 0.2 for conv1, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1, (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`
Tang Lang's avatar
Tang Lang committed
120

121
122
        If sparsity is 0.6 for conv2, the expected masks should mask out filter 0,1,2, this can be verified through:
        `all(torch.sum(mask2, (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))`
Tang Lang's avatar
Tang Lang committed
123
        """
124
125
126
        w1 = np.array([np.ones((1, 5, 5))*i for i in range(5)]).astype(np.float32)
        w2 = np.array([np.ones((5, 5, 5))*i for i in range(10)]).astype(np.float32)

Tang Lang's avatar
Tang Lang committed
127
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
128
129
        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
130
        pruner = torch_pruner.L1FilterPruner(model, config_list)
Tang Lang's avatar
Tang Lang committed
131

132
133
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()
Cjkkkk's avatar
Cjkkkk committed
134
135
        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
136
137
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 0., 0., 0., 125., 125., 125., 125.]))
Tang Lang's avatar
Tang Lang committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

    def test_torch_slim_pruner(self):
        """
        Scale factors with minimum l1 norm in the BN layers are pruned in this paper:
        Learning Efficient Convolutional Networks through Network Slimming,
        https://arxiv.org/pdf/1708.06519.pdf

        So if sparsity is 0.2, the expected masks should mask out channel 0, this can be verified through:
        `all(mask1.numpy() == np.array([0., 1., 1., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 1., 1., 1., 1.]))`

        If sparsity is 0.6, the expected masks should mask out channel 0,1,2, this can be verified through:
        `all(mask1.numpy() == np.array([0., 0., 0., 1., 1.]))`
        `all(mask2.numpy() == np.array([0., 0., 0., 1., 1.]))`
        """
        w = np.array([0, 1, 2, 3, 4])
        model = TorchModel()
        config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(-w).float()
158
        pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
Tang Lang's avatar
Tang Lang committed
159

Cjkkkk's avatar
Cjkkkk committed
160
161
162
163
164
165
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 1., 1., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
166

Cjkkkk's avatar
Cjkkkk committed
167
        model = TorchModel()
Tang Lang's avatar
Tang Lang committed
168
169
170
        config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
        model.bn1.weight.data = torch.tensor(w).float()
        model.bn2.weight.data = torch.tensor(w).float()
171
        pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
Tang Lang's avatar
Tang Lang committed
172

Cjkkkk's avatar
Cjkkkk committed
173
174
175
176
177
178
        mask1 = pruner.calc_mask(model.bn1)
        mask2 = pruner.calc_mask(model.bn2)
        assert all(mask1['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['weight_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask1['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
        assert all(mask2['bias_mask'].numpy() == np.array([0., 0., 0., 1., 1.]))
Tang Lang's avatar
Tang Lang committed
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    def test_torch_taylorFOweight_pruner(self):
        """
        Filters with the minimum importance approxiamtion based on the first order 
        taylor expansion on the weights (w*grad)**2 are pruned in this paper:
        Importance Estimation for Neural Network Pruning,
        http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf

        So if sparsity of conv1 is 0.2, the expected masks should mask out filter 0, this can be verified through:
        `all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))`

        If sparsity of conv2 is 0.6, the expected masks should mask out filter 4,5,6,7,8,9 this can be verified through:
        `all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))`
        """

        w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
                      np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
        w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])

        grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
                      np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])

        grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])

        config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
                       {'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]

        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
J-shang's avatar
J-shang committed
208
        pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1)
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        x = torch.rand((1, 1, 28, 28), requires_grad=True)
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()

        y = model(x)
        y.backward(torch.ones_like(y))

        model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
        model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
        optimizer.step()

        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    def test_torch_taylorFOweight_pruner_global_sort(self):
        """
        After enabling global_sort, taylorFOweight pruner will calculate contributions and rank topk from all
        of the conv operators. Then it will prune low contribution filters depends on the global information.

        So if sparsity of conv operator is 0.4, the expected masks should mask out filter 0 and filter 1 together, 
        this can be verified through:
        `all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))`
        `all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))`
        """

        w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
                      np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
        w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])

        grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
                      np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])

        grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])

        config_list = [{'sparsity': 0.4, 'op_types': ['Conv2d']}]

        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
        pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1, global_sort=True)

        x = torch.rand((1, 1, 28, 28), requires_grad=True)
        model.conv1.module.weight.data = torch.tensor(w1).float()
        model.conv2.module.weight.data = torch.tensor(w2).float()

        y = model(x)
        y.backward(torch.ones_like(y))

        model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
        model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
        optimizer.step()

        mask1 = pruner.calc_mask(model.conv1)
        mask2 = pruner.calc_mask(model.conv2)
        print(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy())
        print(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy())
        assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))
        assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    def test_torch_observer_quantizer(self):
        model = TorchModel()
        # test invalid config
        # only support 8bit for now
        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 5,
            'op_types': ['Conv2d', 'Linear']
        }]
        with self.assertRaises(schema.SchemaError):
            torch_quantizer.ObserverQuantizer(model, config_list)

        # weight will not change for now
        model = TorchModel().eval()
        origin_parameters = copy.deepcopy(dict(model.named_parameters()))

        config_list = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
            'op_types': ['Conv2d', 'Linear']
        }]
        quantizer = torch_quantizer.ObserverQuantizer(model, config_list)
        input = torch.randn(1, 1, 28, 28)
        model(input)
        quantizer.compress()
        buffers = dict(model.named_buffers())
        scales = {k: v for k, v in buffers.items() if 'scale' in k}
        model_path = "test_model.pth"
        calibration_path = "test_calibration.pth"
        calibration_config = quantizer.export_model(model_path, calibration_path)
        new_parameters = dict(model.named_parameters())
        for layer_name, v in calibration_config.items():
            scale_name = layer_name + '.module.weight_scale'
            weight_name = layer_name + '.weight'
            s = float(scales[scale_name])
            self.assertTrue(torch.allclose(origin_parameters[weight_name], new_parameters[weight_name], atol=0.5 * s))

        self.assertTrue(calibration_config is not None)
        self.assertTrue(len(calibration_config) == 4)

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    def test_torch_quantizer_weight_type(self):
        quantizer_list = [
            torch_quantizer.QAT_Quantizer,
            torch_quantizer.LsqQuantizer,
            torch_quantizer.ObserverQuantizer,
            torch_quantizer.NaiveQuantizer,
            torch_quantizer.DoReFaQuantizer]
        for quantizer_type in quantizer_list:
            model = TorchModel().eval()
            config_list = [{
                'quant_types': ['weight'],
                'quant_bits': 8,
                'op_types': ['Conv2d', 'Linear']
            }]

            optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
            dummy = torch.randn(1, 1, 28, 28)
            if quantizer_type == torch_quantizer.QAT_Quantizer:
                quantizer_type(model, config_list, optimizer, dummy_input=dummy)
            else:
                quantizer_type(model, config_list, optimizer)

            self.assertFalse(isinstance(model.conv1.module.weight, torch.nn.Parameter))
            self.assertFalse(isinstance(model.conv2.module.weight, torch.nn.Parameter))
            self.assertFalse(isinstance(model.fc1.module.weight, torch.nn.Parameter))
            self.assertFalse(isinstance(model.fc2.module.weight, torch.nn.Parameter))

337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    def test_quantization_dtype_scheme(self):
        class TestModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = torch.nn.Conv2d(1, 2, 3, 1)
                self.bn1 = torch.nn.BatchNorm2d(2)

            def forward(self, x):
                x = self.bn1(self.conv1(x))
                return x
        dtypes = ['int', 'uint']
        qschemes = ['per_tensor_affine', 'per_tensor_symmetric', 'per_channel_affine', 'per_channel_symmetric']
        for dtype in dtypes:
            for qscheme in qschemes:
                config_list = [{
                    'quant_types': ['weight', 'input'],
                    'quant_bits': 8,
                    'op_types': ['Conv2d'],
                    'quant_dtype': dtype,
                    'quant_scheme': qscheme
                }]
                model = TestModel()
                optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
                # only QAT_quantizer is supported for now
                dummy = torch.randn(1, 1, 4, 4)
                quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)

                # test layer setting
                for layer, config in quantizer.modules_to_compress:
                    module = layer.module
                    name = layer.name
                    layer_setting = module.layer_quant_setting
                    qmin, qmax = calculate_qmin_qmax(8, dtype)
                    all_quant_types = ['input', 'weight']
                    for quant_type in all_quant_types:
                        # check for settings
                        tensor_setting = getattr(layer_setting, quant_type)
                        self.assertTrue(tensor_setting is not None)
                        self.assertTrue(tensor_setting.quant_scheme == qscheme)
                        self.assertTrue(tensor_setting.quant_dtype == dtype)
                        self.assertTrue(tensor_setting.qmin == qmin)
                        self.assertTrue(tensor_setting.qmax == qmax)

                        input_shape, output_shape = quantizer.all_shapes[name]

                        shape = input_shape if quant_type == 'input' else module.weight.shape
                        quant_shape = get_quant_shape(shape, quant_type, qscheme)
                        scale_name = quant_type + '_scale'
                        zero_point_name = quant_type + '_zero_point'
                        scale = getattr(module, scale_name)
                        zero_point = getattr(module, zero_point_name)
                        self.assertTrue(list(scale.shape) == quant_shape)
                        self.assertTrue(list(zero_point.shape) == quant_shape)

                    weight = torch.arange(start=1, end=19).view(2, 1, 3, 3)
                    if qscheme == 'per_channel_symmetric':
                        if dtype == 'int':
                            target_scale = torch.tensor([9. / 127, 18. / 127]).view([2, 1, 1, 1])
                            target_zero_point = torch.ones([2, 1, 1, 1]) * 0
                        else:
                            target_scale = torch.tensor([9. / 127.5, 18. / 127.5]).view([2, 1, 1, 1])
                            target_zero_point = torch.ones([2, 1, 1, 1]) * 127
                    elif qscheme == 'per_tensor_symmetric':
                        if dtype == 'int':
401
402
                            target_scale = torch.tensor([18. / 127])
                            target_zero_point = torch.zeros([1])
403
                        else:
404
405
                            target_scale = torch.tensor([18. / 127.5])
                            target_zero_point = torch.ones([1]) * 127
406
407
408
409
410
411
412
413
414
415
                    elif qscheme == 'per_channel_affine':
                        min_val = torch.tensor([0., 0.]).view([2, 1, 1, 1])
                        if dtype == 'int':
                            target_scale = torch.tensor([9. / 254, 18. / 254]).view([2, 1, 1, 1])
                            target_zero_point = -127 - torch.round(min_val / target_scale)
                        else:
                            target_scale = torch.tensor([9. / 255, 18. / 255]).view([2, 1, 1, 1])
                            target_zero_point = 0 - torch.round(min_val / target_scale)
                    else:
                        if dtype == 'int':
416
                            target_scale = torch.tensor([18. / 254])
417
418
                            target_zero_point = -127 - torch.round(0 / target_scale)
                        else:
419
                            target_scale = torch.tensor([18. / 255])
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
                            target_zero_point = 0 - torch.round(0 / target_scale)
                    wrapper = getattr(model, name)
                    wrapper.module.weight = weight
                    quantizer.quantize_weight(wrapper)
                    self.assertTrue(torch.equal(getattr(model, name).module.weight_scale, target_scale))
                    self.assertTrue(torch.equal(getattr(model, name).module.weight_zero_point, target_zero_point))

                    inp = torch.arange(start=0, end=16).view(1, 1, 4, 4)
                    if qscheme == 'per_channel_symmetric':
                        if dtype == 'int':
                            target_scale = torch.tensor([15. / 127]).view([1, 1, 1, 1])
                            target_zero_point = torch.ones([1, 1, 1, 1]) * 0
                        else:
                            target_scale = torch.tensor([15. / 127.5]).view([1, 1, 1, 1])
                            target_zero_point = torch.ones([1, 1, 1, 1]) * 127
                    elif qscheme == 'per_tensor_symmetric':
                        if dtype == 'int':
437
438
                            target_scale = torch.tensor([15. / 127])
                            target_zero_point = torch.zeros([1])
439
                        else:
440
441
                            target_scale = torch.tensor([15. / 127.5])
                            target_zero_point = torch.ones([1]) * 127
442
443
444
445
446
447
448
449
450
451
                    elif qscheme == 'per_channel_affine':
                        min_val = torch.tensor([0.]).view([1, 1, 1, 1])
                        if dtype == 'int':
                            target_scale = torch.tensor([15. / 254]).view([1, 1, 1, 1])
                            target_zero_point = -127 - torch.round(min_val / target_scale)
                        else:
                            target_scale = torch.tensor([15. / 255]).view([1, 1, 1, 1])
                            target_zero_point = 0 - torch.round(min_val / target_scale)
                    else:
                        if dtype == 'int':
452
                            target_scale = torch.tensor([15. / 254])
453
454
                            target_zero_point = -127 - torch.round(0 / target_scale)
                        else:
455
                            target_scale = torch.tensor([15. / 255])
456
457
458
459
460
                            target_zero_point = 0 - torch.round(0 / target_scale)
                    quantizer.quantize_input(inp, wrapper)
                    self.assertTrue(torch.equal(getattr(model, name).module.input_scale, target_scale))
                    self.assertTrue(torch.equal(getattr(model, name).module.input_zero_point, target_zero_point))

461
462
463
    def test_torch_QAT_quantizer(self):
        model = TorchModel()
        config_list = [{
464
            'quant_types': ['weight', 'input'],
465
            'quant_bits': 8,
Tang Lang's avatar
Tang Lang committed
466
            'op_types': ['Conv2d', 'Linear']
467
468
469
470
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
Tang Lang's avatar
Tang Lang committed
471
            'op_types': ['ReLU']
472
473
        }]
        model.relu = torch.nn.ReLU()
474
475

        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
476
477
        dummy = torch.randn(1, 1, 28, 28)
        quantizer = torch_quantizer.QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy)
478
        quantizer.compress()
479

480
481
482
        # test quantize
        # range not including 0
        eps = 1e-7
483
        input = torch.tensor([[1, 4], [2, 1]])
484
        weight = torch.tensor([[1, 2], [3, 5]]).float()
485
        model.conv2.module.weight.data = weight
486
        quantizer.quantize_weight(model.conv2, input_tensor=input)
487
488
        assert math.isclose(model.conv2.module.weight_scale, 5 / 255, abs_tol=eps)
        assert model.conv2.module.weight_zero_point == 0
489
        quantizer.quantize_input(input, model.conv2)
490
        self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
491
        self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor([0.])))
Tang Lang's avatar
Tang Lang committed
492
        # range including 0
493
        weight = torch.tensor([[-1, 2], [3, 5]]).float()
494
        model.conv2.module.weight = weight
495
        quantizer.quantize_weight(model.conv2, input_tensor=input)
496
497
        assert math.isclose(model.conv2.module.weight_scale, 6 / 255, abs_tol=eps)
        assert model.conv2.module.weight_zero_point in (42, 43)
498
        quantizer.quantize_input(input, model.conv2)
499
        self.assertTrue(torch.allclose(model.conv2.module.input_scale, torch.tensor([4. / 255])))
500
        self.assertTrue(torch.equal(model.conv2.module.input_zero_point, torch.tensor([0.])))
501
502
503
504
505
        # test value of weight and bias after quantization
        weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
        weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
        bias = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
        bias_valid = torch.tensor([2.3432, 3.4342, 1.3414, 5.2341])
506
        model.conv2.module.weight = weight
507
        model.conv2.module.bias.data = bias
508
        quantizer.quantize_weight(model.conv2, input_tensor=input)
509
510
        assert torch.all(torch.isclose(model.conv2.module.weight.data, weight_valid, rtol=1e-4))
        assert torch.all(torch.isclose(model.conv2.module.bias.data, bias_valid, rtol=1e-7))
511
512

        # test ema
513
        eps = 1e-7
514
        x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
515
        model.relu(x)
516
517
        self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor([0.])))
        self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor([0.2])))
518

Cjkkkk's avatar
Cjkkkk committed
519
        quantizer.step_with_optimizer()
520
        x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
521
        model.relu(x)
522
523
        self.assertTrue(torch.equal(model.relu.module.tracked_min_output, torch.tensor([0.002])))
        self.assertTrue(torch.equal(model.relu.module.tracked_max_output, torch.tensor([0.2060])))
524

lin bin's avatar
lin bin committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    def test_torch_quantizer_export(self):
        config_list_qat = [{
            'quant_types': ['weight'],
            'quant_bits': 8,
            'op_types': ['Conv2d', 'Linear']
        }, {
            'quant_types': ['output'],
            'quant_bits': 8,
            'quant_start_step': 0,
            'op_types': ['ReLU']
        }]
        config_list_dorefa = [{
            'quant_types': ['weight'],
            'quant_bits': {
                'weight': 8,
            }, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
            'op_types':['Conv2d', 'Linear']
        }]
        config_list_bnn = [{
            'quant_types': ['weight'],
            'quant_bits': 1,
            'op_types': ['Conv2d', 'Linear']
        }, {
            'quant_types': ['output'],
            'quant_bits': 1,
            'op_types': ['ReLU']
        }]
        config_set = [config_list_qat, config_list_dorefa, config_list_bnn]
        quantize_algorithm_set = [torch_quantizer.QAT_Quantizer, torch_quantizer.DoReFaQuantizer, torch_quantizer.BNNQuantizer]
554
        dummy = torch.randn(1, 1, 28, 28)
lin bin's avatar
lin bin committed
555
556
557
        for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
            model = TorchModel()
            model.relu = torch.nn.ReLU()
558
            optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
559
560
561
562
            if quantize_algorithm == torch_quantizer.QAT_Quantizer:
                quantizer = quantize_algorithm(model, config, optimizer, dummy)
            else:
                quantizer = quantize_algorithm(model, config, optimizer)
lin bin's avatar
lin bin committed
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
            quantizer.compress()

            x = torch.rand((1, 1, 28, 28), requires_grad=True)
            y = model(x)
            y.backward(torch.ones_like(y))

            model_path = "test_model.pth"
            calibration_path = "test_calibration.pth"
            onnx_path = "test_model.onnx"
            input_shape = (1, 1, 28, 28)
            device = torch.device("cpu")

            calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)
            assert calibration_config is not None

578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    def test_quantizer_load_calibration_config(self):
        configure_list = [{
            'quant_types': ['weight', 'input'],
            'quant_bits': {'weight': 8, 'input': 8},
            'op_names': ['conv1', 'conv2']
        }, {
            'quant_types': ['output', 'weight', 'input'],
            'quant_bits': {'output': 8, 'weight': 8, 'input': 8},
            'op_names': ['fc1', 'fc2'],
        }]
        quantize_algorithm_set = [torch_quantizer.ObserverQuantizer, torch_quantizer.QAT_Quantizer, torch_quantizer.LsqQuantizer]
        calibration_config = None
        for quantize_algorithm in quantize_algorithm_set:
            model = TorchModel().eval()
            model.relu = torch.nn.ReLU()
            optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
594
595
596
597
598
            if quantize_algorithm == torch_quantizer.QAT_Quantizer:
                dummy = torch.randn(1, 1, 28, 28)
                quantizer = quantize_algorithm(model, configure_list, optimizer, dummy_input=dummy)
            else:
                quantizer = quantize_algorithm(model, configure_list, optimizer)
599
600
601
602
603
604
605
606
607
608
609
610
            quantizer.compress()
            if calibration_config is not None:
                quantizer.load_calibration_config(calibration_config)

            model_path = "test_model.pth"
            calibration_path = "test_calibration.pth"
            onnx_path = "test_model.onnx"
            input_shape = (1, 1, 28, 28)
            device = torch.device("cpu")

            calibration_config = quantizer.export_model(model_path, calibration_path, onnx_path, input_shape, device)

611
612
    def test_torch_pruner_validation(self):
        # test bad configuraiton
613
        pruner_classes = [torch_pruner.__dict__[x] for x in \
Guoxin's avatar
Guoxin committed
614
            ['LevelPruner', 'SlimPruner', 'FPGMPruner', 'L1FilterPruner', 'L2FilterPruner', 'AGPPruner',\
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
            'ActivationMeanRankFilterPruner', 'ActivationAPoZRankFilterPruner']]

        bad_configs = [
            [
                {'sparsity': '0.2'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2},
                {'sparsity': 1.6 }
            ],
            [
                {'sparsity': 0.2, 'op_types': 'default'},
                {'sparsity': 0.6 }
            ],
            [
                {'sparsity': 0.2 },
632
                {'sparsity': 0.6, 'op_names': 'abc'}
633
634
635
            ]
        ]
        model = TorchModel()
Guoxin's avatar
Guoxin committed
636
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
637
638
639
        for pruner_class in pruner_classes:
            for config_list in bad_configs:
                try:
640
641
642
643
644
645
646
                    kwargs = {}
                    if pruner_class in (torch_pruner.SlimPruner, torch_pruner.AGPPruner, torch_pruner.ActivationMeanRankFilterPruner, torch_pruner.ActivationAPoZRankFilterPruner):
                        kwargs = {'optimizer': None, 'trainer': None, 'criterion': None}

                    print('kwargs', kwargs)
                    pruner_class(model, config_list, **kwargs)      

647
648
649
650
651
652
653
654
655
656
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', pruner_class, config_list)
                    raise

    def test_torch_quantizer_validation(self):
        # test bad configuraiton
657
        quantizer_classes = [torch_quantizer.__dict__[x] for x in \
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
            ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer']]

        bad_configs = [
            [
                {'bad_key': 'abc'}
            ],
            [
                {'quant_types': 'abc'}
            ],
            [
                {'quant_bits': 34}
            ],
            [
                {'op_types': 'default'}
            ],
            [
                {'quant_bits': {'abc': 123}}
            ]
        ]
        model = TorchModel()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        for quantizer_class in quantizer_classes:
            for config_list in bad_configs:
                try:
                    quantizer_class(model, config_list, optimizer)
                    print(config_list)
                    assert False, 'Validation error should be raised for bad configuration'
                except schema.SchemaError:
                    pass
                except:
                    print('FAILED:', quantizer_class, config_list)
                    raise
Tang Lang's avatar
Tang Lang committed
690

691
692
if __name__ == '__main__':
    main()