train.py 4.02 KB
Newer Older
myownskyW7's avatar
myownskyW7 committed
1
2
3
4
5
6
7
8
from __future__ import division

from collections import OrderedDict

import torch
from mmcv.runner import Runner, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

9
from mmdet import datasets
Kai Chen's avatar
Kai Chen committed
10
11
from mmdet.core import (DistOptimizerHook, DistEvalmAPHook,
                        CocoDistEvalRecallHook, CocoDistEvalmAPHook)
myownskyW7's avatar
myownskyW7 committed
12
13
from mmdet.datasets import build_dataloader
from mmdet.models import RPN
Kai Chen's avatar
Kai Chen committed
14
from .env import get_root_logger
myownskyW7's avatar
myownskyW7 committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                '{} is not a tensor or list of tensors'.format(loss_name))

    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

    log_vars['loss'] = loss
    for name in log_vars:
        log_vars[name] = log_vars[name].item()

    return loss, log_vars


def batch_processor(model, data, train_mode):
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

    return outputs


Kai Chen's avatar
Kai Chen committed
47
48
49
50
51
52
53
54
def train_detector(model,
                   dataset,
                   cfg,
                   distributed=False,
                   validate=False,
                   logger=None):
    if logger is None:
        logger = get_root_logger(cfg.log_level)
myownskyW7's avatar
myownskyW7 committed
55

Kai Chen's avatar
Kai Chen committed
56
57
58
    # start training
    if distributed:
        _dist_train(model, dataset, cfg, validate=validate)
myownskyW7's avatar
myownskyW7 committed
59
    else:
Kai Chen's avatar
Kai Chen committed
60
        _non_dist_train(model, dataset, cfg, validate=validate)
myownskyW7's avatar
myownskyW7 committed
61

Kai Chen's avatar
Kai Chen committed
62
63

def _dist_train(model, dataset, cfg, validate=False):
myownskyW7's avatar
myownskyW7 committed
64
65
    # prepare data loaders
    data_loaders = [
Kai Chen's avatar
Kai Chen committed
66
67
68
69
70
        build_dataloader(
            dataset,
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            dist=True)
myownskyW7's avatar
myownskyW7 committed
71
72
    ]
    # put model on gpus
Kai Chen's avatar
Kai Chen committed
73
    model = MMDistributedDataParallel(model.cuda())
myownskyW7's avatar
myownskyW7 committed
74
75
76
77
    # build runner
    runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
                    cfg.log_level)
    # register hooks
Kai Chen's avatar
Kai Chen committed
78
    optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
myownskyW7's avatar
myownskyW7 committed
79
80
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)
Kai Chen's avatar
Kai Chen committed
81
82
83
    runner.register_hook(DistSamplerSeedHook())
    # register eval hooks
    if validate:
84
        val_dataset_cfg = cfg.data.val
Kai Chen's avatar
Kai Chen committed
85
        if isinstance(model.module, RPN):
Kai Chen's avatar
Kai Chen committed
86
            # TODO: implement recall hooks for other datasets
87
            runner.register_hook(CocoDistEvalRecallHook(val_dataset_cfg))
Kai Chen's avatar
Kai Chen committed
88
        else:
89
90
91
            dataset_type = getattr(datasets, val_dataset_cfg.type)
            if issubclass(dataset_type, datasets.CocoDataset):
                runner.register_hook(CocoDistEvalmAPHook(val_dataset_cfg))
Kai Chen's avatar
Kai Chen committed
92
            else:
93
                runner.register_hook(DistEvalmAPHook(val_dataset_cfg))
Kai Chen's avatar
Kai Chen committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)


def _non_dist_train(model, dataset, cfg, validate=False):
    # prepare data loaders
    data_loaders = [
        build_dataloader(
            dataset,
            cfg.data.imgs_per_gpu,
            cfg.data.workers_per_gpu,
            cfg.gpus,
            dist=False)
    ]
    # put model on gpus
    model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
    # build runner
    runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
                    cfg.log_level)
    runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
                                   cfg.checkpoint_config, cfg.log_config)
myownskyW7's avatar
myownskyW7 committed
119
120
121
122
123

    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
myownskyW7's avatar
myownskyW7 committed
124
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)