"examples/vscode:/vscode.git/clone" did not exist on "b58666acefedc2ef7b38a0f9861dc16e51776af1"
Unverified Commit 468917ca authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #3155 from microsoft/dev-retiarii

[Do NOT Squash] Merge retiarii dev branch to master
parents f8424a9f d5a551c8
# we will support tensorflow in future release
framework = 'pytorch'
import time
import os
from typing import List
from ..graph import Model, ModelStatus
from .base import BaseExecutionEngine
from .cgo_engine import CGOExecutionEngine
from .interface import AbstractExecutionEngine, WorkerInfo
from .listener import DefaultListener
_execution_engine = None
_default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources']
def get_execution_engine() -> BaseExecutionEngine:
"""
Currently we assume the default execution engine is BaseExecutionEngine.
"""
global _execution_engine
if _execution_engine is None:
if os.environ.get('CGO') == 'true':
_execution_engine = CGOExecutionEngine()
else:
_execution_engine = BaseExecutionEngine()
return _execution_engine
def get_and_register_default_listener(engine: AbstractExecutionEngine) -> DefaultListener:
global _default_listener
if _default_listener is None:
_default_listener = DefaultListener()
engine.register_graph_listener(_default_listener)
return _default_listener
def submit_models(*models: Model) -> None:
engine = get_execution_engine()
get_and_register_default_listener(engine)
engine.submit_models(*models)
def wait_models(*models: Model) -> None:
get_and_register_default_listener(get_execution_engine())
while True:
time.sleep(1)
left_models = [g for g in models if not g.status in (ModelStatus.Trained, ModelStatus.Failed)]
if not left_models:
break
def query_available_resources() -> List[WorkerInfo]:
listener = get_and_register_default_listener(get_execution_engine())
return listener.resources
import logging
from typing import Dict, Any, List
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData
from ..integration import send_trial, receive_trial_parameters, get_advisor
_logger = logging.getLogger(__name__)
class BaseGraphData:
def __init__(self, model_script: str, training_module: str, training_kwargs: Dict[str, Any]) -> None:
self.model_script = model_script
self.training_module = training_module
self.training_kwargs = training_kwargs
def dump(self) -> dict:
return {
'model_script': self.model_script,
'training_module': self.training_module,
'training_kwargs': self.training_kwargs
}
@staticmethod
def load(data):
return BaseGraphData(data['model_script'], data['training_module'], data['training_kwargs'])
class BaseExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with no optimization at all.
Resource management is yet to be implemented.
"""
def __init__(self) -> None:
"""
Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener.
"""
self._listeners: List[AbstractGraphListener] = []
# register advisor callbacks
advisor = get_advisor()
advisor.send_trial_callback = self._send_trial_callback
advisor.request_trial_jobs_callback = self._request_trial_jobs_callback
advisor.trial_end_callback = self._trial_end_callback
advisor.intermediate_metric_callback = self._intermediate_metric_callback
advisor.final_metric_callback = self._final_metric_callback
self._running_models: Dict[int, Model] = dict()
def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model),
model.training_config.module, model.training_config.kwargs)
self._running_models[send_trial(data.dump())] = model
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners:
_logger.warning('resources: %s', listener.resources)
if not listener.has_available_resource():
_logger.warning('There is no available resource, but trial is submitted.')
listener.on_resource_used(1)
_logger.warning('on_resource_used: %s', listener.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners:
listener.on_resource_available(1 * num_trials)
_logger.warning('on_resource_available: %s', listener.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
for listener in self._listeners:
listener.on_training_end(model, success)
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_intermediate_metric(model, metrics)
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.metric = metrics
for listener in self._listeners:
listener.on_metric(model, metrics)
def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here?
@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
with open('_generated_model.py', 'w') as f:
f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module)
model_cls = utils.import_('_generated_model._model')
trainer_instance = trainer_cls(model=model_cls(), **graph_data.training_kwargs)
trainer_instance.fit()
import logging
from typing import List, Dict, Tuple
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData
from ..integration import send_trial, receive_trial_parameters, get_advisor
from .logical_optimizer.logical_plan import LogicalPlan, PhysicalDevice
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from .base import BaseGraphData
_logger = logging.getLogger(__name__)
class CGOExecutionEngine(AbstractExecutionEngine):
def __init__(self, n_model_per_graph=4) -> None:
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
self.n_model_per_graph = n_model_per_graph
self._optimizers = [DedupInputOptimizer()]
self._original_models = {}
self._original_model_to_multi_model = {}
# register advisor callbacks
advisor = get_advisor()
advisor.send_trial_callback = self._send_trial_callback
advisor.request_trial_jobs_callback = self._request_trial_jobs_callback
advisor.trial_end_callback = self._trial_end_callback
advisor.intermediate_metric_callback = self._intermediate_metric_callback
advisor.final_metric_callback = self._final_metric_callback
def add_optimizer(self, opt):
self._optimizers.append(opt)
def submit_models(self, *models: List[Model]) -> None:
_logger.info('%d models are submitted', len(models))
logical = self._build_logical(models)
for opt in self._optimizers:
opt.convert(logical)
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement),
model.training_config.module, model.training_config.kwargs)
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
self._running_models[send_trial(data.dump())] = model
# for model in models:
# data = BaseGraphData(codegen.model_to_pytorch_script(model),
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
# if node.graph.model not in unique_models:
# unique_models.add(node.graph.model)
# return [m for m in unique_models]
grouped_models: List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan)
phy_models_and_placements = []
for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model)
phy_models_and_placements.append((model, model_placement, multi_model.keys()))
return phy_models_and_placements
def _build_logical(self, models: List[Model]) -> LogicalPlan:
logical_plan = LogicalPlan(plan_id=self.logical_plan_counter)
for model in models:
logical_plan.add_model(model)
self.logical_plan_counter += 1
return logical_plan
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners:
listener.on_resource_used(0) # FIXME: find the real resource id
def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners:
listener.on_resource_available([0] * num_trials) # FIXME: find the real resource id
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
for model_id in self._original_model_to_multi_model:
if self._original_model_to_multi_model[model_id] == model:
original_model = self._original_models[model_id]
if success:
original_model.status = ModelStatus.Trained
else:
original_model.status = ModelStatus.Failed
for listener in self._listeners:
listener.on_training_end(original_model, success)
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
# model = self._running_models[trial_id]
merged_metrics = dict(metrics)
for model_id in merged_metrics:
int_model_id = int(model_id)
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
# model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_intermediate_metric(self._original_models[int_model_id], merged_metrics[model_id])
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
merged_metrics = dict(metrics)
for model_id in merged_metrics:
int_model_id = int(model_id)
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
# model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_metric(self._original_models[int_model_id], merged_metrics[model_id])
def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here?
@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
_logger.info('CGO_ENGINE trial parameters received')
with open('_generated_model.py', 'w') as f:
f.write(graph_data.model_script)
# with open('_debug_graph_data.json', 'w') as f:
# json.dump(graph_data.dump(), f)
trainer_cls = utils.import_(graph_data.training_module)
model_cls = utils.import_(f"_generated_model.{graph_data.training_kwargs['model_cls']}")
trainer_instance = trainer_cls(model_cls(), graph_data.training_kwargs)
trainer_instance.fit()
class AssemblePolicy:
@staticmethod
def group(logical_plan):
group_model = {}
for idx, m in enumerate(logical_plan.models):
group_model[m] = PhysicalDevice('server', f'cuda:{idx}')
return [group_model]
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, NewType, List
from ..graph import Model, MetricData
__all__ = [
'GraphData', 'WorkerInfo',
'AbstractGraphListener', 'AbstractExecutionEngine'
]
GraphData = NewType('GraphData', Any)
"""
A _serializable_ internal data type defined by execution engine.
Execution engine will submit this kind of data through NNI to worker machine, and train it there.
A `GraphData` object describes a (merged) executable graph.
This is trial's "hyper-parameter" in NNI's term and will be transfered in JSON format.
See `AbstractExecutionEngine` for details.
"""
WorkerInfo = NewType('WorkerInfo', Any)
"""
To be designed. Discussion needed.
This describes the properties of a worker machine. (e.g. memory size)
"""
class AbstractGraphListener(ABC):
"""
Abstract listener interface to receive graph events.
Use `AbstractExecutionEngine.register_graph_listener()` to activate a listener.
"""
@abstractmethod
def on_metric(self, model: Model, metric: MetricData) -> None:
"""
Reports the final metric of a graph.
"""
raise NotImplementedError
@abstractmethod
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
"""
Reports the latest intermediate metric of a trainning graph.
"""
pass
@abstractmethod
def on_training_end(self, model: Model, success: bool) -> None:
"""
Reports either a graph is fully trained or the training process has failed.
"""
pass
@abstractmethod
def on_resource_available(self, resources: List[WorkerInfo]) -> None:
"""
Reports when a worker becomes idle.
"""
pass
class AbstractExecutionEngine(ABC):
"""
The abstract interface of execution engine.
Most of these APIs are used by strategy, except `trial_execute_graph`, which is invoked by framework in trial.
Strategy will get the singleton execution engine object through a global API,
and use it in either sync or async manner.
Execution engine is responsible for submitting (maybe-optimized) models to NNI,
and assigning their metrics to the `Model` object after training.
Execution engine is also responsible to launch the graph in trial process,
because it's the only one who understands graph data, or "hyper-parameter" in NNI's term.
Execution engine will leverage NNI Advisor APIs, which are yet open for discussion.
In synchronized use case, the strategy will have a loop to call `submit_models` and `wait_models` repeatly,
and will receive metrics from `Model` attributes.
Execution engine could assume that strategy will only submit graph when there are availabe resources (for now).
In asynchronized use case, the strategy will register a listener to receive events,
while still using `submit_models` to train.
There will be a `BaseExecutionEngine` subclass.
Inner-graph optimizing is supposed to derive `BaseExecutionEngine`,
while overrides `submit_models` and `trial_execute_graph`.
cross-graph optimizing is supposed to derive `AbstractExectutionEngine` directly,
because in this case APIs like `wait_graph` and `listener.on_training_end` will have unique logic.
There might be some util functions benefit all optimizing methods,
but non-mandatory utils should not be covered in abstract interface.
"""
@abstractmethod
def submit_models(self, *models: Model) -> None:
"""
Submit models to NNI.
This method is supposed to call something like `nni.Advisor.create_trial_job(graph_data)`.
"""
raise NotImplementedError
@abstractmethod
def query_available_resource(self) -> List[WorkerInfo]:
"""
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
Could be left unimplemented for first iteration.
"""
raise NotImplementedError
@abstractmethod
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
"""
Register a listener to receive graph events.
Could be left unimplemented for first iteration.
"""
raise NotImplementedError
@abstractclassmethod
def trial_execute_graph(cls) -> MetricData:
"""
Train graph and returns its metrics, in a separate trial process.
Each call to `nni.Advisor.create_trial_job(graph_data)` will eventually invoke this method.
Because this method will be invoked in trial process on training platform,
it has different context from other methods and has no access to global variable or `self`.
However util APIs like `.utils.experiment_config()` should still be available.
"""
raise NotImplementedError
from ..graph import Model, ModelStatus
from .interface import MetricData, AbstractGraphListener
class DefaultListener(AbstractGraphListener):
def __init__(self):
self.resources: int = 0 # simply resource count
def has_available_resource(self) -> bool:
return self.resources > 0
def on_metric(self, model: Model, metric: MetricData) -> None:
model.metric = metric
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
model.intermediate_metrics.append(metric)
def on_training_end(self, model: Model, success: bool) -> None:
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
def on_resource_available(self, resources: int) -> None:
self.resources += resources
def on_resource_used(self, resources: int) -> None:
self.resources -= resources
from abc import ABC
from .logical_plan import LogicalPlan
class AbstractOptimizer(ABC):
def __init__(self) -> None:
pass
def convert(self, logical_plan: LogicalPlan) -> None:
raise NotImplementedError
import copy
from typing import Dict, Tuple, List, Any
from nni.retiarii.utils import uid
from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation
class PhysicalDevice:
def __init__(self, server: str, device: str):
self.server = server
self.device = device
def __eq__(self, o) -> bool:
return self.server == o.server and self.device == o.device
def __hash__(self) -> int:
return hash(self.server + '_' + self.device)
class AbstractLogicalNode(Node):
def __init__(self, graph, node_id, name, operation, _internal=False):
super().__init__(graph, node_id, name, operation, _internal=_internal)
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]:
raise NotImplementedError
def _fork_to(self, graph: Graph):
raise NotImplementedError
class LogicalGraph(Graph):
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
super().__init__(model, graph_id, name='logical_' + name, _internal=_internal)
def _dump(self) -> Any:
nodes_dump = {}
for node in self.hidden_nodes:
if isinstance(node, OriginNode):
nodes_dump[f"{node.original_graph.model.model_id}_{node.name}"] = node._dump(
)
else:
nodes_dump[f"{node.graph.model.model_id}_{node.name}"] = node._dump()
edges_dump = []
for edge in self.edges:
if isinstance(edge.head, OriginNode):
head_info = f'{edge.head.original_graph.model.model_id}_{edge.head.name}'
else:
head_info = edge.head.name
if isinstance(edge.tail, OriginNode):
tail_info = f'{edge.tail.original_graph.model.model_id}_{edge.tail.name}'
else:
tail_info = edge.tail.name
edges_dump.append((head_info, tail_info))
return {
'inputs': self.input_node.operation.io_names,
'outputs': self.output_node.operation.io_names,
'nodes': nodes_dump,
'edges': edges_dump
}
def _fork_to(self, model: Model) -> Graph:
new_graph = Graph(model, self.id, self.name,
_internal=True)._register()
for node in self.hidden_nodes:
if isinstance(node, AbstractLogicalNode):
node._fork_to(new_graph)
else:
Node(new_graph, node.id, node.name,
node.operation, _internal=True)._register()
id_to_new_node = {node.__repr__(): node for node in new_graph.nodes}
for edge in self.edges:
new_head = id_to_new_node[edge.head.__repr__()]
new_tail = id_to_new_node[edge.tail.__repr__()]
Edge((new_head, edge.head_slot),
(new_tail, edge.tail_slot), _internal=True)._register()
return new_graph
class OriginNode(AbstractLogicalNode):
def __init__(self, logical_graph: LogicalGraph,
original_graph: Graph, original_node: Node,
name: str, operation, _internal=False):
super().__init__(logical_graph, original_node.id, name, operation)
self.original_graph = original_graph
self.original_node = original_node
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]:
model_id = self.original_node.graph.model.model_id
new_node = Node(self.original_node.graph, self.original_node.id,
f"M_{model_id}_" +
self.original_node.name,
self.original_node.operation)
return new_node, multi_model_placement[self.original_node.graph.model]
def __repr__(self):
return f'OriginNode(id={self.id}, name={self.name}, \
operation={self.operation}, origin_model_id={self.original_graph.model.model_id})'
def _fork_to(self, graph: Graph):
OriginNode(graph, self.original_graph, self.original_node,
self.name, self.operation)._register()
class LogicalPlan:
def __init__(self, plan_id=0) -> None:
self.lp_model = Model(_internal=True)
self.id = plan_id
self.logical_graph = LogicalGraph(
self.lp_model, self.id, name=f'{self.id}', _internal=True)._register()
self.lp_model._root_graph_name = self.logical_graph.name
self.models = []
def add_model(self, model: Model):
self.models.append(model)
# Only optimize the root graph.
self._merge_graph(model.root_graph)
def _merge_graph(self, from_graph):
to_graph = self.logical_graph
id_to_new_node = {} # old node ID -> new node object
for old_node in from_graph.nodes:
new_node = OriginNode(to_graph, old_node.graph,
old_node, old_node.name,
old_node.operation, _internal=True)._register()
id_to_new_node[old_node.id] = new_node
for edge in from_graph.edges:
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail,
edge.tail_slot), _internal=True)._register()
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \
-> Tuple[Model, Dict[Node, PhysicalDevice], List[Model]]:
phy_model = Model(_internal=True) # self.lp_model.fork()
phy_graph = self.lp_model.root_graph._fork_to(phy_model)
# Add a flag to mark multi-model in graph json.
# Multi-model has a list of training configs in kwargs['model_kwargs']
if len(multi_model_placement) > 1:
phy_model.training_config.kwargs['is_multi_model'] = True
phy_model.training_config.kwargs['model_cls'] = phy_graph.name
phy_model.training_config.kwargs['model_kwargs'] = []
# FIXME: allow user to specify
phy_model.training_config.module = 'nni.retiarii.trainer.PyTorchMultiModelTrainer'
# merge sub-graphs
for model in multi_model_placement:
for graph_name in model.graphs:
if graph_name != model._root_graph_name:
model.graphs[graph_name]._fork_to(
phy_model, name_prefix=f'M_{model.model_id}_')
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
training_config_slot = {} # Model ID -> Slot ID
input_slot_mapping = {}
output_slot_mapping = {}
# Replace all logical nodes to executable physical nodes
hidden_nodes = phy_graph.hidden_nodes.copy()
node_placements = {}
for node in hidden_nodes:
if isinstance(node, OriginNode):
model_id = node.original_graph.model.model_id
if node.original_graph.model not in multi_model_placement:
for edge in node.incoming_edges:
edge.remove()
for edge in node.outgoing_edges:
edge.remove()
node.remove()
continue
if isinstance(node, AbstractLogicalNode):
new_node, placement = node.assemble(multi_model_placement)
if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id
if model_id not in training_config_slot:
phy_model.training_config.kwargs['model_kwargs'].append(new_node.graph.model.training_config.kwargs.copy())
training_config_slot[model_id] = len(phy_model.training_config.kwargs['model_kwargs']) - 1
slot = training_config_slot[model_id]
phy_model.training_config.kwargs['model_kwargs'][slot]['model_id'] = model_id
phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = False
phy_model.training_config.kwargs['model_kwargs'][slot]['use_output'] = False
else:
slot = training_config_slot[model_id]
# If a model's inputs/outputs are not used in the multi-model
# the codegen and trainer should not generate and use them
# "use_input" and "use_output" are used to mark whether
# an input/output of a model is used in a multi-model
if new_node.operation.type == '_inputs':
input_slot_mapping[new_node] = slot
phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = True
if new_node.operation.type == '_outputs':
output_slot_mapping[new_node] = slot
phy_model.training_config.kwargs['model_kwargs'][slot]['use_output'] = True
self.node_replace(node, new_node)
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model_id}_{old_cell_name}'
node_placements[new_node] = placement
node.remove()
# If two nodes are placed on different devices, use ToDevice op to copy the node
existing_edges = phy_graph.edges.copy()
# Avoid a node is copied multiple times on the same device
copied_op: Dict[Tuple(Node, PhysicalDevice), Node] = {}
for edge in existing_edges:
head_placement = node_placements[edge.head]
tail_placement = node_placements[edge.tail]
if head_placement != tail_placement:
if head_placement.server != tail_placement.server:
raise ValueError('Cross-server placement is not supported.')
# Same server different devices
if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)]
else:
to_operation = Operation.new('ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, uid(), edge.head.name + "_to_" + edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node
edge.head = to_node
edge.head_slot = None
# merge all input nodes into one with multiple slots
input_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_inputs':
input_nodes.append(node)
for edge in phy_graph.edges:
if edge.head in input_nodes:
edge.head_slot = input_slot_mapping[edge.head]
edge.head = phy_graph.input_node
# merge all output nodes into one with multiple slots
output_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs':
output_nodes.append(node)
for edge in phy_graph.edges:
if edge.tail in output_nodes:
edge.tail_slot = output_slot_mapping[edge.tail]
edge.tail = phy_graph.output_node
for node in input_nodes:
node.remove()
for node in output_nodes:
node.remove()
return phy_model, node_placements
def node_replace(self, old_node: Node, new_node: Node, input_slot_mapping=None, output_slot_mapping=None):
# TODO: currently, only support single input slot and output slot.
if input_slot_mapping is not None or output_slot_mapping is not None:
raise ValueError('Slot mapping is not supported')
phy_graph = old_node.graph
new_node.graph = phy_graph
new_node._register()
for edge in phy_graph.edges:
if edge.head == old_node:
edge.head = new_node
elif edge.tail == old_node:
edge.tail = new_node
# after the replacement, there might be multiple duplicated edges
# with the same input and output nodes, which should be de-duplicated
self._remove_duplicated_edges()
def _remove_duplicated_edges(self):
# TODO: it does not have duplicated edges if only supporting dedup input
# Duplicated edges appear when a chain of prefix nodes are deduplicated
pass
from typing import List, Dict, Tuple
from nni.retiarii.utils import uid
from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode, PhysicalDevice)
_supported_training_modules = ['nni.retiarii.trainer.PyTorchImageClassificationTrainer']
class DedupInputNode(AbstractLogicalNode):
def __init__(self, logical_graph: LogicalGraph, node_id: int,
nodes_to_dedup: List[Node], _internal=False):
super().__init__(logical_graph, node_id,
"Dedup_"+nodes_to_dedup[0].name,
nodes_to_dedup[0].operation)
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]:
for node in self.origin_nodes:
if node.original_graph.model in multi_model_placement:
new_node = Node(node.original_graph, node.id,
f'M_{node.original_graph.model.model_id}_{node.name}',
node.operation)
return new_node, multi_model_placement[node.original_graph.model]
raise ValueError(f'DedupInputNode {self.name} does not contain nodes from multi_model')
def _fork_to(self, graph: Graph):
DedupInputNode(graph, self.id, self.origin_nodes)._register()
def __repr__(self) -> str:
return f'DedupNode(id={self.id}, name={self.name}, \
len(nodes_to_dedup)={len(self.origin_nodes)}'
class DedupInputOptimizer(AbstractOptimizer):
def __init__(self) -> None:
pass
def _check_deduplicate_by_node(self, root_node, node_to_check):
if root_node == node_to_check:
return True
if root_node.operation.type == '_inputs' and \
node_to_check.operation.type == '_inputs' and \
isinstance(root_node, OriginNode) and \
isinstance(node_to_check, OriginNode):
if root_node.original_graph.model.training_config.module not in _supported_training_modules:
return False
if root_node.original_graph.model.training_config == node_to_check.original_graph.model.training_config:
return True
else:
return False
else:
return False
def convert(self, logical_plan: LogicalPlan) -> None:
nodes_to_skip = set()
while True: # repeat until the logical_graph converges
input_nodes = logical_plan.logical_graph.get_nodes_by_type("_inputs")
# _PseudoOperation(type_name="_inputs"))
root_node = None
for node in input_nodes:
if node in nodes_to_skip:
continue
root_node = node
break
if root_node == None:
break # end of convert
else:
nodes_to_dedup = []
for node in input_nodes:
if node in nodes_to_skip:
continue
if self._check_deduplicate_by_node(root_node, node):
nodes_to_dedup.append(node)
assert(len(nodes_to_dedup) >= 1)
if len(nodes_to_dedup) == 1:
assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node)
else:
dedup_node = DedupInputNode(logical_plan.logical_graph, uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup:
edge.head = dedup_node
if edge.tail in nodes_to_dedup:
edge.tail = dedup_node
for node in nodes_to_dedup:
node.remove()
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
from threading import Thread
from typing import Any, Optional
from ..experiment import Experiment, TrainingServiceConfig, launcher, rest
from ..experiment.config.base import ConfigBase, PathLike
from ..experiment.config import util
from ..experiment.pipe import Pipe
from .graph import Model
from .utils import get_records
from .integration import RetiariiAdvisor
from .converter import convert_to_graph
from .mutator import Mutator, LayerChoiceMutator, InputChoiceMutator
from .trainer.interface import BaseTrainer
from .strategies.strategy import BaseStrategy
_logger = logging.getLogger(__name__)
@dataclass(init=False)
class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove
trial_command: str = 'python3 -m nni.retiarii.trial_entry'
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: int = 0
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
nni_manager_ip: Optional[str] = None
debug: bool = False
log_level: Optional[str] = None
experiment_working_directory: Optional[PathLike] = None
# remove configuration of tuner/assessor/advisor
training_service: TrainingServiceConfig
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(training_service_platform)
def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
@property
def _canonical_rules(self):
return _canonical_rules
@property
def _validation_rules(self):
return _validation_rules
_canonical_rules = {
'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
'experiment_working_directory': util.canonical_path
}
_validation_rules = {
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
'trial_concurrency': lambda value: value > 0,
'trial_gpu_number': lambda value: value >= 0,
'max_experiment_duration': lambda value: util.parse_time(value) > 0,
'max_trial_number': lambda value: value > 0,
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
class RetiariiExperiment(Experiment):
def __init__(self, base_model: Model, trainer: BaseTrainer,
applied_mutators: Mutator, strategy: BaseStrategy):
self.config: RetiariiExeConfig = None
self.port: Optional[int] = None
self.base_model = base_model
self.trainer = trainer
self.applied_mutators = applied_mutators
self.strategy = strategy
self.recorded_module_args = get_records()
self._dispatcher = RetiariiAdvisor()
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
def _process_inline_mutation(self, base_model):
"""
the mutators are order independent
"""
lc_nodes = base_model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.nn.LayerChoice')
ic_nodes = base_model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.nn.InputChoice')
if not lc_nodes and not ic_nodes:
return None
applied_mutators = []
for node in lc_nodes:
mutator = LayerChoiceMutator(node.name, node.operation.parameters['choices'])
applied_mutators.append(mutator)
for node in ic_nodes:
mutator = InputChoiceMutator(node.name, node.operation.parameters['n_chosen'])
applied_mutators.append(mutator)
return applied_mutators
def _start_strategy(self):
import torch
try:
script_module = torch.jit.script(self.base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
base_model = convert_to_graph(script_module, self.base_model, self.recorded_module_args)
assert id(self.trainer) in self.recorded_module_args
trainer_config = self.recorded_module_args[id(self.trainer)]
base_model.apply_trainer(trainer_config['modulename'], trainer_config['args'])
# handle inline mutations
mutators = self._process_inline_mutation(base_model)
if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \
do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
self.applied_mutators = mutators
_logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators)).start()
_logger.info('Strategy started!')
def start(self, config: RetiariiExeConfig, port: int = 8080, debug: bool = False) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
# FIXME:
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread(target=self._dispatcher.run).start()
self._start_strategy()
# TODO: register experiment management metadata
def stop(self) -> None:
"""
Stop background experiment.
"""
self._proc.kill()
self._pipe.close()
self.port = None
self._proc = None
self._pipe = None
def run(self, config: RetiariiExeConfig, port: int = 8080, debug: bool = False) -> str:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self.config = config
self.start(config, port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
# TODO: double check the status
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']:
return status
finally:
self.stop()
def get_status(self) -> str:
if self.port is None:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
This diff is collapsed.
import logging
from typing import Any, Callable
import json_tricks
import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType
from .graph import MetricData
_logger = logging.getLogger(__name__)
class RetiariiAdvisor(MsgDispatcherBase):
"""
The class is to connect Retiarii components to NNI backend.
It will function as the main thread when running a Retiarii experiment through NNI.
Strategy will be launched as its thread, who will call APIs in execution engine. Execution
engine will then find the advisor singleton and send payloads to advisor.
When metrics are sent back, advisor will first receive the payloads, who will call the callback
function (that is a member function in graph listener).
The conversion advisor provides are minimum. It is only a send/receive module, and execution engine
needs to handle all the rest.
FIXME
How does advisor exit when strategy exists?
Attributes
----------
send_trial_callback
request_trial_jobs_callback
trial_end_callback
intermediate_metric_callback
final_metric_callback
"""
def __init__(self):
super(RetiariiAdvisor, self).__init__()
register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None
self.send_trial_callback: Callable[[dict], None] = None
self.request_trial_jobs_callback: Callable[[int], None] = None
self.trial_end_callback: Callable[[int, bool], None] = None
self.intermediate_metric_callback: Callable[[int, MetricData], None] = None
self.final_metric_callback: Callable[[int, MetricData], None] = None
self.parameters_count = 0
def handle_initialize(self, data):
"""callback for initializing the advisor
Parameters
----------
data: dict
search space
"""
self.handle_update_search_space(data)
send(CommandType.Initialized, '')
def send_trial(self, parameters):
"""
Send parameters to NNI.
Parameters
----------
parameters : Any
Any payload.
Returns
-------
int
Parameter ID that is assigned to this parameter,
which will be used for identification in future.
"""
self.parameters_count += 1
new_trial = {
'parameter_id': self.parameters_count,
'parameters': parameters,
'parameter_source': 'algorithm'
}
_logger.info('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_tricks.dumps(new_trial))
if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count
def handle_request_trial_jobs(self, num_trials):
_logger.info('Request trial jobs: %s', num_trials)
if self.request_trial_jobs_callback is not None:
self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable
def handle_update_search_space(self, data):
_logger.info('Received search space: %s', data)
self.search_space = data
def handle_trial_end(self, data):
_logger.info('Trial end: %s', data)
self.trial_end_callback(json_tricks.loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.info('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
elif data['type'] == MetricType.FINAL:
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
@staticmethod
def _process_value(value) -> Any: # hopefully a float
value = json_tricks.loads(value)
if isinstance(value, dict):
if 'default' in value:
return value['default']
else:
return value
return value
_advisor: RetiariiAdvisor = None
def get_advisor() -> RetiariiAdvisor:
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor: RetiariiAdvisor):
global _advisor
assert _advisor is None
_advisor = advisor
def send_trial(parameters: dict) -> int:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return get_advisor().send_trial(parameters)
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
"""
params = nni.get_next_parameter()
return params
from typing import (Any, Iterable, List, Optional)
from .graph import Model
__all__ = ['Sampler', 'Mutator']
Choice = Any
class Sampler:
"""
Handles `Mutator.choice()` calls.
"""
def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
raise NotImplementedError()
def mutation_start(self, mutator: 'Mutator', model: Model) -> None:
pass
def mutation_end(self, mutator: 'Mutator', model: Model) -> None:
pass
class Mutator:
"""
Mutates graphs in model to generate new model.
`Mutator` class will be used in two places:
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
and then use `Mutator.apply()` to mutate model.
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
"""
def __init__(self, sampler: Optional[Sampler] = None):
self.sampler: Optional[Sampler] = sampler
self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None
def bind_sampler(self, sampler: Sampler) -> 'Mutator':
"""
Set the sampler which will handle `Mutator.choice` calls.
"""
self.sampler = sampler
return self
def apply(self, model: Model) -> Model:
"""
Apply this mutator on a model.
Returns mutated model.
The model will be copied before mutation and the original model will not be modified.
"""
assert self.sampler is not None
copy = model.fork()
self._cur_model = copy
self._cur_choice_idx = 0
self.sampler.mutation_start(self, copy)
self.mutate(copy)
self.sampler.mutation_end(self, copy)
self._cur_model = None
self._cur_choice_idx = None
return copy
def dry_run(self, model: Model) -> List[List[Choice]]:
"""
Dry run mutator on a model to collect choice candidates.
If you invoke this method multiple times on same or different models,
it may or may not return identical results, depending on how the subclass implements `Mutator.mutate()`.
"""
sampler_backup = self.sampler
recorder = _RecorderSampler()
self.sampler = recorder
new_model = self.apply(model)
self.sampler = sampler_backup
return recorder.recorded_candidates, new_model
def mutate(self, model: Model) -> None:
"""
Abstract method to be implemented by subclass.
Mutate a model in place.
"""
raise NotImplementedError()
def choice(self, candidates: Iterable[Choice]) -> Choice:
"""
Ask sampler to make a choice.
"""
assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
self._cur_choice_idx += 1
return ret
class _RecorderSampler(Sampler):
def __init__(self):
self.recorded_candidates: List[List[Choice]] = []
def choice(self, candidates: List[Choice], *args) -> Choice:
self.recorded_candidates.append(candidates)
return candidates[0]
# the following is for inline mutation
class LayerChoiceMutator(Mutator):
def __init__(self, node_name: str, candidates: List):
super().__init__()
self.node_name = node_name
self.candidates = candidates
def mutate(self, model):
target = model.get_node_by_name(self.node_name)
indexes = [i for i in range(len(self.candidates))]
chosen_index = self.choice(indexes)
chosen_cand = self.candidates[chosen_index]
target.update_operation(chosen_cand['type'], chosen_cand['parameters'])
class InputChoiceMutator(Mutator):
def __init__(self, node_name: str, n_chosen: int):
super().__init__()
self.node_name = node_name
self.n_chosen = n_chosen
def mutate(self, model):
target = model.get_node_by_name(self.node_name)
candidates = [i for i in range(self.n_chosen)]
chosen = self.choice(candidates)
target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs',
{'chosen': chosen})
import inspect
import logging
from typing import Any, List
import torch
import torch.nn as nn
from ...utils import add_record
_logger = logging.getLogger(__name__)
__all__ = [
'LayerChoice', 'InputChoice', 'Placeholder',
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink',
'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin',
'Tanhshrink', 'RReLU', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
'LSTMCell', 'GRUCell', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten', 'Hardsigmoid', 'Hardswish'
]
class LayerChoice(nn.Module):
def __init__(self, op_candidates, reduction=None, return_mask=False, key=None):
super(LayerChoice, self).__init__()
self.candidate_ops = op_candidates
self.label = key
if reduction or return_mask:
_logger.warning('input arguments `reduction` and `return_mask` are deprecated!')
def forward(self, x):
return x
class InputChoice(nn.Module):
def __init__(self, n_candidates=None, choose_from=None, n_chosen=1,
reduction="sum", return_mask=False, key=None):
super(InputChoice, self).__init__()
self.n_chosen = n_chosen
self.reduction = reduction
self.label = key
if n_candidates or choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor:
# fake return
return torch.tensor(candidate_inputs) # pylint: disable=not-callable
class ValueChoice:
"""
The instance of this class can only be used as input argument,
when instantiating a pytorch module.
TODO: can also be used in training approach
"""
def __init__(self, candidate_values: List[Any]):
self.candidate_values = candidate_values
class Placeholder(nn.Module):
def __init__(self, label, related_info):
add_record(id(self), related_info)
self.label = label
self.related_info = related_info
super(Placeholder, self).__init__()
def forward(self, x):
return x
class ChosenInputs(nn.Module):
def __init__(self, chosen: int):
super().__init__()
self.chosen = chosen
def forward(self, candidate_inputs):
# TODO: support multiple chosen inputs
return candidate_inputs[self.chosen]
# the following are pytorch modules
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
class Sequential(nn.Sequential):
def __init__(self, *args):
add_record(id(self), {})
super(Sequential, self).__init__(*args)
class ModuleList(nn.ModuleList):
def __init__(self, *args):
add_record(id(self), {})
super(ModuleList, self).__init__(*args)
def wrap_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
# TODO: support different versions of pytorch
Identity = wrap_module(nn.Identity)
Linear = wrap_module(nn.Linear)
Conv1d = wrap_module(nn.Conv1d)
Conv2d = wrap_module(nn.Conv2d)
Conv3d = wrap_module(nn.Conv3d)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d)
Threshold = wrap_module(nn.Threshold)
ReLU = wrap_module(nn.ReLU)
Hardtanh = wrap_module(nn.Hardtanh)
ReLU6 = wrap_module(nn.ReLU6)
Sigmoid = wrap_module(nn.Sigmoid)
Tanh = wrap_module(nn.Tanh)
Softmax = wrap_module(nn.Softmax)
Softmax2d = wrap_module(nn.Softmax2d)
LogSoftmax = wrap_module(nn.LogSoftmax)
ELU = wrap_module(nn.ELU)
SELU = wrap_module(nn.SELU)
CELU = wrap_module(nn.CELU)
GLU = wrap_module(nn.GLU)
GELU = wrap_module(nn.GELU)
Hardshrink = wrap_module(nn.Hardshrink)
LeakyReLU = wrap_module(nn.LeakyReLU)
LogSigmoid = wrap_module(nn.LogSigmoid)
Softplus = wrap_module(nn.Softplus)
Softshrink = wrap_module(nn.Softshrink)
MultiheadAttention = wrap_module(nn.MultiheadAttention)
PReLU = wrap_module(nn.PReLU)
Softsign = wrap_module(nn.Softsign)
Softmin = wrap_module(nn.Softmin)
Tanhshrink = wrap_module(nn.Tanhshrink)
RReLU = wrap_module(nn.RReLU)
AvgPool1d = wrap_module(nn.AvgPool1d)
AvgPool2d = wrap_module(nn.AvgPool2d)
AvgPool3d = wrap_module(nn.AvgPool3d)
MaxPool1d = wrap_module(nn.MaxPool1d)
MaxPool2d = wrap_module(nn.MaxPool2d)
MaxPool3d = wrap_module(nn.MaxPool3d)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d)
LPPool1d = wrap_module(nn.LPPool1d)
LPPool2d = wrap_module(nn.LPPool2d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm)
BatchNorm1d = wrap_module(nn.BatchNorm1d)
BatchNorm2d = wrap_module(nn.BatchNorm2d)
BatchNorm3d = wrap_module(nn.BatchNorm3d)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d)
LayerNorm = wrap_module(nn.LayerNorm)
GroupNorm = wrap_module(nn.GroupNorm)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm)
Dropout = wrap_module(nn.Dropout)
Dropout2d = wrap_module(nn.Dropout2d)
Dropout3d = wrap_module(nn.Dropout3d)
AlphaDropout = wrap_module(nn.AlphaDropout)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d)
Embedding = wrap_module(nn.Embedding)
EmbeddingBag = wrap_module(nn.EmbeddingBag)
RNNBase = wrap_module(nn.RNNBase)
RNN = wrap_module(nn.RNN)
LSTM = wrap_module(nn.LSTM)
GRU = wrap_module(nn.GRU)
RNNCellBase = wrap_module(nn.RNNCellBase)
RNNCell = wrap_module(nn.RNNCell)
LSTMCell = wrap_module(nn.LSTMCell)
GRUCell = wrap_module(nn.GRUCell)
PixelShuffle = wrap_module(nn.PixelShuffle)
Upsample = wrap_module(nn.Upsample)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = wrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d)
ConstantPad2d = wrap_module(nn.ConstantPad2d)
ConstantPad3d = wrap_module(nn.ConstantPad3d)
Bilinear = wrap_module(nn.Bilinear)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
Flatten = wrap_module(nn.Flatten)
#Unflatten = wrap_module(nn.Unflatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
Hardswish = wrap_module(nn.Hardswish)
#SiLU = wrap_module(nn.SiLU)
#TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
from typing import (Any, Dict, List)
from . import debug_configs
__all__ = ['Operation', 'Cell']
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
class Operation:
"""
Calculation logic of a graph node.
The constructor is private. Use `Operation.new()` to create operation object.
`Operation` is a naive record.
Do not "mutate" its attributes or store information relate to specific node.
All complex logic should be implemented in `Node` class.
Attributes
----------
type
Operation type name (e.g. Conv2D).
If it starts with underscore, the "operation" is a special one (e.g. subgraph, input/output).
parameters
Arbitrary key-value parameters (e.g. kernel_size).
"""
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False):
assert _internal, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type_name
self.parameters: Dict[str, Any] = parameters
def to_init_code(self, field: str) -> str:
raise NotImplementedError()
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
raise NotImplementedError()
def _to_class_name(self) -> str:
raise NotImplementedError()
def __bool__(self) -> bool:
return True
@staticmethod
def new(type_name: str, parameters: Dict[str, Any] = {}, cell_name: str = None) -> 'Operation':
if type_name == '_cell':
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters)
else:
if debug_configs.framework.lower() in ('torch', 'pytorch'):
from .operation_def import torch_op_def # pylint: disable=unused-import
cls = PyTorchOperation._find_subclass(type_name)
elif debug_configs.framework.lower() in ('tf', 'tensorflow'):
from .operation_def import tf_op_def # pylint: disable=unused-import
cls = TensorFlowOperation._find_subclass(type_name)
else:
raise ValueError(f'Unsupported framework: {debug_configs.framework}')
return cls(type_name, parameters, _internal=True)
@classmethod
def _find_subclass(cls, subclass_name):
for subclass in cls.__subclasses__():
if subclass.__name__ == subclass_name:
return subclass
return cls
def __repr__(self):
type_name = type(self).__name__
args = [f'{key}={repr(value)}' for key, value in self.parameters.items()]
if type_name != self.type:
args = [f'type="{self.type}"'] + args
return f'{type_name}({", ".join(args)})'
def __eq__(self, other):
return type(other) is type(self) and other.type == self.type and other.parameters == self.parameters
class PyTorchOperation(Operation):
def _to_class_name(self) -> str:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):]
elif self.type.startswith('__mutated__.'):
return self.type[len('__mutated__.'):]
else:
return None
def get_import_pkg(self) -> str:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):].split('.')[0]
elif self.type.startswith('__mutated__.'):
return self.type[len('__mutated__.'):].split('.')[0]
else:
return None
def to_init_code(self, field: str) -> str:
if self._to_class_name() is not None:
assert 'positional_args' not in self.parameters
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
return f'self.{field} = {self._to_class_name()}({kw_params})'
return None
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
from .converter.op_types import OpTypeName
if self._to_class_name() is not None:
return f'{output} = self.{field}({", ".join(inputs)})'
elif self.type.startswith('Function.'):
func_name = self.type[len('Function.'):]
return f'{output} = F.{func_name}({", ".join(inputs)})'
elif self.type == 'prim::Constant':
if self.parameters:
value = self.parameters['value']
else:
value = None
return f'{output} = {value}'
elif self.type == 'prim::ListConstruct':
return f'{output} = [{", ".join(inputs)}]'
elif self.type == 'aten::mean':
return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})'
elif self.type == 'aten::__getitem__':
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
elif self.type == 'aten::append':
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
elif self.type == 'aten::cat':
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
elif self.type == 'aten::add':
assert len(inputs) == 2
return f'{output} = {inputs[0]} + {inputs[1]}'
elif self.type == OpTypeName.MergedSlice:
assert (len(inputs) - 1) % 4 == 0
slices = []
dim = int((len(inputs) - 1) / 4)
for i in range(dim):
slices.append(f'{inputs[i*4+2]}:{inputs[i*4+3]}:{inputs[i*4+4]}')
slice_str = ','.join(slices)
return f'{output} = {inputs[0]}[{slice_str}]'
elif self.type == 'aten::size':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.size({inputs[1]})'
elif self.type == 'aten::view':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.view({inputs[1]})'
elif self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation')
else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
class TensorFlowOperation(Operation):
def _to_class_name(self) -> str:
return 'K.layers.' + self.type
class Cell(PyTorchOperation):
"""
TODO: this is pytorch cell
An operation reference to a subgraph.
Example code:
```
def __init__(...):
...
self.cell = CustomCell(...)
self.relu = K.layers.ReLU()
...
def forward(...):
...
x = self.cell(x)
...
```
In above example, node `self.cell`'s operation is `Cell(cell_name='CustomCell')`.
For comparison, `self.relu`'s operation is `Operation(type='ReLU')`.
TODO: parameters of subgraph (see `Node` class)
Attributes
----------
type
Always "_cell".
parameters
A dict with only one item; the key is "cell" and the value is cell's name.
framework
No real usage. Exists for compatibility with base class.
"""
def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}):
self.type = '_cell'
self.cell_name = cell_name
self.parameters = parameters
def _to_class_name(self):
# TODO: ugly, think about how to refactor this part
return _convert_name(self.cell_name)
class _IOPseudoOperation(Operation):
"""
This is the pseudo operation used by I/O nodes.
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
"""
def __init__(self, type_name: str, io_names: List = None):
assert type_name.startswith('_')
super(_IOPseudoOperation, self).__init__(type_name, {}, True)
self.io_names = io_names
def to_init_code(self, field: str) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def __bool__(self) -> bool:
return False
"""
Definition of operation types.
These are currently examples for overriding codegen.
Feel free to propose better package name or hierarchy.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment