trainer_engine.md 4.41 KB
Newer Older
zbian's avatar
zbian committed
1
2
3
4
# Build your engine & Customize your trainer

## Build your engine

5
6
7
To better understand how `Engine` class works, let's start from the conception of the process function in common engines. The process function 
usually controls the behavior over a batch of a dataset, `Engine` class just controls the process function. Here we give a standard process 
function in the following code block.
zbian's avatar
zbian committed
8
9
10
11
12
13
14
15
16
17
18

```python
def process_function(dataloader, model, criterion, optim):
    optim.zero_grad()
    data, label = next(dataloader)
    output = model(data)
    loss = criterion(output, label)
    loss.backward()
    optim.setp()
```

19
20
21
22
In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users to write their own process 
functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for users, we provide the powerful `Engine` class. This class 
enables pipeline parallelism and offers one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler 
in the `Engine` class to adjust learning rate during training.
zbian's avatar
zbian committed
23

24
25
In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The following code block provides
an example.
zbian's avatar
zbian committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

```python
import torch
import torch.nn as nn
import torchvision.models as models
import colossalai

model = models.resnet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model)
lr_scheduler = colossalai.nn.lr_scheduler.CosineAnnealingLR(optimizer, 1000)
schedule = colossalai.engine.schedule.NoPipelineSchedule()

MyEngine = Engine(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    schedule=schedule
)
```

48
More information regarding the class can be found in the API references.
zbian's avatar
zbian committed
49
50
51
52
53

## Customize your trainer

### Overview

54
55
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly recommend that you read *Get Started* 
section and *Build your engine* first.
zbian's avatar
zbian committed
56

57
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write your own scripts, you can simply 
58
construct your own trainer by calling the `Trainer` class, just like what we did in the following code block.
zbian's avatar
zbian committed
59

60
61
62
```python
MyTrainer = Trainer(MyEngine)
```
zbian's avatar
zbian committed
63

64
65
66
67
68
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more powerful, we incorporate a set of 
handy tools to the class. For example, you can monitor or record the running states and metrics which indicate the current performance of the model. These
functions are realized by hooks. The `BasicHook` class allows you to execute your hook functions at specified time. We have already created some practical
hooks for you, as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the class can be found 
in the API references.
zbian's avatar
zbian committed
69
70
71
72
73
74
75
76

```python
hooks = [
    dict(type='LogMetricByEpochHook'),
    dict(type='LogTimingByEpochHook'),
    dict(type='LogMemoryByEpochHook'),
    dict(type='AccuracyHook'),
    dict(type='LossHook'),
77
78
79
    dict(type='TensorboardHook', log_dir='./tfb_logs'),
    dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
    dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
zbian's avatar
zbian committed
80
81
82
]
```

83
84
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides, they print the current loss and 
accuracy to let users monitor the performance of the model.
zbian's avatar
zbian committed
85
86
87

### Hook

88
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook` class to write a metric collector. 
89
These hook functions can be called at twelve timing in the trainer's life cycle. Besides, you can define the priorities of all hooks to arrange the execution order of them.
90
More information can be found in the API references. 
zbian's avatar
zbian committed
91
92
93

### Metric

94
95
You can write your own metrics by extending our `Metric` class. It should be used with the `MetricHook` class. When your write your own metric hooks, please set 
the priority carefully and make sure the hook is called before other hooks which might require the results of the metric hook.
zbian's avatar
zbian committed
96

97
We've already provided some metric hooks and we store metric objects in `runner.states['metrics']`. It is a dictionary and metrics can be accessed by their names.