Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7fc5af07
"docs/source/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "bf5f121c0284a2a06483b585f0d49e8508c69573"
Unverified
Commit
7fc5af07
authored
Jul 26, 2021
by
chenbohua3
Committed by
GitHub
Jul 26, 2021
Browse files
Add batch normalization folding to QAT quantizer (#3911)
parent
441c5da5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
197 additions
and
18 deletions
+197
-18
docs/en_US/Compression/Quantizer.rst
docs/en_US/Compression/Quantizer.rst
+18
-3
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+61
-8
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+118
-7
No files found.
docs/en_US/Compression/Quantizer.rst
View file @
7fc5af07
...
...
@@ -82,10 +82,25 @@ configuration needed by this algorithm :
disable
quantization
until
model
are
run
by
certain
number
of
steps
,
this
allows
the
network
to
enter
a
more
stable
state
where
activation
quantization
ranges
do
not
exclude
a
signi
fi
cant
fraction
of
values
,
default
value
is
0
note
^^^^
Batch
normalization
folding
^^^^
^^^^^^^^^^^^^^^^^^^^^^^
batch
normalization
folding
is
currently
not
supported
.
Batch
normalization
folding
is
supported
in
QAT
quantizer
.
It
can
be
easily
enabled
by
passing
an
argument
`
dummy_input
`
to
the
quantizer
,
like
:
..
code
-
block
::
python
#
assume
your
model
takes
an
input
of
shape
(
1
,
1
,
28
,
28
)
#
and
dummy_input
must
be
on
the
same
device
as
the
model
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
)
#
pass
the
dummy_input
to
the
quantizer
quantizer
=
QAT_Quantizer
(
model
,
config_list
,
dummy_input
=
dummy_input
)
The
quantizer
will
automatically
detect
Conv
-
BN
patterns
and
simulate
batch
normalization
folding
process
in
the
training
graph
.
Note
that
when
the
quantization
aware
training
process
is
finished
,
the
folded
weight
/
bias
would
be
restored
after
calling
`
quantizer
.
export_model
`.
----
...
...
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
7fc5af07
...
...
@@ -6,7 +6,7 @@ import copy
import
torch
from
schema
import
Schema
,
And
,
Or
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
QuantizerSchema
from
nni.compression.pytorch.compressor
import
Quantizer
,
QuantForward
,
QuantGrad
,
QuantType
from
nni.compression.pytorch.compressor
import
BN_FOLD_TAG
,
Quantizer
,
QuantForward
,
QuantGrad
,
QuantType
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
,
'LsqQuantizer'
]
...
...
@@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer):
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
"""
Parameters
----------
...
...
@@ -145,8 +145,13 @@ class QAT_Quantizer(Quantizer):
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
- dummy_input : tuple of tensor
inputs to the model, which are used to get the graph of the module. The graph is used to find
Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not
given, the batch normalization folding would be disabled.
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
,
dummy_input
)
self
.
quant_grad
=
QATGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
device
=
next
(
model
.
parameters
()).
device
...
...
@@ -169,8 +174,9 @@ class QAT_Quantizer(Quantizer):
"""
delete redundant parameters in quantize module
"""
del_attr_list
=
[
'old_weight'
,
'ema_decay'
,
'tracked_min_activation'
,
'tracked_max_activation'
,
'tracked_min_input'
,
\
'tracked_max_input'
,
'scale'
,
'zero_point'
,
'weight_bit'
,
'activation_bit'
]
del_attr_list
=
[
'old_weight'
,
'old_bias'
,
'ema_decay'
,
'tracked_min_activation'
,
'tracked_max_activation'
,
'tracked_min_input'
,
'tracked_max_input'
,
'scale'
,
'zero_point'
,
'weight_bit'
,
'activation_bit'
,
'BN_FOLD_TAG'
]
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
...
...
@@ -334,6 +340,23 @@ class QAT_Quantizer(Quantizer):
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
# Recover weight/bias for batch normalization folding
if
hasattr
(
module
,
BN_FOLD_TAG
):
actual_weight
=
getattr
(
module
,
'old_weight'
,
None
)
if
actual_weight
is
None
:
logger
.
warning
(
"Can not recover weight for layer %s. "
"This may lead to a wrong accuracy performance on the backend."
,
name
)
delattr
(
module
,
'weight'
)
module
.
register_parameter
(
'weight'
,
actual_weight
)
actual_bias
=
getattr
(
module
,
'old_bias'
,
None
)
delattr
(
module
,
'bias'
)
if
actual_bias
is
not
None
:
module
.
register_parameter
(
'bias'
,
actual_bias
)
else
:
setattr
(
module
,
'bias'
,
None
)
if
hasattr
(
module
,
'activation_bit'
):
calibration_config
[
name
][
'activation_bit'
]
=
int
(
module
.
activation_bit
)
calibration_config
[
name
][
'tracked_min_activation'
]
=
float
(
module
.
tracked_min_activation
)
...
...
@@ -344,9 +367,39 @@ class QAT_Quantizer(Quantizer):
return
calibration_config
def
fold_bn
(
self
,
config
,
**
kwargs
):
# TODO simulate folded weight
pass
def
fold_bn
(
self
,
*
inputs
,
wrapper
):
"""
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module
=
wrapper
.
module
bn_module
=
wrapper
.
bn_module
with
torch
.
no_grad
():
output
=
module
(
*
inputs
)
_
=
bn_module
(
output
)
running_mean
=
bn_module
.
running_mean
running_var
=
torch
.
sqrt
(
bn_module
.
running_var
+
bn_module
.
eps
)
bn_weight
=
bn_module
.
weight
bn_bias
=
bn_module
.
bias
dimensions
=
len
(
module
.
weight
.
shape
)
shape
=
[
-
1
]
+
[
1
]
*
(
dimensions
-
1
)
new_weight
=
module
.
old_weight
*
bn_weight
.
reshape
(
shape
)
/
running_var
.
reshape
(
shape
)
if
hasattr
(
module
,
'old_bias'
):
new_bias
=
bn_bias
+
(
module
.
old_bias
-
running_mean
)
/
running_var
*
bn_weight
else
:
new_bias
=
bn_bias
-
running_mean
/
running_var
*
bn_weight
return
new_weight
,
new_bias
def
step_with_optimizer
(
self
):
"""
...
...
nni/compression/pytorch/compressor.py
View file @
7fc5af07
...
...
@@ -4,6 +4,7 @@
import
types
import
logging
import
torch
from
nni.common.graph_utils
import
build_module_graph
from
.
import
default_layers
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -463,7 +464,7 @@ class Pruner(Compressor):
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
,
bn_module
=
None
):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
...
...
@@ -479,6 +480,8 @@ class QuantizerModuleWrapper(torch.nn.Module):
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
bn_module : torch.nn.Module
batch norm layer corresponding to current module, used for simulating batch normalization folding
"""
super
().
__init__
()
# origin layer information
...
...
@@ -488,6 +491,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner
self
.
config
=
config
self
.
quantizer
=
quantizer
self
.
bn_module
=
bn_module
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
...
...
@@ -501,6 +505,17 @@ class QuantizerModuleWrapper(torch.nn.Module):
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
# for batch normalization folding
if
self
.
bn_module
is
not
None
:
if
_check_bias
(
self
.
module
):
self
.
module
.
register_parameter
(
'old_bias'
,
torch
.
nn
.
Parameter
(
self
.
module
.
bias
))
init_tensor
=
self
.
module
.
old_bias
else
:
init_tensor
=
torch
.
zeros_like
(
self
.
bn_module
.
weight
)
delattr
(
self
.
module
,
'bias'
)
self
.
module
.
register_buffer
(
'bias'
,
init_tensor
)
setattr
(
module
,
BN_FOLD_TAG
,
True
)
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
inputs
=
self
.
quantizer
.
quant_grad
(
...
...
@@ -509,12 +524,19 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
if
self
.
bn_module
is
not
None
:
# simulate batch normalization folding
new_weight
,
new_bias
=
self
.
quantizer
.
fold_bn
(
*
inputs
,
wrapper
=
self
)
self
.
module
.
bias
=
new_bias
self
.
module
.
weight
=
new_weight
else
:
new_weight
=
self
.
module
.
old_weight
self
.
quantizer
.
quant_grad
(
self
.
module
.
old
_weight
,
new
_weight
,
QuantType
.
QUANT_WEIGHT
,
self
,
inputs
[
0
])
result
=
self
.
module
(
*
inputs
)
else
:
result
=
self
.
module
(
*
inputs
)
if
'output'
in
self
.
config
[
'quant_types'
]:
...
...
@@ -525,12 +547,35 @@ class QuantizerModuleWrapper(torch.nn.Module):
return
result
class
QuantizerIdentityWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
):
"""
Used to wrap modules that should be treated as torch.Identity
Parameters
----------
module : pytorch module
the module to be wrapped
module_name : str
the name of the module to wrapped, wrapper module shares same name
"""
super
().
__init__
()
self
.
module
=
module
self
.
module_name
=
module_name
def
forward
(
self
,
x
):
return
x
class
Quantizer
(
Compressor
):
"""
Base quantizer for pytorch quantizer
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
self
.
identity_wrappers
=
[]
self
.
conv_bn_patterns
=
{}
self
.
find_conv_bn_patterns
(
model
,
dummy_input
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
QuantGrad
.
apply
if
self
.
optimizer
is
not
None
:
...
...
@@ -540,6 +585,10 @@ class Quantizer(Compressor):
# old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer.
self
.
optimizer
.
add_param_group
({
"params"
:
wrapper
.
module
.
old_weight
})
# This is for conv with bias + bn. Although this situation is relatively rare,
# we still need to deal with the old_bias when it occurs
if
hasattr
(
wrapper
.
module
,
"old_bias"
):
self
.
optimizer
.
add_param_group
({
"params"
:
getattr
(
wrapper
.
module
,
"old_bias"
)})
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
"""
...
...
@@ -597,7 +646,36 @@ class Quantizer(Compressor):
for
quant_type
in
config
[
'quant_types'
]:
assert
quant_type
in
config
[
'quant_bits'
],
'bits length for %s must be specified in quant_bits dict'
%
quant_type
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
# bound bn module to corresponding conv module
bn_module
=
None
if
layer
.
name
in
self
.
conv_bn_patterns
:
bn_module_name
=
self
.
conv_bn_patterns
[
layer
.
name
]
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
name
==
bn_module_name
:
bn_module
=
module
break
assert
bn_module
is
not
None
,
"BN module corresponding to layer {} is not found"
.
format
(
layer
.
name
)
self
.
identity_wrappers
.
append
(
QuantizerIdentityWrapper
(
bn_module
,
bn_module_name
))
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
,
bn_module
)
def
_wrap_model
(
self
):
"""
wrap all modules that needed to be compressed
"""
# wrap folded bn in order to bypass its forward process
for
wrapper
in
reversed
(
self
.
identity_wrappers
):
_setattr
(
self
.
bound_model
,
wrapper
.
module_name
,
wrapper
)
super
().
_wrap_model
()
def
_unwrap_model
(
self
):
"""
unwrap all modules that needed to be compressed
"""
for
wrapper
in
self
.
identity_wrappers
:
_setattr
(
self
.
bound_model
,
wrapper
.
module_name
,
wrapper
.
module
)
super
().
_unwrap_model
()
def
export_model_save
(
self
,
model
,
model_path
,
calibration_config
=
None
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
...
...
@@ -660,6 +738,30 @@ class Quantizer(Compressor):
"""
raise
NotImplementedError
(
'Quantizer must overload export_model()'
)
def
find_conv_bn_patterns
(
self
,
model
,
dummy_input
):
"""
Find all Conv-BN patterns, used for batch normalization folding
Parameters
----------
model : torch.nn.Module
model to be analyzed.
dummy_input : tupel of torch.tensor
inputs to the model, used for generating the torchscript
"""
if
dummy_input
is
None
:
_logger
.
debug
(
"Model inputs are not given, batch normalization folding is disabled"
)
return
graph
=
build_module_graph
(
model
,
dummy_input
)
for
node_group
in
graph
.
nodes_py
.
nodes_op
:
if
node_group
.
op_type
in
BN_FOLD_OP
:
successors
=
graph
.
find_successors
(
node_group
.
unique_name
)
successors
=
[
graph
.
name_to_node
[
x
]
for
x
in
successors
]
for
successor
in
successors
:
if
successor
.
op_type
==
'BatchNorm2d'
:
self
.
conv_bn_patterns
[
node_group
.
name
]
=
successor
.
name
def
step_with_optimizer
(
self
):
pass
...
...
@@ -677,6 +779,9 @@ QType_Dict = {
2
:
"output"
}
BN_FOLD_OP
=
[
"Conv2d"
]
BN_FOLD_TAG
=
'BN_FOLD_TAG'
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
Base class for overriding backward function of quantization operation.
...
...
@@ -773,6 +878,12 @@ def _check_weight(module):
except
AttributeError
:
return
False
def
_check_bias
(
module
):
try
:
return
isinstance
(
module
.
bias
.
data
,
torch
.
Tensor
)
except
AttributeError
:
return
False
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
if
quant_type
==
QuantType
.
QUANT_INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
*
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment