Commit 4d49dd03 authored by lintangsutawika's avatar lintangsutawika
Browse files

aggregation to compute_metric

parent c6a91582
...@@ -367,7 +367,7 @@ def evaluate( ...@@ -367,7 +367,7 @@ def evaluate(
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.idx) requests.sort(key=lambda x: x.idx)
items = task.process_results( metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests] doc, [req.filtered_resps[key] for req in requests]
) )
if log_samples: if log_samples:
...@@ -380,11 +380,10 @@ def evaluate( ...@@ -380,11 +380,10 @@ def evaluate(
"resps": [req.resps for req in requests], "resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[key] for req in requests], "filtered_resps": [req.filtered_resps[key] for req in requests],
} }
example.update(items) example.update(metrics)
samples[task_name].append(example) samples[task_name].append(example)
vals[(task_name, key)].append(items) for metric, value in metrics.items():
# for metric, value in results.items(): vals[(task_name, key, metric)].append(value)
# vals[(task_name, key, metric)].append(value)
if lm.world_size > 1: if lm.world_size > 1:
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks
...@@ -397,8 +396,7 @@ def evaluate( ...@@ -397,8 +396,7 @@ def evaluate(
# then collect metrics across all ranks # then collect metrics across all ranks
vals_torch = collections.defaultdict(list) vals_torch = collections.defaultdict(list)
# for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
for (task_name, key), items in vals.items():
numitem = 0 numitem = 0
if type(items[0]) == tuple: if type(items[0]) == tuple:
numitem = len(items[0]) numitem = len(items[0])
...@@ -434,8 +432,7 @@ def evaluate( ...@@ -434,8 +432,7 @@ def evaluate(
gathered_item = [tuple(g) for g in gathered_item] gathered_item = [tuple(g) for g in gathered_item]
if lm.rank == 0: if lm.rank == 0:
# vals_torch[(task_name, key, metric)] = gathered_item vals_torch[(task_name, key, metric)] = gathered_item
vals_torch[(task_name, key)] = gathered_item
vals = vals_torch vals = vals_torch
...@@ -443,25 +440,26 @@ def evaluate( ...@@ -443,25 +440,26 @@ def evaluate(
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
# for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
for (task_name, key), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
# metric_key = metric + "," + key metric_key = metric + "," + key
if type(task) == tuple: if type(task) == tuple:
group_name, task = task group_name, task = task
else: else:
group_name = None group_name = None
for metric_key, metric_fn in task.aggregation().items():
results[task_name][metric_key] = metric_fn(*list(zip(*items))) metric_fn = task.compute_metric()[metric]
results[task_name]["samples"] = len(items) results[task_name][metric_key] = metric_fn(items)
results[task_name]["samples"] = len(items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
if bootstrap_iters > 0: if bootstrap_iters > 0:
stderr = lm_eval.api.metrics.stderr_for_metric( stderr = lm_eval.api.metrics.stderr_for_metric(
metric=task.aggregation()[metric], # metric=task.aggregation()[metric],
metric=task.compute_metric()[metric],
bootstrap_iters=min(bootstrap_iters, 100) bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"] if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters, else bootstrap_iters,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment