Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
efa4e31c
Unverified
Commit
efa4e31c
authored
Nov 20, 2020
by
QuanluZhang
Committed by
GitHub
Nov 20, 2020
Browse files
[Retiarii] refactor convert_name (#3101)
parent
8af73146
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
55 additions
and
57 deletions
+55
-57
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+12
-20
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+3
-3
nni/retiarii/converter/utils.py
nni/retiarii/converter/utils.py
+8
-2
nni/retiarii/execution/listener.py
nni/retiarii/execution/listener.py
+1
-1
nni/retiarii/graph.py
nni/retiarii/graph.py
+27
-23
nni/retiarii/operation.py
nni/retiarii/operation.py
+4
-8
No files found.
nni/retiarii/codegen/pytorch.py
View file @
efa4e31c
import
logging
from
typing
import
*
from
typing
import
*
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
from
..operation
import
Operation
,
Cell
from
..operation
import
Operation
,
Cell
# TODO: fix: inputs is a list, how to deal with single element list and single element
_logger
=
logging
.
getLogger
(
__name__
)
def
model_to_pytorch_script
(
model
:
Model
)
->
str
:
def
model_to_pytorch_script
(
model
:
Model
)
->
str
:
graphs
=
[]
graphs
=
[]
...
@@ -16,17 +18,9 @@ def model_to_pytorch_script(model: Model) -> str:
...
@@ -16,17 +18,9 @@ def model_to_pytorch_script(model: Model) -> str:
pkgs_code
=
'
\n
'
.
join
([
'import {}'
.
format
(
pkg
)
for
pkg
in
total_pkgs
])
pkgs_code
=
'
\n
'
.
join
([
'import {}'
.
format
(
pkg
)
for
pkg
in
total_pkgs
])
return
_PyTorchScriptTemplate
.
format
(
pkgs_code
,
'
\n\n
'
.
join
(
graphs
)).
strip
()
return
_PyTorchScriptTemplate
.
format
(
pkgs_code
,
'
\n\n
'
.
join
(
graphs
)).
strip
()
def
_convert_name
(
name
:
str
)
->
str
:
"""
Convert the names using separator '.' to valid variable name in code
"""
return
name
.
replace
(
'.'
,
'__'
)
def
_convert_names
(
names
:
List
[
str
])
->
List
[
str
]:
return
[
_convert_name
(
name
)
for
name
in
names
]
def
_sorted_incoming_edges
(
node
:
Node
)
->
List
[
Edge
]:
def
_sorted_incoming_edges
(
node
:
Node
)
->
List
[
Edge
]:
edges
=
[
edge
for
edge
in
node
.
graph
.
edges
if
edge
.
tail
is
node
]
edges
=
[
edge
for
edge
in
node
.
graph
.
edges
if
edge
.
tail
is
node
]
_logger
.
info
(
'sorted_incoming_edges: {}'
.
format
(
edges
))
if
not
edges
:
if
not
edges
:
return
[]
return
[]
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
...
@@ -43,9 +37,9 @@ def _format_inputs(node: Node) -> List[str]:
...
@@ -43,9 +37,9 @@ def _format_inputs(node: Node) -> List[str]:
for
edge
in
edges
:
for
edge
in
edges
:
if
edge
.
head
.
name
==
'_inputs'
:
if
edge
.
head
.
name
==
'_inputs'
:
assert
isinstance
(
edge
.
head_slot
,
int
)
assert
isinstance
(
edge
.
head_slot
,
int
)
if
node
.
graph
.
input
_names
is
not
None
:
if
edge
.
head
.
operation
.
io
_names
is
not
None
:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs
.
append
(
node
.
graph
.
input
_names
[
edge
.
head_slot
])
inputs
.
append
(
edge
.
head
.
operation
.
io
_names
[
edge
.
head_slot
])
else
:
else
:
# when input has no name, e.g., forward(*_inputs)
# when input has no name, e.g., forward(*_inputs)
inputs
.
append
(
'_inputs[{}]'
.
format
(
edge
.
head_slot
))
inputs
.
append
(
'_inputs[{}]'
.
format
(
edge
.
head_slot
))
...
@@ -59,7 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
...
@@ -59,7 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
return
inputs
return
inputs
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
)
->
str
:
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
)
->
str
:
nodes
=
graph
.
nodes
# FIXME: topological sort is needed here
nodes
=
graph
.
nodes
# handle module node and function node differently
# handle module node and function node differently
# only need to generate code for module here
# only need to generate code for module here
...
@@ -74,11 +68,10 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
...
@@ -74,11 +68,10 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
if
node_code
is
not
None
:
if
node_code
is
not
None
:
node_codes
.
append
(
node_code
)
node_codes
.
append
(
node_code
)
if
graph
.
input_names
is
None
:
if
graph
.
input_
node
.
operation
.
io_
names
is
None
:
input_code
=
'*_inputs'
input_code
=
'*_inputs'
else
:
else
:
# TODO: remove _convert_names (after merging input_names and input_node)
input_code
=
', '
.
join
(
graph
.
input_node
.
operation
.
io_names
)
input_code
=
', '
.
join
(
_convert_names
(
graph
.
input_names
))
edge_codes
=
[]
edge_codes
=
[]
sorted_nodes
=
graph
.
topo_sort
()
sorted_nodes
=
graph
.
topo_sort
()
...
@@ -87,15 +80,14 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
...
@@ -87,15 +80,14 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
inputs
=
_format_inputs
(
node
)
inputs
=
_format_inputs
(
node
)
edge_codes
.
append
(
node
.
operation
.
to_forward_code
(
node
.
name
,
node
.
name
,
inputs
))
edge_codes
.
append
(
node
.
operation
.
to_forward_code
(
node
.
name
,
node
.
name
,
inputs
))
# TODO: refactor graph output_node
output_names
=
_format_inputs
(
graph
.
output_node
)
output_names
=
_format_inputs
(
graph
.
output_node
)
output_names
=
_convert_names
(
output_names
)
if
not
output_names
:
if
not
output_names
:
output_names
=
[
'None'
]
raise
RuntimeError
(
'"forward" function should have return value(s): {}, {}, {}'
.
format
(
output_names
,
graph_name
,
graph
.
output_node
))
output_code
=
', '
.
join
(
output_names
)
linebreak
=
'
\n
'
linebreak
=
'
\n
'
return
import_pkgs
,
_PyTorchModelTemplate
.
format
(
return
import_pkgs
,
_PyTorchModelTemplate
.
format
(
graph_name
=
(
'Graph'
if
graph_name
==
'_graph'
else
_convert_name
(
graph_name
)
)
,
graph_name
=
(
'Graph'
if
graph_name
==
'_graph'
else
graph_name
),
inputs
=
input_code
,
inputs
=
input_code
,
outputs
=
', '
.
join
(
output_names
),
outputs
=
', '
.
join
(
output_names
),
nodes
=
linebreak
.
join
(
node_codes
),
nodes
=
linebreak
.
join
(
node_codes
),
...
...
nni/retiarii/converter/graph_gen.py
View file @
efa4e31c
...
@@ -7,7 +7,7 @@ from ..operation import Cell, Operation
...
@@ -7,7 +7,7 @@ from ..operation import Cell, Operation
from
..model_apis.nn
import
Placeholder
from
..model_apis.nn
import
Placeholder
from
.op_types
import
RETIARII_BASE_OPS
,
MODULE_EXCEPT_LIST
,
Type
from
.op_types
import
RETIARII_BASE_OPS
,
MODULE_EXCEPT_LIST
,
Type
from
.utils
import
build_full_name
from
.utils
import
build_full_name
,
_convert_name
global_seq
=
0
global_seq
=
0
...
@@ -149,7 +149,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
...
@@ -149,7 +149,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
continue
continue
graph_inputs
.
append
(
_input
)
graph_inputs
.
append
(
_input
)
# TODO: add scope name
# TODO: add scope name
ir_graph
.
_add_input
(
_input
.
debugName
())
ir_graph
.
_add_input
(
_
convert_name
(
_
input
.
debugName
())
)
node_index
=
{}
# graph node to graph ir node
node_index
=
{}
# graph node to graph ir node
...
@@ -315,7 +315,7 @@ def convert_module(script_module, module, module_name, ir_model):
...
@@ -315,7 +315,7 @@ def convert_module(script_module, module, module_name, ir_model):
graph_outputs
=
[]
graph_outputs
=
[]
for
_output
in
sm_graph
.
outputs
():
for
_output
in
sm_graph
.
outputs
():
graph_outputs
.
append
(
_output
)
# <class 'torch._C.Value'>
graph_outputs
.
append
(
_output
)
# <class 'torch._C.Value'>
ir_graph
.
_add_output
(
_output
.
debugName
())
ir_graph
.
_add_output
(
_
convert_name
(
_
output
.
debugName
())
)
predecessor_node_outputs
=
[
o
for
o
in
_output
.
node
().
outputs
()]
predecessor_node_outputs
=
[
o
for
o
in
_output
.
node
().
outputs
()]
if
len
(
predecessor_node_outputs
)
==
1
:
if
len
(
predecessor_node_outputs
)
==
1
:
src_node_idx
=
None
src_node_idx
=
None
...
...
nni/retiarii/converter/utils.py
View file @
efa4e31c
def
build_full_name
(
prefix
,
name
,
seq
=
None
):
def
build_full_name
(
prefix
,
name
,
seq
=
None
):
if
seq
is
None
:
if
seq
is
None
:
return
'{}
.
{}'
.
format
(
prefix
,
name
)
return
'{}
__
{}'
.
format
(
prefix
,
name
)
else
:
else
:
return
'{}.{}{}'
.
format
(
prefix
,
name
,
str
(
seq
))
return
'{}__{}{}'
.
format
(
prefix
,
name
,
str
(
seq
))
\ No newline at end of file
def
_convert_name
(
name
:
str
)
->
str
:
"""
Convert the names using separator '.' to valid variable name in code
"""
return
name
.
replace
(
'.'
,
'__'
)
nni/retiarii/execution/listener.py
View file @
efa4e31c
...
@@ -13,7 +13,7 @@ class DefaultListener(AbstractGraphListener):
...
@@ -13,7 +13,7 @@ class DefaultListener(AbstractGraphListener):
def
on_metric
(
self
,
model
:
Model
,
metric
:
MetricData
)
->
None
:
def
on_metric
(
self
,
model
:
Model
,
metric
:
MetricData
)
->
None
:
model
.
metric
=
metric
model
.
metric
=
metric
def
on_intermediate_metric
(
self
,
model
:
Model
,
metric
:
MetricData
)
->
None
:
def
on_intermediate_metric
(
self
,
model
:
Model
,
metric
:
MetricData
)
->
None
:
model
.
intermediate_metrics
.
append
(
metric
)
model
.
intermediate_metrics
.
append
(
metric
)
...
...
nni/retiarii/graph.py
View file @
efa4e31c
...
@@ -7,7 +7,7 @@ from enum import Enum
...
@@ -7,7 +7,7 @@ from enum import Enum
import
json
import
json
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
.operation
import
Cell
,
Operation
,
_PseudoOperation
from
.operation
import
Cell
,
Operation
,
_
IO
PseudoOperation
__all__
=
[
'Model'
,
'ModelStatus'
,
'Graph'
,
'Node'
,
'Edge'
,
'IllegalGraphError'
,
'MetricData'
]
__all__
=
[
'Model'
,
'ModelStatus'
,
'Graph'
,
'Node'
,
'Edge'
,
'IllegalGraphError'
,
'MetricData'
]
...
@@ -233,35 +233,33 @@ class Graph:
...
@@ -233,35 +233,33 @@ class Graph:
self
.
id
:
int
=
graph_id
self
.
id
:
int
=
graph_id
self
.
name
:
str
=
name
or
f
'_generated_
{
graph_id
}
'
self
.
name
:
str
=
name
or
f
'_generated_
{
graph_id
}
'
# TODO: why not merge the names into input_node and output_node???
self
.
input_node
:
Node
=
Node
(
self
,
_InputPseudoUid
,
'_inputs'
,
_IOPseudoOperation
(
'_inputs'
),
_internal
=
True
)
self
.
input_names
:
Optional
[
List
[
str
]]
=
None
self
.
output_node
:
Node
=
Node
(
self
,
_OutputPseudoUid
,
'_outputs'
,
_IOPseudoOperation
(
'_outputs'
),
_internal
=
True
)
self
.
output_names
:
Optional
[
List
[
str
]]
=
None
self
.
input_node
:
Node
=
Node
(
self
,
_InputPseudoUid
,
'_inputs'
,
_PseudoOperation
(
'_inputs'
),
_internal
=
True
)
self
.
output_node
:
Node
=
Node
(
self
,
_OutputPseudoUid
,
'_outputs'
,
_PseudoOperation
(
'_outputs'
),
_internal
=
True
)
self
.
hidden_nodes
:
List
[
Node
]
=
[]
self
.
hidden_nodes
:
List
[
Node
]
=
[]
self
.
edges
:
List
[
Edge
]
=
[]
self
.
edges
:
List
[
Edge
]
=
[]
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'Graph(id=
{
self
.
id
}
, name=
{
self
.
name
}
, input_names=
{
self
.
input_names
}
, '
+
\
return
f
'Graph(id=
{
self
.
id
}
, name=
{
self
.
name
}
, '
+
\
f
'output_names=
{
self
.
output_names
}
, num_hidden_nodes=
{
len
(
self
.
hidden_nodes
)
}
, num_edges=
{
len
(
self
.
edges
)
}
)'
f
'input_names=
{
self
.
input_node
.
operation
.
io_names
}
, '
+
\
f
'output_names=
{
self
.
output_node
.
operation
.
io_names
}
, '
+
\
f
'num_hidden_nodes=
{
len
(
self
.
hidden_nodes
)
}
, num_edges=
{
len
(
self
.
edges
)
}
)'
@
property
@
property
def
nodes
(
self
)
->
List
[
'Node'
]:
def
nodes
(
self
)
->
List
[
'Node'
]:
return
[
self
.
input_node
,
self
.
output_node
]
+
self
.
hidden_nodes
return
[
self
.
input_node
,
self
.
output_node
]
+
self
.
hidden_nodes
def
_add_input
(
self
,
input_name
)
->
None
:
def
_add_input
(
self
,
input_name
)
->
None
:
if
self
.
input_names
is
None
:
if
self
.
input_
node
.
operation
.
io_
names
is
None
:
self
.
input_names
=
[
input_name
]
self
.
input_
node
.
operation
.
io_
names
=
[
input_name
]
else
:
else
:
self
.
input_names
.
append
(
input_name
)
self
.
input_
node
.
operation
.
io_
names
.
append
(
input_name
)
def
_add_output
(
self
,
output_name
)
->
None
:
def
_add_output
(
self
,
output_name
)
->
None
:
if
self
.
output_names
is
None
:
if
self
.
output_
node
.
operation
.
io_
names
is
None
:
self
.
output_names
=
[
output_name
]
self
.
output_
node
.
operation
.
io_
names
=
[
output_name
]
else
:
else
:
self
.
output_names
.
append
(
output_name
)
self
.
output_
node
.
operation
.
io_
names
.
append
(
output_name
)
@
overload
@
overload
def
add_node
(
self
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
def
add_node
(
self
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
...
@@ -351,8 +349,11 @@ class Graph:
...
@@ -351,8 +349,11 @@ class Graph:
def
_fork_to
(
self
,
model
:
Model
)
->
'Graph'
:
def
_fork_to
(
self
,
model
:
Model
)
->
'Graph'
:
new_graph
=
Graph
(
model
,
self
.
id
,
self
.
name
,
_internal
=
True
).
_register
()
new_graph
=
Graph
(
model
,
self
.
id
,
self
.
name
,
_internal
=
True
).
_register
()
new_graph
.
input_names
=
self
.
input_names
# TODO: use node copy instead
new_graph
.
output_names
=
self
.
output_names
new_graph
.
input_node
.
operation
.
io_names
=
self
.
input_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
output_node
.
update_label
(
self
.
output_node
.
label
)
for
node
in
self
.
hidden_nodes
:
for
node
in
self
.
hidden_nodes
:
new_node
=
Node
(
new_graph
,
node
.
id
,
node
.
name
,
node
.
operation
,
_internal
=
True
)
new_node
=
Node
(
new_graph
,
node
.
id
,
node
.
name
,
node
.
operation
,
_internal
=
True
)
...
@@ -372,13 +373,16 @@ class Graph:
...
@@ -372,13 +373,16 @@ class Graph:
# Copy this graph inside the model.
# Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different.
# The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph
=
Graph
(
self
.
model
,
self
.
model
.
_uid
(),
_internal
=
True
).
_register
()
new_graph
=
Graph
(
self
.
model
,
self
.
model
.
_uid
(),
_internal
=
True
).
_register
()
new_graph
.
input_names
=
self
.
input_names
new_graph
.
input_node
.
operation
.
io_names
=
self
.
input_node
.
operation
.
io_names
new_graph
.
output_names
=
self
.
output_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
output_node
.
update_label
(
self
.
output_node
.
label
)
id_to_new_node
=
{}
# old node ID -> new node object
id_to_new_node
=
{}
# old node ID -> new node object
for
old_node
in
self
.
hidden_nodes
:
for
old_node
in
self
.
hidden_nodes
:
new_node
=
Node
(
new_graph
,
self
.
model
.
_uid
(),
None
,
old_node
.
operation
,
_internal
=
True
).
_register
()
new_node
=
Node
(
new_graph
,
self
.
model
.
_uid
(),
None
,
old_node
.
operation
,
_internal
=
True
).
_register
()
new_node
.
update_label
(
old_node
.
label
)
id_to_new_node
[
old_node
.
id
]
=
new_node
id_to_new_node
[
old_node
.
id
]
=
new_node
for
edge
in
self
.
edges
:
for
edge
in
self
.
edges
:
...
@@ -395,8 +399,8 @@ class Graph:
...
@@ -395,8 +399,8 @@ class Graph:
@
staticmethod
@
staticmethod
def
_load
(
model
:
Model
,
name
:
str
,
ir
:
Any
)
->
'Graph'
:
def
_load
(
model
:
Model
,
name
:
str
,
ir
:
Any
)
->
'Graph'
:
graph
=
Graph
(
model
,
model
.
_uid
(),
name
,
_internal
=
True
)
graph
=
Graph
(
model
,
model
.
_uid
(),
name
,
_internal
=
True
)
graph
.
input_names
=
ir
.
get
(
'inputs'
)
graph
.
input_
node
.
operation
.
io_
names
=
ir
.
get
(
'inputs'
)
graph
.
output_names
=
ir
.
get
(
'outputs'
)
graph
.
output_
node
.
operation
.
io_
names
=
ir
.
get
(
'outputs'
)
for
node_name
,
node_data
in
ir
[
'nodes'
].
items
():
for
node_name
,
node_data
in
ir
[
'nodes'
].
items
():
Node
.
_load
(
graph
,
node_name
,
node_data
).
_register
()
Node
.
_load
(
graph
,
node_name
,
node_data
).
_register
()
for
edge_data
in
ir
[
'edges'
]:
for
edge_data
in
ir
[
'edges'
]:
...
@@ -405,8 +409,8 @@ class Graph:
...
@@ -405,8 +409,8 @@ class Graph:
def
_dump
(
self
)
->
Any
:
def
_dump
(
self
)
->
Any
:
return
{
return
{
'inputs'
:
self
.
input_names
,
'inputs'
:
self
.
input_
node
.
operation
.
io_
names
,
'outputs'
:
self
.
output_names
,
'outputs'
:
self
.
output_
node
.
operation
.
io_
names
,
'nodes'
:
{
node
.
name
:
node
.
_dump
()
for
node
in
self
.
hidden_nodes
},
'nodes'
:
{
node
.
name
:
node
.
_dump
()
for
node
in
self
.
hidden_nodes
},
'edges'
:
[
edge
.
_dump
()
for
edge
in
self
.
edges
]
'edges'
:
[
edge
.
_dump
()
for
edge
in
self
.
edges
]
}
}
...
...
nni/retiarii/operation.py
View file @
efa4e31c
...
@@ -98,7 +98,6 @@ class PyTorchOperation(Operation):
...
@@ -98,7 +98,6 @@ class PyTorchOperation(Operation):
return
None
return
None
def
to_init_code
(
self
,
field
:
str
)
->
str
:
def
to_init_code
(
self
,
field
:
str
)
->
str
:
field
=
_convert_name
(
field
)
if
self
.
_to_class_name
()
is
not
None
:
if
self
.
_to_class_name
()
is
not
None
:
assert
'positional_args'
not
in
self
.
parameters
assert
'positional_args'
not
in
self
.
parameters
kw_params
=
', '
.
join
(
f
'
{
key
}
=
{
repr
(
value
)
}
'
for
key
,
value
in
self
.
parameters
.
items
())
kw_params
=
', '
.
join
(
f
'
{
key
}
=
{
repr
(
value
)
}
'
for
key
,
value
in
self
.
parameters
.
items
())
...
@@ -106,9 +105,6 @@ class PyTorchOperation(Operation):
...
@@ -106,9 +105,6 @@ class PyTorchOperation(Operation):
return
None
return
None
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
field
=
_convert_name
(
field
)
output
=
_convert_name
(
output
)
inputs
=
[
_convert_name
(
_input
)
for
_input
in
inputs
]
if
self
.
_to_class_name
()
is
not
None
:
if
self
.
_to_class_name
()
is
not
None
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
elif
self
.
type
.
startswith
(
'Function.'
):
elif
self
.
type
.
startswith
(
'Function.'
):
...
@@ -176,16 +172,16 @@ class Cell(PyTorchOperation):
...
@@ -176,16 +172,16 @@ class Cell(PyTorchOperation):
return
_convert_name
(
self
.
cell_name
)
return
_convert_name
(
self
.
cell_name
)
class
_PseudoOperation
(
Operation
):
class
_
IO
PseudoOperation
(
Operation
):
"""
"""
This is the pseudo operation used by I/O nodes.
This is the pseudo operation used by I/O nodes.
The benefit is that users no longer need to verify `Node.operation is not None`,
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
especially in static type checking.
"""
"""
def
__init__
(
self
,
type_name
:
str
):
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
=
None
):
assert
type_name
.
startswith
(
'_'
)
assert
type_name
.
startswith
(
'_'
)
s
elf
.
type
=
type_name
s
uper
(
_IOPseudoOperation
,
self
).
__init__
(
type_name
,
{},
True
)
self
.
parameters
=
{}
self
.
io_names
=
io_names
def
to_init_code
(
self
,
field
:
str
)
->
str
:
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment