Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
06a98372
Commit
06a98372
authored
Nov 25, 2019
by
Cjkkkk
Committed by
chicm-ms
Nov 25, 2019
Browse files
add new QAT_quantization (#1732)
parent
a03570a0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
335 additions
and
27 deletions
+335
-27
docs/en_US/Compressor/Quantizer.md
docs/en_US/Compressor/Quantizer.md
+25
-11
examples/model_compress/QAT_torch_quantizer.py
examples/model_compress/QAT_torch_quantizer.py
+98
-0
src/sdk/pynni/nni/compression/torch/builtin_quantizers.py
src/sdk/pynni/nni/compression/torch/builtin_quantizers.py
+183
-13
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+29
-3
No files found.
docs/en_US/Compressor/Quantizer.md
View file @
06a98372
...
...
@@ -28,17 +28,23 @@ In [Quantization and Training of Neural Networks for Efficient Integer-Arithmeti
### Usage
You can quantize your model to 8 bits with the code below before your training code.
Tensorflow code
```
python
from
nni.compressors.tensorflow
import
QAT_Quantizer
config_list
=
[{
'q_bits'
:
8
,
'op_types'
:
[
'default'
]
}]
quantizer
=
QAT_Quantizer
(
tf
.
get_default_graph
(),
config_list
)
quantizer
.
compress
()
```
PyTorch code
```
python
from
nni.compressors.torch
import
QAT_Quantizer
config_list
=
[{
'q_bits'
:
8
,
'op_types'
:
[
'default'
]
}]
model
=
Mnist
()
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
{
'weight'
:
8
,
},
# you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types'
:[
'Conv2d'
,
'Linear'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_start_step'
:
7000
,
'op_types'
:[
'ReLU6'
]
}]
quantizer
=
QAT_Quantizer
(
model
,
config_list
)
quantizer
.
compress
()
```
...
...
@@ -46,9 +52,17 @@ quantizer.compress()
You can view example for more information
#### User configuration for QAT Quantizer
*
**q_bits:**
This is to specify the q_bits operations to be quantized to
*
**quant_types:**
: list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'
*
**quant_bits:**
int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
when the type is int, all quantization types share same bits length
*
**quant_start_step:**
int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
### note
batch normalization folding is currently not supported.
***
## DoReFa Quantizer
...
...
examples/model_compress/QAT_torch_quantizer.py
0 → 100644
View file @
06a98372
import
torch
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
nni.compression.torch
import
QAT_Quantizer
class
Mnist
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
torch
.
nn
.
Linear
(
500
,
10
)
self
.
relu1
=
torch
.
nn
.
ReLU6
()
self
.
relu2
=
torch
.
nn
.
ReLU6
()
self
.
relu3
=
torch
.
nn
.
ReLU6
()
def
forward
(
self
,
x
):
x
=
self
.
relu1
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
self
.
relu2
(
self
.
conv2
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
self
.
relu3
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
):
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
quantizer
.
step
()
if
batch_idx
%
100
==
0
:
print
(
'{:2.0f}% Loss {}'
.
format
(
100
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
test
(
model
,
device
,
test_loader
):
model
.
eval
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
print
(
'Loss: {} Accuracy: {}%)
\n
'
.
format
(
test_loss
,
100
*
correct
/
len
(
test_loader
.
dataset
)))
def
main
():
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
'cpu'
)
trans
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'data'
,
train
=
True
,
download
=
True
,
transform
=
trans
),
batch_size
=
64
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
model
=
Mnist
()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
configure_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
{
'weight'
:
8
,
},
# you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types'
:[
'Conv2d'
,
'Linear'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_start_step'
:
7000
,
'op_types'
:[
'ReLU6'
]
}]
quantizer
=
QAT_Quantizer
(
model
,
configure_list
)
quantizer
.
compress
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
for
epoch
in
range
(
10
):
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
if
__name__
==
'__main__'
:
main
()
src/sdk/pynni/nni/compression/torch/builtin_quantizers.py
View file @
06a98372
...
...
@@ -22,6 +22,80 @@ class NaiveQuantizer(Quantizer):
return
weight
.
div
(
scale
).
type
(
torch
.
int8
).
type
(
orig_type
).
mul
(
scale
)
def
update_ema
(
biased_ema
,
value
,
decay
,
step
):
"""
calculate biased stat and unbiased stat in each step using exponential moving average method
Parameters
----------
biased_ema : float
previous stat value
value : float
current stat value
decay : float
the weight of previous stat value, larger means smoother curve
step : int
current step
Returns
-------
float, float
"""
biased_ema
=
biased_ema
*
decay
+
(
1
-
decay
)
*
value
unbiased_ema
=
biased_ema
/
(
1
-
decay
**
step
)
# Bias correction
return
biased_ema
,
unbiased_ema
def
update_quantization_param
(
bits
,
rmin
,
rmax
):
"""
calculate the `zero_point` and `scale`.
Parameters
----------
bits : int
quantization bits length
rmin : float
min value of real value
rmax : float
max value of real value
Returns
-------
float, float
"""
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
rmin
=
min
(
rmin
,
0
)
rmax
=
max
(
rmax
,
0
)
# the min and max quantized values, as floating-point values
qmin
=
0
qmax
=
(
1
<<
bits
)
-
1
# First determine the scale.
scale
=
(
rmax
-
rmin
)
/
(
qmax
-
qmin
)
# Zero-point computation.
initial_zero_point
=
qmin
-
rmin
/
scale
# Now we need to nudge the zero point to be an integer
nudged_zero_point
=
0
if
initial_zero_point
<
qmin
:
nudged_zero_point
=
qmin
elif
initial_zero_point
>
qmax
:
nudged_zero_point
=
qmax
else
:
nudged_zero_point
=
torch
.
round
(
initial_zero_point
)
return
scale
,
nudged_zero_point
def
get_bits_length
(
config
,
quant_type
):
if
isinstance
(
config
[
"quant_bits"
],
int
):
return
config
[
"quant_bits"
]
else
:
return
config
[
"quant_bits"
].
get
(
quant_type
)
class
QAT_Quantizer
(
Quantizer
):
"""Quantizer using the DoReFa scheme, as defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
...
...
@@ -29,23 +103,119 @@ class QAT_Quantizer(Quantizer):
"""
def
__init__
(
self
,
model
,
config_list
):
"""
config_list: supported keys:
- q_bits
Parameters
----------
layer : LayerInfo
the layer to quantize
config_list : list of dict
list of configurations for quantization
supported keys for dict:
- quant_types : list of string
type of quantization you want to apply, currently support 'weight', 'input', 'output'
- quant_bits : int or dict of {str : int}
bits length of quantization, key is the quantization type, value is the length, eg. {'weight', 8},
when the type is int, all quantization types share same bits length
- quant_start_step : int
disable quantization until model are run by certain number of steps, this allows the network to enter a more stable
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super
().
__init__
(
model
,
config_list
)
self
.
steps
=
1
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
layer
.
module
.
register_buffer
(
"zero_point"
,
None
)
layer
.
module
.
register_buffer
(
"scale"
,
None
)
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
layer
.
module
.
register_buffer
(
'tracked_min_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max'
,
torch
.
zeros
(
1
))
def
quantize_weight
(
self
,
weight
,
config
,
**
kwargs
):
if
config
[
'q_bits'
]
<=
1
:
def
_quantize
(
self
,
bits
,
op
,
real_val
):
"""
quantize real value.
Parameters
----------
bits : int
quantization bits length
op : torch.nn.module
target module
real_val : float
real value to be quantized
Returns
-------
float
"""
transformed_val
=
op
.
zero_point
+
real_val
/
op
.
scale
qmin
=
0
qmax
=
(
1
<<
bits
)
-
1
clamped_val
=
torch
.
clamp
(
transformed_val
,
qmin
,
qmax
)
quantized_val
=
torch
.
round
(
clamped_val
)
return
quantized_val
def
_dequantize
(
self
,
op
,
quantized_val
):
"""
dequantize quantized value.
Because we simulate quantization in training process, all the computations still happen as float point computations, which means we
first quantize tensors then dequantize them. For more details, please refer to the paper.
Parameters
----------
op : torch.nn.Module
target module
quantized_val : float
quantized_val value to be dequantized
Returns
-------
float
"""
real_val
=
op
.
scale
*
(
quantized_val
-
op
.
zero_point
)
return
real_val
def
quantize_weight
(
self
,
weight
,
config
,
op
,
**
kwargs
):
weight_bits
=
get_bits_length
(
config
,
'weight'
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
steps
:
return
weight
a
=
torch
.
min
(
weight
)
b
=
torch
.
max
(
weight
)
n
=
pow
(
2
,
config
[
'q_bits'
])
scale
=
(
b
-
a
)
/
(
n
-
1
)
zero_point
=
a
out
=
torch
.
round
((
weight
-
zero_point
)
/
scale
)
out
=
out
*
scale
+
zero_point
orig_type
=
weight
.
dtype
return
out
.
type
(
orig_type
)
rmin
,
rmax
=
torch
.
min
(
weight
),
torch
.
max
(
weight
)
op
.
scale
,
op
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
out
=
self
.
_quantize
(
weight_bits
,
op
,
weight
)
out
=
self
.
_dequantize
(
op
,
out
)
return
out
def
quantize_output
(
self
,
output
,
config
,
op
,
**
kwargs
):
output_bits
=
get_bits_length
(
config
,
'output'
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
steps
:
return
output
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
op
.
tracked_min_biased
,
op
.
tracked_min
=
update_ema
(
op
.
tracked_min_biased
,
current_min
,
op
.
ema_decay
,
self
.
steps
)
op
.
tracked_max_biased
,
op
.
tracked_max
=
update_ema
(
op
.
tracked_max_biased
,
current_max
,
op
.
ema_decay
,
self
.
steps
)
op
.
scale
,
op
.
zero_point
=
update_quantization_param
(
output_bits
,
op
.
tracked_min
,
op
.
tracked_max
)
out
=
self
.
_quantize
(
output_bits
,
op
,
output
)
out
=
self
.
_dequantize
(
op
,
out
)
return
out
def
fold_bn
(
self
,
config
,
**
kwargs
):
# TODO simulate folded weight
pass
def
step
(
self
):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self
.
steps
+=
1
class
DoReFaQuantizer
(
Quantizer
):
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
06a98372
...
...
@@ -304,6 +304,12 @@ class Quantizer(Compressor):
assert
layer
.
_forward
is
None
,
'Each model can only be compressed once'
assert
"quant_types"
in
config
,
'must provide quant_types in config'
assert
isinstance
(
config
[
"quant_types"
],
list
),
'quant_types must be list type'
assert
"quant_bits"
in
config
,
'must provide quant_bits in config'
assert
isinstance
(
config
[
"quant_bits"
],
int
)
or
isinstance
(
config
[
"quant_bits"
],
dict
),
'quant_bits must be dict type or int type'
if
isinstance
(
config
[
"quant_bits"
],
dict
):
for
quant_type
in
config
[
"quant_types"
]:
assert
quant_type
in
config
[
"quant_bits"
],
'bits length for %s must be specified in quant_bits dict'
%
quant_type
if
'weight'
in
config
[
"quant_types"
]:
if
not
_check_weight
(
layer
.
module
):
...
...
@@ -312,7 +318,7 @@ class Quantizer(Compressor):
def
new_forward
(
*
inputs
):
if
'input'
in
config
[
"quant_types"
]:
inputs
=
s
elf
.
quantize_input
(
inputs
,
config
=
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
inputs
=
s
traight_through_
quantize_input
.
apply
(
inputs
,
self
,
config
,
layer
)
if
'weight'
in
config
[
"quant_types"
]
and
_check_weight
(
layer
.
module
):
weight
=
layer
.
module
.
weight
.
data
...
...
@@ -324,12 +330,32 @@ class Quantizer(Compressor):
result
=
layer
.
_forward
(
*
inputs
)
if
'output'
in
config
[
"quant_types"
]:
result
=
self
.
quantize_output
(
result
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
result
=
straight_through_quantize_output
.
apply
(
result
,
self
,
config
,
layer
)
return
result
layer
.
module
.
forward
=
new_forward
class
straight_through_quantize_output
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
output
,
quantizer
,
config
,
layer
):
return
quantizer
.
quantize_output
(
output
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# Straight-through estimator
return
grad_output
,
None
,
None
,
None
class
straight_through_quantize_input
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
inputs
,
quantizer
,
config
,
layer
):
return
quantizer
.
quantize_input
(
inputs
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# Straight-through estimator
return
grad_output
,
None
,
None
,
None
def
_check_weight
(
module
):
try
:
return
isinstance
(
module
.
weight
,
torch
.
nn
.
Parameter
)
and
isinstance
(
module
.
weight
.
data
,
torch
.
Tensor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment