Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5dc80762
Unverified
Commit
5dc80762
authored
May 13, 2022
by
J-shang
Committed by
GitHub
May 13, 2022
Browse files
[Compression] pruning speedup support RecursiveScriptModule (#4801)
* support RecursiveScriptModule
parent
9644cf69
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
nni/common/graph_utils.py
nni/common/graph_utils.py
+2
-2
No files found.
nni/common/graph_utils.py
View file @
5dc80762
...
@@ -57,7 +57,7 @@ class TorchGraph:
...
@@ -57,7 +57,7 @@ class TorchGraph:
assert
torch
.
__version__
>=
'1.3.1'
assert
torch
.
__version__
>=
'1.3.1'
# check if the input is legal
# check if the input is legal
if
traced_model
is
not
None
:
if
traced_model
is
not
None
:
assert
isinstance
(
traced_model
,
torch
.
jit
.
TopLevelTracedModule
)
assert
isinstance
(
traced_model
,
torch
.
jit
.
TopLevelTracedModule
)
or
isinstance
(
traced_model
,
torch
.
jit
.
RecursiveScriptModule
)
self
.
trace
=
traced_model
self
.
trace
=
traced_model
# it's ok if the graph is already unpacked
# it's ok if the graph is already unpacked
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
...
@@ -709,7 +709,7 @@ class TorchModuleGraph(TorchGraph):
...
@@ -709,7 +709,7 @@ class TorchModuleGraph(TorchGraph):
self
.
leaf_modules
=
self
.
_extract_leaf_modules
()
self
.
leaf_modules
=
self
.
_extract_leaf_modules
()
module_to_type
=
{
name
:
parse_traced_name
(
module_to_type
=
{
name
:
parse_traced_name
(
module
.
_name
)
for
name
,
module
in
self
.
trace
.
named_modules
()}
module
.
_name
if
hasattr
(
module
,
'_name'
)
else
module
.
original_name
)
for
name
,
module
in
self
.
trace
.
named_modules
()}
# associate module name with their trace graph nodes
# associate module name with their trace graph nodes
for
node
in
graph
.
nodes
():
for
node
in
graph
.
nodes
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment