train.py 5.68 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
from __future__ import division
import argparse
import copy
4
import logging
zhangwenwei's avatar
zhangwenwei committed
5
6
7
8
9
10
import os
import os.path as osp
import time

import mmcv
import torch
zww's avatar
zww committed
11
from mmcv import Config, DictAction
zhangwenwei's avatar
zhangwenwei committed
12
13
14
15
16
from mmcv.runner import init_dist

from mmdet3d import __version__
from mmdet3d.datasets import build_dataset
from mmdet3d.models import build_detector
zhangwenwei's avatar
zhangwenwei committed
17
from mmdet3d.utils import collect_env, get_root_logger
zhangwenwei's avatar
zhangwenwei committed
18
from mmdet.apis import set_random_seed, train_detector
zhangwenwei's avatar
zhangwenwei committed
19
20
21
22
23


def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
zhangwenwei's avatar
zhangwenwei committed
24
    parser.add_argument('--work-dir', help='the dir to save logs and models')
zhangwenwei's avatar
zhangwenwei committed
25
    parser.add_argument(
zhangwenwei's avatar
zhangwenwei committed
26
        '--resume-from', help='the checkpoint file to resume from')
zhangwenwei's avatar
zhangwenwei committed
27
    parser.add_argument(
zww's avatar
zww committed
28
        '--no-validate',
zhangwenwei's avatar
zhangwenwei committed
29
        action='store_true',
zww's avatar
zww committed
30
        help='whether not to evaluate the checkpoint during training')
31
32
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument(
zhangwenwei's avatar
zhangwenwei committed
33
34
35
36
        '--gpus',
        type=int,
        help='number of gpus to use '
        '(only applicable to non-distributed training)')
37
38
39
40
41
42
    group_gpus.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='ids of gpus to use '
        '(only applicable to non-distributed training)')
zhangwenwei's avatar
zhangwenwei committed
43
44
45
46
47
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
zww's avatar
zww committed
48
49
    parser.add_argument(
        '--options', nargs='+', action=DictAction, help='arguments in dict')
zhangwenwei's avatar
zhangwenwei committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument(
        '--autoscale-lr',
        action='store_true',
        help='automatically scale lr with the number of gpus')
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
zww's avatar
zww committed
71
72
73
    if args.options is not None:
        cfg.merge_from_dict(args.options)

zhangwenwei's avatar
zhangwenwei committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
88
89
90
91
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
zhangwenwei's avatar
zhangwenwei committed
92
93
94

    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
95
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
zhangwenwei's avatar
zhangwenwei committed
96
97
98
99
100
101
102
103
104
105
106
107

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
zww's avatar
zww committed
108
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
zhangwenwei's avatar
zhangwenwei committed
109
110
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)

111
112
113
114
    # add a logging filter
    logging_filter = logging.Filter('mmdet')
    logging_filter.filter = lambda record: record.find('mmdet') != -1

zhangwenwei's avatar
zhangwenwei committed
115
116
117
118
119
    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
zww's avatar
zww committed
120
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
zhangwenwei's avatar
zhangwenwei committed
121
122
123
124
125
126
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info

    # log some basic info
zww's avatar
zww committed
127
128
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')
zhangwenwei's avatar
zhangwenwei committed
129
130
131

    # set random seeds
    if args.seed is not None:
zww's avatar
zww committed
132
133
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
zhangwenwei's avatar
zhangwenwei committed
134
135
136
137
138
139
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed

    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
zww's avatar
zww committed
140
    logger.info(f'Model:\n{model}')
zhangwenwei's avatar
zhangwenwei committed
141
142
143
144
145
146
147
148
149
150
    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
        val_dataset.pipeline = cfg.data.train.pipeline
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
            mmdet_version=__version__,
zww's avatar
zww committed
151
            config=cfg.pretty_text,
zhangwenwei's avatar
zhangwenwei committed
152
153
154
155
156
157
158
159
            CLASSES=datasets[0].CLASSES)
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
    train_detector(
        model,
        datasets,
        cfg,
        distributed=distributed,
zww's avatar
zww committed
160
        validate=(not args.no_validate),
zhangwenwei's avatar
zhangwenwei committed
161
162
163
164
165
166
        timestamp=timestamp,
        meta=meta)


if __name__ == '__main__':
    main()