train.py 4.82 KB
Newer Older
pangjm's avatar
pangjm committed
1
from __future__ import division
Kai Chen's avatar
Kai Chen committed
2

pangjm's avatar
pangjm committed
3
import argparse
4
import logging
5
import random
Kai Chen's avatar
Kai Chen committed
6
from collections import OrderedDict
pangjm's avatar
pangjm committed
7

Kai Chen's avatar
Kai Chen committed
8
import numpy as np
pangjm's avatar
pangjm committed
9
10
import torch
from mmcv import Config
Kai Chen's avatar
Kai Chen committed
11
from mmcv.runner import Runner, obj_from_dict
Kai Chen's avatar
Kai Chen committed
12

13
from mmdet import datasets, __version__
14
from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook,
Kai Chen's avatar
Kai Chen committed
15
                        MMDataParallel, MMDistributedDataParallel,
16
                        CocoDistEvalRecallHook, CocoDistEvalmAPHook)
Kai Chen's avatar
Kai Chen committed
17
from mmdet.datasets import build_dataloader
Kai Chen's avatar
Kai Chen committed
18
from mmdet.models import build_detector, RPN
Kai Chen's avatar
Kai Chen committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                '{} is not a tensor or list of tensors'.format(loss_name))

    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

    log_vars['loss'] = loss
    for name in log_vars:
        log_vars[name] = log_vars[name].item()

    return loss, log_vars


Kai Chen's avatar
Kai Chen committed
41
def batch_processor(model, data, train_mode):
Kai Chen's avatar
Kai Chen committed
42
43
44
45
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

    outputs = dict(
46
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
Kai Chen's avatar
Kai Chen committed
47
48

    return outputs
pangjm's avatar
pangjm committed
49
50


51
52
53
54
55
56
57
def get_logger(log_level):
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
    logger = logging.getLogger()
    return logger


Kai Chen's avatar
Kai Chen committed
58
def set_random_seed(seed):
59
    random.seed(seed)
Kai Chen's avatar
Kai Chen committed
60
61
62
63
64
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


pangjm's avatar
pangjm committed
65
def parse_args():
Kai Chen's avatar
Kai Chen committed
66
    parser = argparse.ArgumentParser(description='Train a detector')
pangjm's avatar
pangjm committed
67
    parser.add_argument('config', help='train config file path')
68
    parser.add_argument('--work_dir', help='the dir to save logs and models')
pangjm's avatar
pangjm committed
69
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
70
71
72
73
        '--validate',
        action='store_true',
        help='whether to add a validate phase')
    parser.add_argument(
74
        '--gpus', type=int, default=1, help='number of gpus to use')
Kai Chen's avatar
Kai Chen committed
75
    parser.add_argument('--seed', type=int, help='random seed')
76
77
78
79
80
81
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
pangjm's avatar
pangjm committed
82
83
84
85
86
87
    args = parser.parse_args()

    return args


def main():
88
89
    args = parse_args()

Kai Chen's avatar
Kai Chen committed
90
    cfg = Config.fromfile(args.config)
91
92
93
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    cfg.gpus = args.gpus
94
    # save mmdet version in checkpoint as meta data
Kai Chen's avatar
Kai Chen committed
95
96
    cfg.checkpoint_config.meta = dict(
        mmdet_version=__version__, config=cfg.text)
97
98

    logger = get_logger(cfg.log_level)
Kai Chen's avatar
Kai Chen committed
99

Kai Chen's avatar
Kai Chen committed
100
101
102
103
104
    # set random seed if specified
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

Kai Chen's avatar
Kai Chen committed
105
    # init distributed environment if necessary
106
107
    if args.launcher == 'none':
        dist = False
108
        logger.info('Non-distributed training.')
109
110
    else:
        dist = True
Kai Chen's avatar
Kai Chen committed
111
        init_dist(args.launcher, **cfg.dist_params)
112
113
        if torch.distributed.get_rank() != 0:
            logger.setLevel('ERROR')
114
        logger.info('Distributed training.')
pangjm's avatar
pangjm committed
115

Kai Chen's avatar
Kai Chen committed
116
117
118
    # prepare data loaders
    train_dataset = obj_from_dict(cfg.data.train, datasets)
    data_loaders = [
119
120
        build_dataloader(train_dataset, cfg.data.imgs_per_gpu,
                         cfg.data.workers_per_gpu, cfg.gpus, dist)
Kai Chen's avatar
Kai Chen committed
121
    ]
pangjm's avatar
pangjm committed
122
123

    # build model
Kai Chen's avatar
Kai Chen committed
124
125
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
126
    if dist:
127
        model = MMDistributedDataParallel(model.cuda())
pangjm's avatar
pangjm committed
128
    else:
129
        model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
pangjm's avatar
pangjm committed
130

Kai Chen's avatar
Kai Chen committed
131
    # build runner
pangjm's avatar
pangjm committed
132
133
    runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
                    cfg.log_level)
pangjm's avatar
pangjm committed
134

Kai Chen's avatar
Kai Chen committed
135
136
    # register hooks
    optimizer_config = DistOptimizerHook(
137
        **cfg.optimizer_config) if dist else cfg.optimizer_config
Kai Chen's avatar
Kai Chen committed
138
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
pangjm's avatar
pangjm committed
139
                                   cfg.checkpoint_config, cfg.log_config)
140
    if dist:
pangjm's avatar
pangjm committed
141
        runner.register_hook(DistSamplerSeedHook())
Kai Chen's avatar
Kai Chen committed
142
        # register eval hooks
pangjm's avatar
pangjm committed
143
144
145
146
147
        if args.validate:
            if isinstance(model.module, RPN):
                runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
            elif cfg.data.val.type == 'CocoDataset':
                runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
Kai Chen's avatar
Kai Chen committed
148

pangjm's avatar
pangjm committed
149
150
151
152
    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
Kai Chen's avatar
Kai Chen committed
153
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
pangjm's avatar
pangjm committed
154
155


Kai Chen's avatar
Kai Chen committed
156
if __name__ == '__main__':
pangjm's avatar
pangjm committed
157
    main()