Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0c13ea49
Unverified
Commit
0c13ea49
authored
Dec 25, 2020
by
lin bin
Committed by
GitHub
Dec 25, 2020
Browse files
fix QAT ema issue and tensor type error (#3219)
parent
cc58a81d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
11 deletions
+9
-11
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+8
-10
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+1
-1
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
0c13ea49
...
...
@@ -41,7 +41,7 @@ class NaiveQuantizer(Quantizer):
wrapper
.
module
.
weight
=
weight
return
weight
def
update_ema
(
biased_ema
,
value
,
decay
,
step
):
def
update_ema
(
biased_ema
,
value
,
decay
):
"""
calculate biased stat and unbiased stat in each step using exponential moving average method
...
...
@@ -53,16 +53,13 @@ def update_ema(biased_ema, value, decay, step):
current stat value
decay : float
the weight of previous stat value, larger means smoother curve
step : int
current step
Returns
-------
float, float
"""
biased_ema
=
biased_ema
*
decay
+
(
1
-
decay
)
*
value
unbiased_ema
=
biased_ema
/
(
1
-
decay
**
step
)
# Bias correction
return
biased_ema
,
unbiased_ema
return
biased_ema
def
update_quantization_param
(
bits
,
rmin
,
rmax
):
...
...
@@ -260,16 +257,17 @@ class QAT_Quantizer(Quantizer):
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_biased
,
module
.
tracked_max_biased
=
torch
.
min
(
output
),
torch
.
max
(
output
)
return
output
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_biased
,
module
.
tracked_min
=
update_ema
(
module
.
tracked_min_biased
,
current_min
,
module
.
ema_decay
,
self
.
bound_model
.
steps
)
module
.
tracked_max_biased
,
module
.
tracked_max
=
update_ema
(
module
.
tracked_max_biased
,
current_max
,
module
.
ema_decay
,
self
.
bound_model
.
steps
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min
,
module
.
tracked_max
)
module
.
tracked_min_biased
=
update_ema
(
module
.
tracked_min_biased
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_biased
=
update_ema
(
module
.
tracked_max_biased
,
current_max
,
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min
_biased
,
module
.
tracked_max
_biased
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
...
...
nni/compression/pytorch/compressor.py
View file @
0c13ea49
...
...
@@ -669,7 +669,7 @@ class QuantGrad(torch.autograd.Function):
raise
ValueError
(
"unrecognized QuantType."
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
quant_type
)
qmin
,
qmax
=
torch
.
Tensor
([
0
]
,
device
=
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
],
device
=
tensor
.
device
)
qmin
,
qmax
=
torch
.
Tensor
([
0
]
).
to
(
device
=
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
to
(
device
=
tensor
.
device
)
ctx
.
save_for_backward
(
tensor
,
wrapper
.
module
.
scale
,
wrapper
.
module
.
zero_point
,
qmin
,
qmax
)
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment