quantizers.py 30.6 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import logging
5
import copy
6
import torch
7
from schema import Schema, And, Or, Optional
8
from nni.compression.pytorch.utils.config_validation import CompressorSchema
chenbohua3's avatar
chenbohua3 committed
9
from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType
10

chenbohua3's avatar
chenbohua3 committed
11
__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']
12
13
14
15
16

logger = logging.getLogger(__name__)


class NaiveQuantizer(Quantizer):
17
    """quantize weight to 8 bits
18
    """
19

Cjkkkk's avatar
Cjkkkk committed
20
21
    def __init__(self, model, config_list, optimizer=None):
        super().__init__(model, config_list, optimizer)
22
23
        self.layer_scale = {}

24
25
26
27
28
29
30
31
32
33
    def validate_config(self, model, config_list):
        schema = CompressorSchema([{
            Optional('quant_types'): ['weight'],
            Optional('quant_bits'): Or(8, {'weight': 8}),
            Optional('op_types'): [str],
            Optional('op_names'): [str]
        }], model, logger)

        schema.validate(config_list)

34
35
    def quantize_weight(self, wrapper, **kwargs):
        weight = copy.deepcopy(wrapper.module.old_weight.data)
36
        new_scale = weight.abs().max() / 127
Cjkkkk's avatar
Cjkkkk committed
37
38
        scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
        self.layer_scale[wrapper.name] = scale
39
        orig_type = weight.type()  # TODO: user layer
40
41
42
        weight = weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
        wrapper.module.weight = weight
        return weight
43

44
def update_ema(biased_ema, value, decay):
Cjkkkk's avatar
Cjkkkk committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    """
    calculate biased stat and unbiased stat in each step using exponential moving average method

    Parameters
    ----------
    biased_ema : float
        previous stat value
    value : float
        current stat value
    decay : float
        the weight of previous stat value, larger means smoother curve

    Returns
    -------
    float, float
    """
    biased_ema = biased_ema * decay + (1 - decay) * value
chenbohua3's avatar
chenbohua3 committed
62
    return biased_ema
Cjkkkk's avatar
Cjkkkk committed
63

64

Cjkkkk's avatar
Cjkkkk committed
65
66
67
68
69
70
71
72
def update_quantization_param(bits, rmin, rmax):
    """
    calculate the `zero_point` and `scale`.

    Parameters
    ----------
    bits : int
        quantization bits length
73
    rmin : Tensor
Cjkkkk's avatar
Cjkkkk committed
74
        min value of real value
75
    rmax : Tensor
Cjkkkk's avatar
Cjkkkk committed
76
77
78
79
80
81
82
83
84
        max value of real value

    Returns
    -------
    float, float
    """
    # extend the [min, max] interval to ensure that it contains 0.
    # Otherwise, we would not meet the requirement that 0 be an exactly
    # representable value.
85
86
87
88
    rmin = torch.min(rmin, torch.Tensor([0]).to(rmin.device))
    rmax = torch.max(rmax, torch.Tensor([0]).to(rmin.device))
    qmin = torch.Tensor([0]).to(rmin.device)
    qmax = torch.Tensor([(1 << bits) - 1]).to(rmin.device)
Cjkkkk's avatar
Cjkkkk committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

    # First determine the scale.
    scale = (rmax - rmin) / (qmax - qmin)

    # Zero-point computation.
    initial_zero_point = qmin - rmin / scale

    # Now we need to nudge the zero point to be an integer
    if initial_zero_point < qmin:
        nudged_zero_point = qmin
    elif initial_zero_point > qmax:
        nudged_zero_point = qmax
    else:
        nudged_zero_point = torch.round(initial_zero_point)

    return scale, nudged_zero_point


def get_bits_length(config, quant_type):
    if isinstance(config["quant_bits"], int):
        return config["quant_bits"]
    else:
        return config["quant_bits"].get(quant_type)

lin bin's avatar
lin bin committed
113
114
115
116
117
118
119
120
121
class QATGrad(QuantGrad):
    @staticmethod
    def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
        tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
        mask = (tensor_q < qmin) | (tensor_q > qmax)
        grad_output[mask] = 0
        return grad_output


122
class QAT_Quantizer(Quantizer):
123
    """Quantizer defined in:
124
125
126
    Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
    http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
    """
127

Cjkkkk's avatar
Cjkkkk committed
128
    def __init__(self, model, config_list, optimizer=None):
129
        """
Cjkkkk's avatar
Cjkkkk committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        Parameters
        ----------
        layer : LayerInfo
            the layer to quantize
        config_list : list of dict
            list of configurations for quantization
            supported keys for dict:
                - quant_types : list of string
                    type of quantization you want to apply, currently support 'weight', 'input', 'output'
                - quant_bits : int or dict of {str : int}
                    bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
                    when the type is int, all quantization types share same bits length
                - quant_start_step : int
                    disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
                    state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
                - op_types : list of string
                    types of nn.module you want to apply quantization, eg. 'Conv2d'
147
        """
Cjkkkk's avatar
Cjkkkk committed
148
        super().__init__(model, config_list, optimizer)
chenbohua3's avatar
chenbohua3 committed
149
        self.quant_grad = QATGrad.apply
Cjkkkk's avatar
Cjkkkk committed
150
        modules_to_compress = self.get_modules_to_compress()
151
        self.bound_model.register_buffer("steps", torch.Tensor([1]))
Cjkkkk's avatar
Cjkkkk committed
152
        for layer, config in modules_to_compress:
153
154
            layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
            layer.module.register_buffer("scale", torch.Tensor([1.0]))
155
            layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
lin bin's avatar
lin bin committed
156
157
            if "weight" in config.get("quant_types", []):
                layer.module.register_buffer('weight_bit', torch.zeros(1))
158
159
                layer.module.register_buffer('tracked_min_input', torch.zeros(1))
                layer.module.register_buffer('tracked_max_input', torch.zeros(1))
Cjkkkk's avatar
Cjkkkk committed
160
            if "output" in config.get("quant_types", []):
lin bin's avatar
lin bin committed
161
                layer.module.register_buffer('activation_bit', torch.zeros(1))
162
163
164
                layer.module.register_buffer('tracked_min_activation', torch.zeros(1))
                layer.module.register_buffer('tracked_max_activation', torch.zeros(1))
                
165

lin bin's avatar
lin bin committed
166
167
168
169
    def _del_simulated_attr(self, module):
        """
        delete redundant parameters in quantize module
        """
170
171
        del_attr_list = ['old_weight', 'ema_decay', 'tracked_min_activation', 'tracked_max_activation', 'tracked_min_input', \
        'tracked_max_input', 'scale', 'zero_point', 'weight_bit', 'activation_bit']
lin bin's avatar
lin bin committed
172
173
174
175
        for attr in del_attr_list:
            if hasattr(module, attr):
                delattr(module, attr)

176
177
178
179
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
180
        model : torch.nn.Module
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            Model to be pruned
        config_list : list of dict
            List of configurations
        """
        schema = CompressorSchema([{
            Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
            Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
                Optional('weight'): And(int, lambda n: 0 < n < 32),
                Optional('output'): And(int, lambda n: 0 < n < 32),
            })),
            Optional('quant_start_step'): And(int, lambda n: n >= 0),
            Optional('op_types'): [str],
            Optional('op_names'): [str]
        }], model, logger)

        schema.validate(config_list)

Cjkkkk's avatar
Cjkkkk committed
198
199
200
201
202
203
204
205
    def _quantize(self, bits, op, real_val):
        """
        quantize real value.

        Parameters
        ----------
        bits : int
            quantization bits length
206
        op : torch.nn.Module
Cjkkkk's avatar
Cjkkkk committed
207
            target module
208
        real_val : Tensor
Cjkkkk's avatar
Cjkkkk committed
209
210
211
212
            real value to be quantized

        Returns
        -------
213
        Tensor
Cjkkkk's avatar
Cjkkkk committed
214
        """
215
216
        op.zero_point = op.zero_point.to(real_val.device)
        op.scale = op.scale.to(real_val.device)
Cjkkkk's avatar
Cjkkkk committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        transformed_val = op.zero_point + real_val / op.scale
        qmin = 0
        qmax = (1 << bits) - 1
        clamped_val = torch.clamp(transformed_val, qmin, qmax)
        quantized_val = torch.round(clamped_val)
        return quantized_val

    def _dequantize(self, op, quantized_val):
        """
        dequantize quantized value.
        Because we simulate quantization in training process, all the computations still happen as float point computations, which means we
        first quantize tensors then dequantize them. For more details, please refer to the paper.

        Parameters
        ----------
        op : torch.nn.Module
            target module
        quantized_val : float
            quantized_val value to be dequantized

        Returns
        -------
        float
        """
        real_val = op.scale * (quantized_val - op.zero_point)
        return real_val

244
    def quantize_weight(self, wrapper, **kwargs):
Cjkkkk's avatar
Cjkkkk committed
245
246
        config = wrapper.config
        module = wrapper.module
247
        input = kwargs['input_tensor']
248
        weight = copy.deepcopy(wrapper.module.old_weight.data)
Cjkkkk's avatar
Cjkkkk committed
249
250
251
252
        weight_bits = get_bits_length(config, 'weight')
        quant_start_step = config.get('quant_start_step', 0)
        assert weight_bits >= 1, "quant bits length should be at least 1"

253
        # we dont update weight in evaluation stage
254
255
256
257
258
        if quant_start_step > self.bound_model.steps:
            module.tracked_min_input, module.tracked_max_input = torch.min(input), torch.max(input)
            return weight

        if not wrapper.training:
259
            return weight
260

261
262
263
264
265
266
        current_min, current_max = torch.min(input), torch.max(input)
        module.tracked_min_input = update_ema(module.tracked_min_input, current_min,
                                                                    module.ema_decay)
        module.tracked_max_input = update_ema(module.tracked_max_input, current_max,
                                                                    module.ema_decay)

267
268
269
270
271
272
273
274
275
276
277
278
        # if bias exists, quantize bias to uint32
        if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
            bias = wrapper.module.bias.data
            bias_bits = 32
            rmin, rmax = torch.min(bias), torch.max(bias)
            module.scale, module.zero_point = update_quantization_param(bias_bits, rmin, rmax)
            bias = self._quantize(bias_bits, module, bias)
            bias = self._dequantize(module, bias)
            wrapper.module.bias.data = bias


        # quantize weight
Cjkkkk's avatar
Cjkkkk committed
279
        rmin, rmax = torch.min(weight), torch.max(weight)
Cjkkkk's avatar
Cjkkkk committed
280
        module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
281
282
        weight = self._quantize(weight_bits, module, weight)
        weight = self._dequantize(module, weight)
lin bin's avatar
lin bin committed
283
        module.weight_bit = torch.Tensor([weight_bits])
284
285
        wrapper.module.weight = weight
        return weight
Cjkkkk's avatar
Cjkkkk committed
286

Cjkkkk's avatar
Cjkkkk committed
287
288
289
    def quantize_output(self, output, wrapper, **kwargs):
        config = wrapper.config
        module = wrapper.module
Cjkkkk's avatar
Cjkkkk committed
290
        output_bits = get_bits_length(config, 'output')
lin bin's avatar
lin bin committed
291
        module.activation_bit = torch.Tensor([output_bits])
Cjkkkk's avatar
Cjkkkk committed
292
293
294
        quant_start_step = config.get('quant_start_step', 0)
        assert output_bits >= 1, "quant bits length should be at least 1"

295
        if quant_start_step > self.bound_model.steps:
296
            module.tracked_min_activation, module.tracked_max_activation = torch.min(output), torch.max(output)
Cjkkkk's avatar
Cjkkkk committed
297
298
            return output

299
300
301
        # we dont update output quantization parameters in evaluation stage
        if wrapper.training:
            current_min, current_max = torch.min(output), torch.max(output)
302
            module.tracked_min_activation = update_ema(module.tracked_min_activation, current_min,
303
                                                                       module.ema_decay)
304
            module.tracked_max_activation = update_ema(module.tracked_max_activation, current_max,
305
                                                                       module.ema_decay)
306
            module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_activation, module.tracked_max_activation)
Cjkkkk's avatar
Cjkkkk committed
307
308
        out = self._quantize(output_bits, module, output)
        out = self._dequantize(module, out)
Cjkkkk's avatar
Cjkkkk committed
309
310
        return out

lin bin's avatar
lin bin committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
        """
        Export quantized model weights and calibration parameters(optional)

        Parameters
        ----------
        model_path : str
            path to save quantized model weight
        calibration_path : str
            (optional) path to save quantize parameters after calibration
        onnx_path : str
            (optional) path to save onnx model
        input_shape : list or tuple
            input shape to onnx model
        device : torch.device
            device of the model, used to place the dummy input tensor for exporting onnx file.
            the tensor is placed on cpu if ```device``` is None

        Returns
        -------
        Dict
        """
        assert model_path is not None, 'model_path must be specified'
        self._unwrap_model()
        calibration_config = {}

        for name, module in self.bound_model.named_modules():
            if hasattr(module, 'weight_bit') or hasattr(module, 'activation_bit'):
                calibration_config[name] = {}
            if hasattr(module, 'weight_bit'):
                calibration_config[name]['weight_bit'] = int(module.weight_bit)
342
343
                calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
                calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
lin bin's avatar
lin bin committed
344
345
            if hasattr(module, 'activation_bit'):
                calibration_config[name]['activation_bit'] = int(module.activation_bit)
346
347
                calibration_config[name]['tracked_min_activation'] = float(module.tracked_min_activation)
                calibration_config[name]['tracked_max_activation'] = float(module.tracked_max_activation)
lin bin's avatar
lin bin committed
348
349
350
351
352
353
            self._del_simulated_attr(module)

        self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)

        return calibration_config

Cjkkkk's avatar
Cjkkkk committed
354
355
356
357
    def fold_bn(self, config, **kwargs):
        # TODO simulate folded weight
        pass

Cjkkkk's avatar
Cjkkkk committed
358
    def step_with_optimizer(self):
Cjkkkk's avatar
Cjkkkk committed
359
360
361
        """
        override `compressor` `step` method, quantization only happens after certain number of steps
        """
362
        self.bound_model.steps +=1
363
364
365


class DoReFaQuantizer(Quantizer):
366
    """Quantizer using the DoReFa scheme, as defined in:
367
368
369
    Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
    (https://arxiv.org/abs/1606.06160)
    """
370

Cjkkkk's avatar
Cjkkkk committed
371
372
    def __init__(self, model, config_list, optimizer=None):
        super().__init__(model, config_list, optimizer)
lin bin's avatar
lin bin committed
373
374
375
376
377
378
379
380
381
382
383
384
385
        modules_to_compress = self.get_modules_to_compress()
        for layer, config in modules_to_compress:
            if "weight" in config.get("quant_types", []):
                layer.module.register_buffer('weight_bit', torch.zeros(1))

    def _del_simulated_attr(self, module):
        """
        delete redundant parameters in quantize module
        """
        del_attr_list = ['old_weight', 'weight_bit']
        for attr in del_attr_list:
            if hasattr(module, attr):
                delattr(module, attr)
386

387
388
389
390
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
391
        model : torch.nn.Module
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
            Model to be pruned
        config_list : list of dict
            List of configurations
        """
        schema = CompressorSchema([{
            Optional('quant_types'): Schema([lambda x: x in ['weight']]),
            Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
                Optional('weight'): And(int, lambda n: 0 < n < 32)
            })),
            Optional('op_types'): [str],
            Optional('op_names'): [str]
        }], model, logger)

        schema.validate(config_list)

407
408
    def quantize_weight(self, wrapper, **kwargs):
        weight = copy.deepcopy(wrapper.module.old_weight.data)
Cjkkkk's avatar
Cjkkkk committed
409
        weight_bits = get_bits_length(wrapper.config, 'weight')
410
411
412
413
414
        weight = weight.tanh()
        weight = weight / (2 * weight.abs().max()) + 0.5
        weight = self.quantize(weight, weight_bits)
        weight = 2 * weight - 1
        wrapper.module.weight = weight
lin bin's avatar
lin bin committed
415
        wrapper.module.weight_bit = torch.Tensor([weight_bits])
416
417
        # wrapper.module.weight.data = weight
        return weight
418
419

    def quantize(self, input_ri, q_bits):
420
421
        scale = pow(2, q_bits) - 1
        output = torch.round(input_ri * scale) / scale
422
423
        return output

lin bin's avatar
lin bin committed
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
        """
        Export quantized model weights and calibration parameters(optional)

        Parameters
        ----------
        model_path : str
            path to save quantized model weight
        calibration_path : str
            (optional) path to save quantize parameters after calibration
        onnx_path : str
            (optional) path to save onnx model
        input_shape : list or tuple
            input shape to onnx model
        device : torch.device
            device of the model, used to place the dummy input tensor for exporting onnx file.
            the tensor is placed on cpu if ```device``` is None

        Returns
        -------
        Dict
        """
        assert model_path is not None, 'model_path must be specified'
        self._unwrap_model()
        calibration_config = {}

        for name, module in self.bound_model.named_modules():
            if hasattr(module, 'weight_bit'):
                calibration_config[name] = {}
                calibration_config[name]['weight_bit'] = int(module.weight_bit)
            self._del_simulated_attr(module)

        self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)

        return calibration_config

460
461
462

class ClipGrad(QuantGrad):
    @staticmethod
lin bin's avatar
lin bin committed
463
    def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
464
465
466
467
468
469
470
471
472
473
        if quant_type == QuantType.QUANT_OUTPUT:
            grad_output[torch.abs(tensor) > 1] = 0
        return grad_output


class BNNQuantizer(Quantizer):
    """Binarized Neural Networks, as defined in:
    Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
    (https://arxiv.org/abs/1602.02830)
    """
474

Cjkkkk's avatar
Cjkkkk committed
475
476
    def __init__(self, model, config_list, optimizer=None):
        super().__init__(model, config_list, optimizer)
chenbohua3's avatar
chenbohua3 committed
477
        self.quant_grad = ClipGrad.apply
lin bin's avatar
lin bin committed
478
479
480
481
482
483
484
485
486
487
488
489
490
        modules_to_compress = self.get_modules_to_compress()
        for layer, config in modules_to_compress:
            if "weight" in config.get("quant_types", []):
                layer.module.register_buffer('weight_bit', torch.zeros(1))

    def _del_simulated_attr(self, module):
        """
        delete redundant parameters in quantize module
        """
        del_attr_list = ['old_weight', 'weight_bit']
        for attr in del_attr_list:
            if hasattr(module, attr):
                delattr(module, attr)
491

492
493
494
495
    def validate_config(self, model, config_list):
        """
        Parameters
        ----------
496
        model : torch.nn.Module
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
            Model to be pruned
        config_list : list of dict
            List of configurations
        """
        schema = CompressorSchema([{
            Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
            Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
                Optional('weight'): And(int, lambda n: 0 < n < 32),
                Optional('output'): And(int, lambda n: 0 < n < 32),
            })),
            Optional('op_types'): [str],
            Optional('op_names'): [str]
        }], model, logger)

        schema.validate(config_list)

513
514
515
    def quantize_weight(self, wrapper, **kwargs):
        weight = copy.deepcopy(wrapper.module.old_weight.data)
        weight = torch.sign(weight)
516
        # remove zeros
517
518
        weight[weight == 0] = 1
        wrapper.module.weight = weight
lin bin's avatar
lin bin committed
519
        wrapper.module.weight_bit = torch.Tensor([1.0])
520
        return weight
521

Cjkkkk's avatar
Cjkkkk committed
522
    def quantize_output(self, output, wrapper, **kwargs):
523
524
525
526
        out = torch.sign(output)
        # remove zeros
        out[out == 0] = 1
        return out
lin bin's avatar
lin bin committed
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561

    def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
        """
        Export quantized model weights and calibration parameters(optional)

        Parameters
        ----------
        model_path : str
            path to save quantized model weight
        calibration_path : str
            (optional) path to save quantize parameters after calibration
        onnx_path : str
            (optional) path to save onnx model
        input_shape : list or tuple
            input shape to onnx model
        device : torch.device
            device of the model, used to place the dummy input tensor for exporting onnx file.
            the tensor is placed on cpu if ```device``` is None

        Returns
        -------
        Dict
        """
        assert model_path is not None, 'model_path must be specified'
        self._unwrap_model()
        calibration_config = {}

        for name, module in self.bound_model.named_modules():
            if hasattr(module, 'weight_bit'):
                calibration_config[name] = {}
                calibration_config[name]['weight_bit'] = int(module.weight_bit)
            self._del_simulated_attr(module)

        self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)

chenbohua3's avatar
chenbohua3 committed
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
        return calibration_config


class LsqQuantizer(Quantizer):
    """Quantizer defined in:
       Learned Step Size Quantization (ICLR 2020)
       https://arxiv.org/pdf/1902.08153.pdf
    """

    def __init__(self, model, config_list, optimizer=None):
        """
        Parameters
        ----------
        model : torch.nn.Module
            the model to be quantized
        config_list : list of dict
            list of configurations for quantization
            supported keys for dict:
                - quant_types : list of string
                    type of quantization you want to apply, currently support 'weight', 'input', 'output'
                - quant_bits : int or dict of {str : int}
                    bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
                    when the type is int, all quantization types share same bits length
                - quant_start_step : int
                    disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
                    state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
                - op_types : list of string
                    types of nn.module you want to apply quantization, eg. 'Conv2d'
        """
        super().__init__(model, config_list, optimizer)
        self.quant_grad = QuantForward()
        modules_to_compress = self.get_modules_to_compress()
        self.bound_model.register_buffer("steps", torch.Tensor([1]))
        for layer, config in modules_to_compress:
            if "weight" in config.get("quant_types", []):
                layer.module.register_parameter("weight_scale", torch.nn.Parameter(torch.Tensor([1.0])))
                # todo: support per-channel quantization for weight since TensorRT use it for conv weight
                q_bit = get_bits_length(config, "weight")
                layer.module.register_buffer('weight_bit', torch.Tensor([q_bit]))
                qmax = 2 ** (q_bit - 1) - 1
                qmin = -2 ** (q_bit - 1)
                init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5)
                layer.module.weight_scale = torch.nn.Parameter(init_weight_scale)
                layer.module.weight_qmax = qmax
                layer.module.weight_qmin = qmin

                self.optimizer.add_param_group({"params": layer.module.weight_scale})

            if "output" in config.get("quant_types", []):
                # scale of activation will be initialized using the first batch data
                layer.module.register_parameter("output_scale", torch.nn.Parameter(torch.Tensor([1.0])))
                q_bit = get_bits_length(config, "output")
                layer.module.register_buffer('output_bit', torch.Tensor([q_bit]))
                qmax = 2 ** (q_bit - 1) - 1
                qmin = -2 ** (q_bit - 1)
                layer.module.output_qmax = qmax
                layer.module.output_qmin = qmin

                self.optimizer.add_param_group({"params": layer.module.output_scale})

            if "input" in config.get("quant_types", []):
                # scale of input will be initialized using the first batch data
                layer.module.register_parameter("input_scale", torch.nn.Parameter(torch.Tensor([1.0])))
                q_bit = get_bits_length(config, "input")
                layer.module.register_buffer('input_bit', torch.Tensor([q_bit]))
                qmax = 2 ** (q_bit - 1) - 1
                qmin = -2 ** (q_bit - 1)
                layer.module.input_qmax = qmax
                layer.module.input_qmin = qmin

                self.optimizer.add_param_group({"params": layer.module.input_scale})

    @staticmethod
    def grad_scale(x, scale):
        """
            Used to scale the gradient. Give tensor `x`, we have `y=grad_scale(x, scale)=x` in the forward pass,
            which means that this function will not change the value of `x`. In the backward pass, we have:

            :math:`\frac{\alpha_L}{\alpha_x}=\frac{\alpha_L}{\alpha_y}*\frac{\alpha_y}{\alpha_x}=sclae*\frac{\alpha_L}{\alpha_x}`

            This means that the origin gradient of x is scaled by a factor of `scale`. Applying this function
            to a nn.Parameter will scale the gradient of it without changing its value.
        """
        y = x
        y_grad = x * scale
        return (y - y_grad).detach() + y_grad

    @staticmethod
    def round_pass(x):
        """
            A simple way to achieve STE operation.
        """
        y = x.round()
        y_grad = x
        return (y - y_grad).detach() + y_grad

    def quantize(self, x, scale, qmin, qmax):
        grad_scale_factor = 1.0 / ((qmax * x.numel()) ** 0.5)
        scale = self.grad_scale(scale, grad_scale_factor)
        x = x / scale
        x = torch.clamp(x, qmin, qmax)
        x = self.round_pass(x)
        x = x * scale
        return x

    def quantize_weight(self, wrapper, **kwargs):
        module = wrapper.module

        # todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize
        # bias
        old_weight = module.old_weight
        weight = self.quantize(old_weight, module.weight_scale, module.weight_qmin, module.weight_qmax)
        module.weight = weight
        return weight

    def quantize_output(self, output, wrapper, **kwargs):
        module = wrapper.module

        # initialize the scale
        if self.bound_model.steps == 1:
            qmax = module.output_qmax
            init_oup_scale = output.data.detach().abs().mean() * 2 / (qmax ** 0.5)
            module.output_scale.data = init_oup_scale

        output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax)
        return output

    def quantize_input(self, *inputs, wrapper, **kwargs):
        # This is hacky since it is not recommended to modify a tuple
        # NB: support layers with multi inputs
        module = wrapper.module
        # initialize the scale
        if self.bound_model.steps == 1:
            qmax = module.input_qmax
            init_oup_scale = inputs[0].data.detach().abs().mean() * 2 / (qmax ** 0.5)
            module.input_scale.data = init_oup_scale

        new_input = self.quantize(inputs[0], module.input_scale, module.input_qmin, module.input_qmax)
        list_inp = list(inputs)
        list_inp[0] = new_input
        return tuple(list_inp)

    def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
        """
        Export quantized model weights and calibration parameters(optional)

        Parameters
        ----------
        model_path : str
            path to save quantized model weight
        calibration_path : str
            (optional) path to save quantize parameters after calibration
        onnx_path : str
            (optional) path to save onnx model
        input_shape : list or tuple
            input shape to onnx model
        device : torch.device
            device of the model, used to place the dummy input tensor for exporting onnx file.
            the tensor is placed on cpu if ```device``` is None

        Returns
        -------
        Dict
        """
        assert model_path is not None, 'model_path must be specified'
        self._unwrap_model()
        calibration_config = {}

        for name, module in self.bound_model.named_modules():
            if hasattr(module, 'input_bit') or hasattr(module, 'output_bit'):
                calibration_config[name] = {}
            if hasattr(module, 'weight_bit'):
                calibration_config[name]['weight_bit'] = int(module.weight_bit)
                abs_max_input = float(module.input_scale * module.input_qmax)
                calibration_config[name]['tracked_min_input'] = -abs_max_input
                calibration_config[name]['tracked_max_input'] = abs_max_input
            if hasattr(module, 'output_bit'):
                calibration_config[name]['activation_bit'] = int(module.output_bit)
                abs_max_output = float(module.output_scale * module.output_qmax)
                calibration_config[name]['tracked_min_activation'] = -abs_max_output
                calibration_config[name]['tracked_max_activation'] = abs_max_output
            self._del_simulated_attr(module)

        self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path,
                               input_shape, device)

        return calibration_config

    def _del_simulated_attr(self, module):
        """
        delete redundant parameters in quantize module
        """
        del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_activation', \
        'tracked_max_activation', 'output_scale', 'input_scale', 'weight_scale','weight_bit', 'output_bit', 'input_bit']
        for attr in del_attr_list:
            if hasattr(module, attr):
                delattr(module, attr)

    def step_with_optimizer(self):
        """
        override `compressor` `step` method, quantization only happens after certain number of steps
        """
        self.bound_model.steps += 1