Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f51d985b
Unverified
Commit
f51d985b
authored
Mar 24, 2021
by
lin bin
Committed by
GitHub
Mar 24, 2021
Browse files
Add model export for QAT (#3458)
parent
4635b559
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
266 additions
and
2 deletions
+266
-2
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+156
-1
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+61
-1
test/ut/sdk/test_compressor_torch.py
test/ut/sdk/test_compressor_torch.py
+49
-0
No files found.
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
f51d985b
...
@@ -110,7 +110,6 @@ def get_bits_length(config, quant_type):
...
@@ -110,7 +110,6 @@ def get_bits_length(config, quant_type):
else
:
else
:
return
config
[
"quant_bits"
].
get
(
quant_type
)
return
config
[
"quant_bits"
].
get
(
quant_type
)
class
QATGrad
(
QuantGrad
):
class
QATGrad
(
QuantGrad
):
@
staticmethod
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
):
def
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
):
...
@@ -153,13 +152,26 @@ class QAT_Quantizer(Quantizer):
...
@@ -153,13 +152,26 @@ class QAT_Quantizer(Quantizer):
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
layer
.
module
.
register_buffer
(
"zero_point"
,
torch
.
Tensor
([
0.0
]))
layer
.
module
.
register_buffer
(
"zero_point"
,
torch
.
Tensor
([
0.0
]))
layer
.
module
.
register_buffer
(
"scale"
,
torch
.
Tensor
([
1.0
]))
layer
.
module
.
register_buffer
(
"scale"
,
torch
.
Tensor
([
1.0
]))
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
if
"output"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'activation_bit'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
layer
.
module
.
register_buffer
(
'ema_decay'
,
torch
.
Tensor
([
0.99
]))
layer
.
module
.
register_buffer
(
'tracked_min_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_min'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max_biased'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max'
,
torch
.
zeros
(
1
))
layer
.
module
.
register_buffer
(
'tracked_max'
,
torch
.
zeros
(
1
))
def
_del_simulated_attr
(
self
,
module
):
"""
delete redundant parameters in quantize module
"""
del_attr_list
=
[
'old_weight'
,
'ema_decay'
,
'tracked_min_biased'
,
'tracked_max_biased'
,
'tracked_min'
,
\
'tracked_max'
,
'scale'
,
'zero_point'
,
'weight_bit'
,
'activation_bit'
]
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
def
validate_config
(
self
,
model
,
config_list
):
def
validate_config
(
self
,
model
,
config_list
):
"""
"""
Parameters
Parameters
...
@@ -256,6 +268,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -256,6 +268,7 @@ class QAT_Quantizer(Quantizer):
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
weight
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
weight
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
weight
=
self
.
_dequantize
(
module
,
weight
)
weight
=
self
.
_dequantize
(
module
,
weight
)
module
.
weight_bit
=
torch
.
Tensor
([
weight_bits
])
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
return
weight
return
weight
...
@@ -263,6 +276,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -263,6 +276,7 @@ class QAT_Quantizer(Quantizer):
config
=
wrapper
.
config
config
=
wrapper
.
config
module
=
wrapper
.
module
module
=
wrapper
.
module
output_bits
=
get_bits_length
(
config
,
'output'
)
output_bits
=
get_bits_length
(
config
,
'output'
)
module
.
activation_bit
=
torch
.
Tensor
([
output_bits
])
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
...
@@ -282,6 +296,47 @@ class QAT_Quantizer(Quantizer):
...
@@ -282,6 +296,47 @@ class QAT_Quantizer(Quantizer):
out
=
self
.
_dequantize
(
module
,
out
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
return
out
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert
model_path
is
not
None
,
'model_path must be specified'
self
.
_unwrap_model
()
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
hasattr
(
module
,
'weight_bit'
)
or
hasattr
(
module
,
'activation_bit'
):
calibration_config
[
name
]
=
{}
if
hasattr
(
module
,
'weight_bit'
):
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
if
hasattr
(
module
,
'activation_bit'
):
calibration_config
[
name
][
'activation_bit'
]
=
int
(
module
.
activation_bit
)
calibration_config
[
name
][
'tracked_min'
]
=
float
(
module
.
tracked_min_biased
)
calibration_config
[
name
][
'tracked_max'
]
=
float
(
module
.
tracked_max_biased
)
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
return
calibration_config
def
fold_bn
(
self
,
config
,
**
kwargs
):
def
fold_bn
(
self
,
config
,
**
kwargs
):
# TODO simulate folded weight
# TODO simulate folded weight
pass
pass
...
@@ -301,6 +356,19 @@ class DoReFaQuantizer(Quantizer):
...
@@ -301,6 +356,19 @@ class DoReFaQuantizer(Quantizer):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
def
_del_simulated_attr
(
self
,
module
):
"""
delete redundant parameters in quantize module
"""
del_attr_list
=
[
'old_weight'
,
'weight_bit'
]
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
def
validate_config
(
self
,
model
,
config_list
):
def
validate_config
(
self
,
model
,
config_list
):
"""
"""
...
@@ -330,6 +398,7 @@ class DoReFaQuantizer(Quantizer):
...
@@ -330,6 +398,7 @@ class DoReFaQuantizer(Quantizer):
weight
=
self
.
quantize
(
weight
,
weight_bits
)
weight
=
self
.
quantize
(
weight
,
weight_bits
)
weight
=
2
*
weight
-
1
weight
=
2
*
weight
-
1
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight_bit
=
torch
.
Tensor
([
weight_bits
])
# wrapper.module.weight.data = weight
# wrapper.module.weight.data = weight
return
weight
return
weight
...
@@ -338,6 +407,42 @@ class DoReFaQuantizer(Quantizer):
...
@@ -338,6 +407,42 @@ class DoReFaQuantizer(Quantizer):
output
=
torch
.
round
(
input_ri
*
scale
)
/
scale
output
=
torch
.
round
(
input_ri
*
scale
)
/
scale
return
output
return
output
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert
model_path
is
not
None
,
'model_path must be specified'
self
.
_unwrap_model
()
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
hasattr
(
module
,
'weight_bit'
):
calibration_config
[
name
]
=
{}
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
return
calibration_config
class
ClipGrad
(
QuantGrad
):
class
ClipGrad
(
QuantGrad
):
@
staticmethod
@
staticmethod
...
@@ -356,6 +461,19 @@ class BNNQuantizer(Quantizer):
...
@@ -356,6 +461,19 @@ class BNNQuantizer(Quantizer):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
ClipGrad
self
.
quant_grad
=
ClipGrad
modules_to_compress
=
self
.
get_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
if
"weight"
in
config
.
get
(
"quant_types"
,
[]):
layer
.
module
.
register_buffer
(
'weight_bit'
,
torch
.
zeros
(
1
))
def
_del_simulated_attr
(
self
,
module
):
"""
delete redundant parameters in quantize module
"""
del_attr_list
=
[
'old_weight'
,
'weight_bit'
]
for
attr
in
del_attr_list
:
if
hasattr
(
module
,
attr
):
delattr
(
module
,
attr
)
def
validate_config
(
self
,
model
,
config_list
):
def
validate_config
(
self
,
model
,
config_list
):
"""
"""
...
@@ -384,6 +502,7 @@ class BNNQuantizer(Quantizer):
...
@@ -384,6 +502,7 @@ class BNNQuantizer(Quantizer):
# remove zeros
# remove zeros
weight
[
weight
==
0
]
=
1
weight
[
weight
==
0
]
=
1
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight_bit
=
torch
.
Tensor
([
1.0
])
return
weight
return
weight
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
...
@@ -391,3 +510,39 @@ class BNNQuantizer(Quantizer):
...
@@ -391,3 +510,39 @@ class BNNQuantizer(Quantizer):
# remove zeros
# remove zeros
out
[
out
==
0
]
=
1
out
[
out
==
0
]
=
1
return
out
return
out
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert
model_path
is
not
None
,
'model_path must be specified'
self
.
_unwrap_model
()
calibration_config
=
{}
for
name
,
module
in
self
.
bound_model
.
named_modules
():
if
hasattr
(
module
,
'weight_bit'
):
calibration_config
[
name
]
=
{}
calibration_config
[
name
][
'weight_bit'
]
=
int
(
module
.
weight_bit
)
self
.
_del_simulated_attr
(
module
)
self
.
export_model_save
(
self
.
bound_model
,
model_path
,
calibration_config
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
return
calibration_config
\ No newline at end of file
nni/compression/pytorch/compressor.py
View file @
f51d985b
...
@@ -21,7 +21,6 @@ def _setattr(model, name, module):
...
@@ -21,7 +21,6 @@ def _setattr(model, name, module):
model
=
getattr
(
model
,
name
)
model
=
getattr
(
model
,
name
)
setattr
(
model
,
name_list
[
-
1
],
module
)
setattr
(
model
,
name_list
[
-
1
],
module
)
class
Compressor
:
class
Compressor
:
"""
"""
Abstract base PyTorch compressor
Abstract base PyTorch compressor
...
@@ -573,6 +572,67 @@ class Quantizer(Compressor):
...
@@ -573,6 +572,67 @@ class Quantizer(Compressor):
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
def
export_model_save
(
self
,
model
,
model_path
,
calibration_config
=
None
,
calibration_path
=
None
,
onnx_path
=
None
,
\
input_shape
=
None
,
device
=
None
):
"""
This method helps save pytorch model, calibration config, onnx model in quantizer.
Parameters
----------
model : pytorch model
pytorch model to be saved
model_path : str
path to save pytorch
calibration_config: dict
(optional) config of calibration parameters
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
torch
.
save
(
model
.
state_dict
(),
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
if
calibration_path
is
not
None
:
torch
.
save
(
calibration_config
,
calibration_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
calibration_path
)
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
# input info needed
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
.
to
(
device
),
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
def
export_model
(
self
,
model_path
,
calibration_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
Export quantized model weights and calibration parameters
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
raise
NotImplementedError
(
'Quantizer must overload export_model()'
)
def
step_with_optimizer
(
self
):
def
step_with_optimizer
(
self
):
pass
pass
...
...
test/ut/sdk/test_compressor_torch.py
View file @
f51d985b
...
@@ -274,6 +274,55 @@ class CompressorTestCase(TestCase):
...
@@ -274,6 +274,55 @@ class CompressorTestCase(TestCase):
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_min_biased
,
0.002
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
assert
math
.
isclose
(
model
.
relu
.
module
.
tracked_max_biased
,
0.00998
,
abs_tol
=
eps
)
def
test_torch_quantizer_export
(
self
):
config_list_qat
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
8
,
'op_types'
:
[
'Conv2d'
,
'Linear'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
8
,
'quant_start_step'
:
0
,
'op_types'
:
[
'ReLU'
]
}]
config_list_dorefa
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
{
'weight'
:
8
,
},
# you can just use `int` here because all `quan_types` share same bits length, see config for `ReLu6` below.
'op_types'
:[
'Conv2d'
,
'Linear'
]
}]
config_list_bnn
=
[{
'quant_types'
:
[
'weight'
],
'quant_bits'
:
1
,
'op_types'
:
[
'Conv2d'
,
'Linear'
]
},
{
'quant_types'
:
[
'output'
],
'quant_bits'
:
1
,
'op_types'
:
[
'ReLU'
]
}]
config_set
=
[
config_list_qat
,
config_list_dorefa
,
config_list_bnn
]
quantize_algorithm_set
=
[
torch_quantizer
.
QAT_Quantizer
,
torch_quantizer
.
DoReFaQuantizer
,
torch_quantizer
.
BNNQuantizer
]
for
config
,
quantize_algorithm
in
zip
(
config_set
,
quantize_algorithm_set
):
model
=
TorchModel
()
model
.
relu
=
torch
.
nn
.
ReLU
()
quantizer
=
quantize_algorithm
(
model
,
config
)
quantizer
.
compress
()
x
=
torch
.
rand
((
1
,
1
,
28
,
28
),
requires_grad
=
True
)
y
=
model
(
x
)
y
.
backward
(
torch
.
ones_like
(
y
))
model_path
=
"test_model.pth"
calibration_path
=
"test_calibration.pth"
onnx_path
=
"test_model.onnx"
input_shape
=
(
1
,
1
,
28
,
28
)
device
=
torch
.
device
(
"cpu"
)
calibration_config
=
quantizer
.
export_model
(
model_path
,
calibration_path
,
onnx_path
,
input_shape
,
device
)
assert
calibration_config
is
not
None
def
test_torch_pruner_validation
(
self
):
def
test_torch_pruner_validation
(
self
):
# test bad configuraiton
# test bad configuraiton
pruner_classes
=
[
torch_pruner
.
__dict__
[
x
]
for
x
in
\
pruner_classes
=
[
torch_pruner
.
__dict__
[
x
]
for
x
in
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment