Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f13a9cd4
Unverified
Commit
f13a9cd4
authored
Oct 20, 2021
by
lin bin
Committed by
GitHub
Oct 20, 2021
Browse files
[Quantization] fix QAT export param (#4252)
parent
cdb65dac
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
36 deletions
+41
-36
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+21
-16
nni/compression/pytorch/quantization/utils.py
nni/compression/pytorch/quantization/utils.py
+1
-1
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+19
-19
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
f13a9cd4
...
@@ -26,7 +26,6 @@ __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer',
...
@@ -26,7 +26,6 @@ __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer',
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
NaiveQuantizer
(
Quantizer
):
class
NaiveQuantizer
(
Quantizer
):
"""quantize weight to 8 bits
"""quantize weight to 8 bits
"""
"""
...
@@ -676,17 +675,20 @@ class QAT_Quantizer(Quantizer):
...
@@ -676,17 +675,20 @@ class QAT_Quantizer(Quantizer):
for
layer
,
_
in
modules_to_compress
:
for
layer
,
_
in
modules_to_compress
:
name
,
module
=
layer
.
name
,
layer
.
module
name
,
module
=
layer
.
name
,
layer
.
module
if
name
not
in
calibration_config
:
if
name
not
in
calibration_config
:
if
hasattr
(
module
,
'weight_bits'
)
or
hasattr
(
module
,
'output_bits'
)
or
hasattr
(
module
,
'input_bits'
)
:
if
module
.
layer_quant_setting
.
weight
or
module
.
layer_quant_setting
.
input
or
module
.
layer_quant_setting
.
output
:
logger
.
warning
(
f
"Can not find module
{
name
}
's parameter in input config."
)
logger
.
warning
(
f
"Can not find module
{
name
}
's parameter in input config."
)
continue
continue
if
hasattr
(
module
,
'weight_bits'
):
if
module
.
layer_quant_setting
.
weight
:
assert
calibration_config
[
name
][
'weight_bits'
]
==
module
.
weight_bits
,
f
"weight bits of module
{
name
}
fail to match"
assert
calibration_config
[
name
][
'weight_bits'
]
==
module
.
layer_quant_setting
.
weight
.
bits
,
\
if
hasattr
(
module
,
'input_bits'
):
f
"weight bits of module
{
name
}
fail to match"
assert
calibration_config
[
name
][
'input_bits'
]
==
module
.
input_bits
,
f
"input bits of module
{
name
}
fail to match"
if
module
.
layer_quant_setting
.
input
:
assert
calibration_config
[
name
][
'input_bits'
]
==
module
.
layer_quant_setting
.
input
.
bits
,
\
f
"input bits of module
{
name
}
fail to match"
module
.
tracked_min_input
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_min_input'
]])
module
.
tracked_min_input
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_min_input'
]])
module
.
tracked_max_input
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_max_input'
]])
module
.
tracked_max_input
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_max_input'
]])
if
hasattr
(
module
,
'output_bits'
):
if
module
.
layer_quant_setting
.
output
:
assert
calibration_config
[
name
][
'output_bits'
]
==
module
.
output_bits
,
f
"output bits of module
{
name
}
fail to match"
assert
calibration_config
[
name
][
'output_bits'
]
==
module
.
layer_quant_setting
.
output
.
bits
,
\
f
"output bits of module
{
name
}
fail to match"
module
.
tracked_min_output
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_min_output'
]])
module
.
tracked_min_output
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_min_output'
]])
module
.
tracked_max_output
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_max_output'
]])
module
.
tracked_max_output
.
data
=
torch
.
tensor
([
calibration_config
[
name
][
'tracked_max_output'
]])
...
@@ -716,11 +718,13 @@ class QAT_Quantizer(Quantizer):
...
@@ -716,11 +718,13 @@ class QAT_Quantizer(Quantizer):
self
.
_unwrap_model
()
self
.
_unwrap_model
()
calibration_config
=
{}
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
modules_to_compress
=
self
.
get_modules_to_compress
()
if
hasattr
(
module
,
'weight_bits'
)
or
hasattr
(
module
,
'output_bits'
):
for
layer
,
_
in
modules_to_compress
:
name
,
module
=
layer
.
name
,
layer
.
module
if
hasattr
(
module
.
layer_quant_setting
,
'weight'
)
or
hasattr
(
module
.
layer_quant_setting
,
'output'
):
calibration_config
[
name
]
=
{}
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_bits'
)
:
if
module
.
layer_quant_setting
.
weight
:
calibration_config
[
name
][
'weight_bits'
]
=
int
(
module
.
weight
_
bits
)
calibration_config
[
name
][
'weight_bits'
]
=
int
(
module
.
layer_quant_setting
.
weight
.
bits
)
calibration_config
[
name
][
'weight_scale'
]
=
module
.
weight_scale
calibration_config
[
name
][
'weight_scale'
]
=
module
.
weight_scale
calibration_config
[
name
][
'weight_zero_point'
]
=
module
.
weight_zero_point
calibration_config
[
name
][
'weight_zero_point'
]
=
module
.
weight_zero_point
...
@@ -738,13 +742,14 @@ class QAT_Quantizer(Quantizer):
...
@@ -738,13 +742,14 @@ class QAT_Quantizer(Quantizer):
module
.
register_parameter
(
'bias'
,
actual_bias
)
module
.
register_parameter
(
'bias'
,
actual_bias
)
else
:
else
:
setattr
(
module
,
'bias'
,
None
)
setattr
(
module
,
'bias'
,
None
)
if
hasattr
(
module
,
'input_bits'
):
calibration_config
[
name
][
'input_bits'
]
=
int
(
module
.
input_bits
)
if
module
.
layer_quant_setting
.
input
:
calibration_config
[
name
][
'input_bits'
]
=
int
(
module
.
layer_quant_setting
.
input
.
bits
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_min_input'
]
=
float
(
module
.
tracked_min_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
calibration_config
[
name
][
'tracked_max_input'
]
=
float
(
module
.
tracked_max_input
)
if
hasattr
(
module
,
'output_bits'
)
:
if
module
.
layer_quant_setting
.
output
:
calibration_config
[
name
][
'output_bits'
]
=
int
(
module
.
output
_
bits
)
calibration_config
[
name
][
'output_bits'
]
=
int
(
module
.
layer_quant_setting
.
output
.
bits
)
calibration_config
[
name
][
'tracked_min_output'
]
=
float
(
module
.
tracked_min_output
)
calibration_config
[
name
][
'tracked_min_output'
]
=
float
(
module
.
tracked_min_output
)
calibration_config
[
name
][
'tracked_max_output'
]
=
float
(
module
.
tracked_max_output
)
calibration_config
[
name
][
'tracked_max_output'
]
=
float
(
module
.
tracked_max_output
)
self
.
_del_simulated_attr
(
module
)
self
.
_del_simulated_attr
(
module
)
...
...
nni/compression/pytorch/quantization/utils.py
View file @
f13a9cd4
...
@@ -79,5 +79,5 @@ def get_quant_shape(shape, quant_type, quant_scheme):
...
@@ -79,5 +79,5 @@ def get_quant_shape(shape, quant_type, quant_scheme):
if
is_per_channel
(
quant_scheme
):
if
is_per_channel
(
quant_scheme
):
quant_shape
=
[
1
if
idx
!=
default_idx
else
s
for
idx
,
s
in
enumerate
(
shape
)]
quant_shape
=
[
1
if
idx
!=
default_idx
else
s
for
idx
,
s
in
enumerate
(
shape
)]
else
:
else
:
quant_shape
=
[]
quant_shape
=
[
1
]
return
quant_shape
return
quant_shape
test/ut/sdk/test_compressor_torch.py
View file @
f13a9cd4
...
@@ -9,7 +9,7 @@ import torch.nn.functional as F
...
@@ -9,7 +9,7 @@ import torch.nn.functional as F
import
schema
import
schema
import
nni.algorithms.compression.pytorch.pruning
as
torch_pruner
import
nni.algorithms.compression.pytorch.pruning
as
torch_pruner
import
nni.algorithms.compression.pytorch.quantization
as
torch_quantizer
import
nni.algorithms.compression.pytorch.quantization
as
torch_quantizer
from
nni.compression.pytorch.quantization.utils
import
calculate_qmin_qmax
,
get_quant_shape
,
get_min_max_value
from
nni.compression.pytorch.quantization.utils
import
calculate_qmin_qmax
,
get_quant_shape
import
math
import
math
...
@@ -398,11 +398,11 @@ class CompressorTestCase(TestCase):
...
@@ -398,11 +398,11 @@ class CompressorTestCase(TestCase):
target_zero_point
=
torch
.
ones
([
2
,
1
,
1
,
1
])
*
127
target_zero_point
=
torch
.
ones
([
2
,
1
,
1
,
1
])
*
127
elif
qscheme
==
'per_tensor_symmetric'
:
elif
qscheme
==
'per_tensor_symmetric'
:
if
dtype
==
'int'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
18.
/
127
)
target_scale
=
torch
.
tensor
(
[
18.
/
127
]
)
target_zero_point
=
torch
.
zeros
([])
target_zero_point
=
torch
.
zeros
([
1
])
else
:
else
:
target_scale
=
torch
.
tensor
(
18.
/
127.5
)
target_scale
=
torch
.
tensor
(
[
18.
/
127.5
]
)
target_zero_point
=
torch
.
ones
([])
*
127
target_zero_point
=
torch
.
ones
([
1
])
*
127
elif
qscheme
==
'per_channel_affine'
:
elif
qscheme
==
'per_channel_affine'
:
min_val
=
torch
.
tensor
([
0.
,
0.
]).
view
([
2
,
1
,
1
,
1
])
min_val
=
torch
.
tensor
([
0.
,
0.
]).
view
([
2
,
1
,
1
,
1
])
if
dtype
==
'int'
:
if
dtype
==
'int'
:
...
@@ -413,10 +413,10 @@ class CompressorTestCase(TestCase):
...
@@ -413,10 +413,10 @@ class CompressorTestCase(TestCase):
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
else
:
else
:
if
dtype
==
'int'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
18.
/
254
)
target_scale
=
torch
.
tensor
(
[
18.
/
254
]
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
else
:
else
:
target_scale
=
torch
.
tensor
(
18.
/
255
)
target_scale
=
torch
.
tensor
(
[
18.
/
255
]
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
wrapper
=
getattr
(
model
,
name
)
wrapper
=
getattr
(
model
,
name
)
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
...
@@ -434,11 +434,11 @@ class CompressorTestCase(TestCase):
...
@@ -434,11 +434,11 @@ class CompressorTestCase(TestCase):
target_zero_point
=
torch
.
ones
([
1
,
1
,
1
,
1
])
*
127
target_zero_point
=
torch
.
ones
([
1
,
1
,
1
,
1
])
*
127
elif
qscheme
==
'per_tensor_symmetric'
:
elif
qscheme
==
'per_tensor_symmetric'
:
if
dtype
==
'int'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
15.
/
127
)
target_scale
=
torch
.
tensor
(
[
15.
/
127
]
)
target_zero_point
=
torch
.
zeros
([])
target_zero_point
=
torch
.
zeros
([
1
])
else
:
else
:
target_scale
=
torch
.
tensor
(
15.
/
127.5
)
target_scale
=
torch
.
tensor
(
[
15.
/
127.5
]
)
target_zero_point
=
torch
.
ones
([])
*
127
target_zero_point
=
torch
.
ones
([
1
])
*
127
elif
qscheme
==
'per_channel_affine'
:
elif
qscheme
==
'per_channel_affine'
:
min_val
=
torch
.
tensor
([
0.
]).
view
([
1
,
1
,
1
,
1
])
min_val
=
torch
.
tensor
([
0.
]).
view
([
1
,
1
,
1
,
1
])
if
dtype
==
'int'
:
if
dtype
==
'int'
:
...
@@ -449,10 +449,10 @@ class CompressorTestCase(TestCase):
...
@@ -449,10 +449,10 @@ class CompressorTestCase(TestCase):
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
target_zero_point
=
0
-
torch
.
round
(
min_val
/
target_scale
)
else
:
else
:
if
dtype
==
'int'
:
if
dtype
==
'int'
:
target_scale
=
torch
.
tensor
(
15.
/
254
)
target_scale
=
torch
.
tensor
(
[
15.
/
254
]
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
target_zero_point
=
-
127
-
torch
.
round
(
0
/
target_scale
)
else
:
else
:
target_scale
=
torch
.
tensor
(
15.
/
255
)
target_scale
=
torch
.
tensor
(
[
15.
/
255
]
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
target_zero_point
=
0
-
torch
.
round
(
0
/
target_scale
)
quantizer
.
quantize_input
(
inp
,
wrapper
)
quantizer
.
quantize_input
(
inp
,
wrapper
)
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
input_scale
,
target_scale
))
self
.
assertTrue
(
torch
.
equal
(
getattr
(
model
,
name
).
module
.
input_scale
,
target_scale
))
...
@@ -488,7 +488,7 @@ class CompressorTestCase(TestCase):
...
@@ -488,7 +488,7 @@ class CompressorTestCase(TestCase):
assert
model
.
conv2
.
module
.
weight_zero_point
==
0
assert
model
.
conv2
.
module
.
weight_zero_point
==
0
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_scale
,
torch
.
tensor
([
4.
/
255
])))
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_scale
,
torch
.
tensor
([
4.
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_zero_point
,
torch
.
tensor
(
0.
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_zero_point
,
torch
.
tensor
(
[
0.
]
)))
# range including 0
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
weight
=
weight
model
.
conv2
.
module
.
weight
=
weight
...
@@ -497,7 +497,7 @@ class CompressorTestCase(TestCase):
...
@@ -497,7 +497,7 @@ class CompressorTestCase(TestCase):
assert
model
.
conv2
.
module
.
weight_zero_point
in
(
42
,
43
)
assert
model
.
conv2
.
module
.
weight_zero_point
in
(
42
,
43
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
quantizer
.
quantize_input
(
input
,
model
.
conv2
)
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_scale
,
torch
.
tensor
([
4.
/
255
])))
self
.
assertTrue
(
torch
.
allclose
(
model
.
conv2
.
module
.
input_scale
,
torch
.
tensor
([
4.
/
255
])))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_zero_point
,
torch
.
tensor
(
0.
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
conv2
.
module
.
input_zero_point
,
torch
.
tensor
(
[
0.
]
)))
# test value of weight and bias after quantization
# test value of weight and bias after quantization
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
...
@@ -513,14 +513,14 @@ class CompressorTestCase(TestCase):
...
@@ -513,14 +513,14 @@ class CompressorTestCase(TestCase):
eps
=
1e-7
eps
=
1e-7
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
model
.
relu
(
x
)
model
.
relu
(
x
)
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
0.
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
[
0.
]
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
0.2
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
[
0.2
]
)))
quantizer
.
step_with_optimizer
()
quantizer
.
step_with_optimizer
()
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
x
=
torch
.
tensor
([[
0.2
,
0.4
],
[
0.6
,
0.8
]])
model
.
relu
(
x
)
model
.
relu
(
x
)
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
0.002
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_min_output
,
torch
.
tensor
(
[
0.002
]
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
0.2060
)))
self
.
assertTrue
(
torch
.
equal
(
model
.
relu
.
module
.
tracked_max_output
,
torch
.
tensor
(
[
0.2060
]
)))
def
test_torch_quantizer_export
(
self
):
def
test_torch_quantizer_export
(
self
):
config_list_qat
=
[{
config_list_qat
=
[{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment