Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
fc0ff8ce
Unverified
Commit
fc0ff8ce
authored
Nov 30, 2020
by
Dalong
Committed by
GitHub
Nov 30, 2020
Browse files
fix checkpoint load error and stop updating paramters in evaluation stage (#3124)
parent
62d5812d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
21 deletions
+33
-21
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+33
-21
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
fc0ff8ce
...
@@ -73,9 +73,9 @@ def update_quantization_param(bits, rmin, rmax):
...
@@ -73,9 +73,9 @@ def update_quantization_param(bits, rmin, rmax):
----------
----------
bits : int
bits : int
quantization bits length
quantization bits length
rmin :
float
rmin :
Tensor
min value of real value
min value of real value
rmax :
float
rmax :
Tensor
max value of real value
max value of real value
Returns
Returns
...
@@ -85,12 +85,17 @@ def update_quantization_param(bits, rmin, rmax):
...
@@ -85,12 +85,17 @@ def update_quantization_param(bits, rmin, rmax):
# extend the [min, max] interval to ensure that it contains 0.
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
# representable value.
rmin
=
min
(
rmin
,
0
)
if
rmin
.
is_cuda
:
rmax
=
max
(
rmax
,
0
)
rmin
=
torch
.
min
(
rmin
,
torch
.
Tensor
([
0
]).
cuda
())
rmax
=
torch
.
max
(
rmax
,
torch
.
Tensor
([
0
]).
cuda
())
qmin
=
torch
.
Tensor
([
0
]).
cuda
()
qmax
=
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
cuda
()
else
:
rmin
=
torch
.
min
(
rmin
,
torch
.
Tensor
([
0
]))
rmax
=
torch
.
max
(
rmax
,
torch
.
Tensor
([
0
]))
qmin
=
torch
.
Tensor
([
0
])
qmax
=
torch
.
Tensor
([(
1
<<
bits
)
-
1
])
# the min and max quantized values, as floating-point values
qmin
=
0
qmax
=
(
1
<<
bits
)
-
1
# First determine the scale.
# First determine the scale.
scale
=
(
rmax
-
rmin
)
/
(
qmax
-
qmin
)
scale
=
(
rmax
-
rmin
)
/
(
qmax
-
qmin
)
...
@@ -143,11 +148,11 @@ class QAT_Quantizer(Quantizer):
...
@@ -143,11 +148,11 @@ class QAT_Quantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d'
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
steps
=
1
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
layer
.
module
.
register_buffer
(
"zero_point"
,
None
)
layer
.
module
.
register_buffer
(
"zero_point"
,
torch
.
Tensor
([
0.0
])
)
layer
.
module
.
register_buffer
(
"scale"
,
None
)
layer
.
module
.
register_buffer
(
"scale"
,
torch
.
Tensor
([
1.0
])
)
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
layer
.
module
.
register_buffer
(
'tracked_min_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_biased'
,
torch
.
zeros
(
1
))
...
@@ -187,13 +192,17 @@ class QAT_Quantizer(Quantizer):
...
@@ -187,13 +192,17 @@ class QAT_Quantizer(Quantizer):
quantization bits length
quantization bits length
op : torch.nn.Module
op : torch.nn.Module
target module
target module
real_val :
float
real_val :
Tensor
real value to be quantized
real value to be quantized
Returns
Returns
-------
-------
float
Tensor
"""
"""
if
real_val
.
is_cuda
:
op
.
zero_point
=
op
.
zero_point
.
cuda
()
op
.
scale
=
op
.
scale
.
cuda
()
transformed_val
=
op
.
zero_point
+
real_val
/
op
.
scale
transformed_val
=
op
.
zero_point
+
real_val
/
op
.
scale
qmin
=
0
qmin
=
0
qmax
=
(
1
<<
bits
)
-
1
qmax
=
(
1
<<
bits
)
-
1
...
@@ -229,7 +238,8 @@ class QAT_Quantizer(Quantizer):
...
@@ -229,7 +238,8 @@ class QAT_Quantizer(Quantizer):
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
steps
:
# we dont update weight in evaluation stage
if
quant_start_step
>
self
.
bound_model
.
steps
or
not
wrapper
.
training
:
return
weight
return
weight
# if bias exists, quantize bias to uint32
# if bias exists, quantize bias to uint32
...
@@ -258,14 +268,16 @@ class QAT_Quantizer(Quantizer):
...
@@ -258,14 +268,16 @@ class QAT_Quantizer(Quantizer):
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
steps
:
if
quant_start_step
>
self
.
bound_model
.
steps
:
return
output
return
output
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_biased
,
module
.
tracked_min
=
update_ema
(
module
.
tracked_min_biased
,
current_min
,
module
.
tracked_min_biased
,
module
.
tracked_min
=
update_ema
(
module
.
tracked_min_biased
,
current_min
,
module
.
ema_decay
,
self
.
steps
)
module
.
ema_decay
,
self
.
bound_model
.
steps
)
module
.
tracked_max_biased
,
module
.
tracked_max
=
update_ema
(
module
.
tracked_max_biased
,
current_max
,
module
.
tracked_max_biased
,
module
.
tracked_max
=
update_ema
(
module
.
tracked_max_biased
,
current_max
,
module
.
ema_decay
,
self
.
steps
)
module
.
ema_decay
,
self
.
bound_model
.
steps
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min
,
module
.
tracked_max
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min
,
module
.
tracked_max
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
module
,
out
)
out
=
self
.
_dequantize
(
module
,
out
)
...
@@ -279,7 +291,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -279,7 +291,7 @@ class QAT_Quantizer(Quantizer):
"""
"""
override `compressor` `step` method, quantization only happens after certain number of steps
override `compressor` `step` method, quantization only happens after certain number of steps
"""
"""
self
.
steps
+=
1
self
.
bound_model
.
steps
+=
1
class
DoReFaQuantizer
(
Quantizer
):
class
DoReFaQuantizer
(
Quantizer
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment