Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
39e3a990
Unverified
Commit
39e3a990
authored
Aug 17, 2021
by
chenbohua3
Committed by
GitHub
Aug 17, 2021
Browse files
change signature of quantize_input (#4039)
parent
86335921
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
23 deletions
+13
-23
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+7
-19
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+6
-4
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
39e3a990
...
...
@@ -187,12 +187,7 @@ class ObserverQuantizer(Quantizer):
def
record
(
self
,
wrapper
,
quant_type
,
tensor
):
name
=
wrapper
.
name
observer
=
self
.
all_observers
[
name
][
quant_type
]
if
isinstance
(
tensor
,
tuple
):
# NB: This only works for single tensor
tensor
=
(
t
.
cpu
()
for
t
in
tensor
)
observer
(
*
tensor
)
else
:
observer
(
tensor
.
cpu
())
observer
(
tensor
.
cpu
())
def
calculate_qparams
(
self
,
name
,
quant_type
):
observer
=
self
.
all_observers
[
name
][
quant_type
]
...
...
@@ -206,17 +201,14 @@ class ObserverQuantizer(Quantizer):
x
=
(
x
-
zero_point
)
*
scale
return
x
def
quantize_input
(
self
,
*
inputs
,
wrapper
,
**
kwargs
):
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
if
self
.
compressed
:
module
=
wrapper
.
module
new_
input
=
self
.
_quantize
(
inputs
[
0
]
,
input
s
=
self
.
_quantize
(
inputs
,
module
.
input_scale
,
module
.
input_zero_point
,
module
.
input_qmin
,
module
.
input_qmax
)
list_inp
=
list
(
inputs
)
list_inp
[
0
]
=
new_input
inputs
=
tuple
(
list_inp
)
else
:
self
.
record
(
wrapper
,
'input'
,
inputs
)
return
inputs
...
...
@@ -973,20 +965,16 @@ class LsqQuantizer(Quantizer):
output
=
self
.
quantize
(
output
,
module
.
output_scale
,
module
.
output_qmin
,
module
.
output_qmax
)
return
output
def
quantize_input
(
self
,
*
inputs
,
wrapper
,
**
kwargs
):
# This is hacky since it is not recommended to modify a tuple
# NB: support layers with multi inputs
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
module
=
wrapper
.
module
# initialize the scale
if
self
.
bound_model
.
steps
==
1
:
qmax
=
module
.
input_qmax
init_oup_scale
=
inputs
[
0
]
.
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
init_oup_scale
=
inputs
.
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
module
.
input_scale
.
data
=
init_oup_scale
new_input
=
self
.
quantize
(
inputs
[
0
],
module
.
input_scale
,
module
.
input_qmin
,
module
.
input_qmax
)
list_inp
=
list
(
inputs
)
list_inp
[
0
]
=
new_input
return
tuple
(
list_inp
)
inputs
=
self
.
quantize
(
inputs
,
module
.
input_scale
,
module
.
input_qmin
,
module
.
input_qmax
)
return
inputs
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
...
...
nni/compression/pytorch/compressor.py
View file @
39e3a990
...
...
@@ -544,10 +544,12 @@ class QuantizerModuleWrapper(torch.nn.Module):
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
inputs
=
self
.
quantizer
.
quant_grad
(
inputs
,
assert
len
(
inputs
)
==
1
,
"Quantization of input only supports ops with single input."
new_inp
=
self
.
quantizer
.
quant_grad
(
inputs
[
0
],
QuantType
.
QUANT_INPUT
,
self
)
inputs
=
(
new_inp
,)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
if
self
.
bn_module
is
not
None
:
...
...
@@ -640,7 +642,7 @@ class Quantizer(Compressor):
"""
raise
NotImplementedError
(
'Quantizer must overload quantize_output()'
)
def
quantize_input
(
self
,
*
inputs
,
wrapper
,
**
kwargs
):
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
...
...
@@ -912,7 +914,7 @@ def _check_bias(module):
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
if
quant_type
==
QuantType
.
QUANT_INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
*
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment