Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5d2a59fd
Unverified
Commit
5d2a59fd
authored
Aug 12, 2020
by
Ningxin Zheng
Committed by
GitHub
Aug 12, 2020
Browse files
Successive unpack (#2768)
parent
e7fccfb4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
9 deletions
+119
-9
src/sdk/pynni/nni/_graph_utils.py
src/sdk/pynni/nni/_graph_utils.py
+22
-8
src/sdk/pynni/tests/test_graph_utils.py
src/sdk/pynni/tests/test_graph_utils.py
+97
-1
No files found.
src/sdk/pynni/nni/_graph_utils.py
View file @
5d2a59fd
...
@@ -530,8 +530,15 @@ class TorchModuleGraph(TorchGraph):
...
@@ -530,8 +530,15 @@ class TorchModuleGraph(TorchGraph):
return
True
return
True
if
node_cpp
.
kind
()
in
[
LIST_UNPACK_KIND
,
TUPLE_UNPACK_KIND
]:
if
node_cpp
.
kind
()
in
[
LIST_UNPACK_KIND
,
TUPLE_UNPACK_KIND
]:
# We cannot merge the List/Tuple
# We cannot merge the List/Tuple
#
Construct/
Unpack func into other nodes, else it
# Unpack func into other nodes, else it
# may lead to a graph construction error.
# may lead to a graph construction error.
# The reason why we donnot take the construct node
# also as a key node is that `cat` operation node need
# the last(previous) visited node to infer the mask. If
# we take the Construct node as the important node, the
# predecessor of the `cat` node will always be a construct
# node, which means we cannot infer the mask for the cat
# operation.
return
True
return
True
return
False
return
False
...
@@ -556,9 +563,13 @@ class TorchModuleGraph(TorchGraph):
...
@@ -556,9 +563,13 @@ class TorchModuleGraph(TorchGraph):
_logger
.
debug
(
'List/Tuple Construct Node(cpp) %s'
,
str
(
last_cpp
))
_logger
.
debug
(
'List/Tuple Construct Node(cpp) %s'
,
str
(
last_cpp
))
_logger
.
debug
(
'List/Tuple Unpack Node(cpp) %s'
,
str
(
unpack_cpp
))
_logger
.
debug
(
'List/Tuple Unpack Node(cpp) %s'
,
str
(
unpack_cpp
))
assert
len
(
list
(
unpack_cpp
.
outputs
()))
==
len
(
list
(
last_cpp
.
inputs
()))
assert
len
(
list
(
unpack_cpp
.
outputs
()))
==
len
(
list
(
last_cpp
.
inputs
()))
for
_input
,
_output
in
zip
(
last_cpp
.
inputs
(),
unpack_cpp
.
outputs
()):
errmsg
=
'%s Input number: %d if inconsistent with the output number %d'
%
(
unpack_cpp
,
\
_debug_input
=
_input
.
debugName
()
len
(
node
.
inputs
),
len
(
list
(
last_cpp
.
inputs
())))
_debug_output
=
_output
.
debugName
()
assert
len
(
node
.
inputs
)
==
len
(
list
(
last_cpp
.
inputs
())),
errmsg
for
_debug_input
,
_debug_output
in
zip
(
node
.
inputs
,
node
.
outputs
):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if
_debug_input
in
self
.
input_to_node
and
_debug_output
in
self
.
input_to_node
:
if
_debug_input
in
self
.
input_to_node
and
_debug_output
in
self
.
input_to_node
:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
# one tensor can be used as input for multiple nodes at the same time.
...
@@ -570,10 +581,13 @@ class TorchModuleGraph(TorchGraph):
...
@@ -570,10 +581,13 @@ class TorchModuleGraph(TorchGraph):
self
.
input_to_node
[
_debug_input
].
remove
(
node
)
self
.
input_to_node
[
_debug_input
].
remove
(
node
)
# add the following nodes of _output into the input_to_node[_debug_input]
# add the following nodes of _output into the input_to_node[_debug_input]
self
.
input_to_node
[
_debug_input
].
extend
(
self
.
input_to_node
[
_debug_output
])
self
.
input_to_node
[
_debug_input
].
extend
(
self
.
input_to_node
[
_debug_output
])
if
_debug_input
in
self
.
output_to_node
and
_debug_output
in
self
.
output_to_node
:
# just remove the _debug_output from the grapgh index. So that we can also skip
# output_to_node[_debug_output] is a NodePyGroup, because one output
# the construct and tuple
# tensor only can be generated by one node.
if
_debug_output
in
self
.
input_to_node
:
self
.
output_to_node
[
_debug_output
]
=
self
.
output_to_node
[
_debug_input
]
for
following_node
in
self
.
input_to_node
[
_debug_output
]:
_tmp_index
=
following_node
.
inputs
.
index
(
_debug_output
)
following_node
.
inputs
[
_tmp_index
]
=
_debug_input
self
.
unpacked
=
True
self
.
unpacked
=
True
...
...
src/sdk/pynni/tests/test_graph_utils.py
View file @
5d2a59fd
...
@@ -15,7 +15,7 @@ from google.protobuf import text_format
...
@@ -15,7 +15,7 @@ from google.protobuf import text_format
import
unittest
import
unittest
from
unittest
import
TestCase
,
main
from
unittest
import
TestCase
,
main
from
nni._graph_utils
import
build_module_graph
,
build_graph
,
TorchModuleGraph
from
nni._graph_utils
import
build_module_graph
,
build_graph
,
TorchModuleGraph
,
TUPLE_UNPACK_KIND
class
BackboneModel1
(
nn
.
Module
):
class
BackboneModel1
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -194,5 +194,101 @@ class GraphUtilsTestCase(TestCase):
...
@@ -194,5 +194,101 @@ class GraphUtilsTestCase(TestCase):
assert
len
(
nodes
)
==
1
assert
len
(
nodes
)
==
1
node
=
nodes
[
0
]
node
=
nodes
[
0
]
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4.0"
,
"not supported"
)
def
test_module_unpack
(
self
):
"""
test the tuple/list unpack function of TorchModuleGraph.
Following models are from the issue 2756
https://github.com/microsoft/nni/issues/2756.
MyModule will have two successive tuple unpack operations
between the B and C.
"""
class
CBR
(
nn
.
Module
):
def
__init__
(
self
,
i
,
o
):
super
(
CBR
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
i
,
o
,
kernel_size
=
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
o
)
self
.
act1
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
return
self
.
act1
(
self
.
bn1
(
self
.
conv1
(
x
)))
class
A
(
nn
.
Module
):
def
__init__
(
self
):
super
(
A
,
self
).
__init__
()
self
.
conv1
=
CBR
(
3
,
6
,
)
self
.
conv2
=
CBR
(
6
,
8
,
)
self
.
conv3
=
CBR
(
6
,
12
)
def
forward
(
self
,
x
):
x1
=
self
.
conv1
(
x
)
x2
=
self
.
conv2
(
x1
)
x3
=
self
.
conv3
(
x1
)
return
(
x2
,
x3
)
class
B1
(
nn
.
Module
):
def
__init__
(
self
):
super
(
B1
,
self
).
__init__
()
self
.
conv1
=
CBR
(
12
,
32
)
self
.
conv2
=
CBR
(
32
,
32
)
self
.
conv3
=
CBR
(
32
,
32
)
def
forward
(
self
,
x
):
x1
=
self
.
conv1
(
x
)
x2
=
self
.
conv2
(
x1
)
x3
=
self
.
conv3
(
x2
)
return
(
x1
,
x2
,
x3
)
class
B
(
nn
.
Module
):
def
__init__
(
self
):
super
(
B
,
self
).
__init__
()
self
.
b
=
B1
()
def
forward
(
self
,
x
):
return
self
.
b
(
x
[
-
1
])
class
C
(
nn
.
Module
):
def
__init__
(
self
):
super
(
C
,
self
).
__init__
()
self
.
conv1
=
CBR
(
8
,
32
)
self
.
conv2
=
CBR
(
12
,
32
)
self
.
conv3
=
CBR
(
32
,
32
)
self
.
conv4
=
CBR
(
32
,
32
)
self
.
conv5
=
CBR
(
32
,
32
)
def
forward
(
self
,
x
):
return
(
self
.
conv1
(
x
[
0
]),
self
.
conv2
(
x
[
1
]),
self
.
conv3
(
x
[
2
]),
self
.
conv4
(
x
[
3
]),
self
.
conv5
(
x
[
4
]))
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MyModule
,
self
).
__init__
()
self
.
a
=
A
()
self
.
b
=
B
()
# self.dummy = Dummy()
self
.
c
=
C
()
def
forward
(
self
,
x
):
x_a
=
self
.
a
(
x
)
x_b
=
self
.
b
(
x_a
)
xc
=
self
.
c
(
x_a
+
x_b
)
return
xc
dummy_input
=
torch
.
rand
(
1
,
3
,
28
,
28
)
model
=
MyModule
()
graph
=
TorchModuleGraph
(
model
,
dummy_input
)
graph
.
unpack_manually
()
for
node
in
graph
.
nodes_py
.
nodes_op
:
# The input of the function nodes should
# not come from the TupleUnpack node, because
# all the TupleUnpack nodes have been removed(unpacked)
# manually
for
_input
in
node
.
inputs
:
if
_input
in
graph
.
output_to_node
:
preprocessor
=
graph
.
output_to_node
[
_input
]
assert
preprocessor
.
op_type
!=
TUPLE_UNPACK_KIND
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment