"vscode:/vscode.git/clone" did not exist on "8afc3c079aae1c19d916cd31b2acf16bf023bece"
Commit 1f8a8c1d authored by jon-tow's avatar jon-tow
Browse files

Merge branch 'master' of https://github.com/EleutherAI/lm-evaluation-harness into remove-dataset

parents b4c0275d b0acb337
...@@ -16,13 +16,14 @@ def assert_target(name, ob): ...@@ -16,13 +16,14 @@ def assert_target(name, ob):
fname = f"tests/testdata/{name}.json" fname = f"tests/testdata/{name}.json"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname) as fh: with open(fname) as fh:
# Use relative tolerance of 1e-5 and absolute tolerance of 1e-8 # Use relative tolerance of 1e-5 and absolute tolerance of 1e-8
# assuming most metrics work on `float32` values, which is the common # assuming most metrics work on `float32` values, which is the common
# default floating type across popular libraries (PyTorch, Tensorflow, and JAX). # default floating type across popular libraries (PyTorch, Tensorflow, and JAX).
assert flatten(json.load(fh)) == pytest.approx( assert flatten(json.load(fh)) == pytest.approx(
flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8) flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8
)
else: else:
with open(fname, 'w') as fh: with open(fname, "w") as fh:
json.dump(ob, fh, sort_keys=True) json.dump(ob, fh, sort_keys=True)
...@@ -30,14 +31,23 @@ def assert_target_hashed(name, ob): ...@@ -30,14 +31,23 @@ def assert_target_hashed(name, ob):
fname = f"tests/testdata/{name}" fname = f"tests/testdata/{name}"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname) as fh: with open(fname) as fh:
assert fh.read() == hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest() assert (
fh.read()
== hashlib.sha256(
json.dumps(ob, sort_keys=True).encode("utf-8")
).hexdigest()
)
else: else:
with open(fname, 'w') as fh: with open(fname, "w") as fh:
fh.write(hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest()) fh.write(
hashlib.sha256(
json.dumps(ob, sort_keys=True).encode("utf-8")
).hexdigest()
)
# from https://stackoverflow.com/a/6027615 # from https://stackoverflow.com/a/6027615
def flatten(d, parent_key='', sep='.'): def flatten(d, parent_key="", sep="."):
items = [] items = []
for k, v in d.items(): for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k new_key = parent_key + sep + k if parent_key else k
...@@ -47,24 +57,26 @@ def flatten(d, parent_key='', sep='.'): ...@@ -47,24 +57,26 @@ def flatten(d, parent_key='', sep='.'):
items.append((new_key, v)) items.append((new_key, v))
return dict(items) return dict(items)
# make sure eval results for a task version are stable # make sure eval results for a task version are stable
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_versions_stable(taskname, task_class): def test_versions_stable(taskname, task_class):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')() lm = models.get_model("dummy")()
def ll_fn(reqs): def ll_fn(reqs):
for ctx, cont in reqs: for ctx, cont in reqs:
if len(ctx) == 0: if len(ctx) == 0:
continue continue
# space convention # space convention
assert ctx[-1] != ' ' assert ctx[-1] != " "
assert cont[0] == ' ' or ctx[-1] == '\n' assert cont[0] == " " or ctx[-1] == "\n"
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs) assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs)
res = [] res = []
random.seed(42) random.seed(42)
for _ in reqs: for _ in reqs:
res.append((-random.random(), False)) res.append((-random.random(), False))
...@@ -72,10 +84,12 @@ def test_versions_stable(taskname, task_class): ...@@ -72,10 +84,12 @@ def test_versions_stable(taskname, task_class):
return res return res
def ll_perp_fn(reqs): def ll_perp_fn(reqs):
for string, in reqs: for (string,) in reqs:
assert isinstance(string, str) assert isinstance(string, str)
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs) assert_target_hashed(
f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs
)
res = [] res = []
random.seed(42) random.seed(42)
...@@ -83,14 +97,14 @@ def test_versions_stable(taskname, task_class): ...@@ -83,14 +97,14 @@ def test_versions_stable(taskname, task_class):
res.append(-random.random()) res.append(-random.random())
return res return res
def greedy_until(reqs): def greedy_until(reqs):
res = [] res = []
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs) assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs)
for ctx, _ in reqs: for ctx, _ in reqs:
res.append("lol") res.append("lol")
assert ctx.strip() != '' assert ctx.strip() != ""
return res return res
...@@ -100,12 +114,12 @@ def test_versions_stable(taskname, task_class): ...@@ -100,12 +114,12 @@ def test_versions_stable(taskname, task_class):
limit = None limit = None
result = evaluator.evaluate( result = evaluator.evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
assert_target(f"{taskname}-v{task_class.VERSION}-res", result) assert_target(f"{taskname}-v{task_class.VERSION}-res", result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment