train.py 10.6 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import logging
import random
3
import re
myownskyW7's avatar
myownskyW7 committed
4
5
from collections import OrderedDict

Kai Chen's avatar
Kai Chen committed
6
import numpy as np
myownskyW7's avatar
myownskyW7 committed
7
import torch
Cao Yuhang's avatar
Cao Yuhang committed
8
import torch.distributed as dist
myownskyW7's avatar
myownskyW7 committed
9
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
Kai Chen's avatar
Kai Chen committed
10
11
from mmcv.runner import (DistSamplerSeedHook, Runner, get_dist_info,
                         obj_from_dict)
myownskyW7's avatar
myownskyW7 committed
12

13
from mmdet import datasets
14
15
16
from mmdet.core import (CocoDistEvalmAPHook, CocoDistEvalRecallHook,
                        DistEvalmAPHook, DistOptimizerHook, Fp16OptimizerHook)
from mmdet.datasets import DATASETS, build_dataloader
myownskyW7's avatar
myownskyW7 committed
17
from mmdet.models import RPN
Kai Chen's avatar
Kai Chen committed
18
19
20
21
22
23
24
25
26


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


27
28
29
30
31
32
33
34
def get_root_logger(log_file=None, log_level=logging.INFO):
    logger = logging.getLogger('mmdet')
    # if the logger has been initialized, just return it
    if logger.hasHandlers():
        return logger

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
Kai Chen's avatar
Kai Chen committed
35
36
37
    rank, _ = get_dist_info()
    if rank != 0:
        logger.setLevel('ERROR')
38
39
40
41
42
43
44
    elif log_file is not None:
        file_handler = logging.FileHandler(log_file, 'w')
        file_handler.setFormatter(
            logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
        file_handler.setLevel(log_level)
        logger.addHandler(file_handler)

Kai Chen's avatar
Kai Chen committed
45
    return logger
myownskyW7's avatar
myownskyW7 committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61


def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                '{} is not a tensor or list of tensors'.format(loss_name))

    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

    log_vars['loss'] = loss
Cao Yuhang's avatar
Cao Yuhang committed
62
63
64
65
66
67
    for loss_name, loss_value in log_vars.items():
        # reduce loss when distributed training
        if dist.is_initialized():
            loss_value = loss_value.data.clone()
            dist.all_reduce(loss_value.div_(dist.get_world_size()))
        log_vars[loss_name] = loss_value.item()
myownskyW7's avatar
myownskyW7 committed
68
69
70
71
72
73
74
75

    return loss, log_vars


def batch_processor(model, data, train_mode):
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

76
77
    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
myownskyW7's avatar
myownskyW7 committed
78
79
80
81

    return outputs


Kai Chen's avatar
Kai Chen committed
82
83
84
85
86
def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
87
88
                   timestamp=None):
    logger = get_root_logger(cfg.log_level)
myownskyW7's avatar
myownskyW7 committed
89

Kai Chen's avatar
Kai Chen committed
90
91
    # start training
    if distributed:
92
93
94
95
96
97
98
        _dist_train(
            model,
            dataset,
            cfg,
            validate=validate,
            logger=logger,
            timestamp=timestamp)
myownskyW7's avatar
myownskyW7 committed
99
    else:
100
101
102
103
104
105
106
        _non_dist_train(
            model,
            dataset,
            cfg,
            validate=validate,
            logger=logger,
            timestamp=timestamp)
myownskyW7's avatar
myownskyW7 committed
107

Kai Chen's avatar
Kai Chen committed
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def build_optimizer(model, optimizer_cfg):
    """Build optimizer from configs.

    Args:
        model (:obj:`nn.Module`): The model with parameters to be optimized.
        optimizer_cfg (dict): The config dict of the optimizer.
            Positional fields are:
                - type: class name of the optimizer.
                - lr: base learning rate.
            Optional fields are:
                - any arguments of the corresponding optimizer type, e.g.,
                  weight_decay, momentum, etc.
                - paramwise_options: a dict with 3 accepted fileds
                  (bias_lr_mult, bias_decay_mult, norm_decay_mult).
                  `bias_lr_mult` and `bias_decay_mult` will be multiplied to
                  the lr and weight decay respectively for all bias parameters
                  (except for the normalization layers), and
                  `norm_decay_mult` will be multiplied to the weight decay
                  for all weight and bias parameters of normalization layers.

    Returns:
        torch.optim.Optimizer: The initialized optimizer.
131
132
133
134
135
136

    Example:
        >>> model = torch.nn.modules.Conv1d(1, 1, 1)
        >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
        >>>                      weight_decay=0.0001)
        >>> optimizer = build_optimizer(model, optimizer_cfg)
137
138
139
140
141
142
143
144
    """
    if hasattr(model, 'module'):
        model = model.module

    optimizer_cfg = optimizer_cfg.copy()
    paramwise_options = optimizer_cfg.pop('paramwise_options', None)
    # if no paramwise option is specified, just use the global setting
    if paramwise_options is None:
145
146
        return obj_from_dict(optimizer_cfg, torch.optim,
                             dict(params=model.parameters()))
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    else:
        assert isinstance(paramwise_options, dict)
        # get base lr and weight decay
        base_lr = optimizer_cfg['lr']
        base_wd = optimizer_cfg.get('weight_decay', None)
        # weight_decay must be explicitly specified if mult is specified
        if ('bias_decay_mult' in paramwise_options
                or 'norm_decay_mult' in paramwise_options):
            assert base_wd is not None
        # get param-wise options
        bias_lr_mult = paramwise_options.get('bias_lr_mult', 1.)
        bias_decay_mult = paramwise_options.get('bias_decay_mult', 1.)
        norm_decay_mult = paramwise_options.get('norm_decay_mult', 1.)
        # set param-wise lr and weight decay
        params = []
        for name, param in model.named_parameters():
Cao Yuhang's avatar
Cao Yuhang committed
163
            param_group = {'params': [param]}
164
            if not param.requires_grad:
Cao Yuhang's avatar
Cao Yuhang committed
165
166
167
168
                # FP16 training needs to copy gradient/weight between master
                # weight copy and model weight, it is convenient to keep all
                # parameters here to align with model.parameters()
                params.append(param_group)
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
                continue

            # for norm layers, overwrite the weight decay of weight and bias
            # TODO: obtain the norm layer prefixes dynamically
            if re.search(r'(bn|gn)(\d+)?.(weight|bias)', name):
                if base_wd is not None:
                    param_group['weight_decay'] = base_wd * norm_decay_mult
            # for other layers, overwrite both lr and weight decay of bias
            elif name.endswith('.bias'):
                param_group['lr'] = base_lr * bias_lr_mult
                if base_wd is not None:
                    param_group['weight_decay'] = base_wd * bias_decay_mult
            # otherwise use the global settings

            params.append(param_group)

        optimizer_cls = getattr(torch.optim, optimizer_cfg.pop('type'))
        return optimizer_cls(params, **optimizer_cfg)


189
190
191
192
193
194
def _dist_train(model,
                dataset,
                cfg,
                validate=False,
                logger=None,
                timestamp=None):
myownskyW7's avatar
myownskyW7 committed
195
    # prepare data loaders
196
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
myownskyW7's avatar
myownskyW7 committed
197
    data_loaders = [
198
        build_dataloader(
199
200
            ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True)
        for ds in dataset
myownskyW7's avatar
myownskyW7 committed
201
202
    ]
    # put model on gpus
Kai Chen's avatar
Kai Chen committed
203
    model = MMDistributedDataParallel(model.cuda())
Cao Yuhang's avatar
Cao Yuhang committed
204

myownskyW7's avatar
myownskyW7 committed
205
    # build runner
206
    optimizer = build_optimizer(model, cfg.optimizer)
207
208
209
210
    runner = Runner(
        model, batch_processor, optimizer, cfg.work_dir, logger=logger)
    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp
Cao Yuhang's avatar
Cao Yuhang committed
211
212
213
214
215
216
217
218
219

    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,
                                             **fp16_cfg)
    else:
        optimizer_config = DistOptimizerHook(**cfg.optimizer_config)

myownskyW7's avatar
myownskyW7 committed
220
221
222
    # register hooks
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)
Kai Chen's avatar
Kai Chen committed
223
224
225
    runner.register_hook(DistSamplerSeedHook())
    # register eval hooks
    if validate:
226
        val_dataset_cfg = cfg.data.val
227
        eval_cfg = cfg.get('evaluation', {})
Kai Chen's avatar
Kai Chen committed
228
        if isinstance(model.module, RPN):
Kai Chen's avatar
Kai Chen committed
229
            # TODO: implement recall hooks for other datasets
230
231
            runner.register_hook(
                CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg))
Kai Chen's avatar
Kai Chen committed
232
        else:
233
            dataset_type = DATASETS.get(val_dataset_cfg.type)
234
            if issubclass(dataset_type, datasets.CocoDataset):
235
236
                runner.register_hook(
                    CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg))
Kai Chen's avatar
Kai Chen committed
237
            else:
238
239
                runner.register_hook(
                    DistEvalmAPHook(val_dataset_cfg, **eval_cfg))
Kai Chen's avatar
Kai Chen committed
240
241
242
243
244
245
246
247

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)


248
249
250
251
252
253
def _non_dist_train(model,
                    dataset,
                    cfg,
                    validate=False,
                    logger=None,
                    timestamp=None):
254
255
256
257
258
    if validate:
        raise NotImplementedError('Built-in validation is not implemented '
                                  'yet in not-distributed training. Use '
                                  'distributed training or test.py and '
                                  '*eval.py scripts instead.')
Kai Chen's avatar
Kai Chen committed
259
    # prepare data loaders
260
    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
Kai Chen's avatar
Kai Chen committed
261
    data_loaders = [
262
        build_dataloader(
263
            ds,
264
265
266
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            cfg.gpus,
267
            dist=False) for ds in dataset
Kai Chen's avatar
Kai Chen committed
268
269
270
    ]
    # put model on gpus
    model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
Cao Yuhang's avatar
Cao Yuhang committed
271

Kai Chen's avatar
Kai Chen committed
272
    # build runner
273
    optimizer = build_optimizer(model, cfg.optimizer)
274
275
276
277
    runner = Runner(
        model, batch_processor, optimizer, cfg.work_dir, logger=logger)
    # an ugly walkaround to make the .log and .log.json filenames the same
    runner.timestamp = timestamp
Cao Yuhang's avatar
Cao Yuhang committed
278
279
280
281
282
283
284
285
    # fp16 setting
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        optimizer_config = Fp16OptimizerHook(
            **cfg.optimizer_config, **fp16_cfg, distributed=False)
    else:
        optimizer_config = cfg.optimizer_config
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
Kai Chen's avatar
Kai Chen committed
286
                                   cfg.checkpoint_config, cfg.log_config)
myownskyW7's avatar
myownskyW7 committed
287
288
289
290
291

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
myownskyW7's avatar
myownskyW7 committed
292
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)