Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
efa4e31c
Unverified
Commit
efa4e31c
authored
Nov 20, 2020
by
QuanluZhang
Committed by
GitHub
Nov 20, 2020
Browse files
[Retiarii] refactor convert_name (#3101)
parent
8af73146
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
55 additions
and
57 deletions
+55
-57
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+12
-20
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+3
-3
nni/retiarii/converter/utils.py
nni/retiarii/converter/utils.py
+8
-2
nni/retiarii/execution/listener.py
nni/retiarii/execution/listener.py
+1
-1
nni/retiarii/graph.py
nni/retiarii/graph.py
+27
-23
nni/retiarii/operation.py
nni/retiarii/operation.py
+4
-8
No files found.
nni/retiarii/codegen/pytorch.py
View file @
efa4e31c
import
logging
from
typing
import
*
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
from
..operation
import
Operation
,
Cell
# TODO: fix: inputs is a list, how to deal with single element list and single element
_logger
=
logging
.
getLogger
(
__name__
)
def
model_to_pytorch_script
(
model
:
Model
)
->
str
:
graphs
=
[]
...
...
@@ -16,17 +18,9 @@ def model_to_pytorch_script(model: Model) -> str:
pkgs_code
=
'
\n
'
.
join
([
'import {}'
.
format
(
pkg
)
for
pkg
in
total_pkgs
])
return
_PyTorchScriptTemplate
.
format
(
pkgs_code
,
'
\n\n
'
.
join
(
graphs
)).
strip
()
def
_convert_name
(
name
:
str
)
->
str
:
"""
Convert the names using separator '.' to valid variable name in code
"""
return
name
.
replace
(
'.'
,
'__'
)
def
_convert_names
(
names
:
List
[
str
])
->
List
[
str
]:
return
[
_convert_name
(
name
)
for
name
in
names
]
def
_sorted_incoming_edges
(
node
:
Node
)
->
List
[
Edge
]:
edges
=
[
edge
for
edge
in
node
.
graph
.
edges
if
edge
.
tail
is
node
]
_logger
.
info
(
'sorted_incoming_edges: {}'
.
format
(
edges
))
if
not
edges
:
return
[]
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
...
...
@@ -43,9 +37,9 @@ def _format_inputs(node: Node) -> List[str]:
for
edge
in
edges
:
if
edge
.
head
.
name
==
'_inputs'
:
assert
isinstance
(
edge
.
head_slot
,
int
)
if
node
.
graph
.
input
_names
is
not
None
:
if
edge
.
head
.
operation
.
io
_names
is
not
None
:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs
.
append
(
node
.
graph
.
input
_names
[
edge
.
head_slot
])
inputs
.
append
(
edge
.
head
.
operation
.
io
_names
[
edge
.
head_slot
])
else
:
# when input has no name, e.g., forward(*_inputs)
inputs
.
append
(
'_inputs[{}]'
.
format
(
edge
.
head_slot
))
...
...
@@ -59,7 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
return
inputs
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
)
->
str
:
nodes
=
graph
.
nodes
# FIXME: topological sort is needed here
nodes
=
graph
.
nodes
# handle module node and function node differently
# only need to generate code for module here
...
...
@@ -74,11 +68,10 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
if
node_code
is
not
None
:
node_codes
.
append
(
node_code
)
if
graph
.
input_names
is
None
:
if
graph
.
input_
node
.
operation
.
io_
names
is
None
:
input_code
=
'*_inputs'
else
:
# TODO: remove _convert_names (after merging input_names and input_node)
input_code
=
', '
.
join
(
_convert_names
(
graph
.
input_names
))
input_code
=
', '
.
join
(
graph
.
input_node
.
operation
.
io_names
)
edge_codes
=
[]
sorted_nodes
=
graph
.
topo_sort
()
...
...
@@ -87,15 +80,14 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
inputs
=
_format_inputs
(
node
)
edge_codes
.
append
(
node
.
operation
.
to_forward_code
(
node
.
name
,
node
.
name
,
inputs
))
# TODO: refactor graph output_node
output_names
=
_format_inputs
(
graph
.
output_node
)
output_names
=
_convert_names
(
output_names
)
if
not
output_names
:
output_names
=
[
'None'
]
raise
RuntimeError
(
'"forward" function should have return value(s): {}, {}, {}'
.
format
(
output_names
,
graph_name
,
graph
.
output_node
))
output_code
=
', '
.
join
(
output_names
)
linebreak
=
'
\n
'
return
import_pkgs
,
_PyTorchModelTemplate
.
format
(
graph_name
=
(
'Graph'
if
graph_name
==
'_graph'
else
_convert_name
(
graph_name
)
)
,
graph_name
=
(
'Graph'
if
graph_name
==
'_graph'
else
graph_name
),
inputs
=
input_code
,
outputs
=
', '
.
join
(
output_names
),
nodes
=
linebreak
.
join
(
node_codes
),
...
...
nni/retiarii/converter/graph_gen.py
View file @
efa4e31c
...
...
@@ -7,7 +7,7 @@ from ..operation import Cell, Operation
from
..model_apis.nn
import
Placeholder
from
.op_types
import
RETIARII_BASE_OPS
,
MODULE_EXCEPT_LIST
,
Type
from
.utils
import
build_full_name
from
.utils
import
build_full_name
,
_convert_name
global_seq
=
0
...
...
@@ -149,7 +149,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
continue
graph_inputs
.
append
(
_input
)
# TODO: add scope name
ir_graph
.
_add_input
(
_input
.
debugName
())
ir_graph
.
_add_input
(
_
convert_name
(
_
input
.
debugName
())
)
node_index
=
{}
# graph node to graph ir node
...
...
@@ -315,7 +315,7 @@ def convert_module(script_module, module, module_name, ir_model):
graph_outputs
=
[]
for
_output
in
sm_graph
.
outputs
():
graph_outputs
.
append
(
_output
)
# <class 'torch._C.Value'>
ir_graph
.
_add_output
(
_output
.
debugName
())
ir_graph
.
_add_output
(
_
convert_name
(
_
output
.
debugName
())
)
predecessor_node_outputs
=
[
o
for
o
in
_output
.
node
().
outputs
()]
if
len
(
predecessor_node_outputs
)
==
1
:
src_node_idx
=
None
...
...
nni/retiarii/converter/utils.py
View file @
efa4e31c
def
build_full_name
(
prefix
,
name
,
seq
=
None
):
if
seq
is
None
:
return
'{}
.
{}'
.
format
(
prefix
,
name
)
return
'{}
__
{}'
.
format
(
prefix
,
name
)
else
:
return
'{}.{}{}'
.
format
(
prefix
,
name
,
str
(
seq
))
\ No newline at end of file
return
'{}__{}{}'
.
format
(
prefix
,
name
,
str
(
seq
))
def
_convert_name
(
name
:
str
)
->
str
:
"""
Convert the names using separator '.' to valid variable name in code
"""
return
name
.
replace
(
'.'
,
'__'
)
nni/retiarii/execution/listener.py
View file @
efa4e31c
...
...
@@ -13,7 +13,7 @@ class DefaultListener(AbstractGraphListener):
def
on_metric
(
self
,
model
:
Model
,
metric
:
MetricData
)
->
None
:
model
.
metric
=
metric
def
on_intermediate_metric
(
self
,
model
:
Model
,
metric
:
MetricData
)
->
None
:
model
.
intermediate_metrics
.
append
(
metric
)
...
...
nni/retiarii/graph.py
View file @
efa4e31c
...
...
@@ -7,7 +7,7 @@ from enum import Enum
import
json
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
.operation
import
Cell
,
Operation
,
_PseudoOperation
from
.operation
import
Cell
,
Operation
,
_
IO
PseudoOperation
__all__
=
[
'Model'
,
'ModelStatus'
,
'Graph'
,
'Node'
,
'Edge'
,
'IllegalGraphError'
,
'MetricData'
]
...
...
@@ -233,35 +233,33 @@ class Graph:
self
.
id
:
int
=
graph_id
self
.
name
:
str
=
name
or
f
'_generated_
{
graph_id
}
'
# TODO: why not merge the names into input_node and output_node???
self
.
input_names
:
Optional
[
List
[
str
]]
=
None
self
.
output_names
:
Optional
[
List
[
str
]]
=
None
self
.
input_node
:
Node
=
Node
(
self
,
_InputPseudoUid
,
'_inputs'
,
_PseudoOperation
(
'_inputs'
),
_internal
=
True
)
self
.
output_node
:
Node
=
Node
(
self
,
_OutputPseudoUid
,
'_outputs'
,
_PseudoOperation
(
'_outputs'
),
_internal
=
True
)
self
.
input_node
:
Node
=
Node
(
self
,
_InputPseudoUid
,
'_inputs'
,
_IOPseudoOperation
(
'_inputs'
),
_internal
=
True
)
self
.
output_node
:
Node
=
Node
(
self
,
_OutputPseudoUid
,
'_outputs'
,
_IOPseudoOperation
(
'_outputs'
),
_internal
=
True
)
self
.
hidden_nodes
:
List
[
Node
]
=
[]
self
.
edges
:
List
[
Edge
]
=
[]
def
__repr__
(
self
):
return
f
'Graph(id=
{
self
.
id
}
, name=
{
self
.
name
}
, input_names=
{
self
.
input_names
}
, '
+
\
f
'output_names=
{
self
.
output_names
}
, num_hidden_nodes=
{
len
(
self
.
hidden_nodes
)
}
, num_edges=
{
len
(
self
.
edges
)
}
)'
return
f
'Graph(id=
{
self
.
id
}
, name=
{
self
.
name
}
, '
+
\
f
'input_names=
{
self
.
input_node
.
operation
.
io_names
}
, '
+
\
f
'output_names=
{
self
.
output_node
.
operation
.
io_names
}
, '
+
\
f
'num_hidden_nodes=
{
len
(
self
.
hidden_nodes
)
}
, num_edges=
{
len
(
self
.
edges
)
}
)'
@
property
def
nodes
(
self
)
->
List
[
'Node'
]:
return
[
self
.
input_node
,
self
.
output_node
]
+
self
.
hidden_nodes
def
_add_input
(
self
,
input_name
)
->
None
:
if
self
.
input_names
is
None
:
self
.
input_names
=
[
input_name
]
if
self
.
input_
node
.
operation
.
io_
names
is
None
:
self
.
input_
node
.
operation
.
io_
names
=
[
input_name
]
else
:
self
.
input_names
.
append
(
input_name
)
self
.
input_
node
.
operation
.
io_
names
.
append
(
input_name
)
def
_add_output
(
self
,
output_name
)
->
None
:
if
self
.
output_names
is
None
:
self
.
output_names
=
[
output_name
]
if
self
.
output_
node
.
operation
.
io_
names
is
None
:
self
.
output_
node
.
operation
.
io_
names
=
[
output_name
]
else
:
self
.
output_names
.
append
(
output_name
)
self
.
output_
node
.
operation
.
io_
names
.
append
(
output_name
)
@
overload
def
add_node
(
self
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
...
...
@@ -351,8 +349,11 @@ class Graph:
def
_fork_to
(
self
,
model
:
Model
)
->
'Graph'
:
new_graph
=
Graph
(
model
,
self
.
id
,
self
.
name
,
_internal
=
True
).
_register
()
new_graph
.
input_names
=
self
.
input_names
new_graph
.
output_names
=
self
.
output_names
# TODO: use node copy instead
new_graph
.
input_node
.
operation
.
io_names
=
self
.
input_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
output_node
.
update_label
(
self
.
output_node
.
label
)
for
node
in
self
.
hidden_nodes
:
new_node
=
Node
(
new_graph
,
node
.
id
,
node
.
name
,
node
.
operation
,
_internal
=
True
)
...
...
@@ -372,13 +373,16 @@ class Graph:
# Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph
=
Graph
(
self
.
model
,
self
.
model
.
_uid
(),
_internal
=
True
).
_register
()
new_graph
.
input_names
=
self
.
input_names
new_graph
.
output_names
=
self
.
output_names
new_graph
.
input_node
.
operation
.
io_names
=
self
.
input_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
output_node
.
update_label
(
self
.
output_node
.
label
)
id_to_new_node
=
{}
# old node ID -> new node object
for
old_node
in
self
.
hidden_nodes
:
new_node
=
Node
(
new_graph
,
self
.
model
.
_uid
(),
None
,
old_node
.
operation
,
_internal
=
True
).
_register
()
new_node
.
update_label
(
old_node
.
label
)
id_to_new_node
[
old_node
.
id
]
=
new_node
for
edge
in
self
.
edges
:
...
...
@@ -395,8 +399,8 @@ class Graph:
@
staticmethod
def
_load
(
model
:
Model
,
name
:
str
,
ir
:
Any
)
->
'Graph'
:
graph
=
Graph
(
model
,
model
.
_uid
(),
name
,
_internal
=
True
)
graph
.
input_names
=
ir
.
get
(
'inputs'
)
graph
.
output_names
=
ir
.
get
(
'outputs'
)
graph
.
input_
node
.
operation
.
io_
names
=
ir
.
get
(
'inputs'
)
graph
.
output_
node
.
operation
.
io_
names
=
ir
.
get
(
'outputs'
)
for
node_name
,
node_data
in
ir
[
'nodes'
].
items
():
Node
.
_load
(
graph
,
node_name
,
node_data
).
_register
()
for
edge_data
in
ir
[
'edges'
]:
...
...
@@ -405,8 +409,8 @@ class Graph:
def
_dump
(
self
)
->
Any
:
return
{
'inputs'
:
self
.
input_names
,
'outputs'
:
self
.
output_names
,
'inputs'
:
self
.
input_
node
.
operation
.
io_
names
,
'outputs'
:
self
.
output_
node
.
operation
.
io_
names
,
'nodes'
:
{
node
.
name
:
node
.
_dump
()
for
node
in
self
.
hidden_nodes
},
'edges'
:
[
edge
.
_dump
()
for
edge
in
self
.
edges
]
}
...
...
nni/retiarii/operation.py
View file @
efa4e31c
...
...
@@ -98,7 +98,6 @@ class PyTorchOperation(Operation):
return
None
def
to_init_code
(
self
,
field
:
str
)
->
str
:
field
=
_convert_name
(
field
)
if
self
.
_to_class_name
()
is
not
None
:
assert
'positional_args'
not
in
self
.
parameters
kw_params
=
', '
.
join
(
f
'
{
key
}
=
{
repr
(
value
)
}
'
for
key
,
value
in
self
.
parameters
.
items
())
...
...
@@ -106,9 +105,6 @@ class PyTorchOperation(Operation):
return
None
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
field
=
_convert_name
(
field
)
output
=
_convert_name
(
output
)
inputs
=
[
_convert_name
(
_input
)
for
_input
in
inputs
]
if
self
.
_to_class_name
()
is
not
None
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
elif
self
.
type
.
startswith
(
'Function.'
):
...
...
@@ -176,16 +172,16 @@ class Cell(PyTorchOperation):
return
_convert_name
(
self
.
cell_name
)
class
_PseudoOperation
(
Operation
):
class
_
IO
PseudoOperation
(
Operation
):
"""
This is the pseudo operation used by I/O nodes.
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
"""
def
__init__
(
self
,
type_name
:
str
):
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
=
None
):
assert
type_name
.
startswith
(
'_'
)
s
elf
.
type
=
type_name
self
.
parameters
=
{}
s
uper
(
_IOPseudoOperation
,
self
).
__init__
(
type_name
,
{},
True
)
self
.
io_names
=
io_names
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment