Write_Models.md 2.25 KB
Newer Older
yuguo960516's avatar
bloom  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Write Models

This section introduces how to implement a new model entirely from scratch and make it compatible with LiBai.


## Construct Models in LiBai

LiBai uses [LazyConfig](https://libai.readthedocs.io/en/latest/tutorials/Config_System.html) for a more flexible config system, which means you can simply import your own model in your config and train it under LiBai.

For image classification task, the input data is usually a batch of images and labels. The following code shows how to build a toy model for this task. Import in your code:
```python
# toy_model.py
import oneflow as flow
import oneflow.nn as nn


class ToyModel(nn.Module):
    def __init__(self, 
        num_classes=1000, 
    ):
        super().__init__()
        self.features = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(64, num_classes)
        self.loss_func = nn.CrossEntropyLoss()
    
    def forward(self, images, labels=None):
        x = self.features(images)
        x = self.avgpool(x)
        x = flow.flatten(x, 1)
        x = self.classifier(x)

        if labels is not None and self.training:
            losses = self.loss_func(x, labels)
            return {"losses": losses}
        else:
            return {"prediction_scores": x}
```

**Note:**
- For classification models, the ``forward`` function must have ``images`` and ``labels`` as arguments, which correspond to the output in ``__getitem__`` of LiBai's built-in datasets. Please refer to [imagenet.py](https://github.com/Oneflow-Inc/libai/blob/main/libai/data/datasets/imagenet.py) for more details about the dataset.
- **This toy model** will return ``losses`` during training and ``prediction_scores`` during inference, and both of them should be the type of ``dict``, which means you should implement the ``loss function`` in your model, like ``self.loss_func=nn.CrossEntropyLoss()`` as the ToyModel shows above.


## Import the model in config

With ``LazyConfig System``, you can simply import the model in your config file. The following code shows how to use ``ToyModel`` in your config file:
```python
# config.py
from libai.config import LazyCall
from toy_model import ToyModel

model = LazyCall(ToyModel)(
    num_classes=1000
)
```