Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
396ae65c
"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "29deb085f097f584223e0e276050b867577693d7"
Unverified
Commit
396ae65c
authored
Sep 08, 2021
by
chenbohua3
Committed by
GitHub
Sep 08, 2021
Browse files
support dp multi-gpu training for QAT quantizer (#4127)
parent
29b4d46c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
32 deletions
+63
-32
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+50
-30
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+13
-2
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
396ae65c
...
...
@@ -373,21 +373,22 @@ class QAT_Quantizer(Quantizer):
self
.
quant_grad
=
QATGrad
.
apply
modules_to_compress
=
self
.
get_modules_to_compress
()
device
=
next
(
model
.
parameters
()).
device
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
T
ensor
(
[
1
]
))
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
t
ensor
(
1
))
for
layer
,
config
in
modules_to_compress
:
layer
.
module
.
register_buffer
(
"zero_point"
,
torch
.
Tensor
([
0.0
]))
layer
.
module
.
register_buffer
(
"scale"
,
torch
.
Tensor
([
1.0
]))
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
module
=
layer
.
module
module
.
register_buffer
(
"zero_point"
,
torch
.
tensor
([
0.0
]))
module
.
register_buffer
(
"scale"
,
torch
.
tensor
([
1.0
]))
module
.
register_buffer
(
'ema_decay'
,
torch
.
tensor
([
0.99
]))
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bits'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'weight_bits'
,
torch
.
zeros
(
1
))
if
"input"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'
tracked_min_
input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_m
ax
_input'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'input
_bits
'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'input
_bits
'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'tracked_m
in
_input'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'
tracked_max_
input'
,
torch
.
zeros
(
1
))
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'output_bits'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_output'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_output'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'output_bits'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'tracked_min_output'
,
torch
.
zeros
(
1
))
module
.
register_buffer
(
'tracked_max_output'
,
torch
.
zeros
(
1
))
self
.
bound_model
.
to
(
device
)
def
_del_simulated_attr
(
self
,
module
):
...
...
@@ -479,8 +480,7 @@ class QAT_Quantizer(Quantizer):
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
# we dont update weight in evaluation stage
if
quant_start_step
>
self
.
bound_model
.
steps
:
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
return
weight
if
not
wrapper
.
training
:
...
...
@@ -488,10 +488,16 @@ class QAT_Quantizer(Quantizer):
# quantize weight
rmin
,
rmax
=
torch
.
min
(
weight
),
torch
.
max
(
weight
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
scale
,
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
module
.
scale
.
copy_
(
scale
)
module
.
zero_point
.
copy_
(
zero_point
)
weight
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
weight
=
self
.
_dequantize
(
module
,
weight
)
module
.
weight_bits
=
torch
.
Tensor
([
weight_bits
])
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper
.
module
.
weight
=
weight
return
weight
...
...
@@ -499,23 +505,30 @@ class QAT_Quantizer(Quantizer):
config
=
wrapper
.
config
module
=
wrapper
.
module
input_bits
=
get_bits_length
(
config
,
'input'
)
module
.
input_bits
=
torch
.
Tensor
([
input_bits
])
module
.
input_bit
=
torch
.
tensor
([
input_bits
])
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
input_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_input
,
module
.
tracked_max_input
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
current_min
,
current_max
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
module
.
tracked_min_input
.
copy_
(
current_min
)
module
.
tracked_max_input
.
copy_
(
current_max
)
return
inputs
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
inputs
),
torch
.
max
(
inputs
)
module
.
tracked_min_input
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_input
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
current_min
=
update_ema
(
module
.
tracked_min_input
,
current_min
,
module
.
ema_decay
)
current_max
=
update_ema
(
module
.
tracked_max_input
,
current_max
,
module
.
ema_decay
)
module
.
tracked_min_input
.
copy_
(
current_min
)
module
.
tracked_max_input
.
copy_
(
current_max
)
scale
,
zero_point
=
update_quantization_param
(
input_bits
,
module
.
tracked_min_input
,
module
.
tracked_max_input
)
module
.
scale
.
copy_
(
scale
)
module
.
zero_point
.
copy_
(
zero_point
)
inp
=
self
.
_quantize
(
input_bits
,
module
,
inputs
)
inp
=
self
.
_dequantize
(
module
,
inp
)
return
inp
...
...
@@ -528,19 +541,26 @@ class QAT_Quantizer(Quantizer):
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_output
,
module
.
tracked_max_output
=
torch
.
min
(
output
),
torch
.
max
(
output
)
if
quant_start_step
>
int
(
self
.
bound_model
.
steps
):
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_output
.
copy_
(
current_min
)
module
.
tracked_max_output
.
copy_
(
current_max
)
return
output
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_output
=
update_ema
(
module
.
tracked_min_output
,
current_min
,
module
.
ema_decay
)
module
.
tracked_max_output
=
update_ema
(
module
.
tracked_max_output
,
current_max
,
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
tracked_min_output
=
update_ema
(
module
.
tracked_min_output
,
current_min
,
module
.
ema_decay
)
tracked_max_output
=
update_ema
(
module
.
tracked_max_output
,
current_max
,
module
.
ema_decay
)
module
.
tracked_min_output
.
copy_
(
tracked_min_output
)
module
.
tracked_max_output
.
copy_
(
tracked_max_output
)
scale
,
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min_output
,
module
.
tracked_max_output
)
module
.
scale
.
copy_
(
scale
)
module
.
zero_point
.
copy_
(
zero_point
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
module
,
out
)
...
...
@@ -645,7 +665,7 @@ class QAT_Quantizer(Quantizer):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self
.
bound_model
.
steps
+=
1
self
.
bound_model
.
steps
.
add_
(
1
)
class
DoReFaQuantizer
(
Quantizer
):
...
...
nni/compression/pytorch/compressor.py
View file @
396ae65c
...
...
@@ -602,6 +602,8 @@ class Quantizer(Compressor):
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
dummy_input
=
None
):
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
):
model
=
model
.
module
self
.
identity_wrappers
=
[]
self
.
conv_bn_patterns
=
{}
self
.
find_conv_bn_patterns
(
model
,
dummy_input
)
...
...
@@ -892,12 +894,21 @@ class QuantGrad(torch.autograd.Function):
zero_point
=
wrapper
.
module
.
zero_point
else
:
scale
,
zero_point
=
None
,
None
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]),
scale
,
zero_point
,
qmin
,
qmax
)
ctx
.
save_for_backward
(
tensor
)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
ctx
.
scale
=
scale
ctx
.
zero_point
=
zero_point
ctx
.
quant_type
=
quant_type
ctx
.
qmin
,
ctx
.
qmax
=
qmin
,
qmax
return
output
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
=
ctx
.
saved_variables
tensor
=
ctx
.
saved_variables
[
0
]
scale
,
zero_point
=
ctx
.
scale
,
ctx
.
zero_point
qmin
,
qmax
=
ctx
.
qmin
,
ctx
.
qmax
quant_type
=
ctx
.
quant_type
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
)
return
output
,
None
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment