trainer_engine.md 4.62 KB
Newer Older
1
# Colossal-AI Engine & Customize Your Trainer
zbian's avatar
zbian committed
2

3
## Colossal-AI engine
zbian's avatar
zbian committed
4

5
6
7
To better understand how `Engine` class works, let's start from the conception of the process function in common
engines. The process function usually controls the behavior over a batch of a dataset, `Engine` class just controls the
process function. Here we give a standard process function in the following code block.
zbian's avatar
zbian committed
8
9
10
11
12
13
14
15
16
17
18

```python
def process_function(dataloader, model, criterion, optim):
    optim.zero_grad()
    data, label = next(dataloader)
    output = model(data)
    loss = criterion(output, label)
    loss.backward()
    optim.setp()
```

19
The engine class is a high-level wrapper of these frequently-used functions while preserving the PyTorch-like function signature and integrating with our features.
zbian's avatar
zbian committed
20
21
22
23
24
25

```python
import torch
import torch.nn as nn
import torchvision.models as models
import colossalai
26
from colossalai.engine import Engine
27
from torchvision.datasets import CIFAR10
zbian's avatar
zbian committed
28
29
30

model = models.resnet18()
criterion = nn.CrossEntropyLoss()
31
optimizer = torch.optim.Adam(model.parameters())
zbian's avatar
zbian committed
32

33
34
35
36
37
38
39
40
41
42
43
44
45
dataset = CIFAR10(...)
dataloader = colossalai.utils.get_dataloader(dataset)

engine, dataloader, _,  _ = colossalai.initialize(model, optimizer, criterion, dataloader)

# exmaple of a training iteratio
for img, label in dataloader:
    engine.zero_grad()
    output = engine(img)
    loss = engine.criterion(output, label)
    engine.backward(loss)
    engine.step()

zbian's avatar
zbian committed
46
47
```

48
More information regarding the class can be found in the API references.
zbian's avatar
zbian committed
49
50
51
52
53

## Customize your trainer

### Overview

54
55
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly
recommend that you read *Get Started*
56
section and *Colossal-AI engine* first.
zbian's avatar
zbian committed
57

58
59
60
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write
your own scripts, you can simply construct your own trainer by calling the `Trainer` class, just like what we did in the
following code block.
zbian's avatar
zbian committed
61

62
```python
63
trainer = Trainer(engine)
64
```
zbian's avatar
zbian committed
65

66
67
68
69
70
71
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more
powerful, we incorporate a set of handy tools to the class. For example, you can monitor or record the running states
and metrics which indicate the current performance of the model. These functions are realized by hooks. The `BasicHook`
class allows you to execute your hook functions at specified time. We have already created some practical hooks for you,
as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the
class can be found in the API references.
zbian's avatar
zbian committed
72

73
74
75
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides,
they print the current loss and accuracy to let users monitor the performance of the model.

zbian's avatar
zbian committed
76
```python
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import colossalai
from colossalai.trainer import hooks, Trainer
from colossalai.utils import MultiTimer
from colossalai.logging import get_dist_logger

... = colossalai.initialize(...)

timer = MultiTimer()
logger = get_dist_logger()

# if you want to save log to file
logger.log_to_file('./logs/')

trainer = Trainer(
    engine=engine,
    timer=timer,
    logger=logger
)

hook_list = [
    hooks.LossHook(),
    hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
    hooks.AccuracyHook(),
    hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
    hooks.LogMetricByEpochHook(logger),
    hooks.LogMemoryByEpochHook(logger),
    hooks.LogTimingByEpochHook(timer, logger),
    hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
zbian's avatar
zbian committed
105
106
]

107
108
109
110
111
112
113
114
115
116
trainer.fit(
    train_dataloader=train_dataloader,
    epochs=NUM_EPOCHS,
    test_dataloader=test_dataloader,
    test_interval=1,
    hooks=hook_list,
    display_progress=True
)

```
zbian's avatar
zbian committed
117
118
119

### Hook

120
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook`
121
class to write a metric collector. These hook functions can be called at different stage in the trainer's life cycle.
122
123
Besides, you can define the priorities of all hooks to arrange the execution order of them. More information can be
found in the API references.
zbian's avatar
zbian committed
124
125
126

### Metric

127
128
129
You can write your own metrics by extending our `Metric` class. It should be used with the `MetricHook` class. When your
write your own metric hooks, please set the priority carefully and make sure the hook is called before other hooks which
might require the results of the metric hook.
zbian's avatar
zbian committed
130

131
132
We've already provided some metric hooks and we store metric objects in `runner.states['metrics']`. It is a dictionary
and metrics can be accessed by their names.