Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4f3ee9cb
Commit
4f3ee9cb
authored
Dec 24, 2019
by
Cjkkkk
Committed by
chicm-ms
Dec 24, 2019
Browse files
add quantization backward support (#1854)
parent
7a558113
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
43 deletions
+74
-43
examples/model_compress/QAT_torch_quantizer.py
examples/model_compress/QAT_torch_quantizer.py
+0
-1
src/sdk/pynni/nni/compression/torch/builtin_quantizers.py
src/sdk/pynni/nni/compression/torch/builtin_quantizers.py
+4
-7
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+70
-35
No files found.
examples/model_compress/QAT_torch_quantizer.py
View file @
4f3ee9cb
...
...
@@ -35,7 +35,6 @@ def train(model, quantizer, device, train_loader, optimizer):
loss
=
F
.
nll_loss
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
quantizer
.
step
()
if
batch_idx
%
100
==
0
:
print
(
'{:2.0f}% Loss {}'
.
format
(
100
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
...
...
src/sdk/pynni/nni/compression/torch/builtin_quantizers.py
View file @
4f3ee9cb
...
...
@@ -100,7 +100,7 @@ def get_bits_length(config, quant_type):
class
QAT_Quantizer
(
Quantizer
):
"""Quantizer
using the DoReFa scheme, as
defined in:
"""Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
...
...
@@ -227,16 +227,13 @@ class DoReFaQuantizer(Quantizer):
(https://arxiv.org/abs/1606.06160)
"""
def
__init__
(
self
,
model
,
config_list
):
"""
config_list: supported keys:
- q_bits
"""
super
().
__init__
(
model
,
config_list
)
def
quantize_weight
(
self
,
weight
,
config
,
**
kwargs
):
weight_bits
=
get_bits_length
(
config
,
'weight'
)
out
=
weight
.
tanh
()
out
=
out
/
(
2
*
out
.
abs
().
max
())
+
0.5
out
=
self
.
quantize
(
out
,
config
[
'q
_bits
'
]
)
out
=
self
.
quantize
(
out
,
weight
_bits
)
out
=
2
*
out
-
1
return
out
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
4f3ee9cb
...
...
@@ -250,6 +250,10 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer
"""
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
)
self
.
quant_grad
=
QuantGrad
def
quantize_weight
(
self
,
weight
,
config
,
op
,
op_type
,
op_name
):
"""
quantize should overload this method to quantize weight.
...
...
@@ -262,7 +266,7 @@ class Quantizer(Compressor):
config : dict
the configuration for weight quantization
"""
raise
NotImplementedError
(
"
Quantizer must overload quantize_weight()
"
)
raise
NotImplementedError
(
'
Quantizer must overload quantize_weight()
'
)
def
quantize_output
(
self
,
output
,
config
,
op
,
op_type
,
op_name
):
"""
...
...
@@ -276,7 +280,7 @@ class Quantizer(Compressor):
config : dict
the configuration for output quantization
"""
raise
NotImplementedError
(
"
Quantizer must overload quantize_output()
"
)
raise
NotImplementedError
(
'
Quantizer must overload quantize_output()
'
)
def
quantize_input
(
self
,
*
inputs
,
config
,
op
,
op_type
,
op_name
):
"""
...
...
@@ -290,7 +294,7 @@ class Quantizer(Compressor):
config : dict
the configuration for inputs quantization
"""
raise
NotImplementedError
(
"
Quantizer must overload quantize_input()
"
)
raise
NotImplementedError
(
'
Quantizer must overload quantize_input()
'
)
def
_instrument_layer
(
self
,
layer
,
config
):
...
...
@@ -305,62 +309,93 @@ class Quantizer(Compressor):
the configuration for quantization
"""
assert
layer
.
_forward
is
None
,
'Each model can only be compressed once'
assert
"
quant_types
"
in
config
,
'must provide quant_types in config'
assert
isinstance
(
config
[
"
quant_types
"
],
list
),
'quant_types must be list type'
assert
"
quant_bits
"
in
config
,
'must provide quant_bits in config'
assert
isinstance
(
config
[
"
quant_bits
"
],
int
)
or
isinstance
(
config
[
"
quant_bits
"
],
dict
),
'quant_bits must be dict type or int type'
assert
'
quant_types
'
in
config
,
'must provide quant_types in config'
assert
isinstance
(
config
[
'
quant_types
'
],
list
),
'quant_types must be list type'
assert
'
quant_bits
'
in
config
,
'must provide quant_bits in config'
assert
isinstance
(
config
[
'
quant_bits
'
],
int
)
or
isinstance
(
config
[
'
quant_bits
'
],
dict
),
'quant_bits must be dict type or int type'
if
isinstance
(
config
[
"
quant_bits
"
],
dict
):
for
quant_type
in
config
[
"
quant_types
"
]:
assert
quant_type
in
config
[
"
quant_bits
"
],
'bits length for %s must be specified in quant_bits dict'
%
quant_type
if
isinstance
(
config
[
'
quant_bits
'
],
dict
):
for
quant_type
in
config
[
'
quant_types
'
]:
assert
quant_type
in
config
[
'
quant_bits
'
],
'bits length for %s must be specified in quant_bits dict'
%
quant_type
if
'weight'
in
config
[
"
quant_types
"
]:
if
'weight'
in
config
[
'
quant_types
'
]:
if
not
_check_weight
(
layer
.
module
):
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
layer
.
name
)
else
:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer
.
module
.
register_parameter
(
'old_weight'
,
torch
.
nn
.
Parameter
(
layer
.
module
.
weight
))
delattr
(
layer
.
module
,
'weight'
)
layer
.
module
.
register_buffer
(
'weight'
,
layer
.
module
.
old_weight
)
layer
.
_forward
=
layer
.
module
.
forward
def
new_forward
(
*
inputs
):
if
'input'
in
config
[
"
quant_types
"
]:
inputs
=
s
traight_through_quantize_input
.
apply
(
inputs
,
self
,
config
,
layer
)
if
'input'
in
config
[
'
quant_types
'
]:
inputs
=
s
elf
.
quant_grad
.
apply
(
inputs
,
QuantType
.
QUANT_INPUT
,
self
.
quantize_input
,
config
,
layer
)
if
'weight'
in
config
[
"quant_types"
]
and
_check_weight
(
layer
.
module
):
weight
=
layer
.
module
.
weight
.
data
new_weight
=
self
.
quantize_weight
(
weight
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
layer
.
module
.
weight
.
data
=
new_weight
if
'weight'
in
config
[
'quant_types'
]
and
_check_weight
(
layer
.
module
):
new_weight
=
self
.
quant_grad
.
apply
(
layer
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
self
.
quantize_weight
,
config
,
layer
)
layer
.
module
.
weight
=
new_weight
result
=
layer
.
_forward
(
*
inputs
)
layer
.
module
.
weight
.
data
=
weight
else
:
result
=
layer
.
_forward
(
*
inputs
)
if
'output'
in
config
[
"
quant_types
"
]:
result
=
s
traight_through_quantize_output
.
apply
(
result
,
self
,
config
,
layer
)
if
'output'
in
config
[
'
quant_types
'
]:
result
=
s
elf
.
quant_grad
.
apply
(
result
,
QuantType
.
QUANT_OUTPUT
,
self
.
quantize_output
,
config
,
layer
)
return
result
layer
.
module
.
forward
=
new_forward
class
QuantType
:
"""
Enum class for quantization type.
"""
QUANT_INPUT
=
0
QUANT_WEIGHT
=
1
QUANT_OUTPUT
=
2
class
straight_through_quantize_output
(
torch
.
autograd
.
Function
):
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
Base class for overriding backward function of quantization operation.
"""
@
staticmethod
def
forward
(
ctx
,
output
,
quantizer
,
config
,
layer
):
return
quantizer
.
quantize_output
(
output
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
def
quant_backward
(
tensor
,
grad_output
,
quant_type
):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# Straight-through estimator
return
grad_output
,
None
,
None
,
None
Parameters
----------
tensor : Tensor
input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
class
straight_through_quantize_input
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
inputs
,
quantizer
,
config
,
layer
):
return
quantizer
.
quantize_input
(
inputs
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
Returns
-------
tensor
gradient of the input of quantization operation
"""
return
grad_output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# Straight-through estimator
return
grad_output
,
None
,
None
,
None
def
forward
(
ctx
,
tensor
,
quant_type
,
quant_func
,
config
,
layer
):
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]))
return
quant_func
(
tensor
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
quant_type
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
)
return
output
,
None
,
None
,
None
,
None
def
_check_weight
(
module
):
try
:
return
isinstance
(
module
.
weight
,
torch
.
nn
.
Parameter
)
and
isinstance
(
module
.
weight
.
data
,
torch
.
Tensor
)
return
isinstance
(
module
.
weight
.
data
,
torch
.
Tensor
)
except
AttributeError
:
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment