Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7b16db17
Unverified
Commit
7b16db17
authored
Sep 08, 2021
by
lin bin
Committed by
GitHub
Sep 08, 2021
Browse files
[Quantization] support bn-folding for lsq (#4148)
parent
19914055
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
37 deletions
+54
-37
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+20
-37
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+34
-0
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
7b16db17
...
@@ -635,40 +635,6 @@ class QAT_Quantizer(Quantizer):
...
@@ -635,40 +635,6 @@ class QAT_Quantizer(Quantizer):
return
calibration_config
return
calibration_config
def
fold_bn
(
self
,
*
inputs
,
wrapper
):
"""
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module
=
wrapper
.
module
bn_module
=
wrapper
.
bn_module
with
torch
.
no_grad
():
output
=
module
(
*
inputs
)
_
=
bn_module
(
output
)
running_mean
=
bn_module
.
running_mean
running_var
=
torch
.
sqrt
(
bn_module
.
running_var
+
bn_module
.
eps
)
bn_weight
=
bn_module
.
weight
bn_bias
=
bn_module
.
bias
dimensions
=
len
(
module
.
weight
.
shape
)
shape
=
[
-
1
]
+
[
1
]
*
(
dimensions
-
1
)
new_weight
=
module
.
old_weight
*
bn_weight
.
reshape
(
shape
)
/
running_var
.
reshape
(
shape
)
if
hasattr
(
module
,
'old_bias'
):
new_bias
=
bn_bias
+
(
module
.
old_bias
-
running_mean
)
/
running_var
*
bn_weight
else
:
new_bias
=
bn_bias
-
running_mean
/
running_var
*
bn_weight
return
new_weight
,
new_bias
def
step_with_optimizer
(
self
):
def
step_with_optimizer
(
self
):
"""
"""
override `compressor` `step` method, quantization only happens after certain number of steps
override `compressor` `step` method, quantization only happens after certain number of steps
...
@@ -890,7 +856,7 @@ class LsqQuantizer(Quantizer):
...
@@ -890,7 +856,7 @@ class LsqQuantizer(Quantizer):
https://arxiv.org/pdf/1902.08153.pdf
https://arxiv.org/pdf/1902.08153.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
dummy_input
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -909,9 +875,13 @@ class LsqQuantizer(Quantizer):
...
@@ -909,9 +875,13 @@ class LsqQuantizer(Quantizer):
state where output quantization ranges do not exclude a significant fraction of values, default value is 0
state where output quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
types of nn.module you want to apply quantization, eg. 'Conv2d'
- dummy_input : tuple of tensor
inputs to the model, which are used to get the graph of the module. The graph is used to find
Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not
given, the batch normalization folding would be disabled.
"""
"""
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"unrecognized optimizer type"
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"unrecognized optimizer type"
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
,
dummy_input
)
device
=
next
(
model
.
parameters
()).
device
device
=
next
(
model
.
parameters
()).
device
self
.
quant_grad
=
QuantForward
()
self
.
quant_grad
=
QuantForward
()
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
...
@@ -1057,6 +1027,19 @@ class LsqQuantizer(Quantizer):
...
@@ -1057,6 +1027,19 @@ class LsqQuantizer(Quantizer):
abs_max_input
=
float
(
module
.
input_scale
*
module
.
input_qmax
)
abs_max_input
=
float
(
module
.
input_scale
*
module
.
input_qmax
)
calibration_config
[
name
][
'tracked_min_input'
]
=
-
abs_max_input
calibration_config
[
name
][
'tracked_min_input'
]
=
-
abs_max_input
calibration_config
[
name
][
'tracked_max_input'
]
=
abs_max_input
calibration_config
[
name
][
'tracked_max_input'
]
=
abs_max_input
actual_weight
=
getattr
(
module
,
'old_weight'
,
None
)
if
actual_weight
is
None
:
logger
.
warning
(
"Can not recover weight for layer %s. "
"This may lead to a wrong accuracy performance on the backend."
,
name
)
delattr
(
module
,
'weight'
)
module
.
register_parameter
(
'weight'
,
actual_weight
)
if
hasattr
(
module
,
BN_FOLD_TAG
):
actual_bias
=
getattr
(
module
,
'old_bias'
,
None
)
delattr
(
module
,
'bias'
)
if
actual_bias
is
not
None
:
module
.
register_parameter
(
'bias'
,
actual_bias
)
else
:
setattr
(
module
,
'bias'
,
None
)
if
hasattr
(
module
,
'output_bits'
):
if
hasattr
(
module
,
'output_bits'
):
calibration_config
[
name
][
'output_bits'
]
=
int
(
module
.
output_bits
)
calibration_config
[
name
][
'output_bits'
]
=
int
(
module
.
output_bits
)
abs_max_output
=
float
(
module
.
output_scale
*
module
.
output_qmax
)
abs_max_output
=
float
(
module
.
output_scale
*
module
.
output_qmax
)
...
@@ -1074,7 +1057,7 @@ class LsqQuantizer(Quantizer):
...
@@ -1074,7 +1057,7 @@ class LsqQuantizer(Quantizer):
delete redundant parameters in quantize module
delete redundant parameters in quantize module
"""
"""
del_attr_list
=
[
'old_weight'
,
'tracked_min_input'
,
'tracked_max_input'
,
'tracked_min_output'
,
\
del_attr_list
=
[
'old_weight'
,
'tracked_min_input'
,
'tracked_max_input'
,
'tracked_min_output'
,
\
'tracked_max_output'
,
'output_scale'
,
'input_scale'
,
'weight_scale'
,
'weight_bits'
,
'output_bits'
,
'input_bits'
]
'tracked_max_output'
,
'output_scale'
,
'input_scale'
,
'weight_scale'
,
'weight_bits'
,
'output_bits'
,
'input_bits'
,
'BN_FOLD_TAG'
]
for
attr
in
del_attr_list
:
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
delattr
(
module
,
attr
)
...
...
nni/compression/pytorch/compressor.py
View file @
7b16db17
...
@@ -658,6 +658,40 @@ class Quantizer(Compressor):
...
@@ -658,6 +658,40 @@ class Quantizer(Compressor):
"""
"""
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
def
fold_bn
(
self
,
*
inputs
,
wrapper
):
"""
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module
=
wrapper
.
module
bn_module
=
wrapper
.
bn_module
with
torch
.
no_grad
():
output
=
module
(
*
inputs
)
_
=
bn_module
(
output
)
running_mean
=
bn_module
.
running_mean
running_var
=
torch
.
sqrt
(
bn_module
.
running_var
+
bn_module
.
eps
)
bn_weight
=
bn_module
.
weight
bn_bias
=
bn_module
.
bias
dimensions
=
len
(
module
.
weight
.
shape
)
shape
=
[
-
1
]
+
[
1
]
*
(
dimensions
-
1
)
new_weight
=
module
.
old_weight
*
bn_weight
.
reshape
(
shape
)
/
running_var
.
reshape
(
shape
)
if
hasattr
(
module
,
'old_bias'
):
new_bias
=
bn_bias
+
(
module
.
old_bias
-
running_mean
)
/
running_var
*
bn_weight
else
:
new_bias
=
bn_bias
-
running_mean
/
running_var
*
bn_weight
return
new_weight
,
new_bias
def
_wrap_modules
(
self
,
layer
,
config
):
def
_wrap_modules
(
self
,
layer
,
config
):
"""
"""
Create a wrapper forward function to replace the original one.
Create a wrapper forward function to replace the original one.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment