Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5d2a59fd
Unverified
Commit
5d2a59fd
authored
Aug 12, 2020
by
Ningxin Zheng
Committed by
GitHub
Aug 12, 2020
Browse files
Successive unpack (#2768)
parent
e7fccfb4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
9 deletions
+119
-9
src/sdk/pynni/nni/_graph_utils.py
src/sdk/pynni/nni/_graph_utils.py
+22
-8
src/sdk/pynni/tests/test_graph_utils.py
src/sdk/pynni/tests/test_graph_utils.py
+97
-1
No files found.
src/sdk/pynni/nni/_graph_utils.py
View file @
5d2a59fd
...
@@ -530,8 +530,15 @@ class TorchModuleGraph(TorchGraph):
...
@@ -530,8 +530,15 @@ class TorchModuleGraph(TorchGraph):
return
True
return
True
if
node_cpp
.
kind
()
in
[
LIST_UNPACK_KIND
,
TUPLE_UNPACK_KIND
]:
if
node_cpp
.
kind
()
in
[
LIST_UNPACK_KIND
,
TUPLE_UNPACK_KIND
]:
# We cannot merge the List/Tuple
# We cannot merge the List/Tuple
#
Construct/
Unpack func into other nodes, else it
# Unpack func into other nodes, else it
# may lead to a graph construction error.
# may lead to a graph construction error.
# The reason why we donnot take the construct node
# also as a key node is that `cat` operation node need
# the last(previous) visited node to infer the mask. If
# we take the Construct node as the important node, the
# predecessor of the `cat` node will always be a construct
# node, which means we cannot infer the mask for the cat
# operation.
return
True
return
True
return
False
return
False
...
@@ -556,9 +563,13 @@ class TorchModuleGraph(TorchGraph):
...
@@ -556,9 +563,13 @@ class TorchModuleGraph(TorchGraph):
_logger
.
debug
(
'List/Tuple Construct Node(cpp) %s'
,
str
(
last_cpp
))
_logger
.
debug
(
'List/Tuple Construct Node(cpp) %s'
,
str
(
last_cpp
))
_logger
.
debug
(
'List/Tuple Unpack Node(cpp) %s'
,
str
(
unpack_cpp
))
_logger
.
debug
(
'List/Tuple Unpack Node(cpp) %s'
,
str
(
unpack_cpp
))
assert
len
(
list
(
unpack_cpp
.
outputs
()))
==
len
(
list
(
last_cpp
.
inputs
()))
assert
len
(
list
(
unpack_cpp
.
outputs
()))
==
len
(
list
(
last_cpp
.
inputs
()))
for
_input
,
_output
in
zip
(
last_cpp
.
inputs
(),
unpack_cpp
.
outputs
()):
errmsg
=
'%s Input number: %d if inconsistent with the output number %d'
%
(
unpack_cpp
,
\
_debug_input
=
_input
.
debugName
()
len
(
node
.
inputs
),
len
(
list
(
last_cpp
.
inputs
())))
_debug_output
=
_output
.
debugName
()
assert
len
(
node
.
inputs
)
==
len
(
list
(
last_cpp
.
inputs
())),
errmsg
for
_debug_input
,
_debug_output
in
zip
(
node
.
inputs
,
node
.
outputs
):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if
_debug_input
in
self
.
input_to_node
and
_debug_output
in
self
.
input_to_node
:
if
_debug_input
in
self
.
input_to_node
and
_debug_output
in
self
.
input_to_node
:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
# one tensor can be used as input for multiple nodes at the same time.
...
@@ -570,10 +581,13 @@ class TorchModuleGraph(TorchGraph):
...
@@ -570,10 +581,13 @@ class TorchModuleGraph(TorchGraph):
self
.
input_to_node
[
_debug_input
].
remove
(
node
)
self
.
input_to_node
[
_debug_input
].
remove
(
node
)
# add the following nodes of _output into the input_to_node[_debug_input]
# add the following nodes of _output into the input_to_node[_debug_input]
self
.
input_to_node
[
_debug_input
].
extend
(
self
.
input_to_node
[
_debug_output
])
self
.
input_to_node
[
_debug_input
].
extend
(
self
.
input_to_node
[
_debug_output
])
if
_debug_input
in
self
.
output_to_node
and
_debug_output
in
self
.
output_to_node
:
# just remove the _debug_output from the grapgh index. So that we can also skip
# output_to_node[_debug_output] is a NodePyGroup, because one output
# the construct and tuple
# tensor only can be generated by one node.
if
_debug_output
in
self
.
input_to_node
:
self
.
output_to_node
[
_debug_output
]
=
self
.
output_to_node
[
_debug_input
]
for
following_node
in
self
.
input_to_node
[
_debug_output
]:
_tmp_index
=
following_node
.
inputs
.
index
(
_debug_output
)
following_node
.
inputs
[
_tmp_index
]
=
_debug_input
self
.
unpacked
=
True
self
.
unpacked
=
True
...
...
src/sdk/pynni/tests/test_graph_utils.py
View file @
5d2a59fd
...
@@ -15,7 +15,7 @@ from google.protobuf import text_format
...
@@ -15,7 +15,7 @@ from google.protobuf import text_format
import
unittest
import
unittest
from
unittest
import
TestCase
,
main
from
unittest
import
TestCase
,
main
from
nni._graph_utils
import
build_module_graph
,
build_graph
,
TorchModuleGraph
from
nni._graph_utils
import
build_module_graph
,
build_graph
,
TorchModuleGraph
,
TUPLE_UNPACK_KIND
class
BackboneModel1
(
nn
.
Module
):
class
BackboneModel1
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -194,5 +194,101 @@ class GraphUtilsTestCase(TestCase):
...
@@ -194,5 +194,101 @@ class GraphUtilsTestCase(TestCase):
assert
len
(
nodes
)
==
1
assert
len
(
nodes
)
==
1
node
=
nodes
[
0
]
node
=
nodes
[
0
]
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4.0"
,
"not supported"
)
def
test_module_unpack
(
self
):
"""
test the tuple/list unpack function of TorchModuleGraph.
Following models are from the issue 2756
https://github.com/microsoft/nni/issues/2756.
MyModule will have two successive tuple unpack operations
between the B and C.
"""
class
CBR
(
nn
.
Module
):
def
__init__
(
self
,
i
,
o
):
super
(
CBR
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
i
,
o
,
kernel_size
=
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
o
)
self
.
act1
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
return
self
.
act1
(
self
.
bn1
(
self
.
conv1
(
x
)))
class
A
(
nn
.
Module
):
def
__init__
(
self
):
super
(
A
,
self
).
__init__
()
self
.
conv1
=
CBR
(
3
,
6
,
)
self
.
conv2
=
CBR
(
6
,
8
,
)
self
.
conv3
=
CBR
(
6
,
12
)
def
forward
(
self
,
x
):
x1
=
self
.
conv1
(
x
)
x2
=
self
.
conv2
(
x1
)
x3
=
self
.
conv3
(
x1
)
return
(
x2
,
x3
)
class
B1
(
nn
.
Module
):
def
__init__
(
self
):
super
(
B1
,
self
).
__init__
()
self
.
conv1
=
CBR
(
12
,
32
)
self
.
conv2
=
CBR
(
32
,
32
)
self
.
conv3
=
CBR
(
32
,
32
)
def
forward
(
self
,
x
):
x1
=
self
.
conv1
(
x
)
x2
=
self
.
conv2
(
x1
)
x3
=
self
.
conv3
(
x2
)
return
(
x1
,
x2
,
x3
)
class
B
(
nn
.
Module
):
def
__init__
(
self
):
super
(
B
,
self
).
__init__
()
self
.
b
=
B1
()
def
forward
(
self
,
x
):
return
self
.
b
(
x
[
-
1
])
class
C
(
nn
.
Module
):
def
__init__
(
self
):
super
(
C
,
self
).
__init__
()
self
.
conv1
=
CBR
(
8
,
32
)
self
.
conv2
=
CBR
(
12
,
32
)
self
.
conv3
=
CBR
(
32
,
32
)
self
.
conv4
=
CBR
(
32
,
32
)
self
.
conv5
=
CBR
(
32
,
32
)
def
forward
(
self
,
x
):
return
(
self
.
conv1
(
x
[
0
]),
self
.
conv2
(
x
[
1
]),
self
.
conv3
(
x
[
2
]),
self
.
conv4
(
x
[
3
]),
self
.
conv5
(
x
[
4
]))
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MyModule
,
self
).
__init__
()
self
.
a
=
A
()
self
.
b
=
B
()
# self.dummy = Dummy()
self
.
c
=
C
()
def
forward
(
self
,
x
):
x_a
=
self
.
a
(
x
)
x_b
=
self
.
b
(
x_a
)
xc
=
self
.
c
(
x_a
+
x_b
)
return
xc
dummy_input
=
torch
.
rand
(
1
,
3
,
28
,
28
)
model
=
MyModule
()
graph
=
TorchModuleGraph
(
model
,
dummy_input
)
graph
.
unpack_manually
()
for
node
in
graph
.
nodes_py
.
nodes_op
:
# The input of the function nodes should
# not come from the TupleUnpack node, because
# all the TupleUnpack nodes have been removed(unpacked)
# manually
for
_input
in
node
.
inputs
:
if
_input
in
graph
.
output_to_node
:
preprocessor
=
graph
.
output_to_node
[
_input
]
assert
preprocessor
.
op_type
!=
TUPLE_UNPACK_KIND
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment