Commit bce7d0c3 authored by yinchimaoliang's avatar yinchimaoliang
Browse files

Merge branch 'master_temp' into indoor_pipeline

parents 1756485e 868c5fab
......@@ -8,7 +8,7 @@ import time
import mmcv
import torch
from mmcv import Config
from mmcv import Config, DictAction
from mmcv.runner import init_dist
from mmdet3d import __version__
......@@ -26,9 +26,9 @@ def parse_args():
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
parser.add_argument(
'--validate',
'--no-validate',
action='store_true',
help='whether to evaluate the checkpoint during training')
help='whether not to evaluate the checkpoint during training')
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
......@@ -46,6 +46,8 @@ def parse_args():
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--options', nargs='+', action=DictAction, help='arguments in dict')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
......@@ -67,6 +69,9 @@ def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.options is not None:
cfg.merge_from_dict(args.options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
......@@ -101,7 +106,7 @@ def main():
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# add a logging filter
......@@ -113,28 +118,27 @@ def main():
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([('{}: {}'.format(k, v))
for k, v in env_info_dict.items()])
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
dash_line)
meta['env_info'] = env_info
# log some basic info
logger.info('Distributed training: {}'.format(distributed))
logger.info('Config:\n{}'.format(cfg.text))
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info('Set random seed to {}, deterministic: {}'.format(
args.seed, args.deterministic))
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
logger.info('Model:\n{}'.format(model))
logger.info(f'Model:\n{model}')
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
val_dataset = copy.deepcopy(cfg.data.val)
......@@ -145,7 +149,7 @@ def main():
# checkpoints as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__,
config=cfg.text,
config=cfg.pretty_text,
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
......@@ -154,7 +158,7 @@ def main():
datasets,
cfg,
distributed=distributed,
validate=args.validate,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment