train.py 3.71 KB
Newer Older
pangjm's avatar
pangjm committed
1
from __future__ import division
Kai Chen's avatar
Kai Chen committed
2

pangjm's avatar
pangjm committed
3
import argparse
Kai Chen's avatar
Kai Chen committed
4
from collections import OrderedDict
pangjm's avatar
pangjm committed
5
6
7

import torch
from mmcv import Config
Kai Chen's avatar
Kai Chen committed
8
9
10
from mmcv.torchpack import Runner, obj_from_dict

from mmdet import datasets
11
12
from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook,
                        MMDataParallel, MMDistributedDataParallel)
Kai Chen's avatar
Kai Chen committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from mmdet.datasets.loader import build_dataloader
from mmdet.models import build_detector


def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                '{} is not a tensor or list of tensors'.format(loss_name))

    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

    log_vars['loss'] = loss
    for name in log_vars:
        log_vars[name] = log_vars[name].item()

    return loss, log_vars


Kai Chen's avatar
Kai Chen committed
37
def batch_processor(model, data, train_mode):
Kai Chen's avatar
Kai Chen committed
38
39
40
41
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

    outputs = dict(
42
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
Kai Chen's avatar
Kai Chen committed
43
44

    return outputs
pangjm's avatar
pangjm committed
45
46
47


def parse_args():
Kai Chen's avatar
Kai Chen committed
48
    parser = argparse.ArgumentParser(description='Train a detector')
pangjm's avatar
pangjm committed
49
50
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
51
52
53
54
        '--validate',
        action='store_true',
        help='whether to add a validate phase')
    parser.add_argument(
55
56
57
58
59
60
61
        '--gpus', type=int, default=1, help='number of gpus to use')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
pangjm's avatar
pangjm committed
62
63
64
65
66
67
    args = parser.parse_args()

    return args


def main():
68
69
    args = parse_args()

Kai Chen's avatar
Kai Chen committed
70
    cfg = Config.fromfile(args.config)
71
    cfg.update(gpus=args.gpus)
Kai Chen's avatar
Kai Chen committed
72
73

    # init distributed environment if necessary
74
75
    if args.launcher == 'none':
        dist = False
pangjm's avatar
pangjm committed
76
        print('Disabled distributed training.')
77
78
79
    else:
        dist = True
        print('Enabled distributed training.')
Kai Chen's avatar
Kai Chen committed
80
        init_dist(args.launcher, **cfg.dist_params)
pangjm's avatar
pangjm committed
81

Kai Chen's avatar
Kai Chen committed
82
83
84
    # prepare data loaders
    train_dataset = obj_from_dict(cfg.data.train, datasets)
    data_loaders = [
85
86
        build_dataloader(train_dataset, cfg.data.imgs_per_gpu,
                         cfg.data.workers_per_gpu, cfg.gpus, dist)
Kai Chen's avatar
Kai Chen committed
87
    ]
pangjm's avatar
pangjm committed
88
    if args.validate:
Kai Chen's avatar
Kai Chen committed
89
90
        val_dataset = obj_from_dict(cfg.data.val, datasets)
        data_loaders.append(
91
92
            build_dataloader(val_dataset, cfg.data.imgs_per_gpu,
                             cfg.data.workers_per_gpu, cfg.gpus, dist))
pangjm's avatar
pangjm committed
93
94

    # build model
Kai Chen's avatar
Kai Chen committed
95
96
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
97
    if dist:
98
        model = MMDistributedDataParallel(model.cuda())
pangjm's avatar
pangjm committed
99
    else:
100
        model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
pangjm's avatar
pangjm committed
101

Kai Chen's avatar
Kai Chen committed
102
    # build runner
pangjm's avatar
pangjm committed
103
104
    runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
                    cfg.log_level)
Kai Chen's avatar
Kai Chen committed
105
106
    # register hooks
    optimizer_config = DistOptimizerHook(
107
        **cfg.optimizer_config) if dist else cfg.optimizer_config
Kai Chen's avatar
Kai Chen committed
108
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
pangjm's avatar
pangjm committed
109
                                   cfg.checkpoint_config, cfg.log_config)
110
    if dist:
pangjm's avatar
pangjm committed
111
        runner.register_hook(DistSamplerSeedHook())
Kai Chen's avatar
Kai Chen committed
112

pangjm's avatar
pangjm committed
113
114
115
116
    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
Kai Chen's avatar
Kai Chen committed
117
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
pangjm's avatar
pangjm committed
118
119


Kai Chen's avatar
Kai Chen committed
120
if __name__ == '__main__':
pangjm's avatar
pangjm committed
121
    main()