Unverified Commit 445e7e0b authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Rewrite trainer with PyTorch Lightning (#3359)

parent 137830df
......@@ -2,13 +2,11 @@
-f https://download.pytorch.org/whl/torch_stable.html
tensorflow
# PyTorch 1.7 has compatibility issue with model compression.
# Check for MacOS because this file is used on all platforms.
torch == 1.6.0+cpu ; sys_platform != "darwin"
torch == 1.6.0 ; sys_platform == "darwin"
torchvision == 0.7.0+cpu ; sys_platform != "darwin"
torchvision == 0.7.0 ; sys_platform == "darwin"
pytorch-lightning
onnx
peewee
graphviz
......@@ -2,6 +2,11 @@
tensorflow == 1.15.4
torch == 1.5.1+cpu
torchvision == 0.6.1+cpu
# It will install pytorch-lightning 0.8.x and unit tests won't work.
# Latest version has conflict with tensorboard and tensorflow 1.x.
pytorch-lightning
keras == 2.1.6
onnx
peewee
......
......@@ -42,10 +42,16 @@ Graph Mutation APIs
Trainers
--------
.. autoclass:: nni.retiarii.trainer.pytorch.PyTorchImageClassificationTrainer
.. autoclass:: nni.retiarii.trainer.FunctionalTrainer
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.LightningModule
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.Classification
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.Regression
:members:
Oneshot Trainers
......@@ -75,8 +81,8 @@ Strategies
Retiarii Experiments
--------------------
.. autoclass:: nni.retiarii.experiment.RetiariiExperiment
.. autoclass:: nni.retiarii.experiment.pytorch.RetiariiExperiment
:members:
.. autoclass:: nni.retiarii.experiment.RetiariiExeConfig
.. autoclass:: nni.retiarii.experiment.pytorch.RetiariiExeConfig
:members:
......@@ -149,7 +149,7 @@ Create a Trainer and Exploration Strategy
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**Classic search approach:**
In this approach, trainer is for training each explored model, while strategy is for sampling the models. Both trainer and strategy are required to explore the model space.
In this approach, trainer is for training each explored model, while strategy is for sampling the models. Both trainer and strategy are required to explore the model space. We recommend PyTorch-Lightning to write the full training process.
**Oneshot (weight-sharing) search approach:**
In this approach, users only need a oneshot trainer, because this trainer takes charge of both search and training.
......@@ -163,10 +163,10 @@ In the following table, we listed the available trainers and strategies.
* - Trainer
- Strategy
- Oneshot Trainer
* - PyTorchImageClassificationTrainer
* - Classification
- TPEStrategy
- DartsTrainer
* - PyTorchMultiModelTrainer
* - Regression
- RandomStrategy
- EnasTrainer
* -
......@@ -182,15 +182,20 @@ Here is a simple example of using trainer and strategy.
.. code-block:: python
trainer = PyTorchImageClassificationTrainer(base_model,
dataset_cls="MNIST",
dataset_kwargs={"root": "data/mnist", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
simple_startegy = RandomStrategy()
import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii import blackbox
from torchvision import transforms
Users can refer to `this document <./WriteTrainer.rst>`__ for how to write a new trainer, and refer to `this document <./WriteStrategy.rst>`__ for how to write a new strategy.
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = blackbox(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = blackbox(MNIST, root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=10)
.. Note:: For NNI to capture the dataset and dataloader and distribute it across different runs, please wrap your dataset with ``blackbox`` and use ``pl.DataLoader`` instead of ``torch.utils.data.DataLoader``. See ``blackbox_module`` section below for details.
Users can refer to `API reference <./ApiReference.rst>`__ on detailed usage of trainer. "`write a trainer <./WriteTrainer.rst>`__" for how to write a new trainer, and refer to `this document <./WriteStrategy.rst>`__ for how to write a new strategy.
Set up an Experiment
^^^^^^^^^^^^^^^^^^^^
......
......@@ -3,51 +3,106 @@ Customize A New Trainer
Trainers are necessary to evaluate the performance of new explored models. In NAS scenario, this further divides into two use cases:
1. **Classic trainers**: trainers that are used to train and evaluate one single model.
1. **Single-arch trainers**: trainers that are used to train and evaluate one single model.
2. **One-shot trainers**: trainers that handle training and searching simultaneously, from an end-to-end perspective.
Classic trainers
----------------
Single-arch trainers
--------------------
All classic trainers need to inherit ``nni.retiarii.trainer.BaseTrainer``, implement the ``fit`` method and decorated with ``@register_trainer`` if it is intended to be used together with Retiarii. The decorator serialize the trainer that is used and its argument to fit for the requirements of NNI.
With PyTorch-Lightning
^^^^^^^^^^^^^^^^^^^^^^
The init function of trainer should take model as its first argument, and the rest of the arguments should be named (``*args`` and ``**kwargs`` may not work as expected) and JSON serializable. This means, currently, passing a complex object like ``torchvision.datasets.ImageNet()`` is not supported. Trainer should use NNI standard API to communicate with tuning algorithms. This includes ``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics.
It's recommended to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the `document of PyTorch-lightning <https://pytorch-lightning.readthedocs.io/>` to learn the basic concepts and components provided by PyTorch-lightning.
In pratice, writing a new training module in NNI should inherit ``nni.retiarii.trainer.pytorch.lightning.LightningModule``, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Trainers should also communicate with strategies via two API calls (``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively.
An example is as follows:
.. code-block::python
from nni.retiarii import register_trainer
from nni.retiarii.trainer import BaseTrainer
from nni.retiarii.trainer.pytorch.lightning import LightningModule # please import this one
@register_trainer
class MnistTrainer(BaseTrainer):
def __init__(self, model, optimizer_class_name='SGD', learning_rate=0.1):
@blackbox_module
class AutoEncoder(LightningModule):
def __init__(self):
super().__init__()
self.model = model
self.criterion = nn.CrossEntropyLoss()
self.train_dataset = MNIST(train=True)
self.valid_dataset = MNIST(train=False)
self.optimizer = getattr(torch.optim, optimizer_class_name)(lr=learning_rate)
def validate():
pass
def fit(self) -> None:
for i in range(10): # number of epochs:
for x, y in DataLoader(self.dataset):
self.optimizer.zero_grad()
pred = self.model(x)
loss = self.criterion(pred, y)
loss.backward()
self.optimizer.step()
acc = self.validate() # get validation accuracy
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28*28)
)
def forward(self, x):
embedding = self.model(x) # let's search for encoder
return embedding
def training_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.model(x) # model is the one that is searched for
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
# Logging to TensorBoard by default
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.model(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('val_loss', loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def on_validation_epoch_end(self):
nni.report_intermediate_result(self.trainer.callback_metrics['val_loss'].item())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_loss'].item())
Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a ``Lightning`` object, and pass this object into a Retiarii experiment.
.. code-block::python
import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii.experiment.pytorch import RetiariiExperiment
lightning = pl.Lightning(AutoEncoder(),
pl.Trainer(max_epochs=10),
train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
experiment = RetiariiExperiment(base_model, lightning, mutators, strategy)
With FunctionalTrainer
^^^^^^^^^^^^^^^^^^^^^^
There is another way to customize a new trainer with functional APIs, which provides more flexibility. Users only need to write a fit function that wraps everything. This function takes one positional arguments (model) and possible keyword arguments. In this way, users get everything under their control, but exposes less information to the framework and thus fewer opportunities for possible optimization. An example is as belows:
.. code-block::python
from nni.retiarii.trainer import FunctionalTrainer
from nni.retiarii.experiment.pytorch import RetiariiExperiment
def fit(model, dataloader):
train(model, dataloader)
acc = test(model, dataloader)
nni.report_final_result(acc)
trainer = FunctionalTrainer(fit, dataloader=DataLoader(foo, bar))
experiment = RetiariiExperiment(base_model, trainer, mutators, strategy)
One-shot trainers
-----------------
One-shot trainers should inheirt ``nni.retiarii.trainer.BaseOneShotTrainer``, which is basically same as ``BaseTrainer``, but only with one extra method ``export()``, which is expected to return the searched best architecture.
One-shot trainers should inheirt ``nni.retiarii.trainer.BaseOneShotTrainer``, and need to implement ``fit()`` (used to conduct the fitting and searching process) and ``export()`` method (used to return the searched best architecture).
Writing a one-shot trainer is very different to classic trainers. First of all, there are no more restrictions on init method arguments, any Python arguments are acceptable. Secondly, the model feeded into one-shot trainers might be a model with Retiarii-specific modules, such as LayerChoice and InputChoice. Such model cannot directly forward-propagate and trainers need to decide how to handle those modules.
......@@ -55,7 +110,7 @@ A typical example is DartsTrainer, where learnable-parameters are used to combin
.. code-block::python
from nni.retiarii.trainer import BaseOneShotTrainer
from nni.retiarii.trainer.pytorch import BaseOneShotTrainer
from nni.retiarii.trainer.pytorch.utils import replace_layer_choice, replace_input_choice
......
......@@ -2,4 +2,4 @@ from .operation import Operation
from .graph import *
from .execution import *
from .mutator import *
from .utils import blackbox, blackbox_module, register_trainer
from .utils import blackbox, blackbox_module, json_dump, json_dumps, json_load, json_loads, register_trainer
......@@ -2,31 +2,29 @@ import logging
import os
import random
import string
from typing import Dict, Any, List
from typing import Dict, List
from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData
from ..graph import Model, ModelStatus, MetricData, TrainingConfig
from ..integration_api import send_trial, receive_trial_parameters, get_advisor
_logger = logging.getLogger(__name__)
class BaseGraphData:
def __init__(self, model_script: str, training_module: str, training_kwargs: Dict[str, Any]) -> None:
def __init__(self, model_script: str, training_config: TrainingConfig) -> None:
self.model_script = model_script
self.training_module = training_module
self.training_kwargs = training_kwargs
self.training_config = training_config
def dump(self) -> dict:
return {
'model_script': self.model_script,
'training_module': self.training_module,
'training_kwargs': self.training_kwargs
'training_config': self.training_config
}
@staticmethod
def load(data):
return BaseGraphData(data['model_script'], data['training_module'], data['training_kwargs'])
def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], data['training_config'])
class BaseExecutionEngine(AbstractExecutionEngine):
......@@ -57,8 +55,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model),
model.training_config.module, model.training_config.kwargs)
data = BaseGraphData(codegen.model_to_pytorch_script(model), model.training_config)
self._running_models[send_trial(data.dump())] = model
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
......@@ -105,11 +102,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model_{random_str}.py'
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module)
model_cls = utils.import_(f'_generated_model_{random_str}._model')
trainer_instance = trainer_cls(model=model_cls(), **graph_data.training_kwargs)
trainer_instance.fit()
model_cls = utils.import_(f'_generated_model.{random_str}._model')
graph_data.training_config._execute(model_cls)
os.remove(file_name)
......@@ -44,7 +44,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement),
model.training_config.module, model.training_config.kwargs)
model.training_config)
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
......
import logging
from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
from threading import Thread
from typing import Any, Optional
from ..experiment import Experiment, TrainingServiceConfig
from ..experiment.config.base import ConfigBase, PathLike
from ..experiment.config import util
from ..experiment.pipe import Pipe
from .graph import Model
from .utils import get_records
from .integration import RetiariiAdvisor
from .converter import convert_to_graph
from .mutator import Mutator
from .trainer.interface import BaseTrainer, BaseOneShotTrainer
from .strategies.strategy import BaseStrategy
from .trainer import BaseOneShotTrainer
from typing import Any, List, Optional, Union
import torch
import torch.nn as nn
from nni.experiment import Experiment, TrainingServiceConfig
from nni.experiment.config import util
from nni.experiment.config.base import ConfigBase, PathLike
from nni.experiment.pipe import Pipe
from ..converter import convert_to_graph
from ..graph import Model, TrainingConfig
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import process_inline_mutation
from ..strategies.strategy import BaseStrategy
from ..trainer.interface import BaseOneShotTrainer, BaseTrainer
from ..utils import get_records
_logger = logging.getLogger(__name__)
......@@ -76,8 +77,9 @@ _validation_rules = {
class RetiariiExperiment(Experiment):
def __init__(self, base_model: Model, trainer: BaseTrainer,
applied_mutators: Mutator = None, strategy: BaseStrategy = None):
def __init__(self, base_model: nn.Module, trainer: Union[TrainingConfig, BaseOneShotTrainer],
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None):
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = None
self.port: Optional[int] = None
......@@ -93,28 +95,19 @@ class RetiariiExperiment(Experiment):
self._pipe: Optional[Pipe] = None
def _start_strategy(self):
import torch
from .nn.pytorch.mutator import process_inline_mutation
try:
script_module = torch.jit.script(self.base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
base_model_ir = convert_to_graph(script_module, self.base_model)
recorded_module_args = get_records()
if id(self.trainer) not in recorded_module_args:
raise KeyError('Your trainer is not found in registered classes. You might have forgotten to \
register your customized trainer with @register_trainer decorator.')
trainer_config = recorded_module_args[id(self.trainer)]
base_model_ir.apply_trainer(trainer_config['modulename'], trainer_config['args'])
base_model_ir.training_config = self.trainer
# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \
do not use mutators when you use LayerChoice/InputChoice')
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
self.applied_mutators = mutators
......
......@@ -2,13 +2,14 @@
Model representation.
"""
import abc
import copy
from enum import Enum
import json
from enum import Enum
from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid
from .utils import get_full_class_name, import_, uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
......@@ -24,40 +25,43 @@ Type hint for edge's endpoint. The int indicates nodes' order.
"""
class TrainingConfig:
class TrainingConfig(abc.ABC):
"""
Training training_config of a model.
Module will be imported, initialized with generated model and arguments in ``kwargs``.
Training config of a model. A training config should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
or tune-able parameters (such as learning rate), depending on the implementation of training code.
Attributes
----------
module
Trainer module
kwargs
Trainer keyword arguments
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
For example, functional training config might directly import the function and call the function.
"""
def __init__(self, module: str, kwargs: Dict[str, Any]):
self.module = module
self.kwargs = kwargs
def __repr__(self):
return f'TrainingConfig(module={self.module}, kwargs={self.kwargs})'
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})'
@staticmethod
@abc.abstractstaticmethod
def _load(ir: Any) -> 'TrainingConfig':
return TrainingConfig(ir['module'], ir.get('kwargs', {}))
pass
@staticmethod
def _load_with_type(type_name: str, ir: Any) -> 'Optional[TrainingConfig]':
if type_name == '_debug_no_trainer':
return DebugTraining()
config_cls = import_(type_name)
assert issubclass(config_cls, TrainingConfig)
return config_cls._load(ir)
@abc.abstractmethod
def _dump(self) -> Any:
return {
'module': self.module,
'kwargs': self.kwargs
}
pass
def __eq__(self, other):
return self.module == other.module and \
self.kwargs == other.kwargs
@abc.abstractmethod
def _execute(self, model_cls: type) -> Any:
pass
@abc.abstractmethod
def __eq__(self, other) -> bool:
pass
class Model:
......@@ -100,7 +104,7 @@ class Model:
self._root_graph_name: str = '_model'
self.graphs: Dict[str, Graph] = {}
self.training_config: TrainingConfig = TrainingConfig('foo', {})
self.training_config: Optional[TrainingConfig] = None
self.history: List[Model] = []
......@@ -137,18 +141,17 @@ class Model:
for graph_name, graph_data in ir.items():
if graph_name != '_training_config':
Graph._load(model, graph_name, graph_data)._register()
model.training_config = TrainingConfig._load(ir['_training_config'])
model.training_config = TrainingConfig._load_with_type(ir['_training_config']['__type__'], ir['_training_config'])
return model
def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()}
ret['_training_config'] = self.training_config._dump()
ret['_training_config'] = {
'__type__': get_full_class_name(self.training_config.__class__),
**self.training_config._dump()
}
return ret
def apply_trainer(self, module, args) -> None:
# TODO: rethink the way of specifying a trainer
self.training_config = TrainingConfig(module, args)
def get_nodes_by_label(self, label: str) -> List['Node']:
"""
Traverse all the nodes to find the matched node(s) with the given name.
......@@ -668,3 +671,18 @@ class IllegalGraphError(ValueError):
graph = graph._dump()
with open('generated/debug.json', 'w') as dump_file:
json.dump(graph, dump_file, indent=4)
class DebugTraining(TrainingConfig):
@staticmethod
def _load(ir: Any) -> 'DebugTraining':
return DebugTraining()
def _dump(self) -> Any:
return {'__type__': '_debug_no_trainer'}
def _execute(self, model_cls: type) -> Any:
pass
def __eq__(self, other) -> bool:
return True
......@@ -2,7 +2,6 @@ import logging
import os
from typing import Any, Callable
import json_tricks
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType
......@@ -12,6 +11,7 @@ from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
from .execution.api import set_execution_engine
from .integration_api import register_advisor
from .utils import json_dumps, json_loads
_logger = logging.getLogger(__name__)
......@@ -100,7 +100,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameter_source': 'algorithm'
}
_logger.info('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_tricks.dumps(new_trial))
send(CommandType.NewTrialJob, json_dumps(new_trial))
if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count
......@@ -116,7 +116,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
def handle_trial_end(self, data):
_logger.info('Trial end: %s', data)
self.trial_end_callback(json_tricks.loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
......@@ -132,7 +132,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
@staticmethod
def _process_value(value) -> Any: # hopefully a float
value = json_tricks.loads(value)
value = json_loads(value)
if isinstance(value, dict):
if 'default' in value:
return value['default']
......
import json
from typing import NewType, Any
import nni
from .utils import json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
......@@ -31,6 +35,12 @@ def send_trial(parameters: dict) -> int:
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params = nni.get_next_parameter()
params = json_loads(json.dumps(params))
return params
def get_experiment_id() -> str:
return nni.get_experiment_id()
from .interface import BaseTrainer, BaseOneShotTrainer
from .functional import FunctionalTrainer
from .interface import BaseOneShotTrainer
from ..graph import TrainingConfig
class FunctionalTrainer(TrainingConfig):
"""
Functional training config that directly takes a function and thus should be general.
Attributes
----------
function
The full name of the function.
arguments
Keyword arguments for the function other than model.
"""
def __init__(self, function, **kwargs):
self.function = function
self.arguments = kwargs
@staticmethod
def _load(ir):
return FunctionalTrainer(ir['function'], **ir['arguments'])
def _dump(self):
return {
'function': self.function,
'arguments': self.arguments
}
def _execute(self, model_cls):
return self.function(model_cls, **self.arguments)
def __eq__(self, other):
return self.function == other.function and self.arguments == other.arguments
......@@ -3,33 +3,25 @@ from typing import Any
class BaseTrainer(abc.ABC):
"""
In this version, we plan to write our own trainers instead of using PyTorch-lightning, to
ease the burden to integrate our optmization with PyTorch-lightning, a large part of which is
opaque to us.
We will try to align with PyTorch-lightning name conversions so that we can easily migrate to
PyTorch-lightning in the future.
Currently, our trainer = LightningModule + LightningTrainer. We might want to separate these two things
in future.
Trainer has a ``fit`` function with no return value. Intermediate results and final results should be
directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions.
"""
@abc.abstractmethod
def fit(self) -> None:
# Deprecated class
pass
class BaseOneShotTrainer(BaseTrainer):
class BaseOneShotTrainer(abc.ABC):
"""
Build many (possibly all) architectures into a full graph, search (with train) and export the best.
One-shot trainer has a ``fit`` function with no return value. Trainers should fit and search for the best architecture.
Currently, all the inputs of trainer needs to be manually set before fit (including the search space, data loader
to use training epochs, and etc.).
It has an extra ``export`` function that exports an object representing the final searched architecture.
"""
@abc.abstractmethod
def fit(self) -> None:
pass
@abc.abstractmethod
def export(self) -> Any:
pass
# This file is deprecated.
from typing import Any, List, Dict, Tuple
import numpy as np
......
import warnings
from typing import Dict, Union, Optional, List
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import nni
from ...graph import TrainingConfig
from ...utils import blackbox_module
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
class LightningModule(pl.LightningModule):
def set_model(self, model):
if isinstance(model, type):
self.model = model()
else:
self.model = model
Trainer = blackbox_module(pl.Trainer)
DataLoader = blackbox_module(DataLoader)
class Lightning(TrainingConfig):
"""
Delegate the whole training to PyTorch Lightning.
Since the arguments passed to the initialization needs to be serialized, ``LightningModule``, ``Trainer`` or
``DataLoader`` in this file should be used. Another option is to hide dataloader in the Lightning module, in
which case, dataloaders are not required for this class to work.
Following the programming style of Lightning, metrics sent to NNI should be obtained from ``callback_metrics``
in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name
and type depend on the specific task.
Parameters
----------
lightning_module : LightningModule
Lightning module that defines the training logic.
trainer : Trainer
Lightning trainer that handles the training.
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
"""
def __init__(self, lightning_module: LightningModule, trainer: Trainer,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
assert isinstance(trainer, Trainer), f'Trainer must be imported from {__name__}.'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
self.module = lightning_module
self.trainer = trainer
self.train_dataloader = train_dataloader
self.val_dataloaders = val_dataloaders
@staticmethod
def _load(ir):
return Lightning(ir['module'], ir['trainer'], ir['train_dataloader'], ir['val_dataloaders'])
def _dump(self):
return {
'module': self.module,
'trainer': self.trainer,
'train_dataloader': self.train_dataloader,
'val_dataloaders': self.val_dataloaders
}
def _execute(self, model_cls):
return self.fit(model_cls)
def __eq__(self, other):
return self.function == other.function and self.arguments == other.arguments
def fit(self, model):
"""
Fit the model with provided dataloader, with Lightning trainer.
Parameters
----------
model : nn.Module
The model to fit.
"""
self.module.set_model(model)
return self.trainer.fit(self.module, self.train_dataloader, self.val_dataloaders)
def _check_dataloader(dataloader):
if dataloader is None:
return True
if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader])
return isinstance(dataloader, DataLoader)
### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
self.criterion = criterion()
self.optimizer = optimizer
self.metrics = nn.ModuleDict({name: cls() for name, cls in metrics.items()})
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss, prog_bar=True)
for name, metric in self.metrics.items():
self.log('train_' + name, metric(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('val_' + name, metric(y_hat, y), prog_bar=True)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('test_loss', self.criterion(y_hat, y), prog_bar=True)
for name, metric in self.metrics.items():
self.log('test_' + name, metric(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self._get_validation_metrics())
def _get_validation_metrics(self):
if len(self.metrics) == 1:
metric_name = next(iter(self.metrics))
return self.trainer.callback_metrics['val_' + metric_name].item()
else:
warnings.warn('Multiple metrics without "default" is not supported by current framework.')
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
@blackbox_module
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'acc': pl.metrics.Accuracy},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Classification(Lightning):
"""
Trainer that is used for classification.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.CrossEntropyLoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _ClassificationModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@blackbox_module
class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
class Regression(Lightning):
"""
Trainer that is used for regression.
Parameters
----------
criterion : nn.Module
Class for criterion module (not an instance). default: ``nn.MSELoss``
learning_rate : float
Learning rate. default: 0.001
weight_decay : float
L2 weight decay. default: 0
optimizer : Optimizer
Class for optimizer (not an instance). default: ``Adam``
train_dataloders : DataLoader
Used in ``trainer.fit()``. A PyTorch DataLoader with training samples.
If the ``lightning_module`` has a predefined train_dataloader method this will be skipped.
val_dataloaders : DataLoader or List of DataLoader
Used in ``trainer.fit()``. Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
**trainer_kwargs):
module = _RegressionModule(criterion=criterion, learning_rate=learning_rate,
weight_decay=weight_decay, optimizer=optimizer)
super().__init__(module, Trainer(**trainer_kwargs),
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
import functools
import inspect
import warnings
from collections import defaultdict
from typing import Any
from pathlib import Path
import json_tricks
def import_(target: str, allow_none: bool = False) -> Any:
if target is None:
......@@ -20,6 +22,45 @@ def version_larger_equal(a: str, b: str) -> bool:
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
### This is a patch of json-tricks to make it more useful to us ###
def _blackbox_class_instance_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if hasattr(obj, '__class__') and hasattr(obj, '__init_parameters__'):
return {
'__type__': get_full_class_name(obj.__class__),
'arguments': obj.__init_parameters__
}
return obj
def _blackbox_class_instance_decode(obj):
if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
return import_(obj['__type__'])(**obj['arguments'])
return obj
def _type_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if isinstance(obj, type):
return {'__typename__': get_full_class_name(obj, relocate_module=True)}
return obj
def _type_decode(obj):
if isinstance(obj, dict) and '__typename__' in obj:
return import_(obj['__typename__'])
return obj
json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_blackbox_class_instance_decode, _type_decode])
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_blackbox_class_instance_encode, _type_encode])
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_blackbox_class_instance_decode, _type_decode])
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_blackbox_class_instance_encode, _type_encode])
### End of json-tricks patch ###
_records = {}
......@@ -48,10 +89,10 @@ def del_record(key):
_records.pop(key, None)
def _blackbox_cls(cls, module_name, register_format=None):
def _blackbox_cls(cls):
class wrapper(cls):
def __init__(self, *args, **kwargs):
argname_list = list(inspect.signature(cls).parameters.keys())
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
full_args = {}
full_args.update(kwargs)
......@@ -59,30 +100,16 @@ def _blackbox_cls(cls, module_name, register_format=None):
for argname, value in zip(argname_list, args):
full_args[argname] = value
# eject un-serializable arguments
for k in list(full_args.keys()):
# The list is not complete and does not support nested cases.
if not isinstance(full_args[k], (int, float, str, dict, list, tuple)):
if not (register_format == 'full' and k == 'model'):
# no warning if it is base model in trainer
warnings.warn(f'{cls} has un-serializable arguments {k} whose value is {full_args[k]}. \
This is not supported. You can ignore this warning if you are passing the model to trainer.')
full_args.pop(k)
if register_format == 'args':
add_record(id(self), full_args)
elif register_format == 'full':
full_class_name = cls.__module__ + '.' + cls.__name__
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
add_record(id(self), full_args) # for compatibility. Will remove soon.
self.__init_parameters__ = full_args
super().__init__(*args, **kwargs)
def __del__(self):
del_record(id(self))
# using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped
# instead of simply putting torch.nn or etc.
wrapper.__module__ = module_name
wrapper.__module__ = _get_module_name(cls)
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
......@@ -97,41 +124,21 @@ def blackbox(cls, *args, **kwargs):
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
# get caller module name
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')(*args, **kwargs)
return _blackbox_cls(cls)(*args, **kwargs)
def blackbox_module(cls):
"""
Register a module. Use it as a decorator.
"""
frm = inspect.stack()[1]
assert (inspect.getmodule(frm[0]) is not None), ('unable to locate the definition of the given black box module, '
'please define it explicitly in a .py file.')
module_name = inspect.getmodule(frm[0]).__name__
if module_name == '__main__':
main_file_path = Path(inspect.getsourcefile(frm[0]))
if main_file_path.parents[0] != Path('.'):
raise RuntimeError(f'you are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem
return _blackbox_cls(cls, module_name, 'args')
return _blackbox_cls(cls)
def register_trainer(cls):
"""
Register a trainer. Use it as a decorator.
"""
frm = inspect.stack()[1]
assert (inspect.getmodule(frm[0]) is not None), ('unable to locate the definition of the given trainer, '
'please define it explicitly in a .py file.')
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'full')
return _blackbox_cls(cls)
_last_uid = defaultdict(int)
......@@ -140,3 +147,24 @@ _last_uid = defaultdict(int)
def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
return _last_uid[namespace]
def _get_module_name(cls):
module_name = cls.__module__
if module_name == '__main__':
# infer the module name with inspect
for frm in inspect.stack():
if inspect.getmodule(frm[0]).__name__ == '__main__':
# main module found
main_file_path = Path(inspect.getsourcefile(frm[0]))
if main_file_path.parents[0] != Path('.'):
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem
break
return module_name
def get_full_class_name(cls, relocate_module=False):
module_name = _get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__
......@@ -6,6 +6,7 @@ assessor_result.txt
_generated_model.py
_generated_model_*.py
_generated_model
data
generated
lightning_logs
......@@ -4,21 +4,35 @@ import sys
import torch
from pathlib import Path
from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig
import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii import blackbox_module as bm
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategies import TPEStrategy, RandomStrategy
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer
from torchvision import transforms
from torchvision.datasets import CIFAR10
from darts_model import CNN
if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
#simple_startegy = TPEStrategy()
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = bm(CIFAR10)(root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = bm(CIFAR10)(root='data/cifar10', train=False, download=True, transform=valid_transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2)
simple_startegy = RandomStrategy()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment