Commit 787b23f6 authored by lintangsutawika's avatar lintangsutawika
Browse files

readd aggregation

parent aaf64aab
...@@ -449,16 +449,15 @@ def evaluate( ...@@ -449,16 +449,15 @@ def evaluate(
else: else:
group_name = None group_name = None
metric_fn = task.compute_metric()[metric] agg_fn = task.aggregation()[metric]
results[task_name][metric_key] = metric_fn(items) results[task_name][metric_key] = agg_fn(items)
results[task_name]["samples"] = len(items) results[task_name]["samples"] = len(items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
if bootstrap_iters > 0: if bootstrap_iters > 0:
stderr = lm_eval.api.metrics.stderr_for_metric( stderr = lm_eval.api.metrics.stderr_for_metric(
# metric=task.aggregation()[metric], metric=task.aggregation()[metric],
metric=task.compute_metric()[metric],
bootstrap_iters=min(bootstrap_iters, 100) bootstrap_iters=min(bootstrap_iters, 100)
if metric in ["bleu", "chrf", "ter"] if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters, else bootstrap_iters,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment