Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
68818a33
Unverified
Commit
68818a33
authored
Jul 25, 2021
by
Bill Wu
Committed by
GitHub
Jul 26, 2021
Browse files
[Model Compression] Expand export_model arguments: dummy input and onnx opset_version (#3968)
parent
deef0c42
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
10 deletions
+36
-10
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+36
-10
No files found.
nni/compression/pytorch/compressor.py
View file @
68818a33
...
@@ -375,7 +375,8 @@ class Pruner(Compressor):
...
@@ -375,7 +375,8 @@ class Pruner(Compressor):
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
return
wrapper
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
,
dummy_input
=
None
,
opset_version
=
None
):
"""
"""
Export pruned model weights, masks and onnx model(optional)
Export pruned model weights, masks and onnx model(optional)
...
@@ -388,10 +389,21 @@ class Pruner(Compressor):
...
@@ -388,10 +389,21 @@ class Pruner(Compressor):
onnx_path : str
onnx_path : str
(optional) path to save onnx model
(optional) path to save onnx model
input_shape : list or tuple
input_shape : list or tuple
input shape to onnx model
input shape to onnx model, used for creating a dummy input tensor for torch.onnx.export
if the input has a complex structure (e.g., a tuple), please directly create the input and
pass it to dummy_input instead
note: this argument is deprecated and will be removed; please use dummy_input instead
device : torch.device
device : torch.device
device of the model,
used
to place the dummy input tensor for exporting onnx file
.
device of the model,
where
to place the dummy input tensor for exporting onnx file
;
the tensor is placed on cpu if ```device``` is None
the tensor is placed on cpu if ```device``` is None
only useful when both onnx_path and input_shape are passed
note: this argument is deprecated and will be removed; please use dummy_input instead
dummy_input: torch.Tensor or tuple
dummy input to the onnx model; used when input_shape is not enough to specify dummy input
user should ensure that the dummy_input is on the same device as the model
opset_version: int
opset_version parameter for torch.onnx.export; only useful when onnx_path is not None
if not passed, torch.onnx.export will use its default opset_version
"""
"""
assert
model_path
is
not
None
,
'model_path must be specified'
assert
model_path
is
not
None
,
'model_path must be specified'
mask_dict
=
{}
mask_dict
=
{}
...
@@ -412,17 +424,31 @@ class Pruner(Compressor):
...
@@ -412,17 +424,31 @@ class Pruner(Compressor):
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
if
mask_path
is
not
None
:
if
mask_path
is
not
None
:
torch
.
save
(
mask_dict
,
mask_path
)
torch
.
save
(
mask_dict
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
if
onnx_path
is
not
None
:
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
assert
input_shape
is
not
None
or
dummy_input
is
not
None
,
\
# input info needed
'input_shape or dummy_input must be specified to export onnx model'
# create dummy_input using input_shape if input_shape is not passed
if
dummy_input
is
None
:
_logger
.
warning
(
"""The argument input_shape and device will be removed in the future.
Please create a dummy input and pass it to dummy_input instead."""
)
if
device
is
None
:
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
input_data
=
torch
.
Tensor
(
*
input_shape
).
to
(
device
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
.
to
(
device
),
onnx_path
)
else
:
input_data
=
dummy_input
if
opset_version
is
not
None
:
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
,
opset_version
=
opset_version
)
else
:
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
)
if
dummy_input
is
None
:
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
else
:
_logger
.
info
(
'Model in onnx saved to %s'
,
onnx_path
)
self
.
_wrap_model
()
self
.
_wrap_model
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment