Unverified Commit 7451bb2a authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

merge eval results in all processes. (#1160)

parent dfb10db8
......@@ -139,13 +139,26 @@ def main(args):
args.step = 0
args.max_step = 0
if args.num_proc > 1:
queue = mp.Queue(args.num_proc)
procs = []
for i in range(args.num_proc):
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]]))
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
'Test', queue))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
total_metrics = {}
for i in range(args.num_proc):
metrics = queue.get()
for k, v in metrics.items():
if i == 0:
total_metrics[k] = v / args.num_proc
else:
total_metrics[k] += v / args.num_proc
for k, v in metrics.items():
print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
else:
test(args, model, [test_sampler_head, test_sampler_tail])
......
......@@ -263,16 +263,32 @@ def run(args, logger):
# test
if args.test:
start = time.time()
if args.num_proc > 1:
queue = mp.Queue(args.num_proc)
procs = []
for i in range(args.num_proc):
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]]))
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
'Test', queue))
procs.append(proc)
proc.start()
total_metrics = {}
for i in range(args.num_proc):
metrics = queue.get()
for k, v in metrics.items():
if i == 0:
total_metrics[k] = v / args.num_proc
else:
total_metrics[k] += v / args.num_proc
for k, v in metrics.items():
print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
for proc in procs:
proc.join()
else:
test(args, model, [test_sampler_head, test_sampler_tail])
print('test:', time.time() - start)
if __name__ == '__main__':
args = ArgParser().parse_args()
......
......@@ -61,7 +61,7 @@ def train(args, model, train_sampler, valid_samplers=None):
# clear cache
logs = []
def test(args, model, test_samplers, mode='Test'):
def test(args, model, test_samplers, mode='Test', queue=None):
logs = []
for sampler in test_samplers:
......
......@@ -80,10 +80,9 @@ def train(args, model, train_sampler, valid_samplers=None):
test(args, model, valid_samplers, mode='Valid')
print('test:', time.time() - start)
def test(args, model, test_samplers, mode='Test'):
def test(args, model, test_samplers, mode='Test', queue=None):
if args.num_proc > 1:
th.set_num_threads(1)
start = time.time()
with th.no_grad():
logs = []
for sampler in test_samplers:
......@@ -96,9 +95,10 @@ def test(args, model, test_samplers, mode='Test'):
if len(logs) > 0:
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
for k, v in metrics.items():
print('{} average {} at [{}/{}]: {}'.format(mode, k, args.step, args.max_step, v))
print('test:', time.time() - start)
if queue is not None:
queue.put(metrics)
else:
for k, v in metrics.items():
print('{} average {} at [{}/{}]: {}'.format(mode, k, args.step, args.max_step, v))
test_samplers[0] = test_samplers[0].reset()
test_samplers[1] = test_samplers[1].reset()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment