Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
06e438b7
Unverified
Commit
06e438b7
authored
Feb 03, 2021
by
Ningxin Zheng
Committed by
GitHub
Feb 03, 2021
Browse files
support the scenario that there are duplicate tensors in a same tuple (#3340)
parent
7d6b8b3b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
19 deletions
+21
-19
nni/common/graph_utils.py
nni/common/graph_utils.py
+21
-19
No files found.
nni/common/graph_utils.py
View file @
06e438b7
...
@@ -285,8 +285,8 @@ class TorchModuleGraph(TorchGraph):
...
@@ -285,8 +285,8 @@ class TorchModuleGraph(TorchGraph):
self
.
global_count
+=
1
self
.
global_count
+=
1
op_type
=
node
.
kind
()
op_type
=
node
.
kind
()
node_group
=
[
node
]
node_group
=
[
node
]
inputs
=
set
()
inputs
=
[]
outputs
=
set
()
outputs
=
[]
node_queue
=
queue
.
Queue
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
node_queue
.
put
(
node
)
while
not
node_queue
.
empty
():
while
not
node_queue
.
empty
():
...
@@ -303,17 +303,17 @@ class TorchModuleGraph(TorchGraph):
...
@@ -303,17 +303,17 @@ class TorchModuleGraph(TorchGraph):
node_group
.
append
(
predecessor_node
)
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
else
:
inputs
.
a
d
d
(
input_name
)
inputs
.
a
ppen
d
(
input_name
)
else
:
else
:
inputs
.
a
d
d
(
input_name
)
inputs
.
a
ppen
d
(
input_name
)
else
:
else
:
inputs
.
a
d
d
(
input_name
)
inputs
.
a
ppen
d
(
input_name
)
for
output
in
node
.
outputs
():
for
output
in
node
.
outputs
():
if
output
.
node
().
kind
()
==
CONSTANT_KIND
:
if
output
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
continue
outputs
.
a
d
d
(
output
.
debugName
())
outputs
.
a
ppen
d
(
output
.
debugName
())
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
list
(
inputs
)
,
outputs
=
list
(
outputs
)
,
key_node
=
node
)
node_group
,
inputs
=
inputs
,
outputs
=
outputs
,
key_node
=
node
)
return
nodepy
return
nodepy
def
_expand_module_node
(
self
,
node
,
node_name
,
unique_name
,
op_type
,
nodes
,
def
_expand_module_node
(
self
,
node
,
node_name
,
unique_name
,
op_type
,
nodes
,
...
@@ -353,8 +353,8 @@ class TorchModuleGraph(TorchGraph):
...
@@ -353,8 +353,8 @@ class TorchModuleGraph(TorchGraph):
if
not
op_type
:
if
not
op_type
:
op_type
=
node
.
kind
()
op_type
=
node
.
kind
()
node_group
=
[
node
]
node_group
=
[
node
]
inputs
=
set
()
inputs
=
[]
outputs
=
set
()
outputs
=
[]
node_queue
=
queue
.
Queue
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
node_queue
.
put
(
node
)
visited
=
{
node
}
visited
=
{
node
}
...
@@ -372,9 +372,9 @@ class TorchModuleGraph(TorchGraph):
...
@@ -372,9 +372,9 @@ class TorchModuleGraph(TorchGraph):
node_queue
.
put
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
visited
.
add
(
predecessor_node
)
visited
.
add
(
predecessor_node
)
else
:
else
:
inputs
.
a
d
d
(
input_name
)
inputs
.
a
ppen
d
(
input_name
)
else
:
else
:
inputs
.
a
d
d
(
input_name
)
inputs
.
a
ppen
d
(
input_name
)
for
_output
in
curr_node
.
outputs
():
for
_output
in
curr_node
.
outputs
():
if
_output
.
node
().
kind
()
==
CONSTANT_KIND
:
if
_output
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
continue
...
@@ -387,9 +387,9 @@ class TorchModuleGraph(TorchGraph):
...
@@ -387,9 +387,9 @@ class TorchModuleGraph(TorchGraph):
node_queue
.
put
(
successor_node
)
node_queue
.
put
(
successor_node
)
visited
.
add
(
successor_node
)
visited
.
add
(
successor_node
)
else
:
else
:
outputs
.
a
d
d
(
output_name
)
outputs
.
a
ppen
d
(
output_name
)
else
:
else
:
outputs
.
a
d
d
(
output_name
)
outputs
.
a
ppen
d
(
output_name
)
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
list
(
inputs
),
outputs
=
list
(
outputs
))
node_group
,
inputs
=
list
(
inputs
),
outputs
=
list
(
outputs
))
...
@@ -562,9 +562,12 @@ class TorchModuleGraph(TorchGraph):
...
@@ -562,9 +562,12 @@ class TorchModuleGraph(TorchGraph):
for
node
in
nodes_op
:
for
node
in
nodes_op
:
name_to_node
[
node
.
unique_name
]
=
node
name_to_node
[
node
.
unique_name
]
=
node
for
_input
in
node
.
inputs
:
for
_input
in
node
.
inputs
:
# inputs may have duplicate tensors
if
node
not
in
input_to_node
[
_input
]:
input_to_node
[
_input
].
append
(
node
)
input_to_node
[
_input
].
append
(
node
)
for
output
in
node
.
outputs
:
for
output
in
node
.
outputs
:
assert
not
output
in
output_to_node
,
\
if
output
in
output_to_node
:
assert
output_to_node
[
output
]
==
node
,
\
"One output cannot be generated by multiple nodes %s"
%
output
"One output cannot be generated by multiple nodes %s"
%
output
output_to_node
[
output
]
=
node
output_to_node
[
output
]
=
node
return
name_to_node
,
input_to_node
,
output_to_node
return
name_to_node
,
input_to_node
,
output_to_node
...
@@ -619,8 +622,6 @@ class TorchModuleGraph(TorchGraph):
...
@@ -619,8 +622,6 @@ class TorchModuleGraph(TorchGraph):
assert
len
(
node
.
inputs
)
==
len
(
list
(
last_cpp
.
inputs
())),
errmsg
assert
len
(
node
.
inputs
)
==
len
(
list
(
last_cpp
.
inputs
())),
errmsg
for
_debug_input
,
_debug_output
in
zip
(
node
.
inputs
,
node
.
outputs
):
for
_debug_input
,
_debug_output
in
zip
(
node
.
inputs
,
node
.
outputs
):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if
_debug_input
in
self
.
input_to_node
and
_debug_output
in
self
.
input_to_node
:
if
_debug_input
in
self
.
input_to_node
and
_debug_output
in
self
.
input_to_node
:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
# one tensor can be used as input for multiple nodes at the same time.
...
@@ -629,6 +630,7 @@ class TorchModuleGraph(TorchGraph):
...
@@ -629,6 +630,7 @@ class TorchModuleGraph(TorchGraph):
# will be merged into the same NodePyGroup, so we remove the `node` from
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
# input_to_node[_debug_output]
if
node
in
self
.
input_to_node
[
_debug_input
]:
self
.
input_to_node
[
_debug_input
].
remove
(
node
)
self
.
input_to_node
[
_debug_input
].
remove
(
node
)
# add the following nodes of _output into the input_to_node[_debug_input]
# add the following nodes of _output into the input_to_node[_debug_input]
self
.
input_to_node
[
_debug_input
].
extend
(
self
.
input_to_node
[
_debug_output
])
self
.
input_to_node
[
_debug_input
].
extend
(
self
.
input_to_node
[
_debug_output
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment