train.py 9.79 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
3
4
5
6
from __future__ import division
import argparse
import copy
import os
import time
Wenhao Wu's avatar
Wenhao Wu committed
7
import warnings
8
9
10
11
from os import path as osp

import mmcv
import torch
12
import torch.distributed as dist
zww's avatar
zww committed
13
from mmcv import Config, DictAction
Wenhao Wu's avatar
Wenhao Wu committed
14
from mmcv.runner import get_dist_info, init_dist
zhangwenwei's avatar
zhangwenwei committed
15

16
17
from mmdet import __version__ as mmdet_version
from mmdet3d import __version__ as mmdet3d_version
18
from mmdet3d.apis import init_random_seed, train_model
zhangwenwei's avatar
zhangwenwei committed
19
from mmdet3d.datasets import build_dataset
20
from mmdet3d.models import build_model
zhangwenwei's avatar
zhangwenwei committed
21
from mmdet3d.utils import collect_env, get_root_logger
22
23
from mmdet.apis import set_random_seed
from mmseg import __version__ as mmseg_version
zhangwenwei's avatar
zhangwenwei committed
24

25
26
27
28
29
30
31
try:
    # If mmdet version > 2.20.0, setup_multi_processes would be imported and
    # used from mmdet instead of mmdet3d.
    from mmdet.utils import setup_multi_processes
except ImportError:
    from mmdet3d.utils import setup_multi_processes

zhangwenwei's avatar
zhangwenwei committed
32
33
34
35

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
zhangwenwei's avatar
zhangwenwei committed
36
    parser.add_argument('--work-dir', help='the dir to save logs and models')
zhangwenwei's avatar
zhangwenwei committed
37
    parser.add_argument(
zhangwenwei's avatar
zhangwenwei committed
38
        '--resume-from', help='the checkpoint file to resume from')
39
40
41
42
    parser.add_argument(
        '--auto-resume',
        action='store_true',
        help='resume from the latest checkpoint automatically')
zhangwenwei's avatar
zhangwenwei committed
43
    parser.add_argument(
zww's avatar
zww committed
44
        '--no-validate',
zhangwenwei's avatar
zhangwenwei committed
45
        action='store_true',
zww's avatar
zww committed
46
        help='whether not to evaluate the checkpoint during training')
47
48
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument(
zhangwenwei's avatar
zhangwenwei committed
49
50
        '--gpus',
        type=int,
51
        help='(Deprecated, please use --gpu-id) number of gpus to use '
zhangwenwei's avatar
zhangwenwei committed
52
        '(only applicable to non-distributed training)')
53
54
55
56
    group_gpus.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
57
58
59
60
61
62
63
        help='(Deprecated, please use --gpu-id) ids of gpus to use '
        '(only applicable to non-distributed training)')
    group_gpus.add_argument(
        '--gpu-id',
        type=int,
        default=0,
        help='number of gpus to use '
64
        '(only applicable to non-distributed training)')
zhangwenwei's avatar
zhangwenwei committed
65
    parser.add_argument('--seed', type=int, default=0, help='random seed')
66
67
68
69
    parser.add_argument(
        '--diff-seed',
        action='store_true',
        help='Whether or not set different seeds for different ranks')
zhangwenwei's avatar
zhangwenwei committed
70
71
72
73
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
zww's avatar
zww committed
74
    parser.add_argument(
Wenhao Wu's avatar
Wenhao Wu committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        '--options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file (deprecate), '
        'change to --cfg-options instead.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
zhangwenwei's avatar
zhangwenwei committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument(
        '--autoscale-lr',
        action='store_true',
        help='automatically scale lr with the number of gpus')
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

Wenhao Wu's avatar
Wenhao Wu committed
105
106
107
108
109
110
111
112
    if args.options and args.cfg_options:
        raise ValueError(
            '--options and --cfg-options cannot be both specified, '
            '--options is deprecated in favor of --cfg-options')
    if args.options:
        warnings.warn('--options is deprecated in favor of --cfg-options')
        args.cfg_options = args.options

zhangwenwei's avatar
zhangwenwei committed
113
114
115
116
117
118
119
    return args


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
Wenhao Wu's avatar
Wenhao Wu committed
120
121
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
zww's avatar
zww committed
122

123
124
125
    # set multi-process settings
    setup_multi_processes(cfg)

zhangwenwei's avatar
zhangwenwei committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
140
141
142
143
144
145
146
147

    if args.auto_resume:
        cfg.auto_resume = args.auto_resume
        warnings.warn('`--auto-resume` is only supported when mmdet'
                      'version >= 2.20.0 for 3D detection model or'
                      'mmsegmentation verision >= 0.21.0 for 3D'
                      'segmentation model')

148
149
150
151
152
    if args.gpus is not None:
        cfg.gpu_ids = range(1)
        warnings.warn('`--gpus` is deprecated because we only support '
                      'single GPU mode in non-distributed training. '
                      'Use `gpus=1` now.')
153
    if args.gpu_ids is not None:
154
155
156
157
158
159
160
        cfg.gpu_ids = args.gpu_ids[0:1]
        warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
                      'Because we only support single GPU mode in '
                      'non-distributed training. Use the first GPU '
                      'in `gpu_ids` now.')
    if args.gpus is None and args.gpu_ids is None:
        cfg.gpu_ids = [args.gpu_id]
zhangwenwei's avatar
zhangwenwei committed
161
162
163

    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
164
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8
zhangwenwei's avatar
zhangwenwei committed
165
166
167
168
169
170
171

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)
Wenhao Wu's avatar
Wenhao Wu committed
172
173
174
        # re-set gpu_ids with distributed training mode
        _, world_size = get_dist_info()
        cfg.gpu_ids = range(world_size)
zhangwenwei's avatar
zhangwenwei committed
175
176
177

    # create work_dir
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
Wenhao Wu's avatar
Wenhao Wu committed
178
179
    # dump config
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
zhangwenwei's avatar
zhangwenwei committed
180
181
    # init the logger before other steps
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
zww's avatar
zww committed
182
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
183
184
185
186
187
188
189
190
191
    # specify logger name, if we still use 'mmdet', the output info will be
    # filtered and won't be saved in the log_file
    # TODO: ugly workaround to judge whether we are training det or seg model
    if cfg.model.type in ['EncoderDecoder3D']:
        logger_name = 'mmseg'
    else:
        logger_name = 'mmdet'
    logger = get_root_logger(
        log_file=log_file, log_level=cfg.log_level, name=logger_name)
192

zhangwenwei's avatar
zhangwenwei committed
193
194
195
196
197
    # init the meta dict to record some important information such as
    # environment info and seed, which will be logged
    meta = dict()
    # log env info
    env_info_dict = collect_env()
zww's avatar
zww committed
198
    env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
zhangwenwei's avatar
zhangwenwei committed
199
200
201
202
    dash_line = '-' * 60 + '\n'
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
                dash_line)
    meta['env_info'] = env_info
Wenhao Wu's avatar
Wenhao Wu committed
203
    meta['config'] = cfg.pretty_text
zhangwenwei's avatar
zhangwenwei committed
204
205

    # log some basic info
zww's avatar
zww committed
206
207
    logger.info(f'Distributed training: {distributed}')
    logger.info(f'Config:\n{cfg.pretty_text}')
zhangwenwei's avatar
zhangwenwei committed
208
209

    # set random seeds
210
    seed = init_random_seed(args.seed)
211
    seed = seed + dist.get_rank() if args.diff_seed else seed
212
213
214
215
216
    logger.info(f'Set random seed to {seed}, '
                f'deterministic: {args.deterministic}')
    set_random_seed(seed, deterministic=args.deterministic)
    cfg.seed = seed
    meta['seed'] = seed
Wenhao Wu's avatar
Wenhao Wu committed
217
    meta['exp_name'] = osp.basename(args.config)
zhangwenwei's avatar
zhangwenwei committed
218

219
    model = build_model(
220
221
222
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))
223
    model.init_weights()
224

zww's avatar
zww committed
225
    logger.info(f'Model:\n{model}')
zhangwenwei's avatar
zhangwenwei committed
226
227
228
    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        val_dataset = copy.deepcopy(cfg.data.val)
229
230
231
232
233
234
235
236
237
        # in case we use a dataset wrapper
        if 'dataset' in cfg.data.train:
            val_dataset.pipeline = cfg.data.train.dataset.pipeline
        else:
            val_dataset.pipeline = cfg.data.train.pipeline
        # set test_mode=False here in deep copied config
        # which do not affect AP/AR calculation later
        # refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow  # noqa
        val_dataset.test_mode = False
zhangwenwei's avatar
zhangwenwei committed
238
239
240
241
242
        datasets.append(build_dataset(val_dataset))
    if cfg.checkpoint_config is not None:
        # save mmdet version, config file content and class names in
        # checkpoints as meta data
        cfg.checkpoint_config.meta = dict(
243
244
245
            mmdet_version=mmdet_version,
            mmseg_version=mmseg_version,
            mmdet3d_version=mmdet3d_version,
zww's avatar
zww committed
246
            config=cfg.pretty_text,
247
248
249
            CLASSES=datasets[0].CLASSES,
            PALETTE=datasets[0].PALETTE  # for segmentors
            if hasattr(datasets[0], 'PALETTE') else None)
zhangwenwei's avatar
zhangwenwei committed
250
251
    # add an attribute for visualization convenience
    model.CLASSES = datasets[0].CLASSES
252
    train_model(
zhangwenwei's avatar
zhangwenwei committed
253
254
255
256
        model,
        datasets,
        cfg,
        distributed=distributed,
zww's avatar
zww committed
257
        validate=(not args.no_validate),
zhangwenwei's avatar
zhangwenwei committed
258
259
260
261
262
263
        timestamp=timestamp,
        meta=meta)


if __name__ == '__main__':
    main()