Unverified Commit a0fd0036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
parents d6dcb483 bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import nni
from nni.retiarii.graph import Evaluator
# pylint: disable=wildcard-import,unused-wildcard-import
@nni.trace
class FunctionalEvaluator(Evaluator):
"""
Functional evaluator that directly takes a function and thus should be general.
Attributes
----------
function
The full name of the function.
arguments
Keyword arguments for the function other than model.
"""
def __init__(self, function, **kwargs):
self.function = function
self.arguments = kwargs
@staticmethod
def _load(ir):
return FunctionalEvaluator(ir['function'], **ir['arguments'])
def _dump(self):
return {
'type': self.__class__,
'function': self.function,
'arguments': self.arguments
}
def _execute(self, model_cls):
return self.function(model_cls, **self.arguments)
def __eq__(self, other):
return self.function == other.function and self.arguments == other.arguments
from nni.nas.evaluator.functional import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import warnings
from typing import Dict, List, Optional, Union, Type
# pylint: disable=wildcard-import,unused-wildcard-import
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader
import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
from .trainer import Trainer
@nni.trace
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.criterion_cls = criterion
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
self.metrics_args = metrics
self.n_models = n_models
def dump_kwargs(self):
kwargs = {}
kwargs['criterion'] = self.criterion_cls
kwargs['metrics'] = self.metrics_args
kwargs['n_models'] = self.n_models
kwargs['learning_rate'] = self.hparams['learning_rate']
kwargs['weight_decay'] = self.hparams['weight_decay']
kwargs['optimizer'] = self.optimizer
return kwargs
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
multi_loss = []
for idx, y_hat in enumerate(multi_y_hat):
loss = self.criterion(y_hat.to("cpu"), y.to("cpu"))
self.log(f'train_loss_{idx}', loss, prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'train_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
multi_loss.append(loss)
return sum(multi_loss)
def validation_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
for idx, y_hat in enumerate(multi_y_hat):
self.log(f'val_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'val_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
multi_y_hat = self(x)
if isinstance(multi_y_hat, tuple):
assert len(multi_y_hat) == self.n_models
else:
assert self.n_models == 1
multi_y_hat = [multi_y_hat]
for idx, y_hat in enumerate(multi_y_hat):
self.log(f'test_loss_{idx}', self.criterion(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
for name, metric in self.metrics.items():
self.log(f'test_{idx}_' + name, metric(y_hat.to("cpu"), y.to("cpu")), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
# TODO: split metric of multiple models?
if len(self.metrics) == 1:
metric_name = next(iter(self.metrics))
ret = []
for idx in range(self.n_models):
ret.append(self.trainer.callback_metrics[f'val_{idx}_' + metric_name].item())
return ret
else:
warnings.warn('Multiple metrics without "default" is not supported by current framework.')
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
"""
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
Users who needs cross-graph optimization should use this module.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
"""
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class _ClassificationModule(_MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Classification(Lightning):
"""
Trainer that is used for classification.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
class _RegressionModule(_MultiModelSupervisedLearningModule):
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Regression(Lightning):
"""
Trainer that is used for regression.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.MSELoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(use_cgo=True, **trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
from nni.nas.evaluator.pytorch.cgo.evaluator import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytorch_lightning as pl
from pytorch_lightning.strategies import SingleDeviceStrategy
# pylint: disable=wildcard-import,unused-wildcard-import
class BypassStrategy(SingleDeviceStrategy):
strategy_name = "single_device"
def model_to_device(self) -> None:
pass
class Trainer(pl.Trainer):
"""
Trainer for cross-graph optimization.
Parameters
----------
use_cgo : bool
Whether cross-graph optimization (CGO) is used.
If it is True, CGO will manage device placement.
Any device placement from pytorch lightning will be bypassed.
default: False
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, use_cgo=False, **trainer_kwargs):
if use_cgo:
if "accelerator" in trainer_kwargs:
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
if 'strategy' in trainer_kwargs:
raise ValueError("cgo.trainer does not support specifying strategy")
trainer_kwargs['strategy'] = BypassStrategy()
super().__init__(**trainer_kwargs)
from nni.nas.evaluator.pytorch.cgo.trainer import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import warnings
from pathlib import Path
from typing import Any, Dict, Union, Optional, List, Callable, Type
# pylint: disable=wildcard-import,unused-wildcard-import
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as nn_functional
import torch.optim as optim
import torchmetrics
import torch.utils.data as torch_data
import nni
from nni.common.serializer import is_traceable
try:
from .cgo import trainer as cgo_trainer
cgo_import_failed = False
except ImportError:
cgo_import_failed = True
from nni.retiarii.graph import Evaluator
from nni.typehint import Literal
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
class LightningModule(pl.LightningModule):
"""
Basic wrapper of generated model.
Lightning modules used in NNI should inherit this class.
It's a subclass of ``pytorch_lightning.LightningModule``.
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
"""
running_mode: Literal['multi', 'oneshot'] = 'multi'
"""An indicator of whether current module is running in a multi-trial experiment or an one-shot.
This flag should be automatically set by experiments when they start to run.
"""
def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None:
"""Set the inner model (architecture) to train / evaluate.
Parameters
----------
model : callable or nn.Module
Can be a callable returning nn.Module or nn.Module.
"""
if isinstance(model, nn.Module):
self.model = model
else:
self.model = model()
Trainer = nni.trace(pl.Trainer)
Trainer.__doc__ = """
Traced version of ``pytorch_lightning.Trainer``. See https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
"""
DataLoader = nni.trace(torch_data.DataLoader)
DataLoader.__doc__ = """
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
"""
@nni.trace
class Lightning(Evaluator):
"""
Delegate the whole training to PyTorch Lightning.
Since the arguments passed to the initialization needs to be serialized, ``LightningModule``, ``Trainer`` or
``DataLoader`` in this file should be used. Another option is to hide dataloader in the Lightning module, in
which case, dataloaders are not required for this class to work.
Following the programming style of Lightning, metrics sent to NNI should be obtained from ``callback_metrics``
in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name
and type depend on the specific task.
.. warning::
The Lightning evaluator are stateful. If you try to use a previous Lightning evaluator,
please note that the inner ``lightning_module`` and ``trainer`` will be reused.
Parameters
----------
lightning_module
Lightning module that defines the training logic.
trainer
Lightning trainer that handles the training.
train_dataloders
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
val_dataloaders
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
"""
def __init__(self, lightning_module: LightningModule, trainer: Trainer,
train_dataloaders: Optional[Any] = None,
val_dataloaders: Optional[Any] = None,
train_dataloader: Optional[Any] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
if cgo_import_failed:
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}'
else:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
if not _check_dataloader(train_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {train_dataloaders}',
RuntimeWarning)
if not _check_dataloader(val_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {val_dataloaders}',
RuntimeWarning)
self.module = lightning_module
self.trainer = trainer
self.train_dataloaders = train_dataloaders
self.val_dataloaders = val_dataloaders
@staticmethod
def _load(ir):
return Lightning(ir['module'], ir['trainer'], ir['train_dataloaders'], ir['val_dataloaders'])
def _dump(self):
return {
'type': self.__class__,
'module': self.module,
'trainer': self.trainer,
'train_dataloaders': self.train_dataloaders,
'val_dataloaders': self.val_dataloaders
}
def _execute(self, model_cls):
return self.fit(model_cls)
@property
def train_dataloader(self):
warnings.warn('train_dataloader is deprecated, please use `train_dataloaders`.', DeprecationWarning)
def __eq__(self, other):
eq_func = False
eq_args = False
if other is None:
return False
if hasattr(self, "function") and hasattr(other, "function"):
eq_func = getattr(self, "function") == getattr(other, "function")
elif not (hasattr(self, "function") or hasattr(other, "function")):
eq_func = True
if hasattr(self, "arguments") and hasattr(other, "arguments"):
eq_args = getattr(self, "arguments") == getattr(other, "arguments")
elif not (hasattr(self, "arguments") or hasattr(other, "arguments")):
eq_args = True
return eq_func and eq_args
def fit(self, model):
"""
Fit the model with provided dataloader, with Lightning trainer.
Parameters
----------
model : nn.Module
The model to fit.
"""
self.module.set_model(model)
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders)
def _check_dataloader(dataloader):
# Check the type of dataloader recursively.
if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader])
if isinstance(dataloader, dict):
return all([_check_dataloader(v) for v in dataloader.values()])
if isinstance(dataloader, torch_data.DataLoader):
return is_traceable(dataloader)
return True
### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule):
trainer: pl.Trainer
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetrics.Metric]],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
if export_onnx is None or export_onnx is True:
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
elif export_onnx:
self.export_onnx = Path(export_onnx)
else:
self.export_onnx = None
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss, prog_bar=True)
for name, metric in self.metrics.items():
self.log('train_' + name, metric(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
if self.running_mode == 'multi' and self.export_onnx is not None:
self.export_onnx.parent.mkdir(exist_ok=True)
try:
self.to_onnx(self.export_onnx, x, export_params=True)
except RuntimeError as e:
warnings.warn(f'ONNX conversion failed. As a result, you might not be able to use visualization. Error message: {e}')
self.export_onnx = None
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('val_' + name, metric(y_hat, y), prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('test_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('test_' + name, metric(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self):
if not self.trainer.sanity_checking and self.running_mode == 'multi':
# Don't report metric when sanity checking
nni.report_intermediate_result(self._get_validation_metrics())
def on_fit_end(self):
if self.running_mode == 'multi':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
if len(self.metrics) == 1:
metric_name = next(iter(self.metrics))
return self.trainer.callback_metrics['val_' + metric_name].item()
else:
warnings.warn('Multiple metrics without "default" is not supported by current framework.')
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
class _AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn_functional.softmax(pred, dim=-1), target)
@nni.trace
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
class Classification(Lightning):
"""
Evaluator that is used for classification.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloaders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
Examples
--------
>>> evaluator = Classification()
To use customized criterion and optimizer:
>>> evaluator = Classification(nn.LabelSmoothingCrossEntropy, optimizer=torch.optim.SGD)
Extra keyword arguments will be passed to trainer, some of which might be necessary to enable GPU acceleration:
>>> evaluator = Classification(accelerator='gpu', devices=2, strategy='ddp')
"""
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
train_dataloader: Optional[DataLoader] = None,
**trainer_kwargs):
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
@nni.trace
class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)
class Regression(Lightning):
"""
Evaluator that is used for regression.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.MSELoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloaders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
export_onnx : bool
If true, model will be exported to ``model.onnx`` before training starts. default: true
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
Examples
--------
>>> evaluator = Regression()
Extra keyword arguments will be passed to trainer, some of which might be necessary to enable GPU acceleration:
>>> evaluator = Regression(gpus=1)
"""
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
train_dataloader: Optional[DataLoader] = None,
**trainer_kwargs):
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
from nni.nas.evaluator.pytorch.lightning import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time
import warnings
from typing import Iterable
# pylint: disable=wildcard-import,unused-wildcard-import
from ..graph import Model, ModelStatus
from .interface import AbstractExecutionEngine
from .listener import DefaultListener
_execution_engine = None
_default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine
if _execution_engine is not None:
warnings.warn('Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are running in a Jupyter notebook, please restart the kernel.',
RuntimeWarning)
_execution_engine = engine
def get_execution_engine() -> AbstractExecutionEngine:
global _execution_engine
assert _execution_engine is not None, 'You need to set execution engine, before using it.'
return _execution_engine
def get_and_register_default_listener(engine: AbstractExecutionEngine) -> DefaultListener:
global _default_listener
if _default_listener is None:
_default_listener = DefaultListener()
engine.register_graph_listener(_default_listener)
return _default_listener
def submit_models(*models: Model) -> None:
engine = get_execution_engine()
get_and_register_default_listener(engine)
engine.submit_models(*models)
def list_models(*models: Model) -> Iterable[Model]:
engine = get_execution_engine()
get_and_register_default_listener(engine)
return engine.list_models()
def wait_models(*models: Model) -> None:
get_and_register_default_listener(get_execution_engine())
while True:
time.sleep(1)
left_models = [g for g in models if not g.status in (ModelStatus.Trained, ModelStatus.Failed)]
if not left_models:
break
def query_available_resources() -> int:
engine = get_execution_engine()
resources = engine.query_available_resource()
return resources if isinstance(resources, int) else len(resources)
def is_stopped_exec(model: Model) -> bool:
return model.status in (ModelStatus.Trained, ModelStatus.Failed)
def budget_exhausted() -> bool:
engine = get_execution_engine()
return engine.budget_exhausted()
from nni.nas.execution.api import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import logging
import os
import random
import string
from typing import Any, Dict, Iterable, List
from nni.experiment import rest
from nni.retiarii.integration import RetiariiAdvisor
from .interface import AbstractExecutionEngine, AbstractGraphListener
from .utils import get_mutation_summary
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Evaluator
from ..integration_api import send_trial, receive_trial_parameters, get_advisor
_logger = logging.getLogger(__name__)
class BaseGraphData:
"""
Data sent between strategy and trial, in graph-based execution engine.
Attributes
----------
model_script
code of an instantiated PyTorch model
evaluator
training approach for model_script
mutation_summary
a dict of all the choices during mutations in the HPO search space format
"""
def __init__(self, model_script: str, evaluator: Evaluator, mutation_summary: dict) -> None:
self.model_script = model_script
self.evaluator = evaluator
self.mutation_summary = mutation_summary
def dump(self) -> dict:
return {
'model_script': self.model_script,
# engine needs to call dump here,
# otherwise, evaluator will become binary
# also, evaluator can be none in tests
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
'mutation_summary': self.mutation_summary
}
@staticmethod
def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], Evaluator._load(data['evaluator']), data['mutation_summary'])
class BaseExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with no optimization at all.
Resource management is implemented in this class.
"""
def __init__(self, rest_port: int | None = None, rest_url_prefix: str | None = None) -> None:
"""
Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener.
Parameters
----------
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self._history: List[Model] = []
self.resources = 0
# register advisor callbacks
advisor: RetiariiAdvisor = get_advisor()
advisor.register_callbacks({
'send_trial': self._send_trial_callback,
'request_trial_jobs': self._request_trial_jobs_callback,
'trial_end': self._trial_end_callback,
'intermediate_metric': self._intermediate_metric_callback,
'final_metric': self._final_metric_callback
})
def submit_models(self, *models: Model) -> None:
for model in models:
data = self.pack_model_data(model)
self._running_models[send_trial(data.dump())] = model
self._history.append(model)
def list_models(self) -> Iterable[Model]:
return self._history
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
def _send_trial_callback(self, paramater: dict) -> None:
if self.resources <= 0:
# FIXME: should be a warning message here
_logger.debug('There is no available resource, but trial is submitted.')
self.resources -= 1
_logger.debug('Resource used. Remaining: %d', self.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None:
self.resources += num_trials
_logger.debug('New resource available. Remaining: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
for listener in self._listeners:
listener.on_training_end(model, success)
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_intermediate_metric(model, metrics)
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
model = self._running_models[trial_id]
model.metric = metrics
for listener in self._listeners:
listener.on_metric(model, metrics)
def query_available_resource(self) -> int:
return self.resources
def budget_exhausted(self) -> bool:
resp = rest.get(self.port, '/check-status', self.url_prefix)
return resp['status'] == 'DONE'
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation_summary = get_mutation_summary(model)
assert model.evaluator is not None, 'Model evaluator can not be None'
return BaseGraphData(codegen.pytorch.model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
model_cls = utils.import_(f'_generated_model.{random_str}._model')
graph_data.evaluator._execute(model_cls)
os.remove(file_name)
from nni.nas.execution.pytorch.graph import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import random
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
# pylint: disable=wildcard-import,unused-wildcard-import
from ..graph import Model
from ..integration_api import receive_trial_parameters
from .base import BaseExecutionEngine
from .utils import get_mutation_dict
class BenchmarkGraphData:
SUPPORTED_BENCHMARK_LIST = [
'nasbench101',
'nasbench201-cifar10',
'nasbench201-cifar100',
'nasbench201-imagenet16',
'nds-cifar10',
'nds-imagenet',
'nlp'
]
def __init__(self, mutation: Dict[str, Any], benchmark: str,
metric_name: Optional[str] = None,
db_path: Optional[str] = None) -> None:
self.mutation = mutation # mutation dict. e.g., {'layer1': 'conv3x3', ...}
self.benchmark = benchmark # e.g., nasbench101, nasbench201, ...
self.db_path = db_path # path to directory of database
def dump(self) -> dict:
from nni.nas.benchmarks.constants import DATABASE_DIR
return {
'mutation': self.mutation,
'benchmark': self.benchmark,
'db_path': self.db_path or DATABASE_DIR # database path need to be passed from manager to worker
}
@staticmethod
def load(data) -> 'BenchmarkGraphData':
return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path'])
def __repr__(self) -> str:
return f"BenchmarkGraphData({self.mutation}, {self.benchmark}, {self.db_path})"
class BenchmarkExecutionEngine(BaseExecutionEngine):
"""
Execution engine that does not actually run any trial, but query the database for results.
The database query is done on the trial end to make sure intermediate metrics are available.
It will also support an accelerated mode that returns metric immediately without even running into NNI manager
(not implemented yet).
"""
def __init__(self, benchmark: Union[str, Callable[[BenchmarkGraphData], Tuple[float, List[float]]]], acceleration: bool = False):
super().__init__()
assert benchmark in BenchmarkGraphData.SUPPORTED_BENCHMARK_LIST, \
f'{benchmark} is not one of the supported benchmarks: {BenchmarkGraphData.SUPPORTED_BENCHMARK_LIST}'
self.benchmark = benchmark
self.acceleration = acceleration
def pack_model_data(self, model: Model) -> Any:
# called when a new model is submitted to backend.
# convert a Model into a data that is acceptable by trial end.
mutation = get_mutation_dict(model)
graph_data = BenchmarkGraphData(mutation, self.benchmark)
return graph_data
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = BenchmarkGraphData.load(receive_trial_parameters())
assert graph_data.db_path is not None, f'Invalid graph data because db_path is None: {graph_data}'
os.environ['NASBENCHMARK_DIR'] = graph_data.db_path
final, intermediates = cls.query_in_benchmark(graph_data)
import nni
for i in intermediates:
nni.report_intermediate_result(i)
nni.report_final_result(final)
@staticmethod
def query_in_benchmark(graph_data: BenchmarkGraphData) -> Tuple[float, List[float]]:
if not isinstance(graph_data.benchmark, str):
return graph_data.benchmark(graph_data)
# built-in benchmarks with default query setting
if graph_data.benchmark == 'nasbench101':
from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats
arch = None
for t in graph_data.mutation.values():
if isinstance(t, dict):
arch = t
if arch is None:
raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}')
return _convert_to_final_and_intermediates(
query_nb101_trial_stats(arch, 108, include_intermediates=True),
'valid_acc'
)
elif graph_data.benchmark.startswith('nasbench201'):
from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
dataset = graph_data.benchmark.split('-')[-1]
return _convert_to_final_and_intermediates(
query_nb201_trial_stats(_flatten_architecture(graph_data.mutation), 200, dataset, include_intermediates=True),
'valid_acc',
)
elif graph_data.benchmark.startswith('nds'):
# FIXME: not tested yet
from nni.nas.benchmarks.nds import query_nds_trial_stats
dataset = graph_data.benchmark.split('-')[-1]
return _convert_to_final_and_intermediates(
query_nds_trial_stats(None, None, None, None, _flatten_architecture(graph_data.mutation),
dataset, include_intermediates=True),
'valid_acc'
)
elif graph_data.benchmark.startswith('nlp'):
# FIXME: not tested yet
from nni.nas.benchmarks.nlp import query_nlp_trial_stats
# TODO: I'm not sure of the availble datasets in this benchmark. and the docs are missing.
return _convert_to_final_and_intermediates(
query_nlp_trial_stats(_flatten_architecture(graph_data.mutation), 'ptb', include_intermediates=True),
'valid_acc'
)
else:
raise ValueError(f'{graph_data.benchmark} is not a supported benchmark.')
def _flatten_architecture(mutation: Dict[str, Any], benchmark: Optional[str] = None):
# STRONG ASSUMPTION HERE!
# This assumes that the benchmarked search space is a one-level search space.
# This means that it is either ONE cell or ONE network.
# Two cell search space like NDS is not supported yet for now.
# Some benchmark even needs special handling to pop out invalid keys. I don't think this is a good design.
# support double underscore to be compatible with naming convention in base engine
ret = {k.split('/')[-1].split('__')[-1]: v for k, v in mutation.items()}
if benchmark == 'nasbench101':
ret = {k: v for k, v in ret.items() if k.startswith('op') or k.startswith('input')}
ret = {k: v if k.startswith('op') or isinstance(v, list) else [v] for k, v in ret.items()}
return ret
def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_name: str) -> Tuple[float, List[float]]:
# convert benchmark results from database to
# final result (float) and intermediate results (list of floats)
benchmark_result = list(benchmark_result)
assert len(benchmark_result) > 0, 'Invalid query. Results from benchmark is empty.'
if len(benchmark_result) > 1:
benchmark_result = random.choice(benchmark_result)
else:
benchmark_result = benchmark_result[0]
benchmark_result = cast(dict, benchmark_result)
return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None]
from nni.nas.execution.pytorch.benchmark import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import logging
import os
import random
import string
import time
import threading
from typing import Iterable, List, Dict, Tuple, cast
from dataclasses import dataclass
from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from nni.retiarii.integration import RetiariiAdvisor
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Node
from ..integration_api import send_trial, receive_trial_parameters, get_advisor
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from ..evaluator.pytorch.lightning import Lightning
from ..evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule
from .base import BaseGraphData
_logger = logging.getLogger(__name__)
def _noop(*args, **kwargs):
pass
@dataclass
class TrialSubmission:
model: Model
placement: Dict[Node, Device]
grouped_models: List[Model]
class CGOExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with Cross-Graph Optimization (CGO).
Only models using PyTorch Lighting and MultiModelSupervisedLearningModule as the evaluator can be optimized.
Otherwise, a model will be submitted independently without any cross-graph optimization.
Parameters
----------
training_service
The remote training service config.
max_concurrency
The maximum number of trials to run concurrently.
batch_waiting_time
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
def __init__(self, training_service: RemoteConfig,
max_concurrency: int = None,
batch_waiting_time: int = 60,
rest_port: int | None = None,
rest_url_prefix: str | None = None
) -> None:
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
self.available_devices: List[Device] = []
self.max_concurrency: int = max_concurrency
devices = self._construct_devices(training_service)
for device in devices:
self.available_devices.append(device)
self.all_devices = self.available_devices.copy()
self._batch_waiting_time = batch_waiting_time # seconds to wait for all models in a batch to do cross-graph optimization
self._optimizers = [DedupInputOptimizer()]
self._original_models = {}
self._original_model_to_multi_model = {}
self._trial_to_original_models = {}
self._trial_used_devices: Dict[int, List[Device]] = {}
self._history: List[Model] = []
self._queuing_models: List[Model] = []
self._models_to_retry: List[Model] = []
self._queue_lock = threading.Lock()
# register advisor callbacks
advisor: RetiariiAdvisor = get_advisor()
advisor.register_callbacks({
'send_trial': _noop,
'request_trial_jobs': _noop,
'trial_end': self._trial_end_callback,
'intermediate_metric': self._intermediate_metric_callback,
'final_metric': self._final_metric_callback
})
self._stopped = False
self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start()
def _construct_devices(self, training_service):
devices = []
if hasattr(training_service, 'machine_list'):
for machine in cast(RemoteConfig, training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def join(self):
self._stopped = True
self._consumer_thread.join()
def add_optimizer(self, opt):
self._optimizers.append(opt)
def submit_models(self, *models: List[Model]) -> None:
curr_time = time.time()
_logger.info('%d models are submitted', len(models))
self._queue_lock.acquire()
self._queuing_models.extend([(curr_time, _) for _ in models])
self._queue_lock.release()
def _submit_retry_models(self, models: List[Model]) -> None:
_logger.info('%d models are retried', len(models))
self._queue_lock.acquire()
self._models_to_retry.extend(models)
self._queue_lock.release()
def _consume_models(self):
# a thread to monitor self._models_to_retry and self._queuing_models to consume them in batch
while not self._stopped:
if len(self._models_to_retry) > 0:
self._queue_lock.acquire()
# retrying jobs should be first scheduled.
for m in self._models_to_retry:
if len(self.available_devices) > 0:
self._submit_models_in_batch(m) # submit the single model to avoid cross-graph optimization.
self._models_to_retry = self._models_to_retry[1:]
self._queue_lock.release()
if len(self._queuing_models) > 0:
self._queue_lock.acquire()
curr_time = time.time()
num_models_to_submit = len(self.available_devices)
if self.max_concurrency:
num_models_to_submit = min(num_models_to_submit, self.max_concurrency)
if curr_time - self._queuing_models[0][0] > self._batch_waiting_time:
num_models_to_submit = min(num_models_to_submit, len(self._queuing_models))
if num_models_to_submit > 0:
self._submit_models_in_batch(*[_[1] for _ in self._queuing_models[:num_models_to_submit]])
self._queuing_models = self._queuing_models[num_models_to_submit:]
self._queue_lock.release()
time.sleep(1)
def _extract_placement_constaint(self, placement_mapping: Dict[Node, Device]):
unique_gpus = sorted(list(set([e for e in placement_mapping.values() if isinstance(e, GPUDevice)])))
placement_constraint = None
if len(unique_gpus) > 0:
placement_constraint = {}
placement_constraint['type'] = 'Device'
placement_constraint['gpus'] = [(e.node_id, e.gpu_id) for e in unique_gpus]
return placement_constraint
def _submit_models_in_batch(self, *models: List[Model]) -> None:
_logger.info('%d models are submitted in batch', len(models))
_logger.debug('model id: %s', str([m.model_id for m in models]))
logical = self._build_logical(models)
for opt in self._optimizers:
opt.convert(logical)
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator, {})
placement_constraint = self._extract_placement_constaint(placement)
trial_id = send_trial(data.dump(), placement_constraint=placement_constraint)
# unique non-cpu devices used by the trial
self._trial_used_devices[trial_id] = list(set([_ for _ in placement.values() if isinstance(_, GPUDevice)]))
# currently, it is impossible for search strategy to submit models more than the number of available devices
for used_device in self._trial_used_devices[trial_id]:
self.available_devices.remove(used_device) # used_device must be in self.available_devices
self._running_models[trial_id] = model
self._trial_to_original_models[trial_id] = []
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
self._trial_to_original_models[trial_id].append(m.model_id)
self._history.append(m)
def list_models(self) -> Iterable[Model]:
return self._history
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, Device], List[Model]]]:
"""
Return the assembled models as a list of tuple.
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
"""
# try to use the available_devices first so that it can be launched as early as possible
# if free devices are not enough to assemble all models in one trial, try all devices
if len(self.available_devices) > 0:
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.available_devices)
if len(self.available_devices) == 0 or len(grouped_models) > 1:
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)
phy_models_and_placements = []
for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model)
assert isinstance(model.evaluator, Lightning), \
"cross-graph optimization only supports pytorch lighting as evaluator"
assert isinstance(model.evaluator.module, _MultiModelSupervisedLearningModule), \
"cross-graph optimization only support MultiModelSupervisedLearningModule"
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params = model.evaluator.module.dump_kwargs().copy()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params['n_models'] = len(multi_model)
new_module = _MultiModelSupervisedLearningModule(**new_module_init_params)
model.evaluator.module = new_module
phy_models_and_placements.append((model, model_placement, multi_model.keys()))
return phy_models_and_placements
def _build_logical(self, models: List[Model]) -> LogicalPlan:
logical_plan = LogicalPlan(plan_id=self.logical_plan_counter)
for model in models:
logical_plan.add_model(model)
self.logical_plan_counter += 1
return logical_plan
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
# def _send_trial_callback(self, paramater: dict) -> None:
# if len(self.available_devices) == 0:
# _logger.warning('There is no available devices, but trial is submitted.')
# _logger.debug('Resource used. Remaining: %d', len(self.available_devices))
# def _request_trial_jobs_callback(self, num_trials: int) -> None:
# self.resources += num_trials
# _logger.info('on_resource_available: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
models_to_retry = []
for model_id in self._original_model_to_multi_model:
if self._original_model_to_multi_model[model_id] == model:
original_model = self._original_models[model_id]
if success:
original_model.status = ModelStatus.Trained
else:
original_model.status = ModelStatus.Failed
# the failed models in a multi-model will be retried one by one w/o CGO
if len(self._trial_to_original_models[trial_id]) > 1:
models_to_retry.append(original_model)
for listener in self._listeners:
listener.on_training_end(original_model, success)
if len(models_to_retry) > 0:
self._submit_retry_models(models_to_retry)
self.available_devices.extend(self._trial_used_devices[trial_id])
self.available_devices = sorted(list(set(self.available_devices)))
del self._running_models[trial_id]
def _intermediate_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self._original_models[model_id].intermediate_metrics.append(merged_metrics[model_id])
for listener in self._listeners:
listener.on_intermediate_metric(self._original_models[model_id], merged_metrics[model_id])
def _final_metric_callback(self, trial_id: int, metrics: MetricData) -> None:
_logger.debug(metrics)
if isinstance(metrics, float):
self._listeners[0].on_metric(self._running_models[trial_id], metrics)
else:
merged_metrics = {}
for idx, _ in enumerate(metrics):
merged_metrics[self._trial_to_original_models[trial_id][idx]] = metrics[idx]
for model_id in merged_metrics:
self._original_models[model_id].metric = merged_metrics[model_id]
for listener in self._listeners:
listener.on_metric(self._original_models[model_id], merged_metrics[model_id])
def query_available_resource(self) -> List[WorkerInfo]:
# the _queuing_models need to use available_devices first
self._queue_lock.acquire()
available_for_more_models = len(self.available_devices) - len(self._queuing_models) - len(self._models_to_retry)
self._queue_lock.release()
return available_for_more_models
def budget_exhausted(self) -> bool:
advisor = get_advisor()
return advisor.stopping
@classmethod
def trial_execute_graph(cls) -> None:
"""
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
_logger.info('CGO_ENGINE trial parameters received')
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
trainer_instance = graph_data.evaluator
model_cls = utils.import_(f'_generated_model.{random_str}._model')
trainer_instance.fit(model_cls())
os.remove(file_name)
class AssemblePolicy:
@staticmethod
def _is_related_node(model: Model, node: Node):
if isinstance(node, AbstractLogicalNode):
if model in node.related_models:
return True
else:
if model == node.graph.model:
return True
return False
@staticmethod
def _check_graph_connectivity(model: Model,
group_model: Dict[Model, Device],
logical_plan: LogicalPlan) -> bool:
for edge in logical_plan.logical_graph.edges:
if AssemblePolicy._is_related_node(model, edge.head) or \
AssemblePolicy._is_related_node(model, edge.tail):
for grouped_model in group_model:
if AssemblePolicy._is_related_node(grouped_model, edge.head) or \
AssemblePolicy._is_related_node(grouped_model, edge.tail):
return True
return False
@staticmethod
def _check_evaluator(new_model: Model, group_model: Dict[Model, Device]) -> bool:
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, _MultiModelSupervisedLearningModule)):
return False
for m in group_model:
if not m.evaluator == new_model.evaluator:
return False
return True
@staticmethod
def group(logical_plan, available_devices):
# TODO: Packing multiple model in one GPU
# Currently, we only support one model per GPU
all_grouped_models = []
group_model = {}
assert(len(available_devices) > 0) # There should be at least 1 device, set in CGO_DEVICES
for idx, m in enumerate(logical_plan.models):
# models in one group should
# (1) not use more GPUs than available_devices
# (2) be connected in the logical plan (independent models should be assembled in multiple groups)
# (3) use same MultiModelSupervisedLearningModule
if len(group_model) > 0 and \
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
AssemblePolicy._check_evaluator(m, group_model) == False):
all_grouped_models.append(group_model)
group_model = {}
group_model[m] = available_devices[idx % len(available_devices)]
if len(group_model) == len(available_devices) or \
idx == len(logical_plan.models) - 1:
all_grouped_models.append(group_model)
group_model = {}
return all_grouped_models
from nni.nas.execution.pytorch.cgo import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, Iterable, NewType, List, Union, Type
# pylint: disable=wildcard-import,unused-wildcard-import
from ..graph import Model, MetricData
__all__ = [
'GraphData', 'WorkerInfo',
'AbstractGraphListener', 'AbstractExecutionEngine'
]
GraphData: Type[Any] = NewType('GraphData', Any)
"""
A _serializable_ internal data type defined by execution engine.
Execution engine will submit this kind of data through NNI to worker machine, and train it there.
A `GraphData` object describes a (merged) executable graph.
This is trial's "hyper-parameter" in NNI's term and will be transfered in JSON format.
See `AbstractExecutionEngine` for details.
"""
WorkerInfo: Type[Any] = NewType('WorkerInfo', Any)
"""
To be designed. Discussion needed.
This describes the properties of a worker machine. (e.g. memory size)
"""
class AbstractGraphListener(ABC):
"""
Abstract listener interface to receive graph events.
Use `AbstractExecutionEngine.register_graph_listener()` to activate a listener.
"""
@abstractmethod
def on_metric(self, model: Model, metric: MetricData) -> None:
"""
Reports the final metric of a graph.
"""
raise NotImplementedError
@abstractmethod
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
"""
Reports the latest intermediate metric of a trainning graph.
"""
pass
@abstractmethod
def on_training_end(self, model: Model, success: bool) -> None:
"""
Reports either a graph is fully trained or the training process has failed.
"""
pass
class AbstractExecutionEngine(ABC):
"""
The abstract interface of execution engine.
Most of these APIs are used by strategy, except `trial_execute_graph`, which is invoked by framework in trial.
Strategy will get the singleton execution engine object through a global API,
and use it in either sync or async manner.
Execution engine is responsible for submitting (maybe-optimized) models to NNI,
and assigning their metrics to the `Model` object after training.
Execution engine is also responsible to launch the graph in trial process,
because it's the only one who understands graph data, or "hyper-parameter" in NNI's term.
Execution engine will leverage NNI Advisor APIs, which are yet open for discussion.
In synchronized use case, the strategy will have a loop to call `submit_models` and `wait_models` repeatly,
and will receive metrics from `Model` attributes.
Execution engine could assume that strategy will only submit graph when there are availabe resources (for now).
In asynchronized use case, the strategy will register a listener to receive events,
while still using `submit_models` to train.
There will be a `BaseExecutionEngine` subclass.
Inner-graph optimizing is supposed to derive `BaseExecutionEngine`,
while overrides `submit_models` and `trial_execute_graph`.
cross-graph optimizing is supposed to derive `AbstractExectutionEngine` directly,
because in this case APIs like `wait_graph` and `listener.on_training_end` will have unique logic.
There might be some util functions benefit all optimizing methods,
but non-mandatory utils should not be covered in abstract interface.
"""
@abstractmethod
def submit_models(self, *models: Model) -> None:
"""
Submit models to NNI.
This method is supposed to call something like `nni.Advisor.create_trial_job(graph_data)`.
"""
raise NotImplementedError
@abstractmethod
def list_models(self) -> Iterable[Model]:
"""
Get all models in submitted.
Execution engine should store a copy of models that have been submitted and return a list of copies in this method.
"""
raise NotImplementedError
@abstractmethod
def query_available_resource(self) -> Union[List[WorkerInfo], int]: # type: ignore
"""
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
Could be left unimplemented for first iteration.
"""
raise NotImplementedError
@abstractmethod
def budget_exhausted(self) -> bool:
"""
Check whether user configured max trial number or max execution duration has been reached
"""
raise NotImplementedError
@abstractmethod
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
"""
Register a listener to receive graph events.
Could be left unimplemented for first iteration.
"""
raise NotImplementedError
@abstractclassmethod
def trial_execute_graph(cls) -> MetricData:
"""
Train graph and returns its metrics, in a separate trial process.
Each call to `nni.Advisor.create_trial_job(graph_data)` will eventually invoke this method.
Because this method will be invoked in trial process on training platform,
it has different context from other methods and has no access to global variable or `self`.
However util APIs like `.utils.experiment_config()` should still be available.
"""
raise NotImplementedError
from nni.nas.execution.common.engine import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..graph import Model, ModelStatus
from .interface import MetricData, AbstractGraphListener
# pylint: disable=wildcard-import,unused-wildcard-import
class DefaultListener(AbstractGraphListener):
def on_metric(self, model: Model, metric: MetricData) -> None:
model.metric = metric
def on_intermediate_metric(self, model: Model, metric: MetricData) -> None:
model.intermediate_metrics.append(metric)
def on_training_end(self, model: Model, success: bool) -> None:
if success:
model.status = ModelStatus.Trained
else:
model.status = ModelStatus.Failed
from nni.nas.execution.common.listener import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC
# pylint: disable=wildcard-import,unused-wildcard-import
from .logical_plan import LogicalPlan
class AbstractOptimizer(ABC):
def __init__(self) -> None:
pass
def convert(self, logical_plan: LogicalPlan) -> None:
raise NotImplementedError
from nni.nas.execution.pytorch.cgo.logical_optimizer.interface import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
from typing import Dict, Tuple, Any
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.retiarii.utils import uid
from nni.common.device import Device, CPUDevice
from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation
class AbstractLogicalNode(Node):
def __init__(self, graph, node_id, name, operation, _internal=False):
super().__init__(graph, node_id, name, operation, _internal=_internal)
self.related_models = []
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
"""
Given a set of models to be formed in a physical model and their device placement,
this function replaces the logical node with an executable physical node for the physical model.
Parameters
----------
multi_model_placement : dict
a dict of models and device placement.
These models will be assembled into the same physical model to run.
Returns
-------
node : Node
the physical node to replace the logical node in the physical model
placement : Device
the device placement of the returned physical node
"""
raise NotImplementedError
def _fork_to(self, graph: Graph):
raise NotImplementedError
class LogicalGraph(Graph):
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
super().__init__(model, graph_id, name='logical_' + name, _internal=_internal)
def _dump(self) -> Any:
nodes_dump = {}
for node in self.hidden_nodes:
if isinstance(node, OriginNode):
nodes_dump[f"{node.original_graph.model.model_id}_{node.name}"] = node._dump()
else:
nodes_dump[f"{node.graph.model.model_id}_{node.name}"] = node._dump()
edges_dump = []
for edge in self.edges:
if isinstance(edge.head, OriginNode):
head_info = f'{edge.head.original_graph.model.model_id}_{edge.head.name}'
else:
head_info = edge.head.name
if isinstance(edge.tail, OriginNode):
tail_info = f'{edge.tail.original_graph.model.model_id}_{edge.tail.name}'
else:
tail_info = edge.tail.name
edges_dump.append((head_info, tail_info))
return {
'inputs': self.input_node.operation.io_names,
'outputs': self.output_node.operation.io_names,
'nodes': nodes_dump,
'edges': edges_dump
}
def _fork_to(self, model: Model) -> Graph:
new_graph = Graph(model, self.id, self.name,
_internal=True)._register()
for node in self.hidden_nodes:
if isinstance(node, AbstractLogicalNode):
node._fork_to(new_graph)
else:
Node(new_graph, node.id, node.name,
node.operation, _internal=True)._register()
id_to_new_node = {node.__repr__(): node for node in new_graph.nodes}
for edge in self.edges:
new_head = id_to_new_node[edge.head.__repr__()]
new_tail = id_to_new_node[edge.tail.__repr__()]
Edge((new_head, edge.head_slot),
(new_tail, edge.tail_slot), _internal=True)._register()
return new_graph
class OriginNode(AbstractLogicalNode):
"""
This is logical node representing the original node without any modification.
In assemble, just return the original node along with the physical placement given by multi_model_placement.
"""
def __init__(self, logical_graph: LogicalGraph,
original_graph: Graph, original_node: Node,
name: str, operation, _internal=False):
super().__init__(logical_graph, original_node.id, name, operation)
self.original_graph = original_graph
self.original_node = original_node
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
model_id = self.original_node.graph.model.model_id
new_node = Node(self.original_node.graph, self.original_node.id,
f"M_{model_id}_" +
self.original_node.name,
self.original_node.operation)
return new_node, multi_model_placement[self.original_node.graph.model]
def __repr__(self):
return f'OriginNode(id={self.id}, name={self.name}, \
operation={self.operation}, origin_model_id={self.original_graph.model.model_id})'
def _fork_to(self, graph: Graph):
OriginNode(graph, self.original_graph, self.original_node,
self.name, self.operation)._register()
class LogicalPlan:
def __init__(self, plan_id=0) -> None:
self.lp_model = Model(_internal=True)
self.id = plan_id
self.logical_graph = LogicalGraph(
self.lp_model, self.id, name=f'{self.id}', _internal=True)._register()
self.lp_model._root_graph_name = self.logical_graph.name
self.models = []
def add_model(self, model: Model):
self.models.append(model)
# Only optimize the root graph.
self._merge_graph(model.root_graph)
def _merge_graph(self, from_graph):
to_graph = self.logical_graph
id_to_new_node = {} # old node ID -> new node object
for old_node in from_graph.nodes:
new_node = OriginNode(to_graph, old_node.graph,
old_node, old_node.name,
old_node.operation, _internal=True)._register()
id_to_new_node[old_node.id] = new_node
for edge in from_graph.edges:
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
def assemble(self, multi_model_placement: Dict[Model, Device]) \
-> Tuple[Model, Dict[Node, Device]]:
"""
Given a set of models to be formed in a physical model and their device placement,
this function replaces all the logical node in this LogicalPlan with executable physical nodes
for the physical model.
Parameters
----------
multi_model_placement : dict
a dict of models and device placement.
These models will be assembled into the same physical model to run.
Returns
-------
phy_model : Model
the physical model formed by models in `multi_model_placement`
all logical node are replaced by physical nodes
node_placements : dict
the device placement of the nodes in `phy_model`
"""
phy_model = Model(_internal=True)
phy_graph = self.lp_model.root_graph._fork_to(phy_model)
phy_graph._rename_graph(phy_graph.name, "_model")
# merge sub-graphs
for model in multi_model_placement:
if phy_model.evaluator is None and model.evaluator is not None:
phy_model.evaluator = model.evaluator
for graph_name in model.graphs:
if graph_name != model._root_graph_name:
new_graph = model.graphs[graph_name]._fork_to(
phy_model, name_prefix=f'M_{model.model_id}_')
# prefix of M_ of hidden_nodes name in non-root graphs is added here
for new_node in new_graph.hidden_nodes:
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model.model_id}_{old_cell_name}'
assert(phy_model.evaluator is not None)
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
evaluator_slot = {} # Model ID -> Slot ID
input_slot_mapping = {}
output_slot_mapping = {}
# Replace all logical nodes to executable physical nodes
hidden_nodes = phy_graph.hidden_nodes.copy()
node_placements = {}
added_models = []
for node in hidden_nodes:
if isinstance(node, OriginNode):
model_id = node.original_graph.model.model_id
if node.original_graph.model not in multi_model_placement:
for edge in node.incoming_edges:
edge.remove()
for edge in node.outgoing_edges:
edge.remove()
node.remove()
continue
if isinstance(node, AbstractLogicalNode):
new_node, placement = node.assemble(multi_model_placement)
if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id
if model_id not in evaluator_slot:
added_models.append(model_id)
evaluator_slot[model_id] = len(added_models) - 1
slot = evaluator_slot[model_id]
else:
slot = evaluator_slot[model_id]
# If a model's inputs/outputs are not used in the multi-model
# the codegen and trainer should not generate and use them
# "use_input" and "use_output" are used to mark whether
# an input/output of a model is used in a multi-model
if new_node.operation.type == '_inputs':
input_slot_mapping[new_node] = slot
if new_node.operation.type == '_outputs':
output_slot_mapping[new_node] = slot
self.node_replace(node, new_node)
# name prefix of M_ of cells in hidden_nodes of root graphs is added here
# FIXME: merge this rename with non-root graph, only do once.
if isinstance(new_node.operation, Cell):
old_cell_name = new_node.operation.cell_name
new_node.operation = copy.deepcopy(new_node.operation)
new_node.operation.cell_name = f'M_{model_id}_{old_cell_name}'
# input should be at CPU, move it to GPU first if necessary
if isinstance(new_node.operation, _IOPseudoOperation) and new_node.operation.type == '_inputs':
# hack: only support single_server
node_placements[new_node] = CPUDevice(node_id=placement.node_id)
else:
node_placements[new_node] = placement
node.remove()
# If two nodes are placed on different devices, use ToDevice op to copy the node
# TODO: when copying one node to multiple devices, broadcast is more efficient than P2P communication
existing_edges = phy_graph.edges.copy()
# Avoid a node is copied multiple times on the same device
copied_op: Dict[Tuple(Node, Device), Node] = {}
for edge in existing_edges:
head_placement = node_placements[edge.head]
tail_placement = node_placements[edge.tail]
if head_placement != tail_placement:
if head_placement.node_id != tail_placement.node_id:
raise ValueError('Cross-server placement is not supported.')
# Same server different devices
if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)]
else:
dst_name = edge.head.name + "_to_" + edge.tail.name
to_operation = Operation.new(
'ToDevice', {
"device": tail_placement, "src": (
edge.head.name, edge.head_slot), "dst": dst_name})
to_node = Node(phy_graph, uid(), dst_name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node
node_placements[to_node] = head_placement
edge.head = to_node
edge.head_slot = None
# merge all input nodes into one with multiple slots
input_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_inputs':
input_nodes.append(node)
for edge in phy_graph.edges:
if edge.head in input_nodes:
edge.head_slot = input_slot_mapping[edge.head]
edge.head = phy_graph.input_node
# merge all output nodes into one with multiple slots
output_nodes = []
for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs':
output_nodes.append(node)
for edge in phy_graph.edges:
if edge.tail in output_nodes:
edge.tail_slot = output_slot_mapping[edge.tail]
edge.tail = phy_graph.output_node
for node in input_nodes:
node.remove()
for node in output_nodes:
node.remove()
return phy_model, node_placements
def node_replace(self, old_node: Node, new_node: Node, input_slot_mapping=None, output_slot_mapping=None):
# TODO: currently, only support single input slot and output slot.
if input_slot_mapping is not None or output_slot_mapping is not None:
raise ValueError('Slot mapping is not supported')
phy_graph = old_node.graph
new_node.graph = phy_graph
new_node._register()
for edge in phy_graph.edges:
if edge.head == old_node:
edge.head = new_node
elif edge.tail == old_node:
edge.tail = new_node
# after the replacement, there might be multiple duplicated edges
# with the same input and output nodes, which should be de-duplicated
self._remove_duplicated_edges()
def _remove_duplicated_edges(self):
# TODO: it does not have duplicated edges if only supporting dedup input
# Duplicated edges appear when a chain of prefix nodes are deduplicated
pass
from nni.nas.execution.pytorch.cgo.logical_optimizer.logical_plan import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List, Dict, Tuple
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.retiarii.utils import uid
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
from nni.common.device import GPUDevice
from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode)
_supported_evaluators = [MultiModelSupervisedLearningModule]
class DedupInputNode(AbstractLogicalNode):
"""
This is logical node representing the node for deduplication.
In assemble, just return one copy of the original node when multiple models are assembled.
These models will share the result of once calculation.
"""
def __init__(self, logical_graph: LogicalGraph, node_id: int,
nodes_to_dedup: List[Node], _internal=False):
super().__init__(logical_graph, node_id,
"Dedup_" + nodes_to_dedup[0].name,
nodes_to_dedup[0].operation)
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
self.related_models = [_.original_graph.model for _ in self.origin_nodes]
def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
for node in self.origin_nodes:
if node.original_graph.model in multi_model_placement:
new_node = Node(node.original_graph, node.id,
f'M_{node.original_graph.model.model_id}_{node.name}',
node.operation)
return new_node, multi_model_placement[node.original_graph.model]
raise ValueError(f'DedupInputNode {self.name} does not contain nodes from multi_model')
def _fork_to(self, graph: Graph):
DedupInputNode(graph, self.id, self.origin_nodes)._register()
def __repr__(self) -> str:
return f'DedupNode(id={self.id}, name={self.name}, \
len(nodes_to_dedup)={len(self.origin_nodes)}'
class DedupInputOptimizer(AbstractOptimizer):
def __init__(self) -> None:
pass
def _check_supported_evaluator(self, evaluator):
for e in _supported_evaluators:
if isinstance(evaluator, e):
return True
return False
def _check_deduplicate_by_node(self, root_node, node_to_check):
if root_node == node_to_check:
return True
if root_node.operation.type == '_inputs' and \
node_to_check.operation.type == '_inputs' and \
isinstance(root_node, OriginNode) and \
isinstance(node_to_check, OriginNode):
if self._check_supported_evaluator(root_node.original_graph.model.evaluator):
return False
if root_node.original_graph.model.evaluator == node_to_check.original_graph.model.evaluator:
return True
else:
return False
else:
return False
def convert(self, logical_plan: LogicalPlan) -> None:
nodes_to_skip = set()
while True: # repeat until the logical_graph converges
input_nodes = logical_plan.logical_graph.get_nodes_by_type("_inputs")
# _PseudoOperation(type_name="_inputs"))
root_node = None
for node in input_nodes:
if node in nodes_to_skip:
continue
root_node = node
break
if root_node is None:
break # end of convert
else:
nodes_to_dedup = []
for node in input_nodes:
if node in nodes_to_skip:
continue
if self._check_deduplicate_by_node(root_node, node):
nodes_to_dedup.append(node)
assert(len(nodes_to_dedup) >= 1)
if len(nodes_to_dedup) == 1:
assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node)
else:
dedup_node = DedupInputNode(logical_plan.logical_graph, uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup:
edge.head = dedup_node
if edge.tail in nodes_to_dedup:
edge.tail = dedup_node
for node in nodes_to_dedup:
node.remove()
from nni.nas.execution.pytorch.cgo.logical_optimizer.opt_dedup_input import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Any, Type, cast
# pylint: disable=wildcard-import,unused-wildcard-import
import torch.nn as nn
from ..graph import Evaluator, Model
from ..integration_api import receive_trial_parameters
from ..utils import ContextStack
from .base import BaseExecutionEngine
from .utils import get_mutation_dict, mutation_dict_to_summary
class PythonGraphData:
def __init__(self, class_: Type[nn.Module], init_parameters: Dict[str, Any],
mutation: Dict[str, Any], evaluator: Evaluator) -> None:
self.class_ = class_
self.init_parameters = init_parameters
self.mutation = mutation
self.evaluator = evaluator
self.mutation_summary = mutation_dict_to_summary(mutation)
def dump(self) -> dict:
return {
'class': self.class_,
'init_parameters': self.init_parameters,
'mutation': self.mutation,
# engine needs to call dump here,
# otherwise, evaluator will become binary
# also, evaluator can be none in tests
'evaluator': self.evaluator._dump() if self.evaluator is not None else None,
'mutation_summary': self.mutation_summary
}
@staticmethod
def load(data) -> 'PythonGraphData':
return PythonGraphData(data['class'], data['init_parameters'], data['mutation'], Evaluator._load(data['evaluator']))
class PurePythonExecutionEngine(BaseExecutionEngine):
"""
This is the execution engine that doesn't rely on Python-IR converter.
We didn't explicitly state this independency for now. Front-end needs to decide which converter / no converter
to use depending on the execution type. In the future, that logic may be moved into this execution engine.
The execution engine needs to store the class path of base model, and init parameters to re-initialize the model
with the mutation dict in the context, so that the mutable modules are created to be the fixed instance on the fly.
"""
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model)
assert model.evaluator is not None, 'Model evaluator is not available.'
graph_data = PythonGraphData(
cast(Type[nn.Module], model.python_class),
model.python_init_params or {}, mutation, model.evaluator
)
return graph_data
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = PythonGraphData.load(receive_trial_parameters())
def _model():
return graph_data.class_(**graph_data.init_parameters)
with ContextStack('fixed', graph_data.mutation):
graph_data.evaluator._execute(_model)
from nni.nas.execution.pytorch.simplified import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, List
from ..graph import Model
# pylint: disable=wildcard-import,unused-wildcard-import
def _unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele
def get_mutation_dict(model: Model):
return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
def mutation_dict_to_summary(mutation: dict) -> dict:
mutation_summary = {}
for label, samples in mutation.items():
# FIXME: this check might be wrong
if not isinstance(samples, list):
mutation_summary[label] = samples
else:
for i, sample in enumerate(samples):
mutation_summary[f'{label}_{i}'] = sample
return mutation_summary
def get_mutation_summary(model: Model) -> dict:
mutation = get_mutation_dict(model)
return mutation_dict_to_summary(mutation)
from nni.nas.execution.common.utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Optional, List
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.experiment.config.base import ConfigBase
__all__ = ['ExecutionEngineConfig', 'BaseEngineConfig', 'OneshotEngineConfig',
'PyEngineConfig', 'CgoEngineConfig', 'BenchmarkEngineConfig']
@dataclass(init=False)
class ExecutionEngineConfig(ConfigBase):
name: str
@dataclass(init=False)
class PyEngineConfig(ExecutionEngineConfig):
name: str = 'py'
@dataclass(init=False)
class OneshotEngineConfig(ExecutionEngineConfig):
name: str = 'oneshot'
@dataclass(init=False)
class BaseEngineConfig(ExecutionEngineConfig):
name: str = 'base'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class CgoEngineConfig(ExecutionEngineConfig):
name: str = 'cgo'
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class BenchmarkEngineConfig(ExecutionEngineConfig):
name: str = 'benchmark'
benchmark: Optional[str] = None
\ No newline at end of file
from nni.nas.experiment.config.engine_config import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, Union, Optional
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.experiment.config import utils, ExperimentConfig
from .engine_config import ExecutionEngineConfig
__all__ = ['RetiariiExeConfig']
def execution_engine_config_factory(engine_name):
# FIXME: may move this function to experiment utils in future
cls = _get_ee_config_class(engine_name)
if cls is None:
raise ValueError(f'Invalid execution engine name: {engine_name}')
return cls()
def _get_ee_config_class(engine_name):
for cls in ExecutionEngineConfig.__subclasses__():
if cls.name == engine_name:
return cls
return None
@dataclass(init=False)
class RetiariiExeConfig(ExperimentConfig):
# FIXME: refactor this class to inherit from a new common base class with HPO config
search_space: Any = ''
trial_code_directory: utils.PathLike = '.'
trial_command: str = '_reserved'
# new config field for NAS
execution_engine: Union[str, ExecutionEngineConfig]
# Internal: to support customized fields in trial command
# Useful when customized python / environment variables are needed
_trial_command_params: Optional[Dict[str, Any]] = None
def __init__(self, training_service_platform: Union[str, None] = None,
execution_engine: Union[str, ExecutionEngineConfig] = 'py',
**kwargs):
super().__init__(training_service_platform, **kwargs)
self.execution_engine = execution_engine
def _canonicalize(self, _parents):
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if self.search_space != '':
raise ValueError(msg.format('search_space', self.search_space))
# TODO: maybe we should also allow users to specify trial_code_directory
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
trial_command_tmpl = '{envs} {python} -m nni.retiarii.trial_entry {execution_engine}'
if self.trial_command != '_reserved' and '-m nni.retiarii.trial_entry' not in self.trial_command:
raise ValueError(msg.format('trial_command', self.trial_command))
if isinstance(self.execution_engine, str):
self.execution_engine = execution_engine_config_factory(self.execution_engine)
_trial_command_params = {
# Default variables
'envs': '',
# TODO: maybe use sys.executable rendered in trial side (e.g., trial_runner)
'python': sys.executable,
'execution_engine': self.execution_engine.name,
# This should override the parameters above.
**(self._trial_command_params or {})
}
self.trial_command = trial_command_tmpl.format(**_trial_command_params).strip()
super()._canonicalize([self])
from nni.nas.experiment.config.experiment_config import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import logging
import warnings
from threading import Thread
from typing import Any, List, Union, cast
import colorama
import torch
import torch.nn as nn
from nni.experiment import Experiment, RunMode
from nni.experiment.config.training_services import RemoteConfig
from .config import (
RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig,
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
)
from ..codegen.pytorch import model_to_pytorch_script
from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape
from ..execution import list_models, set_execution_engine
from ..execution.utils import get_mutation_dict
from ..graph import Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import (
extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations, process_oneshot_mutations
)
from ..oneshot.interface import BaseOneShotTrainer
from ..serializer import is_model_wrapped
from ..strategy import BaseStrategy
from ..strategy.utils import dry_run_for_formatted_search_space
_logger = logging.getLogger(__name__)
__all__ = ['RetiariiExperiment']
def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False):
# TODO: this logic might need to be refactored into execution engine
if oneshot:
base_model_ir, mutators = process_oneshot_mutations(base_model, evaluator)
elif full_ir:
try:
script_module = torch.jit.script(base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
if dummy_input is not None:
# FIXME: this is a workaround as full tensor is not supported in configs
dummy_input = torch.randn(*dummy_input)
converter = GraphConverterWithShape()
base_model_ir = convert_to_graph(script_module, base_model, converter, dummy_input=dummy_input)
else:
base_model_ir = convert_to_graph(script_module, base_model)
# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
else:
base_model_ir, mutators = extract_mutation_from_pt_module(base_model)
base_model_ir.evaluator = evaluator
if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
# Add mutations on evaluators
applied_mutators += process_evaluator_mutations(evaluator, applied_mutators)
return base_model_ir, applied_mutators
def debug_mutated_model(base_model, evaluator, applied_mutators):
"""
Locally run only one trial without launching an experiment for debug purpose, then exit.
For example, it can be used to quickly check shape mismatch.
Specifically, it applies mutators (default to choose the first candidate for the choices)
to generate a new model, then run this model locally.
The model will be parsed with graph execution engine.
Parameters
----------
base_model : nni.retiarii.nn.pytorch.nn.Module
the base model
evaluator : nni.retiarii.graph.Evaluator
the training class of the generated models
applied_mutators : list
a list of mutators that will be applied on the base model for generating a new model
"""
base_model_ir, applied_mutators = preprocess_model(base_model, evaluator, applied_mutators)
from ..strategy.local_debug_strategy import _LocalDebugStrategy
strategy = _LocalDebugStrategy()
strategy.run(base_model_ir, applied_mutators)
_logger.info('local debug completed!')
class RetiariiExperiment(Experiment):
"""
The entry for a NAS experiment.
Users can use this class to start/stop or inspect an experiment, like exporting the results.
Experiment is a sub-class of :class:`nni.experiment.Experiment`, there are many similarities such as
configurable training service to distributed running the experiment on remote server.
But unlike :class:`nni.experiment.Experiment`, RetiariiExperiment doesn't support configure:
- ``trial_code_directory``, which can only be current working directory.
- ``search_space``, which is auto-generated in NAS.
- ``trial_command``, which must be ``python -m nni.retiarii.trial_entry`` to launch the modulized trial code.
RetiariiExperiment also doesn't have tuner/assessor/advisor, because they are also implemented in strategy.
Also, unlike :class:`nni.experiment.Experiment` which is bounded to a node server,
RetiariiExperiment optionally starts a node server to schedule the trials, when the strategy is a multi-trial strategy.
When the strategy is one-shot, the step of launching node server is omitted, and the experiment is run locally by default.
Configurations of experiments, such as execution engine, number of GPUs allocated,
should be put into a :class:`RetiariiExeConfig` and used as an argument of :meth:`RetiariiExperiment.run`.
Parameters
----------
base_model : nn.Module
The model defining the search space / base skeleton without mutation.
It should be wrapped by decorator ``nni.retiarii.model_wrapper``.
evaluator : nni.retiarii.Evaluator, default = None
Evaluator for the experiment.
If you are using a one-shot trainer, it should be placed here, although this usage is deprecated.
applied_mutators : list of nni.retiarii.Mutator, default = None
Mutators os mutate the base model. If none, mutators are skipped.
Note that when ``base_model`` uses inline mutations (e.g., LayerChoice), ``applied_mutators`` must be empty / none.
strategy : nni.retiarii.strategy.BaseStrategy, default = None
Exploration strategy. Can be multi-trial or one-shot.
trainer : BaseOneShotTrainer
Kept for compatibility purposes.
Examples
--------
Multi-trial NAS:
>>> base_model = Net()
>>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
>>> exp = RetiariiExperiment(base_model, model_evaluator, [], search_strategy)
>>> exp_config = RetiariiExeConfig('local')
>>> exp_config.trial_concurrency = 2
>>> exp_config.max_trial_number = 20
>>> exp_config.training_service.use_active_gpu = False
>>> exp.run(exp_config, 8081)
One-shot NAS:
>>> base_model = Net()
>>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
>>> exp = RetiariiExperiment(base_model, evaluator, [], search_strategy)
>>> exp_config = RetiariiExeConfig()
>>> exp_config.execution_engine = 'oneshot' # must be set of one-shot strategy
>>> exp.run(exp_config)
Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net()
"""
def __init__(self, base_model: nn.Module,
evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None),
strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)):
super().__init__(None)
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
if trainer is not None:
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
evaluator = trainer
if evaluator is None:
raise ValueError('Evaluator should not be none.')
self.base_model = base_model
self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator
self.applied_mutators = applied_mutators
self.strategy = strategy
self._dispatcher = None
self._dispatcher_thread = None
# check for sanity
if not is_model_wrapped(base_model):
warnings.warn(colorama.Style.BRIGHT + colorama.Fore.RED +
'`@model_wrapper` is missing for the base model. The experiment might still be able to run, '
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning)
def _run_strategy(self, config: RetiariiExeConfig):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.evaluator, self.applied_mutators,
full_ir=not isinstance(config.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
dummy_input=config.execution_engine.dummy_input
if isinstance(config.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
)
_logger.info('Start strategy...')
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators)
self.update_search_space(search_space)
self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI
def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
#TODO: we will probably need a execution engine factory to make this clean and elegant
if isinstance(config.execution_engine, BaseEngineConfig):
from ..execution.base import BaseExecutionEngine
engine = BaseExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, CgoEngineConfig):
from ..execution.cgo_engine import CGOExecutionEngine
assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert config.execution_engine.batch_waiting_time is not None \
and config.execution_engine.max_concurrency_cgo is not None
engine = CGOExecutionEngine(cast(RemoteConfig, config.training_service),
max_concurrency=config.execution_engine.max_concurrency_cgo,
batch_waiting_time=config.execution_engine.batch_waiting_time,
rest_port=self.port,
rest_url_prefix=self.url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig):
from ..execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from ..execution.benchmark import BenchmarkExecutionEngine
assert config.execution_engine.benchmark is not None, \
'"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
else:
raise ValueError(f'Unsupported engine type: {config.execution_engine}')
set_execution_engine(engine)
def start(self, *args, **kwargs) -> None:
"""
By design, the only different between `start` and `run` is that `start` is asynchronous,
while `run` waits the experiment to complete. RetiariiExperiment always waits the experiment
to complete as strategy runs in foreground.
"""
raise NotImplementedError('RetiariiExperiment is not supposed to provide `start` method')
def run(self,
config: RetiariiExeConfig | None = None,
port: int = 8080,
debug: bool = False) -> None:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
if isinstance(self.evaluator, BaseOneShotTrainer):
# TODO: will throw a deprecation warning soon
# warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
# 'We will try to convert this trainer to our new implementation to run the algorithm. '
# 'In case you want to stick to the old implementation, '
# 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self.evaluator.fit()
return
if config is None:
warnings.warn('config = None is deprecate in future. If you are running a one-shot experiment, '
'please consider creating a config and set execution engine to `oneshot`.', DeprecationWarning)
self.config = RetiariiExeConfig()
self.config.execution_engine = OneshotEngineConfig()
else:
self.config = config
if isinstance(self.config.execution_engine, OneshotEngineConfig) \
or (isinstance(self.config.execution_engine, str) and self.config.execution_engine == 'oneshot'):
# this is hacky, will be refactored when oneshot can run on training services
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.evaluator, self.applied_mutators, oneshot=True)
self.strategy.run(base_model_ir, self.applied_mutators)
else:
ws_url = f'ws://localhost:{port}/tuner'
canonicalized_config = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii'])
canonicalized_config = cast(RetiariiExeConfig, canonicalized_config)
self._dispatcher = RetiariiAdvisor(ws_url)
self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True)
self._dispatcher_thread.start()
# FIXME: engine cannot be created twice
self._create_execution_engine(canonicalized_config)
try:
self._run_strategy(canonicalized_config)
# FIXME: move this logic to strategy with a new API provided by execution engine
self._wait_completion()
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
self.stop()
_logger.info('Search process is done, the experiment is still alive, `stop()` can terminate the experiment.')
def stop(self) -> None:
"""
Stop background experiment.
"""
_logger.info('Stopping experiment, please wait...')
self._stop_impl()
if self._dispatcher_thread:
self._dispatcher_thread.join()
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread = None
_logger.info('Experiment stopped')
def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', formatter: str = 'dict') -> Any:
"""
Export several top performing models.
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are
available for customization.
The concrete behavior of export depends on each strategy.
See the documentation of each strategy for detailed specifications.
Parameters
----------
top_k : int
How many models are intended to be exported.
optimize_mode : str
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
``optimize_mode`` is likely to be removed and defined in strategy in future.
formatter : str
Support ``code`` and ``dict``. Not supported by one-shot algorithms.
If ``code``, the python code of model will be returned.
If ``dict``, the mutation history will be returned.
"""
# TODO: the base class may also need this method
if formatter == 'code':
config = self.config.canonical_copy()
assert not isinstance(config.execution_engine, PyEngineConfig), \
'You should use `dict` formatter when using Python execution engine.'
if isinstance(self.evaluator, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.evaluator.export()
try:
# this currently works for one-shot algorithms
return self.strategy.export_top_models(top_k=top_k)
except NotImplementedError:
# when strategy hasn't implemented its own export logic
all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: cast(float, m.metric), reverse=optimize_mode == 'maximize')
assert formatter in ['code', 'dict'], 'Export formatter other than "code" and "dict" is not supported yet.'
if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
elif formatter == 'dict':
return [get_mutation_dict(model) for model in all_models[:top_k]]
from nni.nas.experiment.pytorch import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from pathlib import Path
from typing import Union, Dict, Any
# pylint: disable=wildcard-import,unused-wildcard-import
from .utils import ContextStack
_logger = logging.getLogger(__name__)
def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
"""
Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example,
.. code-block:: python
with fixed_arch('/path/to/export.json'):
model = Model(3, 224, 224)
Parameters
----------
fixed_arc : str, Path or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
ContextStack
Context manager that provides a fixed architecture when creates the model.
"""
if isinstance(fixed_arch, (str, Path)):
with open(fixed_arch) as f:
fixed_arch = json.load(f)
if verbose:
_logger.info(f'Fixed architecture: %s', fixed_arch)
return ContextStack('fixed', fixed_arch)
from nni.nas.fixed import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment