Unverified Commit 391f99c5 authored by VVsssssk's avatar VVsssssk Committed by GitHub
Browse files

[Fix] Fix tools to support auto-scalr lr, custom runner, auto resume and the...

[Fix] Fix tools to support auto-scalr lr, custom runner, auto resume and the markdown version used for doc (#1708)

* fix train.py

* fix test.py

* fix rea

* add auto-resume
parent b16c8dfa
docutils==0.16.0 docutils==0.16.0
m2r m2r
markdown<3.4.0
mistune==0.8.4 mistune==0.8.4
myst-parser myst-parser
-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import os.path as osp import os.path as osp
from mmengine.config import Config, DictAction from mmengine.config import Config, DictAction
from mmengine.registry import RUNNERS
from mmengine.runner import Runner from mmengine.runner import Runner
from mmdet3d.utils import register_all_modules, replace_ceph_backend from mmdet3d.utils import register_all_modules, replace_ceph_backend
...@@ -68,7 +69,13 @@ def main(): ...@@ -68,7 +69,13 @@ def main():
cfg.load_from = args.checkpoint cfg.load_from = args.checkpoint
# build the runner from config # build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# start testing # start testing
runner.test() runner.test()
......
...@@ -6,6 +6,7 @@ import os.path as osp ...@@ -6,6 +6,7 @@ import os.path as osp
from mmengine.config import Config, DictAction from mmengine.config import Config, DictAction
from mmengine.logging import print_log from mmengine.logging import print_log
from mmengine.registry import RUNNERS
from mmengine.runner import Runner from mmengine.runner import Runner
from mmdet3d.utils import register_all_modules, replace_ceph_backend from mmdet3d.utils import register_all_modules, replace_ceph_backend
...@@ -20,6 +21,14 @@ def parse_args(): ...@@ -20,6 +21,14 @@ def parse_args():
action='store_true', action='store_true',
default=False, default=False,
help='enable automatic-mixed-precision training') help='enable automatic-mixed-precision training')
parser.add_argument(
'--auto-scale-lr',
action='store_true',
help='enable automatically scaling LR.')
parser.add_argument(
'--resume',
action='store_true',
help='resume from the latest checkpoint in the work_dir automatically')
parser.add_argument( parser.add_argument(
'--cfg-options', '--cfg-options',
nargs='+', nargs='+',
...@@ -81,8 +90,27 @@ def main(): ...@@ -81,8 +90,27 @@ def main():
cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic' cfg.optim_wrapper.loss_scale = 'dynamic'
# enable automatically scaling LR
if args.auto_scale_lr:
if 'auto_scale_lr' in cfg and \
'enable' in cfg.auto_scale_lr and \
'base_batch_size' in cfg.auto_scale_lr:
cfg.auto_scale_lr.enable = True
else:
raise RuntimeError('Can not find "auto_scale_lr" or '
'"auto_scale_lr.enable" or '
'"auto_scale_lr.base_batch_size" in your'
' configuration file.')
cfg.resume = args.resume
# build the runner from config # build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
# start training # start training
runner.train() runner.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment