Commit 7a39d68b authored by Leo Gao's avatar Leo Gao
Browse files

Make bootstrap_iters configurable and fix some tests

parent 0fa316f8
...@@ -4,7 +4,7 @@ import random ...@@ -4,7 +4,7 @@ import random
import lm_eval.metrics import lm_eval.metrics
def evaluate(lm, task_dict, provide_description, num_fewshot, limit): def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000):
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())] task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
...@@ -91,7 +91,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit): ...@@ -91,7 +91,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric] = task.aggregation()[metric](items) results[task_name][metric] = task.aggregation()[metric](items)
stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[metric]) stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[metric], bootstrap_iters=bootstrap_iters)
if stderr is not None: if stderr is not None:
results[task_name][metric + "_stderr"] = stderr(items) results[task_name][metric + "_stderr"] = stderr(items)
......
...@@ -195,7 +195,7 @@ class _bootstrap_internal: ...@@ -195,7 +195,7 @@ class _bootstrap_internal:
return res return res
def bootstrap_stderr(f, xs, iters=100000): def bootstrap_stderr(f, xs, iters):
import multiprocessing as mp import multiprocessing as mp
pool = mp.Pool(mp.cpu_count()) pool = mp.Pool(mp.cpu_count())
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something # this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
...@@ -205,9 +205,10 @@ def bootstrap_stderr(f, xs, iters=100000): ...@@ -205,9 +205,10 @@ def bootstrap_stderr(f, xs, iters=100000):
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator) # that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
# Thankfully, shouldn't matter because our samples are pretty big usually anyways # Thankfully, shouldn't matter because our samples are pretty big usually anyways
res = [] res = []
chunk_size = min(1000, iters)
from tqdm import tqdm from tqdm import tqdm
print("bootstrapping for stddev:", f.__name__) print("bootstrapping for stddev:", f.__name__)
for bootstrap in tqdm(pool.imap(_bootstrap_internal(f, 1000), [(i, xs) for i in range(iters // 1000)]), total=iters // 1000): for bootstrap in tqdm(pool.imap(_bootstrap_internal(f, chunk_size), [(i, xs) for i in range(iters // chunk_size)]), total=iters // chunk_size):
# sample w replacement # sample w replacement
res.extend(bootstrap) res.extend(bootstrap)
...@@ -215,7 +216,7 @@ def bootstrap_stderr(f, xs, iters=100000): ...@@ -215,7 +216,7 @@ def bootstrap_stderr(f, xs, iters=100000):
return sample_stddev(res) return sample_stddev(res)
def stderr_for_metric(metric): def stderr_for_metric(metric, bootstrap_iters):
bootstrappable = [ bootstrappable = [
median, median,
matthews_corrcoef, matthews_corrcoef,
...@@ -227,7 +228,7 @@ def stderr_for_metric(metric): ...@@ -227,7 +228,7 @@ def stderr_for_metric(metric):
] ]
if metric in bootstrappable: if metric in bootstrappable:
return lambda x: bootstrap_stderr(metric, x) return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
stderr = { stderr = {
mean: mean_stderr, mean: mean_stderr,
......
...@@ -41,4 +41,4 @@ def test_evaluator(taskname, Task): ...@@ -41,4 +41,4 @@ def test_evaluator(taskname, Task):
lm.loglikelihood = ll_fn lm.loglikelihood = ll_fn
lm.loglikelihood_rolling = ll_perp_fn lm.loglikelihood_rolling = ll_perp_fn
evaluator.evaluate(lm, task_dict, False, 0, 10) evaluator.evaluate(lm, task_dict, False, 0, 10, bootstrap_iters=10)
...@@ -41,19 +41,19 @@ def test_gpt2(): ...@@ -41,19 +41,19 @@ def test_gpt2():
targets = [-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281] targets = [-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281]
for (pred, _), tgt in zip(vals, targets): for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt, abs=1e-3) assert pred == pytest.approx(tgt, rel=1e-3)
def test_gpt2_perplexity(): def test_gpt2_perplexity():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss." test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt2.loglikelihood_perplexity([(test_string,)])[0] perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0]
tgt = sum([-4.9599953, -8.069298, -8.308624, -10.178513, -8.906924, -1.9318912, -7.745445, -7.146077, -5.2072, -3.5882986, -1.9957212, -8.044922, -0.20841774, -5.1096807, -0.099879116, -8.888423, -4.6180487]) tgt = sum([-4.9599953, -8.069298, -8.308624, -10.178513, -8.906924, -1.9318912, -7.745445, -7.146077, -5.2072, -3.5882986, -1.9957212, -8.044922, -0.20841774, -5.1096807, -0.099879116, -8.888423, -4.6180487])
assert perplexity == pytest.approx(tgt, abs=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
# Hack: modify gpt2 to have shorter context length to induce rolling windows # Hack: modify gpt2 to have shorter context length to induce rolling windows
gpt2.max_length = 5 gpt2.max_length = 5
perplexity = gpt2.loglikelihood_perplexity([(test_string,)])[0] perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0]
tgt = sum([-4.96001, -8.069275, -8.308612, -10.178482, -8.90691, -4.037338, -8.09261, -11.662385, -10.206891, -4.425003, -2.2563353, -7.909143, -1.9304147, -7.3610134, -2.3120654, -7.3229, -2.1643813]) tgt = sum([-4.96001, -8.069275, -8.308612, -10.178482, -8.90691, -4.037338, -8.09261, -11.662385, -10.206891, -4.425003, -2.2563353, -7.909143, -1.9304147, -7.3610134, -2.3120654, -7.3229, -2.1643813])
assert perplexity == pytest.approx(tgt, abs=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
...@@ -71,8 +71,10 @@ def test_documents_and_requests(taskname, Task): ...@@ -71,8 +71,10 @@ def test_documents_and_requests(taskname, Task):
assert isinstance(tgt, str) assert isinstance(tgt, str)
# space convention # space convention
assert txt[-1] != ' ' # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
assert tgt[0] == ' ' or txt[-1] == '\n' if len(txt) != 0:
assert txt[-1] != ' '
assert tgt[0] == ' ' or txt[-1] == '\n'
reqs = task.construct_requests(doc, txt) reqs = task.construct_requests(doc, txt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment