Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
396ae65c
Unverified
Commit
396ae65c
authored
Sep 08, 2021
by
chenbohua3
Committed by
GitHub
Sep 08, 2021
Browse files
support dp multi-gpu training for QAT quantizer (#4127)
parent
29b4d46c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
32 deletions
+63
-32
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+50
-30
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+13
-2
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
396ae65c
...
...
@@ -373,21 +373,22 @@ class QAT_Quantizer(Quantizer):
self
.
quant_grad
=
QATGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
device
=
next
(
model
.
parameters
()).
device
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
T
ensor
(
[
1
]
))
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
t
ensor
(
1
))
for
layer
,
config
in
modules_to_compress
:
layer
.
module
.
register_buffer
(
"zero_point"
,
torch
.
Tensor
([
0.0
]))
layer
.
module
.
register_buffer
(
"scale"
,
torch
.
Tensor
([
1.0
]))
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
module
=
layer
.
module
module
.
register_buffer
(
"zero_point"
,
torch
.
tensor
([
0.0
]))
module
.
register_buffer
(
"scale"
,
torch
.
tensor
([
1.0
]))
module
.
register_buffer
(
'ema_decay'
,
torch
.
tensor
([
0.99
]))
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'weight_bits'
,
torch
.
zeros
(
1
))
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'
tracked_min_
input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_m
ax
_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'input
_bits
'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'input
_bits
'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'tracked_m
in
_input'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'
tracked_max_
input'
,
torch
.
zeros
(
1
))
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'output_bits'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_output'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_output'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'output_bits'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'tracked_min_output'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'tracked_max_output'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
...
...
@@ -479,8 +480,7 @@ class QAT_Quantizer(Quantizer):
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
# we dont update weight in evaluation stage
if
quant_start_step
>
self
.
bound_model
.
steps
:
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
return
weight
if
not
wrapper
.
training
:
...
...
@@ -488,10 +488,16 @@ class QAT_Quantizer(Quantizer):
# quantize weight
rmin
,
rmax
=
torch
.
min
(
weight
),
torch
.
max
(
weight
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
scale
,
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
module
.
scale
.
copy_
(
scale
)
module
.
zero_point
.
copy_
(
zero_point
)
weight
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
weight
=
self
.
_dequantize
(
module
,
weight
)
module
.
weight_bits
=
torch
.
Tensor
([
weight_bits
])
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper
.
module
.
weight
=
weight
return
weight
...
...
@@ -499,23 +505,30 @@ class QAT_Quantizer(Quantizer):
config
=
wrapper
.
config
module
=
wrapper
.
module
input_bits
=
get_bits_length
(
config
,
'input'
)
module
.
input_bits
=
torch
.
Tensor
([
input_bits
])
module
.
input_bit
=
torch
.
tensor
([
input_bits
])
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
input_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_input
,
module
.
tracked_max_input
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
current_min
,
current_max
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
module
.
tracked_min_input
.
copy_
(
current_min
)
module
.
tracked_max_input
.
copy_
(
current_max
)
return
inputs
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
module
.
tracked_min_input
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_input
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
current_min
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
module
.
ema_decay
)
current_max
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
module
.
ema_decay
)
module
.
tracked_min_input
.
copy_
(
current_min
)
module
.
tracked_max_input
.
copy_
(
current_max
)
scale
,
zero_point
=
update_quantization_param
(
input_bits
,
module
.
tracked_min_input
,
module
.
tracked_max_input
)
module
.
scale
.
copy_
(
scale
)
module
.
zero_point
.
copy_
(
zero_point
)
inp
=
self
.
_quantize
(
input_bits
,
module
,
inputs
)
inp
=
self
.
_dequantize
(
module
,
inp
)
return
inp
...
...
@@ -528,19 +541,26 @@ class QAT_Quantizer(Quantizer):
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_output
,
module
.
tracked_max_output
=
torch
.
min
(
output
),
torch
.
max
(
output
)
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_output
.
copy_
(
current_min
)
module
.
tracked_max_output
.
copy_
(
current_max
)
return
output
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_output
=
update_ema
(
module
.
tracked_min_output
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_output
=
update_ema
(
module
.
tracked_max_output
,
current_max
,
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
tracked_min_output
=
update_ema
(
module
.
tracked_min_output
,
current_min
,
module
.
ema_decay
)
tracked_max_output
=
update_ema
(
module
.
tracked_max_output
,
current_max
,
module
.
ema_decay
)
module
.
tracked_min_output
.
copy_
(
tracked_min_output
)
module
.
tracked_max_output
.
copy_
(
tracked_max_output
)
scale
,
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min_output
,
module
.
tracked_max_output
)
module
.
scale
.
copy_
(
scale
)
module
.
zero_point
.
copy_
(
zero_point
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
module
,
out
)
...
...
@@ -645,7 +665,7 @@ class QAT_Quantizer(Quantizer):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self
.
bound_model
.
steps
+=
1
self
.
bound_model
.
steps
.
add_
(
1
)
class
DoReFaQuantizer
(
Quantizer
):
...
...
nni/compression/pytorch/compressor.py
View file @
396ae65c
...
...
@@ -602,6 +602,8 @@ class Quantizer(Compressor):
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
):
model
=
model
.
module
self
.
identity_wrappers
=
[]
self
.
conv_bn_patterns
=
{}
self
.
find_conv_bn_patterns
(
model
,
dummy_input
)
...
...
@@ -892,12 +894,21 @@ class QuantGrad(torch.autograd.Function):
zero_point
=
wrapper
.
module
.
zero_point
else
:
scale
,
zero_point
=
None
,
None
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]),
scale
,
zero_point
,
qmin
,
qmax
)
ctx
.
save_for_backward
(
tensor
)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
ctx
.
scale
=
scale
ctx
.
zero_point
=
zero_point
ctx
.
quant_type
=
quant_type
ctx
.
qmin
,
ctx
.
qmax
=
qmin
,
qmax
return
output
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
=
ctx
.
saved_variables
tensor
=
ctx
.
saved_variables
[
0
]
scale
,
zero_point
=
ctx
.
scale
,
ctx
.
zero_point
qmin
,
qmax
=
ctx
.
qmin
,
ctx
.
qmax
quant_type
=
ctx
.
quant_type
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
)
return
output
,
None
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment