Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2f6a74f1
Unverified
Commit
2f6a74f1
authored
Dec 08, 2020
by
Dalong
Committed by
GitHub
Dec 08, 2020
Browse files
fix quant grad function calculation error (#3160)
parent
78e874f9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
11 deletions
+62
-11
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+62
-11
No files found.
nni/compression/pytorch/compressor.py
View file @
2f6a74f1
...
...
@@ -580,17 +580,55 @@ class QuantType:
"""
Enum class for quantization type.
"""
QUANT_INPUT
=
0
QUANT_WEIGHT
=
1
QUANT_OUTPUT
=
2
QUANT_INPUT
=
'input'
QUANT_WEIGHT
=
'weight'
QUANT_OUTPUT
=
'output'
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
Base class for overriding backward function of quantization operation.
"""
@
classmethod
def
_quantize
(
cls
,
x
,
scale
,
zero_point
):
"""
Reference function for quantizing x -- non-clamped.
Parameters
----------
x : Tensor
tensor to be quantized
scale : Tensor
scale for quantizing x
zero_point : Tensor
zero_point for quantizing x
Returns
-------
tensor
quantized x without clamped
"""
return
((
x
/
scale
)
+
zero_point
).
round
()
@
classmethod
def
get_bits_length
(
cls
,
config
,
quant_type
):
"""
Get bit for quantize config
Parameters
----------
config : Dict
the configuration for quantization
quant_type : str
quant type
Returns
-------
int
n-bits for quantization configuration
"""
if
isinstance
(
config
[
"quant_bits"
],
int
):
return
config
[
"quant_bits"
]
else
:
return
config
[
"quant_bits"
].
get
(
quant_type
)
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
quant_type
):
def
quant_backward
(
tensor
,
grad_output
,
scale
,
zero_point
,
qmin
,
qmax
):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
...
...
@@ -600,32 +638,45 @@ class QuantGrad(torch.autograd.Function):
input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
zero_point : Tensor
zero_point for quantizing tensor
qmin : Tensor
quant_min for quantizing tensor
qmax : Tensor
quant_max for quantizng tensor
Returns
-------
tensor
gradient of the input of quantization operation
"""
tensor_q
=
QuantGrad
.
_quantize
(
tensor
,
scale
,
zero_point
)
mask
=
(
tensor_q
<
qmin
)
|
(
tensor_q
>
qmax
)
grad_output
[
mask
]
=
0
return
grad_output
@
staticmethod
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
**
kwargs
):
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]))
if
quant_type
==
QuantType
.
QUANT_INPUT
:
return
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
return
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
return
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
raise
ValueError
(
"unrecognized QuantType."
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
quant_type
)
qmin
,
qmax
=
torch
.
Tensor
([
0
],
device
=
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
],
device
=
tensor
.
device
)
ctx
.
save_for_backward
(
tensor
,
wrapper
.
module
.
scale
,
wrapper
.
module
.
zero_point
,
qmin
,
qmax
)
return
output
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
quant_type
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
)
tensor
,
scale
,
zero_point
,
qmin
,
qmax
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
scale
,
zero_point
,
qmin
,
qmax
)
return
output
,
None
,
None
,
None
def
_check_weight
(
module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment