Unverified Commit 002af91f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Base execution engine, codegen and trainer (#3059)

parent 60b2a7a3
from .execution import *
from .graph import * from .graph import *
from .mutator import * from .mutator import *
from .operation import * from .operation import *
from .pytorch import model_to_pytorch_script
from typing import *
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from ..operation import Operation, Cell
def model_to_pytorch_script(model: Model) -> str:
graphs = [graph_to_pytorch_model(name, cell) for name, cell in model.graphs.items()]
return _PyTorchScriptTemplate.format('\n\n'.join(graphs)).strip()
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
if not edges:
return []
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_slot))
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> str:
edges = _sorted_incoming_edges(node)
inputs = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
if node.graph.input_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(node.graph.input_names[edge.head_slot])
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name))
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return ', '.join(inputs)
def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
nodes = graph.nodes # FIXME: topological sort is needed here
# handle module node and function node differently
# only need to generate code for module here
node_codes = []
for node in nodes:
if node.operation:
node_codes.append(node.operation.to_init_code(node.name))
if graph.input_names is None:
input_code = '*_inputs'
else:
input_code = ', '.join(graph.input_names)
edge_codes = []
for node in nodes:
if node.operation:
inputs = _format_inputs(node)
edge_codes.append(node.operation.to_forward_code(node.name, node.name, inputs))
output_code = _format_inputs(graph.output_node)
if not output_code:
output_code = 'None'
linebreak = '\n '
return _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
# TODO: handle imports
_PyTorchScriptTemplate = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
{}
'''
_PyTorchModelTemplate = '''
class {graph_name}(nn.Module):
def __init__(self):
super().__init__()
{nodes}
def forward(self, {inputs}):
{edges}
return {outputs}
'''
# pylint: skip-file
"""
FIXME
This file is inherited from last version.
I expect it can work with a few modifications to incorporate with the latest API, but it hasn't
been tested and I'm not sure.
"""
from ..graph_v2 import IllegalGraphError, Cell, Edge, Graph, Node
from ..operations_tf import Operation
from ..type_utils import *
def graph_to_tensorflow_script(graph: Graph) -> str:
graphs = [graph_to_tensorflow_model(name, cell) for name, cell in graph.cell_templates.items()]
return _TensorFlowScriptTemplate.format('\n\n'.join(graphs)).strip()
def _sort_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
if not edges:
return []
if all(edge.tail_idx is None for edge in edges):
return edges
if all(isinstance(edge.tail_idx, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_idx))
if [edge.tail_idx for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> str:
edges = _sort_incoming_edges(node)
inputs = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_idx, int)
if node.graph.input_names is not None:
inputs.append(node.graph.input_names[edge.head_idx])
else:
inputs.append('_inputs[{}]'.format(edge.head_idx))
else:
if edge.head_idx is None:
inputs.append('{}'.format(edge.head.name))
else:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_idx))
return ', '.join(inputs)
def graph_to_tensorflow_model(graph_name: str, graph: Graph) -> str:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
node_codes = []
for node in nodes:
if isinstance(node, Cell):
node_codes.append('self.{} = {}()'.format(node.name, node.template_name))
else:
node_codes.append('self.{} = {}'.format(node.name, cast(Operation, node.operation).to_tensorflow_init()))
edge_codes = []
for node in nodes:
inputs = _format_inputs(node)
edge_codes.append('{} = self.{}({})'.format(node.name, node.name, inputs))
output_code = _format_inputs(graph.output_node)
if not output_code:
output_code = 'None'
if graph.input_names is None:
input_code = '*_inputs'
else:
input_code = ', '.join(graph.input_names)
linebreak = '\n '
return _TensorFlowModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
_TensorFlowScriptTemplate = '''
import tensorflow as tf
import tensorflow.keras as K
import sdk.custom_ops_tf as CUSTOM
{}
'''
_TensorFlowModelTemplate = '''
class {graph_name}(K.Model):
def __init__(self):
super().__init__()
{nodes}
def call(self, {inputs}):
{edges}
return {outputs}
'''
\ No newline at end of file
import time
from typing import *
from ..graph import Model, ModelStatus
from .base import BaseExecutionEngine
from .interface import *
from .listener import DefaultListener
_execution_engine = None
_default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources']
def get_execution_engine() -> BaseExecutionEngine:
"""
Currently we assume the default execution engine is BaseExecutionEngine.
"""
global _execution_engine
if _execution_engine is None:
_execution_engine = BaseExecutionEngine()
return _execution_engine
def get_and_register_default_listener(engine: AbstractExecutionEngine) -> DefaultListener:
global _default_listener
if _default_listener is None:
_default_listener = DefaultListener()
engine.register_graph_listener(_default_listener)
return _default_listener
def submit_models(*models: Model) -> None:
engine = get_execution_engine()
get_and_register_default_listener(engine)
engine.submit_models(*models)
def wait_models(*models: Model) -> None:
get_and_register_default_listener(get_execution_engine())
while True:
time.sleep(1)
left_models = [g for g in models if not g.status in (ModelStatus.Trained, ModelStatus.Failed)]
if not left_models:
break
def query_available_resources() -> List[WorkerInfo]:
listener = get_and_register_default_listener(get_execution_engine())
return listener.resources
from typing import *
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData
from ..integration import send_trial, receive_trial_parameters, get_advisor
class BaseGraphData:
def __init__(self, model_script: str, training_module: str, training_kwargs: Dict[str, Any]) -> None:
self.model_script = model_script
self.training_module = training_module
self.training_kwargs = training_kwargs
def dump(self) -> dict:
return {
'model_script': self.model_script,
'training_module': self.training_module,
'training_kwargs': self.training_kwargs
}
@staticmethod
def load(data):
return BaseGraphData(data['model_script'], data['training_module'], data['training_kwargs'])
class BaseExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with no optimization at all.
Resource management is yet to be implemented.
"""
def __init__(self) -> None:
"""
Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener.
"""
self._listeners: List[AbstractGraphListener] = []
# register advisor callbacks
advisor = get_advisor()
advisor.send_trial_callback = self._send_trial_callback
advisor.request_trial_jobs_callback = self._request_trial_jobs_callback
advisor.trial_end_callback = self._trial_end_callback
advisor.intermediate_metric_callback = self._intermediate_metric_callback
advisor.final_metric_callback = self._final_metric_callback
self._running_models: Dict[int, Model] = dict()
def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model),
model.training_config.module, model.training_config.kwargs)
self._running_models[send_trial(data.dump())] = model
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners:
listener.on_resource_used(0) # FIXME: find the real resource id
def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners:
listener.on_resource_available([0] * num_trials) # FIXME: find the real resource id
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
for listener in self._listeners:
listener.on_training_end(model, success)
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_intermediate_metric(model, metrics)
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.metric = metrics
for listener in self._listeners:
listener.on_metric(model, metrics)
def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here?
@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
with open('_generated_model.py', 'w') as f:
f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module)
model_cls = utils.import_('_generated_model._model')
trainer_instance = trainer_cls(model=model_cls(), **graph_data.training_kwargs)
trainer_instance.fit()
from abc import *
from typing import *
from ..graph import Model, MetricData
__all__ = [
'GraphData', 'WorkerInfo',
'AbstractGraphListener', 'AbstractExecutionEngine'
]
GraphData = NewType('GraphData', Any)
"""
A _serializable_ internal data type defined by execution engine.
Execution engine will submit this kind of data through NNI to worker machine, and train it there.
A `GraphData` object describes a (merged) executable graph.
This is trial's "hyper-parameter" in NNI's term and will be transfered in JSON format.
See `AbstractExecutionEngine` for details.
"""
WorkerInfo = NewType('WorkerInfo', Any)
"""
To be designed. Discussion needed.
This describes the properties of a worker machine. (e.g. memory size)
"""
class AbstractGraphListener(ABC):
"""
Abstract listener interface to receive graph events.
Use `AbstractExecutionEngine.register_graph_listener()` to activate a listener.
"""
@abstractmethod
def on_metric(self, model: Model, metric: MetricData) -> None:
"""
Reports the final metric of a graph.
"""
raise NotImplementedError
@abstractmethod
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
"""
Reports the latest intermediate metric of a trainning graph.
"""
pass
@abstractmethod
def on_training_end(self, model: Model, success: bool) -> None:
"""
Reports either a graph is fully trained or the training process has failed.
"""
pass
@abstractmethod
def on_resource_available(self, resources: List[WorkerInfo]) -> None:
"""
Reports when a worker becomes idle.
"""
pass
class AbstractExecutionEngine(ABC):
"""
The abstract interface of execution engine.
Most of these APIs are used by strategy, except `trial_execute_graph`, which is invoked by framework in trial.
Strategy will get the singleton execution engine object through a global API,
and use it in either sync or async manner.
Execution engine is responsible for submitting (maybe-optimized) models to NNI,
and assigning their metrics to the `Model` object after training.
Execution engine is also responsible to launch the graph in trial process,
because it's the only one who understands graph data, or "hyper-parameter" in NNI's term.
Execution engine will leverage NNI Advisor APIs, which are yet open for discussion.
In synchronized use case, the strategy will have a loop to call `submit_models` and `wait_models` repeatly,
and will receive metrics from `Model` attributes.
Execution engine could assume that strategy will only submit graph when there are availabe resources (for now).
In asynchronized use case, the strategy will register a listener to receive events,
while still using `submit_models` to train.
There will be a `BaseExecutionEngine` subclass.
Inner-graph optimizing is supposed to derive `BaseExecutionEngine`,
while overrides `submit_models` and `trial_execute_graph`.
cross-graph optimizing is supposed to derive `AbstractExectutionEngine` directly,
because in this case APIs like `wait_graph` and `listener.on_training_end` will have unique logic.
There might be some util functions benefit all optimizing methods,
but non-mandatory utils should not be covered in abstract interface.
"""
@abstractmethod
def submit_models(self, *models: Model) -> None:
"""
Submit models to NNI.
This method is supposed to call something like `nni.Advisor.create_trial_job(graph_data)`.
"""
raise NotImplementedError
@abstractmethod
def query_available_resource(self) -> List[WorkerInfo]:
"""
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
Could be left unimplemented for first iteration.
"""
raise NotImplementedError
@abstractmethod
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
"""
Register a listener to receive graph events.
Could be left unimplemented for first iteration.
"""
raise NotImplementedError
@abstractclassmethod
def trial_execute_graph(cls) -> MetricData:
"""
Train graph and returns its metrics, in a separate trial process.
Each call to `nni.Advisor.create_trial_job(graph_data)` will eventually invoke this method.
Because this method will be invoked in trial process on training platform,
it has different context from other methods and has no access to global variable or `self`.
However util APIs like `.utils.experiment_config()` should still be available.
"""
raise NotImplementedError
from typing import *
from ..graph import *
from .interface import *
class DefaultListener(AbstractGraphListener):
def __init__(self):
self.resources: List[WorkerInfo] = []
def on_metric(self, model: Model, metric: MetricData) -> None:
model.metric = metric
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
model.intermediate_metrics.append(metric)
def on_training_end(self, model: Model, success: bool) -> None:
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
def on_resource_available(self, resources: List[WorkerInfo]) -> None:
self.resources += resources
def on_resource_used(self, resources: List[WorkerInfo]) -> None:
self.resources = [r for r in self.resources if r not in resources]
import logging
import threading
from typing import *
import json_tricks
import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import send, CommandType
from nni.utils import MetricType
from . import utils
from .graph import MetricData
_logger = logging.getLogger('nni.msg_dispatcher_base')
class RetiariiAdvisor(MsgDispatcherBase):
"""
The class is to connect Retiarii components to NNI backend.
It will function as the main thread when running a Retiarii experiment through NNI.
Strategy will be launched as its thread, who will call APIs in execution engine. Execution
engine will then find the advisor singleton and send payloads to advisor.
When metrics are sent back, advisor will first receive the payloads, who will call the callback
function (that is a member function in graph listener).
The conversion advisor provides are minimum. It is only a send/receive module, and execution engine
needs to handle all the rest.
FIXME
How does advisor exit when strategy exists?
Attributes
----------
send_trial_callback
request_trial_jobs_callback
trial_end_callback
intermediate_metric_callback
final_metric_callback
"""
def __init__(self, strategy: Union[str, Callable]):
super(RetiariiAdvisor, self).__init__()
register_advisor(self) # register the current advisor as the "global only" advisor
self.send_trial_callback: Callable[[dict], None] = None
self.request_trial_jobs_callback: Callable[[int], None] = None
self.trial_end_callback: Callable[[int, bool], None] = None
self.intermediate_metric_callback: Callable[[int, MetricData], None] = None
self.final_metric_callback: Callable[[int, MetricData], None] = None
self.strategy = utils.import_(strategy) if isinstance(strategy, str) else strategy
self.parameters_count = 0
threading.Thread(target=self.strategy).start()
def handle_initialize(self, data):
pass
def send_trial(self, parameters):
"""
Send parameters to NNI.
Parameters
----------
parameters : Any
Any payload.
Returns
-------
int
Parameter ID that is assigned to this parameter,
which will be used for identification in future.
"""
self.parameters_count += 1
new_trial = {
'parameter_id': self.parameters_count,
'parameters': parameters,
'parameter_source': 'algorithm'
}
_logger.info('New trial sent: {}'.format(new_trial))
send(CommandType.NewTrialJob, json_tricks.dumps(new_trial))
if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count
def handle_request_trial_jobs(self, num_trials):
_logger.info('Request trial jobs: {}'.format(num_trials))
if self.request_trial_jobs_callback is not None:
self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable
def handle_update_search_space(self, data):
pass
def handle_trial_end(self, data):
_logger.info('Trial end: {}'.format(data)) # do nothing
self.trial_end_callback(json_tricks.loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.info('Metric reported: {}'.format(data))
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
elif data['type'] == MetricType.FINAL:
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
@staticmethod
def _process_value(value) -> Any: # hopefully a float
if isinstance(value, dict):
return value['default']
return value
_advisor: RetiariiAdvisor = None
def get_advisor() -> RetiariiAdvisor:
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor: RetiariiAdvisor):
global _advisor
assert _advisor is None
_advisor = advisor
def send_trial(parameters: dict) -> int:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return get_advisor().send_trial(parameters)
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
"""
params = nni.get_next_parameter()
return params
...@@ -49,10 +49,10 @@ class Operation: ...@@ -49,10 +49,10 @@ class Operation:
else: else:
if debug_configs.framework.lower() in ('torch', 'pytorch'): if debug_configs.framework.lower() in ('torch', 'pytorch'):
from .operation_def import torch_op_def # pylint: disable=unused-import from .operation_def import torch_op_def # pylint: disable=unused-import
cls = PyTorchOperation._find_subclass(type) cls = PyTorchOperation._find_subclass(type_name)
elif debug_configs.framework.lower() in ('tf', 'tensorflow'): elif debug_configs.framework.lower() in ('tf', 'tensorflow'):
from .operation_def import tf_op_def # pylint: disable=unused-import from .operation_def import tf_op_def # pylint: disable=unused-import
cls = TensorFlowOperation._find_subclass(type) cls = TensorFlowOperation._find_subclass(type_name)
else: else:
raise ValueError(f'Unsupported framework: {debug_configs.framework}') raise ValueError(f'Unsupported framework: {debug_configs.framework}')
return cls(type_name, parameters, _internal=True) return cls(type_name, parameters, _internal=True)
......
from .interface import BaseTrainer
from .pytorch import PyTorchImageClassificationTrainer
import abc
class BaseTrainer(abc.ABC):
"""
In this version, we plan to write our own trainers instead of using PyTorch-lightning, to
ease the burden to integrate our optmization with PyTorch-lightning, a large part of which is
opaque to us.
We will try to align with PyTorch-lightning name conversions so that we can easily migrate to
PyTorch-lightning in the future.
Currently, our trainer = LightningModule + LightningTrainer. We might want to separate these two things
in future.
Trainer has a ``fit`` function with no return value. Intermediate results and final results should be
directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions.
"""
@abc.abstractmethod
def fit(self) -> None:
pass
import abc
from typing import *
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import nni
from .interface import BaseTrainer
def get_default_transform(dataset: str) -> Any:
"""
To get a default transformation of image for a specific dataset.
This is needed because transform objects can not be directly passed as arguments.
Parameters
----------
dataset : str
Dataset class name.
Returns
-------
transform object
"""
if dataset == 'MNIST':
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# unsupported dataset, return None
return None
class PyTorchImageClassificationTrainer(BaseTrainer):
"""
Image classification trainer for PyTorch.
A model, along with corresponding dataset, optimizer config is used to initialize the trainer.
The trainer will run for a fixed number of epochs (by default 10), and report the final result.
TODO
Support scheduler, validate every n epochs, train/valid dataset
Limitation induced by NNI: kwargs must be serializable to put into a JSON packed in parameters.
"""
def __init__(self, model,
dataset_cls='MNIST', dataset_kwargs=None, dataloader_kwargs=None,
optimizer_cls='SGD', optimizer_kwargs=None, trainer_kwargs=None):
"""Initialization of image classification trainer.
Parameters
----------
model : nn.Module
Model to train.
dataset_cls : str, optional
Dataset class name that is available in ``torchvision.datasets``, by default 'MNIST'
dataset_kwargs : dict, optional
Keyword arguments passed to initialization of dataset class, by default None
dataset_kwargs : dict, optional
Keyword arguments passed to ``torch.utils.data.DataLoader``, by default None
optimizer_cls : str, optional
Optimizer class name that is available in ``torch.optim``, by default 'SGD'
optimizer_kwargs : dict, optional
Keyword arguments passed to initialization of optimizer class, by default None
trainer_kwargs: dict, optional
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
self._use_cuda = torch.cuda.is_available()
self.model = model
if self._use_cuda:
self.model.cuda()
self._loss_fn = nn.CrossEntropyLoss()
self._dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)(model.parameters(), **(optimizer_kwargs or {}))
self._trainer_kwargs = trainer_kwargs or {'max_epochs': 10}
# TODO: we will need at least two (maybe three) data loaders in future.
self._dataloader = DataLoader(self._dataset, **(dataloader_kwargs or {}))
def _accuracy(self, input, target):
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = batch
if self._use_cuda:
x, y = x.cuda(), y.cuda()
y_hat = self.model(x)
loss = self._loss_fn(y_hat, y)
return loss
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = batch
if self._use_cuda:
x, y = x.cuda(), y.cuda()
y_hat = self.model(x)
acc = self._accuracy(y_hat, y)
return {'val_acc': acc}
def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
# We might need dict metrics in future?
avg_acc = np.mean([x['val_acc'] for x in outputs]).item()
nni.report_intermediate_result(avg_acc)
return {'val_acc': avg_acc}
def _validate(self):
validation_outputs = []
for i, batch in enumerate(self._dataloader):
validation_outputs.append(self.validation_step(batch, i))
return self.validation_epoch_end(validation_outputs)
def _train(self):
for i, batch in enumerate(self._dataloader):
self.training_step(batch, i)
def fit(self) -> None:
for _ in range(self._trainer_kwargs['max_epochs']):
self._train()
nni.report_final_result(self._validate()['val_acc']) # assuming val_acc here
"""
Entrypoint for trials.
Assuming execution engine is BaseExecutionEngine.
"""
from .execution.base import BaseExecutionEngine
if __name__ == '__main__':
BaseExecutionEngine.trial_execute_graph()
def import_(target: str, allow_none: bool = False) -> 'Any':
if target is None:
return None
path, identifier = target.rsplit('.', 1)
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)
...@@ -3,3 +3,7 @@ __pycache__ ...@@ -3,3 +3,7 @@ __pycache__
tuner_search_space.json tuner_search_space.json
tuner_result.txt tuner_result.txt
assessor_result.txt assessor_result.txt
_generated_model.py
data
generated
import os
import sys
from nni.retiarii.integration import RetiariiAdvisor
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class _model(nn.Module):
def __init__(self):
super().__init__()
self.stem = stem()
self.fc1 = nn.Linear(1024, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, image):
stem = self.stem(image)
flatten = stem.view(stem.size(0), -1)
fc1 = self.fc1(flatten)
fc2 = self.fc2(fc1)
softmax = F.softmax(fc2, -1)
return softmax
class stem(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = nn.MaxPool2d(kernel_size=2)
def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0])
pool1 = self.pool1(conv1)
conv2 = self.conv2(pool1)
pool2 = self.pool2(conv2)
return pool2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment