Commit 6117c507 authored by lintangsutawika's avatar lintangsutawika
Browse files

changed how metrics are calculated

parent 028f04c7
......@@ -370,7 +370,7 @@ def evaluate(
# subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.idx)
metrics = task.process_results(
items = task.process_results(
doc, [req.filtered_resps[key] for req in requests]
)
if log_samples:
......@@ -383,10 +383,11 @@ def evaluate(
"resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests],
}
example.update(metrics)
example.update(items)
samples[task_name].append(example)
for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value)
vals[(task_name, key)].append(items)
# for metric, value in results.items():
# vals[(task_name, key, metric)].append(value)
if lm.world_size > 1:
# if multigpu, then gather data across all ranks
......@@ -399,7 +400,8 @@ def evaluate(
# then collect metrics across all ranks
vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items():
# for (task_name, key, metric), items in vals.items():
for (task_name, key), items in vals.items():
numitem = 0
if type(items[0]) == tuple:
numitem = len(items[0])
......@@ -435,7 +437,8 @@ def evaluate(
gathered_item = [tuple(g) for g in gathered_item]
if lm.rank == 0:
vals_torch[(task_name, key, metric)] = gathered_item
# vals_torch[(task_name, key, metric)] = gathered_item
vals_torch[(task_name, key)] = gathered_item
vals = vals_torch
......@@ -469,17 +472,18 @@ def evaluate(
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
# for (task_name, key, metric), items in vals.items():
for (task_name, key), items in vals.items():
task = task_dict[task_name]
metric_key = metric + "," + key
# metric_key = metric + "," + key
if type(task) == tuple:
group_name, task = task
else:
group_name = None
agg_fn = task.aggregation()[metric]
results[task_name][metric_key] = agg_fn(items)
for metric_key, metric_fn in task.aggregation().items():
results[task_name][metric_key] = metric_fn(*list(zip(*items)))
results[task_name]["samples"] = len(items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment