5_new_modules.md 8.53 KB
Newer Older
unknown's avatar
unknown committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# Tutorial 5: Adding New Modules

In this tutorial, we will introduce some methods about how to customize optimizer, develop new components and new a learning rate scheduler for this project.

<!-- TOC -->

- [Customize Optimizer](#customize-optimizer)
- [Customize Optimizer Constructor](#customize-optimizer-constructor)
- [Develop New Components](#develop-new-components)
  - [Add new backbones](#add-new-backbones)
  - [Add new heads](#add-new-heads)
  - [Add new loss](#add-new-loss)
- [Add new learning rate scheduler (updater)](#add-new-learning-rate-scheduler--updater-)

<!-- TOC -->

## Customize Optimizer

An example of customized optimizer is [CopyOfSGD](/mmaction/core/optimizer/copy_of_sgd.py) is defined in `mmaction/core/optimizer/copy_of_sgd.py`.
More generally, a customized optimizer could be defined as following.

Assume you want to add an optimizer named as `MyOptimizer`, which has arguments `a`, `b` and `c`.
You need to first implement the new optimizer in a file, e.g., in `mmaction/core/optimizer/my_optimizer.py`:

```python
from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer

@OPTIMIZERS.register_module()
class MyOptimizer(Optimizer):

    def __init__(self, a, b, c):
```

Then add this module in `mmaction/core/optimizer/__init__.py`, thus the registry will find the new module and add it:

```python
from .my_optimizer import MyOptimizer
```

Then you can use `MyOptimizer` in `optimizer` field of config files.
In the configs, the optimizers are defined by the field `optimizer` like the following:

```python
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
```

To use your own optimizer, the field can be changed as

```python
optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)
```

We already support to use all the optimizers implemented by PyTorch, and the only modification is to change the `optimizer` field of config files.
For example, if you want to use `ADAM`, though the performance will drop a lot, the modification could be as the following.

```python
optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)
```

The users can directly set arguments following the [API doc](https://pytorch.org/docs/stable/optim.html?highlight=optim#module-torch.optim) of PyTorch.

## Customize Optimizer Constructor

Some models may have some parameter-specific settings for optimization, e.g. weight decay for BatchNorm layers.
The users can do those fine-grained parameter tuning through customizing optimizer constructor.

You can write a new optimizer constructor inherit from [DefaultOptimizerConstructor](https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/optimizer/default_constructor.py)
and overwrite the `add_params(self, params, module)` method.

An example of customized optimizer constructor is [TSMOptimizerConstructor](/mmaction/core/optimizer/tsm_optimizer_constructor.py).
More generally, a customized optimizer constructor could be defined as following.

In `mmaction/core/optimizer/my_optimizer_constructor.py`:

```python
from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor

@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):

```

In `mmaction/core/optimizer/__init__.py`:

```python
from .my_optimizer_constructor import MyOptimizerConstructor
```

Then you can use `MyOptimizerConstructor` in `optimizer` field of config files.

```python
# optimizer
optimizer = dict(
    type='SGD',
    constructor='MyOptimizerConstructor',
    paramwise_cfg=dict(fc_lr5=True),
    lr=0.02,
    momentum=0.9,
    weight_decay=0.0001)
```

## Develop New Components

We basically categorize model components into 4 types.

- recognizer: the whole recognizer model pipeline, usually contains a backbone and cls_head.
- backbone: usually an FCN network to extract feature maps, e.g., ResNet, BNInception.
- cls_head: the component for classification task, usually contains an FC layer with some pooling layers.
- localizer: the model for temporal localization task, currently available: BSN, BMN, SSN.

### Add new backbones

Here we show how to develop new components with an example of TSN.

1. Create a new file `mmaction/models/backbones/resnet.py`.

   ```python
   import torch.nn as nn

   from ..builder import BACKBONES

   @BACKBONES.register_module()
   class ResNet(nn.Module):

       def __init__(self, arg1, arg2):
           pass

       def forward(self, x):  # should return a tuple
           pass

       def init_weights(self, pretrained=None):
           pass
   ```

2. Import the module in `mmaction/models/backbones/__init__.py`.

   ```python
   from .resnet import ResNet
   ```

3. Use it in your config file.

   ```python
   model = dict(
       ...
       backbone=dict(
           type='ResNet',
           arg1=xxx,
           arg2=xxx),
   )
   ```

### Add new heads

Here we show how to develop a new head with the example of TSNHead as the following.

1. Create a new file `mmaction/models/heads/tsn_head.py`.

   You can write a new classification head inheriting from [BaseHead](/mmaction/models/heads/base.py),
   and overwrite `init_weights(self)` and `forward(self, x)` method.

   ```python
   from ..builder import HEADS
   from .base import BaseHead


   @HEADS.register_module()
   class TSNHead(BaseHead):

       def __init__(self, arg1, arg2):
           pass

       def forward(self, x):
           pass

       def init_weights(self):
           pass
   ```

2. Import the module in `mmaction/models/heads/__init__.py`

   ```python
   from .tsn_head import TSNHead
   ```

3. Use it in your config file

   ```python
   model = dict(
       ...
       cls_head=dict(
           type='TSNHead',
           num_classes=400,
           in_channels=2048,
           arg1=xxx,
           arg2=xxx),
   ```

### Add new loss

Assume you want to add a new loss as `MyLoss`. To add a new loss function, the users need implement it in `mmaction/models/losses/my_loss.py`.

```python
import torch
import torch.nn as nn

from ..builder import LOSSES

def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss


@LOSSES.register_module()
class MyLoss(nn.Module):

    def forward(self, pred, target):
        loss = my_loss(pred, target)
        return loss
```

Then the users need to add it in the `mmaction/models/losses/__init__.py`

```python
from .my_loss import MyLoss, my_loss
```

To use it, modify the `loss_xxx` field. Since MyLoss is for regression, we can use it for the bbox loss `loss_bbox`.

```python
loss_bbox=dict(type='MyLoss'))
```

## Add new learning rate scheduler (updater)

The default manner of constructing a lr updater(namely, 'scheduler' by pytorch convention), is to modify the config such as:

```python
...
lr_config = dict(policy='step', step=[20, 40])
...
```

In the api for [`train.py`](/mmaction/apis/train.py), it will register the learning rate updater hook based on the config at:

```python
...
    runner.register_training_hooks(
        cfg.lr_config,
        optimizer_config,
        cfg.checkpoint_config,
        cfg.log_config,
        cfg.get('momentum_config', None))
...
```

So far, the supported updaters can be find in [mmcv](https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py), but if you want to customize a new learning rate updater, you may follow the steps below:

1. First, write your own LrUpdaterHook in `$MMAction2/mmaction/core/scheduler`. The snippet followed is an example of customized lr updater that uses learning rate based on a specific learning rate ratio: `lrs`, by which the learning rate decreases at each `steps`:

```python
@HOOKS.register_module()
# Register it here
class RelativeStepLrUpdaterHook(LrUpdaterHook):
    # You should inheritate it from mmcv.LrUpdaterHook
    def __init__(self, steps, lrs, **kwargs):
        super().__init__(**kwargs)
        assert len(steps) == (len(lrs))
        self.steps = steps
        self.lrs = lrs

    def get_lr(self, runner, base_lr):
        # Only this function is required to override
        # This function is called before each training epoch, return the specific learning rate here.
        progress = runner.epoch if self.by_epoch else runner.iter
        for i in range(len(self.steps)):
            if progress < self.steps[i]:
                return self.lrs[i]
```

2. Modify your config:

In your config file, swap the original `lr_config` by:

```python
lr_config = dict(policy='RelativeStep', steps=[20, 40, 60], lrs=[0.1, 0.01, 0.001])
```

More examples can be found in [mmcv](https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py).