Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7ee5036b
Unverified
Commit
7ee5036b
authored
Jun 11, 2020
by
Ningxin Zheng
Committed by
GitHub
Jun 11, 2020
Browse files
Bugfix issue2485 (#2524)
parent
e1e1977c
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
222 additions
and
72 deletions
+222
-72
src/sdk/pynni/nni/_graph_utils.py
src/sdk/pynni/nni/_graph_utils.py
+181
-71
src/sdk/pynni/tests/test_graph_utils.py
src/sdk/pynni/tests/test_graph_utils.py
+41
-1
No files found.
src/sdk/pynni/nni/_graph_utils.py
View file @
7ee5036b
This diff is collapsed.
Click to expand it.
src/sdk/pynni/tests/test_graph_utils.py
View file @
7ee5036b
...
...
@@ -15,7 +15,7 @@ from google.protobuf import text_format
import
unittest
from
unittest
import
TestCase
,
main
from
nni._graph_utils
import
build_module_graph
,
build_graph
from
nni._graph_utils
import
build_module_graph
,
build_graph
,
TorchModuleGraph
class
BackboneModel1
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -154,5 +154,45 @@ class GraphUtilsTestCase(TestCase):
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"expect"
,
"test_graph_module3.expect"
)
)
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4.0"
,
"not supported"
)
def
test_module_reuse
(
self
):
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
liner1
=
nn
.
Linear
(
10
,
10
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
liner2
=
nn
.
Linear
(
10
,
20
)
self
.
liner3
=
nn
.
Linear
(
20
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
liner1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
liner2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
liner3
(
x
)
x
=
self
.
relu
(
x
)
return
x
data
=
torch
.
rand
(
10
,
10
)
net
=
MyModule
()
traced
=
torch
.
jit
.
trace
(
net
,
data
)
modulegraph
=
TorchModuleGraph
(
traced_model
=
traced
)
# Traverse the TorchModuleGraph, due the resue of the relu module,
# there will be three cpp_nodes corrspoding to the same module.
# During traversing the graph, there should be only one
# successor of each cpp-node (including the cpp_nodes that corresponds
# to the same relu module).
for
name
,
nodeio
in
modulegraph
.
nodes_py
.
nodes_io
.
items
():
if
nodeio
.
input_or_output
==
'input'
:
# Find the first node of the whole graph
start_nodes
=
modulegraph
.
input_to_node
[
name
]
# We have only one single path top-down
assert
len
(
start_nodes
)
==
1
node
=
start_nodes
[
0
].
unique_name
while
modulegraph
.
find_successors
(
node
):
nodes
=
modulegraph
.
find_successors
(
node
)
assert
len
(
nodes
)
==
1
node
=
nodes
[
0
]
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment