Unverified Commit cae4308f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Rename APIs and refine documentation (#3404)

parent d047d6f4
......@@ -42,31 +42,31 @@ Graph Mutation APIs
Trainers
--------
.. autoclass:: nni.retiarii.trainer.FunctionalTrainer
.. autoclass:: nni.retiarii.evaluator.FunctionalEvaluator
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.LightningModule
.. autoclass:: nni.retiarii.evaluator.pytorch.lightning.LightningModule
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.Classification
.. autoclass:: nni.retiarii.evaluator.pytorch.lightning.Classification
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.Regression
.. autoclass:: nni.retiarii.evaluator.pytorch.lightning.Regression
:members:
Oneshot Trainers
----------------
.. autoclass:: nni.retiarii.trainer.pytorch.DartsTrainer
.. autoclass:: nni.retiarii.oneshot.pytorch.DartsTrainer
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.EnasTrainer
.. autoclass:: nni.retiarii.oneshot.pytorch.EnasTrainer
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.ProxylessTrainer
.. autoclass:: nni.retiarii.oneshot.pytorch.ProxylessTrainer
:members:
.. autoclass:: nni.retiarii.trainer.pytorch.SinglePathTrainer
.. autoclass:: nni.retiarii.oneshot.pytorch.SinglePathTrainer
:members:
Strategies
......
......@@ -24,7 +24,7 @@ Define Base Model
Defining a base model is almost the same as defining a PyTorch (or TensorFlow) model. There are only two small differences.
* Replace the code ``import torch.nn as nn`` with ``import nni.retiarii.nn.pytorch as nn`` for PyTorch modules, such as ``nn.Conv2d``, ``nn.ReLU``.
* Some **user-defined** modules should be decorated with ``@blackbox_module``. For example, user-defined module used in ``LayerChoice`` should be decorated. Users can refer to `here <#blackbox-module>`__ for detailed usage instruction of ``@blackbox_module``.
* Some **user-defined** modules should be decorated with ``@basic_unit``. For example, user-defined module used in ``LayerChoice`` should be decorated. Users can refer to `here <#serialize-module>`__ for detailed usage instruction of ``@basic_unit``.
Below is a very simple example of defining a base model, it is almost the same as defining a PyTorch model.
......@@ -59,7 +59,7 @@ A base model is only one concrete model not a model space. We provide APIs and p
For easy usability and also backward compatibility, we provide some APIs for users to easily express possible mutations after defining a base model. The APIs can be used just like PyTorch module.
* ``nn.LayerChoice``. It allows users to put several candidate operations (e.g., PyTorch modules), one of them is chosen in each explored model. *Note that if the candidate is a user-defined module, it should be decorated as `blackbox module <#blackbox-module>`__. In the following example, ``ops.PoolBN`` and ``ops.SepConv`` should be decorated.*
* ``nn.LayerChoice``. It allows users to put several candidate operations (e.g., PyTorch modules), one of them is chosen in each explored model. *Note that if the candidate is a user-defined module, it should be decorated as `serialize module <#serialize-module>`__. In the following example, ``ops.PoolBN`` and ``ops.SepConv`` should be decorated.*
.. code-block:: python
......@@ -83,7 +83,7 @@ For easy usability and also backward compatibility, we provide some APIs for use
# invoked in `forward` function, choose one from the three
out = self.input_switch([tensor1, tensor2, tensor3])
* ``nn.ValueChoice``. It is for choosing one value from some candidate values. It can only be used as input argument of the modules in ``nn.modules`` and ``@blackbox_module`` decorated user-defined modules.
* ``nn.ValueChoice``. It is for choosing one value from some candidate values. It can only be used as input argument of the modules in ``nn.modules`` and ``@basic_unit`` decorated user-defined modules.
.. code-block:: python
......@@ -129,38 +129,37 @@ Use placehoder to make mutation easier: ``nn.Placeholder``. If you want to mutat
.. code-block:: python
ph = nn.Placeholder(label='mutable_0',
related_info={
'kernel_size_options': [1, 3, 5],
'n_layer_options': [1, 2, 3, 4],
'exp_ratio': exp_ratio,
'stride': stride
}
ph = nn.Placeholder(
label='mutable_0',
kernel_size_options=[1, 3, 5],
n_layer_options=[1, 2, 3, 4],
exp_ratio=exp_ratio,
stride=stride
)
``label`` is used by mutator to identify this placeholder, ``related_info`` is the information that are required by mutator. As ``related_info`` is a dict, it could include any information that users want to put to pass it to user defined mutator. The complete example code can be found in :githublink:`Mnasnet base model <test/retiarii_test/mnasnet/base_mnasnet.py>`.
``label`` is used by mutator to identify this placeholder. The other parameters are the information that are required by mutator. They can be accessed from ``node.operation.parameters`` as a dict, it could include any information that users want to put to pass it to user defined mutator. The complete example code can be found in :githublink:`Mnasnet base model <test/retiarii_test/mnasnet/base_mnasnet.py>`.
Explore the Defined Model Space
-------------------------------
After model space is defined, it is time to explore this model space. Users can choose proper search and training approach to explore the model space.
After model space is defined, it is time to explore this model space. Users can choose proper search and model evaluator to explore the model space.
Create a Trainer and Exploration Strategy
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Create an Evaluator and Exploration Strategy
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**Classic search approach:**
In this approach, trainer is for training each explored model, while strategy is for sampling the models. Both trainer and strategy are required to explore the model space. We recommend PyTorch-Lightning to write the full training process.
In this approach, model evaluator is for training and testing each explored model, while strategy is for sampling the models. Both evaluator and strategy are required to explore the model space. We recommend PyTorch-Lightning to write the full evaluation process.
**Oneshot (weight-sharing) search approach:**
In this approach, users only need a oneshot trainer, because this trainer takes charge of both search and training.
In this approach, users only need a oneshot trainer, because this trainer takes charge of both search, training and testing.
In the following table, we listed the available trainers and strategies.
In the following table, we listed the available evaluators and strategies.
.. list-table::
:header-rows: 1
:widths: auto
* - Trainer
* - Evaluator
- Strategy
- Oneshot Trainer
* - Classification
......@@ -178,24 +177,24 @@ In the following table, we listed the available trainers and strategies.
There usage and API document can be found `here <./ApiReference>`__\.
Here is a simple example of using trainer and strategy.
Here is a simple example of using evaluator and strategy.
.. code-block:: python
import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii import blackbox
import nni.retiarii.evaluator.pytorch.lightning as pl
from nni.retiarii import serialize
from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = blackbox(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = blackbox(MNIST, root='data/mnist', train=False, download=True, transform=transform)
train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=10)
.. Note:: For NNI to capture the dataset and dataloader and distribute it across different runs, please wrap your dataset with ``blackbox`` and use ``pl.DataLoader`` instead of ``torch.utils.data.DataLoader``. See ``blackbox_module`` section below for details.
.. Note:: For NNI to capture the dataset and dataloader and distribute it across different runs, please wrap your dataset with ``serialize`` and use ``pl.DataLoader`` instead of ``torch.utils.data.DataLoader``. See ``basic_unit`` section below for details.
Users can refer to `API reference <./ApiReference.rst>`__ on detailed usage of trainer. "`write a trainer <./WriteTrainer.rst>`__" for how to write a new trainer, and refer to `this document <./WriteStrategy.rst>`__ for how to write a new strategy.
Users can refer to `API reference <./ApiReference.rst>`__ on detailed usage of evaluator. "`write a trainer <./WriteTrainer.rst>`__" for how to write a new trainer, and refer to `this document <./WriteStrategy.rst>`__ for how to write a new strategy.
Set up an Experiment
^^^^^^^^^^^^^^^^^^^^
......@@ -231,17 +230,17 @@ If you are using *oneshot (weight-sharing) search approach*, you can invole ``ex
Advanced and FAQ
----------------
.. _blackbox-module:
.. _serialize-module:
**Blackbox Module**
**Serialize Module**
To understand the decorator ``blackbox_module``, we first briefly explain how our framework works: it converts user-defined model to a graph representation (called graph IR), each instantiated module is converted to a subgraph. Then user-defined mutations are applied to the graph to generate new graphs. Each new graph is then converted back to PyTorch code and executed. ``@blackbox_module`` here means the module will not be converted to a subgraph but is converted to a single graph node. That is, the module will not be unfolded anymore. Users should/can decorate a user-defined module class in the following cases:
To understand the decorator ``basic_unit``, we first briefly explain how our framework works: it converts user-defined model to a graph representation (called graph IR), each instantiated module is converted to a subgraph. Then user-defined mutations are applied to the graph to generate new graphs. Each new graph is then converted back to PyTorch code and executed. ``@basic_unit`` here means the module will not be converted to a subgraph but is converted to a single graph node. That is, the module will not be unfolded anymore. Users should/can decorate a user-defined module class in the following cases:
* When a module class cannot be successfully converted to a subgraph due to some implementation issues. For example, currently our framework does not support adhoc loop, if there is adhoc loop in a module's forward, this class should be decorated as blackbox module. The following ``MyModule`` should be decorated.
* When a module class cannot be successfully converted to a subgraph due to some implementation issues. For example, currently our framework does not support adhoc loop, if there is adhoc loop in a module's forward, this class should be decorated as serializeble module. The following ``MyModule`` should be decorated.
.. code-block:: python
@blackbox_module
@basic_unit
class MyModule(nn.Module):
def __init__(self):
...
......@@ -249,6 +248,6 @@ To understand the decorator ``blackbox_module``, we first briefly explain how ou
for i in range(10): # <- adhoc loop
...
* The candidate ops in ``LayerChoice`` should be decorated as blackbox module. For example, ``self.op = nn.LayerChoice([Op1(...), Op2(...), Op3(...)])``, where ``Op1``, ``Op2``, ``Op3`` should be decorated if they are user defined modules.
* When users want to use ``ValueChoice`` in a module's input argument, the module should be decorated as blackbox module. For example, ``self.conv = MyConv(kernel_size=nn.ValueChoice([1, 3, 5]))``, where ``MyConv`` should be decorated.
* If no mutation is targeted on a module, this module *can be* decorated as a blackbox module.
\ No newline at end of file
* The candidate ops in ``LayerChoice`` should be decorated as serializable module. For example, ``self.op = nn.LayerChoice([Op1(...), Op2(...), Op3(...)])``, where ``Op1``, ``Op2``, ``Op3`` should be decorated if they are user defined modules.
* When users want to use ``ValueChoice`` in a module's input argument, the module should be decorated as serializable module. For example, ``self.conv = MyConv(kernel_size=nn.ValueChoice([1, 3, 5]))``, where ``MyConv`` should be decorated.
* If no mutation is targeted on a module, this module *can be* decorated as a serializable module.
Customize A New Trainer
=======================
Customize A New Evaluator/Trainer
=================================
Trainers are necessary to evaluate the performance of new explored models. In NAS scenario, this further divides into two use cases:
Evaluators/Trainers are necessary to evaluate the performance of new explored models. In NAS scenario, this further divides into two use cases:
1. **Single-arch trainers**: trainers that are used to train and evaluate one single model.
1. **Single-arch evaluators**: evaluators that are used to train and evaluate one single model.
2. **One-shot trainers**: trainers that handle training and searching simultaneously, from an end-to-end perspective.
Single-arch trainers
--------------------
Single-arch evaluators
----------------------
With FunctionalEvaluator
^^^^^^^^^^^^^^^^^^^^^^^^
The simplest way to customize a new evaluator is with functional APIs, which is very easy when training code is already available. Users only need to write a fit function that wraps everything. This function takes one positional arguments (model) and possible keyword arguments. In this way, users get everything under their control, but exposes less information to the framework and thus fewer opportunities for possible optimization. An example is as belows:
.. code-block:: python
from nni.retiarii.evaluator import FunctionalEvaluator
from nni.retiarii.experiment.pytorch import RetiariiExperiment
def fit(model, dataloader):
train(model, dataloader)
acc = test(model, dataloader)
nni.report_final_result(acc)
evaluator = FunctionalEvaluator(fit, dataloader=DataLoader(foo, bar))
experiment = RetiariiExperiment(base_model, evaluator, mutators, strategy)
With PyTorch-Lightning
^^^^^^^^^^^^^^^^^^^^^^
It's recommended to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the `document of PyTorch-lightning <https://pytorch-lightning.readthedocs.io/>` to learn the basic concepts and components provided by PyTorch-lightning.
In pratice, writing a new training module in NNI should inherit ``nni.retiarii.trainer.pytorch.lightning.LightningModule``, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Trainers should also communicate with strategies via two API calls (``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively.
In pratice, writing a new training module in NNI should inherit ``nni.retiarii.evaluator.pytorch.lightning.LightningModule``, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Evaluators should also communicate with strategies via two API calls (``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively.
An example is as follows:
.. code-block::python
.. code-block:: python
from nni.retiarii.trainer.pytorch.lightning import LightningModule # please import this one
from nni.retiarii.evaluator.pytorch.lightning import LightningModule # please import this one
@blackbox_module
@basic_unit
class AutoEncoder(LightningModule):
def __init__(self):
super().__init__()
......@@ -69,9 +87,9 @@ An example is as follows:
Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a ``Lightning`` object, and pass this object into a Retiarii experiment.
.. code-block::python
.. code-block:: python
import nni.retiarii.trainer.pytorch.lightning as pl
import nni.retiarii.evaluator.pytorch.lightning as pl
from nni.retiarii.experiment.pytorch import RetiariiExperiment
lightning = pl.Lightning(AutoEncoder(),
......@@ -80,38 +98,20 @@ Then, users need to wrap everything (including LightningModule, trainer and data
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
experiment = RetiariiExperiment(base_model, lightning, mutators, strategy)
With FunctionalTrainer
^^^^^^^^^^^^^^^^^^^^^^
There is another way to customize a new trainer with functional APIs, which provides more flexibility. Users only need to write a fit function that wraps everything. This function takes one positional arguments (model) and possible keyword arguments. In this way, users get everything under their control, but exposes less information to the framework and thus fewer opportunities for possible optimization. An example is as belows:
.. code-block::python
from nni.retiarii.trainer import FunctionalTrainer
from nni.retiarii.experiment.pytorch import RetiariiExperiment
def fit(model, dataloader):
train(model, dataloader)
acc = test(model, dataloader)
nni.report_final_result(acc)
trainer = FunctionalTrainer(fit, dataloader=DataLoader(foo, bar))
experiment = RetiariiExperiment(base_model, trainer, mutators, strategy)
One-shot trainers
-----------------
One-shot trainers should inheirt ``nni.retiarii.trainer.BaseOneShotTrainer``, and need to implement ``fit()`` (used to conduct the fitting and searching process) and ``export()`` method (used to return the searched best architecture).
One-shot trainers should inheirt ``nni.retiarii.oneshot.BaseOneShotTrainer``, and need to implement ``fit()`` (used to conduct the fitting and searching process) and ``export()`` method (used to return the searched best architecture).
Writing a one-shot trainer is very different to classic trainers. First of all, there are no more restrictions on init method arguments, any Python arguments are acceptable. Secondly, the model feeded into one-shot trainers might be a model with Retiarii-specific modules, such as LayerChoice and InputChoice. Such model cannot directly forward-propagate and trainers need to decide how to handle those modules.
Writing a one-shot trainer is very different to classic evaluators. First of all, there are no more restrictions on init method arguments, any Python arguments are acceptable. Secondly, the model feeded into one-shot trainers might be a model with Retiarii-specific modules, such as LayerChoice and InputChoice. Such model cannot directly forward-propagate and trainers need to decide how to handle those modules.
A typical example is DartsTrainer, where learnable-parameters are used to combine multiple choices in LayerChoice. Retiarii provides ease-to-use utility functions for module-replace purposes, namely ``replace_layer_choice``, ``replace_input_choice``. A simplified example is as follows:
.. code-block::python
.. code-block:: python
from nni.retiarii.trainer.pytorch import BaseOneShotTrainer
from nni.retiarii.trainer.pytorch.utils import replace_layer_choice, replace_input_choice
from nni.retiarii.oneshot import BaseOneShotTrainer
from nni.retiarii.oneshot.pytorch import replace_layer_choice, replace_input_choice
class DartsLayerChoice(nn.Module):
......
......@@ -55,7 +55,7 @@ if __name__ == "__main__":
trainer.train()
else:
from nni.retiarii.trainer.pytorch import DartsTrainer
from nni.retiarii.oneshot.pytorch import DartsTrainer
trainer = DartsTrainer(
model=model,
loss=criterion,
......
......@@ -48,7 +48,7 @@ class Cell(nn.Module):
], key=cell_name + "_op")
def forward(self, prev_layers):
from nni.retiarii.trainer.pytorch.random import PathSamplingInputChoice
from nni.retiarii.oneshot.pytorch.random import PathSamplingInputChoice
out = self.input_choice(prev_layers)
if isinstance(self.input_choice, PathSamplingInputChoice):
# Retiarii pattern
......
......@@ -66,7 +66,7 @@ if __name__ == "__main__":
trainer.enable_visualization()
trainer.train()
else:
from nni.retiarii.trainer.pytorch.enas import EnasTrainer
from nni.retiarii.oneshot.pytorch.enas import EnasTrainer
trainer = EnasTrainer(model,
loss=criterion,
metrics=accuracy,
......
......@@ -84,7 +84,7 @@ if __name__ == "__main__":
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)
if args.train_mode == 'search':
from nni.retiarii.trainer.pytorch import ProxylessTrainer
from nni.retiarii.oneshot.pytorch import ProxylessTrainer
from torchvision.datasets import ImageNet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
......
......@@ -2,4 +2,4 @@ from .operation import Operation
from .graph import *
from .execution import *
from .mutator import *
from .utils import blackbox, blackbox_module, json_dump, json_dumps, json_load, json_loads, register_trainer
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls
import logging
import re
import torch
......@@ -6,18 +5,16 @@ import torch
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell, Operation
from ..utils import get_records
from ..serializer import get_init_parameters_or_fail
from ..utils import get_full_class_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__)
class GraphConverter:
def __init__(self):
self.global_seq = 0
self.global_graph_id = 0
self.modules_arg = get_records()
def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index):
if _input in graph_inputs:
......@@ -247,7 +244,7 @@ class GraphConverter:
raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.')
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, '
'you are suggested to decorate the corresponding class with "@blackbox_module".')
'you are suggested to decorate the corresponding class with "@basic_unit".')
expr = _generate_expr(cond_tensor)
return eval(expr)
......@@ -539,13 +536,8 @@ class GraphConverter:
def _handle_layerchoice(self, module):
choices = []
for cand in list(module):
assert id(cand) in self.modules_arg, \
f'Module not recorded: {id(cand)}. ' \
'Try to import from `retiarii.nn` if you are using torch.nn module or ' \
'annotate your customized module with @blackbox_module.'
assert isinstance(self.modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
choices.append({'type': cand_type, 'parameters': self.modules_arg[id(cand)]})
cand_type = '__torch__.' + get_full_class_name(cand.__class__)
choices.append({'type': cand_type, 'parameters': get_init_parameters_or_fail(cand)})
return {
'candidates': choices,
'label': module.label
......@@ -601,14 +593,13 @@ class GraphConverter:
elif original_type_name == OpTypeName.ValueChoice:
m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder:
m_attrs = self.modules_arg[id(module)]
m_attrs = get_init_parameters_or_fail(module)
elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__:
# this is a basic module from pytorch, no need to parse its graph
assert id(module) in self.modules_arg, f'{original_type_name} arguments are not recorded'
m_attrs = self.modules_arg[id(module)]
elif id(module) in self.modules_arg:
# this module is marked as blackbox, won't continue to parse
m_attrs = self.modules_arg[id(module)]
m_attrs = get_init_parameters_or_fail(module)
else:
# this module is marked as serialize, won't continue to parse
m_attrs = get_init_parameters_or_fail(module, silently=True)
if m_attrs is not None:
return None, m_attrs
......
......@@ -3,7 +3,7 @@ import graphviz
def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_training_config':
if name == '_evaluator':
continue
with vgraph.subgraph(name='cluster'+name) as subgraph:
subgraph.attr(color='blue')
......
from .functional import FunctionalEvaluator
from ..graph import TrainingConfig
from ..graph import Evaluator
class FunctionalTrainer(TrainingConfig):
class FunctionalEvaluator(Evaluator):
"""
Functional training config that directly takes a function and thus should be general.
Functional evaluator that directly takes a function and thus should be general.
Attributes
----------
......@@ -19,7 +19,7 @@ class FunctionalTrainer(TrainingConfig):
@staticmethod
def _load(ir):
return FunctionalTrainer(ir['function'], **ir['arguments'])
return FunctionalEvaluator(ir['function'], **ir['arguments'])
def _dump(self):
return {
......
from .base import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from .lightning import *
# This file is deprecated.
import abc
from typing import Any, List, Dict, Tuple
import numpy as np
......@@ -10,8 +11,10 @@ from torchvision import datasets, transforms
import nni
from ..interface import BaseTrainer
from ...utils import register_trainer
class BaseTrainer(abc.ABC):
@abc.abstractmethod
def fit(self) -> None:
pass
def get_default_transform(dataset: str) -> Any:
......@@ -45,7 +48,6 @@ def get_default_transform(dataset: str) -> Any:
return None
@register_trainer
class PyTorchImageClassificationTrainer(BaseTrainer):
"""
Image classification trainer for PyTorch.
......
......@@ -7,8 +7,8 @@ import torch.optim as optim
from torch.utils.data import DataLoader
import nni
from ...graph import TrainingConfig
from ...utils import blackbox_module
from ...graph import Evaluator
from ...serializer import serialize_cls
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
......@@ -22,11 +22,11 @@ class LightningModule(pl.LightningModule):
self.model = model
Trainer = blackbox_module(pl.Trainer)
DataLoader = blackbox_module(DataLoader)
Trainer = serialize_cls(pl.Trainer)
DataLoader = serialize_cls(DataLoader)
class Lightning(TrainingConfig):
class Lightning(Evaluator):
"""
Delegate the whole training to PyTorch Lightning.
......@@ -162,7 +162,7 @@ class _SupervisedLearningModule(LightningModule):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
@blackbox_module
@serialize_cls
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
......@@ -210,7 +210,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@blackbox_module
@serialize_cls
class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
......
......@@ -6,25 +6,25 @@ from typing import Dict, List
from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, TrainingConfig
from ..graph import Model, ModelStatus, MetricData, Evaluator
from ..integration_api import send_trial, receive_trial_parameters, get_advisor
_logger = logging.getLogger(__name__)
class BaseGraphData:
def __init__(self, model_script: str, training_config: TrainingConfig) -> None:
def __init__(self, model_script: str, evaluator: Evaluator) -> None:
self.model_script = model_script
self.training_config = training_config
self.evaluator = evaluator
def dump(self) -> dict:
return {
'model_script': self.model_script,
'training_config': self.training_config
'evaluator': self.evaluator
}
@staticmethod
def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], data['training_config'])
return BaseGraphData(data['model_script'], data['evaluator'])
class BaseExecutionEngine(AbstractExecutionEngine):
......@@ -55,7 +55,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), model.training_config)
data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
self._running_models[send_trial(data.dump())] = model
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
......@@ -107,5 +107,5 @@ class BaseExecutionEngine(AbstractExecutionEngine):
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
model_cls = utils.import_(f'_generated_model.{random_str}._model')
graph_data.training_config._execute(model_cls)
graph_data.evaluator._execute(model_cls)
os.remove(file_name)
......@@ -44,7 +44,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement),
model.training_config)
model.evaluator)
for m in grouped_models:
self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model
......
......@@ -145,11 +145,11 @@ class LogicalPlan:
# Add a flag to mark multi-model in graph json.
# Multi-model has a list of training configs in kwargs['model_kwargs']
if len(multi_model_placement) > 1:
phy_model.training_config.kwargs['is_multi_model'] = True
phy_model.training_config.kwargs['model_cls'] = phy_graph.name
phy_model.training_config.kwargs['model_kwargs'] = []
phy_model.evaluator.kwargs['is_multi_model'] = True
phy_model.evaluator.kwargs['model_cls'] = phy_graph.name
phy_model.evaluator.kwargs['model_kwargs'] = []
# FIXME: allow user to specify
phy_model.training_config.module = 'nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer'
phy_model.evaluator.module = 'nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer'
# merge sub-graphs
for model in multi_model_placement:
......@@ -160,7 +160,7 @@ class LogicalPlan:
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
training_config_slot = {} # Model ID -> Slot ID
evaluator_slot = {} # Model ID -> Slot ID
input_slot_mapping = {}
output_slot_mapping = {}
# Replace all logical nodes to executable physical nodes
......@@ -181,25 +181,25 @@ class LogicalPlan:
new_node, placement = node.assemble(multi_model_placement)
if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id
if model_id not in training_config_slot:
phy_model.training_config.kwargs['model_kwargs'].append(new_node.graph.model.training_config.kwargs.copy())
training_config_slot[model_id] = len(phy_model.training_config.kwargs['model_kwargs']) - 1
slot = training_config_slot[model_id]
phy_model.training_config.kwargs['model_kwargs'][slot]['model_id'] = model_id
phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = False
phy_model.training_config.kwargs['model_kwargs'][slot]['use_output'] = False
if model_id not in evaluator_slot:
phy_model.evaluator.kwargs['model_kwargs'].append(new_node.graph.model.evaluator.kwargs.copy())
evaluator_slot[model_id] = len(phy_model.evaluator.kwargs['model_kwargs']) - 1
slot = evaluator_slot[model_id]
phy_model.evaluator.kwargs['model_kwargs'][slot]['model_id'] = model_id
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_input'] = False
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_output'] = False
else:
slot = training_config_slot[model_id]
slot = evaluator_slot[model_id]
# If a model's inputs/outputs are not used in the multi-model
# the codegen and trainer should not generate and use them
# "use_input" and "use_output" are used to mark whether
# an input/output of a model is used in a multi-model
if new_node.operation.type == '_inputs':
input_slot_mapping[new_node] = slot
phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = True
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_input'] = True
if new_node.operation.type == '_outputs':
output_slot_mapping[new_node] = slot
phy_model.training_config.kwargs['model_kwargs'][slot]['use_output'] = True
phy_model.evaluator.kwargs['model_kwargs'][slot]['use_output'] = True
self.node_replace(node, new_node)
......
......@@ -45,9 +45,9 @@ class DedupInputOptimizer(AbstractOptimizer):
node_to_check.operation.type == '_inputs' and \
isinstance(root_node, OriginNode) and \
isinstance(node_to_check, OriginNode):
if root_node.original_graph.model.training_config.module not in _supported_training_modules:
if root_node.original_graph.model.evaluator.module not in _supported_training_modules:
return False
if root_node.original_graph.model.training_config == node_to_check.original_graph.model.training_config:
if root_node.original_graph.model.evaluator == node_to_check.original_graph.model.evaluator:
return True
else:
return False
......
......@@ -13,13 +13,12 @@ from nni.experiment.config.base import ConfigBase, PathLike
from nni.experiment.pipe import Pipe
from ..converter import convert_to_graph
from ..graph import Model, TrainingConfig
from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import process_inline_mutation
from ..strategy import BaseStrategy
from ..trainer.interface import BaseOneShotTrainer, BaseTrainer
from ..utils import get_records
from ..oneshot.interface import BaseOneShotTrainer
_logger = logging.getLogger(__name__)
......@@ -77,7 +76,7 @@ _validation_rules = {
class RetiariiExperiment(Experiment):
def __init__(self, base_model: nn.Module, trainer: Union[TrainingConfig, BaseOneShotTrainer],
def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotTrainer],
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None):
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = None
......@@ -87,7 +86,6 @@ class RetiariiExperiment(Experiment):
self.trainer = trainer
self.applied_mutators = applied_mutators
self.strategy = strategy
self.recorded_module_args = get_records()
self._dispatcher = RetiariiAdvisor()
self._dispatcher_thread: Optional[Thread] = None
......@@ -101,7 +99,7 @@ class RetiariiExperiment(Experiment):
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
base_model_ir = convert_to_graph(script_module, self.base_model)
base_model_ir.training_config = self.trainer
base_model_ir.evaluator = self.trainer
# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment