"docs/zh_CN/Compressor/Framework.md" did not exist on "116fd9adaca1141aa1b189ca53faf98c23339a43"
WriteTrainer.rst 5.22 KB
Newer Older
QuanluZhang's avatar
QuanluZhang committed
1
2
Customize A New Model Evaluator
===============================
QuanluZhang's avatar
QuanluZhang committed
3

QuanluZhang's avatar
QuanluZhang committed
4
Model Evaluator is necessary to evaluate the performance of new explored models. A model evaluator usually includes training, validating and testing of a single model. We provide two ways for users to write a new model evaluator, which will be demonstrated below respectively.
5
6

With FunctionalEvaluator
QuanluZhang's avatar
QuanluZhang committed
7
------------------------
8

QuanluZhang's avatar
QuanluZhang committed
9
The simplest way to customize a new evaluator is with functional APIs, which is very easy when training code is already available. Users only need to write a fit function that wraps everything. This function takes one positional arguments (``model_cls``) and possible keyword arguments. The keyword arguments (other than ``model_cls``) are fed to FunctionEvaluator as its initialization parameters. In this way, users get everything under their control, but expose less information to the framework and thus fewer opportunities for possible optimization. An example is as belows:
10
11
12
13
14
15

.. code-block:: python

    from nni.retiarii.evaluator import FunctionalEvaluator
    from nni.retiarii.experiment.pytorch import RetiariiExperiment

QuanluZhang's avatar
QuanluZhang committed
16
17
    def fit(model_cls, dataloader):
        model = model_cls()
18
19
20
21
22
23
        train(model, dataloader)
        acc = test(model, dataloader)
        nni.report_final_result(acc)

    evaluator = FunctionalEvaluator(fit, dataloader=DataLoader(foo, bar))
    experiment = RetiariiExperiment(base_model, evaluator, mutators, strategy)
24

QuanluZhang's avatar
QuanluZhang committed
25
26
.. note:: Due to our current implementation limitation, the ``fit`` function should be put in another python file instead of putting it in the main file. This limitation will be fixed in future release.

Yuge Zhang's avatar
Yuge Zhang committed
27
28
.. note:: When using customized evaluators, if you want to visualize models, you need to export your model and save it into ``$NNI_OUTPUT_DIR/model.onnx`` in your evaluator.

29
With PyTorch-Lightning
QuanluZhang's avatar
QuanluZhang committed
30
----------------------
31

QuanluZhang's avatar
QuanluZhang committed
32
It's recommended to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the `document of PyTorch-lightning <https://pytorch-lightning.readthedocs.io/>`__ to learn the basic concepts and components provided by PyTorch-lightning.
33

QuanluZhang's avatar
QuanluZhang committed
34
In practice, writing a new training module in Retiarii should inherit ``nni.retiarii.evaluator.pytorch.lightning.LightningModule``, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Evaluators should also communicate with strategies via two API calls (``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively. 
35
36
37

An example is as follows:

38
.. code-block:: python
39

40
    from nni.retiarii.evaluator.pytorch.lightning import LightningModule  # please import this one
41

42
    @basic_unit
43
44
    class AutoEncoder(LightningModule):
        def __init__(self):
45
            super().__init__()
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
            self.decoder = nn.Sequential(
                nn.Linear(3, 64),
                nn.ReLU(),
                nn.Linear(64, 28*28)
            )

        def forward(self, x):
            embedding = self.model(x)  # let's search for encoder
            return embedding

        def training_step(self, batch, batch_idx):
            # training_step defined the train loop.
            # It is independent of forward
            x, y = batch
            x = x.view(x.size(0), -1)
            z = self.model(x)  # model is the one that is searched for
            x_hat = self.decoder(z)
            loss = F.mse_loss(x_hat, x)
            # Logging to TensorBoard by default
            self.log('train_loss', loss)
            return loss

        def validation_step(self, batch, batch_idx):
            x, y = batch
            x = x.view(x.size(0), -1)
            z = self.model(x)
            x_hat = self.decoder(z)
            loss = F.mse_loss(x_hat, x)
            self.log('val_loss', loss)

        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
            return optimizer

        def on_validation_epoch_end(self):
            nni.report_intermediate_result(self.trainer.callback_metrics['val_loss'].item())

        def teardown(self, stage):
            if stage == 'fit':
                nni.report_final_result(self.trainer.callback_metrics['val_loss'].item())

Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a ``Lightning`` object, and pass this object into a Retiarii experiment.

89
.. code-block:: python
90

91
    import nni.retiarii.evaluator.pytorch.lightning as pl
92
93
94
95
96
97
98
    from nni.retiarii.experiment.pytorch import RetiariiExperiment

    lightning = pl.Lightning(AutoEncoder(),
                             pl.Trainer(max_epochs=10),
                             train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                             val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
    experiment = RetiariiExperiment(base_model, lightning, mutators, strategy)