Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a6069165
Unverified
Commit
a6069165
authored
Aug 24, 2021
by
chenbohua3
Committed by
GitHub
Aug 24, 2021
Browse files
fix wrong quantization target in weight quantization (#4038)
parent
e9c21fd3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
14 deletions
+12
-14
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+8
-11
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+1
-0
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+3
-3
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
a6069165
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
import
copy
from
collections
import
defaultdict
from
collections
import
defaultdict
import
torch
import
torch
from
schema
import
Schema
,
And
,
Or
,
Optional
from
schema
import
Schema
,
And
,
Or
,
Optional
...
@@ -36,7 +35,7 @@ class NaiveQuantizer(Quantizer):
...
@@ -36,7 +35,7 @@ class NaiveQuantizer(Quantizer):
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_
weight
.
data
)
weight
=
wrapper
.
module
.
weight
new_scale
=
weight
.
abs
().
max
()
/
127
new_scale
=
weight
.
abs
().
max
()
/
127
scale
=
max
(
self
.
layer_scale
.
get
(
wrapper
.
name
,
0
),
new_scale
)
scale
=
max
(
self
.
layer_scale
.
get
(
wrapper
.
name
,
0
),
new_scale
)
self
.
layer_scale
[
wrapper
.
name
]
=
scale
self
.
layer_scale
[
wrapper
.
name
]
=
scale
...
@@ -218,10 +217,8 @@ class ObserverQuantizer(Quantizer):
...
@@ -218,10 +217,8 @@ class ObserverQuantizer(Quantizer):
# the Pseudo-quantized one. So there is no need to quantize it
# the Pseudo-quantized one. So there is no need to quantize it
if
self
.
compressed
:
if
self
.
compressed
:
return
return
weight
=
wrapper
.
module
.
weight
module
=
wrapper
.
module
self
.
record
(
wrapper
,
'weight'
,
weight
)
old_weight
=
module
.
weight
self
.
record
(
wrapper
,
'weight'
,
old_weight
)
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
if
self
.
compressed
:
if
self
.
compressed
:
...
@@ -474,8 +471,8 @@ class QAT_Quantizer(Quantizer):
...
@@ -474,8 +471,8 @@ class QAT_Quantizer(Quantizer):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
config
=
wrapper
.
config
module
=
wrapper
.
module
module
=
wrapper
.
module
weight
=
module
.
weight
input
=
kwargs
[
'input_tensor'
]
# pylint: disable=redefined-builtin
input
=
kwargs
[
'input_tensor'
]
# pylint: disable=redefined-builtin
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_weight
.
data
)
weight_bits
=
get_bits_length
(
config
,
'weight'
)
weight_bits
=
get_bits_length
(
config
,
'weight'
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
...
@@ -675,7 +672,7 @@ class DoReFaQuantizer(Quantizer):
...
@@ -675,7 +672,7 @@ class DoReFaQuantizer(Quantizer):
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_
weight
.
data
)
weight
=
wrapper
.
module
.
weight
weight_bits
=
get_bits_length
(
wrapper
.
config
,
'weight'
)
weight_bits
=
get_bits_length
(
wrapper
.
config
,
'weight'
)
weight
=
weight
.
tanh
()
weight
=
weight
.
tanh
()
weight
=
weight
/
(
2
*
weight
.
abs
().
max
())
+
0.5
weight
=
weight
/
(
2
*
weight
.
abs
().
max
())
+
0.5
...
@@ -785,7 +782,7 @@ class BNNQuantizer(Quantizer):
...
@@ -785,7 +782,7 @@ class BNNQuantizer(Quantizer):
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_
weight
.
data
)
weight
=
wrapper
.
module
.
weight
weight
=
torch
.
sign
(
weight
)
weight
=
torch
.
sign
(
weight
)
# remove zeros
# remove zeros
weight
[
weight
==
0
]
=
1
weight
[
weight
==
0
]
=
1
...
@@ -944,11 +941,11 @@ class LsqQuantizer(Quantizer):
...
@@ -944,11 +941,11 @@ class LsqQuantizer(Quantizer):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
module
=
wrapper
.
module
module
=
wrapper
.
module
weight
=
wrapper
.
module
.
weight
# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize
# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize
# bias
# bias
old_weight
=
module
.
old_weight
weight
=
self
.
quantize
(
weight
,
module
.
weight_scale
,
module
.
weight_qmin
,
module
.
weight_qmax
)
weight
=
self
.
quantize
(
old_weight
,
module
.
weight_scale
,
module
.
weight_qmin
,
module
.
weight_qmax
)
module
.
weight
=
weight
module
.
weight
=
weight
return
weight
return
weight
...
...
nni/compression/pytorch/compressor.py
View file @
a6069165
...
@@ -559,6 +559,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -559,6 +559,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
module
.
weight
=
new_weight
self
.
module
.
weight
=
new_weight
else
:
else
:
new_weight
=
self
.
module
.
old_weight
new_weight
=
self
.
module
.
old_weight
self
.
module
.
weight
=
new_weight
.
data
self
.
quantizer
.
quant_grad
(
self
.
quantizer
.
quant_grad
(
new_weight
,
new_weight
,
...
...
test/ut/sdk/test_compressor_torch.py
View file @
a6069165
...
@@ -328,13 +328,13 @@ class CompressorTestCase(TestCase):
...
@@ -328,13 +328,13 @@ class CompressorTestCase(TestCase):
eps
=
1e-7
eps
=
1e-7
input
=
torch
.
tensor
([[
0
,
4
],
[
2
,
1
]]).
float
()
input
=
torch
.
tensor
([[
0
,
4
],
[
2
,
1
]]).
float
()
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
old_
weight
.
data
=
weight
model
.
conv2
.
module
.
weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
==
0
assert
model
.
conv2
.
module
.
zero_point
==
0
# range including 0
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
old_
weight
.
data
=
weight
model
.
conv2
.
module
.
weight
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
in
(
42
,
43
)
assert
model
.
conv2
.
module
.
zero_point
in
(
42
,
43
)
...
@@ -343,7 +343,7 @@ class CompressorTestCase(TestCase):
...
@@ -343,7 +343,7 @@ class CompressorTestCase(TestCase):
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
bias
=
torch
.
tensor
([
2.3432
,
3.4342
,
1.3414
,
5.2341
])
bias
=
torch
.
tensor
([
2.3432
,
3.4342
,
1.3414
,
5.2341
])
bias_valid
=
torch
.
tensor
([
2.3432
,
3.4342
,
1.3414
,
5.2341
])
bias_valid
=
torch
.
tensor
([
2.3432
,
3.4342
,
1.3414
,
5.2341
])
model
.
conv2
.
module
.
old_
weight
.
data
=
weight
model
.
conv2
.
module
.
weight
=
weight
model
.
conv2
.
module
.
bias
.
data
=
bias
model
.
conv2
.
module
.
bias
.
data
=
bias
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
torch
.
all
(
torch
.
isclose
(
model
.
conv2
.
module
.
weight
.
data
,
weight_valid
,
rtol
=
1e-4
))
assert
torch
.
all
(
torch
.
isclose
(
model
.
conv2
.
module
.
weight
.
data
,
weight_valid
,
rtol
=
1e-4
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment