Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f8633ac9
Unverified
Commit
f8633ac9
authored
Jul 24, 2020
by
Ningxin Zheng
Committed by
GitHub
Jul 24, 2020
Browse files
Support the List/Tuple Construct/Unpack operation for TorchModuleGraph (#2609)
parent
66f2777f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
9 deletions
+80
-9
src/sdk/pynni/nni/_graph_utils.py
src/sdk/pynni/nni/_graph_utils.py
+77
-9
src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py
...sdk/pynni/nni/compression/torch/utils/shape_dependency.py
+3
-0
No files found.
src/sdk/pynni/nni/_graph_utils.py
View file @
f8633ac9
...
@@ -11,6 +11,10 @@ from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, G
...
@@ -11,6 +11,10 @@ from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, G
CLASSTYPE_KIND
=
'ClassType'
CLASSTYPE_KIND
=
'ClassType'
GETATTR_KIND
=
'prim::GetAttr'
GETATTR_KIND
=
'prim::GetAttr'
CAT_KIND
=
'aten::cat'
CAT_KIND
=
'aten::cat'
LIST_CONSTRUCT_KIND
=
'prim::ListConstruct'
LIST_UNPACK_KIND
=
'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND
=
'prim::TupleConstruct'
TUPLE_UNPACK_KIND
=
'prim::TupleUnpack'
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -177,7 +181,7 @@ class NodePyGroup(NodePy):
...
@@ -177,7 +181,7 @@ class NodePyGroup(NodePy):
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""
"""
def
__init__
(
self
,
name
,
unique_name
,
node_type
,
op_type
,
node_cpps
,
inputs
=
None
,
outputs
=
None
):
def
__init__
(
self
,
name
,
unique_name
,
node_type
,
op_type
,
node_cpps
,
inputs
=
None
,
outputs
=
None
,
key_node
=
None
):
"""
"""
Parameters:
Parameters:
-----------
-----------
...
@@ -199,6 +203,8 @@ class NodePyGroup(NodePy):
...
@@ -199,6 +203,8 @@ class NodePyGroup(NodePy):
All the inputs of this node, each element is debugName of one input
All the inputs of this node, each element is debugName of one input
outputs: list of str
outputs: list of str
All the outputs of this node, each element is debugName of one output
All the outputs of this node, each element is debugName of one output
key_node: torch._C.Node
The key node of this NodePyGroup.
"""
"""
super
(
NodePyGroup
,
self
).
__init__
(
name
,
[])
super
(
NodePyGroup
,
self
).
__init__
(
name
,
[])
self
.
node_cpps
=
node_cpps
self
.
node_cpps
=
node_cpps
...
@@ -211,6 +217,8 @@ class NodePyGroup(NodePy):
...
@@ -211,6 +217,8 @@ class NodePyGroup(NodePy):
self
.
add_nodes
(
node_cpps
)
self
.
add_nodes
(
node_cpps
)
self
.
inputs
=
inputs
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
outputs
=
outputs
# The core node in this NodePyGroup
self
.
key_node
=
key_node
def
add_nodes
(
self
,
node_cpps
):
def
add_nodes
(
self
,
node_cpps
):
for
node_cpp
in
node_cpps
:
for
node_cpp
in
node_cpps
:
...
@@ -239,7 +247,7 @@ class TorchModuleGraph(TorchGraph):
...
@@ -239,7 +247,7 @@ class TorchModuleGraph(TorchGraph):
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
self
.
_extract_auxiliary_info
()
self
.
_extract_auxiliary_info
()
def
_expand_
non_prim
_node
(
self
,
node
,
nodes
,
input_to_node
,
output_to_node
,
def
_expand_
key_func
_node
(
self
,
node
,
nodes
,
input_to_node
,
output_to_node
,
module_type
):
module_type
):
"""
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
...
@@ -284,7 +292,7 @@ class TorchModuleGraph(TorchGraph):
...
@@ -284,7 +292,7 @@ class TorchModuleGraph(TorchGraph):
input_name
=
_input
.
debugName
()
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
predecessor_node
=
output_to_node
[
input_name
]
if
predecessor_node
.
kind
().
startswith
(
'prim::'
):
if
not
self
.
_is_key_func
(
predecessor_node
):
node_group
.
append
(
predecessor_node
)
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
else
:
...
@@ -294,7 +302,7 @@ class TorchModuleGraph(TorchGraph):
...
@@ -294,7 +302,7 @@ class TorchModuleGraph(TorchGraph):
for
output
in
node
.
outputs
():
for
output
in
node
.
outputs
():
outputs
.
append
(
output
.
debugName
())
outputs
.
append
(
output
.
debugName
())
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
)
node_group
,
inputs
=
inputs
,
outputs
=
outputs
,
key_node
=
node
)
return
nodepy
return
nodepy
def
_expand_module_node
(
self
,
node
,
node_name
,
unique_name
,
op_type
,
nodes
,
def
_expand_module_node
(
self
,
node
,
node_name
,
unique_name
,
op_type
,
nodes
,
...
@@ -510,6 +518,65 @@ class TorchModuleGraph(TorchGraph):
...
@@ -510,6 +518,65 @@ class TorchModuleGraph(TorchGraph):
output_to_node
[
output
]
=
node
output_to_node
[
output
]
=
node
return
name_to_node
,
input_to_node
,
output_to_node
return
name_to_node
,
input_to_node
,
output_to_node
def
_is_key_func
(
self
,
node_cpp
):
"""
Judge if a cpp node is a key function node.
If so, we should not merge this node into the
adjacent node.
"""
if
node_cpp
.
kind
().
startswith
(
'aten::'
):
# the nodes that start with 'aten' are key function
# nodes
return
True
if
node_cpp
.
kind
()
in
[
LIST_UNPACK_KIND
,
TUPLE_UNPACK_KIND
]:
# We cannot merge the List/Tuple
# Construct/Unpack func into other nodes, else it
# may lead to a graph construction error.
return
True
return
False
def
unpack_manually
(
self
):
"""
Unpack the tensor tuple or tensor list manually,
and remove the ListUnpack/TupleUnpack node from
the graph. Note: this function will change the
graph structure.
"""
if
hasattr
(
self
,
'unpacked'
):
# if already unpacked the tuple/list manually
return
for
node
in
self
.
nodes_py
.
nodes_op
:
if
node
.
op_type
in
[
TUPLE_UNPACK_KIND
,
LIST_UNPACK_KIND
]:
unpack_cpp
=
node
.
key_node
last_cpp
=
list
(
unpack_cpp
.
inputs
())[
0
].
node
()
if
last_cpp
.
kind
()
in
[
TUPLE_CONSTRUCT_KIND
,
LIST_CONSTRUCT_KIND
]:
# we need check if the tensor tuple or tensor list is produced
# by a list/tuple construct node. If so, we can unpack the tuple
# or list manunally.
_logger
.
debug
(
'List/Tuple Construct Node(cpp) %s'
,
str
(
last_cpp
))
_logger
.
debug
(
'List/Tuple Unpack Node(cpp) %s'
,
str
(
unpack_cpp
))
assert
len
(
list
(
unpack_cpp
.
outputs
()))
==
len
(
list
(
last_cpp
.
inputs
()))
for
_input
,
_output
in
zip
(
last_cpp
.
inputs
(),
unpack_cpp
.
outputs
()):
_debug_input
=
_input
.
debugName
()
_debug_output
=
_output
.
debugName
()
if
_debug_input
in
self
.
input_to_node
and
_debug_output
in
self
.
input_to_node
:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
# note that, in this case, the construct cpp node and unpack cpp node
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
self
.
input_to_node
[
_debug_input
].
remove
(
node
)
# add the following nodes of _output into the input_to_node[_debug_input]
self
.
input_to_node
[
_debug_input
].
extend
(
self
.
input_to_node
[
_debug_output
])
if
_debug_input
in
self
.
output_to_node
and
_debug_output
in
self
.
output_to_node
:
# output_to_node[_debug_output] is a NodePyGroup, because one output
# tensor only can be generated by one node.
self
.
output_to_node
[
_debug_output
]
=
self
.
output_to_node
[
_debug_input
]
self
.
unpacked
=
True
def
_build_graph
(
self
):
def
_build_graph
(
self
):
"""
"""
Build graph using our defined format from jit trace.
Build graph using our defined format from jit trace.
...
@@ -585,13 +652,14 @@ class TorchModuleGraph(TorchGraph):
...
@@ -585,13 +652,14 @@ class TorchModuleGraph(TorchGraph):
# build node group for torch.nn.functional
# build node group for torch.nn.functional
for
_
,
nodes
in
func_to_nodes
.
items
():
for
_
,
nodes
in
func_to_nodes
.
items
():
# extract non prim:: nodes
# extract non prim:: nodes
non_prim
_nodes
=
list
()
key_func
_nodes
=
list
()
for
node
in
nodes
:
for
node
in
nodes
:
if
not
node
.
kind
().
startswith
(
'prim::'
):
if
self
.
_is_key_func
(
node
):
non_prim_nodes
.
append
(
node
)
# find the key function nodes
key_func_nodes
.
append
(
node
)
# for each non prim node, expand it
# for each non prim node, expand it
for
node
in
non_prim
_nodes
:
for
node
in
key_func
_nodes
:
node_group
=
self
.
_expand_
non_prim
_node
(
node_group
=
self
.
_expand_
key_func
_node
(
node
,
nodes
,
input_to_node
,
output_to_node
,
'func'
)
node
,
nodes
,
input_to_node
,
output_to_node
,
'func'
)
nodes_py
.
nodes_op
.
append
(
node_group
)
nodes_py
.
nodes_op
.
append
(
node_group
)
# get shape infor for view (aten::view) func
# get shape infor for view (aten::view) func
...
...
src/sdk/pynni/nni/compression/torch/utils/shape_dependency.py
View file @
f8633ac9
...
@@ -86,6 +86,9 @@ class ChannelDependency(Dependency):
...
@@ -86,6 +86,9 @@ class ChannelDependency(Dependency):
Build the channel dependency for the conv layers
Build the channel dependency for the conv layers
in the model.
in the model.
"""
"""
# unpack the tuple/list manually before analyze the
# channel dependency
self
.
graph
.
unpack_manually
()
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
parent_layers
=
[]
parent_layers
=
[]
# find the node that contains aten::add
# find the node that contains aten::add
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment