Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3b1d5cd4
Unverified
Commit
3b1d5cd4
authored
Jan 05, 2021
by
lin bin
Committed by
GitHub
Jan 05, 2021
Browse files
Fix dorefa bnn (#3247)
parent
0a20c3fc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
13 deletions
+31
-13
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+11
-1
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+20
-12
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
3b1d5cd4
...
@@ -111,6 +111,15 @@ def get_bits_length(config, quant_type):
...
@@ -111,6 +111,15 @@ def get_bits_length(config, quant_type):
return
config
[
"quant_bits"
].
get
(
quant_type
)
return
config
[
"quant_bits"
].
get
(
quant_type
)
class
QATGrad
(
QuantGrad
):
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
):
tensor_q
=
QuantGrad
.
_quantize
(
tensor
,
scale
,
zero_point
)
mask
=
(
tensor_q
<
qmin
)
|
(
tensor_q
>
qmax
)
grad_output
[
mask
]
=
0
return
grad_output
class
QAT_Quantizer
(
Quantizer
):
class
QAT_Quantizer
(
Quantizer
):
"""Quantizer defined in:
"""Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
...
@@ -138,6 +147,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -138,6 +147,7 @@ class QAT_Quantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d'
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
QATGrad
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
...
@@ -331,7 +341,7 @@ class DoReFaQuantizer(Quantizer):
...
@@ -331,7 +341,7 @@ class DoReFaQuantizer(Quantizer):
class
ClipGrad
(
QuantGrad
):
class
ClipGrad
(
QuantGrad
):
@
staticmethod
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
quant_type
):
def
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
):
if
quant_type
==
QuantType
.
QUANT_OUTPUT
:
if
quant_type
==
QuantType
.
QUANT_OUTPUT
:
grad_output
[
torch
.
abs
(
tensor
)
>
1
]
=
0
grad_output
[
torch
.
abs
(
tensor
)
>
1
]
=
0
return
grad_output
return
grad_output
...
...
nni/compression/pytorch/compressor.py
View file @
3b1d5cd4
...
@@ -580,10 +580,15 @@ class QuantType:
...
@@ -580,10 +580,15 @@ class QuantType:
"""
"""
Enum class for quantization type.
Enum class for quantization type.
"""
"""
QUANT_INPUT
=
'input'
QUANT_INPUT
=
0
QUANT_WEIGHT
=
'weight'
QUANT_WEIGHT
=
1
QUANT_OUTPUT
=
'output'
QUANT_OUTPUT
=
2
QType_Dict
=
{
0
:
"input"
,
1
:
"weight"
,
2
:
"output"
}
class
QuantGrad
(
torch
.
autograd
.
Function
):
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
"""
...
@@ -628,7 +633,7 @@ class QuantGrad(torch.autograd.Function):
...
@@ -628,7 +633,7 @@ class QuantGrad(torch.autograd.Function):
return
config
[
"quant_bits"
].
get
(
quant_type
)
return
config
[
"quant_bits"
].
get
(
quant_type
)
@
staticmethod
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
scale
,
zero_point
,
qmin
,
qmax
):
def
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
):
"""
"""
This method should be overrided by subclass to provide customized backward function,
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
default implementation is Straight-Through Estimator
...
@@ -652,9 +657,6 @@ class QuantGrad(torch.autograd.Function):
...
@@ -652,9 +657,6 @@ class QuantGrad(torch.autograd.Function):
tensor
tensor
gradient of the input of quantization operation
gradient of the input of quantization operation
"""
"""
tensor_q
=
QuantGrad
.
_quantize
(
tensor
,
scale
,
zero_point
)
mask
=
(
tensor_q
<
qmin
)
|
(
tensor_q
>
qmax
)
grad_output
[
mask
]
=
0
return
grad_output
return
grad_output
@
staticmethod
@
staticmethod
...
@@ -668,15 +670,21 @@ class QuantGrad(torch.autograd.Function):
...
@@ -668,15 +670,21 @@ class QuantGrad(torch.autograd.Function):
else
:
else
:
raise
ValueError
(
"unrecognized QuantType."
)
raise
ValueError
(
"unrecognized QuantType."
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
quant_type
)
qmin
,
qmax
=
torch
.
Tensor
([
0
]).
to
(
device
=
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
to
(
device
=
tensor
.
device
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
QType_Dict
[
quant_type
])
ctx
.
save_for_backward
(
tensor
,
wrapper
.
module
.
scale
,
wrapper
.
module
.
zero_point
,
qmin
,
qmax
)
qmin
,
qmax
=
torch
.
Tensor
([
0
]).
to
(
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
to
(
tensor
.
device
)
if
hasattr
(
wrapper
.
module
,
'scale'
)
and
hasattr
(
wrapper
.
module
,
'zero_point'
):
scale
=
wrapper
.
module
.
scale
zero_point
=
wrapper
.
module
.
zero_point
else
:
scale
,
zero_point
=
None
,
None
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]),
scale
,
zero_point
,
qmin
,
qmax
)
return
output
return
output
@
classmethod
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
scale
,
zero_point
,
qmin
,
qmax
=
ctx
.
saved_variables
tensor
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
scale
,
zero_point
,
qmin
,
qmax
)
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
)
return
output
,
None
,
None
,
None
return
output
,
None
,
None
,
None
def
_check_weight
(
module
):
def
_check_weight
(
module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment