train.py 8.32 KB
Newer Older
myownskyW7's avatar
myownskyW7 committed
1
from __future__ import division
2
import re
myownskyW7's avatar
myownskyW7 committed
3
4
5
6
from collections import OrderedDict

import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
7
from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict
myownskyW7's avatar
myownskyW7 committed
8

9
from mmdet import datasets
10
11
12
from mmdet.core import (CocoDistEvalmAPHook, CocoDistEvalRecallHook,
                        DistEvalmAPHook, DistOptimizerHook, Fp16OptimizerHook)
from mmdet.datasets import DATASETS, build_dataloader
myownskyW7's avatar
myownskyW7 committed
13
from mmdet.models import RPN
Kai Chen's avatar
Kai Chen committed
14
from .env import get_root_logger
myownskyW7's avatar
myownskyW7 committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                '{} is not a tensor or list of tensors'.format(loss_name))

    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

    log_vars['loss'] = loss
    for name in log_vars:
        log_vars[name] = log_vars[name].item()

    return loss, log_vars


def batch_processor(model, data, train_mode):
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

41
42
    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
myownskyW7's avatar
myownskyW7 committed
43
44
45
46

    return outputs


Kai Chen's avatar
Kai Chen committed
47
48
49
50
51
52
53
54
def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   logger=None):
    if logger is None:
        logger = get_root_logger(cfg.log_level)
myownskyW7's avatar
myownskyW7 committed
55

Kai Chen's avatar
Kai Chen committed
56
57
58
    # start training
    if distributed:
        _dist_train(model, dataset, cfg, validate=validate)
myownskyW7's avatar
myownskyW7 committed
59
    else:
Kai Chen's avatar
Kai Chen committed
60
        _non_dist_train(model, dataset, cfg, validate=validate)
myownskyW7's avatar
myownskyW7 committed
61

Kai Chen's avatar
Kai Chen committed
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def build_optimizer(model, optimizer_cfg):
    """Build optimizer from configs.

    Args:
        model (:obj:`nn.Module`): The model with parameters to be optimized.
        optimizer_cfg (dict): The config dict of the optimizer.
            Positional fields are:
                - type: class name of the optimizer.
                - lr: base learning rate.
            Optional fields are:
                - any arguments of the corresponding optimizer type, e.g.,
                  weight_decay, momentum, etc.
                - paramwise_options: a dict with 3 accepted fileds
                  (bias_lr_mult, bias_decay_mult, norm_decay_mult).
                  `bias_lr_mult` and `bias_decay_mult` will be multiplied to
                  the lr and weight decay respectively for all bias parameters
                  (except for the normalization layers), and
                  `norm_decay_mult` will be multiplied to the weight decay
                  for all weight and bias parameters of normalization layers.

    Returns:
        torch.optim.Optimizer: The initialized optimizer.
    """
    if hasattr(model, 'module'):
        model = model.module

    optimizer_cfg = optimizer_cfg.copy()
    paramwise_options = optimizer_cfg.pop('paramwise_options', None)
    # if no paramwise option is specified, just use the global setting
    if paramwise_options is None:
93
94
        return obj_from_dict(optimizer_cfg, torch.optim,
                             dict(params=model.parameters()))
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    else:
        assert isinstance(paramwise_options, dict)
        # get base lr and weight decay
        base_lr = optimizer_cfg['lr']
        base_wd = optimizer_cfg.get('weight_decay', None)
        # weight_decay must be explicitly specified if mult is specified
        if ('bias_decay_mult' in paramwise_options
                or 'norm_decay_mult' in paramwise_options):
            assert base_wd is not None
        # get param-wise options
        bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.)
        bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.)
        norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.)
        # set param-wise lr and weight decay
        params = []
        for name, param in model.named_parameters():
Cao Yuhang's avatar
Cao Yuhang committed
111
            param_group = {'params': [param]}
112
            if not param.requires_grad:
Cao Yuhang's avatar
Cao Yuhang committed
113
114
115
116
                # FP16 training needs to copy gradient/weight between master
                # weight copy and model weight, it is convenient to keep all
                # parameters here to align with model.parameters()
                params.append(param_group)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
                continue

            # for norm layers, overwrite the weight decay of weight and bias
            # TODO: obtain the norm layer prefixes dynamically
            if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name):
                if base_wd is not None:
                    param_group['weight_decay'] = base_wd * norm_decay_mult
            # for other layers, overwrite both lr and weight decay of bias
            elif name.endswith('.bias'):
                param_group['lr'] = base_lr * bias_lr_mult
                if base_wd is not None:
                    param_group['weight_decay'] = base_wd * bias_decay_mult
            # otherwise use the global settings

            params.append(param_group)

        optimizer_cls = getattr(torch.optim, optimizer_cfg.pop('type'))
        return optimizer_cls(params, **optimizer_cfg)


Kai Chen's avatar
Kai Chen committed
137
def _dist_train(model, dataset, cfg, validate=False):
myownskyW7's avatar
myownskyW7 committed
138
    # prepare data loaders
139
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
myownskyW7's avatar
myownskyW7 committed
140
    data_loaders = [
141
        build_dataloader(
142
143
            ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True)
        for ds in dataset
myownskyW7's avatar
myownskyW7 committed
144
145
    ]
    # put model on gpus
Kai Chen's avatar
Kai Chen committed
146
    model = MMDistributedDataParallel(model.cuda())
Cao Yuhang's avatar
Cao Yuhang committed
147

myownskyW7's avatar
myownskyW7 committed
148
    # build runner
149
150
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
myownskyW7's avatar
myownskyW7 committed
151
                    cfg.log_level)
Cao Yuhang's avatar
Cao Yuhang committed
152
153
154
155
156
157
158
159
160

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,
                                             **fp16_cfg)
    else:
        optimizer_config = DistOptimizerHook(**cfg.optimizer_config)

myownskyW7's avatar
myownskyW7 committed
161
162
163
    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)
Kai Chen's avatar
Kai Chen committed
164
165
166
    runner.register_hook(DistSamplerSeedHook())
    # register eval hooks
    if validate:
167
        val_dataset_cfg = cfg.data.val
168
        eval_cfg = cfg.get('evaluation', {})
Kai Chen's avatar
Kai Chen committed
169
        if isinstance(model.module, RPN):
Kai Chen's avatar
Kai Chen committed
170
            # TODO: implement recall hooks for other datasets
171
172
            runner.register_hook(
                CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
Kai Chen's avatar
Kai Chen committed
173
        else:
174
            dataset_type = DATASETS.get(val_dataset_cfg.type)
175
            if issubclass(dataset_type, datasets.CocoDataset):
176
177
                runner.register_hook(
                    CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
Kai Chen's avatar
Kai Chen committed
178
            else:
179
180
                runner.register_hook(
                    DistEvalmAPHook(val_dataset_cfg, **eval_cfg))
Kai Chen's avatar
Kai Chen committed
181
182
183
184
185
186
187
188
189
190

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)


def _non_dist_train(model, dataset, cfg, validate=False):
    # prepare data loaders
191
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
Kai Chen's avatar
Kai Chen committed
192
    data_loaders = [
193
        build_dataloader(
194
            ds,
195
196
197
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            cfg.gpus,
198
            dist=False) for ds in dataset
Kai Chen's avatar
Kai Chen committed
199
200
201
    ]
    # put model on gpus
    model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
Cao Yuhang's avatar
Cao Yuhang committed
202

Kai Chen's avatar
Kai Chen committed
203
    # build runner
204
205
    optimizer = build_optimizer(model, cfg.optimizer)
    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,
Kai Chen's avatar
Kai Chen committed
206
                    cfg.log_level)
Cao Yuhang's avatar
Cao Yuhang committed
207
208
209
210
211
212
213
214
    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=False)
    else:
        optimizer_config = cfg.optimizer_config
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
Kai Chen's avatar
Kai Chen committed
215
                                   cfg.checkpoint_config, cfg.log_config)
myownskyW7's avatar
myownskyW7 committed
216
217
218
219
220

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
myownskyW7's avatar
myownskyW7 committed
221
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)