Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e76f196d
Unverified
Commit
e76f196d
authored
Aug 31, 2021
by
chenbohua3
Committed by
GitHub
Aug 31, 2021
Browse files
add quantize_input to QAT quantizer (#4084)
parent
5fc73ba6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
79 additions
and
34 deletions
+79
-34
examples/model_compress/quantization/QAT_torch_quantizer.py
examples/model_compress/quantization/QAT_torch_quantizer.py
+34
-18
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+37
-14
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+8
-2
No files found.
examples/model_compress/quantization/QAT_torch_quantizer.py
View file @
e76f196d
...
...
@@ -7,7 +7,7 @@ import sys
sys
.
path
.
append
(
'../models'
)
from
mnist.naive
import
NaiveModel
def
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
):
def
train
(
model
,
device
,
train_loader
,
optimizer
):
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
...
...
@@ -19,6 +19,7 @@ def train(model, quantizer, device, train_loader, optimizer):
if
batch_idx
%
100
==
0
:
print
(
'{:2.0f}% Loss {}'
.
format
(
100
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
test
(
model
,
device
,
test_loader
):
model
.
eval
()
test_loss
=
0
...
...
@@ -47,30 +48,45 @@ def main():
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
model
=
NaiveModel
()
'''you can change this to DoReFaQuantizer to implement it
DoReFaQuantizer(configure_list).compress(model)
'''
# Two things should be kept in mind when set this configure_list:
# 1. When deploying model on backend, some layers will be fused into one layer. For example, the consecutive
# conv + bn + relu layers will be fused into one big layer. If we want to execute the big layer in quantization
# mode, we should tell the backend the quantization information of the input, output, and the weight tensor of
# the big layer, which correspond to conv's input, conv's weight and relu's output.
# 2. Same tensor should be quantized only once. For example, if a tensor is the output of layer A and the input
# of the layer B, you should configure either {'quant_types': ['output'], 'op_names': ['a']} or
# {'quant_types': ['input'], 'op_names': ['b']} in the configure_list.
configure_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
{
'weight'
:
8
,
},
# you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types'
:[
'Conv2d'
,
'Linear'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_start_step'
:
1000
,
'op_types'
:[
'ReLU6'
]
}]
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
{
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'conv1'
,
'conv2'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
,
},
'op_names'
:
[
'relu1'
,
'relu2'
]
},
{
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'fc1'
],
},
{
'quant_types'
:
[
'output'
,
'weight'
,
'input'
],
'quant_bits'
:
{
'output'
:
8
,
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'fc2'
],
}]
model
=
NaiveModel
().
to
(
device
)
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
).
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
quantizer
=
QAT_Quantizer
(
model
,
configure_list
,
optimizer
)
# To enable batch normalization folding in the training process, you should
# pass dummy_input to the QAT_Quantizer.
quantizer
=
QAT_Quantizer
(
model
,
configure_list
,
optimizer
,
dummy_input
=
dummy_input
)
quantizer
.
compress
()
model
.
to
(
device
)
for
epoch
in
range
(
40
):
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
quantizer
,
device
,
train_loader
,
optimizer
)
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
model_path
=
"mnist_model.pth"
...
...
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
e76f196d
...
...
@@ -380,8 +380,10 @@ class QAT_Quantizer(Quantizer):
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
zeros
(
1
))
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'tracked_min_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'input_bits'
,
torch
.
zeros
(
1
))
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'output_bits'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_output'
,
torch
.
zeros
(
1
))
...
...
@@ -394,7 +396,7 @@ class QAT_Quantizer(Quantizer):
"""
del_attr_list
=
[
'old_weight'
,
'old_bias'
,
'ema_decay'
,
'tracked_min_output'
,
'tracked_max_output'
,
'tracked_min_input'
,
'tracked_max_input'
,
'scale'
,
'zero_point'
,
'weight_bits'
,
'output_bits'
,
'BN_FOLD_TAG'
]
'output_bits'
,
'BN_FOLD_TAG'
,
'input_bits'
]
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
...
...
@@ -409,8 +411,9 @@ class QAT_Quantizer(Quantizer):
List of configurations
"""
schema
=
QuantizerSchema
([{
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
]]),
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
,
'input'
]]),
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Schema
({
Optional
(
'input'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
Optional
(
'output'
):
And
(
int
,
lambda
n
:
0
<
n
<
32
),
})),
...
...
@@ -472,25 +475,17 @@ class QAT_Quantizer(Quantizer):
config
=
wrapper
.
config
module
=
wrapper
.
module
weight
=
module
.
weight
input
=
kwargs
[
'input_tensor'
]
# pylint: disable=redefined-builtin
weight_bits
=
get_bits_length
(
config
,
'weight'
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
# we dont update weight in evaluation stage
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_input
,
module
.
tracked_max_input
=
torch
.
min
(
input
),
torch
.
max
(
input
)
return
weight
if
not
wrapper
.
training
:
return
weight
current_min
,
current_max
=
torch
.
min
(
input
),
torch
.
max
(
input
)
module
.
tracked_min_input
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_input
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
module
.
ema_decay
)
# quantize weight
rmin
,
rmax
=
torch
.
min
(
weight
),
torch
.
max
(
weight
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
...
...
@@ -500,6 +495,31 @@ class QAT_Quantizer(Quantizer):
wrapper
.
module
.
weight
=
weight
return
weight
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
input_bits
=
get_bits_length
(
config
,
'input'
)
module
.
input_bits
=
torch
.
Tensor
([
input_bits
])
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
input_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_input
,
module
.
tracked_max_input
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
return
inputs
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
module
.
tracked_min_input
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_input
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
input_bits
,
module
.
tracked_min_input
,
module
.
tracked_max_input
)
inp
=
self
.
_quantize
(
input_bits
,
module
,
inputs
)
inp
=
self
.
_dequantize
(
module
,
inp
)
return
inp
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
...
...
@@ -519,8 +539,9 @@ class QAT_Quantizer(Quantizer):
module
.
ema_decay
)
module
.
tracked_max_output
=
update_ema
(
module
.
tracked_max_output
,
current_max
,
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min_output
,
module
.
tracked_max_output
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min_output
,
module
.
tracked_max_output
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
...
...
@@ -556,8 +577,6 @@ class QAT_Quantizer(Quantizer):
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_bits'
):
calibration_config
[
name
][
'weight_bits'
]
=
int
(
module
.
weight_bits
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
# Recover weight/bias for batch normalization folding
actual_weight
=
getattr
(
module
,
'old_weight'
,
None
)
...
...
@@ -573,6 +592,10 @@ class QAT_Quantizer(Quantizer):
module
.
register_parameter
(
'bias'
,
actual_bias
)
else
:
setattr
(
module
,
'bias'
,
None
)
if
hasattr
(
module
,
'input_bit'
):
calibration_config
[
name
][
'input_bits'
]
=
int
(
module
.
input_bit
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
if
hasattr
(
module
,
'output_bits'
):
calibration_config
[
name
][
'output_bits'
]
=
int
(
module
.
output_bits
)
...
...
test/ut/sdk/test_compressor_torch.py
View file @
e76f196d
...
...
@@ -308,7 +308,7 @@ class CompressorTestCase(TestCase):
def
test_torch_QAT_quantizer
(
self
):
model
=
TorchModel
()
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
8
,
'op_types'
:
[
'Conv2d'
,
'Linear'
]
},
{
...
...
@@ -326,18 +326,24 @@ class CompressorTestCase(TestCase):
# test quantize
# range not including 0
eps
=
1e-7
input
=
torch
.
tensor
([[
0
,
4
],
[
2
,
1
]])
.
float
()
input
=
torch
.
tensor
([[
1
,
4
],
[
2
,
1
]])
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
==
0
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
scale
,
torch
.
tensor
([
0.04
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
zero_point
,
torch
.
tensor
([
0.
])))
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
in
(
42
,
43
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
scale
,
torch
.
tensor
([
0.0796
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
zero_point
,
torch
.
tensor
([
0.
])))
# test value of weight and bias after quantization
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment