"examples/vscode:/vscode.git/clone" did not exist on "34e0883db09c4a66827126aeac5cf0dd66b5f1ef"
Unverified Commit 39ec21ca authored by Frandium's avatar Frandium Committed by GitHub
Browse files

Multi-GPU support of one-shot NAS (#4603)

parent b4559f60
......@@ -4,7 +4,7 @@
import os
import warnings
from pathlib import Path
from typing import Dict, Union, Optional, List, Callable, Type
from typing import Any, Dict, Union, Optional, List, Callable, Type
import pytorch_lightning as pl
import torch.nn as nn
......@@ -22,6 +22,7 @@ except ImportError:
cgo_import_failed = True
from nni.retiarii.graph import Evaluator
from nni.typehint import Literal
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
......@@ -36,6 +37,11 @@ class LightningModule(pl.LightningModule):
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
"""
running_mode: Literal['multi', 'oneshot'] = 'multi'
"""An indicator of whether current module is running in a multi-trial experiment or an one-shot.
This flag should be automatically set by experiments when they start to run.
"""
def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None:
"""Set the inner model (architecture) to train / evaluate.
......@@ -59,6 +65,7 @@ DataLoader.__doc__ = """
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
"""
@nni.trace
class Lightning(Evaluator):
"""
......@@ -74,51 +81,67 @@ class Lightning(Evaluator):
Parameters
----------
lightning_module : LightningModule
lightning_module
Lightning module that defines the training logic.
trainer : Trainer
trainer
Lightning trainer that handles the training.
train_dataloders : DataLoader
train_dataloders
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
val_dataloaders
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
"""
def __init__(self, lightning_module: LightningModule, trainer: Trainer,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None):
train_dataloaders: Optional[Any] = None,
val_dataloaders: Optional[Any] = None,
train_dataloader: Optional[Any] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
if cgo_import_failed:
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}'
else:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
if not _check_dataloader(train_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {train_dataloaders}',
RuntimeWarning)
if not _check_dataloader(val_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {val_dataloaders}',
RuntimeWarning)
self.module = lightning_module
self.trainer = trainer
self.train_dataloader = train_dataloader
self.train_dataloaders = train_dataloaders
self.val_dataloaders = val_dataloaders
@staticmethod
def _load(ir):
return Lightning(ir['module'], ir['trainer'], ir['train_dataloader'], ir['val_dataloaders'])
return Lightning(ir['module'], ir['trainer'], ir['train_dataloaders'], ir['val_dataloaders'])
def _dump(self):
return {
'type': self.__class__,
'module': self.module,
'trainer': self.trainer,
'train_dataloader': self.train_dataloader,
'train_dataloaders': self.train_dataloaders,
'val_dataloaders': self.val_dataloaders
}
def _execute(self, model_cls):
return self.fit(model_cls)
@property
def train_dataloader(self):
warnings.warn('train_dataloader is deprecated, please use `train_dataloaders`.', DeprecationWarning)
def __eq__(self, other):
eq_func = False
eq_args = False
......@@ -146,15 +169,18 @@ class Lightning(Evaluator):
The model to fit.
"""
self.module.set_model(model)
return self.trainer.fit(self.module, self.train_dataloader, self.val_dataloaders)
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders)
def _check_dataloader(dataloader):
if dataloader is None:
return True
# Check the type of dataloader recursively.
if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader])
return isinstance(dataloader, torch_data.DataLoader) and is_traceable(dataloader)
if isinstance(dataloader, dict):
return all([_check_dataloader(v) for v in dataloader.values()])
if isinstance(dataloader, torch_data.DataLoader):
return is_traceable(dataloader)
return True
### The following are some commonly used Lightning modules ###
......@@ -176,7 +202,6 @@ class _SupervisedLearningModule(LightningModule):
if export_onnx is None or export_onnx is True:
self.export_onnx = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
self.export_onnx.parent.mkdir(exist_ok=True)
elif export_onnx:
self.export_onnx = Path(export_onnx)
else:
......@@ -199,7 +224,8 @@ class _SupervisedLearningModule(LightningModule):
x, y = batch
y_hat = self(x)
if self.export_onnx is not None:
if self.running_mode == 'multi' and self.export_onnx is not None:
self.export_onnx.parent.mkdir(exist_ok=True)
try:
self.to_onnx(self.export_onnx, x, export_params=True)
except RuntimeError as e:
......@@ -221,9 +247,11 @@ class _SupervisedLearningModule(LightningModule):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self):
if self.running_mode == 'multi':
nni.report_intermediate_result(self._get_validation_metrics())
def on_fit_end(self):
if self.running_mode == 'multi':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
......@@ -283,14 +311,18 @@ class Classification(Lightning):
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
train_dataloader: Optional[DataLoader] = None,
**trainer_kwargs):
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
@nni.trace
......@@ -336,11 +368,15 @@ class Regression(Lightning):
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
train_dataloader: Optional[DataLoader] = None,
**trainer_kwargs):
if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
......@@ -18,6 +18,7 @@ import nni.retiarii.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.typehint import Literal
from .supermodule.base import BaseSuperNetModule
__all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules']
......@@ -334,21 +335,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
return arc_optimizers + w_optimizers, lr_schedulers
def on_train_start(self):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer # type: ignore
self.model.log = self.log
return self.model.on_train_start()
def on_train_end(self):
return self.model.on_train_end()
def on_fit_start(self):
return self.model.on_train_start()
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer # type: ignore
self.model.log = self.log
return self.model.on_fit_start()
def on_fit_end(self):
return self.model.on_train_end()
return self.model.on_fit_end()
def on_train_batch_start(self, batch, batch_idx, unused=0):
return self.model.on_train_batch_start(batch, batch_idx, unused)
......@@ -356,6 +357,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
def on_train_batch_end(self, outputs, batch, batch_idx, unused=0):
return self.model.on_train_batch_end(outputs, batch, batch_idx, unused)
# Deprecated hooks in pytorch-lightning
def on_epoch_start(self):
return self.model.on_epoch_start()
......@@ -427,7 +429,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
else:
apply(lr_schedulers)
def call_weight_optimizers(self, method):
def call_weight_optimizers(self, method: Literal['step', 'zero_grad']):
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import Any
from pytorch_lightning.trainer.supporters import CombinedLoader, CombinedLoaderIterator
class ConcatLoader(CombinedLoader):
"""This loader is same as CombinedLoader in PyTorch-Lightning, but concatenate sub-loaders
instead of loading them in parallel.
Parameters
----------
loaders
For example, ::
{
"train": DataLoader(train_dataset),
"val": DataLoader(val_dataset)
}
In this example, the loader will first produce the batches from "train", then "val".
mode
Only support "min_size" for now.
"""
def __init__(self, loaders: dict[str, Any], mode: str = 'min_size'):
# FIXME: max_cycle will make dataloaders cycle iterators,
# causing extra problems.
if mode != 'min_size':
raise ValueError('Only min_size mode is supported now.')
super().__init__(loaders, mode)
def __iter__(self) -> Any:
"""Replace the super-class iterator with ours."""
self._try_to_patch_pytorch_dataloader()
iterator = ConcatLoaderIterator(self.loaders)
# handle fault tolerant restart.
self.on_restart(iterator)
self._iterator = iterator
return iterator
@staticmethod
def _try_to_patch_pytorch_dataloader():
"""Copied from CombinedLoader."""
from torch.utils.data.dataloader import _BaseDataLoaderIter
# prevent `NotImplementedError` from PyTorch:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
def __getstate__patch__(*_):
return {}
_BaseDataLoaderIter.__getstate__ = __getstate__patch__ # type: ignore
def __len__(self) -> int:
return int(sum(self._calc_num_batches(loader) for loader in self.loaders.values()))
class ConcatLoaderIterator(CombinedLoaderIterator):
"""Similar to CombinedLoaderIterator in Lightning, but in a concat manner."""
def __next__(self) -> Any:
"""Fetches the next batch from multiple data loaders,
by looking for the first iterator that isn't exhausted yet.
"""
if not len(self.loader_iters) == len(self.loaders):
raise RuntimeError('loader_iters must have the same length as loaders.')
for i, (loader_name, iterator) in enumerate(self.loader_iters.items()):
try:
return (self.request_next_batch(iterator), loader_name)
except StopIteration:
if i + 1 == len(self.loader_iters):
raise
......@@ -75,8 +75,9 @@ class DartsLightningModule(BaseOneShotLightningModule):
if not isinstance(arc_optim, optim.Optimizer):
raise TypeError(f'Expect arc_optim to be a single Optimizer, but found: {arc_optim}')
# The InterleavedTrainValDataLoader yields both train and val data in a batch
trn_batch, val_batch = batch
# DARTS strategy makes sure that ``train`` and ``val`` must be in the batch
trn_batch = batch['train']
val_batch = batch['val']
# phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless.
......
......@@ -133,29 +133,30 @@ class EnasLightningModule(RandomSamplingLightningModule):
def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4)
def training_step(self, batch, batch_idx):
# The ConcatenateTrainValDataloader yields both data and which dataloader it comes from.
batch, source = batch
def training_step(self, batch_packed, batch_idx):
batch, mode = batch_packed
if source == 'train':
# step 1: train model params
if mode == 'train':
# train model params
with torch.no_grad():
self.resample()
self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(batch, batch_idx)
w_step_loss = loss_and_metrics['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics
step_output = self.model.training_step(batch, batch_idx)
w_step_loss = step_output['loss'] \
if isinstance(step_output, dict) else step_output
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
return loss_and_metrics
if source == 'val':
# step 2: train ENAS agent
else:
# train ENAS agent
arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad()
self.resample()
self.model.validation_step(batch, batch_idx)
step_output = self.model.validation_step(batch, batch_idx)
# use the default metric of self.model as reward function
if len(self.trainer.callback_metrics) == 1:
_, metric = next(iter(self.trainer.callback_metrics.items()))
......@@ -163,7 +164,9 @@ class EnasLightningModule(RandomSamplingLightningModule):
metric_name = self.reward_metric_name or 'default'
if metric_name not in self.trainer.callback_metrics:
raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but '
f'found multiple metrics without default: {self.trainer.callback_metrics.keys()}')
f'found multiple (or zero) metrics without default: {list(self.trainer.callback_metrics.keys())}. '
f'Try to use self.log to report metrics with the specified key ``{metric_name}`` in validation_step, '
'and remember to set on_step=True.')
metric = self.trainer.callback_metrics[metric_name]
reward: float = metric.item()
......@@ -183,6 +186,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
arc_opt.step()
arc_opt.zero_grad()
return step_output
def resample(self):
"""Resample the architecture with ENAS controller."""
sample = self.controller.resample()
......
......@@ -16,7 +16,6 @@ import warnings
from typing import Any, Type
import torch.nn as nn
from torch.utils.data import DataLoader
from nni.retiarii.graph import Model
from nni.retiarii.strategy.base import BaseStrategy
......@@ -25,7 +24,6 @@ from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule
from .base_lightning import BaseOneShotLightningModule
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasLightningModule, RandomSamplingLightningModule
from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader
class OneShotStrategy(BaseStrategy):
......@@ -37,15 +35,18 @@ class OneShotStrategy(BaseStrategy):
self.model: BaseOneShotLightningModule | None = None
def _get_dataloader(self, train_dataloader: DataLoader, val_dataloaders: DataLoader | list[DataLoader]) \
-> DataLoader | tuple[DataLoader, DataLoader]:
def preprocess_dataloader(self, train_dataloaders: Any, val_dataloaders: Any) -> tuple[Any, Any]:
"""
One-shot strategy typically requires a customized dataloader.
If only train dataloader is produced, return one dataloader.
Otherwise, return train dataloader and valid loader as a tuple.
One-shot strategy typically requires fusing train and validation dataloader in an ad-hoc way.
As one-shot strategy doesn't try to open the blackbox of a batch,
theoretically, these dataloader can be
`any dataloader types supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
Returns
-------
A tuple of preprocessed train dataloaders and validation dataloaders.
"""
raise NotImplementedError()
return train_dataloaders, val_dataloaders
def run(self, base_model: Model, applied_mutators):
# one-shot strategy doesn't use ``applied_mutators``
......@@ -64,18 +65,15 @@ class OneShotStrategy(BaseStrategy):
raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.')
evaluator_module: LightningModule = base_model.evaluator.module
evaluator_module.running_mode = 'oneshot'
evaluator_module.set_model(py_model)
self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
evaluator: Lightning = base_model.evaluator
if evaluator.train_dataloader is None or evaluator.val_dataloaders is None:
raise TypeError('Train or val dataloader is not set.')
dataloader = self._get_dataloader(evaluator.train_dataloader, evaluator.val_dataloaders)
if isinstance(dataloader, tuple):
dataloader, val_loader = dataloader
evaluator.trainer.fit(self.model, dataloader, val_loader)
else:
evaluator.trainer.fit(self.model, dataloader)
if evaluator.train_dataloaders is None or evaluator.val_dataloaders is None:
raise TypeError('Training and validation dataloader are both required to set in evaluator for one-shot strategy.')
train_loader, val_loader = self.preprocess_dataloader(evaluator.train_dataloaders, evaluator.val_dataloaders)
evaluator.trainer.fit(self.model, train_loader, val_loader)
def export_top_models(self, top_k: int = 1) -> list[Any]:
if self.model is None:
......@@ -91,8 +89,12 @@ class DARTS(OneShotStrategy):
def __init__(self, **kwargs):
super().__init__(DartsLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
# By returning a dict, we make a CombinedLoader (in Lightning)
return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class Proxyless(OneShotStrategy):
......@@ -101,8 +103,11 @@ class Proxyless(OneShotStrategy):
def __init__(self, **kwargs):
super().__init__(ProxylessLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class GumbelDARTS(OneShotStrategy):
......@@ -111,8 +116,11 @@ class GumbelDARTS(OneShotStrategy):
def __init__(self, **kwargs):
super().__init__(GumbelDartsLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class ENAS(OneShotStrategy):
......@@ -121,8 +129,13 @@ class ENAS(OneShotStrategy):
def __init__(self, **kwargs):
super().__init__(EnasLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return ConcatenateTrainValDataLoader(train_dataloader, val_dataloaders)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
# Import locally to avoid import error on legacy PL version
from .dataloader import ConcatLoader
return ConcatLoader({
'train': train_dataloaders,
'val': val_dataloaders
}), None
class RandomOneShot(OneShotStrategy):
......@@ -130,6 +143,3 @@ class RandomOneShot(OneShotStrategy):
def __init__(self, **kwargs):
super().__init__(RandomSamplingLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return train_dataloader, val_dataloaders
......@@ -132,6 +132,7 @@ class AverageMeter:
def _replace_module_with_type(root_module, init_fn, type_name, modules):
if modules is None:
modules = []
def apply(m):
for name, child in m.named_children():
if isinstance(child, type_name):
......
......@@ -5,8 +5,6 @@ import logging
import time
from typing import Optional
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
from .. import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted
from .base import BaseStrategy
......@@ -15,6 +13,9 @@ _logger = logging.getLogger(__name__)
class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'):
# Move import here to eliminate some warning messages about dill.
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
self.cur_sample: Optional[dict] = None
self.index: Optional[int] = None
......
......@@ -15,6 +15,9 @@ from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
from nni.retiarii.strategy import BaseStrategy
pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
......@@ -171,7 +174,7 @@ class CustomOpValueChoiceNet(nn.Module):
return F.log_softmax(x, dim=1)
def _mnist_net(type_):
def _mnist_net(type_, evaluator_kwargs):
if type_ == 'simple':
base_model = SimpleNet(False)
elif type_ == 'simple_value_choice':
......@@ -187,17 +190,18 @@ def _mnist_net(type_):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST('data/mnist', train=True, download=True, transform=transform)
# Multi-GPU combined dataloader will break this subset sampler. Expected though.
train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 20))
train_loader = DataLoader(train_dataset, 64, sampler=train_random_sampler)
valid_dataset = MNIST('data/mnist', train=False, download=True, transform=transform)
valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 20))
valid_loader = DataLoader(valid_dataset, 64, sampler=valid_random_sampler)
evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, max_epochs=1)
evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **evaluator_kwargs)
return base_model, evaluator
def _multihead_attention_net():
def _multihead_attention_net(evaluator_kwargs):
base_model = MultiHeadAttentionNet(1)
class AttentionRandDataset(Dataset):
......@@ -222,19 +226,29 @@ def _multihead_attention_net():
train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
evaluator = Regression(train_dataloader=train_loader, val_dataloaders=val_loader, max_epochs=1)
evaluator = Regression(train_dataloader=train_loader, val_dataloaders=val_loader, **evaluator_kwargs)
return base_model, evaluator
def _test_strategy(strategy_, support_value_choice=True):
def _test_strategy(strategy_, support_value_choice=True, multi_gpu=False):
evaluator_kwargs = {
'max_epochs': 1
}
if multi_gpu:
evaluator_kwargs.update(
strategy='ddp',
accelerator='gpu',
devices=torch.cuda.device_count()
)
to_test = [
# (model, evaluator), support_or_net
(_mnist_net('simple'), True),
(_mnist_net('simple_value_choice'), support_value_choice),
(_mnist_net('value_choice'), support_value_choice),
(_mnist_net('repeat'), False), # no strategy supports repeat currently
(_mnist_net('custom_op'), False), # this is definitely a NO
(_multihead_attention_net(), support_value_choice),
(_mnist_net('simple', evaluator_kwargs), True),
(_mnist_net('simple_value_choice', evaluator_kwargs), support_value_choice),
(_mnist_net('value_choice', evaluator_kwargs), support_value_choice),
(_mnist_net('repeat', evaluator_kwargs), False), # no strategy supports repeat currently
(_mnist_net('custom_op', evaluator_kwargs), False), # this is definitely a NO
(_multihead_attention_net(evaluator_kwargs), support_value_choice),
]
for (base_model, evaluator), support_or_not in to_test:
......@@ -256,17 +270,19 @@ def _test_strategy(strategy_, support_value_choice=True):
experiment.run(config)
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_darts():
_test_strategy(strategy.DARTS())
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() <= 1, reason='Must have multiple GPUs.')
def test_darts_multi_gpu():
_test_strategy(strategy.DARTS(), multi_gpu=True)
def test_proxyless():
_test_strategy(strategy.Proxyless(), False)
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_enas():
def strategy_fn(base_model, evaluator):
if isinstance(base_model, MultiHeadAttentionNet):
......@@ -276,12 +292,20 @@ def test_enas():
_test_strategy(strategy_fn)
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() <= 1, reason='Must have multiple GPUs.')
def test_enas_multi_gpu():
def strategy_fn(base_model, evaluator):
if isinstance(base_model, MultiHeadAttentionNet):
return strategy.ENAS(reward_metric_name='val_mse')
return strategy.ENAS(reward_metric_name='val_acc')
_test_strategy(strategy_fn, multi_gpu=True)
def test_random():
_test_strategy(strategy.RandomOneShot())
@pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_gumbel_darts():
_test_strategy(strategy.GumbelDARTS())
......
import math
from typing import Union
import pytest
import torch
import pytorch_lightning
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset
pytestmark = pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs')
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('train_loss', loss)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('test_loss', loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def test_concat_loader():
from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader
loaders = {
'a': DataLoader(range(10), batch_size=4),
'b': DataLoader(range(20), batch_size=5),
}
dataloader = ConcatLoader(loaders)
assert len(dataloader) == 7
for i, (data, label) in enumerate(dataloader):
if i < 3:
assert len(data) <= 4
assert label == 'a'
else:
assert len(data) <= 5
assert label == 'b'
def test_concat_loader_nested():
from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader
loaders = {
'a': [DataLoader(range(10), batch_size=4), DataLoader(range(20), batch_size=6)],
'b': DataLoader(range(20), batch_size=5),
}
dataloader = ConcatLoader(loaders)
assert len(dataloader) == 7
for i, (data, label) in enumerate(dataloader):
if i < 3:
assert isinstance(data, list) and len(data) == 2
assert label == 'a'
else:
assert label == 'b'
@pytest.mark.parametrize('replace_sampler_ddp', [False, True])
@pytest.mark.parametrize('is_min_size_mode', [True])
@pytest.mark.parametrize('num_devices', ['auto', 1, 3, 10])
def test_concat_loader_with_ddp(
replace_sampler_ddp: bool, is_min_size_mode: bool, num_devices: Union[int, str]
):
"""Inspired by tests/trainer/test_supporters.py in lightning."""
from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader
mode = 'min_size' if is_min_size_mode else 'max_size_cycle'
dim = 3
n1 = 8
n2 = 6
n3 = 9
dataloader = ConcatLoader({
'a': {
'a1': DataLoader(RandomDataset(dim, n1), batch_size=1),
'a2': DataLoader(RandomDataset(dim, n2), batch_size=1),
},
'b': DataLoader(RandomDataset(dim, n3), batch_size=1),
}, mode=mode)
expected_length_before_ddp = n3 + (min(n1, n2) if is_min_size_mode else max(n1, n2))
print(len(dataloader))
assert len(dataloader) == expected_length_before_ddp
model = BoringModel()
trainer = Trainer(
strategy='ddp',
accelerator='auto',
devices=num_devices,
replace_sampler_ddp=replace_sampler_ddp,
)
trainer._data_connector.attach_data(
model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
)
expected_length_after_ddp = (
math.ceil(n3 / trainer.num_devices) + \
math.ceil((min(n1, n2) if is_min_size_mode else max(n1, n2)) / trainer.num_devices)
if replace_sampler_ddp
else expected_length_before_ddp
)
print('Num devices =', trainer.num_devices)
trainer.reset_train_dataloader(model=model)
assert trainer.train_dataloader is not None
assert trainer.train_dataloader.mode == mode
assert trainer.num_training_batches == expected_length_after_ddp
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment