Commit bce7d0c3 authored by yinchimaoliang's avatar yinchimaoliang
Browse files

Merge branch 'master_temp' into indoor_pipeline

parents 1756485e 868c5fab
...@@ -8,7 +8,7 @@ import time ...@@ -8,7 +8,7 @@ import time
import mmcv import mmcv
import torch import torch
from mmcv import Config from mmcv import Config, DictAction
from mmcv.runner import init_dist from mmcv.runner import init_dist
from mmdet3d import __version__ from mmdet3d import __version__
...@@ -26,9 +26,9 @@ def parse_args(): ...@@ -26,9 +26,9 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--resume-from', help='the checkpoint file to resume from') '--resume-from', help='the checkpoint file to resume from')
parser.add_argument( parser.add_argument(
'--validate', '--no-validate',
action='store_true', action='store_true',
help='whether to evaluate the checkpoint during training') help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group() group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument( group_gpus.add_argument(
'--gpus', '--gpus',
...@@ -46,6 +46,8 @@ def parse_args(): ...@@ -46,6 +46,8 @@ def parse_args():
'--deterministic', '--deterministic',
action='store_true', action='store_true',
help='whether to set deterministic options for CUDNN backend.') help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options', nargs='+', action=DictAction, help='arguments in dict')
parser.add_argument( parser.add_argument(
'--launcher', '--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
...@@ -67,6 +69,9 @@ def main(): ...@@ -67,6 +69,9 @@ def main():
args = parse_args() args = parse_args()
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
if args.options is not None:
cfg.merge_from_dict(args.options)
# set cudnn_benchmark # set cudnn_benchmark
if cfg.get('cudnn_benchmark', False): if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
...@@ -101,7 +106,7 @@ def main(): ...@@ -101,7 +106,7 @@ def main():
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps # init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp)) log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# add a logging filter # add a logging filter
...@@ -113,28 +118,27 @@ def main(): ...@@ -113,28 +118,27 @@ def main():
meta = dict() meta = dict()
# log env info # log env info
env_info_dict = collect_env() env_info_dict = collect_env()
env_info = '\n'.join([('{}: {}'.format(k, v)) env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n' dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line) dash_line)
meta['env_info'] = env_info meta['env_info'] = env_info
# log some basic info # log some basic info
logger.info('Distributed training: {}'.format(distributed)) logger.info(f'Distributed training: {distributed}')
logger.info('Config:\n{}'.format(cfg.text)) logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds # set random seeds
if args.seed is not None: if args.seed is not None:
logger.info('Set random seed to {}, deterministic: {}'.format( logger.info(f'Set random seed to {args.seed}, '
args.seed, args.deterministic)) f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic) set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed cfg.seed = args.seed
meta['seed'] = args.seed meta['seed'] = args.seed
model = build_detector( model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
logger.info('Model:\n{}'.format(model)) logger.info(f'Model:\n{model}')
datasets = [build_dataset(cfg.data.train)] datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2: if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val) val_dataset = copy.deepcopy(cfg.data.val)
...@@ -145,7 +149,7 @@ def main(): ...@@ -145,7 +149,7 @@ def main():
# checkpoints as meta data # checkpoints as meta data
cfg.checkpoint_config.meta = dict( cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, mmdet_version=__version__,
config=cfg.text, config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES) CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience # add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES model.CLASSES = datasets[0].CLASSES
...@@ -154,7 +158,7 @@ def main(): ...@@ -154,7 +158,7 @@ def main():
datasets, datasets,
cfg, cfg,
distributed=distributed, distributed=distributed,
validate=args.validate, validate=(not args.no_validate),
timestamp=timestamp, timestamp=timestamp,
meta=meta) meta=meta)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment