"test/vscode:/vscode.git/clone" did not exist on "7d101f83711703b27996a3c6fc64dd6cb101ec7d"
Unverified Commit efa4e31c authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] refactor convert_name (#3101)

parent 8af73146
import logging
from typing import * from typing import *
from ..graph import IllegalGraphError, Edge, Graph, Node, Model from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from ..operation import Operation, Cell from ..operation import Operation, Cell
# TODO: fix: inputs is a list, how to deal with single element list and single element _logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model) -> str: def model_to_pytorch_script(model: Model) -> str:
graphs = [] graphs = []
...@@ -16,17 +18,9 @@ def model_to_pytorch_script(model: Model) -> str: ...@@ -16,17 +18,9 @@ def model_to_pytorch_script(model: Model) -> str:
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs]) pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip() return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
def _convert_names(names: List[str]) -> List[str]:
return [_convert_name(name) for name in names]
def _sorted_incoming_edges(node: Node) -> List[Edge]: def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node] edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: {}'.format(edges))
if not edges: if not edges:
return [] return []
if all(edge.tail_slot is None for edge in edges): if all(edge.tail_slot is None for edge in edges):
...@@ -43,9 +37,9 @@ def _format_inputs(node: Node) -> List[str]: ...@@ -43,9 +37,9 @@ def _format_inputs(node: Node) -> List[str]:
for edge in edges: for edge in edges:
if edge.head.name == '_inputs': if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int) assert isinstance(edge.head_slot, int)
if node.graph.input_names is not None: if edge.head.operation.io_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one) # when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(node.graph.input_names[edge.head_slot]) inputs.append(edge.head.operation.io_names[edge.head_slot])
else: else:
# when input has no name, e.g., forward(*_inputs) # when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot)) inputs.append('_inputs[{}]'.format(edge.head_slot))
...@@ -59,7 +53,7 @@ def _format_inputs(node: Node) -> List[str]: ...@@ -59,7 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
return inputs return inputs
def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str: def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
nodes = graph.nodes # FIXME: topological sort is needed here nodes = graph.nodes
# handle module node and function node differently # handle module node and function node differently
# only need to generate code for module here # only need to generate code for module here
...@@ -74,11 +68,10 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str: ...@@ -74,11 +68,10 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
if node_code is not None: if node_code is not None:
node_codes.append(node_code) node_codes.append(node_code)
if graph.input_names is None: if graph.input_node.operation.io_names is None:
input_code = '*_inputs' input_code = '*_inputs'
else: else:
# TODO: remove _convert_names (after merging input_names and input_node) input_code = ', '.join(graph.input_node.operation.io_names)
input_code = ', '.join(_convert_names(graph.input_names))
edge_codes = [] edge_codes = []
sorted_nodes = graph.topo_sort() sorted_nodes = graph.topo_sort()
...@@ -87,15 +80,14 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str: ...@@ -87,15 +80,14 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
inputs = _format_inputs(node) inputs = _format_inputs(node)
edge_codes.append(node.operation.to_forward_code(node.name, node.name, inputs)) edge_codes.append(node.operation.to_forward_code(node.name, node.name, inputs))
# TODO: refactor graph output_node
output_names = _format_inputs(graph.output_node) output_names = _format_inputs(graph.output_node)
output_names = _convert_names(output_names)
if not output_names: if not output_names:
output_names = ['None'] raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names)
linebreak = '\n ' linebreak = '\n '
return import_pkgs, _PyTorchModelTemplate.format( return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else _convert_name(graph_name)), graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code, inputs=input_code,
outputs=', '.join(output_names), outputs=', '.join(output_names),
nodes=linebreak.join(node_codes), nodes=linebreak.join(node_codes),
......
...@@ -7,7 +7,7 @@ from ..operation import Cell, Operation ...@@ -7,7 +7,7 @@ from ..operation import Cell, Operation
from ..model_apis.nn import Placeholder from ..model_apis.nn import Placeholder
from .op_types import RETIARII_BASE_OPS, MODULE_EXCEPT_LIST, Type from .op_types import RETIARII_BASE_OPS, MODULE_EXCEPT_LIST, Type
from .utils import build_full_name from .utils import build_full_name, _convert_name
global_seq = 0 global_seq = 0
...@@ -149,7 +149,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -149,7 +149,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
continue continue
graph_inputs.append(_input) graph_inputs.append(_input)
# TODO: add scope name # TODO: add scope name
ir_graph._add_input(_input.debugName()) ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node node_index = {} # graph node to graph ir node
...@@ -315,7 +315,7 @@ def convert_module(script_module, module, module_name, ir_model): ...@@ -315,7 +315,7 @@ def convert_module(script_module, module, module_name, ir_model):
graph_outputs = [] graph_outputs = []
for _output in sm_graph.outputs(): for _output in sm_graph.outputs():
graph_outputs.append(_output) # <class 'torch._C.Value'> graph_outputs.append(_output) # <class 'torch._C.Value'>
ir_graph._add_output(_output.debugName()) ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()] predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1: if len(predecessor_node_outputs) == 1:
src_node_idx = None src_node_idx = None
......
def build_full_name(prefix, name, seq=None): def build_full_name(prefix, name, seq=None):
if seq is None: if seq is None:
return '{}.{}'.format(prefix, name) return '{}__{}'.format(prefix, name)
else: else:
return '{}.{}{}'.format(prefix, name, str(seq)) return '{}__{}{}'.format(prefix, name, str(seq))
\ No newline at end of file
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
...@@ -7,7 +7,7 @@ from enum import Enum ...@@ -7,7 +7,7 @@ from enum import Enum
import json import json
from typing import (Any, Dict, List, Optional, Tuple, Union, overload) from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _PseudoOperation from .operation import Cell, Operation, _IOPseudoOperation
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData'] __all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
...@@ -233,35 +233,33 @@ class Graph: ...@@ -233,35 +233,33 @@ class Graph:
self.id: int = graph_id self.id: int = graph_id
self.name: str = name or f'_generated_{graph_id}' self.name: str = name or f'_generated_{graph_id}'
# TODO: why not merge the names into input_node and output_node??? self.input_node: Node = Node(self, _InputPseudoUid, '_inputs', _IOPseudoOperation('_inputs'), _internal=True)
self.input_names: Optional[List[str]] = None self.output_node: Node = Node(self, _OutputPseudoUid, '_outputs', _IOPseudoOperation('_outputs'), _internal=True)
self.output_names: Optional[List[str]] = None
self.input_node: Node = Node(self, _InputPseudoUid, '_inputs', _PseudoOperation('_inputs'), _internal=True)
self.output_node: Node = Node(self, _OutputPseudoUid, '_outputs', _PseudoOperation('_outputs'), _internal=True)
self.hidden_nodes: List[Node] = [] self.hidden_nodes: List[Node] = []
self.edges: List[Edge] = [] self.edges: List[Edge] = []
def __repr__(self): def __repr__(self):
return f'Graph(id={self.id}, name={self.name}, input_names={self.input_names}, ' + \ return f'Graph(id={self.id}, name={self.name}, ' + \
f'output_names={self.output_names}, num_hidden_nodes={len(self.hidden_nodes)}, num_edges={len(self.edges)})' f'input_names={self.input_node.operation.io_names}, ' + \
f'output_names={self.output_node.operation.io_names}, ' + \
f'num_hidden_nodes={len(self.hidden_nodes)}, num_edges={len(self.edges)})'
@property @property
def nodes(self) -> List['Node']: def nodes(self) -> List['Node']:
return [self.input_node, self.output_node] + self.hidden_nodes return [self.input_node, self.output_node] + self.hidden_nodes
def _add_input(self, input_name) -> None: def _add_input(self, input_name) -> None:
if self.input_names is None: if self.input_node.operation.io_names is None:
self.input_names = [input_name] self.input_node.operation.io_names = [input_name]
else: else:
self.input_names.append(input_name) self.input_node.operation.io_names.append(input_name)
def _add_output(self, output_name) -> None: def _add_output(self, output_name) -> None:
if self.output_names is None: if self.output_node.operation.io_names is None:
self.output_names = [output_name] self.output_node.operation.io_names = [output_name]
else: else:
self.output_names.append(output_name) self.output_node.operation.io_names.append(output_name)
@overload @overload
def add_node(self, name: str, operation: Operation) -> 'Node': ... def add_node(self, name: str, operation: Operation) -> 'Node': ...
...@@ -351,8 +349,11 @@ class Graph: ...@@ -351,8 +349,11 @@ class Graph:
def _fork_to(self, model: Model) -> 'Graph': def _fork_to(self, model: Model) -> 'Graph':
new_graph = Graph(model, self.id, self.name, _internal=True)._register() new_graph = Graph(model, self.id, self.name, _internal=True)._register()
new_graph.input_names = self.input_names # TODO: use node copy instead
new_graph.output_names = self.output_names new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label)
for node in self.hidden_nodes: for node in self.hidden_nodes:
new_node = Node(new_graph, node.id, node.name, node.operation, _internal=True) new_node = Node(new_graph, node.id, node.name, node.operation, _internal=True)
...@@ -372,13 +373,16 @@ class Graph: ...@@ -372,13 +373,16 @@ class Graph:
# Copy this graph inside the model. # Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different. # The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph = Graph(self.model, self.model._uid(), _internal=True)._register() new_graph = Graph(self.model, self.model._uid(), _internal=True)._register()
new_graph.input_names = self.input_names new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_names = self.output_names new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label)
id_to_new_node = {} # old node ID -> new node object id_to_new_node = {} # old node ID -> new node object
for old_node in self.hidden_nodes: for old_node in self.hidden_nodes:
new_node = Node(new_graph, self.model._uid(), None, old_node.operation, _internal=True)._register() new_node = Node(new_graph, self.model._uid(), None, old_node.operation, _internal=True)._register()
new_node.update_label(old_node.label)
id_to_new_node[old_node.id] = new_node id_to_new_node[old_node.id] = new_node
for edge in self.edges: for edge in self.edges:
...@@ -395,8 +399,8 @@ class Graph: ...@@ -395,8 +399,8 @@ class Graph:
@staticmethod @staticmethod
def _load(model: Model, name: str, ir: Any) -> 'Graph': def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, model._uid(), name, _internal=True) graph = Graph(model, model._uid(), name, _internal=True)
graph.input_names = ir.get('inputs') graph.input_node.operation.io_names = ir.get('inputs')
graph.output_names = ir.get('outputs') graph.output_node.operation.io_names = ir.get('outputs')
for node_name, node_data in ir['nodes'].items(): for node_name, node_data in ir['nodes'].items():
Node._load(graph, node_name, node_data)._register() Node._load(graph, node_name, node_data)._register()
for edge_data in ir['edges']: for edge_data in ir['edges']:
...@@ -405,8 +409,8 @@ class Graph: ...@@ -405,8 +409,8 @@ class Graph:
def _dump(self) -> Any: def _dump(self) -> Any:
return { return {
'inputs': self.input_names, 'inputs': self.input_node.operation.io_names,
'outputs': self.output_names, 'outputs': self.output_node.operation.io_names,
'nodes': {node.name: node._dump() for node in self.hidden_nodes}, 'nodes': {node.name: node._dump() for node in self.hidden_nodes},
'edges': [edge._dump() for edge in self.edges] 'edges': [edge._dump() for edge in self.edges]
} }
......
...@@ -98,7 +98,6 @@ class PyTorchOperation(Operation): ...@@ -98,7 +98,6 @@ class PyTorchOperation(Operation):
return None return None
def to_init_code(self, field: str) -> str: def to_init_code(self, field: str) -> str:
field = _convert_name(field)
if self._to_class_name() is not None: if self._to_class_name() is not None:
assert 'positional_args' not in self.parameters assert 'positional_args' not in self.parameters
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items()) kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
...@@ -106,9 +105,6 @@ class PyTorchOperation(Operation): ...@@ -106,9 +105,6 @@ class PyTorchOperation(Operation):
return None return None
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
field = _convert_name(field)
output = _convert_name(output)
inputs = [_convert_name(_input) for _input in inputs]
if self._to_class_name() is not None: if self._to_class_name() is not None:
return f'{output} = self.{field}({", ".join(inputs)})' return f'{output} = self.{field}({", ".join(inputs)})'
elif self.type.startswith('Function.'): elif self.type.startswith('Function.'):
...@@ -176,16 +172,16 @@ class Cell(PyTorchOperation): ...@@ -176,16 +172,16 @@ class Cell(PyTorchOperation):
return _convert_name(self.cell_name) return _convert_name(self.cell_name)
class _PseudoOperation(Operation): class _IOPseudoOperation(Operation):
""" """
This is the pseudo operation used by I/O nodes. This is the pseudo operation used by I/O nodes.
The benefit is that users no longer need to verify `Node.operation is not None`, The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking. especially in static type checking.
""" """
def __init__(self, type_name: str): def __init__(self, type_name: str, io_names: List = None):
assert type_name.startswith('_') assert type_name.startswith('_')
self.type = type_name super(_IOPseudoOperation, self).__init__(type_name, {}, True)
self.parameters = {} self.io_names = io_names
def to_init_code(self, field: str) -> str: def to_init_code(self, field: str) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"') raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment