Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f13a9cd4
Unverified
Commit
f13a9cd4
authored
Oct 20, 2021
by
lin bin
Committed by
GitHub
Oct 20, 2021
Browse files
[Quantization] fix QAT export param (#4252)
parent
cdb65dac
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
36 deletions
+41
-36
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+21
-16
nni/compression/pytorch/quantization/utils.py
nni/compression/pytorch/quantization/utils.py
+1
-1
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+19
-19
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
f13a9cd4
...
...
@@ -26,7 +26,6 @@ __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer',
logger
=
logging
.
getLogger
(
__name__
)
class
NaiveQuantizer
(
Quantizer
):
"""quantize weight to 8 bits
"""
...
...
@@ -676,17 +675,20 @@ class QAT_Quantizer(Quantizer):
for
layer
,
_
in
modules_to_compress
:
name
,
module
=
layer
.
name
,
layer
.
module
if
name
not
in
calibration_config
:
if
hasattr
(
module
,
'weight_bits'
)
or
hasattr
(
module
,
'output_bits'
)
or
hasattr
(
module
,
'input_bits'
)
:
if
module
.
layer_quant_setting
.
weight
or
module
.
layer_quant_setting
.
input
or
module
.
layer_quant_setting
.
output
:
logger
.
warning
(
f
"Can not find module
{
name
}
's parameter in input config."
)
continue
if
hasattr
(
module
,
'weight_bits'
):
assert
calibration_config
[
name
][
'weight_bits'
]
==
module
.
weight_bits
,
f
"weight bits of module
{
name
}
fail to match"
if
hasattr
(
module
,
'input_bits'
):
assert
calibration_config
[
name
][
'input_bits'
]
==
module
.
input_bits
,
f
"input bits of module
{
name
}
fail to match"
if
module
.
layer_quant_setting
.
weight
:
assert
calibration_config
[
name
][
'weight_bits'
]
==
module
.
layer_quant_setting
.
weight
.
bits
,
\
f
"weight bits of module
{
name
}
fail to match"
if
module
.
layer_quant_setting
.
input
:
assert
calibration_config
[
name
][
'input_bits'
]
==
module
.
layer_quant_setting
.
input
.
bits
,
\
f
"input bits of module
{
name
}
fail to match"
module
.
tracked_min_input
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_min_input'
]])
module
.
tracked_max_input
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_max_input'
]])
if
hasattr
(
module
,
'output_bits'
):
assert
calibration_config
[
name
][
'output_bits'
]
==
module
.
output_bits
,
f
"output bits of module
{
name
}
fail to match"
if
module
.
layer_quant_setting
.
output
:
assert
calibration_config
[
name
][
'output_bits'
]
==
module
.
layer_quant_setting
.
output
.
bits
,
\
f
"output bits of module
{
name
}
fail to match"
module
.
tracked_min_output
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_min_output'
]])
module
.
tracked_max_output
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_max_output'
]])
...
...
@@ -716,11 +718,13 @@ class QAT_Quantizer(Quantizer):
self
.
_unwrap_model
()
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
hasattr
(
module
,
'weight_bits'
)
or
hasattr
(
module
,
'output_bits'
):
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
_
in
modules_to_compress
:
name
,
module
=
layer
.
name
,
layer
.
module
if
hasattr
(
module
.
layer_quant_setting
,
'weight'
)
or
hasattr
(
module
.
layer_quant_setting
,
'output'
):
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_bits'
)
:
calibration_config
[
name
][
'weight_bits'
]
=
int
(
module
.
weight
_
bits
)
if
module
.
layer_quant_setting
.
weight
:
calibration_config
[
name
][
'weight_bits'
]
=
int
(
module
.
layer_quant_setting
.
weight
.
bits
)
calibration_config
[
name
][
'weight_scale'
]
=
module
.
weight_scale
calibration_config
[
name
][
'weight_zero_point'
]
=
module
.
weight_zero_point
...
...
@@ -738,13 +742,14 @@ class QAT_Quantizer(Quantizer):
module
.
register_parameter
(
'bias'
,
actual_bias
)
else
:
setattr
(
module
,
'bias'
,
None
)
if
hasattr
(
module
,
'input_bits'
):
calibration_config
[
name
][
'input_bits'
]
=
int
(
module
.
input_bits
)
if
module
.
layer_quant_setting
.
input
:
calibration_config
[
name
][
'input_bits'
]
=
int
(
module
.
layer_quant_setting
.
input
.
bits
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
if
hasattr
(
module
,
'output_bits'
)
:
calibration_config
[
name
][
'output_bits'
]
=
int
(
module
.
output
_
bits
)
if
module
.
layer_quant_setting
.
output
:
calibration_config
[
name
][
'output_bits'
]
=
int
(
module
.
layer_quant_setting
.
output
.
bits
)
calibration_config
[
name
][
'tracked_min_output'
]
=
float
(
module
.
tracked_min_output
)
calibration_config
[
name
][
'tracked_max_output'
]
=
float
(
module
.
tracked_max_output
)
self
.
_del_simulated_attr
(
module
)
...
...
nni/compression/pytorch/quantization/utils.py
View file @
f13a9cd4
...
...
@@ -79,5 +79,5 @@ def get_quant_shape(shape, quant_type, quant_scheme):
if
is_per_channel
(
quant_scheme
):
quant_shape
=
[
1
if
idx
!=
default_idx
else
s
for
idx
,
s
in
enumerate
(
shape
)]
else
:
quant_shape
=
[]
quant_shape
=
[
1
]
return
quant_shape
test/ut/sdk/test_compressor_torch.py
View file @
f13a9cd4
...
...
@@ -9,7 +9,7 @@ import torch.nn.functional as F
import
schema
import
nni.algorithms.compression.pytorch.pruning
as
torch_pruner
import
nni.algorithms.compression.pytorch.quantization
as
torch_quantizer
from
nni.compression.pytorch.quantization.utils
import
calculate_qmin_qmax
,
get_quant_shape
,
get_min_max_value
from
nni.compression.pytorch.quantization.utils
import
calculate_qmin_qmax
,
get_quant_shape
import
math
...
...
@@ -398,11 +398,11 @@ class CompressorTestCase(TestCase):
target_zero_point
=
torch
.
ones
([
2
,
1
,
1
,
1
])
*
127
elif
qscheme
==
'per_tensor_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
18.
/
127
)
target_zero_point
=
torch
.
zeros
([])
target_scale
=
torch
.
tensor
(
[
18.
/
127
]
)
target_zero_point
=
torch
.
zeros
([
1
])
else
:
target_scale
=
torch
.
tensor
(
18.
/
127.5
)
target_zero_point
=
torch
.
ones
([])
*
127
target_scale
=
torch
.
tensor
(
[
18.
/
127.5
]
)
target_zero_point
=
torch
.
ones
([
1
])
*
127
elif
qscheme
==
'per_channel_affine'
:
min_val
=
torch
.
tensor
([
0.
,
0.
]).
view
([
2
,
1
,
1
,
1
])
if
dtype
==
'int'
:
...
...
@@ -413,10 +413,10 @@ class CompressorTestCase(TestCase):
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
else
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
18.
/
254
)
target_scale
=
torch
.
tensor
(
[
18.
/
254
]
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
(
18.
/
255
)
target_scale
=
torch
.
tensor
(
[
18.
/
255
]
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
wrapper
=
getattr
(
model
,
name
)
wrapper
.
module
.
weight
=
weight
...
...
@@ -434,11 +434,11 @@ class CompressorTestCase(TestCase):
target_zero_point
=
torch
.
ones
([
1
,
1
,
1
,
1
])
*
127
elif
qscheme
==
'per_tensor_symmetric'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
15.
/
127
)
target_zero_point
=
torch
.
zeros
([])
target_scale
=
torch
.
tensor
(
[
15.
/
127
]
)
target_zero_point
=
torch
.
zeros
([
1
])
else
:
target_scale
=
torch
.
tensor
(
15.
/
127.5
)
target_zero_point
=
torch
.
ones
([])
*
127
target_scale
=
torch
.
tensor
(
[
15.
/
127.5
]
)
target_zero_point
=
torch
.
ones
([
1
])
*
127
elif
qscheme
==
'per_channel_affine'
:
min_val
=
torch
.
tensor
([
0.
]).
view
([
1
,
1
,
1
,
1
])
if
dtype
==
'int'
:
...
...
@@ -449,10 +449,10 @@ class CompressorTestCase(TestCase):
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
else
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
15.
/
254
)
target_scale
=
torch
.
tensor
(
[
15.
/
254
]
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
else
:
target_scale
=
torch
.
tensor
(
15.
/
255
)
target_scale
=
torch
.
tensor
(
[
15.
/
255
]
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
quantizer
.
quantize_input
(
inp
,
wrapper
)
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
input_scale
,
target_scale
))
...
...
@@ -488,7 +488,7 @@ class CompressorTestCase(TestCase):
assert
model
.
conv2
.
module
.
weight_zero_point
==
0
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_scale
,
torch
.
tensor
([
4.
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_zero_point
,
torch
.
tensor
(
0.
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_zero_point
,
torch
.
tensor
(
[
0.
]
)))
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
=
weight
...
...
@@ -497,7 +497,7 @@ class CompressorTestCase(TestCase):
assert
model
.
conv2
.
module
.
weight_zero_point
in
(
42
,
43
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_scale
,
torch
.
tensor
([
4.
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_zero_point
,
torch
.
tensor
(
0.
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_zero_point
,
torch
.
tensor
(
[
0.
]
)))
# test value of weight and bias after quantization
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
...
...
@@ -513,14 +513,14 @@ class CompressorTestCase(TestCase):
eps
=
1e-7
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
model
.
relu
(
x
)
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
0.
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
0.2
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
[
0.
]
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
[
0.2
]
)))
quantizer
.
step_with_optimizer
()
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
model
.
relu
(
x
)
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
0.002
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
0.2060
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
[
0.002
]
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
[
0.2060
]
)))
def
test_torch_quantizer_export
(
self
):
config_list_qat
=
[{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment