train.py 1.52 KB
Newer Older
pangjm's avatar
pangjm committed
1
from __future__ import division
Kai Chen's avatar
Kai Chen committed
2

pangjm's avatar
pangjm committed
3
4
import argparse
from mmcv import Config
myownskyW7's avatar
myownskyW7 committed
5
from mmcv.runner import obj_from_dict
6

myownskyW7's avatar
myownskyW7 committed
7
8
9
from mmdet import datasets
from mmdet.api import train_detector
from mmdet.models import build_detector
Kai Chen's avatar
Kai Chen committed
10
11


pangjm's avatar
pangjm committed
12
def parse_args():
Kai Chen's avatar
Kai Chen committed
13
    parser = argparse.ArgumentParser(description='Train a detector')
pangjm's avatar
pangjm committed
14
    parser.add_argument('config', help='train config file path')
15
    parser.add_argument('--work_dir', help='the dir to save logs and models')
pangjm's avatar
pangjm committed
16
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
17
18
19
20
        '--validate',
        action='store_true',
        help='whether to add a validate phase')
    parser.add_argument(
21
        '--gpus', type=int, default=1, help='number of gpus to use')
Kai Chen's avatar
Kai Chen committed
22
    parser.add_argument('--seed', type=int, help='random seed')
23
24
25
26
27
28
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
pangjm's avatar
pangjm committed
29
30
31
32
33
34
    args = parser.parse_args()

    return args


def main():
35
    args = parse_args()
Kai Chen's avatar
Kai Chen committed
36
    cfg = Config.fromfile(args.config)
37
38
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
myownskyW7's avatar
myownskyW7 committed
39
    cfg.validate = args.validate
40
    cfg.gpus = args.gpus
myownskyW7's avatar
myownskyW7 committed
41
42
43
    cfg.seed = args.seed
    cfg.launcher = args.launcher
    cfg.local_rank = args.local_rank
pangjm's avatar
pangjm committed
44
    # build model
Kai Chen's avatar
Kai Chen committed
45
46
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
myownskyW7's avatar
myownskyW7 committed
47
48
    train_dataset = obj_from_dict(cfg.data.train, datasets)
    train_detector(model, train_dataset, cfg)
pangjm's avatar
pangjm committed
49
50


Kai Chen's avatar
Kai Chen committed
51
if __name__ == '__main__':
pangjm's avatar
pangjm committed
52
    main()