train.py 4.5 KB
Newer Older
pangjm's avatar
pangjm committed
1
from __future__ import division
Kai Chen's avatar
Kai Chen committed
2

pangjm's avatar
pangjm committed
3
import argparse
4
import logging
Kai Chen's avatar
Kai Chen committed
5
from collections import OrderedDict
pangjm's avatar
pangjm committed
6
7
8

import torch
from mmcv import Config
Kai Chen's avatar
Kai Chen committed
9
10
11
from mmcv.torchpack import Runner, obj_from_dict

from mmdet import datasets
12
from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook,
Kai Chen's avatar
Kai Chen committed
13
                        MMDataParallel, MMDistributedDataParallel,
14
                        CocoDistEvalRecallHook, CocoDistEvalmAPHook)
Kai Chen's avatar
Kai Chen committed
15
from mmdet.datasets.loader import build_dataloader
Kai Chen's avatar
Kai Chen committed
16
from mmdet.models import build_detector, RPN
Kai Chen's avatar
Kai Chen committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                '{} is not a tensor or list of tensors'.format(loss_name))

    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

    log_vars['loss'] = loss
    for name in log_vars:
        log_vars[name] = log_vars[name].item()

    return loss, log_vars


Kai Chen's avatar
Kai Chen committed
39
def batch_processor(model, data, train_mode):
Kai Chen's avatar
Kai Chen committed
40
41
42
43
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

    outputs = dict(
44
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
Kai Chen's avatar
Kai Chen committed
45
46

    return outputs
pangjm's avatar
pangjm committed
47
48


49
50
51
52
53
54
55
def get_logger(log_level):
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
    logger = logging.getLogger()
    return logger


pangjm's avatar
pangjm committed
56
def parse_args():
Kai Chen's avatar
Kai Chen committed
57
    parser = argparse.ArgumentParser(description='Train a detector')
pangjm's avatar
pangjm committed
58
    parser.add_argument('config', help='train config file path')
59
    parser.add_argument('--work_dir', help='the dir to save logs and models')
pangjm's avatar
pangjm committed
60
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
61
62
63
64
        '--validate',
        action='store_true',
        help='whether to add a validate phase')
    parser.add_argument(
65
66
67
68
69
70
71
        '--gpus', type=int, default=1, help='number of gpus to use')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
pangjm's avatar
pangjm committed
72
73
74
75
76
77
    args = parser.parse_args()

    return args


def main():
78
79
    args = parse_args()

Kai Chen's avatar
Kai Chen committed
80
    cfg = Config.fromfile(args.config)
81
82
83
84
85
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    cfg.gpus = args.gpus

    logger = get_logger(cfg.log_level)
Kai Chen's avatar
Kai Chen committed
86
87

    # init distributed environment if necessary
88
89
    if args.launcher == 'none':
        dist = False
90
        logger.info('Disabled distributed training.')
91
92
    else:
        dist = True
Kai Chen's avatar
Kai Chen committed
93
        init_dist(args.launcher, **cfg.dist_params)
94
95
96
        if torch.distributed.get_rank() != 0:
            logger.setLevel('ERROR')
        logger.info('Enabled distributed training.')
pangjm's avatar
pangjm committed
97

Kai Chen's avatar
Kai Chen committed
98
99
100
    # prepare data loaders
    train_dataset = obj_from_dict(cfg.data.train, datasets)
    data_loaders = [
101
102
        build_dataloader(train_dataset, cfg.data.imgs_per_gpu,
                         cfg.data.workers_per_gpu, cfg.gpus, dist)
Kai Chen's avatar
Kai Chen committed
103
    ]
pangjm's avatar
pangjm committed
104
    if args.validate:
Kai Chen's avatar
Kai Chen committed
105
106
        val_dataset = obj_from_dict(cfg.data.val, datasets)
        data_loaders.append(
107
108
            build_dataloader(val_dataset, cfg.data.imgs_per_gpu,
                             cfg.data.workers_per_gpu, cfg.gpus, dist))
pangjm's avatar
pangjm committed
109
110

    # build model
Kai Chen's avatar
Kai Chen committed
111
112
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
113
    if dist:
114
        model = MMDistributedDataParallel(model.cuda())
pangjm's avatar
pangjm committed
115
    else:
116
        model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
pangjm's avatar
pangjm committed
117

Kai Chen's avatar
Kai Chen committed
118
    # build runner
pangjm's avatar
pangjm committed
119
120
    runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
                    cfg.log_level)
Kai Chen's avatar
Kai Chen committed
121
122
    # register hooks
    optimizer_config = DistOptimizerHook(
123
        **cfg.optimizer_config) if dist else cfg.optimizer_config
Kai Chen's avatar
Kai Chen committed
124
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
pangjm's avatar
pangjm committed
125
                                   cfg.checkpoint_config, cfg.log_config)
126
    if dist:
pangjm's avatar
pangjm committed
127
        runner.register_hook(DistSamplerSeedHook())
Kai Chen's avatar
Kai Chen committed
128
129
        # register eval hooks
        if isinstance(model.module, RPN):
130
            runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
Kai Chen's avatar
Kai Chen committed
131
132
        elif cfg.data.val.type == 'CocoDataset':
            runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
Kai Chen's avatar
Kai Chen committed
133

pangjm's avatar
pangjm committed
134
135
136
137
    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
Kai Chen's avatar
Kai Chen committed
138
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
pangjm's avatar
pangjm committed
139
140


Kai Chen's avatar
Kai Chen committed
141
if __name__ == '__main__':
pangjm's avatar
pangjm committed
142
    main()