Unverified Commit efa4e31c authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] refactor convert_name (#3101)

parent 8af73146
import logging
from typing import *
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from ..operation import Operation, Cell
# TODO: fix: inputs is a list, how to deal with single element list and single element
_logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model) -> str:
graphs = []
......@@ -16,17 +18,9 @@ def model_to_pytorch_script(model: Model) -> str:
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
def _convert_names(names: List[str]) -> List[str]:
return [_convert_name(name) for name in names]
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: {}'.format(edges))
if not edges:
return []
if all(edge.tail_slot is None for edge in edges):
......@@ -43,9 +37,9 @@ def _format_inputs(node: Node) -> List[str]:
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
if node.graph.input_names is not None:
if edge.head.operation.io_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(node.graph.input_names[edge.head_slot])
inputs.append(edge.head.operation.io_names[edge.head_slot])
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
......@@ -59,7 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
return inputs
def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
nodes = graph.nodes # FIXME: topological sort is needed here
nodes = graph.nodes
# handle module node and function node differently
# only need to generate code for module here
......@@ -74,11 +68,10 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
if node_code is not None:
node_codes.append(node_code)
if graph.input_names is None:
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
# TODO: remove _convert_names (after merging input_names and input_node)
input_code = ', '.join(_convert_names(graph.input_names))
input_code = ', '.join(graph.input_node.operation.io_names)
edge_codes = []
sorted_nodes = graph.topo_sort()
......@@ -87,15 +80,14 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
inputs = _format_inputs(node)
edge_codes.append(node.operation.to_forward_code(node.name, node.name, inputs))
# TODO: refactor graph output_node
output_names = _format_inputs(graph.output_node)
output_names = _convert_names(output_names)
if not output_names:
output_names = ['None']
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names)
linebreak = '\n '
return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else _convert_name(graph_name)),
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=', '.join(output_names),
nodes=linebreak.join(node_codes),
......
......@@ -7,7 +7,7 @@ from ..operation import Cell, Operation
from ..model_apis.nn import Placeholder
from .op_types import RETIARII_BASE_OPS, MODULE_EXCEPT_LIST, Type
from .utils import build_full_name
from .utils import build_full_name, _convert_name
global_seq = 0
......@@ -149,7 +149,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
continue
graph_inputs.append(_input)
# TODO: add scope name
ir_graph._add_input(_input.debugName())
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
......@@ -315,7 +315,7 @@ def convert_module(script_module, module, module_name, ir_model):
graph_outputs = []
for _output in sm_graph.outputs():
graph_outputs.append(_output) # <class 'torch._C.Value'>
ir_graph._add_output(_output.debugName())
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
......
def build_full_name(prefix, name, seq=None):
if seq is None:
return '{}.{}'.format(prefix, name)
return '{}__{}'.format(prefix, name)
else:
return '{}.{}{}'.format(prefix, name, str(seq))
\ No newline at end of file
return '{}__{}{}'.format(prefix, name, str(seq))
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
......@@ -13,7 +13,7 @@ class DefaultListener(AbstractGraphListener):
def on_metric(self, model: Model, metric: MetricData) -> None:
model.metric = metric
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
model.intermediate_metrics.append(metric)
......
......@@ -7,7 +7,7 @@ from enum import Enum
import json
from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _PseudoOperation
from .operation import Cell, Operation, _IOPseudoOperation
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
......@@ -233,35 +233,33 @@ class Graph:
self.id: int = graph_id
self.name: str = name or f'_generated_{graph_id}'
# TODO: why not merge the names into input_node and output_node???
self.input_names: Optional[List[str]] = None
self.output_names: Optional[List[str]] = None
self.input_node: Node = Node(self, _InputPseudoUid, '_inputs', _PseudoOperation('_inputs'), _internal=True)
self.output_node: Node = Node(self, _OutputPseudoUid, '_outputs', _PseudoOperation('_outputs'), _internal=True)
self.input_node: Node = Node(self, _InputPseudoUid, '_inputs', _IOPseudoOperation('_inputs'), _internal=True)
self.output_node: Node = Node(self, _OutputPseudoUid, '_outputs', _IOPseudoOperation('_outputs'), _internal=True)
self.hidden_nodes: List[Node] = []
self.edges: List[Edge] = []
def __repr__(self):
return f'Graph(id={self.id}, name={self.name}, input_names={self.input_names}, ' + \
f'output_names={self.output_names}, num_hidden_nodes={len(self.hidden_nodes)}, num_edges={len(self.edges)})'
return f'Graph(id={self.id}, name={self.name}, ' + \
f'input_names={self.input_node.operation.io_names}, ' + \
f'output_names={self.output_node.operation.io_names}, ' + \
f'num_hidden_nodes={len(self.hidden_nodes)}, num_edges={len(self.edges)})'
@property
def nodes(self) -> List['Node']:
return [self.input_node, self.output_node] + self.hidden_nodes
def _add_input(self, input_name) -> None:
if self.input_names is None:
self.input_names = [input_name]
if self.input_node.operation.io_names is None:
self.input_node.operation.io_names = [input_name]
else:
self.input_names.append(input_name)
self.input_node.operation.io_names.append(input_name)
def _add_output(self, output_name) -> None:
if self.output_names is None:
self.output_names = [output_name]
if self.output_node.operation.io_names is None:
self.output_node.operation.io_names = [output_name]
else:
self.output_names.append(output_name)
self.output_node.operation.io_names.append(output_name)
@overload
def add_node(self, name: str, operation: Operation) -> 'Node': ...
......@@ -351,8 +349,11 @@ class Graph:
def _fork_to(self, model: Model) -> 'Graph':
new_graph = Graph(model, self.id, self.name, _internal=True)._register()
new_graph.input_names = self.input_names
new_graph.output_names = self.output_names
# TODO: use node copy instead
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label)
for node in self.hidden_nodes:
new_node = Node(new_graph, node.id, node.name, node.operation, _internal=True)
......@@ -372,13 +373,16 @@ class Graph:
# Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph = Graph(self.model, self.model._uid(), _internal=True)._register()
new_graph.input_names = self.input_names
new_graph.output_names = self.output_names
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label)
id_to_new_node = {} # old node ID -> new node object
for old_node in self.hidden_nodes:
new_node = Node(new_graph, self.model._uid(), None, old_node.operation, _internal=True)._register()
new_node.update_label(old_node.label)
id_to_new_node[old_node.id] = new_node
for edge in self.edges:
......@@ -395,8 +399,8 @@ class Graph:
@staticmethod
def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, model._uid(), name, _internal=True)
graph.input_names = ir.get('inputs')
graph.output_names = ir.get('outputs')
graph.input_node.operation.io_names = ir.get('inputs')
graph.output_node.operation.io_names = ir.get('outputs')
for node_name, node_data in ir['nodes'].items():
Node._load(graph, node_name, node_data)._register()
for edge_data in ir['edges']:
......@@ -405,8 +409,8 @@ class Graph:
def _dump(self) -> Any:
return {
'inputs': self.input_names,
'outputs': self.output_names,
'inputs': self.input_node.operation.io_names,
'outputs': self.output_node.operation.io_names,
'nodes': {node.name: node._dump() for node in self.hidden_nodes},
'edges': [edge._dump() for edge in self.edges]
}
......
......@@ -98,7 +98,6 @@ class PyTorchOperation(Operation):
return None
def to_init_code(self, field: str) -> str:
field = _convert_name(field)
if self._to_class_name() is not None:
assert 'positional_args' not in self.parameters
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
......@@ -106,9 +105,6 @@ class PyTorchOperation(Operation):
return None
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
field = _convert_name(field)
output = _convert_name(output)
inputs = [_convert_name(_input) for _input in inputs]
if self._to_class_name() is not None:
return f'{output} = self.{field}({", ".join(inputs)})'
elif self.type.startswith('Function.'):
......@@ -176,16 +172,16 @@ class Cell(PyTorchOperation):
return _convert_name(self.cell_name)
class _PseudoOperation(Operation):
class _IOPseudoOperation(Operation):
"""
This is the pseudo operation used by I/O nodes.
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
"""
def __init__(self, type_name: str):
def __init__(self, type_name: str, io_names: List = None):
assert type_name.startswith('_')
self.type = type_name
self.parameters = {}
super(_IOPseudoOperation, self).__init__(type_name, {}, True)
self.io_names = io_names
def to_init_code(self, field: str) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment