"vscode:/vscode.git/clone" did not exist on "f7201d1affd601e4f66e602066c24e6ccfff189b"
Commit 1731613e authored by Leo Gao's avatar Leo Gao
Browse files

Merge branch 'master' of github.com:EleutherAI/lm_evaluation_harness into perplexity

# Conflicts:
#	requirements.txt
parents 45127aa7 f9d87ad3
import collections import collections
import itertools import itertools
import random import random
import lm_eval.metrics
def evaluate(lm, task_dict, provide_description, num_fewshot, limit): def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
...@@ -89,5 +90,9 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -89,5 +90,9 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
for (task_name, metric), items in vals.items(): for (task_name, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) results[task_name][metric] = task.aggregation()[metric](items)
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[metric])
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
return results return results
...@@ -5,12 +5,23 @@ from pprint import pprint ...@@ -5,12 +5,23 @@ from pprint import pprint
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn import sklearn
import random
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
def stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def mean_stderr(arr):
print(stddev(arr), len(arr))
return stddev(arr) / math.sqrt(len(arr))
def median(arr): def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
...@@ -48,6 +59,23 @@ def acc_all(items): ...@@ -48,6 +59,23 @@ def acc_all(items):
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc return acc
def acc_all_stderr(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
question_id = doc["idx"]["question"]
if question_id not in question_scoring_dict:
question_scoring_dict[question_id] = []
gold_label = doc["label"] == 1
question_scoring_dict[question_id].append(gold_label == pred)
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
return acc
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""Compute max metric between prediction and each ground truth.""" """Compute max metric between prediction and each ground truth."""
...@@ -143,3 +171,42 @@ def _sacreformat(refs, preds): ...@@ -143,3 +171,42 @@ def _sacreformat(refs, preds):
preds = [pred[0] for pred in preds] preds = [pred[0] for pred in preds]
return refs, preds return refs, preds
## stderr stuff
def bootstrap_stderr(f, xs, iters=10000):
rnd = random.Random()
rnd.seed(42)
res = []
from tqdm import trange
print("bootstrapping for stddev:", f.__name__)
for i in trange(iters):
# sample w replacement
bootstrap = f(rnd.choices(xs, k=len(xs)))
res.append(bootstrap)
return stddev(res)
def stderr_for_metric(metric):
bootstrappable = [
median,
matthews_corrcoef,
f1_score,
perplexity,
bleu,
chrf,
ter,
]
if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x)
stderr = {
mean: mean_stderr,
acc_all: acc_all_stderr
}
return stderr.get(metric, None)
black==20.8b1 .
best_download>=0.0.5 \ No newline at end of file
datasets>=1.2.1
click>=7.1
scikit-learn>=0.24.1
torch>=1.7
transformers>=4.1
sqlitedict==1.6.0
pytablewriter==0.58.0
sacrebleu==1.5.0
pycountry==20.7.3
numexpr==2.7.2
lm_dataformat>=0.0.19
\ No newline at end of file
...@@ -19,4 +19,18 @@ setuptools.setup( ...@@ -19,4 +19,18 @@ setuptools.setup(
"Operating System :: OS Independent", "Operating System :: OS Independent",
], ],
python_requires='>=3.6', python_requires='>=3.6',
install_requires=[
"black==20.8b1",
"best_download>=0.0.5",
"datasets>=1.2.1",
"click>=7.1",
"scikit-learn>=0.24.1",
"torch>=1.7",
"transformers>=4.1",
"sqlitedict==1.6.0",
"pytablewriter==0.58.0",
"sacrebleu==1.5.0",
"pycountry==20.7.3",
"numexpr==2.7.2",
]
) )
...@@ -40,5 +40,5 @@ def test_evaluator(taskname, Task): ...@@ -40,5 +40,5 @@ def test_evaluator(taskname, Task):
return res return res
lm.loglikelihood = ll_fn lm.loglikelihood = ll_fn
lm.loglikelihood_perplexity = ll_perp_fn lm.loglikelihood_rolling = ll_perp_fn
evaluator.evaluate(lm, task_dict, False, 0, 10) evaluator.evaluate(lm, task_dict, False, 0, 10)
import pytest
import lm_eval.metrics as metrics
def test_bootstrapping():
arr = list(range(100))
expected = metrics.mean_stderr(arr)
bootstrapped = metrics.bootstrap_stderr(metrics.mean, arr)
assert bootstrapped == pytest.approx(expected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment