Commit 45fefe9f authored by Leo Gao's avatar Leo Gao
Browse files

implement stderr calculation

parent 8846bec0
import collections import collections
import itertools import itertools
import random import random
import lm_eval.metrics
def evaluate(lm, task_dict, provide_description, num_fewshot, limit): def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
...@@ -88,5 +89,9 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -88,5 +89,9 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
for (task_name, metric), items in vals.items(): for (task_name, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) results[task_name][metric] = task.aggregation()[metric](items)
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[metric])
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
return results return results
...@@ -5,12 +5,23 @@ from pprint import pprint ...@@ -5,12 +5,23 @@ from pprint import pprint
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn import sklearn
import random
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
def stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def mean_stderr(arr):
print(stddev(arr), len(arr))
return stddev(arr) / math.sqrt(len(arr))
def median(arr): def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
...@@ -48,6 +59,23 @@ def acc_all(items): ...@@ -48,6 +59,23 @@ def acc_all(items):
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc return acc
def acc_all_stderr(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
question_id = doc["idx"]["question"]
if question_id not in question_scoring_dict:
question_scoring_dict[question_id] = []
gold_label = doc["label"] == 1
question_scoring_dict[question_id].append(gold_label == pred)
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
return acc
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Compute max metric between prediction and each ground truth.""" """Compute max metric between prediction and each ground truth."""
...@@ -138,3 +166,42 @@ def _sacreformat(refs, preds): ...@@ -138,3 +166,42 @@ def _sacreformat(refs, preds):
preds = [pred[0] for pred in preds] preds = [pred[0] for pred in preds]
return refs, preds return refs, preds
## stderr stuff
def bootstrap_stddev(f, xs, iters=10000):
rnd = random.Random()
rnd.seed(42)
res = []
from tqdm import trange
print("bootstrapping for stddev:", f.__name__)
for i in trange(iters):
# sample w replacement
bootstrap = rnd.choices(xs, k=len(xs))
res.append(stddev(bootstrap))
return mean(res)
def stderr_for_metric(metric):
bootstrappable = [
median,
matthews_corrcoef,
f1_score,
perplexity,
bleu,
chrf,
ter,
]
if metric in bootstrappable:
return lambda x: bootstrap_stddev(metric, x) / math.sqrt(len(x))
stderr = {
mean: mean_stderr,
acc_all: acc_all_stderr
}
return stderr.get(metric, None)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment