"docs/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "689e69b63864dc41c4c42ffe37ff8672b9ef8b85"
Commit 726ea95e authored by Leo Gao's avatar Leo Gao
Browse files

implement stderr calculation

parent ab0e4d26
import collections import collections
import itertools import itertools
import random import random
import lm_eval.metrics
def evaluate(lm, task_dict, provide_description, num_fewshot, limit): def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
...@@ -88,5 +89,9 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -88,5 +89,9 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
for (task_name, metric), items in vals.items(): for (task_name, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) results[task_name][metric] = task.aggregation()[metric](items)
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[metric])
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
return results return results
...@@ -5,12 +5,23 @@ from pprint import pprint ...@@ -5,12 +5,23 @@ from pprint import pprint
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn import sklearn
import random
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
def stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def mean_stderr(arr):
print(stddev(arr), len(arr))
return stddev(arr) / math.sqrt(len(arr))
def median(arr): def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
...@@ -48,6 +59,23 @@ def acc_all(items): ...@@ -48,6 +59,23 @@ def acc_all(items):
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc return acc
def acc_all_stderr(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
question_id = doc["idx"]["question"]
if question_id not in question_scoring_dict:
question_scoring_dict[question_id] = []
gold_label = doc["label"] == 1
question_scoring_dict[question_id].append(gold_label == pred)
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
return acc
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Compute max metric between prediction and each ground truth.""" """Compute max metric between prediction and each ground truth."""
...@@ -138,3 +166,42 @@ def _sacreformat(refs, preds): ...@@ -138,3 +166,42 @@ def _sacreformat(refs, preds):
preds = [pred[0] for pred in preds] preds = [pred[0] for pred in preds]
return refs, preds return refs, preds
## stderr stuff
def bootstrap_stddev(f, xs, iters=10000):
rnd = random.Random()
rnd.seed(42)
res = []
from tqdm import trange
print("bootstrapping for stddev:", f.__name__)
for i in trange(iters):
# sample w replacement
bootstrap = rnd.choices(xs, k=len(xs))
res.append(stddev(bootstrap))
return mean(res)
def stderr_for_metric(metric):
bootstrappable = [
median,
matthews_corrcoef,
f1_score,
perplexity,
bleu,
chrf,
ter,
]
if metric in bootstrappable:
return lambda x: bootstrap_stddev(metric, x) / math.sqrt(len(x))
stderr = {
mean: mean_stderr,
acc_all: acc_all_stderr
}
return stderr.get(metric, None)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment