"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "b6712d4a9abe261b34b6a62f89ed3ed1fb88fae1"
Unverified Commit b3cdee85 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

fix uid duplicate and add type hint alias for edge endpoint (#3188)

parent 192a807b
...@@ -9,15 +9,19 @@ from collections import defaultdict ...@@ -9,15 +9,19 @@ from collections import defaultdict
from typing import (Any, Dict, List, Optional, Tuple, Union, overload) from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData'] __all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
MetricData = Any MetricData = Any
""" """
Graph metrics like loss, accuracy, etc. Type hint for graph metrics (loss, accuracy, etc).
"""
# Maybe we can assume this is a single float number for first iteration. EdgeEndpoint = Tuple['Node', Optional[int]]
"""
Type hint for edge's endpoint. The int indicates nodes' order.
""" """
...@@ -88,12 +92,10 @@ class Model: ...@@ -88,12 +92,10 @@ class Model:
intermediate_metrics intermediate_metrics
Intermediate training metrics. If the model is not trained, it's an empty list. Intermediate training metrics. If the model is not trained, it's an empty list.
""" """
_cur_model_id = 0
def __init__(self, _internal=False): def __init__(self, _internal=False):
assert _internal, '`Model()` is private, use `model.fork()` instead' assert _internal, '`Model()` is private, use `model.fork()` instead'
Model._cur_model_id += 1 self.model_id: int = uid('model')
self.model_id = Model._cur_model_id
self.status: ModelStatus = ModelStatus.Mutating self.status: ModelStatus = ModelStatus.Mutating
...@@ -106,8 +108,6 @@ class Model: ...@@ -106,8 +108,6 @@ class Model:
self.metric: Optional[MetricData] = None self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = [] self.intermediate_metrics: List[MetricData] = []
self._last_uid: int = 0 # FIXME: this should be global, not model-wise
def __repr__(self): def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \ return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'training_config={self.training_config}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})' f'training_config={self.training_config}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'
...@@ -130,13 +130,8 @@ class Model: ...@@ -130,13 +130,8 @@ class Model:
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()} new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.training_config = copy.deepcopy(self.training_config) new_model.training_config = copy.deepcopy(self.training_config)
new_model.history = self.history + [self] new_model.history = self.history + [self]
new_model._last_uid = self._last_uid
return new_model return new_model
def _uid(self) -> int:
self._last_uid += 1
return self._last_uid
@staticmethod @staticmethod
def _load(ir: Any) -> 'Model': def _load(ir: Any) -> 'Model':
model = Model(_internal=True) model = Model(_internal=True)
...@@ -295,7 +290,7 @@ class Graph: ...@@ -295,7 +290,7 @@ class Graph:
op = operation_or_type op = operation_or_type
else: else:
op = Operation.new(operation_or_type, parameters, name) op = Operation.new(operation_or_type, parameters, name)
return Node(self, self.model._uid(), name, op, _internal=True)._register() return Node(self, uid(), name, op, _internal=True)._register()
@overload @overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ... def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
...@@ -307,7 +302,7 @@ class Graph: ...@@ -307,7 +302,7 @@ class Graph:
op = operation_or_type op = operation_or_type
else: else:
op = Operation.new(operation_or_type, parameters, name) op = Operation.new(operation_or_type, parameters, name)
new_node = Node(self, self.model._uid(), name, op, _internal=True)._register() new_node = Node(self, uid(), name, op, _internal=True)._register()
# update edges # update edges
self.add_edge((edge.head, edge.head_slot), (new_node, None)) self.add_edge((edge.head, edge.head_slot), (new_node, None))
self.add_edge((new_node, None), (edge.tail, edge.tail_slot)) self.add_edge((new_node, None), (edge.tail, edge.tail_slot))
...@@ -315,7 +310,7 @@ class Graph: ...@@ -315,7 +310,7 @@ class Graph:
return new_node return new_node
# mutation # mutation
def add_edge(self, head: Tuple['Node', Optional[int]], tail: Tuple['Node', Optional[int]]) -> 'Edge': def add_edge(self, head: EdgeEndpoint, tail: EdgeEndpoint) -> 'Edge':
assert head[0].graph is self and tail[0].graph is self assert head[0].graph is self and tail[0].graph is self
return Edge(head, tail, _internal=True)._register() return Edge(head, tail, _internal=True)._register()
...@@ -414,7 +409,7 @@ class Graph: ...@@ -414,7 +409,7 @@ class Graph:
def _copy(self) -> 'Graph': def _copy(self) -> 'Graph':
# Copy this graph inside the model. # Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different. # The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph = Graph(self.model, self.model._uid(), _internal=True)._register() new_graph = Graph(self.model, uid(), _internal=True)._register()
new_graph.input_node.operation.io_names = self.input_node.operation.io_names new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label) new_graph.input_node.update_label(self.input_node.label)
...@@ -423,7 +418,7 @@ class Graph: ...@@ -423,7 +418,7 @@ class Graph:
id_to_new_node = {} # old node ID -> new node object id_to_new_node = {} # old node ID -> new node object
for old_node in self.hidden_nodes: for old_node in self.hidden_nodes:
new_node = Node(new_graph, self.model._uid(), None, old_node.operation, _internal=True)._register() new_node = Node(new_graph, uid(), None, old_node.operation, _internal=True)._register()
new_node.update_label(old_node.label) new_node.update_label(old_node.label)
id_to_new_node[old_node.id] = new_node id_to_new_node[old_node.id] = new_node
...@@ -440,7 +435,7 @@ class Graph: ...@@ -440,7 +435,7 @@ class Graph:
@staticmethod @staticmethod
def _load(model: Model, name: str, ir: Any) -> 'Graph': def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, model._uid(), name, _internal=True) graph = Graph(model, uid(), name, _internal=True)
graph.input_node.operation.io_names = ir.get('inputs') graph.input_node.operation.io_names = ir.get('inputs')
graph.output_node.operation.io_names = ir.get('outputs') graph.output_node.operation.io_names = ir.get('outputs')
for node_name, node_data in ir['nodes'].items(): for node_name, node_data in ir['nodes'].items():
...@@ -501,6 +496,8 @@ class Node: ...@@ -501,6 +496,8 @@ class Node:
self.graph: Graph = graph self.graph: Graph = graph
self.id: int = node_id self.id: int = node_id
self.name: str = name or f'_generated_{node_id}' self.name: str = name or f'_generated_{node_id}'
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
self.operation: Operation = operation self.operation: Operation = operation
self.label: str = None self.label: str = None
...@@ -577,7 +574,7 @@ class Node: ...@@ -577,7 +574,7 @@ class Node:
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {})) op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}))
else: else:
op = Operation.new(ir['operation']['type'], ir['operation'].get('parameters', {})) op = Operation.new(ir['operation']['type'], ir['operation'].get('parameters', {}))
node = Node(graph, graph.model._uid(), name, op) node = Node(graph, uid(), name, op)
if 'label' in ir: if 'label' in ir:
node.update_label(ir['label']) node.update_label(ir['label'])
return node return node
...@@ -626,11 +623,7 @@ class Edge: ...@@ -626,11 +623,7 @@ class Edge:
If the node does not care about order, this can be `-1`. If the node does not care about order, this can be `-1`.
""" """
def __init__( def __init__(self, head: EdgeEndpoint, tail: EdgeEndpoint, _internal: bool = False):
self,
head: Tuple[Node, Optional[int]],
tail: Tuple[Node, Optional[int]],
_internal: bool = False):
assert _internal, '`Edge()` is private' assert _internal, '`Edge()` is private'
self.graph: Graph = head[0].graph self.graph: Graph = head[0].graph
self.head: Node = head[0] self.head: Node = head[0]
......
from collections import defaultdict
import inspect import inspect
def import_(target: str, allow_none: bool = False) -> 'Any': def import_(target: str, allow_none: bool = False) -> 'Any':
...@@ -7,7 +8,6 @@ def import_(target: str, allow_none: bool = False) -> 'Any': ...@@ -7,7 +8,6 @@ def import_(target: str, allow_none: bool = False) -> 'Any':
module = __import__(path, globals(), locals(), [identifier]) module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier) return getattr(module, identifier)
_records = {} _records = {}
def get_records(): def get_records():
...@@ -83,3 +83,9 @@ def register_trainer(): ...@@ -83,3 +83,9 @@ def register_trainer():
return m return m
return _register return _register
_last_uid = defaultdict(int)
def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
return _last_uid[namespace]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment