"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9c03a7da43a6fecfc5d1db978f3fd7d8f3f85ae4"
Unverified Commit 7451bb2a authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

merge eval results in all processes. (#1160)

parent dfb10db8
...@@ -139,13 +139,26 @@ def main(args): ...@@ -139,13 +139,26 @@ def main(args):
args.step = 0 args.step = 0
args.max_step = 0 args.max_step = 0
if args.num_proc > 1: if args.num_proc > 1:
queue = mp.Queue(args.num_proc)
procs = [] procs = []
for i in range(args.num_proc): for i in range(args.num_proc):
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]])) proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
'Test', queue))
procs.append(proc) procs.append(proc)
proc.start() proc.start()
for proc in procs: for proc in procs:
proc.join() proc.join()
total_metrics = {}
for i in range(args.num_proc):
metrics = queue.get()
for k, v in metrics.items():
if i == 0:
total_metrics[k] = v / args.num_proc
else:
total_metrics[k] += v / args.num_proc
for k, v in metrics.items():
print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
else: else:
test(args, model, [test_sampler_head, test_sampler_tail]) test(args, model, [test_sampler_head, test_sampler_tail])
......
...@@ -263,16 +263,32 @@ def run(args, logger): ...@@ -263,16 +263,32 @@ def run(args, logger):
# test # test
if args.test: if args.test:
start = time.time()
if args.num_proc > 1: if args.num_proc > 1:
queue = mp.Queue(args.num_proc)
procs = [] procs = []
for i in range(args.num_proc): for i in range(args.num_proc):
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]])) proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
'Test', queue))
procs.append(proc) procs.append(proc)
proc.start() proc.start()
total_metrics = {}
for i in range(args.num_proc):
metrics = queue.get()
for k, v in metrics.items():
if i == 0:
total_metrics[k] = v / args.num_proc
else:
total_metrics[k] += v / args.num_proc
for k, v in metrics.items():
print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
for proc in procs: for proc in procs:
proc.join() proc.join()
else: else:
test(args, model, [test_sampler_head, test_sampler_tail]) test(args, model, [test_sampler_head, test_sampler_tail])
print('test:', time.time() - start)
if __name__ == '__main__': if __name__ == '__main__':
args = ArgParser().parse_args() args = ArgParser().parse_args()
......
...@@ -61,7 +61,7 @@ def train(args, model, train_sampler, valid_samplers=None): ...@@ -61,7 +61,7 @@ def train(args, model, train_sampler, valid_samplers=None):
# clear cache # clear cache
logs = [] logs = []
def test(args, model, test_samplers, mode='Test'): def test(args, model, test_samplers, mode='Test', queue=None):
logs = [] logs = []
for sampler in test_samplers: for sampler in test_samplers:
......
...@@ -80,10 +80,9 @@ def train(args, model, train_sampler, valid_samplers=None): ...@@ -80,10 +80,9 @@ def train(args, model, train_sampler, valid_samplers=None):
test(args, model, valid_samplers, mode='Valid') test(args, model, valid_samplers, mode='Valid')
print('test:', time.time() - start) print('test:', time.time() - start)
def test(args, model, test_samplers, mode='Test'): def test(args, model, test_samplers, mode='Test', queue=None):
if args.num_proc > 1: if args.num_proc > 1:
th.set_num_threads(1) th.set_num_threads(1)
start = time.time()
with th.no_grad(): with th.no_grad():
logs = [] logs = []
for sampler in test_samplers: for sampler in test_samplers:
...@@ -96,9 +95,10 @@ def test(args, model, test_samplers, mode='Test'): ...@@ -96,9 +95,10 @@ def test(args, model, test_samplers, mode='Test'):
if len(logs) > 0: if len(logs) > 0:
for metric in logs[0].keys(): for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs) metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
if queue is not None:
for k, v in metrics.items(): queue.put(metrics)
print('{} average {} at [{}/{}]: {}'.format(mode, k, args.step, args.max_step, v)) else:
print('test:', time.time() - start) for k, v in metrics.items():
print('{} average {} at [{}/{}]: {}'.format(mode, k, args.step, args.max_step, v))
test_samplers[0] = test_samplers[0].reset() test_samplers[0] = test_samplers[0].reset()
test_samplers[1] = test_samplers[1].reset() test_samplers[1] = test_samplers[1].reset()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment