Commit 8b47a12b authored by Kai Chen's avatar Kai Chen
Browse files

minor updates for train/test scripts

parent f8dab59d
...@@ -44,17 +44,16 @@ def parse_args(): ...@@ -44,17 +44,16 @@ def parse_args():
'--eval', '--eval',
type=str, type=str,
nargs='+', nargs='+',
choices=['proposal', 'bbox', 'segm', 'keypoints'], choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
help='eval types') help='eval types')
parser.add_argument('--show', action='store_true', help='show results') parser.add_argument('--show', action='store_true', help='show results')
args = parser.parse_args() args = parser.parse_args()
return args return args
args = parse_args()
def main(): def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config) cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None cfg.model.pretrained = None
cfg.data.test.test_mode = True cfg.data.test.test_mode = True
......
...@@ -2,6 +2,7 @@ from __future__ import division ...@@ -2,6 +2,7 @@ from __future__ import division
import argparse import argparse
import logging import logging
import random
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
...@@ -55,6 +56,7 @@ def get_logger(log_level): ...@@ -55,6 +56,7 @@ def get_logger(log_level):
def set_random_seed(seed): def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
...@@ -89,7 +91,7 @@ def main(): ...@@ -89,7 +91,7 @@ def main():
if args.work_dir is not None: if args.work_dir is not None:
cfg.work_dir = args.work_dir cfg.work_dir = args.work_dir
cfg.gpus = args.gpus cfg.gpus = args.gpus
# add mmdet version to checkpoint as meta data # save mmdet version in checkpoint as meta data
cfg.checkpoint_config.meta = dict( cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text) mmdet_version=__version__, config=cfg.text)
...@@ -103,13 +105,13 @@ def main(): ...@@ -103,13 +105,13 @@ def main():
# init distributed environment if necessary # init distributed environment if necessary
if args.launcher == 'none': if args.launcher == 'none':
dist = False dist = False
logger.info('Disabled distributed training.') logger.info('Non-distributed training.')
else: else:
dist = True dist = True
init_dist(args.launcher, **cfg.dist_params) init_dist(args.launcher, **cfg.dist_params)
if torch.distributed.get_rank() != 0: if torch.distributed.get_rank() != 0:
logger.setLevel('ERROR') logger.setLevel('ERROR')
logger.info('Enabled distributed training.') logger.info('Distributed training.')
# prepare data loaders # prepare data loaders
train_dataset = obj_from_dict(cfg.data.train, datasets) train_dataset = obj_from_dict(cfg.data.train, datasets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment