Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0a6c234a
Unverified
Commit
0a6c234a
authored
Oct 10, 2020
by
lin bin
Committed by
GitHub
Oct 10, 2020
Browse files
Add bias quantization in QAT and refactor the code of weight quantization (#2914)
parent
6126960c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
26 deletions
+69
-26
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+2
-3
src/sdk/pynni/nni/compression/torch/quantization/quantizers.py
...dk/pynni/nni/compression/torch/quantization/quantizers.py
+51
-21
src/sdk/pynni/tests/test_compressor_torch.py
src/sdk/pynni/tests/test_compressor_torch.py
+16
-2
No files found.
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
0a6c234a
...
...
@@ -481,11 +481,10 @@ class QuantizerModuleWrapper(torch.nn.Module):
self
)
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
self
.
quantizer
.
quant_grad
.
apply
(
self
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
self
)
self
.
module
.
weight
=
new_weight
result
=
self
.
module
(
*
inputs
)
else
:
result
=
self
.
module
(
*
inputs
)
...
...
@@ -617,7 +616,7 @@ class QuantGrad(torch.autograd.Function):
if
quant_type
==
QuantType
.
QUANT_INPUT
:
return
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
return
wrapper
.
quantizer
.
quantize_weight
(
tensor
,
wrapper
,
**
kwargs
)
return
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
return
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
...
...
src/sdk/pynni/nni/compression/torch/quantization/quantizers.py
View file @
0a6c234a
...
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import
logging
import
copy
import
torch
from
schema
import
Schema
,
And
,
Or
,
Optional
from
..utils.config_validation
import
CompressorSchema
...
...
@@ -15,6 +16,7 @@ logger = logging.getLogger(__name__)
class
NaiveQuantizer
(
Quantizer
):
"""quantize weight to 8 bits
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
layer_scale
=
{}
...
...
@@ -29,13 +31,15 @@ class NaiveQuantizer(Quantizer):
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_weight
.
data
)
new_scale
=
weight
.
abs
().
max
()
/
127
scale
=
max
(
self
.
layer_scale
.
get
(
wrapper
.
name
,
0
),
new_scale
)
self
.
layer_scale
[
wrapper
.
name
]
=
scale
orig_type
=
weight
.
type
()
# TODO: user layer
return
weight
.
div
(
scale
).
type
(
torch
.
int8
).
type
(
orig_type
).
mul
(
scale
)
weight
=
weight
.
div
(
scale
).
type
(
torch
.
int8
).
type
(
orig_type
).
mul
(
scale
)
wrapper
.
module
.
weight
=
weight
return
weight
def
update_ema
(
biased_ema
,
value
,
decay
,
step
):
"""
...
...
@@ -60,6 +64,7 @@ def update_ema(biased_ema, value, decay, step):
unbiased_ema
=
biased_ema
/
(
1
-
decay
**
step
)
# Bias correction
return
biased_ema
,
unbiased_ema
def
update_quantization_param
(
bits
,
rmin
,
rmax
):
"""
calculate the `zero_point` and `scale`.
...
...
@@ -116,6 +121,7 @@ class QAT_Quantizer(Quantizer):
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
"""
Parameters
...
...
@@ -215,20 +221,35 @@ class QAT_Quantizer(Quantizer):
real_val
=
op
.
scale
*
(
quantized_val
-
op
.
zero_point
)
return
real_val
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_weight
.
data
)
weight_bits
=
get_bits_length
(
config
,
'weight'
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
steps
:
return
weight
# if bias exists, quantize bias to uint32
if
hasattr
(
wrapper
.
module
,
'bias'
)
and
wrapper
.
module
.
bias
is
not
None
:
bias
=
wrapper
.
module
.
bias
.
data
bias_bits
=
32
rmin
,
rmax
=
torch
.
min
(
bias
),
torch
.
max
(
bias
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
bias_bits
,
rmin
,
rmax
)
bias
=
self
.
_quantize
(
bias_bits
,
module
,
bias
)
bias
=
self
.
_dequantize
(
module
,
bias
)
wrapper
.
module
.
bias
.
data
=
bias
# quantize weight
rmin
,
rmax
=
torch
.
min
(
weight
),
torch
.
max
(
weight
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
out
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
weight
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
weight
=
self
.
_dequantize
(
module
,
weight
)
wrapper
.
module
.
weight
=
weight
return
weight
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
...
...
@@ -241,8 +262,10 @@ class QAT_Quantizer(Quantizer):
return
output
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_biased
,
module
.
tracked_min
=
update_ema
(
module
.
tracked_min_biased
,
current_min
,
module
.
ema_decay
,
self
.
steps
)
module
.
tracked_max_biased
,
module
.
tracked_max
=
update_ema
(
module
.
tracked_max_biased
,
current_max
,
module
.
ema_decay
,
self
.
steps
)
module
.
tracked_min_biased
,
module
.
tracked_min
=
update_ema
(
module
.
tracked_min_biased
,
current_min
,
module
.
ema_decay
,
self
.
steps
)
module
.
tracked_max_biased
,
module
.
tracked_max
=
update_ema
(
module
.
tracked_max_biased
,
current_max
,
module
.
ema_decay
,
self
.
steps
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min
,
module
.
tracked_max
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
module
,
out
)
...
...
@@ -264,6 +287,7 @@ class DoReFaQuantizer(Quantizer):
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
...
...
@@ -287,17 +311,20 @@ class DoReFaQuantizer(Quantizer):
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_weight
.
data
)
weight_bits
=
get_bits_length
(
wrapper
.
config
,
'weight'
)
out
=
weight
.
tanh
()
out
=
out
/
(
2
*
out
.
abs
().
max
())
+
0.5
out
=
self
.
quantize
(
out
,
weight_bits
)
out
=
2
*
out
-
1
return
out
weight
=
weight
.
tanh
()
weight
=
weight
/
(
2
*
weight
.
abs
().
max
())
+
0.5
weight
=
self
.
quantize
(
weight
,
weight_bits
)
weight
=
2
*
weight
-
1
wrapper
.
module
.
weight
=
weight
# wrapper.module.weight.data = weight
return
weight
def
quantize
(
self
,
input_ri
,
q_bits
):
scale
=
pow
(
2
,
q_bits
)
-
1
output
=
torch
.
round
(
input_ri
*
scale
)
/
scale
scale
=
pow
(
2
,
q_bits
)
-
1
output
=
torch
.
round
(
input_ri
*
scale
)
/
scale
return
output
...
...
@@ -314,6 +341,7 @@ class BNNQuantizer(Quantizer):
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
ClipGrad
...
...
@@ -339,11 +367,13 @@ class BNNQuantizer(Quantizer):
schema
.
validate
(
config_list
)
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
out
=
torch
.
sign
(
weight
)
def
quantize_weight
(
self
,
wrapper
,
**
kwargs
):
weight
=
copy
.
deepcopy
(
wrapper
.
module
.
old_weight
.
data
)
weight
=
torch
.
sign
(
weight
)
# remove zeros
out
[
out
==
0
]
=
1
return
out
weight
[
weight
==
0
]
=
1
wrapper
.
module
.
weight
=
weight
return
weight
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
out
=
torch
.
sign
(
output
)
...
...
src/sdk/pynni/tests/test_compressor_torch.py
View file @
0a6c234a
...
...
@@ -234,20 +234,34 @@ class CompressorTestCase(TestCase):
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
torch_compressor
.
QAT_Quantizer
(
model
,
config_list
)
quantizer
.
compress
()
# test quantize
# range not including 0
eps
=
1e-7
weight
=
torch
.
tensor
([[
1
,
2
],
[
3
,
5
]]).
float
()
quantize_weight
=
quantizer
.
quantize_weight
(
weight
,
model
.
conv2
)
model
.
conv2
.
module
.
old_weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
5
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
==
0
# range including 0
weight
=
torch
.
tensor
([[
-
1
,
2
],
[
3
,
5
]]).
float
()
quantize_weight
=
quantizer
.
quantize_weight
(
weight
,
model
.
conv2
)
model
.
conv2
.
module
.
old_weight
.
data
=
weight
quantizer
.
quantize_weight
(
model
.
conv2
)
assert
math
.
isclose
(
model
.
conv2
.
module
.
scale
,
6
/
255
,
abs_tol
=
eps
)
assert
model
.
conv2
.
module
.
zero_point
in
(
42
,
43
)
# test value of weight and bias after quantization
weight
=
torch
.
tensor
([[
1.1287
,
2.3456
],
[
3.7814
,
5.9723
]])
weight_valid
=
torch
.
tensor
([[
1.1242
,
2.3421
],
[
3.7707
,
5.9723
]])
bias
=
torch
.
tensor
([
2.3432
,
3.4342
,
1.3414
,
5.2341
])
bias_valid
=
torch
.
tensor
([
2.3432
,
3.4342
,
1.3414
,
5.2341
])
model
.
conv2
.
module
.
old_weight
.
data
=
weight
model
.
conv2
.
module
.
bias
.
data
=
bias
quantizer
.
quantize_weight
(
model
.
conv2
)
assert
torch
.
all
(
torch
.
isclose
(
model
.
conv2
.
module
.
weight
.
data
,
weight_valid
,
rtol
=
1e-4
))
assert
torch
.
all
(
torch
.
isclose
(
model
.
conv2
.
module
.
bias
.
data
,
bias_valid
,
rtol
=
1e-7
))
# test ema
eps
=
1e-7
x
=
torch
.
tensor
([[
-
0.2
,
0
],
[
0.1
,
0.2
]])
out
=
model
.
relu
(
x
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_biased
,
0
,
abs_tol
=
eps
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment