Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
68818a33
Unverified
Commit
68818a33
authored
Jul 25, 2021
by
Bill Wu
Committed by
GitHub
Jul 26, 2021
Browse files
[Model Compression] Expand export_model arguments: dummy input and onnx opset_version (#3968)
parent
deef0c42
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
10 deletions
+36
-10
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+36
-10
No files found.
nni/compression/pytorch/compressor.py
View file @
68818a33
...
@@ -375,7 +375,8 @@ class Pruner(Compressor):
...
@@ -375,7 +375,8 @@ class Pruner(Compressor):
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
return
wrapper
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
,
dummy_input
=
None
,
opset_version
=
None
):
"""
"""
Export pruned model weights, masks and onnx model(optional)
Export pruned model weights, masks and onnx model(optional)
...
@@ -388,10 +389,21 @@ class Pruner(Compressor):
...
@@ -388,10 +389,21 @@ class Pruner(Compressor):
onnx_path : str
onnx_path : str
(optional) path to save onnx model
(optional) path to save onnx model
input_shape : list or tuple
input_shape : list or tuple
input shape to onnx model
input shape to onnx model, used for creating a dummy input tensor for torch.onnx.export
if the input has a complex structure (e.g., a tuple), please directly create the input and
pass it to dummy_input instead
note: this argument is deprecated and will be removed; please use dummy_input instead
device : torch.device
device : torch.device
device of the model,
used
to place the dummy input tensor for exporting onnx file
.
device of the model,
where
to place the dummy input tensor for exporting onnx file
;
the tensor is placed on cpu if ```device``` is None
the tensor is placed on cpu if ```device``` is None
only useful when both onnx_path and input_shape are passed
note: this argument is deprecated and will be removed; please use dummy_input instead
dummy_input: torch.Tensor or tuple
dummy input to the onnx model; used when input_shape is not enough to specify dummy input
user should ensure that the dummy_input is on the same device as the model
opset_version: int
opset_version parameter for torch.onnx.export; only useful when onnx_path is not None
if not passed, torch.onnx.export will use its default opset_version
"""
"""
assert
model_path
is
not
None
,
'model_path must be specified'
assert
model_path
is
not
None
,
'model_path must be specified'
mask_dict
=
{}
mask_dict
=
{}
...
@@ -412,17 +424,31 @@ class Pruner(Compressor):
...
@@ -412,17 +424,31 @@ class Pruner(Compressor):
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
if
mask_path
is
not
None
:
if
mask_path
is
not
None
:
torch
.
save
(
mask_dict
,
mask_path
)
torch
.
save
(
mask_dict
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
if
onnx_path
is
not
None
:
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
assert
input_shape
is
not
None
or
dummy_input
is
not
None
,
\
# input info needed
'input_shape or dummy_input must be specified to export onnx model'
if
device
is
None
:
# create dummy_input using input_shape if input_shape is not passed
device
=
torch
.
device
(
'cpu'
)
if
dummy_input
is
None
:
input_data
=
torch
.
Tensor
(
*
input_shape
)
_logger
.
warning
(
"""The argument input_shape and device will be removed in the future.
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
.
to
(
device
),
onnx_path
)
Please create a dummy input and pass it to dummy_input instead."""
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
input_data
=
torch
.
Tensor
(
*
input_shape
).
to
(
device
)
else
:
input_data
=
dummy_input
if
opset_version
is
not
None
:
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
,
opset_version
=
opset_version
)
else
:
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
)
if
dummy_input
is
None
:
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
else
:
_logger
.
info
(
'Model in onnx saved to %s'
,
onnx_path
)
self
.
_wrap_model
()
self
.
_wrap_model
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment