"tests/vscode:/vscode.git/clone" did not exist on "73acebb8cfbd1d2954cabe1af4185f9994e61917"
Unverified Commit 802650ff authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Miscellaneous fixes of NAS (v2.9) (#5051)

parent cd98c48f
...@@ -114,7 +114,7 @@ class NasBench101TrainingModule(pl.LightningModule): ...@@ -114,7 +114,7 @@ class NasBench101TrainingModule(pl.LightningModule):
momentum=0.9, alpha=0.9, eps=1.0) momentum=0.9, alpha=0.9, eps=1.0)
return { return {
'optimizer': optimizer, 'optimizer': optimizer,
'scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs) 'lr_scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs)
} }
def on_validation_epoch_end(self): def on_validation_epoch_end(self):
......
...@@ -103,7 +103,7 @@ class NasBench201TrainingModule(pl.LightningModule): ...@@ -103,7 +103,7 @@ class NasBench201TrainingModule(pl.LightningModule):
momentum=0.9, alpha=0.9, eps=1.0) momentum=0.9, alpha=0.9, eps=1.0)
return { return {
'optimizer': optimizer, 'optimizer': optimizer,
'scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs) 'lr_scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs)
} }
def on_validation_epoch_end(self): def on_validation_epoch_end(self):
......
...@@ -31,7 +31,12 @@ def load_benchmark(benchmark: str) -> SqliteExtDatabase: ...@@ -31,7 +31,12 @@ def load_benchmark(benchmark: str) -> SqliteExtDatabase:
return _loaded_benchmarks[benchmark] return _loaded_benchmarks[benchmark]
url = DB_URLS[benchmark] url = DB_URLS[benchmark]
local_path = os.path.join(DATABASE_DIR, os.path.basename(url)) local_path = os.path.join(DATABASE_DIR, os.path.basename(url))
load_or_download_file(local_path, url) try:
load_or_download_file(local_path, url)
except FileNotFoundError:
raise FileNotFoundError(
f'Please use `nni.nas.benchmarks.download_benchmark("{benchmark}")` to setup the benchmark first before using it.'
)
_loaded_benchmarks[benchmark] = SqliteExtDatabase(local_path, autoconnect=True) _loaded_benchmarks[benchmark] = SqliteExtDatabase(local_path, autoconnect=True)
return _loaded_benchmarks[benchmark] return _loaded_benchmarks[benchmark]
......
...@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader ...@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
import nni import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning from ..lightning import LightningModule, AccuracyWithLogits, Lightning
from .trainer import Trainer from .trainer import Trainer
__all__ = [ __all__ = [
...@@ -148,7 +148,7 @@ class _ClassificationModule(_MultiModelSupervisedLearningModule): ...@@ -148,7 +148,7 @@ class _ClassificationModule(_MultiModelSupervisedLearningModule):
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam): optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'acc': _AccuracyWithLogits}, super().__init__(criterion, {'acc': AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer) learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
......
...@@ -27,7 +27,7 @@ from nni.typehint import Literal ...@@ -27,7 +27,7 @@ from nni.typehint import Literal
__all__ = [ __all__ = [
'LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression', 'LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression',
'_AccuracyWithLogits', '_SupervisedLearningModule', '_ClassificationModule', '_RegressionModule', 'SupervisedLearningModule', 'ClassificationModule', 'RegressionModule', 'AccuracyWithLogits',
# FIXME: hack to make it importable for tests # FIXME: hack to make it importable for tests
] ]
...@@ -102,12 +102,15 @@ class Lightning(Evaluator): ...@@ -102,12 +102,15 @@ class Lightning(Evaluator):
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples. Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__. It can be `any types of dataloader supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
fit_kwargs
Keyword arguments passed to ``trainer.fit()``.
""" """
def __init__(self, lightning_module: LightningModule, trainer: Trainer, def __init__(self, lightning_module: LightningModule, trainer: Trainer,
train_dataloaders: Optional[Any] = None, train_dataloaders: Optional[Any] = None,
val_dataloaders: Optional[Any] = None, val_dataloaders: Optional[Any] = None,
train_dataloader: Optional[Any] = None): train_dataloader: Optional[Any] = None,
fit_kwargs: Optional[Dict[str, Any]] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.' assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
if train_dataloader is not None: if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
...@@ -117,7 +120,7 @@ class Lightning(Evaluator): ...@@ -117,7 +120,7 @@ class Lightning(Evaluator):
else: else:
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different # this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \ assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.nas.evaluator.pytorch.cgo.trainer' f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
if not _check_dataloader(train_dataloaders): if not _check_dataloader(train_dataloaders):
warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or ' warnings.warn(f'Please try to wrap PyTorch DataLoader with nni.trace or '
f'import DataLoader from {__name__}: {train_dataloaders}', f'import DataLoader from {__name__}: {train_dataloaders}',
...@@ -130,6 +133,7 @@ class Lightning(Evaluator): ...@@ -130,6 +133,7 @@ class Lightning(Evaluator):
self.trainer = trainer self.trainer = trainer
self.train_dataloaders = train_dataloaders self.train_dataloaders = train_dataloaders
self.val_dataloaders = val_dataloaders self.val_dataloaders = val_dataloaders
self.fit_kwargs = fit_kwargs or {}
@staticmethod @staticmethod
def _load(ir): def _load(ir):
...@@ -178,7 +182,7 @@ class Lightning(Evaluator): ...@@ -178,7 +182,7 @@ class Lightning(Evaluator):
The model to fit. The model to fit.
""" """
self.module.set_model(model) self.module.set_model(model)
return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders) return self.trainer.fit(self.module, self.train_dataloaders, self.val_dataloaders, **self.fit_kwargs)
def _check_dataloader(dataloader): def _check_dataloader(dataloader):
...@@ -194,7 +198,7 @@ def _check_dataloader(dataloader): ...@@ -194,7 +198,7 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ### ### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule): class SupervisedLearningModule(LightningModule):
trainer: pl.Trainer trainer: pl.Trainer
...@@ -273,19 +277,19 @@ class _SupervisedLearningModule(LightningModule): ...@@ -273,19 +277,19 @@ class _SupervisedLearningModule(LightningModule):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics} return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
class _AccuracyWithLogits(torchmetrics.Accuracy): class AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target): def update(self, pred, target):
return super().update(nn_functional.softmax(pred, dim=-1), target) return super().update(nn_functional.softmax(pred, dim=-1), target)
@nni.trace @nni.trace
class _ClassificationModule(_SupervisedLearningModule): class ClassificationModule(SupervisedLearningModule):
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss, def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: Type[optim.Optimizer] = optim.Adam, optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True): export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits}, super().__init__(criterion, {'acc': AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx) export_onnx=export_onnx)
...@@ -341,14 +345,14 @@ class Classification(Lightning): ...@@ -341,14 +345,14 @@ class Classification(Lightning):
if train_dataloader is not None: if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader train_dataloaders = train_dataloader
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate, module = ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs), super().__init__(module, Trainer(**trainer_kwargs),
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders) train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
@nni.trace @nni.trace
class _RegressionModule(_SupervisedLearningModule): class RegressionModule(SupervisedLearningModule):
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss, def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
...@@ -406,7 +410,14 @@ class Regression(Lightning): ...@@ -406,7 +410,14 @@ class Regression(Lightning):
if train_dataloader is not None: if train_dataloader is not None:
warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning) warnings.warn('`train_dataloader` is deprecated and replaced with `train_dataloaders`.', DeprecationWarning)
train_dataloaders = train_dataloader train_dataloaders = train_dataloader
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate, module = RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx) weight_decay=weight_decay, optimizer=optimizer, export_onnx=export_onnx)
super().__init__(module, Trainer(**trainer_kwargs), super().__init__(module, Trainer(**trainer_kwargs),
train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders) train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
# Alias for backwards compatibility
_SupervisedLearningModule = SupervisedLearningModule
_AccuracyWithLogits = AccuracyWithLogits
_ClassificationModule = ClassificationModule
_RegressionModule = RegressionModule
...@@ -13,7 +13,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, ...@@ -13,7 +13,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union, cast, overload) Optional, Set, Tuple, Type, Union, cast, overload)
if TYPE_CHECKING: if TYPE_CHECKING:
from .mutator import Mutator from nni.nas.mutable import Mutator
from nni.nas.evaluator import Evaluator from nni.nas.evaluator import Evaluator
from nni.nas.utils import uid from nni.nas.utils import uid
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
# If you've seen lint errors like `"Sequential" is not a known member of module`,
# please run `python test/vso_tools/trigger_import.py` to generate `_layers.py`.
from pathlib import Path from pathlib import Path
# To make auto-completion happy, we generate a _layers.py that lists out all the classes. # To make auto-completion happy, we generate a _layers.py that lists out all the classes.
......
...@@ -152,7 +152,7 @@ def _new_trainer(): ...@@ -152,7 +152,7 @@ def _new_trainer():
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform) train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform) test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits}) multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl.AccuracyWithLogits})
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True, lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
max_epochs=1, max_epochs=1,
...@@ -201,7 +201,7 @@ class CGOEngineTest(unittest.TestCase): ...@@ -201,7 +201,7 @@ class CGOEngineTest(unittest.TestCase):
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform) train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform) test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits}, n_models=2) multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl.AccuracyWithLogits}, n_models=2)
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True, lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
max_epochs=1, max_epochs=1,
...@@ -225,7 +225,7 @@ class CGOEngineTest(unittest.TestCase): ...@@ -225,7 +225,7 @@ class CGOEngineTest(unittest.TestCase):
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform) train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform) test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl._AccuracyWithLogits}, n_models=2) multi_module = _MultiModelSupervisedLearningModule(nn.CrossEntropyLoss, {'acc': pl.AccuracyWithLogits}, n_models=2)
lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True, lightning = pl.Lightning(multi_module, cgo_trainer.Trainer(use_cgo=True,
max_epochs=1, max_epochs=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment