Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e76f196d
Unverified
Commit
e76f196d
authored
Aug 31, 2021
by
chenbohua3
Committed by
GitHub
Aug 31, 2021
Browse files
add quantize_input to QAT quantizer (#4084)
parent
5fc73ba6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
34 deletions
+79
-34
examples/model_compress/quantization/QAT_torch_quantizer.py
examples/model_compress/quantization/QAT_torch_quantizer.py
+34
-18
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+37
-14
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+8
-2
No files found.
examples/model_compress/quantization/QAT_torch_quantizer.py
View file @
e76f196d
...
...
@@ -7,7 +7,7 @@ import sys
sys
.
path
.
append
(
'../models'
)
from
mnist.naive
import
NaiveModel
def
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
):
def
train
(
model
,
device
,
train_loader
,
optimizer
):
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
...
...
@@ -19,6 +19,7 @@ def train(model, quantizer, device, train_loader, optimizer):
if
batch_idx
%
100
==
0
:
print
(
'{:2.0f}% Loss {}'
.
format
(
100
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
test
(
model
,
device
,
test_loader
):
model
.
eval
()
test_loss
=
0
...
...
@@ -47,30 +48,45 @@ def main():
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
model
=
NaiveModel
()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
# Two things should be kept in mind when set this configure_list:
# 1. When deploying model on backend, some layers will be fused into one layer. For example, the consecutive
# conv + bn + relu layers will be fused into one big layer. If we want to execute the big layer in quantization
# mode, we should tell the backend the quantization information of the input, output, and the weight tensor of
# the big layer, which correspond to conv's input, conv's weight and relu's output.
# 2. Same tensor should be quantized only once. For example, if a tensor is the output of layer A and the input
# of the layer B, you should configure either {'quant_types': ['output'], 'op_names': ['a']} or
# {'quant_types': ['input'], 'op_names': ['b']} in the configure_list.
configure_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
{
'weight'
:
8
,
},
# you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types'
:[
'Conv2d'
,
'Linear'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_start_step'
:
1000
,
'op_types'
:[
'ReLU6'
]
}]
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
{
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'conv1'
,
'conv2'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
,
},
'op_names'
:
[
'relu1'
,
'relu2'
]
},
{
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'fc1'
],
},
{
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'fc2'
],
}]
model
=
NaiveModel
().
to
(
device
)
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
).
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
QAT_Quantizer
(
model
,
configure_list
,
optimizer
)
# To enable batch normalization folding in the training process, you should
# pass dummy_input to the QAT_Quantizer.
quantizer
=
QAT_Quantizer
(
model
,
configure_list
,
optimizer
,
dummy_input
=
dummy_input
)
quantizer
.
compress
()
model
.
to
(
device
)
for
epoch
in
range
(
40
):
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
)
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
model_path
=
"mnist_model.pth"
...
...
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
e76f196d
...
...
@@ -380,8 +380,10 @@ class QAT_Quantizer(Quantizer):
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
zeros
(
1
))
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'tracked_min_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'input_bits'
,
torch
.
zeros
(
1
))
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'output_bits'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_output'
,
torch
.
zeros
(
1
))
...
...
@@ -394,7 +396,7 @@ class QAT_Quantizer(Quantizer):
"""
del_attr_list
=
[
'old_weight'
,
'old_bias'
,
'ema_decay'
,
'tracked_min_output'
,
'tracked_max_output'
,
'tracked_min_input'
,
'tracked_max_input'
,
'scale'
,
'zero_point'
,
'weight_bits'
,
'output_bits'
,
'BN_FOLD_TAG'
]
'output_bits'
,
'BN_FOLD_TAG'
,
'input_bits'
]
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
...
...
@@ -409,8 +411,9 @@ class QAT_Quantizer(Quantizer):
List of configurations
"""
schema
=
QuantizerSchema
([{
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
]]),
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
,
'input'
]]),
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Schema
({
Optional
(
'input'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Optional
(
'output'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
})),
...
...
@@ -472,25 +475,17 @@ class QAT_Quantizer(Quantizer):
config
=
wrapper
.
config
module
=
wrapper
.
module
weight
=
module
.
weight
input
=
kwargs
[
'input_tensor'
]
# pylint: disable=redefined-builtin
weight_bits
=
get_bits_length
(
config
,
'weight'
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
# we dont update weight in evaluation stage
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_input
,
module
.
tracked_max_input
=
torch
.
min
(
input
),
torch
.
max
(
input
)
return
weight
if
not
wrapper
.
training
:
return
weight
current_min
,
current_max
=
torch
.
min
(
input
),
torch
.
max
(
input
)
module
.
tracked_min_input
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_input
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
module
.
ema_decay
)
# quantize weight
rmin
,
rmax
=
torch
.
min
(
weight
),
torch
.
max
(
weight
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
...
...
@@ -500,6 +495,31 @@ class QAT_Quantizer(Quantizer):
wrapper
.
module
.
weight
=
weight
return
weight
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
input_bits
=
get_bits_length
(
config
,
'input'
)
module
.
input_bits
=
torch
.
Tensor
([
input_bits
])
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
input_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_input
,
module
.
tracked_max_input
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
return
inputs
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
module
.
tracked_min_input
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_input
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
input_bits
,
module
.
tracked_min_input
,
module
.
tracked_max_input
)
inp
=
self
.
_quantize
(
input_bits
,
module
,
inputs
)
inp
=
self
.
_dequantize
(
module
,
inp
)
return
inp
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
...
...
@@ -519,8 +539,9 @@ class QAT_Quantizer(Quantizer):
module
.
ema_decay
)
module
.
tracked_max_output
=
update_ema
(
module
.
tracked_max_output
,
current_max
,
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min_output
,
module
.
tracked_max_output
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min_output
,
module
.
tracked_max_output
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
...
...
@@ -556,8 +577,6 @@ class QAT_Quantizer(Quantizer):
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_bits'
):
calibration_config
[
name
][
'weight_bits'
]
=
int
(
module
.
weight_bits
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
# Recover weight/bias for batch normalization folding
actual_weight
=
getattr
(
module
,
'old_weight'
,
None
)
...
...
@@ -573,6 +592,10 @@ class QAT_Quantizer(Quantizer):
module
.
register_parameter
(
'bias'
,
actual_bias
)
else
:
setattr
(
module
,
'bias'
,
None
)
if
hasattr
(
module
,
'input_bit'
):
calibration_config
[
name
][
'input_bits'
]
=
int
(
module
.
input_bit
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
if
hasattr
(
module
,
'output_bits'
):
calibration_config
[
name
][
'output_bits'
]
=
int
(
module
.
output_bits
)
...
...
test/ut/sdk/test_compressor_torch.py
View file @
e76f196d
...
...
@@ -308,7 +308,7 @@ class CompressorTestCase(TestCase):
def
test_torch_QAT_quantizer
(
self
):
model
=
TorchModel
()
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
8
,
'op_types'
:
[
'Conv2d'
,
'Linear'
]
},
{
...
...
@@ -326,18 +326,24 @@ class CompressorTestCase(TestCase):
# test quantize
# range not including 0
eps
=
1e-7
input
=
torch
.
tensor
([[
0
,
4
],
[
2
,
1
]])
.
float
()
input
=
torch
.
tensor
([[
1
,
4
],
[
2
,
1
]])
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
==
0
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
scale
,
torch
.
tensor
([
0.04
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
zero_point
,
torch
.
tensor
([
0.
])))
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
in
(
42
,
43
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
scale
,
torch
.
tensor
([
0.0796
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
zero_point
,
torch
.
tensor
([
0.
])))
# test value of weight and bias after quantization
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment