train.py 8.26 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
from __future__ import division
zhangwenwei's avatar
zhangwenwei committed
3

zhangwenwei's avatar
zhangwenwei committed
4
5
import argparse
import copy
zhangwenwei's avatar
zhangwenwei committed
6
import mmcv
zhangwenwei's avatar
zhangwenwei committed
7
8
9
import os
import time
import torch
Wenhao Wu's avatar
Wenhao Wu committed
10
import warnings
zww's avatar
zww committed
11
from mmcv import Config, DictAction
Wenhao Wu's avatar
Wenhao Wu committed
12
from mmcv.runner import get_dist_info, init_dist
zhangwenwei's avatar
zhangwenwei committed
13
from os import path as osp
zhangwenwei's avatar
zhangwenwei committed
14

15
16
17
from mmdet import __version__ as mmdet_version
from mmdet3d import __version__ as mmdet3d_version
from mmdet3d.apis import train_model
zhangwenwei's avatar
zhangwenwei committed
18
from mmdet3d.datasets import build_dataset
19
from mmdet3d.models import build_model
zhangwenwei's avatar
zhangwenwei committed
20
from mmdet3d.utils import collect_env, get_root_logger
21
22
from mmdet.apis import set_random_seed
from mmseg import __version__ as mmseg_version
zhangwenwei's avatar
zhangwenwei committed
23
24
25
26
27


def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
zhangwenwei's avatar
zhangwenwei committed
28
    parser.add_argument('--work-dir', help='the dir to save logs and models')
zhangwenwei's avatar
zhangwenwei committed
29
    parser.add_argument(
zhangwenwei's avatar
zhangwenwei committed
30
        '--resume-from', help='the checkpoint file to resume from')
zhangwenwei's avatar
zhangwenwei committed
31
    parser.add_argument(
zww's avatar
zww committed
32
        '--no-validate',
zhangwenwei's avatar
zhangwenwei committed
33
        action='store_true',
zww's avatar
zww committed
34
        help='whether not to evaluate the checkpoint during training')
35
36
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument(
zhangwenwei's avatar
zhangwenwei committed
37
38
39
40
        '--gpus',
        type=int,
        help='number of gpus to use '
        '(only applicable to non-distributed training)')
41
42
43
44
45
46
    group_gpus.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='ids of gpus to use '
        '(only applicable to non-distributed training)')
zhangwenwei's avatar
zhangwenwei committed
47
48
49
50
51
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
zww's avatar
zww committed
52
    parser.add_argument(
Wenhao Wu's avatar
Wenhao Wu committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        '--options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file (deprecate), '
        'change to --cfg-options instead.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
zhangwenwei's avatar
zhangwenwei committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument(
        '--autoscale-lr',
        action='store_true',
        help='automatically scale lr with the number of gpus')
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

Wenhao Wu's avatar
Wenhao Wu committed
83
84
85
86
87
88
89
90
    if args.options and args.cfg_options:
        raise ValueError(
            '--options and --cfg-options cannot be both specified, '
            '--options is deprecated in favor of --cfg-options')
    if args.options:
        warnings.warn('--options is deprecated in favor of --cfg-options')
        args.cfg_options = args.options

zhangwenwei's avatar
zhangwenwei committed
91
92
93
94
95
96
97
    return args


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
Wenhao Wu's avatar
Wenhao Wu committed
98
99
100
101
102
103
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # import modules from string list.
    if cfg.get('custom_imports', None):
        from mmcv.utils import import_modules_from_strings
        import_modules_from_strings(**cfg['custom_imports'])
zww's avatar
zww committed
104

zhangwenwei's avatar
zhangwenwei committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
119
120
121
122
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
zhangwenwei's avatar
zhangwenwei committed
123
124
125

    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
126
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
zhangwenwei's avatar
zhangwenwei committed
127
128
129
130
131
132
133

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
Wenhao Wu's avatar
Wenhao Wu committed
134
135
136
        # re-set gpu_ids with distributed training mode
        _, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)
zhangwenwei's avatar
zhangwenwei committed
137
138
139

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
Wenhao Wu's avatar
Wenhao Wu committed
140
141
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
zhangwenwei's avatar
zhangwenwei committed
142
143
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
zww's avatar
zww committed
144
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
145
146
147
148
149
150
151
152
153
    # specify logger name, if we still use 'mmdet', the output info will be
    # filtered and won't be saved in the log_file
    # TODO: ugly workaround to judge whether we are training det or seg model
    if cfg.model.type in ['EncoderDecoder3D']:
        logger_name = 'mmseg'
    else:
        logger_name = 'mmdet'
    logger = get_root_logger(
        log_file=log_file, log_level=cfg.log_level, name=logger_name)
154

zhangwenwei's avatar
zhangwenwei committed
155
156
157
158
159
    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
zww's avatar
zww committed
160
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
zhangwenwei's avatar
zhangwenwei committed
161
162
163
164
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info
Wenhao Wu's avatar
Wenhao Wu committed
165
    meta['config'] = cfg.pretty_text
zhangwenwei's avatar
zhangwenwei committed
166
167

    # log some basic info
zww's avatar
zww committed
168
169
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')
zhangwenwei's avatar
zhangwenwei committed
170
171
172

    # set random seeds
    if args.seed is not None:
zww's avatar
zww committed
173
174
        logger.info(f'Set random seed to {args.seed}, '
                    f'deterministic: {args.deterministic}')
zhangwenwei's avatar
zhangwenwei committed
175
176
177
        set_random_seed(args.seed, deterministic=args.deterministic)
    cfg.seed = args.seed
    meta['seed'] = args.seed
Wenhao Wu's avatar
Wenhao Wu committed
178
    meta['exp_name'] = osp.basename(args.config)
zhangwenwei's avatar
zhangwenwei committed
179

180
    model = build_model(
181
182
183
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))
184
    model.init_weights()
185

zww's avatar
zww committed
186
    logger.info(f'Model:\n{model}')
zhangwenwei's avatar
zhangwenwei committed
187
188
189
    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
190
191
192
193
194
195
196
197
198
        # in case we use a dataset wrapper
        if 'dataset' in cfg.data.train:
            val_dataset.pipeline = cfg.data.train.dataset.pipeline
        else:
            val_dataset.pipeline = cfg.data.train.pipeline
        # set test_mode=False here in deep copied config
        # which do not affect AP/AR calculation later
        # refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow  # noqa
        val_dataset.test_mode = False
zhangwenwei's avatar
zhangwenwei committed
199
200
201
202
203
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
204
205
206
            mmdet_version=mmdet_version,
            mmseg_version=mmseg_version,
            mmdet3d_version=mmdet3d_version,
zww's avatar
zww committed
207
            config=cfg.pretty_text,
208
209
210
            CLASSES=datasets[0].CLASSES,
            PALETTE=datasets[0].PALETTE  # for segmentors
            if hasattr(datasets[0], 'PALETTE') else None)
zhangwenwei's avatar
zhangwenwei committed
211
212
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
213
    train_model(
zhangwenwei's avatar
zhangwenwei committed
214
215
216
217
        model,
        datasets,
        cfg,
        distributed=distributed,
zww's avatar
zww committed
218
        validate=(not args.no_validate),
zhangwenwei's avatar
zhangwenwei committed
219
220
221
222
223
224
        timestamp=timestamp,
        meta=meta)


if __name__ == '__main__':
    main()