Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7fc5af07
Unverified
Commit
7fc5af07
authored
Jul 26, 2021
by
chenbohua3
Committed by
GitHub
Jul 26, 2021
Browse files
Add batch normalization folding to QAT quantizer (#3911)
parent
441c5da5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
197 additions
and
18 deletions
+197
-18
docs/en_US/Compression/Quantizer.rst
docs/en_US/Compression/Quantizer.rst
+18
-3
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+61
-8
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+118
-7
No files found.
docs/en_US/Compression/Quantizer.rst
View file @
7fc5af07
...
@@ -82,10 +82,25 @@ configuration needed by this algorithm :
...
@@ -82,10 +82,25 @@ configuration needed by this algorithm :
disable
quantization
until
model
are
run
by
certain
number
of
steps
,
this
allows
the
network
to
enter
a
more
stable
disable
quantization
until
model
are
run
by
certain
number
of
steps
,
this
allows
the
network
to
enter
a
more
stable
state
where
activation
quantization
ranges
do
not
exclude
a
signi
fi
cant
fraction
of
values
,
default
value
is
0
state
where
activation
quantization
ranges
do
not
exclude
a
signi
fi
cant
fraction
of
values
,
default
value
is
0
note
Batch
normalization
folding
^^^^
^^^^
^^^^^^^^^^^^^^^^^^^^^^^
batch
normalization
folding
is
currently
not
supported
.
Batch
normalization
folding
is
supported
in
QAT
quantizer
.
It
can
be
easily
enabled
by
passing
an
argument
`
dummy_input
`
to
the
quantizer
,
like
:
..
code
-
block
::
python
#
assume
your
model
takes
an
input
of
shape
(
1
,
1
,
28
,
28
)
#
and
dummy_input
must
be
on
the
same
device
as
the
model
dummy_input
=
torch
.
randn
(
1
,
1
,
28
,
28
)
#
pass
the
dummy_input
to
the
quantizer
quantizer
=
QAT_Quantizer
(
model
,
config_list
,
dummy_input
=
dummy_input
)
The
quantizer
will
automatically
detect
Conv
-
BN
patterns
and
simulate
batch
normalization
folding
process
in
the
training
graph
.
Note
that
when
the
quantization
aware
training
process
is
finished
,
the
folded
weight
/
bias
would
be
restored
after
calling
`
quantizer
.
export_model
`.
----
----
...
...
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
7fc5af07
...
@@ -6,7 +6,7 @@ import copy
...
@@ -6,7 +6,7 @@ import copy
import
torch
import
torch
from
schema
import
Schema
,
And
,
Or
,
Optional
from
schema
import
Schema
,
And
,
Or
,
Optional
from
nni.compression.pytorch.utils.config_validation
import
QuantizerSchema
from
nni.compression.pytorch.utils.config_validation
import
QuantizerSchema
from
nni.compression.pytorch.compressor
import
Quantizer
,
QuantForward
,
QuantGrad
,
QuantType
from
nni.compression.pytorch.compressor
import
BN_FOLD_TAG
,
Quantizer
,
QuantForward
,
QuantGrad
,
QuantType
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
,
'LsqQuantizer'
]
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
,
'LsqQuantizer'
]
...
@@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -126,7 +126,7 @@ class QAT_Quantizer(Quantizer):
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -145,8 +145,13 @@ class QAT_Quantizer(Quantizer):
...
@@ -145,8 +145,13 @@ class QAT_Quantizer(Quantizer):
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
state where activation quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
types of nn.module you want to apply quantization, eg. 'Conv2d'
- dummy_input : tuple of tensor
inputs to the model, which are used to get the graph of the module. The graph is used to find
Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not
given, the batch normalization folding would be disabled.
"""
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
,
dummy_input
)
self
.
quant_grad
=
QATGrad
.
apply
self
.
quant_grad
=
QATGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
device
=
next
(
model
.
parameters
()).
device
device
=
next
(
model
.
parameters
()).
device
...
@@ -169,8 +174,9 @@ class QAT_Quantizer(Quantizer):
...
@@ -169,8 +174,9 @@ class QAT_Quantizer(Quantizer):
"""
"""
delete redundant parameters in quantize module
delete redundant parameters in quantize module
"""
"""
del_attr_list
=
[
'old_weight'
,
'ema_decay'
,
'tracked_min_activation'
,
'tracked_max_activation'
,
'tracked_min_input'
,
\
del_attr_list
=
[
'old_weight'
,
'old_bias'
,
'ema_decay'
,
'tracked_min_activation'
,
'tracked_max_activation'
,
'tracked_max_input'
,
'scale'
,
'zero_point'
,
'weight_bit'
,
'activation_bit'
]
'tracked_min_input'
,
'tracked_max_input'
,
'scale'
,
'zero_point'
,
'weight_bit'
,
'activation_bit'
,
'BN_FOLD_TAG'
]
for
attr
in
del_attr_list
:
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
delattr
(
module
,
attr
)
...
@@ -334,6 +340,23 @@ class QAT_Quantizer(Quantizer):
...
@@ -334,6 +340,23 @@ class QAT_Quantizer(Quantizer):
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
# Recover weight/bias for batch normalization folding
if
hasattr
(
module
,
BN_FOLD_TAG
):
actual_weight
=
getattr
(
module
,
'old_weight'
,
None
)
if
actual_weight
is
None
:
logger
.
warning
(
"Can not recover weight for layer %s. "
"This may lead to a wrong accuracy performance on the backend."
,
name
)
delattr
(
module
,
'weight'
)
module
.
register_parameter
(
'weight'
,
actual_weight
)
actual_bias
=
getattr
(
module
,
'old_bias'
,
None
)
delattr
(
module
,
'bias'
)
if
actual_bias
is
not
None
:
module
.
register_parameter
(
'bias'
,
actual_bias
)
else
:
setattr
(
module
,
'bias'
,
None
)
if
hasattr
(
module
,
'activation_bit'
):
if
hasattr
(
module
,
'activation_bit'
):
calibration_config
[
name
][
'activation_bit'
]
=
int
(
module
.
activation_bit
)
calibration_config
[
name
][
'activation_bit'
]
=
int
(
module
.
activation_bit
)
calibration_config
[
name
][
'tracked_min_activation'
]
=
float
(
module
.
tracked_min_activation
)
calibration_config
[
name
][
'tracked_min_activation'
]
=
float
(
module
.
tracked_min_activation
)
...
@@ -344,9 +367,39 @@ class QAT_Quantizer(Quantizer):
...
@@ -344,9 +367,39 @@ class QAT_Quantizer(Quantizer):
return
calibration_config
return
calibration_config
def
fold_bn
(
self
,
config
,
**
kwargs
):
def
fold_bn
(
self
,
*
inputs
,
wrapper
):
# TODO simulate folded weight
"""
pass
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module
=
wrapper
.
module
bn_module
=
wrapper
.
bn_module
with
torch
.
no_grad
():
output
=
module
(
*
inputs
)
_
=
bn_module
(
output
)
running_mean
=
bn_module
.
running_mean
running_var
=
torch
.
sqrt
(
bn_module
.
running_var
+
bn_module
.
eps
)
bn_weight
=
bn_module
.
weight
bn_bias
=
bn_module
.
bias
dimensions
=
len
(
module
.
weight
.
shape
)
shape
=
[
-
1
]
+
[
1
]
*
(
dimensions
-
1
)
new_weight
=
module
.
old_weight
*
bn_weight
.
reshape
(
shape
)
/
running_var
.
reshape
(
shape
)
if
hasattr
(
module
,
'old_bias'
):
new_bias
=
bn_bias
+
(
module
.
old_bias
-
running_mean
)
/
running_var
*
bn_weight
else
:
new_bias
=
bn_bias
-
running_mean
/
running_var
*
bn_weight
return
new_weight
,
new_bias
def
step_with_optimizer
(
self
):
def
step_with_optimizer
(
self
):
"""
"""
...
...
nni/compression/pytorch/compressor.py
View file @
7fc5af07
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
types
import
types
import
logging
import
logging
import
torch
import
torch
from
nni.common.graph_utils
import
build_module_graph
from
.
import
default_layers
from
.
import
default_layers
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -463,7 +464,7 @@ class Pruner(Compressor):
...
@@ -463,7 +464,7 @@ class Pruner(Compressor):
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
,
bn_module
=
None
):
"""
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Wrap an module to enable data parallel, forward method customization and buffer registeration.
...
@@ -479,6 +480,8 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -479,6 +480,8 @@ class QuantizerModuleWrapper(torch.nn.Module):
the type of the module to compress
the type of the module to compress
quantizer :quantizer
quantizer :quantizer
the quantizer used to calculate mask
the quantizer used to calculate mask
bn_module : torch.nn.Module
batch norm layer corresponding to current module, used for simulating batch normalization folding
"""
"""
super
().
__init__
()
super
().
__init__
()
# origin layer information
# origin layer information
...
@@ -488,6 +491,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -488,6 +491,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner
# config and pruner
self
.
config
=
config
self
.
config
=
config
self
.
quantizer
=
quantizer
self
.
quantizer
=
quantizer
self
.
bn_module
=
bn_module
# register buffer and parameter
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# old_weight is used to store origin weight and weight is used to store quantized weight
...
@@ -501,6 +505,17 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -501,6 +505,17 @@ class QuantizerModuleWrapper(torch.nn.Module):
delattr
(
self
.
module
,
'weight'
)
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
# for batch normalization folding
if
self
.
bn_module
is
not
None
:
if
_check_bias
(
self
.
module
):
self
.
module
.
register_parameter
(
'old_bias'
,
torch
.
nn
.
Parameter
(
self
.
module
.
bias
))
init_tensor
=
self
.
module
.
old_bias
else
:
init_tensor
=
torch
.
zeros_like
(
self
.
bn_module
.
weight
)
delattr
(
self
.
module
,
'bias'
)
self
.
module
.
register_buffer
(
'bias'
,
init_tensor
)
setattr
(
module
,
BN_FOLD_TAG
,
True
)
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
if
'input'
in
self
.
config
[
'quant_types'
]:
inputs
=
self
.
quantizer
.
quant_grad
(
inputs
=
self
.
quantizer
.
quant_grad
(
...
@@ -509,12 +524,19 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -509,12 +524,19 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
)
self
)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
if
self
.
bn_module
is
not
None
:
# simulate batch normalization folding
new_weight
,
new_bias
=
self
.
quantizer
.
fold_bn
(
*
inputs
,
wrapper
=
self
)
self
.
module
.
bias
=
new_bias
self
.
module
.
weight
=
new_weight
else
:
new_weight
=
self
.
module
.
old_weight
self
.
quantizer
.
quant_grad
(
self
.
quantizer
.
quant_grad
(
self
.
module
.
old
_weight
,
new
_weight
,
QuantType
.
QUANT_WEIGHT
,
QuantType
.
QUANT_WEIGHT
,
self
,
inputs
[
0
])
self
,
inputs
[
0
])
result
=
self
.
module
(
*
inputs
)
else
:
result
=
self
.
module
(
*
inputs
)
result
=
self
.
module
(
*
inputs
)
if
'output'
in
self
.
config
[
'quant_types'
]:
if
'output'
in
self
.
config
[
'quant_types'
]:
...
@@ -525,12 +547,35 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -525,12 +547,35 @@ class QuantizerModuleWrapper(torch.nn.Module):
return
result
return
result
class
QuantizerIdentityWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
):
"""
Used to wrap modules that should be treated as torch.Identity
Parameters
----------
module : pytorch module
the module to be wrapped
module_name : str
the name of the module to wrapped, wrapper module shares same name
"""
super
().
__init__
()
self
.
module
=
module
self
.
module_name
=
module_name
def
forward
(
self
,
x
):
return
x
class
Quantizer
(
Compressor
):
class
Quantizer
(
Compressor
):
"""
"""
Base quantizer for pytorch quantizer
Base quantizer for pytorch quantizer
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
self
.
identity_wrappers
=
[]
self
.
conv_bn_patterns
=
{}
self
.
find_conv_bn_patterns
(
model
,
dummy_input
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
QuantGrad
.
apply
self
.
quant_grad
=
QuantGrad
.
apply
if
self
.
optimizer
is
not
None
:
if
self
.
optimizer
is
not
None
:
...
@@ -540,6 +585,10 @@ class Quantizer(Compressor):
...
@@ -540,6 +585,10 @@ class Quantizer(Compressor):
# old_weight is registered to keep track of weight before quantization
# old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer.
# and it is trainable, therefore, it should be added to optimizer.
self
.
optimizer
.
add_param_group
({
"params"
:
wrapper
.
module
.
old_weight
})
self
.
optimizer
.
add_param_group
({
"params"
:
wrapper
.
module
.
old_weight
})
# This is for conv with bias + bn. Although this situation is relatively rare,
# we still need to deal with the old_bias when it occurs
if
hasattr
(
wrapper
.
module
,
"old_bias"
):
self
.
optimizer
.
add_param_group
({
"params"
:
getattr
(
wrapper
.
module
,
"old_bias"
)})
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
"""
"""
...
@@ -597,7 +646,36 @@ class Quantizer(Compressor):
...
@@ -597,7 +646,36 @@ class Quantizer(Compressor):
for
quant_type
in
config
[
'quant_types'
]:
for
quant_type
in
config
[
'quant_types'
]:
assert
quant_type
in
config
[
'quant_bits'
],
'bits length for %s must be specified in quant_bits dict'
%
quant_type
assert
quant_type
in
config
[
'quant_bits'
],
'bits length for %s must be specified in quant_bits dict'
%
quant_type
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
# bound bn module to corresponding conv module
bn_module
=
None
if
layer
.
name
in
self
.
conv_bn_patterns
:
bn_module_name
=
self
.
conv_bn_patterns
[
layer
.
name
]
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
name
==
bn_module_name
:
bn_module
=
module
break
assert
bn_module
is
not
None
,
"BN module corresponding to layer {} is not found"
.
format
(
layer
.
name
)
self
.
identity_wrappers
.
append
(
QuantizerIdentityWrapper
(
bn_module
,
bn_module_name
))
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
,
bn_module
)
def
_wrap_model
(
self
):
"""
wrap all modules that needed to be compressed
"""
# wrap folded bn in order to bypass its forward process
for
wrapper
in
reversed
(
self
.
identity_wrappers
):
_setattr
(
self
.
bound_model
,
wrapper
.
module_name
,
wrapper
)
super
().
_wrap_model
()
def
_unwrap_model
(
self
):
"""
unwrap all modules that needed to be compressed
"""
for
wrapper
in
self
.
identity_wrappers
:
_setattr
(
self
.
bound_model
,
wrapper
.
module_name
,
wrapper
.
module
)
super
().
_unwrap_model
()
def
export_model_save
(
self
,
model
,
model_path
,
calibration_config
=
None
,
calibration_path
=
None
,
onnx_path
=
None
,
def
export_model_save
(
self
,
model
,
model_path
,
calibration_config
=
None
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
input_shape
=
None
,
device
=
None
):
...
@@ -660,6 +738,30 @@ class Quantizer(Compressor):
...
@@ -660,6 +738,30 @@ class Quantizer(Compressor):
"""
"""
raise
NotImplementedError
(
'Quantizer must overload export_model()'
)
raise
NotImplementedError
(
'Quantizer must overload export_model()'
)
def
find_conv_bn_patterns
(
self
,
model
,
dummy_input
):
"""
Find all Conv-BN patterns, used for batch normalization folding
Parameters
----------
model : torch.nn.Module
model to be analyzed.
dummy_input : tupel of torch.tensor
inputs to the model, used for generating the torchscript
"""
if
dummy_input
is
None
:
_logger
.
debug
(
"Model inputs are not given, batch normalization folding is disabled"
)
return
graph
=
build_module_graph
(
model
,
dummy_input
)
for
node_group
in
graph
.
nodes_py
.
nodes_op
:
if
node_group
.
op_type
in
BN_FOLD_OP
:
successors
=
graph
.
find_successors
(
node_group
.
unique_name
)
successors
=
[
graph
.
name_to_node
[
x
]
for
x
in
successors
]
for
successor
in
successors
:
if
successor
.
op_type
==
'BatchNorm2d'
:
self
.
conv_bn_patterns
[
node_group
.
name
]
=
successor
.
name
def
step_with_optimizer
(
self
):
def
step_with_optimizer
(
self
):
pass
pass
...
@@ -677,6 +779,9 @@ QType_Dict = {
...
@@ -677,6 +779,9 @@ QType_Dict = {
2
:
"output"
2
:
"output"
}
}
BN_FOLD_OP
=
[
"Conv2d"
]
BN_FOLD_TAG
=
'BN_FOLD_TAG'
class
QuantGrad
(
torch
.
autograd
.
Function
):
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
"""
Base class for overriding backward function of quantization operation.
Base class for overriding backward function of quantization operation.
...
@@ -773,6 +878,12 @@ def _check_weight(module):
...
@@ -773,6 +878,12 @@ def _check_weight(module):
except
AttributeError
:
except
AttributeError
:
return
False
return
False
def
_check_bias
(
module
):
try
:
return
isinstance
(
module
.
bias
.
data
,
torch
.
Tensor
)
except
AttributeError
:
return
False
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
if
quant_type
==
QuantType
.
QUANT_INPUT
:
if
quant_type
==
QuantType
.
QUANT_INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
*
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_input
(
*
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment