Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7b16db17
Unverified
Commit
7b16db17
authored
Sep 08, 2021
by
lin bin
Committed by
GitHub
Sep 08, 2021
Browse files
[Quantization] support bn-folding for lsq (#4148)
parent
19914055
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
37 deletions
+54
-37
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+20
-37
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+34
-0
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
7b16db17
...
@@ -635,40 +635,6 @@ class QAT_Quantizer(Quantizer):
...
@@ -635,40 +635,6 @@ class QAT_Quantizer(Quantizer):
return
calibration_config
return
calibration_config
def
fold_bn
(
self
,
*
inputs
,
wrapper
):
"""
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module
=
wrapper
.
module
bn_module
=
wrapper
.
bn_module
with
torch
.
no_grad
():
output
=
module
(
*
inputs
)
_
=
bn_module
(
output
)
running_mean
=
bn_module
.
running_mean
running_var
=
torch
.
sqrt
(
bn_module
.
running_var
+
bn_module
.
eps
)
bn_weight
=
bn_module
.
weight
bn_bias
=
bn_module
.
bias
dimensions
=
len
(
module
.
weight
.
shape
)
shape
=
[
-
1
]
+
[
1
]
*
(
dimensions
-
1
)
new_weight
=
module
.
old_weight
*
bn_weight
.
reshape
(
shape
)
/
running_var
.
reshape
(
shape
)
if
hasattr
(
module
,
'old_bias'
):
new_bias
=
bn_bias
+
(
module
.
old_bias
-
running_mean
)
/
running_var
*
bn_weight
else
:
new_bias
=
bn_bias
-
running_mean
/
running_var
*
bn_weight
return
new_weight
,
new_bias
def
step_with_optimizer
(
self
):
def
step_with_optimizer
(
self
):
"""
"""
override `compressor` `step` method, quantization only happens after certain number of steps
override `compressor` `step` method, quantization only happens after certain number of steps
...
@@ -890,7 +856,7 @@ class LsqQuantizer(Quantizer):
...
@@ -890,7 +856,7 @@ class LsqQuantizer(Quantizer):
https://arxiv.org/pdf/1902.08153.pdf
https://arxiv.org/pdf/1902.08153.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
dummy_input
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -909,9 +875,13 @@ class LsqQuantizer(Quantizer):
...
@@ -909,9 +875,13 @@ class LsqQuantizer(Quantizer):
state where output quantization ranges do not exclude a significant fraction of values, default value is 0
state where output quantization ranges do not exclude a significant fraction of values, default value is 0
- op_types : list of string
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
types of nn.module you want to apply quantization, eg. 'Conv2d'
- dummy_input : tuple of tensor
inputs to the model, which are used to get the graph of the module. The graph is used to find
Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not
given, the batch normalization folding would be disabled.
"""
"""
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"unrecognized optimizer type"
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"unrecognized optimizer type"
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
,
dummy_input
)
device
=
next
(
model
.
parameters
()).
device
device
=
next
(
model
.
parameters
()).
device
self
.
quant_grad
=
QuantForward
()
self
.
quant_grad
=
QuantForward
()
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
...
@@ -1057,6 +1027,19 @@ class LsqQuantizer(Quantizer):
...
@@ -1057,6 +1027,19 @@ class LsqQuantizer(Quantizer):
abs_max_input
=
float
(
module
.
input_scale
*
module
.
input_qmax
)
abs_max_input
=
float
(
module
.
input_scale
*
module
.
input_qmax
)
calibration_config
[
name
][
'tracked_min_input'
]
=
-
abs_max_input
calibration_config
[
name
][
'tracked_min_input'
]
=
-
abs_max_input
calibration_config
[
name
][
'tracked_max_input'
]
=
abs_max_input
calibration_config
[
name
][
'tracked_max_input'
]
=
abs_max_input
actual_weight
=
getattr
(
module
,
'old_weight'
,
None
)
if
actual_weight
is
None
:
logger
.
warning
(
"Can not recover weight for layer %s. "
"This may lead to a wrong accuracy performance on the backend."
,
name
)
delattr
(
module
,
'weight'
)
module
.
register_parameter
(
'weight'
,
actual_weight
)
if
hasattr
(
module
,
BN_FOLD_TAG
):
actual_bias
=
getattr
(
module
,
'old_bias'
,
None
)
delattr
(
module
,
'bias'
)
if
actual_bias
is
not
None
:
module
.
register_parameter
(
'bias'
,
actual_bias
)
else
:
setattr
(
module
,
'bias'
,
None
)
if
hasattr
(
module
,
'output_bits'
):
if
hasattr
(
module
,
'output_bits'
):
calibration_config
[
name
][
'output_bits'
]
=
int
(
module
.
output_bits
)
calibration_config
[
name
][
'output_bits'
]
=
int
(
module
.
output_bits
)
abs_max_output
=
float
(
module
.
output_scale
*
module
.
output_qmax
)
abs_max_output
=
float
(
module
.
output_scale
*
module
.
output_qmax
)
...
@@ -1074,7 +1057,7 @@ class LsqQuantizer(Quantizer):
...
@@ -1074,7 +1057,7 @@ class LsqQuantizer(Quantizer):
delete redundant parameters in quantize module
delete redundant parameters in quantize module
"""
"""
del_attr_list
=
[
'old_weight'
,
'tracked_min_input'
,
'tracked_max_input'
,
'tracked_min_output'
,
\
del_attr_list
=
[
'old_weight'
,
'tracked_min_input'
,
'tracked_max_input'
,
'tracked_min_output'
,
\
'tracked_max_output'
,
'output_scale'
,
'input_scale'
,
'weight_scale'
,
'weight_bits'
,
'output_bits'
,
'input_bits'
]
'tracked_max_output'
,
'output_scale'
,
'input_scale'
,
'weight_scale'
,
'weight_bits'
,
'output_bits'
,
'input_bits'
,
'BN_FOLD_TAG'
]
for
attr
in
del_attr_list
:
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
delattr
(
module
,
attr
)
...
...
nni/compression/pytorch/compressor.py
View file @
7b16db17
...
@@ -658,6 +658,40 @@ class Quantizer(Compressor):
...
@@ -658,6 +658,40 @@ class Quantizer(Compressor):
"""
"""
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
def
fold_bn
(
self
,
*
inputs
,
wrapper
):
"""
Simulate batch normalization folding in the training graph. Folded weight and bias are
returned for the following operations.
Parameters
----------
inputs : tuple of torch.Tensor
inputs for the module
wrapper : QuantizerModuleWrapper
the wrapper for origin module
Returns
-------
Tuple of torch.Tensor
"""
module
=
wrapper
.
module
bn_module
=
wrapper
.
bn_module
with
torch
.
no_grad
():
output
=
module
(
*
inputs
)
_
=
bn_module
(
output
)
running_mean
=
bn_module
.
running_mean
running_var
=
torch
.
sqrt
(
bn_module
.
running_var
+
bn_module
.
eps
)
bn_weight
=
bn_module
.
weight
bn_bias
=
bn_module
.
bias
dimensions
=
len
(
module
.
weight
.
shape
)
shape
=
[
-
1
]
+
[
1
]
*
(
dimensions
-
1
)
new_weight
=
module
.
old_weight
*
bn_weight
.
reshape
(
shape
)
/
running_var
.
reshape
(
shape
)
if
hasattr
(
module
,
'old_bias'
):
new_bias
=
bn_bias
+
(
module
.
old_bias
-
running_mean
)
/
running_var
*
bn_weight
else
:
new_bias
=
bn_bias
-
running_mean
/
running_var
*
bn_weight
return
new_weight
,
new_bias
def
_wrap_modules
(
self
,
layer
,
config
):
def
_wrap_modules
(
self
,
layer
,
config
):
"""
"""
Create a wrapper forward function to replace the original one.
Create a wrapper forward function to replace the original one.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment