Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
06e438b7
Unverified
Commit
06e438b7
authored
Feb 03, 2021
by
Ningxin Zheng
Committed by
GitHub
Feb 03, 2021
Browse files
support the scenario that there are duplicate tensors in a same tuple (#3340)
parent
7d6b8b3b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
19 deletions
+21
-19
nni/common/graph_utils.py
nni/common/graph_utils.py
+21
-19
No files found.
nni/common/graph_utils.py
View file @
06e438b7
...
...
@@ -285,8 +285,8 @@ class TorchModuleGraph(TorchGraph):
self
.
global_count
+=
1
op_type
=
node
.
kind
()
node_group
=
[
node
]
inputs
=
set
()
outputs
=
set
()
inputs
=
[]
outputs
=
[]
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
while
not
node_queue
.
empty
():
...
...
@@ -303,17 +303,17 @@ class TorchModuleGraph(TorchGraph):
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
inputs
.
a
d
d
(
input_name
)
inputs
.
a
ppen
d
(
input_name
)
else
:
inputs
.
a
d
d
(
input_name
)
inputs
.
a
ppen
d
(
input_name
)
else
:
inputs
.
a
d
d
(
input_name
)
inputs
.
a
ppen
d
(
input_name
)
for
output
in
node
.
outputs
():
if
output
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
outputs
.
a
d
d
(
output
.
debugName
())
outputs
.
a
ppen
d
(
output
.
debugName
())
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
list
(
inputs
)
,
outputs
=
list
(
outputs
)
,
key_node
=
node
)
node_group
,
inputs
=
inputs
,
outputs
=
outputs
,
key_node
=
node
)
return
nodepy
def
_expand_module_node
(
self
,
node
,
node_name
,
unique_name
,
op_type
,
nodes
,
...
...
@@ -353,8 +353,8 @@ class TorchModuleGraph(TorchGraph):
if
not
op_type
:
op_type
=
node
.
kind
()
node_group
=
[
node
]
inputs
=
set
()
outputs
=
set
()
inputs
=
[]
outputs
=
[]
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
visited
=
{
node
}
...
...
@@ -372,9 +372,9 @@ class TorchModuleGraph(TorchGraph):
node_queue
.
put
(
predecessor_node
)
visited
.
add
(
predecessor_node
)
else
:
inputs
.
a
d
d
(
input_name
)
inputs
.
a
ppen
d
(
input_name
)
else
:
inputs
.
a
d
d
(
input_name
)
inputs
.
a
ppen
d
(
input_name
)
for
_output
in
curr_node
.
outputs
():
if
_output
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
...
...
@@ -387,9 +387,9 @@ class TorchModuleGraph(TorchGraph):
node_queue
.
put
(
successor_node
)
visited
.
add
(
successor_node
)
else
:
outputs
.
a
d
d
(
output_name
)
outputs
.
a
ppen
d
(
output_name
)
else
:
outputs
.
a
d
d
(
output_name
)
outputs
.
a
ppen
d
(
output_name
)
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
list
(
inputs
),
outputs
=
list
(
outputs
))
...
...
@@ -562,10 +562,13 @@ class TorchModuleGraph(TorchGraph):
for
node
in
nodes_op
:
name_to_node
[
node
.
unique_name
]
=
node
for
_input
in
node
.
inputs
:
input_to_node
[
_input
].
append
(
node
)
# inputs may have duplicate tensors
if
node
not
in
input_to_node
[
_input
]:
input_to_node
[
_input
].
append
(
node
)
for
output
in
node
.
outputs
:
assert
not
output
in
output_to_node
,
\
"One output cannot be generated by multiple nodes %s"
%
output
if
output
in
output_to_node
:
assert
output_to_node
[
output
]
==
node
,
\
"One output cannot be generated by multiple nodes %s"
%
output
output_to_node
[
output
]
=
node
return
name_to_node
,
input_to_node
,
output_to_node
...
...
@@ -619,8 +622,6 @@ class TorchModuleGraph(TorchGraph):
assert
len
(
node
.
inputs
)
==
len
(
list
(
last_cpp
.
inputs
())),
errmsg
for
_debug_input
,
_debug_output
in
zip
(
node
.
inputs
,
node
.
outputs
):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if
_debug_input
in
self
.
input_to_node
and
_debug_output
in
self
.
input_to_node
:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
...
...
@@ -629,7 +630,8 @@ class TorchModuleGraph(TorchGraph):
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
self
.
input_to_node
[
_debug_input
].
remove
(
node
)
if
node
in
self
.
input_to_node
[
_debug_input
]:
self
.
input_to_node
[
_debug_input
].
remove
(
node
)
# add the following nodes of _output into the input_to_node[_debug_input]
self
.
input_to_node
[
_debug_input
].
extend
(
self
.
input_to_node
[
_debug_output
])
# just remove the _debug_output from the grapgh index. So that we can also skip
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment