Commit 7c2b8148 authored by Kai Chen's avatar Kai Chen
Browse files

add an argument to specify process per gpu

parent 20762ce9
...@@ -39,7 +39,13 @@ def parse_args(): ...@@ -39,7 +39,13 @@ def parse_args():
parser = argparse.ArgumentParser(description='MMDet test detector') parser = argparse.ArgumentParser(description='MMDet test detector')
parser.add_argument('config', help='test config file path') parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--gpus', default=1, type=int) parser.add_argument(
'--gpus', default=1, type=int, help='GPU number used for testing')
parser.add_argument(
'--proc_per_gpu',
default=1,
type=int,
help='Number of processes per GPU')
parser.add_argument('--out', help='output result file') parser.add_argument('--out', help='output result file')
parser.add_argument( parser.add_argument(
'--eval', '--eval',
...@@ -81,8 +87,14 @@ def main(): ...@@ -81,8 +87,14 @@ def main():
model_args = cfg.model.copy() model_args = cfg.model.copy()
model_args.update(train_cfg=None, test_cfg=cfg.test_cfg) model_args.update(train_cfg=None, test_cfg=cfg.test_cfg)
model_type = getattr(detectors, model_args.pop('type')) model_type = getattr(detectors, model_args.pop('type'))
outputs = parallel_test(model_type, model_args, args.checkpoint, outputs = parallel_test(
dataset, _data_func, range(args.gpus)) model_type,
model_args,
args.checkpoint,
dataset,
_data_func,
range(args.gpus),
workers_per_gpu=args.proc_per_gpu)
if args.out: if args.out:
print('writing results to {}'.format(args.out)) print('writing results to {}'.format(args.out))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment