"docs/source/vscode:/vscode.git/clone" did not exist on "900be804b0189fd7c01eb262fdf6d63b5689d9b7"
Unverified Commit cae4308f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Rename APIs and refine documentation (#3404)

parent d047d6f4
...@@ -42,31 +42,31 @@ Graph Mutation APIs ...@@ -42,31 +42,31 @@ Graph Mutation APIs
Trainers Trainers
-------- --------
.. autoclass:: nni.retiarii.trainer.FunctionalTrainer .. autoclass:: nni.retiarii.evaluator.FunctionalEvaluator
:members: :members:
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.LightningModule .. autoclass:: nni.retiarii.evaluator.pytorch.lightning.LightningModule
:members: :members:
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.Classification .. autoclass:: nni.retiarii.evaluator.pytorch.lightning.Classification
:members: :members:
.. autoclass:: nni.retiarii.trainer.pytorch.lightning.Regression .. autoclass:: nni.retiarii.evaluator.pytorch.lightning.Regression
:members: :members:
Oneshot Trainers Oneshot Trainers
---------------- ----------------
.. autoclass:: nni.retiarii.trainer.pytorch.DartsTrainer .. autoclass:: nni.retiarii.oneshot.pytorch.DartsTrainer
:members: :members:
.. autoclass:: nni.retiarii.trainer.pytorch.EnasTrainer .. autoclass:: nni.retiarii.oneshot.pytorch.EnasTrainer
:members: :members:
.. autoclass:: nni.retiarii.trainer.pytorch.ProxylessTrainer .. autoclass:: nni.retiarii.oneshot.pytorch.ProxylessTrainer
:members: :members:
.. autoclass:: nni.retiarii.trainer.pytorch.SinglePathTrainer .. autoclass:: nni.retiarii.oneshot.pytorch.SinglePathTrainer
:members: :members:
Strategies Strategies
......
...@@ -24,7 +24,7 @@ Define Base Model ...@@ -24,7 +24,7 @@ Define Base Model
Defining a base model is almost the same as defining a PyTorch (or TensorFlow) model. There are only two small differences. Defining a base model is almost the same as defining a PyTorch (or TensorFlow) model. There are only two small differences.
* Replace the code ``import torch.nn as nn`` with ``import nni.retiarii.nn.pytorch as nn`` for PyTorch modules, such as ``nn.Conv2d``, ``nn.ReLU``. * Replace the code ``import torch.nn as nn`` with ``import nni.retiarii.nn.pytorch as nn`` for PyTorch modules, such as ``nn.Conv2d``, ``nn.ReLU``.
* Some **user-defined** modules should be decorated with ``@blackbox_module``. For example, user-defined module used in ``LayerChoice`` should be decorated. Users can refer to `here <#blackbox-module>`__ for detailed usage instruction of ``@blackbox_module``. * Some **user-defined** modules should be decorated with ``@basic_unit``. For example, user-defined module used in ``LayerChoice`` should be decorated. Users can refer to `here <#serialize-module>`__ for detailed usage instruction of ``@basic_unit``.
Below is a very simple example of defining a base model, it is almost the same as defining a PyTorch model. Below is a very simple example of defining a base model, it is almost the same as defining a PyTorch model.
...@@ -59,7 +59,7 @@ A base model is only one concrete model not a model space. We provide APIs and p ...@@ -59,7 +59,7 @@ A base model is only one concrete model not a model space. We provide APIs and p
For easy usability and also backward compatibility, we provide some APIs for users to easily express possible mutations after defining a base model. The APIs can be used just like PyTorch module. For easy usability and also backward compatibility, we provide some APIs for users to easily express possible mutations after defining a base model. The APIs can be used just like PyTorch module.
* ``nn.LayerChoice``. It allows users to put several candidate operations (e.g., PyTorch modules), one of them is chosen in each explored model. *Note that if the candidate is a user-defined module, it should be decorated as `blackbox module <#blackbox-module>`__. In the following example, ``ops.PoolBN`` and ``ops.SepConv`` should be decorated.* * ``nn.LayerChoice``. It allows users to put several candidate operations (e.g., PyTorch modules), one of them is chosen in each explored model. *Note that if the candidate is a user-defined module, it should be decorated as `serialize module <#serialize-module>`__. In the following example, ``ops.PoolBN`` and ``ops.SepConv`` should be decorated.*
.. code-block:: python .. code-block:: python
...@@ -83,7 +83,7 @@ For easy usability and also backward compatibility, we provide some APIs for use ...@@ -83,7 +83,7 @@ For easy usability and also backward compatibility, we provide some APIs for use
# invoked in `forward` function, choose one from the three # invoked in `forward` function, choose one from the three
out = self.input_switch([tensor1, tensor2, tensor3]) out = self.input_switch([tensor1, tensor2, tensor3])
* ``nn.ValueChoice``. It is for choosing one value from some candidate values. It can only be used as input argument of the modules in ``nn.modules`` and ``@blackbox_module`` decorated user-defined modules. * ``nn.ValueChoice``. It is for choosing one value from some candidate values. It can only be used as input argument of the modules in ``nn.modules`` and ``@basic_unit`` decorated user-defined modules.
.. code-block:: python .. code-block:: python
...@@ -129,38 +129,37 @@ Use placehoder to make mutation easier: ``nn.Placeholder``. If you want to mutat ...@@ -129,38 +129,37 @@ Use placehoder to make mutation easier: ``nn.Placeholder``. If you want to mutat
.. code-block:: python .. code-block:: python
ph = nn.Placeholder(label='mutable_0', ph = nn.Placeholder(
related_info={ label='mutable_0',
'kernel_size_options': [1, 3, 5], kernel_size_options=[1, 3, 5],
'n_layer_options': [1, 2, 3, 4], n_layer_options=[1, 2, 3, 4],
'exp_ratio': exp_ratio, exp_ratio=exp_ratio,
'stride': stride stride=stride
}
) )
``label`` is used by mutator to identify this placeholder, ``related_info`` is the information that are required by mutator. As ``related_info`` is a dict, it could include any information that users want to put to pass it to user defined mutator. The complete example code can be found in :githublink:`Mnasnet base model <test/retiarii_test/mnasnet/base_mnasnet.py>`. ``label`` is used by mutator to identify this placeholder. The other parameters are the information that are required by mutator. They can be accessed from ``node.operation.parameters`` as a dict, it could include any information that users want to put to pass it to user defined mutator. The complete example code can be found in :githublink:`Mnasnet base model <test/retiarii_test/mnasnet/base_mnasnet.py>`.
Explore the Defined Model Space Explore the Defined Model Space
------------------------------- -------------------------------
After model space is defined, it is time to explore this model space. Users can choose proper search and training approach to explore the model space. After model space is defined, it is time to explore this model space. Users can choose proper search and model evaluator to explore the model space.
Create a Trainer and Exploration Strategy Create an Evaluator and Exploration Strategy
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**Classic search approach:** **Classic search approach:**
In this approach, trainer is for training each explored model, while strategy is for sampling the models. Both trainer and strategy are required to explore the model space. We recommend PyTorch-Lightning to write the full training process. In this approach, model evaluator is for training and testing each explored model, while strategy is for sampling the models. Both evaluator and strategy are required to explore the model space. We recommend PyTorch-Lightning to write the full evaluation process.
**Oneshot (weight-sharing) search approach:** **Oneshot (weight-sharing) search approach:**
In this approach, users only need a oneshot trainer, because this trainer takes charge of both search and training. In this approach, users only need a oneshot trainer, because this trainer takes charge of both search, training and testing.
In the following table, we listed the available trainers and strategies. In the following table, we listed the available evaluators and strategies.
.. list-table:: .. list-table::
:header-rows: 1 :header-rows: 1
:widths: auto :widths: auto
* - Trainer * - Evaluator
- Strategy - Strategy
- Oneshot Trainer - Oneshot Trainer
* - Classification * - Classification
...@@ -178,24 +177,24 @@ In the following table, we listed the available trainers and strategies. ...@@ -178,24 +177,24 @@ In the following table, we listed the available trainers and strategies.
There usage and API document can be found `here <./ApiReference>`__\. There usage and API document can be found `here <./ApiReference>`__\.
Here is a simple example of using trainer and strategy. Here is a simple example of using evaluator and strategy.
.. code-block:: python .. code-block:: python
import nni.retiarii.trainer.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
from nni.retiarii import blackbox from nni.retiarii import serialize
from torchvision import transforms from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = blackbox(MNIST, root='data/mnist', train=True, download=True, transform=transform) train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = blackbox(MNIST, root='data/mnist', train=False, download=True, transform=transform) test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100), lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=10) max_epochs=10)
.. Note:: For NNI to capture the dataset and dataloader and distribute it across different runs, please wrap your dataset with ``blackbox`` and use ``pl.DataLoader`` instead of ``torch.utils.data.DataLoader``. See ``blackbox_module`` section below for details. .. Note:: For NNI to capture the dataset and dataloader and distribute it across different runs, please wrap your dataset with ``serialize`` and use ``pl.DataLoader`` instead of ``torch.utils.data.DataLoader``. See ``basic_unit`` section below for details.
Users can refer to `API reference <./ApiReference.rst>`__ on detailed usage of trainer. "`write a trainer <./WriteTrainer.rst>`__" for how to write a new trainer, and refer to `this document <./WriteStrategy.rst>`__ for how to write a new strategy. Users can refer to `API reference <./ApiReference.rst>`__ on detailed usage of evaluator. "`write a trainer <./WriteTrainer.rst>`__" for how to write a new trainer, and refer to `this document <./WriteStrategy.rst>`__ for how to write a new strategy.
Set up an Experiment Set up an Experiment
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
...@@ -231,17 +230,17 @@ If you are using *oneshot (weight-sharing) search approach*, you can invole ``ex ...@@ -231,17 +230,17 @@ If you are using *oneshot (weight-sharing) search approach*, you can invole ``ex
Advanced and FAQ Advanced and FAQ
---------------- ----------------
.. _blackbox-module: .. _serialize-module:
**Blackbox Module** **Serialize Module**
To understand the decorator ``blackbox_module``, we first briefly explain how our framework works: it converts user-defined model to a graph representation (called graph IR), each instantiated module is converted to a subgraph. Then user-defined mutations are applied to the graph to generate new graphs. Each new graph is then converted back to PyTorch code and executed. ``@blackbox_module`` here means the module will not be converted to a subgraph but is converted to a single graph node. That is, the module will not be unfolded anymore. Users should/can decorate a user-defined module class in the following cases: To understand the decorator ``basic_unit``, we first briefly explain how our framework works: it converts user-defined model to a graph representation (called graph IR), each instantiated module is converted to a subgraph. Then user-defined mutations are applied to the graph to generate new graphs. Each new graph is then converted back to PyTorch code and executed. ``@basic_unit`` here means the module will not be converted to a subgraph but is converted to a single graph node. That is, the module will not be unfolded anymore. Users should/can decorate a user-defined module class in the following cases:
* When a module class cannot be successfully converted to a subgraph due to some implementation issues. For example, currently our framework does not support adhoc loop, if there is adhoc loop in a module's forward, this class should be decorated as blackbox module. The following ``MyModule`` should be decorated. * When a module class cannot be successfully converted to a subgraph due to some implementation issues. For example, currently our framework does not support adhoc loop, if there is adhoc loop in a module's forward, this class should be decorated as serializeble module. The following ``MyModule`` should be decorated.
.. code-block:: python .. code-block:: python
@blackbox_module @basic_unit
class MyModule(nn.Module): class MyModule(nn.Module):
def __init__(self): def __init__(self):
... ...
...@@ -249,6 +248,6 @@ To understand the decorator ``blackbox_module``, we first briefly explain how ou ...@@ -249,6 +248,6 @@ To understand the decorator ``blackbox_module``, we first briefly explain how ou
for i in range(10): # <- adhoc loop for i in range(10): # <- adhoc loop
... ...
* The candidate ops in ``LayerChoice`` should be decorated as blackbox module. For example, ``self.op = nn.LayerChoice([Op1(...), Op2(...), Op3(...)])``, where ``Op1``, ``Op2``, ``Op3`` should be decorated if they are user defined modules. * The candidate ops in ``LayerChoice`` should be decorated as serializable module. For example, ``self.op = nn.LayerChoice([Op1(...), Op2(...), Op3(...)])``, where ``Op1``, ``Op2``, ``Op3`` should be decorated if they are user defined modules.
* When users want to use ``ValueChoice`` in a module's input argument, the module should be decorated as blackbox module. For example, ``self.conv = MyConv(kernel_size=nn.ValueChoice([1, 3, 5]))``, where ``MyConv`` should be decorated. * When users want to use ``ValueChoice`` in a module's input argument, the module should be decorated as serializable module. For example, ``self.conv = MyConv(kernel_size=nn.ValueChoice([1, 3, 5]))``, where ``MyConv`` should be decorated.
* If no mutation is targeted on a module, this module *can be* decorated as a blackbox module. * If no mutation is targeted on a module, this module *can be* decorated as a serializable module.
\ No newline at end of file
Customize A New Trainer Customize A New Evaluator/Trainer
======================= =================================
Trainers are necessary to evaluate the performance of new explored models. In NAS scenario, this further divides into two use cases: Evaluators/Trainers are necessary to evaluate the performance of new explored models. In NAS scenario, this further divides into two use cases:
1. **Single-arch trainers**: trainers that are used to train and evaluate one single model. 1. **Single-arch evaluators**: evaluators that are used to train and evaluate one single model.
2. **One-shot trainers**: trainers that handle training and searching simultaneously, from an end-to-end perspective. 2. **One-shot trainers**: trainers that handle training and searching simultaneously, from an end-to-end perspective.
Single-arch trainers Single-arch evaluators
-------------------- ----------------------
With FunctionalEvaluator
^^^^^^^^^^^^^^^^^^^^^^^^
The simplest way to customize a new evaluator is with functional APIs, which is very easy when training code is already available. Users only need to write a fit function that wraps everything. This function takes one positional arguments (model) and possible keyword arguments. In this way, users get everything under their control, but exposes less information to the framework and thus fewer opportunities for possible optimization. An example is as belows:
.. code-block:: python
from nni.retiarii.evaluator import FunctionalEvaluator
from nni.retiarii.experiment.pytorch import RetiariiExperiment
def fit(model, dataloader):
train(model, dataloader)
acc = test(model, dataloader)
nni.report_final_result(acc)
evaluator = FunctionalEvaluator(fit, dataloader=DataLoader(foo, bar))
experiment = RetiariiExperiment(base_model, evaluator, mutators, strategy)
With PyTorch-Lightning With PyTorch-Lightning
^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^
It's recommended to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the `document of PyTorch-lightning <https://pytorch-lightning.readthedocs.io/>` to learn the basic concepts and components provided by PyTorch-lightning. It's recommended to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the `document of PyTorch-lightning <https://pytorch-lightning.readthedocs.io/>` to learn the basic concepts and components provided by PyTorch-lightning.
In pratice, writing a new training module in NNI should inherit ``nni.retiarii.trainer.pytorch.lightning.LightningModule``, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Trainers should also communicate with strategies via two API calls (``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively. In pratice, writing a new training module in NNI should inherit ``nni.retiarii.evaluator.pytorch.lightning.LightningModule``, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Evaluators should also communicate with strategies via two API calls (``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively.
An example is as follows: An example is as follows:
.. code-block::python .. code-block:: python
from nni.retiarii.trainer.pytorch.lightning import LightningModule # please import this one from nni.retiarii.evaluator.pytorch.lightning import LightningModule # please import this one
@blackbox_module @basic_unit
class AutoEncoder(LightningModule): class AutoEncoder(LightningModule):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -69,9 +87,9 @@ An example is as follows: ...@@ -69,9 +87,9 @@ An example is as follows:
Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a ``Lightning`` object, and pass this object into a Retiarii experiment. Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a ``Lightning`` object, and pass this object into a Retiarii experiment.
.. code-block::python .. code-block:: python
import nni.retiarii.trainer.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
from nni.retiarii.experiment.pytorch import RetiariiExperiment from nni.retiarii.experiment.pytorch import RetiariiExperiment
lightning = pl.Lightning(AutoEncoder(), lightning = pl.Lightning(AutoEncoder(),
...@@ -80,38 +98,20 @@ Then, users need to wrap everything (including LightningModule, trainer and data ...@@ -80,38 +98,20 @@ Then, users need to wrap everything (including LightningModule, trainer and data
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100)) val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
experiment = RetiariiExperiment(base_model, lightning, mutators, strategy) experiment = RetiariiExperiment(base_model, lightning, mutators, strategy)
With FunctionalTrainer
^^^^^^^^^^^^^^^^^^^^^^
There is another way to customize a new trainer with functional APIs, which provides more flexibility. Users only need to write a fit function that wraps everything. This function takes one positional arguments (model) and possible keyword arguments. In this way, users get everything under their control, but exposes less information to the framework and thus fewer opportunities for possible optimization. An example is as belows:
.. code-block::python
from nni.retiarii.trainer import FunctionalTrainer
from nni.retiarii.experiment.pytorch import RetiariiExperiment
def fit(model, dataloader):
train(model, dataloader)
acc = test(model, dataloader)
nni.report_final_result(acc)
trainer = FunctionalTrainer(fit, dataloader=DataLoader(foo, bar))
experiment = RetiariiExperiment(base_model, trainer, mutators, strategy)
One-shot trainers One-shot trainers
----------------- -----------------
One-shot trainers should inheirt ``nni.retiarii.trainer.BaseOneShotTrainer``, and need to implement ``fit()`` (used to conduct the fitting and searching process) and ``export()`` method (used to return the searched best architecture). One-shot trainers should inheirt ``nni.retiarii.oneshot.BaseOneShotTrainer``, and need to implement ``fit()`` (used to conduct the fitting and searching process) and ``export()`` method (used to return the searched best architecture).
Writing a one-shot trainer is very different to classic trainers. First of all, there are no more restrictions on init method arguments, any Python arguments are acceptable. Secondly, the model feeded into one-shot trainers might be a model with Retiarii-specific modules, such as LayerChoice and InputChoice. Such model cannot directly forward-propagate and trainers need to decide how to handle those modules. Writing a one-shot trainer is very different to classic evaluators. First of all, there are no more restrictions on init method arguments, any Python arguments are acceptable. Secondly, the model feeded into one-shot trainers might be a model with Retiarii-specific modules, such as LayerChoice and InputChoice. Such model cannot directly forward-propagate and trainers need to decide how to handle those modules.
A typical example is DartsTrainer, where learnable-parameters are used to combine multiple choices in LayerChoice. Retiarii provides ease-to-use utility functions for module-replace purposes, namely ``replace_layer_choice``, ``replace_input_choice``. A simplified example is as follows: A typical example is DartsTrainer, where learnable-parameters are used to combine multiple choices in LayerChoice. Retiarii provides ease-to-use utility functions for module-replace purposes, namely ``replace_layer_choice``, ``replace_input_choice``. A simplified example is as follows:
.. code-block::python .. code-block:: python
from nni.retiarii.trainer.pytorch import BaseOneShotTrainer from nni.retiarii.oneshot import BaseOneShotTrainer
from nni.retiarii.trainer.pytorch.utils import replace_layer_choice, replace_input_choice from nni.retiarii.oneshot.pytorch import replace_layer_choice, replace_input_choice
class DartsLayerChoice(nn.Module): class DartsLayerChoice(nn.Module):
......
...@@ -55,7 +55,7 @@ if __name__ == "__main__": ...@@ -55,7 +55,7 @@ if __name__ == "__main__":
trainer.train() trainer.train()
else: else:
from nni.retiarii.trainer.pytorch import DartsTrainer from nni.retiarii.oneshot.pytorch import DartsTrainer
trainer = DartsTrainer( trainer = DartsTrainer(
model=model, model=model,
loss=criterion, loss=criterion,
......
...@@ -48,7 +48,7 @@ class Cell(nn.Module): ...@@ -48,7 +48,7 @@ class Cell(nn.Module):
], key=cell_name + "_op") ], key=cell_name + "_op")
def forward(self, prev_layers): def forward(self, prev_layers):
from nni.retiarii.trainer.pytorch.random import PathSamplingInputChoice from nni.retiarii.oneshot.pytorch.random import PathSamplingInputChoice
out = self.input_choice(prev_layers) out = self.input_choice(prev_layers)
if isinstance(self.input_choice, PathSamplingInputChoice): if isinstance(self.input_choice, PathSamplingInputChoice):
# Retiarii pattern # Retiarii pattern
......
...@@ -66,7 +66,7 @@ if __name__ == "__main__": ...@@ -66,7 +66,7 @@ if __name__ == "__main__":
trainer.enable_visualization() trainer.enable_visualization()
trainer.train() trainer.train()
else: else:
from nni.retiarii.trainer.pytorch.enas import EnasTrainer from nni.retiarii.oneshot.pytorch.enas import EnasTrainer
trainer = EnasTrainer(model, trainer = EnasTrainer(model,
loss=criterion, loss=criterion,
metrics=accuracy, metrics=accuracy,
......
...@@ -84,7 +84,7 @@ if __name__ == "__main__": ...@@ -84,7 +84,7 @@ if __name__ == "__main__":
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5) optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)
if args.train_mode == 'search': if args.train_mode == 'search':
from nni.retiarii.trainer.pytorch import ProxylessTrainer from nni.retiarii.oneshot.pytorch import ProxylessTrainer
from torchvision.datasets import ImageNet from torchvision.datasets import ImageNet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) std=[0.229, 0.224, 0.225])
......
...@@ -2,4 +2,4 @@ from .operation import Operation ...@@ -2,4 +2,4 @@ from .operation import Operation
from .graph import * from .graph import *
from .execution import * from .execution import *
from .mutator import * from .mutator import *
from .utils import blackbox, blackbox_module, json_dump, json_dumps, json_load, json_loads, register_trainer from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls
import logging
import re import re
import torch import torch
...@@ -6,18 +5,16 @@ import torch ...@@ -6,18 +5,16 @@ import torch
from ..graph import Graph, Model, Node from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell, Operation from ..operation import Cell, Operation
from ..utils import get_records from ..serializer import get_init_parameters_or_fail
from ..utils import get_full_class_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import _convert_name, build_full_name from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__)
class GraphConverter: class GraphConverter:
def __init__(self): def __init__(self):
self.global_seq = 0 self.global_seq = 0
self.global_graph_id = 0 self.global_graph_id = 0
self.modules_arg = get_records()
def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index): def _add_edge_handle_source_node(self, _input, graph_inputs, ir_graph, output_remap, node_index):
if _input in graph_inputs: if _input in graph_inputs:
...@@ -247,7 +244,7 @@ class GraphConverter: ...@@ -247,7 +244,7 @@ class GraphConverter:
raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.') raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.')
else: else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, ' raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, '
'you are suggested to decorate the corresponding class with "@blackbox_module".') 'you are suggested to decorate the corresponding class with "@basic_unit".')
expr = _generate_expr(cond_tensor) expr = _generate_expr(cond_tensor)
return eval(expr) return eval(expr)
...@@ -539,13 +536,8 @@ class GraphConverter: ...@@ -539,13 +536,8 @@ class GraphConverter:
def _handle_layerchoice(self, module): def _handle_layerchoice(self, module):
choices = [] choices = []
for cand in list(module): for cand in list(module):
assert id(cand) in self.modules_arg, \ cand_type = '__torch__.' + get_full_class_name(cand.__class__)
f'Module not recorded: {id(cand)}. ' \ choices.append({'type': cand_type, 'parameters': get_init_parameters_or_fail(cand)})
'Try to import from `retiarii.nn` if you are using torch.nn module or ' \
'annotate your customized module with @blackbox_module.'
assert isinstance(self.modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
choices.append({'type': cand_type, 'parameters': self.modules_arg[id(cand)]})
return { return {
'candidates': choices, 'candidates': choices,
'label': module.label 'label': module.label
...@@ -601,14 +593,13 @@ class GraphConverter: ...@@ -601,14 +593,13 @@ class GraphConverter:
elif original_type_name == OpTypeName.ValueChoice: elif original_type_name == OpTypeName.ValueChoice:
m_attrs = self._handle_valuechoice(module) m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder: elif original_type_name == OpTypeName.Placeholder:
m_attrs = self.modules_arg[id(module)] m_attrs = get_init_parameters_or_fail(module)
elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__: elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__:
# this is a basic module from pytorch, no need to parse its graph # this is a basic module from pytorch, no need to parse its graph
assert id(module) in self.modules_arg, f'{original_type_name} arguments are not recorded' m_attrs = get_init_parameters_or_fail(module)
m_attrs = self.modules_arg[id(module)] else:
elif id(module) in self.modules_arg: # this module is marked as serialize, won't continue to parse
# this module is marked as blackbox, won't continue to parse m_attrs = get_init_parameters_or_fail(module, silently=True)
m_attrs = self.modules_arg[id(module)]
if m_attrs is not None: if m_attrs is not None:
return None, m_attrs return None, m_attrs
......
...@@ -3,7 +3,7 @@ import graphviz ...@@ -3,7 +3,7 @@ import graphviz
def convert_to_visualize(graph_ir, vgraph): def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items(): for name, graph in graph_ir.items():
if name == '_training_config': if name == '_evaluator':
continue continue
with vgraph.subgraph(name='cluster'+name) as subgraph: with vgraph.subgraph(name='cluster'+name) as subgraph:
subgraph.attr(color='blue') subgraph.attr(color='blue')
......
from .functional import FunctionalEvaluator
from ..graph import TrainingConfig from ..graph import Evaluator
class FunctionalTrainer(TrainingConfig): class FunctionalEvaluator(Evaluator):
""" """
Functional training config that directly takes a function and thus should be general. Functional evaluator that directly takes a function and thus should be general.
Attributes Attributes
---------- ----------
...@@ -19,7 +19,7 @@ class FunctionalTrainer(TrainingConfig): ...@@ -19,7 +19,7 @@ class FunctionalTrainer(TrainingConfig):
@staticmethod @staticmethod
def _load(ir): def _load(ir):
return FunctionalTrainer(ir['function'], **ir['arguments']) return FunctionalEvaluator(ir['function'], **ir['arguments'])
def _dump(self): def _dump(self):
return { return {
......
from .base import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from .lightning import *
# This file is deprecated. # This file is deprecated.
import abc
from typing import Any, List, Dict, Tuple from typing import Any, List, Dict, Tuple
import numpy as np import numpy as np
...@@ -10,8 +11,10 @@ from torchvision import datasets, transforms ...@@ -10,8 +11,10 @@ from torchvision import datasets, transforms
import nni import nni
from ..interface import BaseTrainer class BaseTrainer(abc.ABC):
from ...utils import register_trainer @abc.abstractmethod
def fit(self) -> None:
pass
def get_default_transform(dataset: str) -> Any: def get_default_transform(dataset: str) -> Any:
...@@ -45,7 +48,6 @@ def get_default_transform(dataset: str) -> Any: ...@@ -45,7 +48,6 @@ def get_default_transform(dataset: str) -> Any:
return None return None
@register_trainer
class PyTorchImageClassificationTrainer(BaseTrainer): class PyTorchImageClassificationTrainer(BaseTrainer):
""" """
Image classification trainer for PyTorch. Image classification trainer for PyTorch.
......
...@@ -7,8 +7,8 @@ import torch.optim as optim ...@@ -7,8 +7,8 @@ import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import nni import nni
from ...graph import TrainingConfig from ...graph import Evaluator
from ...utils import blackbox_module from ...serializer import serialize_cls
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression'] __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
...@@ -22,11 +22,11 @@ class LightningModule(pl.LightningModule): ...@@ -22,11 +22,11 @@ class LightningModule(pl.LightningModule):
self.model = model self.model = model
Trainer = blackbox_module(pl.Trainer) Trainer = serialize_cls(pl.Trainer)
DataLoader = blackbox_module(DataLoader) DataLoader = serialize_cls(DataLoader)
class Lightning(TrainingConfig): class Lightning(Evaluator):
""" """
Delegate the whole training to PyTorch Lightning. Delegate the whole training to PyTorch Lightning.
...@@ -162,7 +162,7 @@ class _SupervisedLearningModule(LightningModule): ...@@ -162,7 +162,7 @@ class _SupervisedLearningModule(LightningModule):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics} return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
@blackbox_module @serialize_cls
class _ClassificationModule(_SupervisedLearningModule): class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
...@@ -210,7 +210,7 @@ class Classification(Lightning): ...@@ -210,7 +210,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@blackbox_module @serialize_cls
class _RegressionModule(_SupervisedLearningModule): class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss, def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
......
...@@ -6,25 +6,25 @@ from typing import Dict, List ...@@ -6,25 +6,25 @@ from typing import Dict, List
from .interface import AbstractExecutionEngine, AbstractGraphListener from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, TrainingConfig from ..graph import Model, ModelStatus, MetricData, Evaluator
from ..integration_api import send_trial, receive_trial_parameters, get_advisor from ..integration_api import send_trial, receive_trial_parameters, get_advisor
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class BaseGraphData: class BaseGraphData:
def __init__(self, model_script: str, training_config: TrainingConfig) -> None: def __init__(self, model_script: str, evaluator: Evaluator) -> None:
self.model_script = model_script self.model_script = model_script
self.training_config = training_config self.evaluator = evaluator
def dump(self) -> dict: def dump(self) -> dict:
return { return {
'model_script': self.model_script, 'model_script': self.model_script,
'training_config': self.training_config 'evaluator': self.evaluator
} }
@staticmethod @staticmethod
def load(data) -> 'BaseGraphData': def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], data['training_config']) return BaseGraphData(data['model_script'], data['evaluator'])
class BaseExecutionEngine(AbstractExecutionEngine): class BaseExecutionEngine(AbstractExecutionEngine):
...@@ -55,7 +55,7 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -55,7 +55,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def submit_models(self, *models: Model) -> None: def submit_models(self, *models: Model) -> None:
for model in models: for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), model.training_config) data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
self._running_models[send_trial(data.dump())] = model self._running_models[send_trial(data.dump())] = model
def register_graph_listener(self, listener: AbstractGraphListener) -> None: def register_graph_listener(self, listener: AbstractGraphListener) -> None:
...@@ -107,5 +107,5 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -107,5 +107,5 @@ class BaseExecutionEngine(AbstractExecutionEngine):
with open(file_name, 'w') as f: with open(file_name, 'w') as f:
f.write(graph_data.model_script) f.write(graph_data.model_script)
model_cls = utils.import_(f'_generated_model.{random_str}._model') model_cls = utils.import_(f'_generated_model.{random_str}._model')
graph_data.training_config._execute(model_cls) graph_data.evaluator._execute(model_cls)
os.remove(file_name) os.remove(file_name)
...@@ -44,7 +44,7 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -44,7 +44,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
phy_models_and_placements = self._assemble(logical) phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements: for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement),
model.training_config) model.evaluator)
for m in grouped_models: for m in grouped_models:
self._original_models[m.model_id] = m self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model self._original_model_to_multi_model[m.model_id] = model
......
...@@ -145,11 +145,11 @@ class LogicalPlan: ...@@ -145,11 +145,11 @@ class LogicalPlan:
# Add a flag to mark multi-model in graph json. # Add a flag to mark multi-model in graph json.
# Multi-model has a list of training configs in kwargs['model_kwargs'] # Multi-model has a list of training configs in kwargs['model_kwargs']
if len(multi_model_placement) > 1: if len(multi_model_placement) > 1:
phy_model.training_config.kwargs['is_multi_model'] = True phy_model.evaluator.kwargs['is_multi_model'] = True
phy_model.training_config.kwargs['model_cls'] = phy_graph.name phy_model.evaluator.kwargs['model_cls'] = phy_graph.name
phy_model.training_config.kwargs['model_kwargs'] = [] phy_model.evaluator.kwargs['model_kwargs'] = []
# FIXME: allow user to specify # FIXME: allow user to specify
phy_model.training_config.module = 'nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer' phy_model.evaluator.module = 'nni.retiarii.trainer.pytorch.PyTorchMultiModelTrainer'
# merge sub-graphs # merge sub-graphs
for model in multi_model_placement: for model in multi_model_placement:
...@@ -160,7 +160,7 @@ class LogicalPlan: ...@@ -160,7 +160,7 @@ class LogicalPlan:
# When replace logical nodes, merge the training configs when # When replace logical nodes, merge the training configs when
# input/output nodes are replaced. # input/output nodes are replaced.
training_config_slot = {} # Model ID -> Slot ID evaluator_slot = {} # Model ID -> Slot ID
input_slot_mapping = {} input_slot_mapping = {}
output_slot_mapping = {} output_slot_mapping = {}
# Replace all logical nodes to executable physical nodes # Replace all logical nodes to executable physical nodes
...@@ -181,25 +181,25 @@ class LogicalPlan: ...@@ -181,25 +181,25 @@ class LogicalPlan:
new_node, placement = node.assemble(multi_model_placement) new_node, placement = node.assemble(multi_model_placement)
if isinstance(new_node.operation, _IOPseudoOperation): if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id model_id = new_node.graph.model.model_id
if model_id not in training_config_slot: if model_id not in evaluator_slot:
phy_model.training_config.kwargs['model_kwargs'].append(new_node.graph.model.training_config.kwargs.copy()) phy_model.evaluator.kwargs['model_kwargs'].append(new_node.graph.model.evaluator.kwargs.copy())
training_config_slot[model_id] = len(phy_model.training_config.kwargs['model_kwargs']) - 1 evaluator_slot[model_id] = len(phy_model.evaluator.kwargs['model_kwargs']) - 1
slot = training_config_slot[model_id] slot = evaluator_slot[model_id]
phy_model.training_config.kwargs['model_kwargs'][slot]['model_id'] = model_id phy_model.evaluator.kwargs['model_kwargs'][slot]['model_id'] = model_id
phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = False phy_model.evaluator.kwargs['model_kwargs'][slot]['use_input'] = False
phy_model.training_config.kwargs['model_kwargs'][slot]['use_output'] = False phy_model.evaluator.kwargs['model_kwargs'][slot]['use_output'] = False
else: else:
slot = training_config_slot[model_id] slot = evaluator_slot[model_id]
# If a model's inputs/outputs are not used in the multi-model # If a model's inputs/outputs are not used in the multi-model
# the codegen and trainer should not generate and use them # the codegen and trainer should not generate and use them
# "use_input" and "use_output" are used to mark whether # "use_input" and "use_output" are used to mark whether
# an input/output of a model is used in a multi-model # an input/output of a model is used in a multi-model
if new_node.operation.type == '_inputs': if new_node.operation.type == '_inputs':
input_slot_mapping[new_node] = slot input_slot_mapping[new_node] = slot
phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = True phy_model.evaluator.kwargs['model_kwargs'][slot]['use_input'] = True
if new_node.operation.type == '_outputs': if new_node.operation.type == '_outputs':
output_slot_mapping[new_node] = slot output_slot_mapping[new_node] = slot
phy_model.training_config.kwargs['model_kwargs'][slot]['use_output'] = True phy_model.evaluator.kwargs['model_kwargs'][slot]['use_output'] = True
self.node_replace(node, new_node) self.node_replace(node, new_node)
......
...@@ -45,9 +45,9 @@ class DedupInputOptimizer(AbstractOptimizer): ...@@ -45,9 +45,9 @@ class DedupInputOptimizer(AbstractOptimizer):
node_to_check.operation.type == '_inputs' and \ node_to_check.operation.type == '_inputs' and \
isinstance(root_node, OriginNode) and \ isinstance(root_node, OriginNode) and \
isinstance(node_to_check, OriginNode): isinstance(node_to_check, OriginNode):
if root_node.original_graph.model.training_config.module not in _supported_training_modules: if root_node.original_graph.model.evaluator.module not in _supported_training_modules:
return False return False
if root_node.original_graph.model.training_config == node_to_check.original_graph.model.training_config: if root_node.original_graph.model.evaluator == node_to_check.original_graph.model.evaluator:
return True return True
else: else:
return False return False
......
...@@ -13,13 +13,12 @@ from nni.experiment.config.base import ConfigBase, PathLike ...@@ -13,13 +13,12 @@ from nni.experiment.config.base import ConfigBase, PathLike
from nni.experiment.pipe import Pipe from nni.experiment.pipe import Pipe
from ..converter import convert_to_graph from ..converter import convert_to_graph
from ..graph import Model, TrainingConfig from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor from ..integration import RetiariiAdvisor
from ..mutator import Mutator from ..mutator import Mutator
from ..nn.pytorch.mutator import process_inline_mutation from ..nn.pytorch.mutator import process_inline_mutation
from ..strategy import BaseStrategy from ..strategy import BaseStrategy
from ..trainer.interface import BaseOneShotTrainer, BaseTrainer from ..oneshot.interface import BaseOneShotTrainer
from ..utils import get_records
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -77,7 +76,7 @@ _validation_rules = { ...@@ -77,7 +76,7 @@ _validation_rules = {
class RetiariiExperiment(Experiment): class RetiariiExperiment(Experiment):
def __init__(self, base_model: nn.Module, trainer: Union[TrainingConfig, BaseOneShotTrainer], def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotTrainer],
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None): applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None):
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed. # TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = None self.config: RetiariiExeConfig = None
...@@ -87,7 +86,6 @@ class RetiariiExperiment(Experiment): ...@@ -87,7 +86,6 @@ class RetiariiExperiment(Experiment):
self.trainer = trainer self.trainer = trainer
self.applied_mutators = applied_mutators self.applied_mutators = applied_mutators
self.strategy = strategy self.strategy = strategy
self.recorded_module_args = get_records()
self._dispatcher = RetiariiAdvisor() self._dispatcher = RetiariiAdvisor()
self._dispatcher_thread: Optional[Thread] = None self._dispatcher_thread: Optional[Thread] = None
...@@ -101,7 +99,7 @@ class RetiariiExperiment(Experiment): ...@@ -101,7 +99,7 @@ class RetiariiExperiment(Experiment):
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e raise e
base_model_ir = convert_to_graph(script_module, self.base_model) base_model_ir = convert_to_graph(script_module, self.base_model)
base_model_ir.training_config = self.trainer base_model_ir.evaluator = self.trainer
# handle inline mutations # handle inline mutations
mutators = process_inline_mutation(base_model_ir) mutators = process_inline_mutation(base_model_ir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment