Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f8633ac9
"cuda/tests/Makefile" did not exist on "450549a4f52fd2d116cfa2652609deceda41dc8f"
Unverified
Commit
f8633ac9
authored
Jul 24, 2020
by
Ningxin Zheng
Committed by
GitHub
Jul 24, 2020
Browse files
Support the List/Tuple Construct/Unpack operation for TorchModuleGraph (#2609)
parent
66f2777f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
9 deletions
+80
-9
src/sdk/pynni/nni/_graph_utils.py
src/sdk/pynni/nni/_graph_utils.py
+77
-9
src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py
...sdk/pynni/nni/compression/torch/utils/shape_dependency.py
+3
-0
No files found.
src/sdk/pynni/nni/_graph_utils.py
View file @
f8633ac9
...
...
@@ -11,6 +11,10 @@ from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, G
CLASSTYPE_KIND
=
'ClassType'
GETATTR_KIND
=
'prim::GetAttr'
CAT_KIND
=
'aten::cat'
LIST_CONSTRUCT_KIND
=
'prim::ListConstruct'
LIST_UNPACK_KIND
=
'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND
=
'prim::TupleConstruct'
TUPLE_UNPACK_KIND
=
'prim::TupleUnpack'
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -177,7 +181,7 @@ class NodePyGroup(NodePy):
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""
def
__init__
(
self
,
name
,
unique_name
,
node_type
,
op_type
,
node_cpps
,
inputs
=
None
,
outputs
=
None
):
def
__init__
(
self
,
name
,
unique_name
,
node_type
,
op_type
,
node_cpps
,
inputs
=
None
,
outputs
=
None
,
key_node
=
None
):
"""
Parameters:
-----------
...
...
@@ -199,6 +203,8 @@ class NodePyGroup(NodePy):
All the inputs of this node, each element is debugName of one input
outputs: list of str
All the outputs of this node, each element is debugName of one output
key_node: torch._C.Node
The key node of this NodePyGroup.
"""
super
(
NodePyGroup
,
self
).
__init__
(
name
,
[])
self
.
node_cpps
=
node_cpps
...
...
@@ -211,6 +217,8 @@ class NodePyGroup(NodePy):
self
.
add_nodes
(
node_cpps
)
self
.
inputs
=
inputs
self
.
outputs
=
outputs
# The core node in this NodePyGroup
self
.
key_node
=
key_node
def
add_nodes
(
self
,
node_cpps
):
for
node_cpp
in
node_cpps
:
...
...
@@ -239,7 +247,7 @@ class TorchModuleGraph(TorchGraph):
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
self
.
_extract_auxiliary_info
()
def
_expand_
non_prim
_node
(
self
,
node
,
nodes
,
input_to_node
,
output_to_node
,
def
_expand_
key_func
_node
(
self
,
node
,
nodes
,
input_to_node
,
output_to_node
,
module_type
):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
...
...
@@ -284,7 +292,7 @@ class TorchModuleGraph(TorchGraph):
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
if
predecessor_node
.
kind
().
startswith
(
'prim::'
):
if
not
self
.
_is_key_func
(
predecessor_node
):
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
...
...
@@ -294,7 +302,7 @@ class TorchModuleGraph(TorchGraph):
for
output
in
node
.
outputs
():
outputs
.
append
(
output
.
debugName
())
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
)
node_group
,
inputs
=
inputs
,
outputs
=
outputs
,
key_node
=
node
)
return
nodepy
def
_expand_module_node
(
self
,
node
,
node_name
,
unique_name
,
op_type
,
nodes
,
...
...
@@ -510,6 +518,65 @@ class TorchModuleGraph(TorchGraph):
output_to_node
[
output
]
=
node
return
name_to_node
,
input_to_node
,
output_to_node
def
_is_key_func
(
self
,
node_cpp
):
"""
Judge if a cpp node is a key function node.
If so, we should not merge this node into the
adjacent node.
"""
if
node_cpp
.
kind
().
startswith
(
'aten::'
):
# the nodes that start with 'aten' are key function
# nodes
return
True
if
node_cpp
.
kind
()
in
[
LIST_UNPACK_KIND
,
TUPLE_UNPACK_KIND
]:
# We cannot merge the List/Tuple
# Construct/Unpack func into other nodes, else it
# may lead to a graph construction error.
return
True
return
False
def
unpack_manually
(
self
):
"""
Unpack the tensor tuple or tensor list manually,
and remove the ListUnpack/TupleUnpack node from
the graph. Note: this function will change the
graph structure.
"""
if
hasattr
(
self
,
'unpacked'
):
# if already unpacked the tuple/list manually
return
for
node
in
self
.
nodes_py
.
nodes_op
:
if
node
.
op_type
in
[
TUPLE_UNPACK_KIND
,
LIST_UNPACK_KIND
]:
unpack_cpp
=
node
.
key_node
last_cpp
=
list
(
unpack_cpp
.
inputs
())[
0
].
node
()
if
last_cpp
.
kind
()
in
[
TUPLE_CONSTRUCT_KIND
,
LIST_CONSTRUCT_KIND
]:
# we need check if the tensor tuple or tensor list is produced
# by a list/tuple construct node. If so, we can unpack the tuple
# or list manunally.
_logger
.
debug
(
'List/Tuple Construct Node(cpp) %s'
,
str
(
last_cpp
))
_logger
.
debug
(
'List/Tuple Unpack Node(cpp) %s'
,
str
(
unpack_cpp
))
assert
len
(
list
(
unpack_cpp
.
outputs
()))
==
len
(
list
(
last_cpp
.
inputs
()))
for
_input
,
_output
in
zip
(
last_cpp
.
inputs
(),
unpack_cpp
.
outputs
()):
_debug_input
=
_input
.
debugName
()
_debug_output
=
_output
.
debugName
()
if
_debug_input
in
self
.
input_to_node
and
_debug_output
in
self
.
input_to_node
:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
# note that, in this case, the construct cpp node and unpack cpp node
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
self
.
input_to_node
[
_debug_input
].
remove
(
node
)
# add the following nodes of _output into the input_to_node[_debug_input]
self
.
input_to_node
[
_debug_input
].
extend
(
self
.
input_to_node
[
_debug_output
])
if
_debug_input
in
self
.
output_to_node
and
_debug_output
in
self
.
output_to_node
:
# output_to_node[_debug_output] is a NodePyGroup, because one output
# tensor only can be generated by one node.
self
.
output_to_node
[
_debug_output
]
=
self
.
output_to_node
[
_debug_input
]
self
.
unpacked
=
True
def
_build_graph
(
self
):
"""
Build graph using our defined format from jit trace.
...
...
@@ -585,13 +652,14 @@ class TorchModuleGraph(TorchGraph):
# build node group for torch.nn.functional
for
_
,
nodes
in
func_to_nodes
.
items
():
# extract non prim:: nodes
non_prim
_nodes
=
list
()
key_func
_nodes
=
list
()
for
node
in
nodes
:
if
not
node
.
kind
().
startswith
(
'prim::'
):
non_prim_nodes
.
append
(
node
)
if
self
.
_is_key_func
(
node
):
# find the key function nodes
key_func_nodes
.
append
(
node
)
# for each non prim node, expand it
for
node
in
non_prim
_nodes
:
node_group
=
self
.
_expand_
non_prim
_node
(
for
node
in
key_func
_nodes
:
node_group
=
self
.
_expand_
key_func
_node
(
node
,
nodes
,
input_to_node
,
output_to_node
,
'func'
)
nodes_py
.
nodes_op
.
append
(
node_group
)
# get shape infor for view (aten::view) func
...
...
src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py
View file @
f8633ac9
...
...
@@ -86,6 +86,9 @@ class ChannelDependency(Dependency):
Build the channel dependency for the conv layers
in the model.
"""
# unpack the tuple/list manually before analyze the
# channel dependency
self
.
graph
.
unpack_manually
()
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
parent_layers
=
[]
# find the node that contains aten::add
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment