Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
370e88df
"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "8b28251b9434f97fd3299e8a33ab6ffcd67afdb5"
Unverified
Commit
370e88df
authored
Jul 28, 2021
by
chenbohua3
Committed by
GitHub
Jul 28, 2021
Browse files
Add post training observer_quantizer (#3915)
parent
3f1e4f55
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
390 additions
and
1 deletion
+390
-1
examples/model_compress/quantization/observer_quantizer.py
examples/model_compress/quantization/observer_quantizer.py
+117
-0
nni/algorithms/compression/pytorch/quantization/observers.py
nni/algorithms/compression/pytorch/quantization/observers.py
+3
-0
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+229
-1
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+41
-0
No files found.
examples/model_compress/quantization/observer_quantizer.py
0 → 100644
View file @
370e88df
import
torch
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.pytorch.quantization
import
ObserverQuantizer
import
sys
sys
.
path
.
append
(
'../models'
)
from
mnist.naive
import
NaiveModel
def
train
(
model
,
device
,
train_loader
,
optimizer
):
model
.
to
(
device
)
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
if
batch_idx
%
100
==
0
:
print
(
'{:2.0f}% Loss {}'
.
format
(
100
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
test
(
model
,
device
,
test_loader
):
model
.
eval
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
print
(
'Loss: {} Accuracy: {}%)
\n
'
.
format
(
test_loss
,
100
*
correct
/
len
(
test_loader
.
dataset
)))
def
calibration
(
model
,
device
,
test_loader
):
model
.
eval
()
with
torch
.
no_grad
():
for
data
,
_
in
test_loader
:
data
=
data
.
to
(
device
)
model
(
data
)
def
main
():
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
trans
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'data'
,
train
=
True
,
download
=
True
,
transform
=
trans
),
batch_size
=
64
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
MNIST
(
'data'
,
train
=
False
,
transform
=
trans
),
batch_size
=
1000
,
shuffle
=
True
)
model
=
NaiveModel
()
configure_list
=
[{
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
{
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'conv1'
],
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
,
},
'op_names'
:
[
'relu1'
],
},
{
'quant_types'
:
[
'weight'
,
'input'
],
'quant_bits'
:
{
'weight'
:
8
,
'input'
:
8
},
'op_names'
:
[
'conv2'
],
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
},
'op_names'
:
[
'relu2'
],
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
{
'output'
:
8
},
'op_names'
:
[
'max_pool2'
],
}
]
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
# Train the model to get a baseline performance
for
epoch
in
range
(
5
):
print
(
'# Epoch {} #'
.
format
(
epoch
))
train
(
model
,
device
,
train_loader
,
optimizer
)
test
(
model
,
device
,
test_loader
)
# Construct the ObserverQuantizer. Note that currently ObserverQuantizer only works
# in evaluation mode.
quantizer
=
ObserverQuantizer
(
model
.
eval
(),
configure_list
,
optimizer
)
# Use the test data set to do calibration, this will not change the model parameters
calibration
(
model
,
device
,
test_loader
)
# obtain the quantization information and switch the model to "accuracy verification" mode
quantizer
.
compress
()
# measure the accuracy of the quantized model.
test
(
model
,
device
,
test_loader
)
model_path
=
"mnist_model.pth"
calibration_path
=
"mnist_calibration.pth"
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
)
print
(
"calibration_config: "
,
calibration_config
)
# For now the quantization settings of ObserverQuantizer does not match the TensorRT,
# so TensorRT conversion are not supported
# current settings:
# weight : per_tensor_symmetric, qint8
# activation : per_tensor_affine, quint8, reduce_range=True
if
__name__
==
'__main__'
:
main
()
nni/algorithms/compression/pytorch/quantization/observers.py
0 → 100644
View file @
370e88df
from
torch.quantization
import
default_weight_observer
,
default_histogram_observer
__all__
=
[
"default_weight_observer"
,
"default_histogram_observer"
]
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
370e88df
...
...
@@ -3,12 +3,15 @@
import
logging
import
copy
from
collections
import
defaultdict
import
torch
from
schema
import
Schema
,
And
,
Or
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
QuantizerSchema
from
nni.compression.pytorch.compressor
import
BN_FOLD_TAG
,
Quantizer
,
QuantForward
,
QuantGrad
,
QuantType
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
,
'LsqQuantizer'
]
from
.observers
import
default_weight_observer
,
default_histogram_observer
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
,
'LsqQuantizer'
,
'ObserverQuantizer'
]
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -120,6 +123,231 @@ class QATGrad(QuantGrad):
return
grad_output
class
ObserverQuantizer
(
Quantizer
):
"""This quantizer uses observers to record weight/activation statistics to get quantization information.
The whole process can be divided into three steps:
1. It will register observers to the place where quantization would happen (just like registering hooks).
2. The observers would record tensors' statistics during calibration.
3. Scale & zero point would be obtained after calibration.
Note that the observer type, tensor dtype and quantization qscheme are hard coded for now. Their customization
are under development and will be ready soon.
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
# NOTE: this quantizer is experimental for now. The dtype and qscheme of quantization
# is hard-coded.
# TODO:
# 1. support dtype and qscheme customization through config_list. Current settings:
# weight observer : per_tensor_symmetric, qint8
# activation observer : per_tensor_affine, quint8, reduce_range=True
# 2. add more kinds of observers, such as Kullback-Leibler divergence.
# 3. add batch normalization folding
assert
not
model
.
training
,
"Currently the observer quantizer only works in evaluation mode."
self
.
quant_grad
=
QuantForward
()
self
.
device
=
next
(
model
.
parameters
()).
device
modules_to_compress
=
self
.
get_modules_to_compress
()
all_observers
=
defaultdict
(
dict
)
weight_q_min
,
weight_q_max
=
-
127
,
127
activation_q_min
,
activation_q_max
=
0
,
127
# reduce_range is set to True
self
.
compressed
=
False
for
layer
,
config
in
modules_to_compress
:
layer_name
=
layer
.
name
module
=
layer
.
module
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
all_observers
[
layer_name
][
"weight"
]
=
default_weight_observer
()
setattr
(
module
,
"weight_qmax"
,
weight_q_max
)
setattr
(
module
,
"weight_qmin"
,
weight_q_min
)
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
all_observers
[
layer_name
][
"input"
]
=
default_histogram_observer
()
setattr
(
module
,
"input_qmax"
,
activation_q_max
)
setattr
(
module
,
"input_qmin"
,
activation_q_min
)
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
all_observers
[
layer_name
][
"output"
]
=
default_histogram_observer
()
setattr
(
module
,
"output_qmax"
,
activation_q_max
)
setattr
(
module
,
"output_qmin"
,
activation_q_min
)
self
.
all_observers
=
all_observers
self
.
bound_model
.
to
(
self
.
device
)
def
validate_config
(
self
,
model
,
config_list
):
schema
=
QuantizerSchema
([{
Optional
(
'quant_types'
):
Schema
([
lambda
x
:
x
in
[
'weight'
,
'output'
,
'input'
]]),
Optional
(
'quant_bits'
):
Or
(
And
(
int
,
lambda
n
:
n
==
8
),
Schema
({
Optional
(
'weight'
):
And
(
int
,
lambda
n
:
n
==
8
),
Optional
(
'output'
):
And
(
int
,
lambda
n
:
n
==
8
),
Optional
(
'input'
):
And
(
int
,
lambda
n
:
n
==
8
),
})),
Optional
(
'op_types'
):
[
str
],
Optional
(
'op_names'
):
[
str
]
}],
model
,
logger
)
schema
.
validate
(
config_list
)
def
record
(
self
,
wrapper
,
quant_type
,
tensor
):
name
=
wrapper
.
name
observer
=
self
.
all_observers
[
name
][
quant_type
]
if
isinstance
(
tensor
,
tuple
):
# NB: This only works for single tensor
tensor
=
(
t
.
cpu
()
for
t
in
tensor
)
observer
(
*
tensor
)
else
:
observer
(
tensor
.
cpu
())
def
calculate_qparams
(
self
,
name
,
quant_type
):
observer
=
self
.
all_observers
[
name
][
quant_type
]
scale
,
zero_point
=
observer
.
calculate_qparams
()
return
scale
,
zero_point
def
_quantize
(
self
,
x
,
scale
,
zero_point
,
qmin
,
qmax
):
x
=
x
/
scale
+
zero_point
x
=
torch
.
clamp
(
x
,
qmin
,
qmax
)
x
=
torch
.
round
(
x
)
x
=
(
x
-
zero_point
)
*
scale
return
x
def
quantize_input
(
self
,
*
inputs
,
wrapper
,
**
kwargs
):
if
self
.
compressed
:
module
=
wrapper
.
module
new_input
=
self
.
_quantize
(
inputs
[
0
],
module
.
input_scale
,
module
.
input_zero_point
,
module
.
input_qmin
,
module
.
input_qmax
)
list_inp
=
list
(
inputs
)
list_inp
[
0
]
=
new_input
inputs
=
tuple
(
list_inp
)
else
:
self
.
record
(
wrapper
,
'input'
,
inputs
)
return
inputs
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
# If ObserverQuantizer.compress is executed, the weight will be set to
# the Pseudo-quantized one. So there is no need to quantize it
if
self
.
compressed
:
return
module
=
wrapper
.
module
old_weight
=
module
.
weight
self
.
record
(
wrapper
,
'weight'
,
old_weight
)
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
if
self
.
compressed
:
module
=
wrapper
.
module
new_output
=
self
.
_quantize
(
output
,
module
.
output_scale
,
module
.
output_zero_point
,
module
.
output_qmin
,
module
.
output_qmax
)
else
:
self
.
record
(
wrapper
,
'output'
,
output
)
new_output
=
output
return
new_output
def
compress
(
self
):
"""
Calculate quantization information of each tensor. Note that the inference of
the compressed model will no longer update the corresponding. Instead, the quantization
process will be simulated, which is used to test the accuracy of the quantization.
"""
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
module
=
layer
.
module
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
scale
,
zero_point
=
self
.
calculate_qparams
(
layer
.
name
,
'weight'
)
module
.
register_buffer
(
'weight_scale'
,
scale
.
to
(
self
.
device
))
module
.
register_buffer
(
'weight_zero_point'
,
zero_point
.
to
(
self
.
device
))
weight
=
module
.
weight
quantized_weight
=
self
.
_quantize
(
weight
,
module
.
weight_scale
,
module
.
weight_zero_point
,
module
.
weight_qmin
,
module
.
weight_qmax
)
delattr
(
module
,
'weight'
)
module
.
register_parameter
(
'weight'
,
torch
.
nn
.
Parameter
(
quantized_weight
))
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
scale
,
zero_point
=
self
.
calculate_qparams
(
layer
.
name
,
'input'
)
module
.
register_buffer
(
'input_scale'
,
scale
.
to
(
self
.
device
))
module
.
register_buffer
(
'input_zero_point'
,
zero_point
.
to
(
self
.
device
))
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
scale
,
zero_point
=
self
.
calculate_qparams
(
layer
.
name
,
'output'
)
module
.
register_buffer
(
'output_scale'
,
scale
.
to
(
self
.
device
))
module
.
register_buffer
(
'output_zero_point'
,
zero_point
.
to
(
self
.
device
))
self
.
compressed
=
True
super
().
compress
()
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert
model_path
is
not
None
,
'model_path must be specified'
self
.
_unwrap_model
()
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
hasattr
(
module
,
'weight_scale'
)
or
hasattr
(
module
,
'input_scale'
)
or
hasattr
(
module
,
'output_scale'
):
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_scale'
):
calibration_config
[
name
][
'weight_bit'
]
=
8
val
=
float
(
module
.
weight_scale
*
module
.
weight_qmax
)
calibration_config
[
name
][
'tracked_max_weight'
]
=
val
calibration_config
[
name
][
'tracked_min_weight'
]
=
-
val
calibration_config
[
name
][
'tracked_weight_qmin'
]
=
-
127
calibration_config
[
name
][
'tracked_weight_qmax'
]
=
127
# refactor these magic numbers when customizations of dtype and qscheme are ready.
if
hasattr
(
module
,
'input_scale'
):
calibration_config
[
name
][
'input_bit'
]
=
8
max_input
=
float
(
module
.
input_scale
*
(
module
.
input_qmax
-
module
.
input_zero_point
))
min_input
=
float
(
module
.
input_scale
*
(
module
.
input_qmin
-
module
.
input_zero_point
))
calibration_config
[
name
][
'tracked_min_input'
]
=
min_input
calibration_config
[
name
][
'tracked_max_input'
]
=
max_input
calibration_config
[
name
][
'tracked_input_qmin'
]
=
0
calibration_config
[
name
][
'tracked_input_qmax'
]
=
127
if
hasattr
(
module
,
'output_scale'
):
calibration_config
[
name
][
'activation_bit'
]
=
8
max_input
=
float
(
module
.
output_scale
*
(
module
.
output_qmax
-
module
.
output_zero_point
))
min_input
=
float
(
module
.
output_scale
*
(
module
.
output_qmin
-
module
.
output_zero_point
))
calibration_config
[
name
][
'tracked_min_activation'
]
=
min_input
calibration_config
[
name
][
'tracked_max_activation'
]
=
max_input
calibration_config
[
name
][
'tracked_activation_qmin'
]
=
0
calibration_config
[
name
][
'tracked_activation_qmax'
]
=
127
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
return
calibration_config
def
_del_simulated_attr
(
self
,
module
):
"""
delete redundant parameters in quantize module
"""
del_attr_list
=
[
'old_weight'
,
'steps'
,
'weight_qmax'
,
'weight_qmin'
,
'input_qmax'
,
'input_qmin'
,
'output_qmax'
,
'output_qmin'
,
'weight_scale'
,
'weight_zero_point'
,
'input_scale'
,
'input_zero_point'
,
'output_scale'
,
'output_zero_point'
]
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
class
QAT_Quantizer
(
Quantizer
):
"""Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
...
...
test/ut/sdk/test_compressor_torch.py
View file @
370e88df
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
from
unittest
import
TestCase
,
main
import
numpy
as
np
import
torch
...
...
@@ -263,6 +264,46 @@ class CompressorTestCase(TestCase):
assert
all
(
torch
.
sum
(
mask1
[
'weight_mask'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
0.
,
0.
,
0
,
0.
,
25.
]))
assert
all
(
torch
.
sum
(
mask2
[
'weight_mask'
],
(
1
,
2
,
3
)).
numpy
()
==
np
.
array
([
125.
,
125.
,
125.
,
125.
,
125.
,
125.
,
125.
,
0.
,
0.
,
0.
]))
def
test_torch_observer_quantizer
(
self
):
model
=
TorchModel
()
# test invalid config
# only support 8bit for now
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
5
,
'op_types'
:
[
'Conv2d'
,
'Linear'
]
}]
with
self
.
assertRaises
(
schema
.
SchemaError
):
torch_quantizer
.
ObserverQuantizer
(
model
,
config_list
)
# weight will not change for now
model
=
TorchModel
().
eval
()
origin_parameters
=
copy
.
deepcopy
(
dict
(
model
.
named_parameters
()))
config_list
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
8
,
'op_types'
:
[
'Conv2d'
,
'Linear'
]
}]
quantizer
=
torch_quantizer
.
ObserverQuantizer
(
model
,
config_list
)
input
=
torch
.
randn
(
1
,
1
,
28
,
28
)
model
(
input
)
quantizer
.
compress
()
buffers
=
dict
(
model
.
named_buffers
())
scales
=
{
k
:
v
for
k
,
v
in
buffers
.
items
()
if
'scale'
in
k
}
model_path
=
"test_model.pth"
calibration_path
=
"test_calibration.pth"
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
)
new_parameters
=
dict
(
model
.
named_parameters
())
for
layer_name
,
v
in
calibration_config
.
items
():
scale_name
=
layer_name
+
'.module.weight_scale'
weight_name
=
layer_name
+
'.weight'
s
=
float
(
scales
[
scale_name
])
self
.
assertTrue
(
torch
.
allclose
(
origin_parameters
[
weight_name
],
new_parameters
[
weight_name
],
atol
=
0.5
*
s
))
self
.
assertTrue
(
calibration_config
is
not
None
)
self
.
assertTrue
(
len
(
calibration_config
)
==
4
)
def
test_torch_QAT_quantizer
(
self
):
model
=
TorchModel
()
config_list
=
[{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment