Commit 7a39d68b authored by Leo Gao's avatar Leo Gao
Browse files

Make bootstrap_iters configurable and fix some tests

parent 0fa316f8
......@@ -4,7 +4,7 @@ import random
import lm_eval.metrics
def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000):
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
......@@ -91,7 +91,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items)
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[metric])
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[metric], bootstrap_iters=bootstrap_iters)
if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items)
......
......@@ -195,7 +195,7 @@ class _bootstrap_internal:
return res
def bootstrap_stderr(f, xs, iters=100000):
def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp
pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
......@@ -205,9 +205,10 @@ def bootstrap_stderr(f, xs, iters=100000):
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
res = []
chunk_size = min(1000, iters)
from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(pool.imap(_bootstrap_internal(f, 1000), [(i, xs) for i in range(iters // 1000)]), total=iters // 1000):
for bootstrap in tqdm(pool.imap(_bootstrap_internal(f, chunk_size), [(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size):
# sample w replacement
res.extend(bootstrap)
......@@ -215,7 +216,7 @@ def bootstrap_stderr(f, xs, iters=100000):
return sample_stddev(res)
def stderr_for_metric(metric):
def stderr_for_metric(metric, bootstrap_iters):
bootstrappable = [
median,
matthews_corrcoef,
......@@ -227,7 +228,7 @@ def stderr_for_metric(metric):
]
if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x)
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = {
mean: mean_stderr,
......
......@@ -41,4 +41,4 @@ def test_evaluator(taskname, Task):
lm.loglikelihood = ll_fn
lm.loglikelihood_rolling = ll_perp_fn
evaluator.evaluate(lm, task_dict, False, 0, 10)
evaluator.evaluate(lm, task_dict, False, 0, 10, bootstrap_iters=10)
......@@ -41,19 +41,19 @@ def test_gpt2():
targets = [-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281]
for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt, abs=1e-3)
assert pred == pytest.approx(tgt, rel=1e-3)
def test_gpt2_perplexity():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt2.loglikelihood_perplexity([(test_string,)])[0]
perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0]
tgt = sum([-4.9599953, -8.069298, -8.308624, -10.178513, -8.906924, -1.9318912, -7.745445, -7.146077, -5.2072, -3.5882986, -1.9957212, -8.044922, -0.20841774, -5.1096807, -0.099879116, -8.888423, -4.6180487])
assert perplexity == pytest.approx(tgt, abs=1e-3)
assert perplexity == pytest.approx(tgt, rel=1e-3)
# Hack: modify gpt2 to have shorter context length to induce rolling windows
gpt2.max_length = 5
perplexity = gpt2.loglikelihood_perplexity([(test_string,)])[0]
perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0]
tgt = sum([-4.96001, -8.069275, -8.308612, -10.178482, -8.90691, -4.037338, -8.09261, -11.662385, -10.206891, -4.425003, -2.2563353, -7.909143, -1.9304147, -7.3610134, -2.3120654, -7.3229, -2.1643813])
assert perplexity == pytest.approx(tgt, abs=1e-3)
assert perplexity == pytest.approx(tgt, rel=1e-3)
......@@ -71,8 +71,10 @@ def test_documents_and_requests(taskname, Task):
assert isinstance(tgt, str)
# space convention
assert txt[-1] != ' '
assert tgt[0] == ' ' or txt[-1] == '\n'
# allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
if len(txt) != 0:
assert txt[-1] != ' '
assert tgt[0] == ' ' or txt[-1] == '\n'
reqs = task.construct_requests(doc, txt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment