Unverified Commit a0fd0036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
parents d6dcb483 bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time
import warnings
from typing import Iterable
from nni.nas.execution.common import (
Model, ModelStatus,
AbstractExecutionEngine,
DefaultListener
)
_execution_engine = None
_default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine
if _execution_engine is not None:
warnings.warn('Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are running in a Jupyter notebook, please restart the kernel.',
RuntimeWarning)
_execution_engine = engine
def get_execution_engine() -> AbstractExecutionEngine:
global _execution_engine
assert _execution_engine is not None, 'You need to set execution engine, before using it.'
return _execution_engine
def get_and_register_default_listener(engine: AbstractExecutionEngine) -> DefaultListener:
global _default_listener
if _default_listener is None:
_default_listener = DefaultListener()
engine.register_graph_listener(_default_listener)
return _default_listener
def submit_models(*models: Model) -> None:
engine = get_execution_engine()
get_and_register_default_listener(engine)
engine.submit_models(*models)
def list_models(*models: Model) -> Iterable[Model]:
engine = get_execution_engine()
get_and_register_default_listener(engine)
return engine.list_models()
def wait_models(*models: Model) -> None:
get_and_register_default_listener(get_execution_engine())
while True:
time.sleep(1)
left_models = [g for g in models if not g.status in (ModelStatus.Trained, ModelStatus.Failed)]
if not left_models:
break
def query_available_resources() -> int:
engine = get_execution_engine()
resources = engine.query_available_resource()
return resources if isinstance(resources, int) else len(resources)
def is_stopped_exec(model: Model) -> bool:
return model.status in (ModelStatus.Trained, ModelStatus.Failed)
def budget_exhausted() -> bool:
engine = get_execution_engine()
return engine.budget_exhausted()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .engine import *
from .graph_op import *
from .graph import *
from .integration_api import *
from .integration import *
from .listener import *
from .utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, Iterable, NewType, List, Union, Type
from .graph import Model, MetricData
__all__ = [
'GraphData', 'WorkerInfo', 'MetricData',
'AbstractGraphListener', 'AbstractExecutionEngine'
]
GraphData: Type[Any] = NewType('GraphData', Any)
"""
A _serializable_ internal data type defined by execution engine.
Execution engine will submit this kind of data through NNI to worker machine, and train it there.
A `GraphData` object describes a (merged) executable graph.
This is trial's "hyper-parameter" in NNI's term and will be transfered in JSON format.
See `AbstractExecutionEngine` for details.
"""
WorkerInfo: Type[Any] = NewType('WorkerInfo', Any)
"""
To be designed. Discussion needed.
This describes the properties of a worker machine. (e.g. memory size)
"""
class AbstractGraphListener(ABC):
"""
Abstract listener interface to receive graph events.
Use `AbstractExecutionEngine.register_graph_listener()` to activate a listener.
"""
@abstractmethod
def on_metric(self, model: Model, metric: MetricData) -> None:
"""
Reports the final metric of a graph.
"""
raise NotImplementedError
@abstractmethod
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
"""
Reports the latest intermediate metric of a trainning graph.
"""
pass
@abstractmethod
def on_training_end(self, model: Model, success: bool) -> None:
"""
Reports either a graph is fully trained or the training process has failed.
"""
pass
class AbstractExecutionEngine(ABC):
"""
The abstract interface of execution engine.
Most of these APIs are used by strategy, except `trial_execute_graph`, which is invoked by framework in trial.
Strategy will get the singleton execution engine object through a global API,
and use it in either sync or async manner.
Execution engine is responsible for submitting (maybe-optimized) models to NNI,
and assigning their metrics to the `Model` object after training.
Execution engine is also responsible to launch the graph in trial process,
because it's the only one who understands graph data, or "hyper-parameter" in NNI's term.
Execution engine will leverage NNI Advisor APIs, which are yet open for discussion.
In synchronized use case, the strategy will have a loop to call `submit_models` and `wait_models` repeatly,
and will receive metrics from `Model` attributes.
Execution engine could assume that strategy will only submit graph when there are availabe resources (for now).
In asynchronized use case, the strategy will register a listener to receive events,
while still using `submit_models` to train.
There will be a `BaseExecutionEngine` subclass.
Inner-graph optimizing is supposed to derive `BaseExecutionEngine`,
while overrides `submit_models` and `trial_execute_graph`.
cross-graph optimizing is supposed to derive `AbstractExectutionEngine` directly,
because in this case APIs like `wait_graph` and `listener.on_training_end` will have unique logic.
There might be some util functions benefit all optimizing methods,
but non-mandatory utils should not be covered in abstract interface.
"""
@abstractmethod
def submit_models(self, *models: Model) -> None:
"""
Submit models to NNI.
This method is supposed to call something like `nni.Advisor.create_trial_job(graph_data)`.
"""
raise NotImplementedError
@abstractmethod
def list_models(self) -> Iterable[Model]:
"""
Get all models in submitted.
Execution engine should store a copy of models that have been submitted and return a list of copies in this method.
"""
raise NotImplementedError
@abstractmethod
def query_available_resource(self) -> Union[List[WorkerInfo], int]: # type: ignore
"""
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
Could be left unimplemented for first iteration.
"""
raise NotImplementedError
@abstractmethod
def budget_exhausted(self) -> bool:
"""
Check whether user configured max trial number or max execution duration has been reached
"""
raise NotImplementedError
@abstractmethod
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
"""
Register a listener to receive graph events.
Could be left unimplemented for first iteration.
"""
raise NotImplementedError
@abstractclassmethod
def trial_execute_graph(cls) -> MetricData:
"""
Train graph and returns its metrics, in a separate trial process.
Each call to `nni.Advisor.create_trial_job(graph_data)` will eventually invoke this method.
Because this method will be invoked in trial process on training platform,
it has different context from other methods and has no access to global variable or `self`.
However util APIs like `.utils.experiment_config()` should still be available.
"""
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Model representation for engines based on graph.
"""
from __future__ import annotations
import json
from enum import Enum
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union, cast, overload)
if TYPE_CHECKING:
from .mutator import Mutator
from nni.nas.evaluator import Evaluator
from nni.nas.utils import uid
from .graph_op import Cell, Operation, _IOPseudoOperation
__all__ = [
'Evaluator', 'Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData',
'DebugEvaluator',
]
MetricData = Any
"""
Type hint for graph metrics (loss, accuracy, etc).
"""
EdgeEndpoint = Tuple['Node', Optional[int]]
"""
Type hint for edge's endpoint. The int indicates nodes' order.
"""
class Model:
"""
Represents a neural network model.
During mutation, one :class:`Model` object is created for each trainable snapshot.
For example, consider a mutator that insert a node at an edge for each iteration.
In one iteration, the mutator invokes 4 primitives: add node, remove edge, add edge to head, add edge to tail.
These 4 primitives operates in one :class:`Model` object.
When they are all done the model will be set to "frozen" (trainable) status and be submitted to execution engine.
And then a new iteration starts, and a new :class:`Model` object is created by forking last model.
Attributes
----------
python_object
Python object of base model. It will be none when the base model is not available.
python_class
Python class that base model is converted from.
python_init_params
Initialization parameters of python class.
status
See :class:`ModelStatus`.
root_graph
The outermost graph which usually takes dataset as input and feeds output to loss function.
graphs
All graphs (subgraphs) in this model.
evaluator
Model evaluator
history
Mutation history.
``self`` is directly mutated from ``self.history[-1]``;
``self.history[-1]`` is mutated from ``self.history[-2]``, and so on.
``self.history[0]`` is the base graph.
metric
Training result of the model, or ``None`` if it's not yet trained or has failed to train.
intermediate_metrics
Intermediate training metrics. If the model is not trained, it's an empty list.
"""
def __init__(self, _internal=False):
assert _internal, '`Model()` is private, use `model.fork()` instead'
self.model_id: int = uid('model')
self.python_object: Optional[Any] = None # type is uncertain because it could differ between DL frameworks
self.python_class: Optional[Type] = None
self.python_init_params: Optional[Dict[str, Any]] = None
self.status: ModelStatus = ModelStatus.Mutating
self._root_graph_name: str = '_model'
self.graphs: Dict[str, Graph] = {}
self.evaluator: Optional[Evaluator] = None
self.history: List['Mutation'] = []
self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = []
def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics}, ' + \
f'python_class={self.python_class})'
@property
def root_graph(self) -> 'Graph':
return self.graphs[self._root_graph_name]
def fork(self) -> 'Model':
"""
Create a new model which has same topology, names, and IDs to current one.
Can only be invoked on a frozen model.
The new model will be in `Mutating` state.
This API is used in mutator base class.
"""
new_model = Model(_internal=True)
new_model._root_graph_name = self._root_graph_name
new_model.python_class = self.python_class
new_model.python_init_params = self.python_init_params
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.evaluator = self.evaluator # TODO this needs a clever copy (not deepcopy) if we need mutation
new_model.history = [*self.history]
# Note: the history is not updated. It will be updated when the model is changed, that is in mutator.
return new_model
@staticmethod
def _load(ir: Any) -> 'Model':
model = Model(_internal=True)
for graph_name, graph_data in ir.items():
if graph_name != '_evaluator':
Graph._load(model, graph_name, graph_data)._register()
if '_evaluator' in ir:
model.evaluator = Evaluator._load(ir['_evaluator'])
return model
def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()}
if self.evaluator is not None:
ret['_evaluator'] = self.evaluator._dump()
return ret
def get_nodes(self) -> Iterable['Node']:
"""
Traverse through all the nodes.
"""
for graph in self.graphs.values():
for node in graph.nodes:
yield node
def get_nodes_by_label(self, label: str) -> List['Node']:
"""
Traverse all the nodes to find the matched node(s) with the given label.
There could be multiple nodes with the same label. Name space name can uniquely
identify a graph or node.
NOTE: the implementation does not support the class abstraction
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_label(label)
matched_nodes.extend(nodes)
return matched_nodes
def get_nodes_by_type(self, type_name: str) -> List['Node']:
"""
Traverse all the nodes to find the matched node(s) with the given type.
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_type(type_name)
matched_nodes.extend(nodes)
return matched_nodes
def get_node_by_name(self, node_name: str) -> 'Node' | None:
"""
Traverse all the nodes to find the matched node with the given name.
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_name(node_name)
matched_nodes.extend(nodes)
assert len(matched_nodes) <= 1
if matched_nodes:
return matched_nodes[0]
else:
return None
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
"""
Traverse all the nodes to find the matched node with the given python_name.
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_python_name(python_name)
matched_nodes.extend(nodes)
# assert len(matched_nodes) <= 1
if matched_nodes:
return matched_nodes[0]
else:
return None
def get_cell_nodes(self) -> List['Node']:
matched_nodes = []
for graph in self.graphs.values():
nodes = [node for node in graph.nodes if isinstance(node.operation, Cell)]
matched_nodes.extend(nodes)
return matched_nodes
class ModelStatus(Enum):
"""
The status of model.
A model is created in `Mutating` status.
When the mutation is done and the model get ready to train, its status becomes `Frozen`.
When training started, the model's status becomes `Training`.
If training is successfully ended, model's `metric` attribute get set and its status becomes `Trained`.
If training failed, the status becomes `Failed`.
"""
Mutating = "mutating"
Frozen = "frozen"
Training = "training"
Trained = "trained"
Failed = "failed"
_InputPseudoUid = -1
_OutputPseudoUid = -2
class Graph:
"""
Graph topology.
This class simply represents the topology, with no semantic meaning.
All other information like metric, non-graph functions, mutation history, etc should go to :class:`Model`.
Each graph belongs to and only belongs to one :class:`Model`.
Attributes
----------
model
The model containing (and owning) this graph.
id
Unique ID in the model.
If two models have graphs of identical ID, they are semantically the same graph.
Typically this means one graph is mutated from another, or they are both mutated from one ancestor.
name
Mnemonic name of this graph. It should have an one-to-one mapping with ID.
input_names
Optional mnemonic names of input parameters.
output_names
Optional mnemonic names of output values.
input_node
Incoming node.
output_node
Output node.
hidden_nodes
Hidden nodes
nodes
All input/output/hidden nodes.
edges
Edges.
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
"""
def __init__(self, model: Model, graph_id: int, name: str = cast(str, None), _internal: bool = False):
assert _internal, '`Graph()` is private'
self.model: Model = model
self.id: int = graph_id
self.name: str = name or f'_generated_{graph_id}'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self.python_name: Optional[str] = None
self.input_node: Node = Node(self, _InputPseudoUid, '_inputs', _IOPseudoOperation('_inputs'), _internal=True)
self.output_node: Node = Node(self, _OutputPseudoUid, '_outputs', _IOPseudoOperation('_outputs'), _internal=True)
self.hidden_nodes: List[Node] = []
self.edges: List[Edge] = []
def __repr__(self):
return f'Graph(id={self.id}, name={self.name}, ' + \
f'input_names={self.input_node.operation.io_names}, ' + \
f'output_names={self.output_node.operation.io_names}, ' + \
f'num_hidden_nodes={len(self.hidden_nodes)}, num_edges={len(self.edges)})'
@property
def nodes(self) -> List['Node']:
return [self.input_node, self.output_node] + self.hidden_nodes
def _add_input(self, input_name) -> None:
if self.input_node.operation.io_names is None:
self.input_node.operation.io_names = [input_name]
else:
self.input_node.operation.io_names.append(input_name)
def _add_output(self, output_name) -> None:
if self.output_node.operation.io_names is None:
self.output_node.operation.io_names = [output_name]
else:
self.output_node.operation.io_names.append(output_name)
@overload
def add_node(self, name: str, operation: Operation) -> 'Node': ...
@overload
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def add_node(self, name, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
op = Operation.new(operation_or_type, cast(dict, parameters), name)
return Node(self, uid(), name, op, _internal=True)._register()
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node': # type: ignore
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
op = Operation.new(operation_or_type, cast(dict, parameters), name)
new_node = Node(self, uid(), name, op, _internal=True)._register()
# update edges
self.add_edge((edge.head, edge.head_slot), (new_node, None))
self.add_edge((new_node, None), (edge.tail, edge.tail_slot))
self.del_edge(edge)
return new_node
# mutation
def add_edge(self, head: EdgeEndpoint, tail: EdgeEndpoint) -> 'Edge':
assert head[0].graph is self and tail[0].graph is self
return Edge(head, tail, _internal=True)._register()
def del_edge(self, edge: 'Edge') -> None:
self.edges.remove(edge)
def get_node_by_name(self, name: str) -> Optional['Node']:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found = [node for node in self.nodes if node.name == name]
return found[0] if found else None
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
"""
Returns the node which has specified python_name; or returns `None` if no node has this python_name.
"""
found = [node for node in self.nodes if node.python_name == python_name]
return found[0] if found else None
def get_nodes_by_type(self, operation_type: str) -> List['Node']:
"""
Returns nodes whose operation is specified typed.
"""
return [node for node in self.hidden_nodes if node.operation.type == operation_type]
def get_node_by_id(self, node_id: int) -> Optional['Node']:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found = [node for node in self.nodes if node.id == node_id]
return found[0] if found else None
def get_nodes_by_label(self, label: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.label == label]
def get_nodes_by_name(self, name: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.name == name]
def get_nodes_by_python_name(self, python_name: str) -> List['Node']:
return [node for node in self.nodes if node.python_name == python_name]
def topo_sort(self) -> List['Node']:
node_to_fanin = {}
curr_nodes = []
for node in self.nodes:
fanin = len(node.incoming_edges)
node_to_fanin[node] = fanin
if fanin == 0:
curr_nodes.append(node)
sorted_nodes = []
while curr_nodes:
curr_node = curr_nodes.pop(0)
sorted_nodes.append(curr_node)
# use successor_slots because a node may connect to another node multiple times
# to different slots
for successor_slot in curr_node.successor_slots:
successor = successor_slot[0]
node_to_fanin[successor] -= 1
if node_to_fanin[successor] == 0:
curr_nodes.append(successor)
for key in node_to_fanin:
assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(
key,
node_to_fanin[key],
key.predecessors[0],
self.edges,
node_to_fanin.values(),
node_to_fanin.keys())
return sorted_nodes
def fork(self) -> 'Graph':
"""
Fork the model and returns corresponding graph in new model.
This shortcut might be helpful because many algorithms only cares about "stem" subgraph instead of whole model.
"""
return self.model.fork().graphs[self.name]
def __eq__(self, other: object) -> bool:
return self is other
def _fork_to(self, model: Model, name_prefix='') -> 'Graph':
new_graph = Graph(model, self.id, name_prefix + self.name, _internal=True)._register()
# TODO: use node copy instead
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label)
new_graph.python_name = self.python_name
for node in self.hidden_nodes:
new_node = Node(new_graph, node.id, node.name, node.operation, _internal=True)
new_node.python_name = node.python_name
new_node.update_label(node.label)
new_node._register()
id_to_new_node = {node.id: node for node in new_graph.nodes}
for edge in self.edges:
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
return new_graph
def _copy(self) -> 'Graph':
# Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph = Graph(self.model, uid(), _internal=True)._register()
new_graph.input_node.operation.io_names = self.input_node.operation.io_names
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label)
new_graph.python_name = self.python_name
id_to_new_node = {} # old node ID -> new node object
for old_node in self.hidden_nodes:
new_node = Node(new_graph, uid(), None, old_node.operation, _internal=True)._register()
new_node.python_name = old_node.python_name
new_node.update_label(old_node.label)
id_to_new_node[old_node.id] = new_node
for edge in self.edges:
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
return new_graph
def _register(self) -> 'Graph':
self.model.graphs[self.name] = self
return self
def _rename_graph(self, old_name, new_name):
self.model.graphs[old_name].name = new_name
self.model.graphs[new_name] = self.model.graphs[old_name]
del self.model.graphs[old_name]
@staticmethod
def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, uid(), name, _internal=True)
graph.input_node.operation.io_names = ir.get('inputs')
graph.output_node.operation.io_names = ir.get('outputs')
for node_name, node_data in ir['nodes'].items():
Node._load(graph, node_name, node_data)._register()
for edge_data in ir['edges']:
Edge._load(graph, edge_data)._register()
return graph
def _dump(self) -> Any:
return {
'inputs': self.input_node.operation.io_names,
'outputs': self.output_node.operation.io_names,
'nodes': {node.name: node._dump() for node in self.hidden_nodes},
'edges': [edge._dump() for edge in self.edges]
}
class Node:
"""
An operation or an opaque subgraph inside a graph.
Each node belongs to and only belongs to one :class:`Graph`.
Nodes should never be created with constructor. Use :meth:`Graph.add_node` instead.
The node itself is for topology only.
Information of tensor calculation should all go inside ``operation`` attribute.
TODO: parameter of subgraph (cell)
It's easy to assign parameters on cell node, but it's hard to "use" them.
We need to design a way to reference stored cell parameters in inner node operations.
e.g. ``self.fc = Linear(self.units)`` <- how to express ``self.units`` in IR?
Attributes
----------
graph
The graph containing this node.
id
Unique ID in the model.
If two models have nodes with same ID, they are semantically the same node.
name
Mnemonic name. It should have an one-to-one mapping with ID.
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
label
Optional. If two nodes have the same label, they are considered same by the mutator.
operation
Operation.
cell
Read only shortcut to get the referenced subgraph.
If this node is not a subgraph (is a primitive operation), accessing ``cell`` will raise an error.
predecessors
Predecessor nodes of this node in the graph. This is an optional mutation helper.
successors
Successor nodes of this node in the graph. This is an optional mutation helper.
incoming_edges
Incoming edges of this node in the graph. This is an optional mutation helper.
outgoing_edges
Outgoing edges of this node in the graph. This is an optional mutation helper.
"""
def __init__(self, graph, node_id, name, operation, _internal=False):
self.graph: Graph = graph
self.id: int = node_id
self.name: str = name or f'_generated_{node_id}'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self.python_name: Optional[str] = None
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
self.operation: Operation = operation
self.label: Optional[str] = None
def __repr__(self):
return f'Node(id={self.id}, name={self.name}, python_name={self.python_name}, label={self.label}, operation={self.operation})'
@property
def predecessors(self) -> List['Node']:
return sorted(set(edge.head for edge in self.incoming_edges), key=(lambda node: node.id))
@property
def successors(self) -> List['Node']:
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))
@property
def successor_slots(self) -> Set[Tuple['Node', Union[int, None]]]:
return set((edge.tail, edge.tail_slot) for edge in self.outgoing_edges)
@property
def incoming_edges(self) -> List['Edge']:
return [edge for edge in self.graph.edges if edge.tail is self]
@property
def outgoing_edges(self) -> List['Edge']:
return [edge for edge in self.graph.edges if edge.head is self]
@property
def cell(self) -> Graph:
assert isinstance(self.operation, Cell)
return self.graph.model.graphs[self.operation.parameters['cell']]
def update_label(self, label: Optional[str]) -> None:
self.label = label
@overload
def update_operation(self, operation: Operation) -> None: ...
@overload
def update_operation(self, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> None: ...
def update_operation(self, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation):
self.operation = operation_or_type
else:
self.operation = Operation.new(operation_or_type, cast(dict, parameters))
# mutation
def remove(self) -> None:
assert not self.incoming_edges and not self.outgoing_edges
self.graph.hidden_nodes.remove(self)
# mutation
def specialize_cell(self) -> Graph:
"""
Only available if the operation is a cell.
Duplicate the cell template and let this node reference to newly created copy.
"""
new_cell = self.cell._copy()._register()
self.operation = Cell(new_cell.name)
return new_cell
def __eq__(self, other: object) -> bool:
return self is other
def __hash__(self) -> int:
return hash(id(self))
def _register(self) -> 'Node':
self.graph.hidden_nodes.append(self)
return self
@staticmethod
def _load(graph: Graph, name: str, ir: Any) -> 'Node':
if ir['operation']['type'] == '_cell':
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}), attributes=ir['operation'].get('attributes', {}))
else:
op = Operation.new(ir['operation']['type'],
ir['operation'].get('parameters', {}),
attributes=ir['operation'].get('attributes', {}))
node = Node(graph, uid(), name, op)
if 'label' in ir:
node.update_label(ir['label'])
return node
def _dump(self) -> Any:
ret: Dict[str, Any] = {
'operation': {
'type': self.operation.type,
'parameters': self.operation.parameters,
'attributes': self.operation.attributes
}
}
if isinstance(self.operation, Cell):
ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None:
ret['label'] = self.label
if self.python_name is not None:
ret['python_name'] = self.python_name
return ret
class Edge:
"""
A tensor, or "data flow", between two nodes.
Example forward code snippet: ::
a, b, c = split(x)
p = concat(a, c)
q = sum(b, p)
z = relu(q)
Edges in above snippet: ::
+ head: (split, 0), tail: (concat, 0) # a in concat
+ head: (split, 2), tail: (concat, 1) # c in concat
+ head: (split, 1), tail: (sum, -1 or 0) # b in sum
+ head: (concat, null), tail: (sum, -1 or 1) # p in sum
+ head: (sum, null), tail: (relu, null) # q in relu
Attributes
----------
graph
Graph.
head
Head node.
tail
Tail node.
head_slot
Index of outputs in head node.
If the node has only one output, this should be ``null``.
tail_slot
Index of inputs in tail node.
If the node has only one input, this should be ``null``.
If the node does not care about order, this can be ``-1``.
"""
def __init__(self, head: EdgeEndpoint, tail: EdgeEndpoint, _internal: bool = False):
assert _internal, '`Edge()` is private'
self.graph: Graph = head[0].graph
self.head: Node = head[0]
self.tail: Node = tail[0]
self.head_slot: Optional[int] = head[1]
self.tail_slot: Optional[int] = tail[1]
def __repr__(self):
return f'Edge(head=({self.head}, {self.head_slot}), tail=({self.tail}, {self.tail_slot}))'
# mutation
def remove(self) -> None:
self.graph.edges.remove(self)
def _register(self) -> 'Edge':
self.graph.edges.append(self)
return self
@staticmethod
def _load(graph: Graph, ir: Any) -> 'Edge':
head = graph.get_node_by_name(ir['head'][0])
tail = graph.get_node_by_name(ir['tail'][0])
assert head is not None and tail is not None
return Edge((head, ir['head'][1]), (tail, ir['tail'][1]), _internal=True)
def _dump(self) -> Any:
return {
'head': [self.head.name, self.head_slot],
'tail': [self.tail.name, self.tail_slot]
}
class Mutation:
"""
An execution of mutation, which consists of four parts: a mutator, a list of decisions (choices),
the model that it comes from, and the model that it becomes.
In general cases, the mutation logs are not reliable and should not be replayed as the mutators can
be arbitrarily complex. However, for inline mutations, the labels correspond to mutator labels here,
this can be useful for metadata visualization and python execution mode.
Attributes
----------
mutator
Mutator.
samples
Decisions/choices.
from_
Model that is comes from.
to
Model that it becomes.
"""
def __init__(self, mutator: 'Mutator', samples: List[Any], from_: Model, to: Model): # noqa: F821
self.mutator: 'Mutator' = mutator # noqa: F821
self.samples: List[Any] = samples
self.from_: Model = from_
self.to: Model = to
def __repr__(self):
return f'Edge(mutator={self.mutator}, samples={self.samples}, from={self.from_}, to={self.to})'
class IllegalGraphError(ValueError):
def __init__(self, graph, *args):
self._debug_dump_graph(graph)
super().__init__(*args)
@staticmethod
def _debug_dump_graph(graph):
if isinstance(graph, Graph):
graph = graph._dump()
with open('generated/debug.json', 'w') as dump_file:
json.dump(graph, dump_file, indent=4)
class DebugEvaluator(Evaluator):
@staticmethod
def _load(ir: Any) -> 'DebugEvaluator':
return DebugEvaluator()
def _dump(self) -> Any:
return {'type': DebugEvaluator}
def _execute(self, model_cls: type) -> Any:
pass
def __eq__(self, other) -> bool:
return True
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Operations used in graph-based engine.
"""
from typing import (Any, Dict, List, Optional, cast)
from nni.common.framework import get_default_framework
__all__ = ['Operation', 'Cell', 'PyTorchOperation', 'TensorFlowOperation']
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
class Operation:
"""
Calculation logic of a graph node.
The constructor is private. Use `Operation.new()` to create operation object.
`Operation` is a naive record.
Do not "mutate" its attributes or store information relate to specific node.
All complex logic should be implemented in `Node` class.
Attributes
----------
type
Operation type name (e.g. Conv2D).
If it starts with underscore, the "operation" is a special one (e.g. subgraph, input/output).
parameters
Arbitrary key-value parameters (e.g. kernel_size).
"""
io_names: List[str] = []
def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}):
assert _internal, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type_name
self.parameters: Dict[str, Any] = parameters
self.attributes: Dict[str, Any] = attributes
def to_init_code(self, field: str) -> str:
raise NotImplementedError()
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise NotImplementedError()
def _to_class_name(self) -> str:
raise NotImplementedError()
def __bool__(self) -> bool:
return True
@staticmethod
def new(type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None), cell_name: str = cast(str, None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Operation':
parameters = parameters or {}
attributes = attributes or {}
if type_name == '_cell':
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters)
else:
if get_default_framework() in ('torch', 'pytorch'):
from nni.nas.execution.pytorch import op_def # pylint: disable=unused-import
cls = PyTorchOperation._find_subclass(type_name)
elif get_default_framework() in ('tf', 'tensorflow'):
from nni.nas.execution.tensorflow import op_def # pylint: disable=unused-import
cls = TensorFlowOperation._find_subclass(type_name)
else:
raise ValueError(f'Unsupported framework: {get_default_framework()}')
return cls(type_name, parameters, _internal=True, attributes=attributes)
@classmethod
def _find_subclass(cls, subclass_name):
for subclass in cls.__subclasses__():
if subclass.__name__ == subclass_name:
return subclass
return cls
def __repr__(self):
type_name = type(self).__name__
args = [f'{key}={repr(value)}' for key, value in self.parameters.items()]
if type_name != self.type:
args = [f'type="{self.type}"'] + args
return f'{type_name}({", ".join(args)})'
def __eq__(self, other):
return type(other) is type(self) and other.type == self.type and other.parameters == self.parameters
class PyTorchOperation(Operation):
@classmethod
def _find_subclass(cls, subclass_name):
if cls.to_class_name(subclass_name) is not None:
subclass_name = 'ModuleOperator'
if cls.is_functional(subclass_name):
subclass_name = 'FunctionalOperator'
for subclass in cls.__subclasses__():
if hasattr(subclass, '_ori_type_name') and \
subclass_name in cast(Any, subclass)._ori_type_name:
return subclass
for subclass in cls.__subclasses__():
if hasattr(subclass, '_artificial_op_name') and \
subclass_name in cast(Any, subclass)._artificial_op_name:
return subclass
return cls
@classmethod
def to_class_name(cls, type_name) -> Optional[str]:
if type_name.startswith('__torch__.'):
return type_name[len('__torch__.'):]
elif type_name.startswith('__mutated__.'):
return type_name[len('__mutated__.'):]
else:
return None
@classmethod
def is_functional(cls, type_name) -> bool:
return type_name.startswith('Function.')
def _to_class_name(self) -> Optional[str]:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):]
elif self.type.startswith('__mutated__.'):
return self.type[len('__mutated__.'):]
else:
return None
def get_import_pkg(self) -> Optional[str]:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):].split('.')[0]
elif self.type.startswith('__mutated__.'):
return self.type[len('__mutated__.'):].split('.')[0]
else:
return None
def to_init_code(self, field: str) -> Optional[str]:
if self._to_class_name() is not None:
assert 'positional_args' not in self.parameters
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
return f'self.{field} = {self._to_class_name()}({kw_params})'
return None
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
"""
Parameters
----------
field : str
the name of member submodule
output : str
the output name (lvalue) of this line of code
inputs : List[str]
variables used in this line of code
inputs_value : List[Any]
some variables are actually constant, their real values are recorded in ```inputs_value```.
if not constant, we simply put None at the corresponding index
Returns
-------
str
generated code line
"""
if self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation')
else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
class TensorFlowOperation(Operation):
def _to_class_name(self) -> str:
return 'K.layers.' + self.type
class Cell(PyTorchOperation):
"""
TODO: this is pytorch cell
An operation reference to a subgraph.
Example code:
```
def __init__(...):
...
self.cell = CustomCell(...)
self.relu = K.layers.ReLU()
...
def forward(...):
...
x = self.cell(x)
...
```
In above example, node `self.cell`'s operation is `Cell(cell_name='CustomCell')`.
For comparison, `self.relu`'s operation is `Operation(type='ReLU')`.
TODO: parameters of subgraph (see `Node` class)
Attributes
----------
type
Always "_cell".
parameters
A dict with only one item; the key is "cell" and the value is cell's name.
framework
No real usage. Exists for compatibility with base class.
"""
def __init__(self, cell_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)):
self.type = '_cell'
self.cell_name = cell_name
self.parameters = parameters or {}
self.attributes = attributes or {}
def _to_class_name(self):
# TODO: ugly, think about how to refactor this part
return _convert_name(self.cell_name)
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class _IOPseudoOperation(Operation):
"""
This is the pseudo operation used by I/O nodes.
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
"""
def __init__(self, type_name: str, io_names: List[str] = cast(List[str], None)):
assert type_name.startswith('_')
super(_IOPseudoOperation, self).__init__(type_name, {}, True)
self.io_names = io_names
def to_init_code(self, field: str) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def __bool__(self) -> bool:
return False
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['RetiariiAdvisor']
import logging
import os
from typing import Any, Callable, Optional, Dict, List, Tuple
import nni
from nni.common.serializer import PayloadTooLarge
from nni.common.version import version_dump
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.tuner_command_channel import CommandType
from nni.utils import MetricType
from .graph import MetricData
from .integration_api import register_advisor
_logger = logging.getLogger(__name__)
class RetiariiAdvisor(MsgDispatcherBase):
"""
The class is to connect Retiarii components to NNI backend.
It can be considered as a Python wrapper of NNI manager.
It will function as the main thread when running a Retiarii experiment through NNI.
Strategy will be launched as its thread, who will call APIs in execution engine. Execution
engine will then find the advisor singleton and send payloads to advisor.
When metrics are sent back, advisor will first receive the payloads, who will call the callback
function (that is a member function in graph listener).
The conversion advisor provides are minimum. It is only a send/receive module, and execution engine
needs to handle all the rest.
Attributes
----------
send_trial_callback
request_trial_jobs_callback
trial_end_callback
intermediate_metric_callback
final_metric_callback
"""
def __init__(self, url: str):
super().__init__(url)
register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None
self.send_trial_callback: Optional[Callable[[dict], None]] = None
self.request_trial_jobs_callback: Optional[Callable[[int], None]] = None
self.trial_end_callback: Optional[Callable[[int, bool], None]] = None
self.intermediate_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.parameters_count = 0
# Sometimes messages arrive first before the callbacks get registered.
# Or in case that we allow engine to be absent during the experiment.
# Here we need to store the messages and invoke them later.
self.call_queue: List[Tuple[str, list]] = []
def register_callbacks(self, callbacks: Dict[str, Callable[..., None]]):
"""
Register callbacks for NNI backend.
Parameters
----------
callbacks
A dictionary of callbacks.
The key is the name of the callback. The value is the callback function.
"""
self.send_trial_callback = callbacks.get('send_trial')
self.request_trial_jobs_callback = callbacks.get('request_trial_jobs')
self.trial_end_callback = callbacks.get('trial_end')
self.intermediate_metric_callback = callbacks.get('intermediate_metric')
self.final_metric_callback = callbacks.get('final_metric')
self.process_queued_callbacks()
def process_queued_callbacks(self) -> None:
"""
Process callbacks in queue.
Consume the messages that haven't been handled previously.
"""
processed_idx = []
for queue_idx, (call_name, call_args) in enumerate(self.call_queue):
if call_name == 'send_trial' and self.send_trial_callback is not None:
self.send_trial_callback(*call_args) # pylint: disable=not-callable
processed_idx.append(queue_idx)
if call_name == 'request_trial_jobs' and self.request_trial_jobs_callback is not None:
self.request_trial_jobs_callback(*call_args) # pylint: disable=not-callable
processed_idx.append(queue_idx)
if call_name == 'trial_end' and self.trial_end_callback is not None:
self.trial_end_callback(*call_args) # pylint: disable=not-callable
processed_idx.append(queue_idx)
if call_name == 'intermediate_metric' and self.intermediate_metric_callback is not None:
self.intermediate_metric_callback(*call_args) # pylint: disable=not-callable
processed_idx.append(queue_idx)
if call_name == 'final_metric' and self.final_metric_callback is not None:
self.final_metric_callback(*call_args) # pylint: disable=not-callable
processed_idx.append(queue_idx)
# Remove processed messages
for idx in reversed(processed_idx):
self.call_queue.pop(idx)
def invoke_callback(self, name: str, *args: Any) -> None:
"""
Invoke callback.
"""
self.call_queue.append((name, list(args)))
self.process_queued_callbacks()
def handle_initialize(self, data):
"""callback for initializing the advisor
Parameters
----------
data: dict
search space
"""
self.handle_update_search_space(data)
self.send(CommandType.Initialized, '')
def _validate_placement_constraint(self, placement_constraint):
if placement_constraint is None:
raise ValueError('placement_constraint is None')
if not 'type' in placement_constraint:
raise ValueError('placement_constraint must have `type`')
if not 'gpus' in placement_constraint:
raise ValueError('placement_constraint must have `gpus`')
if placement_constraint['type'] not in ['None', 'GPUNumber', 'Device']:
raise ValueError('placement_constraint.type must be either `None`,. `GPUNumber` or `Device`')
if placement_constraint['type'] == 'None' and len(placement_constraint['gpus']) > 0:
raise ValueError('placement_constraint.gpus must be an empty list when type == None')
if placement_constraint['type'] == 'GPUNumber':
if len(placement_constraint['gpus']) != 1:
raise ValueError('placement_constraint.gpus currently only support one host when type == GPUNumber')
for e in placement_constraint['gpus']:
if not isinstance(e, int):
raise ValueError('placement_constraint.gpus must be a list of number when type == GPUNumber')
if placement_constraint['type'] == 'Device':
for e in placement_constraint['gpus']:
if not isinstance(e, tuple):
raise ValueError('placement_constraint.gpus must be a list of tuple when type == Device')
if not (len(e) == 2 and isinstance(e[0], str) and isinstance(e[1], int)):
raise ValueError('placement_constraint.gpus`s tuple must be (str, int)')
def send_trial(self, parameters, placement_constraint=None):
"""
Send parameters to NNI.
Parameters
----------
parameters : Any
Any payload.
Returns
-------
int
Parameter ID that is assigned to this parameter,
which will be used for identification in future.
"""
self.parameters_count += 1
if placement_constraint is None:
placement_constraint = {
'type': 'None',
'gpus': []
}
self._validate_placement_constraint(placement_constraint)
new_trial = {
'parameter_id': self.parameters_count,
'parameters': parameters,
'parameter_source': 'algorithm',
'placement_constraint': placement_constraint,
'version_info': version_dump()
}
_logger.debug('New trial sent: %s', new_trial)
try:
send_payload = nni.dump(new_trial, pickle_size_limit=int(os.getenv('PICKLE_SIZE_LIMIT', 64 * 1024)))
except PayloadTooLarge:
raise ValueError(
'Serialization failed when trying to dump the model because payload too large (larger than 64 KB). '
'This is usually caused by pickling large objects (like datasets) by mistake. '
'See the full error traceback for details and https://nni.readthedocs.io/en/stable/NAS/Serialization.html '
'for how to resolve such issue. '
)
# trial parameters can be super large, disable pickle size limit here
# nevertheless, there could still be blocked by pipe / nni-manager
self.send(CommandType.NewTrialJob, send_payload)
self.invoke_callback('send_trial', parameters)
return self.parameters_count
def mark_experiment_as_ending(self):
self.send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials):
_logger.debug('Request trial jobs: %s', num_trials)
self.invoke_callback('request_trial_jobs', num_trials)
def handle_update_search_space(self, data):
_logger.debug('Received search space: %s', data)
self.search_space = data
def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data)
self.invoke_callback('trial_end', nni.load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:
self.invoke_callback('intermediate_metric', data['parameter_id'], self._process_value(data['value']))
elif data['type'] == MetricType.FINAL:
self.invoke_callback('final_metric', data['parameter_id'], self._process_value(data['value']))
@staticmethod
def _process_value(value) -> Any: # hopefully a float
value = nni.load(value)
if isinstance(value, dict):
if 'default' in value:
return value['default']
else:
return value
return value
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = [
'get_advisor', 'register_advisor', 'send_trial', 'receive_trial_parameters', 'get_experiment_id',
'_advisor' # FIXME: hack to make it importable for tests
]
import warnings
from typing import NewType, Any
import nni
from nni.common.version import version_check
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
_advisor = None # type is RetiariiAdvisor
def get_advisor():
# return type: RetiariiAdvisor
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor):
# type of advisor: RetiariiAdvisor
global _advisor
if _advisor is not None:
warnings.warn('Advisor is already set.'
'You should avoid instantiating RetiariiExperiment twice in one proces.'
'If you are running in a Jupyter notebook, please restart the kernel.')
_advisor = advisor
def send_trial(parameters: dict, placement_constraint=None) -> int:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return get_advisor().send_trial(parameters, placement_constraint)
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params = nni.get_next_parameter()
# version check, optional
raw_params = nni.trial._params
if raw_params is not None and 'version_info' in raw_params:
version_check(raw_params['version_info'])
else:
warnings.warn('Version check failed because `version_info` is not found.')
return params
def get_experiment_id() -> str:
return nni.get_experiment_id()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['DefaultListener']
from .graph import Model, ModelStatus, MetricData
from .engine import AbstractGraphListener
class DefaultListener(AbstractGraphListener):
def on_metric(self, model: Model, metric: MetricData) -> None:
model.metric = metric
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
model.intermediate_metrics.append(metric)
def on_training_end(self, model: Model, success: bool) -> None:
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['unpack_if_only_one', 'get_mutation_dict', 'mutation_dict_to_summary', 'get_mutation_summary']
from typing import Any, List
from .graph import Model
def unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele
def get_mutation_dict(model: Model):
return {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model.history}
def mutation_dict_to_summary(mutation: dict) -> dict:
mutation_summary = {}
for label, samples in mutation.items():
# FIXME: this check might be wrong
if not isinstance(samples, list):
mutation_summary[label] = samples
else:
for i, sample in enumerate(samples):
mutation_summary[f'{label}_{i}'] = sample
return mutation_summary
def get_mutation_summary(model: Model) -> dict:
mutation = get_mutation_dict(model)
return mutation_dict_to_summary(mutation)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import random
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
from nni.nas.execution.common import Model, receive_trial_parameters, get_mutation_dict
from .graph import BaseExecutionEngine
class BenchmarkGraphData:
SUPPORTED_BENCHMARK_LIST = [
'nasbench101',
'nasbench201-cifar10',
'nasbench201-cifar100',
'nasbench201-imagenet16',
'nds-cifar10',
'nds-imagenet',
'nlp'
]
def __init__(self, mutation: Dict[str, Any], benchmark: str,
metric_name: Optional[str] = None,
db_path: Optional[str] = None) -> None:
self.mutation = mutation # mutation dict. e.g., {'layer1': 'conv3x3', ...}
self.benchmark = benchmark # e.g., nasbench101, nasbench201, ...
self.db_path = db_path # path to directory of database
def dump(self) -> dict:
from nni.nas.benchmarks.constants import DATABASE_DIR
return {
'mutation': self.mutation,
'benchmark': self.benchmark,
'db_path': self.db_path or DATABASE_DIR # database path need to be passed from manager to worker
}
@staticmethod
def load(data) -> 'BenchmarkGraphData':
return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path'])
def __repr__(self) -> str:
return f"BenchmarkGraphData({self.mutation}, {self.benchmark}, {self.db_path})"
class BenchmarkExecutionEngine(BaseExecutionEngine):
"""
Execution engine that does not actually run any trial, but query the database for results.
The database query is done on the trial end to make sure intermediate metrics are available.
It will also support an accelerated mode that returns metric immediately without even running into NNI manager
(not implemented yet).
"""
def __init__(self, benchmark: Union[str, Callable[[BenchmarkGraphData], Tuple[float, List[float]]]], acceleration: bool = False):
super().__init__()
assert benchmark in BenchmarkGraphData.SUPPORTED_BENCHMARK_LIST, \
f'{benchmark} is not one of the supported benchmarks: {BenchmarkGraphData.SUPPORTED_BENCHMARK_LIST}'
self.benchmark = benchmark
self.acceleration = acceleration
def pack_model_data(self, model: Model) -> Any:
# called when a new model is submitted to backend.
# convert a Model into a data that is acceptable by trial end.
mutation = get_mutation_dict(model)
graph_data = BenchmarkGraphData(mutation, self.benchmark)
return graph_data
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = BenchmarkGraphData.load(receive_trial_parameters())
assert graph_data.db_path is not None, f'Invalid graph data because db_path is None: {graph_data}'
os.environ['NASBENCHMARK_DIR'] = graph_data.db_path
final, intermediates = cls.query_in_benchmark(graph_data)
import nni
for i in intermediates:
nni.report_intermediate_result(i)
nni.report_final_result(final)
@staticmethod
def query_in_benchmark(graph_data: BenchmarkGraphData) -> Tuple[float, List[float]]:
if not isinstance(graph_data.benchmark, str):
return graph_data.benchmark(graph_data)
# built-in benchmarks with default query setting
if graph_data.benchmark == 'nasbench101':
from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats
arch = None
for t in graph_data.mutation.values():
if isinstance(t, dict):
arch = t
if arch is None:
raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}')
return _convert_to_final_and_intermediates(
query_nb101_trial_stats(arch, 108, include_intermediates=True),
'valid_acc'
)
elif graph_data.benchmark.startswith('nasbench201'):
from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
dataset = graph_data.benchmark.split('-')[-1]
return _convert_to_final_and_intermediates(
query_nb201_trial_stats(_flatten_architecture(graph_data.mutation), 200, dataset, include_intermediates=True),
'valid_acc',
)
elif graph_data.benchmark.startswith('nds'):
# FIXME: not tested yet
from nni.nas.benchmarks.nds import query_nds_trial_stats
dataset = graph_data.benchmark.split('-')[-1]
return _convert_to_final_and_intermediates(
query_nds_trial_stats(None, None, None, None, _flatten_architecture(graph_data.mutation),
dataset, include_intermediates=True),
'valid_acc'
)
elif graph_data.benchmark.startswith('nlp'):
# FIXME: not tested yet
from nni.nas.benchmarks.nlp import query_nlp_trial_stats
# TODO: I'm not sure of the availble datasets in this benchmark. and the docs are missing.
return _convert_to_final_and_intermediates(
query_nlp_trial_stats(_flatten_architecture(graph_data.mutation), 'ptb', include_intermediates=True),
'valid_acc'
)
else:
raise ValueError(f'{graph_data.benchmark} is not a supported benchmark.')
def _flatten_architecture(mutation: Dict[str, Any], benchmark: Optional[str] = None):
# STRONG ASSUMPTION HERE!
# This assumes that the benchmarked search space is a one-level search space.
# This means that it is either ONE cell or ONE network.
# Two cell search space like NDS is not supported yet for now.
# Some benchmark even needs special handling to pop out invalid keys. I don't think this is a good design.
# support double underscore to be compatible with naming convention in base engine
ret = {k.split('/')[-1].split('__')[-1]: v for k, v in mutation.items()}
if benchmark == 'nasbench101':
ret = {k: v for k, v in ret.items() if k.startswith('op') or k.startswith('input')}
ret = {k: v if k.startswith('op') or isinstance(v, list) else [v] for k, v in ret.items()}
return ret
def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_name: str) -> Tuple[float, List[float]]:
# convert benchmark results from database to
# final result (float) and intermediate results (list of floats)
benchmark_result = list(benchmark_result)
assert len(benchmark_result) > 0, 'Invalid query. Results from benchmark is empty.'
if len(benchmark_result) > 1:
benchmark_result = random.choice(benchmark_result)
else:
benchmark_result = benchmark_result[0]
benchmark_result = cast(dict, benchmark_result)
return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trainer import PdartsTrainer
from .engine import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['CGOExecutionEngine', 'TrialSubmission']
import logging
import os
import random
import string
import time
import threading
from typing import Iterable, List, Dict, Tuple, cast
from dataclasses import dataclass
from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from nni.nas import utils
from nni.nas.execution.common import (
AbstractExecutionEngine, AbstractGraphListener, WorkerInfo,
Model, ModelStatus, MetricData, Node,
RetiariiAdvisor, send_trial, receive_trial_parameters, get_advisor,
)
from nni.nas.execution.pytorch import codegen
from nni.nas.evaluator.pytorch.lightning import Lightning
from nni.nas.evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule
from nni.nas.execution.pytorch.graph import BaseGraphData
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
_logger = logging.getLogger(__name__)
def _noop(*args, **kwargs):
pass
@dataclass
class TrialSubmission:
model: Model
placement: Dict[Node, Device]
grouped_models: List[Model]
class CGOExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with Cross-Graph Optimization (CGO).
Only models using PyTorch Lighting and MultiModelSupervisedLearningModule as the evaluator can be optimized.
Otherwise, a model will be submitted independently without any cross-graph optimization.
Parameters
----------
training_service
The remote training service config.
max_concurrency
The maximum number of trials to run concurrently.
batch_waiting_time
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
def __init__(self, training_service: RemoteConfig,
max_concurrency: int = None,
batch_waiting_time: int = 60,
rest_port: int | None = None,
rest_url_prefix: str | None = None
) -> None:
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
self.available_devices: List[Device] = []
self.max_concurrency: int = max_concurrency
devices = self._construct_devices(training_service)
for device in devices:
self.available_devices.append(device)
self.all_devices = self.available_devices.copy()
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
self._optimizers = [DedupInputOptimizer()]
self._original_models = {}
self._original_model_to_multi_model = {}
self._trial_to_original_models = {}
self._trial_used_devices: Dict[int, List[Device]] = {}
self._history: List[Model] = []
self._queuing_models: List[Model] = []
self._models_to_retry: List[Model] = []
self._queue_lock = threading.Lock()
# register advisor callbacks
advisor: RetiariiAdvisor = get_advisor()
advisor.register_callbacks({
'send_trial': _noop,
'request_trial_jobs': _noop,
'trial_end': self._trial_end_callback,
'intermediate_metric': self._intermediate_metric_callback,
'final_metric': self._final_metric_callback
})
self._stopped = False
self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start()
def _construct_devices(self, training_service):
devices = []
if hasattr(training_service, 'machine_list'):
for machine in cast(RemoteConfig, training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def join(self):
self._stopped = True
self._consumer_thread.join()
def add_optimizer(self, opt):
self._optimizers.append(opt)
def submit_models(self, *models: List[Model]) -> None:
curr_time = time.time()
_logger.info('%d models are submitted', len(models))
self._queue_lock.acquire()
self._queuing_models.extend([(curr_time, _) for _ in models])
self._queue_lock.release()
def _submit_retry_models(self, models: List[Model]) -> None:
_logger.info('%d models are retried', len(models))
self._queue_lock.acquire()
self._models_to_retry.extend(models)
self._queue_lock.release()
def _consume_models(self):
# a thread to monitor self._models_to_retry and self._queuing_models to consume them in batch
while not self._stopped:
if len(self._models_to_retry) > 0:
self._queue_lock.acquire()
# retrying jobs should be first scheduled.
for m in self._models_to_retry:
if len(self.available_devices) > 0:
self._submit_models_in_batch(m) # submit the single model to avoid cross-graph optimization.
self._models_to_retry = self._models_to_retry[1:]
self._queue_lock.release()
if len(self._queuing_models) > 0:
self._queue_lock.acquire()
curr_time = time.time()
num_models_to_submit = len(self.available_devices)
if self.max_concurrency:
num_models_to_submit = min(num_models_to_submit, self.max_concurrency)
if curr_time - self._queuing_models[0][0] > self._batch_waiting_time:
num_models_to_submit = min(num_models_to_submit, len(self._queuing_models))
if num_models_to_submit > 0:
self._submit_models_in_batch(*[_[1] for _ in self._queuing_models[:num_models_to_submit]])
self._queuing_models = self._queuing_models[num_models_to_submit:]
self._queue_lock.release()
time.sleep(1)
def _extract_placement_constaint(self, placement_mapping: Dict[Node, Device]):
unique_gpus = sorted(list(set([e for e in placement_mapping.values() if isinstance(e, GPUDevice)])))
placement_constraint = None
if len(unique_gpus) > 0:
placement_constraint = {}
placement_constraint['type'] = 'Device'
placement_constraint['gpus'] = [(e.node_id, e.gpu_id) for e in unique_gpus]
return placement_constraint
def _submit_models_in_batch(self, *models: List[Model]) -> None:
_logger.info('%d models are submitted in batch', len(models))
_logger.debug('model id: %s', str([m.model_id for m in models]))
logical = self._build_logical(models)
for opt in self._optimizers:
opt.convert(logical)
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator, {})
placement_constraint = self._extract_placement_constaint(placement)
trial_id = send_trial(data.dump(), placement_constraint=placement_constraint)
# unique non-cpu devices used by the trial
self._trial_used_devices[trial_id] = list(set([_ for _ in placement.values() if isinstance(_, GPUDevice)]))
# currently, it is impossible for search strategy to submit models more than the number of available devices
for used_device in self._trial_used_devices[trial_id]:
self.available_devices.remove(used_device) # used_device must be in self.available_devices
self._running_models[trial_id] = model
self._trial_to_original_models[trial_id] = []
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
self._trial_to_original_models[trial_id].append(m.model_id)
self._history.append(m)
def list_models(self) -> Iterable[Model]:
return self._history
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, Device], List[Model]]]:
"""
Return the assembled models as a list of tuple.
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
"""
# try to use the available_devices first so that it can be launched as early as possible
# if free devices are not enough to assemble all models in one trial, try all devices
if len(self.available_devices) > 0:
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.available_devices)
if len(self.available_devices) == 0 or len(grouped_models) > 1:
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)
phy_models_and_placements = []
for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model)
assert isinstance(model.evaluator, Lightning), \
"cross-graph optimization only supports pytorch lighting as evaluator"
assert isinstance(model.evaluator.module, _MultiModelSupervisedLearningModule), \
"cross-graph optimization only support MultiModelSupervisedLearningModule"
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params = model.evaluator.module.dump_kwargs().copy()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params['n_models'] = len(multi_model)
new_module = _MultiModelSupervisedLearningModule(**new_module_init_params)
model.evaluator.module = new_module
phy_models_and_placements.append((model, model_placement, multi_model.keys()))
return phy_models_and_placements
def _build_logical(self, models: List[Model]) -> LogicalPlan:
logical_plan = LogicalPlan(plan_id=self.logical_plan_counter)
for model in models:
logical_plan.add_model(model)
self.logical_plan_counter += 1
return logical_plan
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
# def _send_trial_callback(self, paramater: dict) -> None:
# if len(self.available_devices) == 0:
# _logger.warning('There is no available devices, but trial is submitted.')
# _logger.debug('Resource used. Remaining: %d', len(self.available_devices))
# def _request_trial_jobs_callback(self, num_trials: int) -> None:
# self.resources += num_trials
# _logger.info('on_resource_available: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
models_to_retry = []
for model_id in self._original_model_to_multi_model:
if self._original_model_to_multi_model[model_id] == model:
original_model = self._original_models[model_id]
if success:
original_model.status = ModelStatus.Trained
else:
original_model.status = ModelStatus.Failed
# the failed models in a multi-model will be retried one by one w/o CGO
if len(self._trial_to_original_models[trial_id]) > 1:
models_to_retry.append(original_model)
for listener in self._listeners:
listener.on_training_end(original_model, success)
if len(models_to_retry) > 0:
self._submit_retry_models(models_to_retry)
self.available_devices.extend(self._trial_used_devices[trial_id])
self.available_devices = sorted(list(set(self.available_devices)))
del self._running_models[trial_id]
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self._original_models[model_id].intermediate_metrics.append(merged_metrics[model_id])
for listener in self._listeners:
listener.on_intermediate_metric(self._original_models[model_id], merged_metrics[model_id])
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
_logger.debug(metrics)
if isinstance(metrics, float):
self._listeners[0].on_metric(self._running_models[trial_id], metrics)
else:
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self._original_models[model_id].metric = merged_metrics[model_id]
for listener in self._listeners:
listener.on_metric(self._original_models[model_id], merged_metrics[model_id])
def query_available_resource(self) -> List[WorkerInfo]:
# the _queuing_models need to use available_devices first
self._queue_lock.acquire()
available_for_more_models = len(self.available_devices) - len(self._queuing_models) - len(self._models_to_retry)
self._queue_lock.release()
return available_for_more_models
def budget_exhausted(self) -> bool:
advisor = get_advisor()
return advisor.stopping
@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
_logger.info('CGO_ENGINE trial parameters received')
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
trainer_instance = graph_data.evaluator
model_cls = utils.import_(f'_generated_model.{random_str}._model')
trainer_instance.fit(model_cls())
os.remove(file_name)
class AssemblePolicy:
@staticmethod
def _is_related_node(model: Model, node: Node):
if isinstance(node, AbstractLogicalNode):
if model in node.related_models:
return True
else:
if model == node.graph.model:
return True
return False
@staticmethod
def _check_graph_connectivity(model: Model,
group_model: Dict[Model, Device],
logical_plan: LogicalPlan) -> bool:
for edge in logical_plan.logical_graph.edges:
if AssemblePolicy._is_related_node(model, edge.head) or \
AssemblePolicy._is_related_node(model, edge.tail):
for grouped_model in group_model:
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
AssemblePolicy._is_related_node(grouped_model, edge.tail):
return True
return False
@staticmethod
def _check_evaluator(new_model: Model, group_model: Dict[Model, Device]) -> bool:
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, _MultiModelSupervisedLearningModule)):
return False
for m in group_model:
if not m.evaluator == new_model.evaluator:
return False
return True
@staticmethod
def group(logical_plan, available_devices):
# TODO: Packing multiple model in one GPU
# Currently, we only support one model per GPU
all_grouped_models = []
group_model = {}
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
for idx, m in enumerate(logical_plan.models):
# models in one group should
# (1) not use more GPUs than available_devices
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
# (3) use same MultiModelSupervisedLearningModule
if len(group_model) > 0 and \
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
AssemblePolicy._check_evaluator(m, group_model) == False):
all_grouped_models.append(group_model)
group_model = {}
group_model[m] = available_devices[idx % len(available_devices)]
if len(group_model) == len(available_devices) or \
idx == len(logical_plan.models) - 1:
all_grouped_models.append(group_model)
group_model = {}
return all_grouped_models
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC
from .logical_plan import LogicalPlan
class AbstractOptimizer(ABC):
def __init__(self) -> None:
pass
def convert(self, logical_plan: LogicalPlan) -> None:
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
from typing import Dict, Tuple, Any
from nni.retiarii.utils import uid
from nni.common.device import Device, CPUDevice
from nni.nas.execution.common.graph import Cell, Edge, Graph, Model, Node
from nni.nas.execution.common.graph_op import Operation, _IOPseudoOperation
class AbstractLogicalNode(Node):
def __init__(self, graph, node_id, name, operation, _internal=False):
super().__init__(graph, node_id, name, operation, _internal=_internal)
self.related_models = []
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
"""
Given a set of models to be formed in a physical model and their device placement,
this function replaces the logical node with an executable physical node for the physical model.
Parameters
----------
multi_model_placement : dict
a dict of models and device placement.
These models will be assembled into the same physical model to run.
Returns
-------
node : Node
the physical node to replace the logical node in the physical model
placement : Device
the device placement of the returned physical node
"""
raise NotImplementedError
def _fork_to(self, graph: Graph):
raise NotImplementedError
class LogicalGraph(Graph):
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
super().__init__(model, graph_id, name='logical_' + name, _internal=_internal)
def _dump(self) -> Any:
nodes_dump = {}
for node in self.hidden_nodes:
if isinstance(node, OriginNode):
nodes_dump[f"{node.original_graph.model.model_id}_{node.name}"] = node._dump()
else:
nodes_dump[f"{node.graph.model.model_id}_{node.name}"] = node._dump()
edges_dump = []
for edge in self.edges:
if isinstance(edge.head, OriginNode):
head_info = f'{edge.head.original_graph.model.model_id}_{edge.head.name}'
else:
head_info = edge.head.name
if isinstance(edge.tail, OriginNode):
tail_info = f'{edge.tail.original_graph.model.model_id}_{edge.tail.name}'
else:
tail_info = edge.tail.name
edges_dump.append((head_info, tail_info))
return {
'inputs': self.input_node.operation.io_names,
'outputs': self.output_node.operation.io_names,
'nodes': nodes_dump,
'edges': edges_dump
}
def _fork_to(self, model: Model) -> Graph:
new_graph = Graph(model, self.id, self.name,
_internal=True)._register()
for node in self.hidden_nodes:
if isinstance(node, AbstractLogicalNode):
node._fork_to(new_graph)
else:
Node(new_graph, node.id, node.name,
node.operation, _internal=True)._register()
id_to_new_node = {node.__repr__(): node for node in new_graph.nodes}
for edge in self.edges:
new_head = id_to_new_node[edge.head.__repr__()]
new_tail = id_to_new_node[edge.tail.__repr__()]
Edge((new_head, edge.head_slot),
(new_tail, edge.tail_slot), _internal=True)._register()
return new_graph
class OriginNode(AbstractLogicalNode):
"""
This is logical node representing the original node without any modification.
In assemble, just return the original node along with the physical placement given by multi_model_placement.
"""
def __init__(self, logical_graph: LogicalGraph,
original_graph: Graph, original_node: Node,
name: str, operation, _internal=False):
super().__init__(logical_graph, original_node.id, name, operation)
self.original_graph = original_graph
self.original_node = original_node
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
model_id = self.original_node.graph.model.model_id
new_node = Node(self.original_node.graph, self.original_node.id,
f"M_{model_id}_" +
self.original_node.name,
self.original_node.operation)
return new_node, multi_model_placement[self.original_node.graph.model]
def __repr__(self):
return f'OriginNode(id={self.id}, name={self.name}, \
operation={self.operation}, origin_model_id={self.original_graph.model.model_id})'
def _fork_to(self, graph: Graph):
OriginNode(graph, self.original_graph, self.original_node,
self.name, self.operation)._register()
class LogicalPlan:
def __init__(self, plan_id=0) -> None:
self.lp_model = Model(_internal=True)
self.id = plan_id
self.logical_graph = LogicalGraph(
self.lp_model, self.id, name=f'{self.id}', _internal=True)._register()
self.lp_model._root_graph_name = self.logical_graph.name
self.models = []
def add_model(self, model: Model):
self.models.append(model)
# Only optimize the root graph.
self._merge_graph(model.root_graph)
def _merge_graph(self, from_graph):
to_graph = self.logical_graph
id_to_new_node = {} # old node ID -> new node object
for old_node in from_graph.nodes:
new_node = OriginNode(to_graph, old_node.graph,
old_node, old_node.name,
old_node.operation, _internal=True)._register()
id_to_new_node[old_node.id] = new_node
for edge in from_graph.edges:
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
def assemble(self, multi_model_placement: Dict[Model, Device]) \
-> Tuple[Model, Dict[Node, Device]]:
"""
Given a set of models to be formed in a physical model and their device placement,
this function replaces all the logical node in this LogicalPlan with executable physical nodes
for the physical model.
Parameters
----------
multi_model_placement : dict
a dict of models and device placement.
These models will be assembled into the same physical model to run.
Returns
-------
phy_model : Model
the physical model formed by models in `multi_model_placement`
all logical node are replaced by physical nodes
node_placements : dict
the device placement of the nodes in `phy_model`
"""
phy_model = Model(_internal=True)
phy_graph = self.lp_model.root_graph._fork_to(phy_model)
phy_graph._rename_graph(phy_graph.name, "_model")
# merge sub-graphs
for model in multi_model_placement:
if phy_model.evaluator is None and model.evaluator is not None:
phy_model.evaluator = model.evaluator
for graph_name in model.graphs:
if graph_name != model._root_graph_name:
new_graph = model.graphs[graph_name]._fork_to(
phy_model, name_prefix=f'M_{model.model_id}_')
# prefix of M_ of hidden_nodes name in non-root graphs is added here
for new_node in new_graph.hidden_nodes:
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model.model_id}_{old_cell_name}'
assert(phy_model.evaluator is not None)
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
evaluator_slot = {} # Model ID -> Slot ID
input_slot_mapping = {}
output_slot_mapping = {}
# Replace all logical nodes to executable physical nodes
hidden_nodes = phy_graph.hidden_nodes.copy()
node_placements = {}
added_models = []
for node in hidden_nodes:
if isinstance(node, OriginNode):
model_id = node.original_graph.model.model_id
if node.original_graph.model not in multi_model_placement:
for edge in node.incoming_edges:
edge.remove()
for edge in node.outgoing_edges:
edge.remove()
node.remove()
continue
if isinstance(node, AbstractLogicalNode):
new_node, placement = node.assemble(multi_model_placement)
if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id
if model_id not in evaluator_slot:
added_models.append(model_id)
evaluator_slot[model_id] = len(added_models) - 1
slot = evaluator_slot[model_id]
else:
slot = evaluator_slot[model_id]
# If a model's inputs/outputs are not used in the multi-model
# the codegen and trainer should not generate and use them
# "use_input" and "use_output" are used to mark whether
# an input/output of a model is used in a multi-model
if new_node.operation.type == '_inputs':
input_slot_mapping[new_node] = slot
if new_node.operation.type == '_outputs':
output_slot_mapping[new_node] = slot
self.node_replace(node, new_node)
# name prefix of M_ of cells in hidden_nodes of root graphs is added here
# FIXME: merge this rename with non-root graph, only do once.
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model_id}_{old_cell_name}'
# input should be at CPU, move it to GPU first if necessary
if isinstance(new_node.operation, _IOPseudoOperation) and new_node.operation.type == '_inputs':
# hack: only support single_server
node_placements[new_node] = CPUDevice(node_id=placement.node_id)
else:
node_placements[new_node] = placement
node.remove()
# If two nodes are placed on different devices, use ToDevice op to copy the node
# TODO: when copying one node to multiple devices, broadcast is more efficient than P2P communication
existing_edges = phy_graph.edges.copy()
# Avoid a node is copied multiple times on the same device
copied_op: Dict[Tuple(Node, Device), Node] = {}
for edge in existing_edges:
head_placement = node_placements[edge.head]
tail_placement = node_placements[edge.tail]
if head_placement != tail_placement:
if head_placement.node_id != tail_placement.node_id:
raise ValueError('Cross-server placement is not supported.')
# Same server different devices
if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)]
else:
dst_name = edge.head.name + "_to_" + edge.tail.name
to_operation = Operation.new(
'ToDevice', {
"device": tail_placement, "src": (
edge.head.name, edge.head_slot), "dst": dst_name})
to_node = Node(phy_graph, uid(), dst_name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node
node_placements[to_node] = head_placement
edge.head = to_node
edge.head_slot = None
# merge all input nodes into one with multiple slots
input_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_inputs':
input_nodes.append(node)
for edge in phy_graph.edges:
if edge.head in input_nodes:
edge.head_slot = input_slot_mapping[edge.head]
edge.head = phy_graph.input_node
# merge all output nodes into one with multiple slots
output_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs':
output_nodes.append(node)
for edge in phy_graph.edges:
if edge.tail in output_nodes:
edge.tail_slot = output_slot_mapping[edge.tail]
edge.tail = phy_graph.output_node
for node in input_nodes:
node.remove()
for node in output_nodes:
node.remove()
return phy_model, node_placements
def node_replace(self, old_node: Node, new_node: Node, input_slot_mapping=None, output_slot_mapping=None):
# TODO: currently, only support single input slot and output slot.
if input_slot_mapping is not None or output_slot_mapping is not None:
raise ValueError('Slot mapping is not supported')
phy_graph = old_node.graph
new_node.graph = phy_graph
new_node._register()
for edge in phy_graph.edges:
if edge.head == old_node:
edge.head = new_node
elif edge.tail == old_node:
edge.tail = new_node
# after the replacement, there might be multiple duplicated edges
# with the same input and output nodes, which should be de-duplicated
self._remove_duplicated_edges()
def _remove_duplicated_edges(self):
# TODO: it does not have duplicated edges if only supporting dedup input
# Duplicated edges appear when a chain of prefix nodes are deduplicated
pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Dict, Tuple
from nni.nas.utils import uid
from nni.nas.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
from nni.common.device import GPUDevice
from nni.nas.execution.common.graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode)
_supported_evaluators = [MultiModelSupervisedLearningModule]
class DedupInputNode(AbstractLogicalNode):
"""
This is logical node representing the node for deduplication.
In assemble, just return one copy of the original node when multiple models are assembled.
These models will share the result of once calculation.
"""
def __init__(self, logical_graph: LogicalGraph, node_id: int,
nodes_to_dedup: List[Node], _internal=False):
super().__init__(logical_graph, node_id,
"Dedup_" + nodes_to_dedup[0].name,
nodes_to_dedup[0].operation)
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
self.related_models = [_.original_graph.model for _ in self.origin_nodes]
def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
for node in self.origin_nodes:
if node.original_graph.model in multi_model_placement:
new_node = Node(node.original_graph, node.id,
f'M_{node.original_graph.model.model_id}_{node.name}',
node.operation)
return new_node, multi_model_placement[node.original_graph.model]
raise ValueError(f'DedupInputNode {self.name} does not contain nodes from multi_model')
def _fork_to(self, graph: Graph):
DedupInputNode(graph, self.id, self.origin_nodes)._register()
def __repr__(self) -> str:
return f'DedupNode(id={self.id}, name={self.name}, \
len(nodes_to_dedup)={len(self.origin_nodes)}'
class DedupInputOptimizer(AbstractOptimizer):
def __init__(self) -> None:
pass
def _check_supported_evaluator(self, evaluator):
for e in _supported_evaluators:
if isinstance(evaluator, e):
return True
return False
def _check_deduplicate_by_node(self, root_node, node_to_check):
if root_node == node_to_check:
return True
if root_node.operation.type == '_inputs' and \
node_to_check.operation.type == '_inputs' and \
isinstance(root_node, OriginNode) and \
isinstance(node_to_check, OriginNode):
if self._check_supported_evaluator(root_node.original_graph.model.evaluator):
return False
if root_node.original_graph.model.evaluator == node_to_check.original_graph.model.evaluator:
return True
else:
return False
else:
return False
def convert(self, logical_plan: LogicalPlan) -> None:
nodes_to_skip = set()
while True: # repeat until the logical_graph converges
input_nodes = logical_plan.logical_graph.get_nodes_by_type("_inputs")
# _PseudoOperation(type_name="_inputs"))
root_node = None
for node in input_nodes:
if node in nodes_to_skip:
continue
root_node = node
break
if root_node is None:
break # end of convert
else:
nodes_to_dedup = []
for node in input_nodes:
if node in nodes_to_skip:
continue
if self._check_deduplicate_by_node(root_node, node):
nodes_to_dedup.append(node)
assert(len(nodes_to_dedup) >= 1)
if len(nodes_to_dedup) == 1:
assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node)
else:
dedup_node = DedupInputNode(logical_plan.logical_graph, uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup:
edge.head = dedup_node
if edge.tail in nodes_to_dedup:
edge.tail = dedup_node
for node in nodes_to_dedup:
node.remove()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['model_to_pytorch_script']
import logging
import re
from typing import Dict, List, Tuple, Any, cast
from nni.common.device import Device, GPUDevice
from nni.nas.execution.common.graph import IllegalGraphError, Edge, Graph, Node, Model
from nni.nas.execution.common.graph_op import PyTorchOperation
from nni.nas.utils import STATE_DICT_PY_MAPPING
from .op_def import ToDevice
_logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model, placement=None) -> str:
graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.debug('sorted_incoming_edges: %s', str(edges))
if not edges:
return []
_logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: cast(int, edge.tail_slot)))
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node, graph_name: str) -> Tuple[List[str], List[Any]]:
"""
Format the inputs of a given node.
Inputs will be formatted with ``_format_variable_name``
Parameters
----------
node : Node
a graph node, get and format its inputs
graph_name : str
subgraph name, to format variable names
Returns
-------
list
the list of input names
list
the list of input values, if an input is simple type, record its value,
otherwise the value is None
"""
edges = _sorted_incoming_edges(node)
inputs = []
inputs_value = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
if edge.head.operation.io_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(_format_variable_name(edge.head.operation.io_names[edge.head_slot], graph_name))
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
inputs_value.append(None)
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append(_format_variable_name(edge.head.name, graph_name))
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \
'value' in edge.head.operation.parameters:
inputs_value.append(edge.head.operation.parameters['value'])
else:
inputs_value.append(None)
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(_format_variable_name(edge.head.name, graph_name), edge.head_slot))
inputs_value.append(None)
return inputs, inputs_value
def _format_variable_name(name: str, graph_name: str) -> str:
"""
1. replace invalid characters in node name
2. variables name (full name space) is too long, shorten the name by removing the prefix ```graph_name```
"""
name = name[len(graph_name):] if name.startswith(graph_name) else name
name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
name = re.sub(r'\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name = name[1:]
elif name.startswith('_'):
# to avoid conflicts between '_' and '__'
name = 'i' + name
return name
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
'''
Since CUDA_VISIBLE_DEVICES will be set to the list of real GPU ID,
we need to remap the GPU ID when generating code to match them correctly.
For example, when CUDA_VISIBLE_DEVICES="0,3", we need to use "cuda:0", "cuda:1" in the generated code.
'''
unique_devices = sorted(list(set([e for e in placement.values() if isinstance(e, GPUDevice)])))
node_gpu_cnt = {}
cuda_remapped_id = {}
for d in unique_devices:
if d.node_id not in node_gpu_cnt:
node_gpu_cnt[d.node_id] = 0
node_gpu_cnt[d.node_id] += 1
cuda_remapped_id[d] = node_gpu_cnt[d.node_id] - 1
return cuda_remapped_id
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> Tuple[set, str]:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
node_python_mappings = {}
cuda_remapped_id = None
if placement:
cuda_remapped_id = generate_cuda_mapping(placement)
for node in nodes:
if node.operation:
if placement and isinstance(node.operation, ToDevice):
cuda_remapped_id = cast(dict, cuda_remapped_id)
node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])
if node.operation.type == 'shared':
continue
pkg_name = cast(PyTorchOperation, node.operation).get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
py_variable_name = _format_variable_name(node.name, graph_name)
node_code = node.operation.to_init_code(py_variable_name)
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice):
assert cuda_remapped_id is not None
device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
else:
device_repr = placement[node].device_repr()
node_codes.append(f"{node_code}.to('{device_repr}')")
else:
node_codes.append(node_code)
# Map to module hierarchies in original search space python code
node_python_mappings[py_variable_name] = node.python_name
node_codes.append(f'self.{STATE_DICT_PY_MAPPING} = {node_python_mappings}')
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
for name in graph.input_node.operation.io_names:
assert not name.startswith(graph_name)
input_code = ', '.join(graph.input_node.operation.io_names)
edge_codes = []
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs, inputs_value = _format_inputs(node, graph_name)
node_name = _format_variable_name(node.name, graph_name)
submodule_name = node_name
if node.operation.type == 'shared':
submodule_name = _format_variable_name(node.operation.parameters['reference'], graph_name)
edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value))
output_names, _ = _format_inputs(graph.output_node, graph_name)
if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names)
linebreak = '\n '
return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
# TODO: handle imports
_PyTorchScriptTemplate = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nni.nas.nn.pytorch
{}
{}
'''
_PyTorchModelTemplate = '''
class {graph_name}(nn.Module):
def __init__(self):
super().__init__()
{nodes}
def forward(self, {inputs}):
{edges}
return {outputs}
'''
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trainer import CreamSupernetTrainer
from .graph_gen import convert_to_graph
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import re
import torch
from nni.nas.execution.common import Graph, Model, Node, Cell, Operation
from nni.nas.nn.pytorch import InputChoice, Placeholder, LayerChoice
from nni.nas.utils import get_init_parameters_or_fail, get_importable_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import (
_convert_name, build_full_name, _without_shape_info,
_extract_info_from_trace_node, get_full_name_by_scope_name,
is_layerchoice_node, match_node, build_cand_name,
build_python_name
)
class GraphConverter:
def __init__(self):
self.global_seq = 0
self.global_graph_id = 0
def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index):
if _input in output_remap:
assert output_remap[_input].kind() == 'aten::append'
predecessor_node = output_remap[_input]
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = None
src_node = node_index[predecessor_node]
assert isinstance(src_node, Node)
elif _input in graph_inputs:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
else:
predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs = [_output for _output in predecessor_node.outputs()]
if len(predecessor_outputs) == 1:
idx = None
else:
idx = predecessor_outputs.index(_input)
ir_predecessor_node = node_index[predecessor_node]
src_node_idx = idx
assert isinstance(ir_predecessor_node, Node)
src_node = ir_predecessor_node
return src_node, src_node_idx
def _add_edge(self, ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
new_node_input_idx = 0
for _input in node.inputs():
if ignore_first:
ignore_first = False
continue
# handle source node
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
# handle destination node
dst_node = new_node
if is_single_input:
dst_node_idx = None
else:
dst_node_idx = new_node_input_idx
# create edge
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
new_node_input_idx += 1
def create_prim_constant_node(self, ir_graph, node, module_name):
# NOTE: compare with string not type, because the type is defined in pytorch C code.
# `.kind()` can also be used here
if node.outputsAt(0).type().str() == 'None':
attrs = {'type': 'None'}
else:
attrs = {'type': node.outputsAt(0).type().str(), 'value': node.outputsAt(0).toIValue()}
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, self.global_seq),
node.kind(), attrs)
return new_node
def handle_prim_attr_node(self, node, module):
assert node.hasAttribute('name')
value = None
if node.inputsAt(0).debugName() == 'self':
_val = getattr(module, node.s('name'))
# TODO: serialize complex data type, and output proper error message
if isinstance(_val, (int, float, str, bool)):
value = _val
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName(), 'value': value}
return node.kind(), attrs
def _remove_mangle(self, module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(self, ir_graph, targeted_type=None):
"""
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout = set()
for edge in ir_graph.edges:
if edge.head.id not in node_fanout:
node_fanout.add(edge.head.id)
to_removes = []
for hidden_node in ir_graph.hidden_nodes:
if hidden_node.id not in node_fanout:
assert isinstance(hidden_node, Node)
if targeted_type is None:
to_removes.append(hidden_node)
elif hidden_node.operation.type == targeted_type:
to_removes.append(hidden_node)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(self, script_module, sm_graph,
module, module_name, module_python_name,
ir_model, ir_graph,
shared_module_index=None):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
shared_module_index : dict
it is used for knowing which module has been created an ir node,
if created and invoked again, then the new ir node can simply reference that ir node.
this way we can identify shared modules (i.e., one module invoked multiple times in `forward` function)
Returns
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs = []
for _input in sm_graph.inputs():
if _input.debugName() == 'self':
assert _input.unique() == 0
continue
graph_inputs.append(_input)
# TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
if shared_module_index is None:
shared_module_index = {}
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap = {}
# ===================handle control flow: if===================
def handle_if_condition(cond_tensor):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def _generate_expr(tensor):
if tensor.node().kind() == 'prim::GetAttr':
return f'({getattr(module, tensor.node().s("name"))})'
elif tensor.node().kind() == 'aten::__getitem__':
t = _generate_expr(tensor.node().inputsAt(0))
idx = _generate_expr(tensor.node().inputsAt(1))
return f'({t}[{idx}])'
elif tensor.node().kind() == 'prim::Constant':
return f'{tensor.toIValue()}'
elif tensor.node().kind() == 'aten::eq':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
elif tensor.node().kind() == 'aten::le':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} <= {right})'
elif tensor.node().kind() == 'aten::ge':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} >= {right})'
elif tensor.node().kind() == 'aten::__not__':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(not {value})'
elif tensor.node().kind() == 'aten::Bool':
value = _generate_expr(tensor.node().inputsAt(0))
return f'bool({value})'
elif tensor.node().kind() == 'aten::__is__':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} is {right})'
elif tensor.node().kind() == 'aten::__isnot__':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} is not {right})'
elif tensor.node().kind() == 'aten::ne':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} != {right})'
elif tensor.node().kind() == 'aten::gt':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} > {right})'
elif tensor.node().kind() == 'aten::lt':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} < {right})'
elif tensor.node().kind() == 'prim::If':
raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.')
elif tensor.node().kind() == 'aten::abs':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(torch.abs({value}))'
elif tensor.node().kind() == 'aten::sum':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(torch.sum({value}))'
elif tensor.node().kind() == 'aten::item':
value = _generate_expr(tensor.node().inputsAt(0))
return f'({value}.item())'
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, '
'you are suggested to decorate the corresponding class with "@basic_unit".')
expr = _generate_expr(cond_tensor)
return eval(expr)
def handle_if_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
cond = handle_if_condition(inputs[0])
chosen_block = 0 if cond else 1
blocks = [block for block in node.blocks()]
assert len(blocks) == 2
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
self._add_edge(ir_graph, blocks[chosen_block].returnNode(), graph_inputs, node_index, new_node, output_remap)
last_block_node = new_node
return last_block_node
# ===================handle function call===================
def handle_function_callmethod(node):
# get and handle the first input, which should be an nn.Module
assert node.hasAttribute('name')
# NOTE: "forward__0" is hacky, LSTM instance is parsed to call forward__0 in torchscript
if node.s('name') in ['forward', 'forward__0']:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str = self._remove_mangle(node.inputsAt(0).type().str())
submodule = node.inputsAt(0).node()
assert submodule.kind() == 'prim::GetAttr'
assert submodule.hasAttribute('name')
submodule_name = submodule.s('name')
if submodule.inputsAt(0).debugName() == 'self':
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_python_name = build_python_name(module_python_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = self._convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, submodule_python_name,
ir_model)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList
predecessor = submodule.inputsAt(0).node()
module_name_space = [submodule_name]
while predecessor.inputsAt(0).debugName() != 'self':
# this is for dealing with nested ModuleList. below is an example
# %3 : __torch__.torch.nn.modules.container.___torch_mangle_0.ModuleList = prim::GetAttr[name="ops"](%self)
# %5 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="0"](%3)
# %7 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="1"](%3)
# %9 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="2"](%3)
# %11 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="3"](%3)
# %14 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="0"](%5)
# %16 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="1"](%5)
# %state.2 : Tensor = prim::CallMethod[name="forward"](%14, %x.1) # modulelist.py:18:24
# %state.4 : Tensor = prim::CallMethod[name="forward"](%16, %state.2) # modulelist.py:18:24
assert predecessor.kind() == 'prim::GetAttr'
module_name_space.append(predecessor.s('name'))
predecessor = predecessor.inputsAt(0).node()
assert predecessor.kind() == 'prim::GetAttr'
assert predecessor.hasAttribute('name')
module_name_space.append(predecessor.s('name'))
submodule_full_name = build_full_name(module_name, list(reversed(module_name_space)))
submodule_python_name = build_python_name(module_python_name, list(reversed(module_name_space)))
submodule_obj = module
script_submodule = script_module
for each_name in list(reversed(module_name_space)):
submodule_obj = getattr(submodule_obj, each_name)
script_submodule = script_submodule._modules[each_name]
subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name,
submodule_python_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
if submodule_full_name in shared_module_index:
# this module is invoked more than once, the ir node has already been created
# create a reference node for it.
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
self.global_seq += 1
shared_node_name = build_full_name(submodule_full_name, '', self.global_seq)
shared_node_python_name = build_python_name(submodule_python_name, self.global_seq)
shared_type_operation = Operation.new('shared', {'reference': submodule_full_name})
subcell = ir_graph.add_node(shared_node_name, shared_type_operation)
subcell.python_name = shared_node_python_name
else:
# this module is processed for the first time, build cell for it
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
subcell.python_name = submodule_python_name
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, InputChoice):
subcell.update_label(sub_m_attrs['label'])
else:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
subcell.python_name = submodule_python_name
shared_module_index[submodule_full_name] = subcell
node_index[node] = subcell
# connect the cell into graph
self._add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else:
# handle normal member function
assert hasattr(script_module, node.s('name'))
# TODO: support non member functions
assert node.inputsAt(0).debugName() == 'self'
script_method = getattr(script_module, node.s('name')) # <class 'torch._C.ScriptMethod'>
# step #1: generate graph ir for this method
method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True)
self.handle_graph_nodes(script_module, script_method.graph, module,
module_name, module_python_name, ir_model, method_ir_graph, shared_module_index)
self.refine_graph(method_ir_graph)
# step #2: merge this graph to its module graph
for h_node in method_ir_graph.hidden_nodes:
h_node.graph = ir_graph
ir_graph.hidden_nodes.append(h_node)
for edge in method_ir_graph.edges:
edge.graph = ir_graph
if edge.head == method_ir_graph.input_node:
# this is a member method, 'self' is the first argument, thus +1
assert edge.head_slot is not None
_input = node.inputsAt(edge.head_slot + 1)
src_node, src_node_idx = self._add_edge_handle_source_node(_input, graph_inputs, ir_graph, output_remap, node_index)
edge.head = src_node
edge.head_slot = src_node_idx
if edge.tail == method_ir_graph.output_node:
# since the following nodes have not been created, skip this edge
# edge.head is the output node of this method
# TODO: check whether there could be multiple output nodes???
node_index[node] = edge.head
continue
ir_graph.edges.append(edge)
# ===================handle each single node===================
def handle_single_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
if node.kind() == 'prim::CallMethod':
handle_function_callmethod(node)
elif node.kind() == 'prim::CallFunction':
func_type_str = self._remove_mangle(node.inputsAt(0).type().str())
func = node.inputsAt(0).node()
assert func.kind() == 'prim::Constant'
assert func.hasAttribute('name')
func_name = func.s('name')
# create node for func
self.global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq),
'{}.{}'.format(func_type_str, func_name))
func_python_name = build_python_name(module_python_name, func_name)
func_node.python_name = func_python_name
node_index[node] = func_node
self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
new_node = self.create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() in ['prim::ListConstruct', 'prim::ListUnpack', 'prim::TupleConstruct', 'prim::TupleUnpack']:
self.global_seq += 1
prim_op_name = node.kind().split('::')[-1]
new_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
node_index[node] = new_node
self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = self.handle_prim_attr_node(node, module)
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, self.global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
elif node.kind().startswith('prim::'):
self.global_seq += 1
prim_op_name = node.kind().replace('::', '__')
prim_node = ir_graph.add_node(build_full_name(module_name, prim_op_name, self.global_seq), node.kind())
node_index[node] = prim_node
self._add_edge(ir_graph, node, graph_inputs, node_index, prim_node, output_remap)
elif node.kind() == 'aten::append':
self.global_seq += 1
aten_op_name = node.kind().replace('::', '__')
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
self.global_seq += 1
aten_op_name = node.kind().replace('::', '__')
aten_op_python_name = node.kind().replace('aten::', '')
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
aten_python_name = build_python_name(module_python_name, aten_op_python_name)
aten_node.python_name = aten_python_name
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
return node_index[node]
for node in sm_graph.nodes():
handle_single_node(node)
if node_index != {}:
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
else:
# here is an example that the ir_graph and node_index is empty
# graph(%self : __torch__.torchmodels.googlenet.GoogLeNet,
# %x.1 : Tensor): return (%x.1)
# add an edge from head to tail to handle this situation
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ir_graph.output_node, None))
def merge_aten_slices(self, ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes = []
has_slice_node = False
for node in ir_graph.hidden_nodes:
if node.operation.type == 'aten::slice':
has_slice_node = True
for pred in node.predecessors:
if pred.operation.type not in ['aten::slice', 'prim::Constant']:
head_slice_nodes.append(node)
break
if has_slice_node:
assert head_slice_nodes
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
for edge in head_node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(head_node)
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
slot += 5
node = head_node
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
suc_node = node.successors[0]
assert len(suc_node.incoming_edges) == 5
for edge in suc_node.incoming_edges:
if edge.tail_slot == 0:
edge.remove()
else:
edge.tail = new_slice_node
edge.tail_slot = slot + edge.tail_slot - 1
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
def refine_graph(self, ir_graph):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
self.merge_aten_slices(ir_graph)
def _handle_inputchoice(self, module):
return {
'n_candidates': module.n_candidates,
'n_chosen': module.n_chosen,
'reduction': module.reduction,
'label': module.label
}
def _handle_valuechoice(self, module):
return {
'candidates': module.candidates,
'label': module.label,
}
def _convert_module(self, script_module, module, module_name, module_python_name, ir_model):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
m_attrs = None
if original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
graph.python_name = module_python_name
candidate_name_list = []
for cand_name in module.names:
cand = module[cand_name]
script_cand = script_module._modules[cand_name]
cand_full_name = build_cand_name(cand_name, module.label)
cand_python_name = build_python_name(module_python_name, cand_name)
candidate_name_list.append(cand_full_name)
subgraph, attrs = self._convert_module(script_cand, cand, cand_full_name, cand_python_name, ir_model)
if subgraph is not None:
cand_node = graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs))
cand_node.python_name = cand_python_name
else:
cand_type = '__torch__.' + get_importable_name(cand.__class__)
cand_node = graph.add_node(cand_full_name, cand_type, attrs)
cand_node.python_name = cand_python_name
graph._register()
return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
elif original_type_name == OpTypeName.InputChoice:
m_attrs = self._handle_inputchoice(module)
elif original_type_name == OpTypeName.ValueChoice:
m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder:
m_attrs = get_init_parameters_or_fail(module)
elif module.__class__.__module__.startswith('torch.nn') and \
original_type_name in torch.nn.__dict__ and \
original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_nni_basic_unit', False):
# this module is marked as serialize, won't continue to parse
m_attrs = get_init_parameters_or_fail(module)
if m_attrs is not None:
return None, m_attrs
# handle TorchScript graph
sm_graph = script_module.graph
self.global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
ir_graph.python_name = module_python_name
# handle graph nodes
self.handle_graph_nodes(script_module, sm_graph, module,
module_name, module_python_name, ir_model, ir_graph)
self.refine_graph(ir_graph)
ir_graph._register()
# add mutation signal for special modules
if original_type_name == OpTypeName.Repeat:
attrs = {
'mutation': 'repeat',
'label': module.label,
'depth': module.depth_choice,
'max_depth': module.max_depth,
'min_depth': module.min_depth,
}
return ir_graph, attrs
return ir_graph, {}
def convert_module(self, script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
return self._convert_module(script_module, module, module_name, None, ir_model)
class GraphConverterWithShape(GraphConverter):
"""
Convert a pytorch model to nni ir along with input/output shape info.
Based ir acquired through ``torch.jit.script``
and shape info acquired through ``torch.jit.trace``.
.. warning::
Known issues:
1. ``InputChoice`` and ``ValueChoice`` not supported yet.
2. Currently random inputs are fed while tracing layerchoice.
If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info.
"""
def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
module.eval()
ir_graph, attrs = self._convert_module(script_module, module, module_name, None, ir_model)
self.remove_dummy_nodes(ir_model)
self._initialize_parameters(ir_model)
self._trace_module(module, module_name, ir_model, dummy_input)
return ir_graph, attrs
def _initialize_parameters(self, ir_model: 'Model'):
for ir_node in ir_model.get_nodes():
if ir_node.operation.parameters is None:
ir_node.operation.parameters = {}
ir_node.operation.attributes.setdefault('input_shape', [])
ir_node.operation.attributes.setdefault('output_shape', [])
def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
# First, trace the whole graph
tm_graph = self._trace(module, dummy_input)
for node in tm_graph.nodes():
shape_parameters, parameters = _extract_info_from_trace_node(node)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node = match_node(ir_model, node, module_name)
if ir_node is not None:
ir_node.operation.attributes.update(shape_parameters)
if parameters:
ir_node.operation.parameters.update(parameters)
self.propagate_shape(ir_model)
# trace each layerchoice
for name, submodule in module.named_modules():
# TODO: support InputChoice and ValueChoice
if isinstance(submodule, LayerChoice):
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
lc_node = ir_model.get_node_by_name(full_name)
assert lc_node is not None, f'Cannot find a node with name {full_name}'
for cand_name in submodule.names:
cand = submodule[cand_name]
cand_name = build_cand_name(cand_name, submodule.label)
# TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data.
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.attributes['input_shape']]
self._trace_module(cand, cand_name, ir_model, lc_inputs)
def propagate_shape(self, ir_model: 'Model'):
def propagate_shape_for_graph(graph: 'Graph'):
if graph == ir_model.root_graph:
return
graph_node = ir_model.get_node_by_name(graph.name)
assert graph_node is not None, f'Cannot find a node with name {graph.name}'
if not _without_shape_info(graph_node):
return
if is_layerchoice_node(graph_node):
cand_name = graph_node.operation.parameters['candidates'][0]
cand_node = ir_model.get_node_by_name(cand_name)
assert cand_node is not None, f'Cannot find a node with name {cand_name}'
if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape']
graph_node.operation.attributes['output_shape'] = cand_node.operation.attributes['output_shape']
else:
input_shape = [[]] * len(graph.input_node.operation.io_names or [])
output_shape = [[]] * len(graph.output_node.operation.io_names or [])
for edge in graph.input_node.outgoing_edges:
node = edge.tail
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.attributes['input_shape']:
input_shape[edge.head_slot or 0] = node.operation.attributes['input_shape'][edge.tail_slot or 0]
graph_node.operation.attributes['input_shape'] = input_shape
for edge in graph.output_node.incoming_edges:
node = edge.head
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.attributes['output_shape']:
output_shape[edge.tail_slot or 0] = node.operation.attributes['output_shape'][edge.head_slot or 0]
graph_node.operation.attributes['output_shape'] = output_shape
propagate_shape_for_graph(graph_node.graph)
# propagate from node to graph
for node in ir_model.get_nodes():
propagate_shape_for_graph(node.graph)
def _trace(self, module, dummy_input):
traced_module = torch.jit.trace(module, dummy_input)
torch._C._jit_pass_inline(traced_module.graph)
return traced_module.graph
def remove_dummy_nodes(self, ir_model: 'Model'):
# remove identity nodes
for node in ir_model.get_nodes_by_type('noop_identity'):
graph = node.graph
for in_edge in node.incoming_edges:
for out_edge in node.outgoing_edges:
if in_edge.tail_slot == out_edge.head_slot:
graph.add_edge(head=(in_edge.head, in_edge.head_slot), tail=(out_edge.tail, out_edge.tail_slot))
graph.del_edge(in_edge)
graph.del_edge(out_edge)
break
node.remove()
def convert_to_graph(script_module, module, converter=None, **kwargs):
"""
Convert module to our graph ir, i.e., build a :class:`Model` type
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module obtained with torch.jit.script
module : nn.Module
the targeted module instance
converter : `TorchConverter`
default `GraphConverter` is used
kwargs:
will be passed to `converter.convert_module()`
Returns
-------
Model
the constructed IR model
"""
model = Model(_internal=True)
module_name = '_model'
if converter is None:
converter = GraphConverter()
converter.convert_module(script_module, module, module_name, model, **kwargs)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment