Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7ee5036b
Unverified
Commit
7ee5036b
authored
Jun 11, 2020
by
Ningxin Zheng
Committed by
GitHub
Jun 11, 2020
Browse files
Bugfix issue2485 (#2524)
parent
e1e1977c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
222 additions
and
72 deletions
+222
-72
src/sdk/pynni/nni/_graph_utils.py
src/sdk/pynni/nni/_graph_utils.py
+181
-71
src/sdk/pynni/tests/test_graph_utils.py
src/sdk/pynni/tests/test_graph_utils.py
+41
-1
No files found.
src/sdk/pynni/nni/_graph_utils.py
View file @
7ee5036b
...
@@ -8,19 +8,21 @@ import re
...
@@ -8,19 +8,21 @@ import re
from
collections
import
defaultdict
from
collections
import
defaultdict
import
torch
import
torch
from
torch.utils.tensorboard._pytorch_graph
import
NodePy
,
NodePyIO
,
NodePyOP
,
GraphPy
from
torch.utils.tensorboard._pytorch_graph
import
NodePy
,
NodePyIO
,
NodePyOP
,
GraphPy
CLASSTYPE_KIND
=
'ClassType'
CLASSTYPE_KIND
=
'ClassType'
GETATTR_KIND
=
'prim::GetAttr'
GETATTR_KIND
=
'prim::GetAttr'
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
def
build_module_graph
(
model
,
dummy_input
):
def
build_module_graph
(
model
,
dummy_input
):
return
TorchModuleGraph
(
model
,
dummy_input
)
return
TorchModuleGraph
(
model
,
dummy_input
)
def
build_graph
(
model
,
dummy_input
,
verbose
=
False
):
def
build_graph
(
model
,
dummy_input
,
verbose
=
False
):
g
=
TorchProtoGraph
(
model
,
dummy_input
,
verbose
)
g
=
TorchProtoGraph
(
model
,
dummy_input
,
verbose
)
return
g
.
graph_def
,
g
.
stepstats
return
g
.
graph_def
,
g
.
stepstats
def
parse_traced_name
(
module_name
):
def
parse_traced_name
(
module_name
):
prefix
=
'TracedModule['
prefix
=
'TracedModule['
suffix
=
']'
suffix
=
']'
...
@@ -28,11 +30,13 @@ def parse_traced_name(module_name):
...
@@ -28,11 +30,13 @@ def parse_traced_name(module_name):
module_name
=
module_name
[
len
(
prefix
):
-
len
(
suffix
)]
module_name
=
module_name
[
len
(
prefix
):
-
len
(
suffix
)]
return
module_name
return
module_name
class
TorchGraph
:
class
TorchGraph
:
"""
"""
This class is to extract pytorch model topology graph by tracing
This class is to extract pytorch model topology graph by tracing
"""
"""
def
__init__
(
self
,
model
,
dummy_input
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -40,25 +44,39 @@ class TorchGraph:
...
@@ -40,25 +44,39 @@ class TorchGraph:
The model user wants to speed up
The model user wants to speed up
dummy_input : pytorch tensor
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
The dummy input for ```jit.trace```, users should put it on right device before pass in
traced_model : torch._C.torch.jit.TopLevelTracedModule
An alredy traced model, if traced_model is not None, then TorchGraph will build the graph
based on this traced model and won't trace the model again.
"""
"""
assert
torch
.
__version__
>=
'1.3.1'
assert
torch
.
__version__
>=
'1.3.1'
# check if the input is legal
if
traced_model
is
not
None
:
assert
isinstance
(
traced_model
,
torch
.
jit
.
TopLevelTracedModule
)
self
.
trace
=
traced_model
# it's ok if the graph is already unpacked
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
elif
model
is
not
None
and
dummy_input
is
not
None
:
self
.
bound_model
=
model
self
.
bound_model
=
model
self
.
_trace
(
model
,
dummy_input
)
self
.
_trace
(
model
,
dummy_input
)
else
:
raise
Exception
(
'Please provide model & dummy_input or the traced_model as inputs'
)
def
_trace
(
self
,
model
,
dummy_input
):
def
_trace
(
self
,
model
,
dummy_input
):
with
torch
.
onnx
.
set_training
(
model
,
False
):
with
torch
.
onnx
.
set_training
(
model
,
False
):
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
class
TorchProtoGraph
(
TorchGraph
):
class
TorchProtoGraph
(
TorchGraph
):
"""
"""
Generates model graph for pytorch models in protobuf, this implementation
is borrowed from pytorch v1.4.0,
Generates model graph for pytorch models in protobuf, this implementation
and fixed following issues:
is borrowed from pytorch v1.4.0,
and fixed following issues:
https://github.com/pytorch/pytorch/issues/33691
https://github.com/pytorch/pytorch/issues/33691
https://github.com/pytorch/pytorch/issues/33670
https://github.com/pytorch/pytorch/issues/33670
"""
"""
def
__init__
(
self
,
model
,
dummy_input
,
verbose
=
False
):
def
__init__
(
self
,
model
,
dummy_input
,
verbose
=
False
):
super
().
__init__
(
model
,
dummy_input
)
super
().
__init__
(
model
,
dummy_input
)
...
@@ -70,8 +88,10 @@ class TorchProtoGraph(TorchGraph):
...
@@ -70,8 +88,10 @@ class TorchProtoGraph(TorchGraph):
list_of_nodes
=
self
.
parse
(
self
.
trace
.
graph
,
self
.
trace
,
dummy_input
)
list_of_nodes
=
self
.
parse
(
self
.
trace
.
graph
,
self
.
trace
,
dummy_input
)
if
verbose
:
if
verbose
:
print
(
self
.
trace
.
graph
)
print
(
self
.
trace
.
graph
)
self
.
stepstats
=
RunMetadata
(
step_stats
=
StepStats
(
dev_stats
=
[
DeviceStepStats
(
device
=
"/device:CPU:0"
)]))
self
.
stepstats
=
RunMetadata
(
step_stats
=
StepStats
(
self
.
graph_def
=
GraphDef
(
node
=
list_of_nodes
,
versions
=
VersionDef
(
producer
=
22
))
dev_stats
=
[
DeviceStepStats
(
device
=
"/device:CPU:0"
)]))
self
.
graph_def
=
GraphDef
(
node
=
list_of_nodes
,
versions
=
VersionDef
(
producer
=
22
))
def
parse
(
self
,
graph
,
trace
,
args
=
None
,
omit_useless_nodes
=
True
):
def
parse
(
self
,
graph
,
trace
,
args
=
None
,
omit_useless_nodes
=
True
):
"""This method parses an optimized PyTorch model graph and produces
"""This method parses an optimized PyTorch model graph and produces
...
@@ -94,16 +114,20 @@ class TorchProtoGraph(TorchGraph):
...
@@ -94,16 +114,20 @@ class TorchProtoGraph(TorchGraph):
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
attr_to_scope
=
dict
()
attr_to_scope
=
dict
()
node_to_name
=
lambda
d
:
str
(
d
).
split
(
":"
)[
0
].
strip
()
def
node_to_name
(
d
):
return
str
(
d
).
split
(
":"
)[
0
].
strip
()
for
node
in
graph
.
nodes
():
for
node
in
graph
.
nodes
():
if
node
.
kind
()
==
GETATTR_KIND
:
if
node
.
kind
()
==
GETATTR_KIND
:
attr_name
=
node
.
s
(
'name'
)
attr_name
=
node
.
s
(
'name'
)
node_name
=
node_to_name
(
node
)
node_name
=
node_to_name
(
node
)
parent
=
node
.
input
().
node
()
parent
=
node
.
input
().
node
()
if
parent
.
kind
()
==
GETATTR_KIND
:
# If the parent node is not the top-level "self" node
# If the parent node is not the top-level "self" node
if
parent
.
kind
()
==
GETATTR_KIND
:
parent_scope
=
attr_to_scope
[
node_to_name
(
parent
)]
parent_scope
=
attr_to_scope
[
node_to_name
(
parent
)]
attr_scope
=
parent_scope
.
split
(
'/'
)[
-
1
]
attr_scope
=
parent_scope
.
split
(
'/'
)[
-
1
]
attr_to_scope
[
node_name
]
=
'{}/{}.{}'
.
format
(
parent_scope
,
attr_scope
,
attr_name
)
attr_to_scope
[
node_name
]
=
'{}/{}.{}'
.
format
(
parent_scope
,
attr_scope
,
attr_name
)
else
:
else
:
attr_to_scope
[
node_name
]
=
'__module.{}'
.
format
(
attr_name
)
attr_to_scope
[
node_name
]
=
'__module.{}'
.
format
(
attr_name
)
# We don't need classtype nodes; scope will provide this information
# We don't need classtype nodes; scope will provide this information
...
@@ -114,7 +138,8 @@ class TorchProtoGraph(TorchGraph):
...
@@ -114,7 +138,8 @@ class TorchProtoGraph(TorchGraph):
else
:
else
:
nodes_py
.
append
(
NodePyOP
(
node
))
nodes_py
.
append
(
NodePyOP
(
node
))
for
i
,
node
in
enumerate
(
graph
.
outputs
()):
# Create sink nodes for output ops
# Create sink nodes for output ops
for
i
,
node
in
enumerate
(
graph
.
outputs
()):
node_py
=
NodePyIO
(
node
,
'output'
)
node_py
=
NodePyIO
(
node
,
'output'
)
node_py
.
debugName
=
"output.{}"
.
format
(
i
+
1
)
node_py
.
debugName
=
"output.{}"
.
format
(
i
+
1
)
node_py
.
inputs
=
[
node
.
debugName
()]
node_py
.
inputs
=
[
node
.
debugName
()]
...
@@ -136,23 +161,33 @@ class TorchProtoGraph(TorchGraph):
...
@@ -136,23 +161,33 @@ class TorchProtoGraph(TorchGraph):
node
.
scopeName
=
base_name
node
.
scopeName
=
base_name
else
:
else
:
module_name
+=
'.'
+
alias
module_name
+=
'.'
+
alias
node
.
scopeName
+=
'/'
+
(
alias_to_name
[
module_name
]
if
module_name
in
alias_to_name
else
alias
)
node
.
scopeName
+=
'/'
+
\
(
alias_to_name
[
module_name
]
if
module_name
in
alias_to_name
else
alias
)
nodes_py
.
populate_namespace_from_OP_to_IO
()
nodes_py
.
populate_namespace_from_OP_to_IO
()
return
nodes_py
.
to_proto
()
return
nodes_py
.
to_proto
()
class
NodePyGroup
(
NodePy
):
class
NodePyGroup
(
NodePy
):
"""
"""
This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph,
This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph,
there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to
there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""
"""
def
__init__
(
self
,
name
,
node_type
,
op_type
,
node_cpps
,
inputs
=
None
,
outputs
=
None
):
def
__init__
(
self
,
name
,
unique_name
,
node_type
,
op_type
,
node_cpps
,
inputs
=
None
,
outputs
=
None
):
"""
"""
Parameters:
Parameters:
-----------
-----------
name: str
name: str
node name, such as `conv1`, `backbone.classifier`
node name, such as `conv1`, `backbone.classifier`
unique_name: str
A global unique name for current node. Due to some modules,
such as relu, may be reused several times, so the scopename
is not suitable as the global unique identifier, so we add a
unique_name for each node as the global unique identifier.
We should use the unique_name to traverset the module graph.
node_type: str
node_type: str
`module` or `func`
`module` or `func`
op_type: str
op_type: str
...
@@ -167,6 +202,7 @@ class NodePyGroup(NodePy):
...
@@ -167,6 +202,7 @@ class NodePyGroup(NodePy):
super
(
NodePyGroup
,
self
).
__init__
(
name
,
[])
super
(
NodePyGroup
,
self
).
__init__
(
name
,
[])
self
.
node_cpps
=
node_cpps
self
.
node_cpps
=
node_cpps
self
.
name
=
name
self
.
name
=
name
self
.
unique_name
=
unique_name
self
.
op_type
=
op_type
self
.
op_type
=
op_type
self
.
type
=
node_type
self
.
type
=
node_type
self
.
nodes
=
[]
self
.
nodes
=
[]
...
@@ -178,7 +214,7 @@ class NodePyGroup(NodePy):
...
@@ -178,7 +214,7 @@ class NodePyGroup(NodePy):
def
add_nodes
(
self
,
node_cpps
):
def
add_nodes
(
self
,
node_cpps
):
for
node_cpp
in
node_cpps
:
for
node_cpp
in
node_cpps
:
nodepy
=
NodePyOP
(
node_cpp
)
nodepy
=
NodePyOP
(
node_cpp
)
nodepy
.
name
=
str
(
node_cpp
)
.
s
plit
(
':'
)[
0
].
strip
().
replace
(
'%'
,
''
)
nodepy
.
name
=
node_cpp
.
s
copeName
()
+
'_'
+
node_cpp
.
kind
(
)
self
.
nodes
.
append
(
nodepy
)
self
.
nodes
.
append
(
nodepy
)
def
sub_node_names
(
self
):
def
sub_node_names
(
self
):
...
@@ -186,7 +222,8 @@ class NodePyGroup(NodePy):
...
@@ -186,7 +222,8 @@ class NodePyGroup(NodePy):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'
.
format
(
return
'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'
.
format
(
self
.
name
,
self
.
type
,
self
.
op_type
,
self
.
sub_node_names
(),
self
.
inputs
,
self
.
outputs
,
self
.
auxiliary
self
.
name
,
self
.
type
,
self
.
op_type
,
self
.
sub_node_names
(),
self
.
inputs
,
self
.
outputs
,
self
.
auxiliary
)
)
...
@@ -194,12 +231,14 @@ class TorchModuleGraph(TorchGraph):
...
@@ -194,12 +231,14 @@ class TorchModuleGraph(TorchGraph):
"""
"""
Generates model graph, each node is created from single or multiple jit trace nodes.
Generates model graph, each node is created from single or multiple jit trace nodes.
"""
"""
def
__init__
(
self
,
model
,
dummy_input
):
super
().
__init__
(
model
,
dummy_input
)
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
super
().
__init__
(
model
,
dummy_input
,
traced_model
)
self
.
global_count
=
0
self
.
global_count
=
0
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
def
_expand_non_prim_node
(
self
,
node
,
nodes
,
input_to_node
,
output_to_node
):
def
_expand_non_prim_node
(
self
,
node
,
nodes
,
input_to_node
,
output_to_node
,
module_type
):
"""
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
the functions directly called in module ```forward```. For such nodes, some of them are
...
@@ -217,6 +256,8 @@ class TorchModuleGraph(TorchGraph):
...
@@ -217,6 +256,8 @@ class TorchModuleGraph(TorchGraph):
key: input name, value: a node that uses this input
key: input name, value: a node that uses this input
output_to_node : dict
output_to_node : dict
key: output name, value: a node that generates this output
key: output name, value: a node that generates this output
module_type : str
can be 'module' or 'func'
Returns
Returns
-------
-------
...
@@ -224,11 +265,12 @@ class TorchModuleGraph(TorchGraph):
...
@@ -224,11 +265,12 @@ class TorchModuleGraph(TorchGraph):
the expanded non-prim node
the expanded non-prim node
"""
"""
# TODO: scope name could be empty
# TODO: scope name could be empty
node_name
=
'.'
.
join
([
self
.
_get_module_name
(
node
.
scopeName
()),
node
.
kind
(),
str
(
self
.
global_count
)])
node_name
=
'.'
.
join
([
self
.
_get_module_name
(
node
.
scopeName
()),
node
.
kind
(),
str
(
self
.
global_count
)])
unique_name
=
node_name
_logger
.
debug
(
"expand non-prim node, node name: %s"
,
node_name
)
_logger
.
debug
(
"expand non-prim node, node name: %s"
,
node_name
)
self
.
global_count
+=
1
self
.
global_count
+=
1
op_type
=
node
.
kind
()
op_type
=
node
.
kind
()
node_group
=
[
node
]
node_group
=
[
node
]
inputs
=
list
()
inputs
=
list
()
outputs
=
list
()
outputs
=
list
()
...
@@ -249,28 +291,78 @@ class TorchModuleGraph(TorchGraph):
...
@@ -249,28 +291,78 @@ class TorchModuleGraph(TorchGraph):
inputs
.
append
(
input_name
)
inputs
.
append
(
input_name
)
for
output
in
node
.
outputs
():
for
output
in
node
.
outputs
():
outputs
.
append
(
output
.
debugName
())
outputs
.
append
(
output
.
debugName
())
nodepy
=
NodePyGroup
(
node_name
,
'func'
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
)
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
)
return
nodepy
return
nodepy
def
_build_module_node_group
(
self
,
module_name
,
op_type
,
node_cpps
,
input_to_node
,
output_to_node
):
def
_expand_module_node
(
self
,
node
,
node_name
,
unique_name
,
op_type
,
nodes
,
graph
=
self
.
trace
.
graph
input_to_node
,
output_to_node
,
module_type
):
inputs
,
outputs
=
[],
[]
"""
for
n
in
node_cpps
:
merge the adjacent nodes of the module. The difference between the
for
i
in
n
.
inputs
():
_expand_module_node and _expand_non_prim_node is that, the _expand_non_prim_node
name
=
i
.
debugName
()
only merge the prim:: nodes into the aten:: node, in contrast,the _expand_module_node
if
not
name
in
output_to_node
and
i
in
graph
.
inputs
():
will merge all adjacent nodes into a same nodepy group.
inputs
.
append
(
name
)
elif
output_to_node
[
name
]
not
in
node_cpps
:
inputs
.
append
(
name
)
for
o
in
n
.
outputs
():
name
=
o
.
debugName
()
if
not
name
in
input_to_node
and
o
in
graph
.
outputs
():
outputs
.
append
(
name
)
elif
input_to_node
[
name
]
not
in
node_cpps
:
outputs
.
append
(
name
)
return
NodePyGroup
(
module_name
,
'module'
,
op_type
,
node_cpps
,
inputs
,
outputs
)
Parameters
----------
node : trace graph node
The non-prim node to expand
node_name : str
specify the node_name for NodePyGroup
unique_name : str
unique_name for the NodePyGroup
op_type : str
specify the op_type for the NodePyGroup
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
module_type : str
can be 'module' or 'func'
Returns
-------
node
the expanded non-prim node
"""
_logger
.
debug
(
"expand module node, node name: %s"
,
node_name
)
self
.
global_count
+=
1
if
not
op_type
:
op_type
=
node
.
kind
()
node_group
=
[
node
]
inputs
=
list
()
outputs
=
list
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
visited
=
{
node
}
while
not
node_queue
.
empty
():
curr_node
=
node_queue
.
get
()
for
_input
in
curr_node
.
inputs
():
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
if
predecessor_node
not
in
visited
:
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
visited
.
add
(
predecessor_node
)
else
:
inputs
.
append
(
input_name
)
for
_output
in
curr_node
.
outputs
():
output_name
=
_output
.
debugName
()
if
output_name
in
input_to_node
and
input_to_node
[
output_name
]
in
nodes
:
successor_node
=
input_to_node
[
output_name
]
if
successor_node
not
in
visited
:
node_group
.
append
(
successor_node
)
node_queue
.
put
(
successor_node
)
visited
.
add
(
successor_node
)
else
:
outputs
.
append
(
output_name
)
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
)
return
nodepy
def
_extract_shape_info
(
self
,
node
):
def
_extract_shape_info
(
self
,
node
):
"""
"""
...
@@ -318,11 +410,12 @@ class TorchModuleGraph(TorchGraph):
...
@@ -318,11 +410,12 @@ class TorchModuleGraph(TorchGraph):
parts1
,
parts2
=
name1
.
split
(
'.'
),
name2
.
split
(
'.'
)
parts1
,
parts2
=
name1
.
split
(
'.'
),
name2
.
split
(
'.'
)
if
len
(
parts1
)
>=
len
(
parts2
):
if
len
(
parts1
)
>=
len
(
parts2
):
return
False
return
False
for
i
in
range
(
len
(
parts1
)
)
:
for
i
,
_
in
enumerate
(
parts1
):
if
parts2
[
i
]
!=
parts1
[
i
]:
if
parts2
[
i
]
!=
parts1
[
i
]:
return
False
return
False
return
True
return
True
module_names
=
sorted
([
x
[
0
]
for
x
in
self
.
trace
.
named_modules
()
if
x
[
0
]])
module_names
=
sorted
([
x
[
0
]
for
x
in
self
.
trace
.
named_modules
()
if
x
[
0
]])
leaf_nodes
=
[]
leaf_nodes
=
[]
for
i
,
name
in
enumerate
(
module_names
):
for
i
,
name
in
enumerate
(
module_names
):
if
i
+
1
>=
len
(
module_names
)
or
not
is_parent
(
name
,
module_names
[
i
+
1
]):
if
i
+
1
>=
len
(
module_names
)
or
not
is_parent
(
name
,
module_names
[
i
+
1
]):
...
@@ -354,7 +447,7 @@ class TorchModuleGraph(TorchGraph):
...
@@ -354,7 +447,7 @@ class TorchModuleGraph(TorchGraph):
input_to_node
=
defaultdict
(
list
)
input_to_node
=
defaultdict
(
list
)
output_to_node
=
dict
()
output_to_node
=
dict
()
for
node
in
nodes_op
:
for
node
in
nodes_op
:
name_to_node
[
node
.
name
]
=
node
name_to_node
[
node
.
unique_
name
]
=
node
for
_input
in
node
.
inputs
:
for
_input
in
node
.
inputs
:
input_to_node
[
_input
].
append
(
node
)
input_to_node
[
_input
].
append
(
node
)
for
output
in
node
.
outputs
:
for
output
in
node
.
outputs
:
...
@@ -385,9 +478,11 @@ class TorchModuleGraph(TorchGraph):
...
@@ -385,9 +478,11 @@ class TorchModuleGraph(TorchGraph):
graph
=
self
.
trace
.
graph
graph
=
self
.
trace
.
graph
_logger
.
debug
(
graph
)
_logger
.
debug
(
graph
)
# build output mapping, from output debugName to its node
# build output mapping, from output debugName to its node
output_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
for
x
in
n
.
outputs
()}
output_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
for
x
in
n
.
outputs
()}
# build input mapping, from input debugName to its node
# build input mapping, from input debugName to its node
input_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
for
x
in
n
.
inputs
()}
input_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
for
x
in
n
.
inputs
()}
# build module mapping, from module name to all nodes (as list) under this module scope
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes
=
defaultdict
(
list
)
module_to_nodes
=
defaultdict
(
list
)
# the mapping of function (non-module in forward) to nodes, key is scope name
# the mapping of function (non-module in forward) to nodes, key is scope name
...
@@ -403,7 +498,8 @@ class TorchModuleGraph(TorchGraph):
...
@@ -403,7 +498,8 @@ class TorchModuleGraph(TorchGraph):
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
self
.
leaf_modules
=
self
.
_extract_leaf_modules
()
self
.
leaf_modules
=
self
.
_extract_leaf_modules
()
module_to_type
=
{
name
:
parse_traced_name
(
module
.
_name
)
for
name
,
module
in
self
.
trace
.
named_modules
()}
module_to_type
=
{
name
:
parse_traced_name
(
module
.
_name
)
for
name
,
module
in
self
.
trace
.
named_modules
()}
# associate module name with their trace graph nodes
# associate module name with their trace graph nodes
for
node
in
graph
.
nodes
():
for
node
in
graph
.
nodes
():
...
@@ -412,14 +508,24 @@ class TorchModuleGraph(TorchGraph):
...
@@ -412,14 +508,24 @@ class TorchModuleGraph(TorchGraph):
module_to_nodes
[
module_name
].
append
(
node
)
module_to_nodes
[
module_name
].
append
(
node
)
else
:
else
:
func_to_nodes
[
node
.
scopeName
()].
append
(
node
)
func_to_nodes
[
node
.
scopeName
()].
append
(
node
)
# build node group for module
# build node group for module
for
module_name
,
node_cpps
in
module_to_nodes
.
items
():
for
module_name
,
node_cpps
in
module_to_nodes
.
items
():
node_group
=
self
.
_build_module_node_group
(
use_count
=
0
module_name
,
module_to_type
[
module_name
],
node_cpps
,
input_to_node
,
output_to_node
merged
=
set
()
)
for
node
in
node_cpps
:
_logger
.
debug
(
'node_group: %s'
,
node_group
)
if
node
not
in
merged
:
# modules that have same scope name may have different locations in the
# graph. Futhermore, there are also lots of prim:: nodes that in node_cpps,
# so we also need to call the expand_module_node.
unique_name
=
module_name
if
use_count
>
0
:
unique_name
=
module_name
+
'.%d'
%
use_count
node_group
=
self
.
_expand_module_node
(
node
,
module_name
,
unique_name
,
module_to_type
[
module_name
],
node_cpps
,
input_to_node
,
output_to_node
,
'module'
)
nodes_py
.
nodes_op
.
append
(
node_group
)
nodes_py
.
nodes_op
.
append
(
node_group
)
use_count
+=
1
merged
.
update
(
node_group
.
node_cpps
)
# each scope_name may have multiple funcs, we split them and create node for each of them
# each scope_name may have multiple funcs, we split them and create node for each of them
# build node group for torch.nn.functional
# build node group for torch.nn.functional
...
@@ -431,11 +537,13 @@ class TorchModuleGraph(TorchGraph):
...
@@ -431,11 +537,13 @@ class TorchModuleGraph(TorchGraph):
non_prim_nodes
.
append
(
node
)
non_prim_nodes
.
append
(
node
)
# for each non prim node, expand it
# for each non prim node, expand it
for
node
in
non_prim_nodes
:
for
node
in
non_prim_nodes
:
node_group
=
self
.
_expand_non_prim_node
(
node
,
nodes
,
input_to_node
,
output_to_node
)
node_group
=
self
.
_expand_non_prim_node
(
node
,
nodes
,
input_to_node
,
output_to_node
,
'func'
)
nodes_py
.
nodes_op
.
append
(
node_group
)
nodes_py
.
nodes_op
.
append
(
node_group
)
# get shape infor for view (aten::view) func
# get shape infor for view (aten::view) func
if
node_group
.
op_type
in
[
'aten::view'
,
'aten::flatten'
]:
if
node_group
.
op_type
in
[
'aten::view'
,
'aten::flatten'
]:
node_group
.
auxiliary
=
self
.
_extract_shape_info
(
node
)
node_group
.
auxiliary
=
self
.
_extract_shape_info
(
node
)
for
node
in
graph
.
outputs
():
# Create sink nodes for output ops
for
node
in
graph
.
outputs
():
# Create sink nodes for output ops
node_py
=
NodePyIO
(
node
,
'output'
)
node_py
=
NodePyIO
(
node
,
'output'
)
nodes_py
.
append
(
node_py
)
nodes_py
.
append
(
node_py
)
...
@@ -444,14 +552,14 @@ class TorchModuleGraph(TorchGraph):
...
@@ -444,14 +552,14 @@ class TorchModuleGraph(TorchGraph):
# build index
# build index
return
self
.
_build_index
(
self
.
nodes_py
.
nodes_op
)
return
self
.
_build_index
(
self
.
nodes_py
.
nodes_op
)
def
find_predecessors
(
self
,
modul
e_name
):
def
find_predecessors
(
self
,
uniqu
e_name
):
"""
"""
Find predecessor node of the given node
Find predecessor node of the given node
Parameters
Parameters
----------
----------
modul
e_name : str
uniqu
e_name : str
The name of the node
The
unique
name of the node
Returns
Returns
-------
-------
...
@@ -459,22 +567,22 @@ class TorchModuleGraph(TorchGraph):
...
@@ -459,22 +567,22 @@ class TorchModuleGraph(TorchGraph):
a list of nodes who are the given node's predecessor
a list of nodes who are the given node's predecessor
"""
"""
predecessors
=
[]
predecessors
=
[]
for
_input
in
self
.
name_to_node
[
modul
e_name
].
inputs
:
for
_input
in
self
.
name_to_node
[
uniqu
e_name
].
inputs
:
if
not
_input
in
self
.
output_to_node
:
if
not
_input
in
self
.
output_to_node
:
_logger
.
debug
(
"cannot find node with %s as its output"
,
_input
)
_logger
.
debug
(
"cannot find node with %s as its output"
,
_input
)
else
:
else
:
node_py
=
self
.
output_to_node
[
_input
]
node_py
=
self
.
output_to_node
[
_input
]
predecessors
.
append
(
node_py
.
name
)
predecessors
.
append
(
node_py
.
unique_
name
)
return
predecessors
return
predecessors
def
find_successors
(
self
,
modul
e_name
):
def
find_successors
(
self
,
uniqu
e_name
):
"""
"""
Find successor nodes of the given node
Find successor nodes of the given node
Parameters
Parameters
----------
----------
modul
e_name : str
uniqu
e_name : str
The name of the node
The
unique
name of the node
Returns
Returns
-------
-------
...
@@ -482,9 +590,11 @@ class TorchModuleGraph(TorchGraph):
...
@@ -482,9 +590,11 @@ class TorchModuleGraph(TorchGraph):
a list of nodes who are the given node's successor
a list of nodes who are the given node's successor
"""
"""
successors
=
[]
successors
=
[]
for
output
in
self
.
name_to_node
[
module_name
].
outputs
:
for
output
in
self
.
name_to_node
[
unique_name
].
outputs
:
assert
output
in
self
.
input_to_node
,
"No node with input {}"
.
format
(
output
)
if
output
not
in
self
.
input_to_node
:
# may reach the output of the whole graph
continue
nodes_py
=
self
.
input_to_node
[
output
]
nodes_py
=
self
.
input_to_node
[
output
]
for
node_py
in
nodes_py
:
for
node_py
in
nodes_py
:
successors
.
append
(
node_py
.
name
)
successors
.
append
(
node_py
.
unique_
name
)
return
successors
return
successors
src/sdk/pynni/tests/test_graph_utils.py
View file @
7ee5036b
...
@@ -15,7 +15,7 @@ from google.protobuf import text_format
...
@@ -15,7 +15,7 @@ from google.protobuf import text_format
import
unittest
import
unittest
from
unittest
import
TestCase
,
main
from
unittest
import
TestCase
,
main
from
nni._graph_utils
import
build_module_graph
,
build_graph
from
nni._graph_utils
import
build_module_graph
,
build_graph
,
TorchModuleGraph
class
BackboneModel1
(
nn
.
Module
):
class
BackboneModel1
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -154,5 +154,45 @@ class GraphUtilsTestCase(TestCase):
...
@@ -154,5 +154,45 @@ class GraphUtilsTestCase(TestCase):
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"expect"
,
"test_graph_module3.expect"
)
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"expect"
,
"test_graph_module3.expect"
)
)
)
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4.0"
,
"not supported"
)
def
test_module_reuse
(
self
):
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
liner1
=
nn
.
Linear
(
10
,
10
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
liner2
=
nn
.
Linear
(
10
,
20
)
self
.
liner3
=
nn
.
Linear
(
20
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
liner1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
liner2
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
liner3
(
x
)
x
=
self
.
relu
(
x
)
return
x
data
=
torch
.
rand
(
10
,
10
)
net
=
MyModule
()
traced
=
torch
.
jit
.
trace
(
net
,
data
)
modulegraph
=
TorchModuleGraph
(
traced_model
=
traced
)
# Traverse the TorchModuleGraph, due the resue of the relu module,
# there will be three cpp_nodes corrspoding to the same module.
# During traversing the graph, there should be only one
# successor of each cpp-node (including the cpp_nodes that corresponds
# to the same relu module).
for
name
,
nodeio
in
modulegraph
.
nodes_py
.
nodes_io
.
items
():
if
nodeio
.
input_or_output
==
'input'
:
# Find the first node of the whole graph
start_nodes
=
modulegraph
.
input_to_node
[
name
]
# We have only one single path top-down
assert
len
(
start_nodes
)
==
1
node
=
start_nodes
[
0
].
unique_name
while
modulegraph
.
find_successors
(
node
):
nodes
=
modulegraph
.
find_successors
(
node
)
assert
len
(
nodes
)
==
1
node
=
nodes
[
0
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
main
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment