Commit aeebf6f2 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fix and gather sample logging

parent d15ee17a
...@@ -281,7 +281,7 @@ def evaluate( ...@@ -281,7 +281,7 @@ def evaluate(
"doc_id": doc_id, "doc_id": doc_id,
"doc": doc, "doc": doc,
"target": target, "target": target,
"arguments": req.args, "arguments": requests[0].args,
"resps": [req.resps for req in requests], "resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests], "filtered_resps": [req.filtered_resps[key] for req in requests],
} }
...@@ -292,6 +292,15 @@ def evaluate( ...@@ -292,6 +292,15 @@ def evaluate(
if lm.world_size > 1: if lm.world_size > 1:
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks
# first gather logged samples across all ranks
for task_name, task_samples in list(samples.items()):
full_samples = [None] * lm.world_size
torch.distributed.all_gather_object(full_samples, task_samples)
samples[task_name] = list(itertools.chain.from_iterable(full_samples))
# then collect metrics across all ranks
vals_torch = collections.defaultdict(list) vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment