Unverified Commit 445e7e0b authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Rewrite trainer with PyTorch Lightning (#3359)

parent 137830df
...@@ -2,13 +2,11 @@ ...@@ -2,13 +2,11 @@
-f https://download.pytorch.org/whl/torch_stable.html -f https://download.pytorch.org/whl/torch_stable.html
tensorflow tensorflow
# PyTorch 1.7 has compatibility issue with model compression.
# Check for MacOS because this file is used on all platforms.
torch == 1.6.0+cpu ; sys_platform != "darwin" torch == 1.6.0+cpu ; sys_platform != "darwin"
torch == 1.6.0 ; sys_platform == "darwin" torch == 1.6.0 ; sys_platform == "darwin"
torchvision == 0.7.0+cpu ; sys_platform != "darwin" torchvision == 0.7.0+cpu ; sys_platform != "darwin"
torchvision == 0.7.0 ; sys_platform == "darwin" torchvision == 0.7.0 ; sys_platform == "darwin"
pytorch-lightning
onnx onnx
peewee peewee
graphviz graphviz
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
tensorflow == 1.15.4 tensorflow == 1.15.4
torch == 1.5.1+cpu torch == 1.5.1+cpu
torchvision == 0.6.1+cpu torchvision == 0.6.1+cpu
# It will install pytorch-lightning 0.8.x and unit tests won't work.
# Latest version has conflict with tensorboard and tensorflow 1.x.
pytorch-lightning
keras == 2.1.6 keras == 2.1.6
onnx onnx
peewee peewee
......
...@@ -42,10 +42,16 @@ Graph Mutation APIs ...@@ -42,10 +42,16 @@ Graph Mutation APIs
Trainers Trainers
-------- --------
.. autoclass:: nni.retiarii.trainer.pytorch.PyTorchImageClassificationTrainer .. autoclass:: nni.retiarii.trainer.FunctionalTrainer
:members: :members:
.. autoclass:: nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer .. autoclass:: nni.retiarii.trainer.pytorch.lightning.LightningModule
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.Classification
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.Regression
:members: :members:
Oneshot Trainers Oneshot Trainers
...@@ -75,8 +81,8 @@ Strategies ...@@ -75,8 +81,8 @@ Strategies
Retiarii Experiments Retiarii Experiments
-------------------- --------------------
.. autoclass:: nni.retiarii.experiment.RetiariiExperiment .. autoclass:: nni.retiarii.experiment.pytorch.RetiariiExperiment
:members: :members:
.. autoclass:: nni.retiarii.experiment.RetiariiExeConfig .. autoclass:: nni.retiarii.experiment.pytorch.RetiariiExeConfig
:members: :members:
...@@ -149,7 +149,7 @@ Create a Trainer and Exploration Strategy ...@@ -149,7 +149,7 @@ Create a Trainer and Exploration Strategy
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**Classic search approach:** **Classic search approach:**
In this approach, trainer is for training each explored model, while strategy is for sampling the models. Both trainer and strategy are required to explore the model space. In this approach, trainer is for training each explored model, while strategy is for sampling the models. Both trainer and strategy are required to explore the model space. We recommend PyTorch-Lightning to write the full training process.
**Oneshot (weight-sharing) search approach:** **Oneshot (weight-sharing) search approach:**
In this approach, users only need a oneshot trainer, because this trainer takes charge of both search and training. In this approach, users only need a oneshot trainer, because this trainer takes charge of both search and training.
...@@ -163,10 +163,10 @@ In the following table, we listed the available trainers and strategies. ...@@ -163,10 +163,10 @@ In the following table, we listed the available trainers and strategies.
* - Trainer * - Trainer
- Strategy - Strategy
- Oneshot Trainer - Oneshot Trainer
* - PyTorchImageClassificationTrainer * - Classification
- TPEStrategy - TPEStrategy
- DartsTrainer - DartsTrainer
* - PyTorchMultiModelTrainer * - Regression
- RandomStrategy - RandomStrategy
- EnasTrainer - EnasTrainer
* - * -
...@@ -182,15 +182,20 @@ Here is a simple example of using trainer and strategy. ...@@ -182,15 +182,20 @@ Here is a simple example of using trainer and strategy.
.. code-block:: python .. code-block:: python
trainer = PyTorchImageClassificationTrainer(base_model, import nni.retiarii.trainer.pytorch.lightning as pl
dataset_cls="MNIST", from nni.retiarii import blackbox
dataset_kwargs={"root": "data/mnist", "download": True}, from torchvision import transforms
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
simple_startegy = RandomStrategy()
Users can refer to `this document <./WriteTrainer.rst>`__ for how to write a new trainer, and refer to `this document <./WriteStrategy.rst>`__ for how to write a new strategy. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = blackbox(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = blackbox(MNIST, root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=10)
.. Note:: For NNI to capture the dataset and dataloader and distribute it across different runs, please wrap your dataset with ``blackbox`` and use ``pl.DataLoader`` instead of ``torch.utils.data.DataLoader``. See ``blackbox_module`` section below for details.
Users can refer to `API reference <./ApiReference.rst>`__ on detailed usage of trainer. "`write a trainer <./WriteTrainer.rst>`__" for how to write a new trainer, and refer to `this document <./WriteStrategy.rst>`__ for how to write a new strategy.
Set up an Experiment Set up an Experiment
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
......
...@@ -3,51 +3,106 @@ Customize A New Trainer ...@@ -3,51 +3,106 @@ Customize A New Trainer
Trainers are necessary to evaluate the performance of new explored models. In NAS scenario, this further divides into two use cases: Trainers are necessary to evaluate the performance of new explored models. In NAS scenario, this further divides into two use cases:
1. **Classic trainers**: trainers that are used to train and evaluate one single model. 1. **Single-arch trainers**: trainers that are used to train and evaluate one single model.
2. **One-shot trainers**: trainers that handle training and searching simultaneously, from an end-to-end perspective. 2. **One-shot trainers**: trainers that handle training and searching simultaneously, from an end-to-end perspective.
Classic trainers Single-arch trainers
---------------- --------------------
All classic trainers need to inherit ``nni.retiarii.trainer.BaseTrainer``, implement the ``fit`` method and decorated with ``@register_trainer`` if it is intended to be used together with Retiarii. The decorator serialize the trainer that is used and its argument to fit for the requirements of NNI. With PyTorch-Lightning
^^^^^^^^^^^^^^^^^^^^^^
The init function of trainer should take model as its first argument, and the rest of the arguments should be named (``*args`` and ``**kwargs`` may not work as expected) and JSON serializable. This means, currently, passing a complex object like ``torchvision.datasets.ImageNet()`` is not supported. Trainer should use NNI standard API to communicate with tuning algorithms. This includes ``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics. It's recommended to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the `document of PyTorch-lightning <https://pytorch-lightning.readthedocs.io/>` to learn the basic concepts and components provided by PyTorch-lightning.
In pratice, writing a new training module in NNI should inherit ``nni.retiarii.trainer.pytorch.lightning.LightningModule``, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Trainers should also communicate with strategies via two API calls (``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively.
An example is as follows: An example is as follows:
.. code-block::python .. code-block::python
from nni.retiarii import register_trainer from nni.retiarii.trainer.pytorch.lightning import LightningModule # please import this one
from nni.retiarii.trainer import BaseTrainer
@register_trainer @blackbox_module
class MnistTrainer(BaseTrainer): class AutoEncoder(LightningModule):
def __init__(self, model, optimizer_class_name='SGD', learning_rate=0.1): def __init__(self):
super().__init__() super().__init__()
self.model = model self.decoder = nn.Sequential(
self.criterion = nn.CrossEntropyLoss() nn.Linear(3, 64),
self.train_dataset = MNIST(train=True) nn.ReLU(),
self.valid_dataset = MNIST(train=False) nn.Linear(64, 28*28)
self.optimizer = getattr(torch.optim, optimizer_class_name)(lr=learning_rate) )
def validate(): def forward(self, x):
pass embedding = self.model(x) # let's search for encoder
return embedding
def fit(self) -> None:
for i in range(10): # number of epochs: def training_step(self, batch, batch_idx):
for x, y in DataLoader(self.dataset): # training_step defined the train loop.
self.optimizer.zero_grad() # It is independent of forward
pred = self.model(x) x, y = batch
loss = self.criterion(pred, y) x = x.view(x.size(0), -1)
loss.backward() z = self.model(x) # model is the one that is searched for
self.optimizer.step() x_hat = self.decoder(z)
acc = self.validate() # get validation accuracy loss = F.mse_loss(x_hat, x)
nni.report_final_result(acc) # Logging to TensorBoard by default
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.model(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('val_loss', loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def on_validation_epoch_end(self):
nni.report_intermediate_result(self.trainer.callback_metrics['val_loss'].item())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_loss'].item())
Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a ``Lightning`` object, and pass this object into a Retiarii experiment.
.. code-block::python
import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii.experiment.pytorch import RetiariiExperiment
lightning = pl.Lightning(AutoEncoder(),
pl.Trainer(max_epochs=10),
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
experiment = RetiariiExperiment(base_model, lightning, mutators, strategy)
With FunctionalTrainer
^^^^^^^^^^^^^^^^^^^^^^
There is another way to customize a new trainer with functional APIs, which provides more flexibility. Users only need to write a fit function that wraps everything. This function takes one positional arguments (model) and possible keyword arguments. In this way, users get everything under their control, but exposes less information to the framework and thus fewer opportunities for possible optimization. An example is as belows:
.. code-block::python
from nni.retiarii.trainer import FunctionalTrainer
from nni.retiarii.experiment.pytorch import RetiariiExperiment
def fit(model, dataloader):
train(model, dataloader)
acc = test(model, dataloader)
nni.report_final_result(acc)
trainer = FunctionalTrainer(fit, dataloader=DataLoader(foo, bar))
experiment = RetiariiExperiment(base_model, trainer, mutators, strategy)
One-shot trainers One-shot trainers
----------------- -----------------
One-shot trainers should inheirt ``nni.retiarii.trainer.BaseOneShotTrainer``, which is basically same as ``BaseTrainer``, but only with one extra method ``export()``, which is expected to return the searched best architecture. One-shot trainers should inheirt ``nni.retiarii.trainer.BaseOneShotTrainer``, and need to implement ``fit()`` (used to conduct the fitting and searching process) and ``export()`` method (used to return the searched best architecture).
Writing a one-shot trainer is very different to classic trainers. First of all, there are no more restrictions on init method arguments, any Python arguments are acceptable. Secondly, the model feeded into one-shot trainers might be a model with Retiarii-specific modules, such as LayerChoice and InputChoice. Such model cannot directly forward-propagate and trainers need to decide how to handle those modules. Writing a one-shot trainer is very different to classic trainers. First of all, there are no more restrictions on init method arguments, any Python arguments are acceptable. Secondly, the model feeded into one-shot trainers might be a model with Retiarii-specific modules, such as LayerChoice and InputChoice. Such model cannot directly forward-propagate and trainers need to decide how to handle those modules.
...@@ -55,7 +110,7 @@ A typical example is DartsTrainer, where learnable-parameters are used to combin ...@@ -55,7 +110,7 @@ A typical example is DartsTrainer, where learnable-parameters are used to combin
.. code-block::python .. code-block::python
from nni.retiarii.trainer import BaseOneShotTrainer from nni.retiarii.trainer.pytorch import BaseOneShotTrainer
from nni.retiarii.trainer.pytorch.utils import replace_layer_choice, replace_input_choice from nni.retiarii.trainer.pytorch.utils import replace_layer_choice, replace_input_choice
......
...@@ -2,4 +2,4 @@ from .operation import Operation ...@@ -2,4 +2,4 @@ from .operation import Operation
from .graph import * from .graph import *
from .execution import * from .execution import *
from .mutator import * from .mutator import *
from .utils import blackbox, blackbox_module, register_trainer from .utils import blackbox, blackbox_module, json_dump, json_dumps, json_load, json_loads, register_trainer
...@@ -2,31 +2,29 @@ import logging ...@@ -2,31 +2,29 @@ import logging
import os import os
import random import random
import string import string
from typing import Dict, Any, List from typing import Dict, List
from .interface import AbstractExecutionEngine, AbstractGraphListener from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData from ..graph import Model, ModelStatus, MetricData, TrainingConfig
from ..integration_api import send_trial, receive_trial_parameters, get_advisor from ..integration_api import send_trial, receive_trial_parameters, get_advisor
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class BaseGraphData: class BaseGraphData:
def __init__(self, model_script: str, training_module: str, training_kwargs: Dict[str, Any]) -> None: def __init__(self, model_script: str, training_config: TrainingConfig) -> None:
self.model_script = model_script self.model_script = model_script
self.training_module = training_module self.training_config = training_config
self.training_kwargs = training_kwargs
def dump(self) -> dict: def dump(self) -> dict:
return { return {
'model_script': self.model_script, 'model_script': self.model_script,
'training_module': self.training_module, 'training_config': self.training_config
'training_kwargs': self.training_kwargs
} }
@staticmethod @staticmethod
def load(data): def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], data['training_module'], data['training_kwargs']) return BaseGraphData(data['model_script'], data['training_config'])
class BaseExecutionEngine(AbstractExecutionEngine): class BaseExecutionEngine(AbstractExecutionEngine):
...@@ -57,8 +55,7 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -57,8 +55,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def submit_models(self, *models: Model) -> None: def submit_models(self, *models: Model) -> None:
for model in models: for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), data = BaseGraphData(codegen.model_to_pytorch_script(model), model.training_config)
model.training_config.module, model.training_config.kwargs)
self._running_models[send_trial(data.dump())] = model self._running_models[send_trial(data.dump())] = model
def register_graph_listener(self, listener: AbstractGraphListener) -> None: def register_graph_listener(self, listener: AbstractGraphListener) -> None:
...@@ -105,11 +102,10 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -105,11 +102,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
""" """
graph_data = BaseGraphData.load(receive_trial_parameters()) graph_data = BaseGraphData.load(receive_trial_parameters())
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6)) random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model_{random_str}.py' file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f: with open(file_name, 'w') as f:
f.write(graph_data.model_script) f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module) model_cls = utils.import_(f'_generated_model.{random_str}._model')
model_cls = utils.import_(f'_generated_model_{random_str}._model') graph_data.training_config._execute(model_cls)
trainer_instance = trainer_cls(model=model_cls(), **graph_data.training_kwargs) os.remove(file_name)
trainer_instance.fit()
os.remove(file_name)
\ No newline at end of file
...@@ -44,7 +44,7 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -44,7 +44,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
phy_models_and_placements = self._assemble(logical) phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements: for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement),
model.training_config.module, model.training_config.kwargs) model.training_config)
for m in grouped_models: for m in grouped_models:
self._original_models[m.model_id] = m self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model self._original_model_to_multi_model[m.model_id] = model
......
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from subprocess import Popen from subprocess import Popen
from threading import Thread from threading import Thread
from typing import Any, Optional from typing import Any, List, Optional, Union
from ..experiment import Experiment, TrainingServiceConfig import torch
from ..experiment.config.base import ConfigBase, PathLike import torch.nn as nn
from ..experiment.config import util from nni.experiment import Experiment, TrainingServiceConfig
from ..experiment.pipe import Pipe from nni.experiment.config import util
from nni.experiment.config.base import ConfigBase, PathLike
from .graph import Model from nni.experiment.pipe import Pipe
from .utils import get_records
from .integration import RetiariiAdvisor from ..converter import convert_to_graph
from .converter import convert_to_graph from ..graph import Model, TrainingConfig
from .mutator import Mutator from ..integration import RetiariiAdvisor
from .trainer.interface import BaseTrainer, BaseOneShotTrainer from ..mutator import Mutator
from .strategies.strategy import BaseStrategy from ..nn.pytorch.mutator import process_inline_mutation
from .trainer import BaseOneShotTrainer from ..strategies.strategy import BaseStrategy
from ..trainer.interface import BaseOneShotTrainer, BaseTrainer
from ..utils import get_records
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -76,8 +77,9 @@ _validation_rules = { ...@@ -76,8 +77,9 @@ _validation_rules = {
class RetiariiExperiment(Experiment): class RetiariiExperiment(Experiment):
def __init__(self, base_model: Model, trainer: BaseTrainer, def __init__(self, base_model: nn.Module, trainer: Union[TrainingConfig, BaseOneShotTrainer],
applied_mutators: Mutator = None, strategy: BaseStrategy = None): applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None):
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = None self.config: RetiariiExeConfig = None
self.port: Optional[int] = None self.port: Optional[int] = None
...@@ -93,28 +95,19 @@ class RetiariiExperiment(Experiment): ...@@ -93,28 +95,19 @@ class RetiariiExperiment(Experiment):
self._pipe: Optional[Pipe] = None self._pipe: Optional[Pipe] = None
def _start_strategy(self): def _start_strategy(self):
import torch
from .nn.pytorch.mutator import process_inline_mutation
try: try:
script_module = torch.jit.script(self.base_model) script_module = torch.jit.script(self.base_model)
except Exception as e: except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e raise e
base_model_ir = convert_to_graph(script_module, self.base_model) base_model_ir = convert_to_graph(script_module, self.base_model)
base_model_ir.training_config = self.trainer
recorded_module_args = get_records()
if id(self.trainer) not in recorded_module_args:
raise KeyError('Your trainer is not found in registered classes. You might have forgotten to \
register your customized trainer with @register_trainer decorator.')
trainer_config = recorded_module_args[id(self.trainer)]
base_model_ir.apply_trainer(trainer_config['modulename'], trainer_config['args'])
# handle inline mutations # handle inline mutations
mutators = process_inline_mutation(base_model_ir) mutators = process_inline_mutation(base_model_ir)
if mutators is not None and self.applied_mutators: if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \ raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
do not use mutators when you use LayerChoice/InputChoice') 'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None: if mutators is not None:
self.applied_mutators = mutators self.applied_mutators = mutators
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
Model representation. Model representation.
""" """
import abc
import copy import copy
from enum import Enum
import json import json
from enum import Enum
from typing import (Any, Dict, List, Optional, Tuple, Union, overload) from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid from .utils import get_full_class_name, import_, uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData'] __all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
...@@ -24,40 +25,43 @@ Type hint for edge's endpoint. The int indicates nodes' order. ...@@ -24,40 +25,43 @@ Type hint for edge's endpoint. The int indicates nodes' order.
""" """
class TrainingConfig: class TrainingConfig(abc.ABC):
""" """
Training training_config of a model. Training config of a model. A training config should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
Module will be imported, initialized with generated model and arguments in ``kwargs``. or tune-able parameters (such as learning rate), depending on the implementation of training code.
Attributes Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
---------- For example, functional training config might directly import the function and call the function.
module
Trainer module
kwargs
Trainer keyword arguments
""" """
def __init__(self, module: str, kwargs: Dict[str, Any]):
self.module = module
self.kwargs = kwargs
def __repr__(self): def __repr__(self):
return f'TrainingConfig(module={self.module}, kwargs={self.kwargs})' items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})'
@staticmethod @abc.abstractstaticmethod
def _load(ir: Any) -> 'TrainingConfig': def _load(ir: Any) -> 'TrainingConfig':
return TrainingConfig(ir['module'], ir.get('kwargs', {})) pass
@staticmethod
def _load_with_type(type_name: str, ir: Any) -> 'Optional[TrainingConfig]':
if type_name == '_debug_no_trainer':
return DebugTraining()
config_cls = import_(type_name)
assert issubclass(config_cls, TrainingConfig)
return config_cls._load(ir)
@abc.abstractmethod
def _dump(self) -> Any: def _dump(self) -> Any:
return { pass
'module': self.module,
'kwargs': self.kwargs
}
def __eq__(self, other): @abc.abstractmethod
return self.module == other.module and \ def _execute(self, model_cls: type) -> Any:
self.kwargs == other.kwargs pass
@abc.abstractmethod
def __eq__(self, other) -> bool:
pass
class Model: class Model:
...@@ -100,7 +104,7 @@ class Model: ...@@ -100,7 +104,7 @@ class Model:
self._root_graph_name: str = '_model' self._root_graph_name: str = '_model'
self.graphs: Dict[str, Graph] = {} self.graphs: Dict[str, Graph] = {}
self.training_config: TrainingConfig = TrainingConfig('foo', {}) self.training_config: Optional[TrainingConfig] = None
self.history: List[Model] = [] self.history: List[Model] = []
...@@ -137,18 +141,17 @@ class Model: ...@@ -137,18 +141,17 @@ class Model:
for graph_name, graph_data in ir.items(): for graph_name, graph_data in ir.items():
if graph_name != '_training_config': if graph_name != '_training_config':
Graph._load(model, graph_name, graph_data)._register() Graph._load(model, graph_name, graph_data)._register()
model.training_config = TrainingConfig._load(ir['_training_config']) model.training_config = TrainingConfig._load_with_type(ir['_training_config']['__type__'], ir['_training_config'])
return model return model
def _dump(self) -> Any: def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()} ret = {name: graph._dump() for name, graph in self.graphs.items()}
ret['_training_config'] = self.training_config._dump() ret['_training_config'] = {
'__type__': get_full_class_name(self.training_config.__class__),
**self.training_config._dump()
}
return ret return ret
def apply_trainer(self, module, args) -> None:
# TODO: rethink the way of specifying a trainer
self.training_config = TrainingConfig(module, args)
def get_nodes_by_label(self, label: str) -> List['Node']: def get_nodes_by_label(self, label: str) -> List['Node']:
""" """
Traverse all the nodes to find the matched node(s) with the given name. Traverse all the nodes to find the matched node(s) with the given name.
...@@ -668,3 +671,18 @@ class IllegalGraphError(ValueError): ...@@ -668,3 +671,18 @@ class IllegalGraphError(ValueError):
graph = graph._dump() graph = graph._dump()
with open('generated/debug.json', 'w') as dump_file: with open('generated/debug.json', 'w') as dump_file:
json.dump(graph, dump_file, indent=4) json.dump(graph, dump_file, indent=4)
class DebugTraining(TrainingConfig):
@staticmethod
def _load(ir: Any) -> 'DebugTraining':
return DebugTraining()
def _dump(self) -> Any:
return {'__type__': '_debug_no_trainer'}
def _execute(self, model_cls: type) -> Any:
pass
def __eq__(self, other) -> bool:
return True
...@@ -2,7 +2,6 @@ import logging ...@@ -2,7 +2,6 @@ import logging
import os import os
from typing import Any, Callable from typing import Any, Callable
import json_tricks
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType from nni.utils import MetricType
...@@ -12,6 +11,7 @@ from .execution.base import BaseExecutionEngine ...@@ -12,6 +11,7 @@ from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine from .execution.cgo_engine import CGOExecutionEngine
from .execution.api import set_execution_engine from .execution.api import set_execution_engine
from .integration_api import register_advisor from .integration_api import register_advisor
from .utils import json_dumps, json_loads
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -100,7 +100,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -100,7 +100,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameter_source': 'algorithm' 'parameter_source': 'algorithm'
} }
_logger.info('New trial sent: %s', new_trial) _logger.info('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_tricks.dumps(new_trial)) send(CommandType.NewTrialJob, json_dumps(new_trial))
if self.send_trial_callback is not None: if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count return self.parameters_count
...@@ -116,7 +116,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -116,7 +116,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
def handle_trial_end(self, data): def handle_trial_end(self, data):
_logger.info('Trial end: %s', data) _logger.info('Trial end: %s', data)
self.trial_end_callback(json_tricks.loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED') data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
...@@ -132,7 +132,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -132,7 +132,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
@staticmethod @staticmethod
def _process_value(value) -> Any: # hopefully a float def _process_value(value) -> Any: # hopefully a float
value = json_tricks.loads(value) value = json_loads(value)
if isinstance(value, dict): if isinstance(value, dict):
if 'default' in value: if 'default' in value:
return value['default'] return value['default']
......
import json
from typing import NewType, Any from typing import NewType, Any
import nni import nni
from .utils import json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor # NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import # because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any) RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
...@@ -31,6 +35,12 @@ def send_trial(parameters: dict) -> int: ...@@ -31,6 +35,12 @@ def send_trial(parameters: dict) -> int:
def receive_trial_parameters() -> dict: def receive_trial_parameters() -> dict:
""" """
Received a new trial. Executed on trial end. Received a new trial. Executed on trial end.
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
""" """
params = nni.get_next_parameter() params = nni.get_next_parameter()
params = json_loads(json.dumps(params))
return params return params
def get_experiment_id() -> str:
return nni.get_experiment_id()
from .interface import BaseTrainer, BaseOneShotTrainer from .functional import FunctionalTrainer
from .interface import BaseOneShotTrainer
from ..graph import TrainingConfig
class FunctionalTrainer(TrainingConfig):
"""
Functional training config that directly takes a function and thus should be general.
Attributes
----------
function
The full name of the function.
arguments
Keyword arguments for the function other than model.
"""
def __init__(self, function, **kwargs):
self.function = function
self.arguments = kwargs
@staticmethod
def _load(ir):
return FunctionalTrainer(ir['function'], **ir['arguments'])
def _dump(self):
return {
'function': self.function,
'arguments': self.arguments
}
def _execute(self, model_cls):
return self.function(model_cls, **self.arguments)
def __eq__(self, other):
return self.function == other.function and self.arguments == other.arguments
...@@ -3,33 +3,25 @@ from typing import Any ...@@ -3,33 +3,25 @@ from typing import Any
class BaseTrainer(abc.ABC): class BaseTrainer(abc.ABC):
""" # Deprecated class
In this version, we plan to write our own trainers instead of using PyTorch-lightning, to pass
ease the burden to integrate our optmization with PyTorch-lightning, a large part of which is
opaque to us.
We will try to align with PyTorch-lightning name conversions so that we can easily migrate to class BaseOneShotTrainer(abc.ABC):
PyTorch-lightning in the future. """
Build many (possibly all) architectures into a full graph, search (with train) and export the best.
Currently, our trainer = LightningModule + LightningTrainer. We might want to separate these two things One-shot trainer has a ``fit`` function with no return value. Trainers should fit and search for the best architecture.
in future. Currently, all the inputs of trainer needs to be manually set before fit (including the search space, data loader
to use training epochs, and etc.).
Trainer has a ``fit`` function with no return value. Intermediate results and final results should be It has an extra ``export`` function that exports an object representing the final searched architecture.
directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions.
""" """
@abc.abstractmethod @abc.abstractmethod
def fit(self) -> None: def fit(self) -> None:
pass pass
class BaseOneShotTrainer(BaseTrainer):
"""
Build many (possibly all) architectures into a full graph, search (with train) and export the best.
It has an extra ``export`` function that exports an object representing the final searched architecture.
"""
@abc.abstractmethod @abc.abstractmethod
def export(self) -> Any: def export(self) -> Any:
pass pass
# This file is deprecated.
from typing import Any, List, Dict, Tuple from typing import Any, List, Dict, Tuple
import numpy as np import numpy as np
......
import warnings
from typing import Dict, Union, Optional, List
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import nni
from ...graph import TrainingConfig
from ...utils import blackbox_module
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
class LightningModule(pl.LightningModule):
def set_model(self, model):
if isinstance(model, type):
self.model = model()
else:
self.model = model
Trainer = blackbox_module(pl.Trainer)
DataLoader = blackbox_module(DataLoader)
class Lightning(TrainingConfig):
"""
Delegate the whole training to PyTorch Lightning.
Since the arguments passed to the initialization needs to be serialized, ``LightningModule``, ``Trainer`` or
``DataLoader`` in this file should be used. Another option is to hide dataloader in the Lightning module, in
which case, dataloaders are not required for this class to work.
Following the programming style of Lightning, metrics sent to NNI should be obtained from ``callback_metrics``
in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name
and type depend on the specific task.
Parameters
----------
lightning_module : LightningModule
Lightning module that defines the training logic.
trainer : Trainer
Lightning trainer that handles the training.
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
"""
def __init__(self, lightning_module: LightningModule, trainer: Trainer,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
assert isinstance(trainer, Trainer), f'Trainer must be imported from {__name__}.'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
self.module = lightning_module
self.trainer = trainer
self.train_dataloader = train_dataloader
self.val_dataloaders = val_dataloaders
@staticmethod
def _load(ir):
return Lightning(ir['module'], ir['trainer'], ir['train_dataloader'], ir['val_dataloaders'])
def _dump(self):
return {
'module': self.module,
'trainer': self.trainer,
'train_dataloader': self.train_dataloader,
'val_dataloaders': self.val_dataloaders
}
def _execute(self, model_cls):
return self.fit(model_cls)
def __eq__(self, other):
return self.function == other.function and self.arguments == other.arguments
def fit(self, model):
"""
Fit the model with provided dataloader, with Lightning trainer.
Parameters
----------
model : nn.Module
The model to fit.
"""
self.module.set_model(model)
return self.trainer.fit(self.module, self.train_dataloader, self.val_dataloaders)
def _check_dataloader(dataloader):
if dataloader is None:
return True
if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader])
return isinstance(dataloader, DataLoader)
### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss, prog_bar=True)
for name, metric in self.metrics.items():
self.log('train_' + name, metric(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('val_' + name, metric(y_hat, y), prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('test_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('test_' + name, metric(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
if len(self.metrics) == 1:
metric_name = next(iter(self.metrics))
return self.trainer.callback_metrics['val_' + metric_name].item()
else:
warnings.warn('Multiple metrics without "default" is not supported by current framework.')
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
@blackbox_module
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'acc': pl.metrics.Accuracy},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Classification(Lightning):
"""
Trainer that is used for classification.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@blackbox_module
class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Regression(Lightning):
"""
Trainer that is used for regression.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.MSELoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
import functools
import inspect import inspect
import warnings
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
from pathlib import Path from pathlib import Path
import json_tricks
def import_(target: str, allow_none: bool = False) -> Any: def import_(target: str, allow_none: bool = False) -> Any:
if target is None: if target is None:
...@@ -20,6 +22,45 @@ def version_larger_equal(a: str, b: str) -> bool: ...@@ -20,6 +22,45 @@ def version_larger_equal(a: str, b: str) -> bool:
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.'))) return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
### This is a patch of json-tricks to make it more useful to us ###
def _blackbox_class_instance_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if hasattr(obj, '__class__') and hasattr(obj, '__init_parameters__'):
return {
'__type__': get_full_class_name(obj.__class__),
'arguments': obj.__init_parameters__
}
return obj
def _blackbox_class_instance_decode(obj):
if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
return import_(obj['__type__'])(**obj['arguments'])
return obj
def _type_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if isinstance(obj, type):
return {'__typename__': get_full_class_name(obj, relocate_module=True)}
return obj
def _type_decode(obj):
if isinstance(obj, dict) and '__typename__' in obj:
return import_(obj['__typename__'])
return obj
json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_blackbox_class_instance_decode, _type_decode])
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_blackbox_class_instance_encode, _type_encode])
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_blackbox_class_instance_decode, _type_decode])
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_blackbox_class_instance_encode, _type_encode])
### End of json-tricks patch ###
_records = {} _records = {}
...@@ -48,10 +89,10 @@ def del_record(key): ...@@ -48,10 +89,10 @@ def del_record(key):
_records.pop(key, None) _records.pop(key, None)
def _blackbox_cls(cls, module_name, register_format=None): def _blackbox_cls(cls):
class wrapper(cls): class wrapper(cls):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
argname_list = list(inspect.signature(cls).parameters.keys()) argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
full_args = {} full_args = {}
full_args.update(kwargs) full_args.update(kwargs)
...@@ -59,30 +100,16 @@ def _blackbox_cls(cls, module_name, register_format=None): ...@@ -59,30 +100,16 @@ def _blackbox_cls(cls, module_name, register_format=None):
for argname, value in zip(argname_list, args): for argname, value in zip(argname_list, args):
full_args[argname] = value full_args[argname] = value
# eject un-serializable arguments add_record(id(self), full_args) # for compatibility. Will remove soon.
for k in list(full_args.keys()):
# The list is not complete and does not support nested cases. self.__init_parameters__ = full_args
if not isinstance(full_args[k], (int, float, str, dict, list, tuple)):
if not (register_format == 'full' and k == 'model'):
# no warning if it is base model in trainer
warnings.warn(f'{cls} has un-serializable arguments {k} whose value is {full_args[k]}. \
This is not supported. You can ignore this warning if you are passing the model to trainer.')
full_args.pop(k)
if register_format == 'args':
add_record(id(self), full_args)
elif register_format == 'full':
full_class_name = cls.__module__ + '.' + cls.__name__
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def __del__(self): def __del__(self):
del_record(id(self)) del_record(id(self))
# using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped wrapper.__module__ = _get_module_name(cls)
# instead of simply putting torch.nn or etc.
wrapper.__module__ = module_name
wrapper.__name__ = cls.__name__ wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__ wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__ wrapper.__init__.__doc__ = cls.__init__.__doc__
...@@ -97,41 +124,21 @@ def blackbox(cls, *args, **kwargs): ...@@ -97,41 +124,21 @@ def blackbox(cls, *args, **kwargs):
.. code-block:: python .. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128) self.op = blackbox(MyCustomOp, hidden_units=128)
""" """
# get caller module name return _blackbox_cls(cls)(*args, **kwargs)
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')(*args, **kwargs)
def blackbox_module(cls): def blackbox_module(cls):
""" """
Register a module. Use it as a decorator. Register a module. Use it as a decorator.
""" """
frm = inspect.stack()[1] return _blackbox_cls(cls)
assert (inspect.getmodule(frm[0]) is not None), ('unable to locate the definition of the given black box module, '
'please define it explicitly in a .py file.')
module_name = inspect.getmodule(frm[0]).__name__
if module_name == '__main__':
main_file_path = Path(inspect.getsourcefile(frm[0]))
if main_file_path.parents[0] != Path('.'):
raise RuntimeError(f'you are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem
return _blackbox_cls(cls, module_name, 'args')
def register_trainer(cls): def register_trainer(cls):
""" """
Register a trainer. Use it as a decorator. Register a trainer. Use it as a decorator.
""" """
frm = inspect.stack()[1] return _blackbox_cls(cls)
assert (inspect.getmodule(frm[0]) is not None), ('unable to locate the definition of the given trainer, '
'please define it explicitly in a .py file.')
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'full')
_last_uid = defaultdict(int) _last_uid = defaultdict(int)
...@@ -140,3 +147,24 @@ _last_uid = defaultdict(int) ...@@ -140,3 +147,24 @@ _last_uid = defaultdict(int)
def uid(namespace: str = 'default') -> int: def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1 _last_uid[namespace] += 1
return _last_uid[namespace] return _last_uid[namespace]
def _get_module_name(cls):
module_name = cls.__module__
if module_name == '__main__':
# infer the module name with inspect
for frm in inspect.stack():
if inspect.getmodule(frm[0]).__name__ == '__main__':
# main module found
main_file_path = Path(inspect.getsourcefile(frm[0]))
if main_file_path.parents[0] != Path('.'):
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem
break
return module_name
def get_full_class_name(cls, relocate_module=False):
module_name = _get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__
...@@ -6,6 +6,7 @@ assessor_result.txt ...@@ -6,6 +6,7 @@ assessor_result.txt
_generated_model.py _generated_model.py
_generated_model_*.py _generated_model_*.py
_generated_model
data data
generated generated
lightning_logs
...@@ -4,21 +4,35 @@ import sys ...@@ -4,21 +4,35 @@ import sys
import torch import torch
from pathlib import Path from pathlib import Path
from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii import blackbox_module as bm
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategies import TPEStrategy, RandomStrategy from nni.retiarii.strategies import TPEStrategy, RandomStrategy
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer from torchvision import transforms
from torchvision.datasets import CIFAR10
from darts_model import CNN from darts_model import CNN
if __name__ == '__main__': if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8) base_model = CNN(32, 3, 16, 10, 8)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
#simple_startegy = TPEStrategy() train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = bm(CIFAR10)(root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = bm(CIFAR10)(root='data/cifar10', train=False, download=True, transform=valid_transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2)
simple_startegy = RandomStrategy() simple_startegy = RandomStrategy()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy) exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment