Unverified Commit b3cdee85 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

fix uid duplicate and add type hint alias for edge endpoint (#3188)

parent 192a807b
......@@ -9,15 +9,19 @@ from collections import defaultdict
from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
MetricData = Any
"""
Graph metrics like loss, accuracy, etc.
Type hint for graph metrics (loss, accuracy, etc).
"""
# Maybe we can assume this is a single float number for first iteration.
EdgeEndpoint = Tuple['Node', Optional[int]]
"""
Type hint for edge's endpoint. The int indicates nodes' order.
"""
......@@ -88,12 +92,10 @@ class Model:
intermediate_metrics
Intermediate training metrics. If the model is not trained, it's an empty list.
"""
_cur_model_id = 0
def __init__(self, _internal=False):
assert _internal, '`Model()` is private, use `model.fork()` instead'
Model._cur_model_id += 1
self.model_id = Model._cur_model_id
self.model_id: int = uid('model')
self.status: ModelStatus = ModelStatus.Mutating
......@@ -106,8 +108,6 @@ class Model:
self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = []
self._last_uid: int = 0 # FIXME: this should be global, not model-wise
def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'training_config={self.training_config}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'
......@@ -130,13 +130,8 @@ class Model:
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.training_config = copy.deepcopy(self.training_config)
new_model.history = self.history + [self]
new_model._last_uid = self._last_uid
return new_model
def _uid(self) -> int:
self._last_uid += 1
return self._last_uid
@staticmethod
def _load(ir: Any) -> 'Model':
model = Model(_internal=True)
......@@ -295,7 +290,7 @@ class Graph:
op = operation_or_type
else:
op = Operation.new(operation_or_type, parameters, name)
return Node(self, self.model._uid(), name, op, _internal=True)._register()
return Node(self, uid(), name, op, _internal=True)._register()
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
......@@ -307,7 +302,7 @@ class Graph:
op = operation_or_type
else:
op = Operation.new(operation_or_type, parameters, name)
new_node = Node(self, self.model._uid(), name, op, _internal=True)._register()
new_node = Node(self, uid(), name, op, _internal=True)._register()
# update edges
self.add_edge((edge.head, edge.head_slot), (new_node, None))
self.add_edge((new_node, None), (edge.tail, edge.tail_slot))
......@@ -315,7 +310,7 @@ class Graph:
return new_node
# mutation
def add_edge(self, head: Tuple['Node', Optional[int]], tail: Tuple['Node', Optional[int]]) -> 'Edge':
def add_edge(self, head: EdgeEndpoint, tail: EdgeEndpoint) -> 'Edge':
assert head[0].graph is self and tail[0].graph is self
return Edge(head, tail, _internal=True)._register()
......@@ -414,7 +409,7 @@ class Graph:
def _copy(self) -> 'Graph':
# Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph = Graph(self.model, self.model._uid(), _internal=True)._register()
new_graph = Graph(self.model, uid(), _internal=True)._register()
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
......@@ -423,7 +418,7 @@ class Graph:
id_to_new_node = {} # old node ID -> new node object
for old_node in self.hidden_nodes:
new_node = Node(new_graph, self.model._uid(), None, old_node.operation, _internal=True)._register()
new_node = Node(new_graph, uid(), None, old_node.operation, _internal=True)._register()
new_node.update_label(old_node.label)
id_to_new_node[old_node.id] = new_node
......@@ -440,7 +435,7 @@ class Graph:
@staticmethod
def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, model._uid(), name, _internal=True)
graph = Graph(model, uid(), name, _internal=True)
graph.input_node.operation.io_names = ir.get('inputs')
graph.output_node.operation.io_names = ir.get('outputs')
for node_name, node_data in ir['nodes'].items():
......@@ -501,6 +496,8 @@ class Node:
self.graph: Graph = graph
self.id: int = node_id
self.name: str = name or f'_generated_{node_id}'
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
self.operation: Operation = operation
self.label: str = None
......@@ -577,7 +574,7 @@ class Node:
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}))
else:
op = Operation.new(ir['operation']['type'], ir['operation'].get('parameters', {}))
node = Node(graph, graph.model._uid(), name, op)
node = Node(graph, uid(), name, op)
if 'label' in ir:
node.update_label(ir['label'])
return node
......@@ -626,11 +623,7 @@ class Edge:
If the node does not care about order, this can be `-1`.
"""
def __init__(
self,
head: Tuple[Node, Optional[int]],
tail: Tuple[Node, Optional[int]],
_internal: bool = False):
def __init__(self, head: EdgeEndpoint, tail: EdgeEndpoint, _internal: bool = False):
assert _internal, '`Edge()` is private'
self.graph: Graph = head[0].graph
self.head: Node = head[0]
......
from collections import defaultdict
import inspect
def import_(target: str, allow_none: bool = False) -> 'Any':
......@@ -7,7 +8,6 @@ def import_(target: str, allow_none: bool = False) -> 'Any':
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)
_records = {}
def get_records():
......@@ -83,3 +83,9 @@ def register_trainer():
return m
return _register
_last_uid = defaultdict(int)
def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
return _last_uid[namespace]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment