Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
39e3a990
Unverified
Commit
39e3a990
authored
Aug 17, 2021
by
chenbohua3
Committed by
GitHub
Aug 17, 2021
Browse files
change signature of quantize_input (#4039)
parent
86335921
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
23 deletions
+13
-23
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+7
-19
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+6
-4
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
39e3a990
...
@@ -187,11 +187,6 @@ class ObserverQuantizer(Quantizer):
...
@@ -187,11 +187,6 @@ class ObserverQuantizer(Quantizer):
def
record
(
self
,
wrapper
,
quant_type
,
tensor
):
def
record
(
self
,
wrapper
,
quant_type
,
tensor
):
name
=
wrapper
.
name
name
=
wrapper
.
name
observer
=
self
.
all_observers
[
name
][
quant_type
]
observer
=
self
.
all_observers
[
name
][
quant_type
]
if
isinstance
(
tensor
,
tuple
):
# NB: This only works for single tensor
tensor
=
(
t
.
cpu
()
for
t
in
tensor
)
observer
(
*
tensor
)
else
:
observer
(
tensor
.
cpu
())
observer
(
tensor
.
cpu
())
def
calculate_qparams
(
self
,
name
,
quant_type
):
def
calculate_qparams
(
self
,
name
,
quant_type
):
...
@@ -206,17 +201,14 @@ class ObserverQuantizer(Quantizer):
...
@@ -206,17 +201,14 @@ class ObserverQuantizer(Quantizer):
x
=
(
x
-
zero_point
)
*
scale
x
=
(
x
-
zero_point
)
*
scale
return
x
return
x
def
quantize_input
(
self
,
*
inputs
,
wrapper
,
**
kwargs
):
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
if
self
.
compressed
:
if
self
.
compressed
:
module
=
wrapper
.
module
module
=
wrapper
.
module
new_
input
=
self
.
_quantize
(
inputs
[
0
]
,
input
s
=
self
.
_quantize
(
inputs
,
module
.
input_scale
,
module
.
input_scale
,
module
.
input_zero_point
,
module
.
input_zero_point
,
module
.
input_qmin
,
module
.
input_qmin
,
module
.
input_qmax
)
module
.
input_qmax
)
list_inp
=
list
(
inputs
)
list_inp
[
0
]
=
new_input
inputs
=
tuple
(
list_inp
)
else
:
else
:
self
.
record
(
wrapper
,
'input'
,
inputs
)
self
.
record
(
wrapper
,
'input'
,
inputs
)
return
inputs
return
inputs
...
@@ -973,20 +965,16 @@ class LsqQuantizer(Quantizer):
...
@@ -973,20 +965,16 @@ class LsqQuantizer(Quantizer):
output
=
self
.
quantize
(
output
,
module
.
output_scale
,
module
.
output_qmin
,
module
.
output_qmax
)
output
=
self
.
quantize
(
output
,
module
.
output_scale
,
module
.
output_qmin
,
module
.
output_qmax
)
return
output
return
output
def
quantize_input
(
self
,
*
inputs
,
wrapper
,
**
kwargs
):
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
# This is hacky since it is not recommended to modify a tuple
# NB: support layers with multi inputs
module
=
wrapper
.
module
module
=
wrapper
.
module
# initialize the scale
# initialize the scale
if
self
.
bound_model
.
steps
==
1
:
if
self
.
bound_model
.
steps
==
1
:
qmax
=
module
.
input_qmax
qmax
=
module
.
input_qmax
init_oup_scale
=
inputs
[
0
]
.
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
init_oup_scale
=
inputs
.
data
.
detach
().
abs
().
mean
()
*
2
/
(
qmax
**
0.5
)
module
.
input_scale
.
data
=
init_oup_scale
module
.
input_scale
.
data
=
init_oup_scale
new_input
=
self
.
quantize
(
inputs
[
0
],
module
.
input_scale
,
module
.
input_qmin
,
module
.
input_qmax
)
inputs
=
self
.
quantize
(
inputs
,
module
.
input_scale
,
module
.
input_qmin
,
module
.
input_qmax
)
list_inp
=
list
(
inputs
)
return
inputs
list_inp
[
0
]
=
new_input
return
tuple
(
list_inp
)
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
"""
...
...
nni/compression/pytorch/compressor.py
View file @
39e3a990
...
@@ -544,10 +544,12 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -544,10 +544,12 @@ class QuantizerModuleWrapper(torch.nn.Module):
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
if
'input'
in
self
.
config
[
'quant_types'
]:
inputs
=
self
.
quantizer
.
quant_grad
(
assert
len
(
inputs
)
==
1
,
"Quantization of input only supports ops with single input."
inputs
,
new_inp
=
self
.
quantizer
.
quant_grad
(
inputs
[
0
],
QuantType
.
QUANT_INPUT
,
QuantType
.
QUANT_INPUT
,
self
)
self
)
inputs
=
(
new_inp
,)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
if
self
.
bn_module
is
not
None
:
if
self
.
bn_module
is
not
None
:
...
@@ -640,7 +642,7 @@ class Quantizer(Compressor):
...
@@ -640,7 +642,7 @@ class Quantizer(Compressor):
"""
"""
raise
NotImplementedError
(
'Quantizer must overload quantize_output()'
)
raise
NotImplementedError
(
'Quantizer must overload quantize_output()'
)
def
quantize_input
(
self
,
*
inputs
,
wrapper
,
**
kwargs
):
def
quantize_input
(
self
,
inputs
,
wrapper
,
**
kwargs
):
"""
"""
quantize should overload this method to quantize input.
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
This method is effectively hooked to :meth:`forward` of the model.
...
@@ -912,7 +914,7 @@ def _check_bias(module):
...
@@ -912,7 +914,7 @@ def _check_bias(module):
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
def
quantize_helper
(
tensor
,
quant_type
,
wrapper
,
input_tensor
=
None
,
**
kwargs
):
if
quant_type
==
QuantType
.
QUANT_INPUT
:
if
quant_type
==
QuantType
.
QUANT_INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
*
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
=
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
input_tensor
=
input_tensor
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment