Quantizer.md 6.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
Quantizer on NNI Compressor
===
## Naive Quantizer

We provide Naive Quantizer to quantizer weight to default 8 bits, you can use it to test quantize algorithm without any configure.

### Usage
tensorflow
```python
QuanluZhang's avatar
QuanluZhang committed
10
nni.compressors.tensorflow.NaiveQuantizer(model_graph).compress()
11
12
13
```
pytorch
```python
QuanluZhang's avatar
QuanluZhang committed
14
nni.compressors.torch.NaiveQuantizer(model).compress()
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
```

***

## QAT Quantizer
In [Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference](http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf), authors Benoit Jacob and Skirmantas Kligys provide an algorithm to quantize the model with training.

>We propose an approach that simulates quantization effects in the forward pass of training. Backpropagation still happens as usual, and all weights and biases are stored in floating point so that they can be easily nudged by small amounts. The forward propagation pass however simulates quantized inference as it will happen in the inference engine, by implementing in floating-point arithmetic the rounding behavior of the quantization scheme
>* Weights are quantized before they are convolved with the input. If batch normalization (see [17]) is used for the layer, the batch normalization parameters are “folded into” the weights before quantization.
>* Activations are quantized at points where they would be during inference, e.g. after the activation function is applied to a convolutional or fully connected layer’s output, or after a bypass connection adds or concatenates the outputs of several layers together such as in ResNets.


### Usage
You can quantize your model to 8 bits with the code below before your training code.

PyTorch code
```python
from nni.compressors.torch import QAT_Quantizer
Cjkkkk's avatar
Cjkkkk committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
model = Mnist()

config_list = [{
    'quant_types': ['weight'],
    'quant_bits': {
        'weight': 8,
    }, # you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
    'op_types':['Conv2d', 'Linear']
}, {
    'quant_types': ['output'],
    'quant_bits': 8,
    'quant_start_step': 7000,
    'op_types':['ReLU6']
}]
QuanluZhang's avatar
QuanluZhang committed
47
48
quantizer = QAT_Quantizer(model, config_list)
quantizer.compress()
49
50
51
52
53
```

You can view example for more information

#### User configuration for QAT Quantizer
Cjkkkk's avatar
Cjkkkk committed
54
* **quant_types:** : list of string
55
56
57
58
59
60
61
62
63
64
65

type of quantization you want to apply, currently support 'weight', 'input', 'output'.

* **op_types:** list of string

specify the type of modules that will be quantized. eg. 'Conv2D'

* **op_names:** list of string

specify the name of modules that will be quantized. eg. 'conv1'

Cjkkkk's avatar
Cjkkkk committed
66
* **quant_bits:** int or dict of {str : int}
67
68
69
70

bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length.

Cjkkkk's avatar
Cjkkkk committed
71
* **quant_start_step:** int
72

Cjkkkk's avatar
Cjkkkk committed
73
74
75
76
77
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0

### note
batch normalization folding is currently not supported.
78
79
80
81
82
83
84
85
86
87
88
***

## DoReFa Quantizer
In [DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients](https://arxiv.org/abs/1606.06160), authors Shuchang Zhou and Yuxin Wu provide an algorithm named DoReFa to quantize the weight, activation and gradients with training.

### Usage
To implement DoReFa Quantizer, you can add code below before your training code

PyTorch code
```python
from nni.compressors.torch import DoReFaQuantizer
89
90
91
92
93
config_list = [{ 
    'quant_types': ['weight'],
    'quant_bits': 8, 
    'op_types': 'default' 
}]
QuanluZhang's avatar
QuanluZhang committed
94
95
quantizer = DoReFaQuantizer(model, config_list)
quantizer.compress()
96
97
98
99
```

You can view example for more information

Chi Song's avatar
Chi Song committed
100
#### User configuration for DoReFa Quantizer
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
* **quant_types:** : list of string

type of quantization you want to apply, currently support 'weight', 'input', 'output'.

* **op_types:** list of string

specify the type of modules that will be quantized. eg. 'Conv2D'

* **op_names:** list of string

specify the name of modules that will be quantized. eg. 'conv1'

* **quant_bits:** int or dict of {str : int}

bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length.


## BNN Quantizer
In [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830), 

>We introduce a method to train Binarized Neural Networks (BNNs) - neural networks with binary weights and activations at run-time. At training-time the binary weights and activations are used for computing the parameters gradients. During the forward pass, BNNs drastically reduce memory size and accesses, and replace most arithmetic operations with bit-wise operations, which is expected to substantially improve power-efficiency.


### Usage

PyTorch code
```python
from nni.compression.torch import BNNQuantizer
model = VGG_Cifar10(num_classes=10)

configure_list = [{
    'quant_types': ['weight'],
    'quant_bits': 1,
    'op_types': ['Conv2d', 'Linear'],
    'op_names': ['features.0', 'features.3', 'features.7', 'features.10', 'features.14', 'features.17', 'classifier.0', 'classifier.3']
}, {
    'quant_types': ['output'],
    'quant_bits': 1,
    'op_types': ['Hardtanh'],
    'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5']
}]

quantizer = BNNQuantizer(model, configure_list)
model = quantizer.compress()
```

You can view example [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py) for more information.

#### User configuration for BNN Quantizer
* **quant_types:** : list of string

type of quantization you want to apply, currently support 'weight', 'input', 'output'.

* **op_types:** list of string

specify the type of modules that will be quantized. eg. 'Conv2D'

* **op_names:** list of string

specify the name of modules that will be quantized. eg. 'conv1'

* **quant_bits:** int or dict of {str : int}

bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length.

### Experiment
We implemented one of the experiments in [Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1](https://arxiv.org/abs/1602.02830), we quantized the **VGGNet** for CIFAR-10 in the paper. Our experiments results are as follows:

| Model         | Accuracy  | 
| ------------- | --------- | 
| VGGNet        | 86.93%    |


The experiments code can be found at [examples/model_compress/BNN_quantizer_cifar10.py]( https://github.com/microsoft/nni/tree/master/examples/model_compress/BNN_quantizer_cifar10.py)