Commit 1039ad0e authored by ChaimZhu's avatar ChaimZhu
Browse files

add amp

parent 3a939d7f
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import argparse import argparse
import logging
import os import os
import os.path as osp import os.path as osp
from mmengine.config import Config, DictAction from mmengine.config import Config, DictAction
from mmengine.logging import print_log
from mmengine.runner import Runner from mmengine.runner import Runner
from mmdet3d.utils import register_all_modules from mmdet3d.utils import register_all_modules
...@@ -13,6 +15,11 @@ def parse_args(): ...@@ -13,6 +15,11 @@ def parse_args():
parser = argparse.ArgumentParser(description='Train a 3D detector') parser = argparse.ArgumentParser(description='Train a 3D detector')
parser.add_argument('config', help='train config file path') parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument( parser.add_argument(
'--cfg-options', '--cfg-options',
nargs='+', nargs='+',
...@@ -40,11 +47,13 @@ def main(): ...@@ -40,11 +47,13 @@ def main():
# register all modules in mmdet3d into the registries # register all modules in mmdet3d into the registries
# do not init the default scope here because it will be init in the runner # do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False) register_all_modules(init_default_scope=False)
# load config # load config
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher cfg.launcher = args.launcher
if args.cfg_options is not None: if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options) cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename # work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None: if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None # update configs according to CLI args if args.work_dir is not None
...@@ -53,8 +62,25 @@ def main(): ...@@ -53,8 +62,25 @@ def main():
# use config filename as default work_dir if cfg.work_dir is None # use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs', cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0]) osp.splitext(osp.basename(args.config))[0])
# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.type
if optim_wrapper == 'AmpOptimWrapper':
print_log(
'AMP training is already enabled in your config.',
logger='current',
level=logging.WARNING)
else:
assert optim_wrapper == 'OptimWrapper', (
'`--amp` is only supported when the optimizer wrapper type is '
f'`OptimWrapper` but got {optim_wrapper}.')
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
# build the runner from config # build the runner from config
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
# start training # start training
runner.train() runner.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment