train.py 4.03 KB
Newer Older
pangjm's avatar
pangjm committed
1
from __future__ import division
Kai Chen's avatar
Kai Chen committed
2

pangjm's avatar
pangjm committed
3
import argparse
Kai Chen's avatar
Kai Chen committed
4
from collections import OrderedDict
pangjm's avatar
pangjm committed
5
6
7

import torch
from mmcv import Config
Kai Chen's avatar
Kai Chen committed
8
9
10
from mmcv.torchpack import Runner, obj_from_dict

from mmdet import datasets
11
from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook,
Kai Chen's avatar
Kai Chen committed
12
13
                        MMDataParallel, MMDistributedDataParallel,
                        DistEvalRecallHook, CocoDistEvalmAPHook)
Kai Chen's avatar
Kai Chen committed
14
from mmdet.datasets.loader import build_dataloader
Kai Chen's avatar
Kai Chen committed
15
from mmdet.models import build_detector, RPN
Kai Chen's avatar
Kai Chen committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                '{} is not a tensor or list of tensors'.format(loss_name))

    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

    log_vars['loss'] = loss
    for name in log_vars:
        log_vars[name] = log_vars[name].item()

    return loss, log_vars


Kai Chen's avatar
Kai Chen committed
38
def batch_processor(model, data, train_mode):
Kai Chen's avatar
Kai Chen committed
39
40
41
42
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

    outputs = dict(
43
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
Kai Chen's avatar
Kai Chen committed
44
45

    return outputs
pangjm's avatar
pangjm committed
46
47
48


def parse_args():
Kai Chen's avatar
Kai Chen committed
49
    parser = argparse.ArgumentParser(description='Train a detector')
pangjm's avatar
pangjm committed
50
51
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
52
53
54
55
        '--validate',
        action='store_true',
        help='whether to add a validate phase')
    parser.add_argument(
56
57
58
59
60
61
62
        '--gpus', type=int, default=1, help='number of gpus to use')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
pangjm's avatar
pangjm committed
63
64
65
66
67
68
    args = parser.parse_args()

    return args


def main():
69
70
    args = parse_args()

Kai Chen's avatar
Kai Chen committed
71
    cfg = Config.fromfile(args.config)
72
    cfg.update(gpus=args.gpus)
Kai Chen's avatar
Kai Chen committed
73
74

    # init distributed environment if necessary
75
76
    if args.launcher == 'none':
        dist = False
pangjm's avatar
pangjm committed
77
        print('Disabled distributed training.')
78
79
80
    else:
        dist = True
        print('Enabled distributed training.')
Kai Chen's avatar
Kai Chen committed
81
        init_dist(args.launcher, **cfg.dist_params)
pangjm's avatar
pangjm committed
82

Kai Chen's avatar
Kai Chen committed
83
84
85
    # prepare data loaders
    train_dataset = obj_from_dict(cfg.data.train, datasets)
    data_loaders = [
86
87
        build_dataloader(train_dataset, cfg.data.imgs_per_gpu,
                         cfg.data.workers_per_gpu, cfg.gpus, dist)
Kai Chen's avatar
Kai Chen committed
88
    ]
pangjm's avatar
pangjm committed
89
    if args.validate:
Kai Chen's avatar
Kai Chen committed
90
91
        val_dataset = obj_from_dict(cfg.data.val, datasets)
        data_loaders.append(
92
93
            build_dataloader(val_dataset, cfg.data.imgs_per_gpu,
                             cfg.data.workers_per_gpu, cfg.gpus, dist))
pangjm's avatar
pangjm committed
94
95

    # build model
Kai Chen's avatar
Kai Chen committed
96
97
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
98
    if dist:
99
        model = MMDistributedDataParallel(model.cuda())
pangjm's avatar
pangjm committed
100
    else:
101
        model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
pangjm's avatar
pangjm committed
102

Kai Chen's avatar
Kai Chen committed
103
    # build runner
pangjm's avatar
pangjm committed
104
105
    runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
                    cfg.log_level)
Kai Chen's avatar
Kai Chen committed
106
107
    # register hooks
    optimizer_config = DistOptimizerHook(
108
        **cfg.optimizer_config) if dist else cfg.optimizer_config
Kai Chen's avatar
Kai Chen committed
109
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
pangjm's avatar
pangjm committed
110
                                   cfg.checkpoint_config, cfg.log_config)
111
    if dist:
pangjm's avatar
pangjm committed
112
        runner.register_hook(DistSamplerSeedHook())
Kai Chen's avatar
Kai Chen committed
113
114
115
116
117
        # register eval hooks
        if isinstance(model.module, RPN):
            runner.register_hook(DistEvalRecallHook(cfg.data.val))
        elif cfg.data.val.type == 'CocoDataset':
            runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
Kai Chen's avatar
Kai Chen committed
118

pangjm's avatar
pangjm committed
119
120
121
122
    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
Kai Chen's avatar
Kai Chen committed
123
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
pangjm's avatar
pangjm committed
124
125


Kai Chen's avatar
Kai Chen committed
126
if __name__ == '__main__':
pangjm's avatar
pangjm committed
127
    main()