Unverified Commit 39ec21ca authored by Frandium's avatar Frandium Committed by GitHub
Browse files

Multi-GPU support of one-shot NAS (#4603)

parent b4559f60
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Union, Optional, List, Callable, Type from typing import Any, Dict, Union, Optional, List, Callable, Type
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn as nn import torch.nn as nn
...@@ -22,6 +22,7 @@ except ImportError: ...@@ -22,6 +22,7 @@ except ImportError:
cgo_import_failed = True cgo_import_failed = True
from nni.retiarii.graph import Evaluator from nni.retiarii.graph import Evaluator
from nni.typehint import Literal
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression'] __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
...@@ -36,6 +37,11 @@ class LightningModule(pl.LightningModule): ...@@ -36,6 +37,11 @@ class LightningModule(pl.LightningModule):
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
""" """
running_mode: Literal['multi', 'oneshot'] = 'multi'
"""An indicator of whether current module is running in a multi-trial experiment or an one-shot.
This flag should be automatically set by experiments when they start to run.
"""
def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None: def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None:
"""Set the inner model (architecture) to train / evaluate. """Set the inner model (architecture) to train / evaluate.
...@@ -59,6 +65,7 @@ DataLoader.__doc__ = """ ...@@ -59,6 +65,7 @@ DataLoader.__doc__ = """
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
""" """
@nni.trace @nni.trace
class Lightning(Evaluator): class Lightning(Evaluator):
""" """
...@@ -74,51 +81,67 @@ class Lightning(Evaluator): ...@@ -74,51 +81,67 @@ class Lightning(Evaluator):
Parameters Parameters
---------- ----------
lightning_module : LightningModule lightning_module
Lightning module that defines the training logic. Lightning module that defines the training logic.
trainer : Trainer trainer
Lightning trainer that handles the training. Lightning trainer that handles the training.
train_dataloders : DataLoader train_dataloders
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples. Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped. If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
val_dataloaders
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
""" """
def __init__(self, lightning_module: LightningModule, trainer: Trainer, def __init__(self, lightning_module: LightningModule, trainer: Trainer,
train_dataloader: Optional[DataLoader] = None, train_dataloaders: Optional[Any] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None): val_dataloaders: Optional[Any] = None,
train_dataloader: Optional[Any] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.' assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
if cgo_import_failed: if cgo_import_failed:
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}' assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}'
else: else:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different # this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \ assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer' f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.' if not _check_dataloader(train_dataloaders):
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.' warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {train_dataloaders}',
RuntimeWarning)
if not _check_dataloader(val_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {val_dataloaders}',
RuntimeWarning)
self.module = lightning_module self.module = lightning_module
self.trainer = trainer self.trainer = trainer
self.train_dataloader = train_dataloader self.train_dataloaders = train_dataloaders
self.val_dataloaders = val_dataloaders self.val_dataloaders = val_dataloaders
@staticmethod @staticmethod
def _load(ir): def _load(ir):
return Lightning(ir['module'], ir['trainer'], ir['train_dataloader'], ir['val_dataloaders']) return Lightning(ir['module'], ir['trainer'], ir['train_dataloaders'], ir['val_dataloaders'])
def _dump(self): def _dump(self):
return { return {
'type': self.__class__, 'type': self.__class__,
'module': self.module, 'module': self.module,
'trainer': self.trainer, 'trainer': self.trainer,
'train_dataloader': self.train_dataloader, 'train_dataloaders': self.train_dataloaders,
'val_dataloaders': self.val_dataloaders 'val_dataloaders': self.val_dataloaders
} }
def _execute(self, model_cls): def _execute(self, model_cls):
return self.fit(model_cls) return self.fit(model_cls)
@property
def train_dataloader(self):
warnings.warn('train_dataloader is deprecated, please use `train_dataloaders`.', DeprecationWarning)
def __eq__(self, other): def __eq__(self, other):
eq_func = False eq_func = False
eq_args = False eq_args = False
...@@ -146,15 +169,18 @@ class Lightning(Evaluator): ...@@ -146,15 +169,18 @@ class Lightning(Evaluator):
The model to fit. The model to fit.
""" """
self.module.set_model(model) self.module.set_model(model)
return self.trainer.fit(self.module, self.train_dataloader, self.val_dataloaders) return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders)
def _check_dataloader(dataloader): def _check_dataloader(dataloader):
if dataloader is None: # Check the type of dataloader recursively.
return True
if isinstance(dataloader, list): if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader]) return all([_check_dataloader(d) for d in dataloader])
return isinstance(dataloader, torch_data.DataLoader) and is_traceable(dataloader) if isinstance(dataloader, dict):
return all([_check_dataloader(v) for v in dataloader.values()])
if isinstance(dataloader, torch_data.DataLoader):
return is_traceable(dataloader)
return True
### The following are some commonly used Lightning modules ### ### The following are some commonly used Lightning modules ###
...@@ -176,7 +202,6 @@ class _SupervisedLearningModule(LightningModule): ...@@ -176,7 +202,6 @@ class _SupervisedLearningModule(LightningModule):
if export_onnx is None or export_onnx is True: if export_onnx is None or export_onnx is True:
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx' self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
self.export_onnx.parent.mkdir(exist_ok=True)
elif export_onnx: elif export_onnx:
self.export_onnx = Path(export_onnx) self.export_onnx = Path(export_onnx)
else: else:
...@@ -199,7 +224,8 @@ class _SupervisedLearningModule(LightningModule): ...@@ -199,7 +224,8 @@ class _SupervisedLearningModule(LightningModule):
x, y = batch x, y = batch
y_hat = self(x) y_hat = self(x)
if self.export_onnx is not None: if self.running_mode == 'multi' and self.export_onnx is not None:
self.export_onnx.parent.mkdir(exist_ok=True)
try: try:
self.to_onnx(self.export_onnx, x, export_params=True) self.to_onnx(self.export_onnx, x, export_params=True)
except RuntimeError as e: except RuntimeError as e:
...@@ -221,10 +247,12 @@ class _SupervisedLearningModule(LightningModule): ...@@ -221,10 +247,12 @@ class _SupervisedLearningModule(LightningModule):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self): def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics()) if self.running_mode == 'multi':
nni.report_intermediate_result(self._get_validation_metrics())
def on_fit_end(self): def on_fit_end(self):
nni.report_final_result(self._get_validation_metrics()) if self.running_mode == 'multi':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self): def _get_validation_metrics(self):
if len(self.metrics) == 1: if len(self.metrics) == 1:
...@@ -283,14 +311,18 @@ class Classification(Lightning): ...@@ -283,14 +311,18 @@ class Classification(Lightning):
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam, optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None, train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True, export_onnx: bool = True,
train_dataloader: Optional[DataLoader] = None,
**trainer_kwargs): **trainer_kwargs):
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate, module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs), super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
@nni.trace @nni.trace
...@@ -336,11 +368,15 @@ class Regression(Lightning): ...@@ -336,11 +368,15 @@ class Regression(Lightning):
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam, optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None, train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True, export_onnx: bool = True,
train_dataloader: Optional[DataLoader] = None,
**trainer_kwargs): **trainer_kwargs):
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate, module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs), super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
...@@ -18,6 +18,7 @@ import nni.retiarii.nn.pytorch as nas_nn ...@@ -18,6 +18,7 @@ import nni.retiarii.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.typehint import Literal
from .supermodule.base import BaseSuperNetModule from .supermodule.base import BaseSuperNetModule
__all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules'] __all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules']
...@@ -334,21 +335,21 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -334,21 +335,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
return arc_optimizers + w_optimizers, lr_schedulers return arc_optimizers + w_optimizers, lr_schedulers
def on_train_start(self): def on_train_start(self):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer # type: ignore
self.model.log = self.log
return self.model.on_train_start() return self.model.on_train_start()
def on_train_end(self): def on_train_end(self):
return self.model.on_train_end() return self.model.on_train_end()
def on_fit_start(self): def on_fit_start(self):
return self.model.on_train_start() # redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer # type: ignore
self.model.log = self.log
return self.model.on_fit_start()
def on_fit_end(self): def on_fit_end(self):
return self.model.on_train_end() return self.model.on_fit_end()
def on_train_batch_start(self, batch, batch_idx, unused=0): def on_train_batch_start(self, batch, batch_idx, unused=0):
return self.model.on_train_batch_start(batch, batch_idx, unused) return self.model.on_train_batch_start(batch, batch_idx, unused)
...@@ -356,6 +357,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -356,6 +357,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
def on_train_batch_end(self, outputs, batch, batch_idx, unused=0): def on_train_batch_end(self, outputs, batch, batch_idx, unused=0):
return self.model.on_train_batch_end(outputs, batch, batch_idx, unused) return self.model.on_train_batch_end(outputs, batch, batch_idx, unused)
# Deprecated hooks in pytorch-lightning
def on_epoch_start(self): def on_epoch_start(self):
return self.model.on_epoch_start() return self.model.on_epoch_start()
...@@ -427,7 +429,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -427,7 +429,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
else: else:
apply(lr_schedulers) apply(lr_schedulers)
def call_weight_optimizers(self, method): def call_weight_optimizers(self, method: Literal['step', 'zero_grad']):
""" """
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer. class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import Any
from pytorch_lightning.trainer.supporters import CombinedLoader, CombinedLoaderIterator
class ConcatLoader(CombinedLoader):
"""This loader is same as CombinedLoader in PyTorch-Lightning, but concatenate sub-loaders
instead of loading them in parallel.
Parameters
----------
loaders
For example, ::
{
"train": DataLoader(train_dataset),
"val": DataLoader(val_dataset)
}
In this example, the loader will first produce the batches from "train", then "val".
mode
Only support "min_size" for now.
"""
def __init__(self, loaders: dict[str, Any], mode: str = 'min_size'):
# FIXME: max_cycle will make dataloaders cycle iterators,
# causing extra problems.
if mode != 'min_size':
raise ValueError('Only min_size mode is supported now.')
super().__init__(loaders, mode)
def __iter__(self) -> Any:
"""Replace the super-class iterator with ours."""
self._try_to_patch_pytorch_dataloader()
iterator = ConcatLoaderIterator(self.loaders)
# handle fault tolerant restart.
self.on_restart(iterator)
self._iterator = iterator
return iterator
@staticmethod
def _try_to_patch_pytorch_dataloader():
"""Copied from CombinedLoader."""
from torch.utils.data.dataloader import _BaseDataLoaderIter
# prevent `NotImplementedError` from PyTorch:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
def __getstate__patch__(*_):
return {}
_BaseDataLoaderIter.__getstate__ = __getstate__patch__ # type: ignore
def __len__(self) -> int:
return int(sum(self._calc_num_batches(loader) for loader in self.loaders.values()))
class ConcatLoaderIterator(CombinedLoaderIterator):
"""Similar to CombinedLoaderIterator in Lightning, but in a concat manner."""
def __next__(self) -> Any:
"""Fetches the next batch from multiple data loaders,
by looking for the first iterator that isn't exhausted yet.
"""
if not len(self.loader_iters) == len(self.loaders):
raise RuntimeError('loader_iters must have the same length as loaders.')
for i, (loader_name, iterator) in enumerate(self.loader_iters.items()):
try:
return (self.request_next_batch(iterator), loader_name)
except StopIteration:
if i + 1 == len(self.loader_iters):
raise
...@@ -75,8 +75,9 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -75,8 +75,9 @@ class DartsLightningModule(BaseOneShotLightningModule):
if not isinstance(arc_optim, optim.Optimizer): if not isinstance(arc_optim, optim.Optimizer):
raise TypeError(f'Expect arc_optim to be a single Optimizer, but found: {arc_optim}') raise TypeError(f'Expect arc_optim to be a single Optimizer, but found: {arc_optim}')
# The InterleavedTrainValDataLoader yields both train and val data in a batch # DARTS strategy makes sure that ``train`` and ``val`` must be in the batch
trn_batch, val_batch = batch trn_batch = batch['train']
val_batch = batch['val']
# phase 1: architecture step # phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless. # The _resample hook is kept for some darts-based NAS methods like proxyless.
......
...@@ -133,29 +133,30 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -133,29 +133,30 @@ class EnasLightningModule(RandomSamplingLightningModule):
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4) return optim.Adam(self.controller.parameters(), lr=3.5e-4)
def training_step(self, batch, batch_idx): def training_step(self, batch_packed, batch_idx):
# The ConcatenateTrainValDataloader yields both data and which dataloader it comes from. batch, mode = batch_packed
batch, source = batch
if source == 'train': if mode == 'train':
# step 1: train model params # train model params
self.resample() with torch.no_grad():
self.resample()
self.call_weight_optimizers('zero_grad') self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(batch, batch_idx) step_output = self.model.training_step(batch, batch_idx)
w_step_loss = loss_and_metrics['loss'] \ w_step_loss = step_output['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics if isinstance(step_output, dict) else step_output
self.manual_backward(w_step_loss) self.manual_backward(w_step_loss)
self.call_weight_optimizers('step') self.call_weight_optimizers('step')
return loss_and_metrics
if source == 'val': else:
# step 2: train ENAS agent # train ENAS agent
arc_opt = self.architecture_optimizers() arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer): if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}') raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad() arc_opt.zero_grad()
self.resample() self.resample()
self.model.validation_step(batch, batch_idx)
step_output = self.model.validation_step(batch, batch_idx)
# use the default metric of self.model as reward function # use the default metric of self.model as reward function
if len(self.trainer.callback_metrics) == 1: if len(self.trainer.callback_metrics) == 1:
_, metric = next(iter(self.trainer.callback_metrics.items())) _, metric = next(iter(self.trainer.callback_metrics.items()))
...@@ -163,7 +164,9 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -163,7 +164,9 @@ class EnasLightningModule(RandomSamplingLightningModule):
metric_name = self.reward_metric_name or 'default' metric_name = self.reward_metric_name or 'default'
if metric_name not in self.trainer.callback_metrics: if metric_name not in self.trainer.callback_metrics:
raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but ' raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but '
f'found multiple metrics without default: {self.trainer.callback_metrics.keys()}') f'found multiple (or zero) metrics without default: {list(self.trainer.callback_metrics.keys())}. '
f'Try to use self.log to report metrics with the specified key ``{metric_name}`` in validation_step, '
'and remember to set on_step=True.')
metric = self.trainer.callback_metrics[metric_name] metric = self.trainer.callback_metrics[metric_name]
reward: float = metric.item() reward: float = metric.item()
...@@ -183,6 +186,8 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -183,6 +186,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
arc_opt.step() arc_opt.step()
arc_opt.zero_grad() arc_opt.zero_grad()
return step_output
def resample(self): def resample(self):
"""Resample the architecture with ENAS controller.""" """Resample the architecture with ENAS controller."""
sample = self.controller.resample() sample = self.controller.resample()
......
...@@ -16,7 +16,6 @@ import warnings ...@@ -16,7 +16,6 @@ import warnings
from typing import Any, Type from typing import Any, Type
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader
from nni.retiarii.graph import Model from nni.retiarii.graph import Model
from nni.retiarii.strategy.base import BaseStrategy from nni.retiarii.strategy.base import BaseStrategy
...@@ -25,7 +24,6 @@ from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule ...@@ -25,7 +24,6 @@ from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule
from .base_lightning import BaseOneShotLightningModule from .base_lightning import BaseOneShotLightningModule
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasLightningModule, RandomSamplingLightningModule from .sampling import EnasLightningModule, RandomSamplingLightningModule
from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader
class OneShotStrategy(BaseStrategy): class OneShotStrategy(BaseStrategy):
...@@ -37,15 +35,18 @@ class OneShotStrategy(BaseStrategy): ...@@ -37,15 +35,18 @@ class OneShotStrategy(BaseStrategy):
self.model: BaseOneShotLightningModule | None = None self.model: BaseOneShotLightningModule | None = None
def _get_dataloader(self, train_dataloader: DataLoader, val_dataloaders: DataLoader | list[DataLoader]) \ def preprocess_dataloader(self, train_dataloaders: Any, val_dataloaders: Any) -> tuple[Any, Any]:
-> DataLoader | tuple[DataLoader, DataLoader]:
""" """
One-shot strategy typically requires a customized dataloader. One-shot strategy typically requires fusing train and validation dataloader in an ad-hoc way.
As one-shot strategy doesn't try to open the blackbox of a batch,
If only train dataloader is produced, return one dataloader. theoretically, these dataloader can be
Otherwise, return train dataloader and valid loader as a tuple. `any dataloader types supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
Returns
-------
A tuple of preprocessed train dataloaders and validation dataloaders.
""" """
raise NotImplementedError() return train_dataloaders, val_dataloaders
def run(self, base_model: Model, applied_mutators): def run(self, base_model: Model, applied_mutators):
# one-shot strategy doesn't use ``applied_mutators`` # one-shot strategy doesn't use ``applied_mutators``
...@@ -64,18 +65,15 @@ class OneShotStrategy(BaseStrategy): ...@@ -64,18 +65,15 @@ class OneShotStrategy(BaseStrategy):
raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.') raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.')
evaluator_module: LightningModule = base_model.evaluator.module evaluator_module: LightningModule = base_model.evaluator.module
evaluator_module.running_mode = 'oneshot'
evaluator_module.set_model(py_model) evaluator_module.set_model(py_model)
self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs) self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
evaluator: Lightning = base_model.evaluator evaluator: Lightning = base_model.evaluator
if evaluator.train_dataloader is None or evaluator.val_dataloaders is None: if evaluator.train_dataloaders is None or evaluator.val_dataloaders is None:
raise TypeError('Train or val dataloader is not set.') raise TypeError('Training and validation dataloader are both required to set in evaluator for one-shot strategy.')
dataloader = self._get_dataloader(evaluator.train_dataloader, evaluator.val_dataloaders) train_loader, val_loader = self.preprocess_dataloader(evaluator.train_dataloaders, evaluator.val_dataloaders)
if isinstance(dataloader, tuple): evaluator.trainer.fit(self.model, train_loader, val_loader)
dataloader, val_loader = dataloader
evaluator.trainer.fit(self.model, dataloader, val_loader)
else:
evaluator.trainer.fit(self.model, dataloader)
def export_top_models(self, top_k: int = 1) -> list[Any]: def export_top_models(self, top_k: int = 1) -> list[Any]:
if self.model is None: if self.model is None:
...@@ -91,8 +89,12 @@ class DARTS(OneShotStrategy): ...@@ -91,8 +89,12 @@ class DARTS(OneShotStrategy):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(DartsLightningModule, **kwargs) super().__init__(DartsLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders): def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders) # By returning a dict, we make a CombinedLoader (in Lightning)
return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class Proxyless(OneShotStrategy): class Proxyless(OneShotStrategy):
...@@ -101,8 +103,11 @@ class Proxyless(OneShotStrategy): ...@@ -101,8 +103,11 @@ class Proxyless(OneShotStrategy):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(ProxylessLightningModule, **kwargs) super().__init__(ProxylessLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders): def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders) return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class GumbelDARTS(OneShotStrategy): class GumbelDARTS(OneShotStrategy):
...@@ -111,8 +116,11 @@ class GumbelDARTS(OneShotStrategy): ...@@ -111,8 +116,11 @@ class GumbelDARTS(OneShotStrategy):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(GumbelDartsLightningModule, **kwargs) super().__init__(GumbelDartsLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders): def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders) return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class ENAS(OneShotStrategy): class ENAS(OneShotStrategy):
...@@ -121,8 +129,13 @@ class ENAS(OneShotStrategy): ...@@ -121,8 +129,13 @@ class ENAS(OneShotStrategy):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(EnasLightningModule, **kwargs) super().__init__(EnasLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders): def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
return ConcatenateTrainValDataLoader(train_dataloader, val_dataloaders) # Import locally to avoid import error on legacy PL version
from .dataloader import ConcatLoader
return ConcatLoader({
'train': train_dataloaders,
'val': val_dataloaders
}), None
class RandomOneShot(OneShotStrategy): class RandomOneShot(OneShotStrategy):
...@@ -130,6 +143,3 @@ class RandomOneShot(OneShotStrategy): ...@@ -130,6 +143,3 @@ class RandomOneShot(OneShotStrategy):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(RandomSamplingLightningModule, **kwargs) super().__init__(RandomSamplingLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return train_dataloader, val_dataloaders
...@@ -132,6 +132,7 @@ class AverageMeter: ...@@ -132,6 +132,7 @@ class AverageMeter:
def _replace_module_with_type(root_module, init_fn, type_name, modules): def _replace_module_with_type(root_module, init_fn, type_name, modules):
if modules is None: if modules is None:
modules = [] modules = []
def apply(m): def apply(m):
for name, child in m.named_children(): for name, child in m.named_children():
if isinstance(child, type_name): if isinstance(child, type_name):
......
...@@ -5,8 +5,6 @@ import logging ...@@ -5,8 +5,6 @@ import logging
import time import time
from typing import Optional from typing import Optional
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
from .. import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted from .. import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted
from .base import BaseStrategy from .base import BaseStrategy
...@@ -15,6 +13,9 @@ _logger = logging.getLogger(__name__) ...@@ -15,6 +13,9 @@ _logger = logging.getLogger(__name__)
class TPESampler(Sampler): class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'): def __init__(self, optimize_mode='minimize'):
# Move import here to eliminate some warning messages about dill.
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode) self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
self.cur_sample: Optional[dict] = None self.cur_sample: Optional[dict] = None
self.index: Optional[int] = None self.index: Optional[int] = None
......
...@@ -15,6 +15,9 @@ from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice ...@@ -15,6 +15,9 @@ from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
from nni.retiarii.strategy import BaseStrategy from nni.retiarii.strategy import BaseStrategy
pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
class DepthwiseSeparableConv(nn.Module): class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
super().__init__() super().__init__()
...@@ -171,7 +174,7 @@ class CustomOpValueChoiceNet(nn.Module): ...@@ -171,7 +174,7 @@ class CustomOpValueChoiceNet(nn.Module):
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
def _mnist_net(type_): def _mnist_net(type_, evaluator_kwargs):
if type_ == 'simple': if type_ == 'simple':
base_model = SimpleNet(False) base_model = SimpleNet(False)
elif type_ == 'simple_value_choice': elif type_ == 'simple_value_choice':
...@@ -187,17 +190,18 @@ def _mnist_net(type_): ...@@ -187,17 +190,18 @@ def _mnist_net(type_):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST('data/mnist', train=True, download=True, transform=transform) train_dataset = MNIST('data/mnist', train=True, download=True, transform=transform)
# Multi-GPU combined dataloader will break this subset sampler. Expected though.
train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 20)) train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 20))
train_loader = DataLoader(train_dataset, 64, sampler=train_random_sampler) train_loader = DataLoader(train_dataset, 64, sampler=train_random_sampler)
valid_dataset = MNIST('data/mnist', train=False, download=True, transform=transform) valid_dataset = MNIST('data/mnist', train=False, download=True, transform=transform)
valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 20)) valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 20))
valid_loader = DataLoader(valid_dataset, 64, sampler=valid_random_sampler) valid_loader = DataLoader(valid_dataset, 64, sampler=valid_random_sampler)
evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, max_epochs=1) evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **evaluator_kwargs)
return base_model, evaluator return base_model, evaluator
def _multihead_attention_net(): def _multihead_attention_net(evaluator_kwargs):
base_model = MultiHeadAttentionNet(1) base_model = MultiHeadAttentionNet(1)
class AttentionRandDataset(Dataset): class AttentionRandDataset(Dataset):
...@@ -222,19 +226,29 @@ def _multihead_attention_net(): ...@@ -222,19 +226,29 @@ def _multihead_attention_net():
train_loader = DataLoader(train_set, batch_size=32) train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32) val_loader = DataLoader(val_set, batch_size=32)
evaluator = Regression(train_dataloader=train_loader, val_dataloaders=val_loader, max_epochs=1) evaluator = Regression(train_dataloader=train_loader, val_dataloaders=val_loader, **evaluator_kwargs)
return base_model, evaluator return base_model, evaluator
def _test_strategy(strategy_, support_value_choice=True): def _test_strategy(strategy_, support_value_choice=True, multi_gpu=False):
evaluator_kwargs = {
'max_epochs': 1
}
if multi_gpu:
evaluator_kwargs.update(
strategy='ddp',
accelerator='gpu',
devices=torch.cuda.device_count()
)
to_test = [ to_test = [
# (model, evaluator), support_or_net # (model, evaluator), support_or_net
(_mnist_net('simple'), True), (_mnist_net('simple', evaluator_kwargs), True),
(_mnist_net('simple_value_choice'), support_value_choice), (_mnist_net('simple_value_choice', evaluator_kwargs), support_value_choice),
(_mnist_net('value_choice'), support_value_choice), (_mnist_net('value_choice', evaluator_kwargs), support_value_choice),
(_mnist_net('repeat'), False), # no strategy supports repeat currently (_mnist_net('repeat', evaluator_kwargs), False), # no strategy supports repeat currently
(_mnist_net('custom_op'), False), # this is definitely a NO (_mnist_net('custom_op', evaluator_kwargs), False), # this is definitely a NO
(_multihead_attention_net(), support_value_choice), (_multihead_attention_net(evaluator_kwargs), support_value_choice),
] ]
for (base_model, evaluator), support_or_not in to_test: for (base_model, evaluator), support_or_not in to_test:
...@@ -256,17 +270,19 @@ def _test_strategy(strategy_, support_value_choice=True): ...@@ -256,17 +270,19 @@ def _test_strategy(strategy_, support_value_choice=True):
experiment.run(config) experiment.run(config)
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_darts(): def test_darts():
_test_strategy(strategy.DARTS()) _test_strategy(strategy.DARTS())
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() <= 1, reason='Must have multiple GPUs.')
def test_darts_multi_gpu():
_test_strategy(strategy.DARTS(), multi_gpu=True)
def test_proxyless(): def test_proxyless():
_test_strategy(strategy.Proxyless(), False) _test_strategy(strategy.Proxyless(), False)
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_enas(): def test_enas():
def strategy_fn(base_model, evaluator): def strategy_fn(base_model, evaluator):
if isinstance(base_model, MultiHeadAttentionNet): if isinstance(base_model, MultiHeadAttentionNet):
...@@ -276,12 +292,20 @@ def test_enas(): ...@@ -276,12 +292,20 @@ def test_enas():
_test_strategy(strategy_fn) _test_strategy(strategy_fn)
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() <= 1, reason='Must have multiple GPUs.')
def test_enas_multi_gpu():
def strategy_fn(base_model, evaluator):
if isinstance(base_model, MultiHeadAttentionNet):
return strategy.ENAS(reward_metric_name='val_mse')
return strategy.ENAS(reward_metric_name='val_acc')
_test_strategy(strategy_fn, multi_gpu=True)
def test_random(): def test_random():
_test_strategy(strategy.RandomOneShot()) _test_strategy(strategy.RandomOneShot())
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_gumbel_darts(): def test_gumbel_darts():
_test_strategy(strategy.GumbelDARTS()) _test_strategy(strategy.GumbelDARTS())
......
import math
from typing import Union
import pytest
import torch
import pytorch_lightning
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
pytestmark = pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs')
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('train_loss', loss)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('test_loss', loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def test_concat_loader():
from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader
loaders = {
'a': DataLoader(range(10), batch_size=4),
'b': DataLoader(range(20), batch_size=5),
}
dataloader = ConcatLoader(loaders)
assert len(dataloader) == 7
for i, (data, label) in enumerate(dataloader):
if i < 3:
assert len(data) <= 4
assert label == 'a'
else:
assert len(data) <= 5
assert label == 'b'
def test_concat_loader_nested():
from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader
loaders = {
'a': [DataLoader(range(10), batch_size=4), DataLoader(range(20), batch_size=6)],
'b': DataLoader(range(20), batch_size=5),
}
dataloader = ConcatLoader(loaders)
assert len(dataloader) == 7
for i, (data, label) in enumerate(dataloader):
if i < 3:
assert isinstance(data, list) and len(data) == 2
assert label == 'a'
else:
assert label == 'b'
@pytest.mark.parametrize('replace_sampler_ddp', [False, True])
@pytest.mark.parametrize('is_min_size_mode', [True])
@pytest.mark.parametrize('num_devices', ['auto', 1, 3, 10])
def test_concat_loader_with_ddp(
replace_sampler_ddp: bool, is_min_size_mode: bool, num_devices: Union[int, str]
):
"""Inspired by tests/trainer/test_supporters.py in lightning."""
from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader
mode = 'min_size' if is_min_size_mode else 'max_size_cycle'
dim = 3
n1 = 8
n2 = 6
n3 = 9
dataloader = ConcatLoader({
'a': {
'a1': DataLoader(RandomDataset(dim, n1), batch_size=1),
'a2': DataLoader(RandomDataset(dim, n2), batch_size=1),
},
'b': DataLoader(RandomDataset(dim, n3), batch_size=1),
}, mode=mode)
expected_length_before_ddp = n3 + (min(n1, n2) if is_min_size_mode else max(n1, n2))
print(len(dataloader))
assert len(dataloader) == expected_length_before_ddp
model = BoringModel()
trainer = Trainer(
strategy='ddp',
accelerator='auto',
devices=num_devices,
replace_sampler_ddp=replace_sampler_ddp,
)
trainer._data_connector.attach_data(
model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
)
expected_length_after_ddp = (
math.ceil(n3 / trainer.num_devices) + \
math.ceil((min(n1, n2) if is_min_size_mode else max(n1, n2)) / trainer.num_devices)
if replace_sampler_ddp
else expected_length_before_ddp
)
print('Num devices =', trainer.num_devices)
trainer.reset_train_dataloader(model=model)
assert trainer.train_dataloader is not None
assert trainer.train_dataloader.mode == mode
assert trainer.num_training_batches == expected_length_after_ddp
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment