Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a6069165
Unverified
Commit
a6069165
authored
Aug 24, 2021
by
chenbohua3
Committed by
GitHub
Aug 24, 2021
Browse files
fix wrong quantization target in weight quantization (#4038)
parent
e9c21fd3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
14 deletions
+12
-14
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+8
-11
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+1
-0
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+3
-3
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
a6069165
...
...
@@ -2,7 +2,6 @@
# Licensed under the MIT license.
import
logging
import
copy
from
collections
import
defaultdict
import
torch
from
schema
import
Schema
,
And
,
Or
,
Optional
...
...
@@ -36,7 +35,7 @@ class NaiveQuantizer(Quantizer):
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_
weight
.
data
)
weight
=
wrapper
.
module
.
weight
new_scale
=
weight
.
abs
().
max
()
/
127
scale
=
max
(
self
.
layer_scale
.
get
(
wrapper
.
name
,
0
),
new_scale
)
self
.
layer_scale
[
wrapper
.
name
]
=
scale
...
...
@@ -218,10 +217,8 @@ class ObserverQuantizer(Quantizer):
# the Pseudo-quantized one. So there is no need to quantize it
if
self
.
compressed
:
return
module
=
wrapper
.
module
old_weight
=
module
.
weight
self
.
record
(
wrapper
,
'weight'
,
old_weight
)
weight
=
wrapper
.
module
.
weight
self
.
record
(
wrapper
,
'weight'
,
weight
)
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
if
self
.
compressed
:
...
...
@@ -474,8 +471,8 @@ class QAT_Quantizer(Quantizer):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
weight
=
module
.
weight
input
=
kwargs
[
'input_tensor'
]
# pylint: disable=redefined-builtin
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_weight
.
data
)
weight_bits
=
get_bits_length
(
config
,
'weight'
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
...
...
@@ -675,7 +672,7 @@ class DoReFaQuantizer(Quantizer):
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_
weight
.
data
)
weight
=
wrapper
.
module
.
weight
weight_bits
=
get_bits_length
(
wrapper
.
config
,
'weight'
)
weight
=
weight
.
tanh
()
weight
=
weight
/
(
2
*
weight
.
abs
().
max
())
+
0.5
...
...
@@ -785,7 +782,7 @@ class BNNQuantizer(Quantizer):
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_
weight
.
data
)
weight
=
wrapper
.
module
.
weight
weight
=
torch
.
sign
(
weight
)
# remove zeros
weight
[
weight
==
0
]
=
1
...
...
@@ -944,11 +941,11 @@ class LsqQuantizer(Quantizer):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
module
=
wrapper
.
module
weight
=
wrapper
.
module
.
weight
# todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize
# bias
old_weight
=
module
.
old_weight
weight
=
self
.
quantize
(
old_weight
,
module
.
weight_scale
,
module
.
weight_qmin
,
module
.
weight_qmax
)
weight
=
self
.
quantize
(
weight
,
module
.
weight_scale
,
module
.
weight_qmin
,
module
.
weight_qmax
)
module
.
weight
=
weight
return
weight
...
...
nni/compression/pytorch/compressor.py
View file @
a6069165
...
...
@@ -559,6 +559,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
.
module
.
weight
=
new_weight
else
:
new_weight
=
self
.
module
.
old_weight
self
.
module
.
weight
=
new_weight
.
data
self
.
quantizer
.
quant_grad
(
new_weight
,
...
...
test/ut/sdk/test_compressor_torch.py
View file @
a6069165
...
...
@@ -328,13 +328,13 @@ class CompressorTestCase(TestCase):
eps
=
1e-7
input
=
torch
.
tensor
([[
0
,
4
],
[
2
,
1
]]).
float
()
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
old_
weight
.
data
=
weight
model
.
conv2
.
module
.
weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
==
0
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
model
.
conv2
.
module
.
old_
weight
.
data
=
weight
model
.
conv2
.
module
.
weight
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
in
(
42
,
43
)
...
...
@@ -343,7 +343,7 @@ class CompressorTestCase(TestCase):
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
bias
=
torch
.
tensor
([
2.3432
,
3.4342
,
1.3414
,
5.2341
])
bias_valid
=
torch
.
tensor
([
2.3432
,
3.4342
,
1.3414
,
5.2341
])
model
.
conv2
.
module
.
old_
weight
.
data
=
weight
model
.
conv2
.
module
.
weight
=
weight
model
.
conv2
.
module
.
bias
.
data
=
bias
quantizer
.
quantize_weight
(
model
.
conv2
,
input_tensor
=
input
)
assert
torch
.
all
(
torch
.
isclose
(
model
.
conv2
.
module
.
weight
.
data
,
weight_valid
,
rtol
=
1e-4
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment