Unverified Commit a2cada5d authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #317 from EleutherAI/Mistobaan/add-pre-commit

Add pre-commit
parents 7a038118 83507c4b
...@@ -6,7 +6,7 @@ from itertools import islice ...@@ -6,7 +6,7 @@ from itertools import islice
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class): def test_basic_interface(taskname, task_class):
print('Evaluating task', taskname) print("Evaluating task", taskname)
# dl = task_class.download # dl = task_class.download
# task_class.download = MagicMock() # task_class.download = MagicMock()
task = task_class() task = task_class()
...@@ -42,7 +42,7 @@ def test_basic_interface(taskname, task_class): ...@@ -42,7 +42,7 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
if task.has_test_docs(): if task.has_test_docs():
...@@ -53,7 +53,7 @@ def test_basic_interface(taskname, task_class): ...@@ -53,7 +53,7 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
if task.has_training_docs(): if task.has_training_docs():
...@@ -64,13 +64,13 @@ def test_basic_interface(taskname, task_class): ...@@ -64,13 +64,13 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, task_class): def test_documents_and_requests(taskname, task_class):
print('Evaluating task', taskname) print("Evaluating task", taskname)
task = task_class() task = task_class()
fns = [] fns = []
if task.has_training_docs(): if task.has_training_docs():
...@@ -83,21 +83,21 @@ def test_documents_and_requests(taskname, task_class): ...@@ -83,21 +83,21 @@ def test_documents_and_requests(taskname, task_class):
for fn in fns: for fn in fns:
# print(list(islice(fn(), 10))) # print(list(islice(fn(), 10)))
for doc in islice(fn(), 10): for doc in islice(fn(), 10):
txt = task.doc_to_text(doc) txt = task.doc_to_text(doc)
tgt = task.doc_to_target(doc) tgt = task.doc_to_target(doc)
assert isinstance(txt, str) assert isinstance(txt, str)
assert isinstance(tgt, str) assert isinstance(tgt, str)
# space convention # space convention
# allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
if len(txt) != 0: if len(txt) != 0:
assert txt[-1] != ' ' assert txt[-1] != " "
assert tgt[0] == ' ' or txt[-1] == '\n' assert tgt[0] == " " or txt[-1] == "\n"
reqs = task.construct_requests(doc, txt) reqs = task.construct_requests(doc, txt)
# construct_requests can return just one request # construct_requests can return just one request
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
......
...@@ -5,8 +5,14 @@ from lm_eval.utils import get_rolling_token_windows, make_disjoint_window ...@@ -5,8 +5,14 @@ from lm_eval.utils import get_rolling_token_windows, make_disjoint_window
def test_get_rolling_token_windows_v1(): def test_get_rolling_token_windows_v1():
gold = [ gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), (
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]), [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
),
(
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]), ([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]),
] ]
x = list(range(34)) x = list(range(34))
...@@ -123,7 +129,6 @@ def test_get_rolling_token_windows_v4(): ...@@ -123,7 +129,6 @@ def test_get_rolling_token_windows_v4():
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]), ([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]), ([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]), ([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
] ]
x = list(range(30)) x = list(range(30))
generator = get_rolling_token_windows( generator = get_rolling_token_windows(
...@@ -145,8 +150,14 @@ def test_get_rolling_token_windows_v4(): ...@@ -145,8 +150,14 @@ def test_get_rolling_token_windows_v4():
def test_get_rolling_token_windows_v5(): def test_get_rolling_token_windows_v5():
gold = [ gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), (
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]), [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
),
(
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
),
] ]
x = list(range(30)) x = list(range(30))
generator = get_rolling_token_windows( generator = get_rolling_token_windows(
...@@ -203,5 +214,8 @@ def test_get_rolling_token_windows_empty(): ...@@ -203,5 +214,8 @@ def test_get_rolling_token_windows_empty():
def test_make_disjoint_window(): def test_make_disjoint_window():
assert make_disjoint_window(([1,2,3,4,5], [2,3,4,5,6])) == ([1], [2,3,4,5,6]) assert make_disjoint_window(([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])) == (
assert make_disjoint_window(([1,2,3,4,5], [4,5,6])) == ([1,2,3], [4,5,6]) [1],
\ No newline at end of file [2, 3, 4, 5, 6],
)
assert make_disjoint_window(([1, 2, 3, 4, 5], [4, 5, 6])) == ([1, 2, 3], [4, 5, 6])
...@@ -16,13 +16,14 @@ def assert_target(name, ob): ...@@ -16,13 +16,14 @@ def assert_target(name, ob):
fname = f"tests/testdata/{name}.json" fname = f"tests/testdata/{name}.json"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname) as fh: with open(fname) as fh:
# Use relative tolerance of 1e-5 and absolute tolerance of 1e-8 # Use relative tolerance of 1e-5 and absolute tolerance of 1e-8
# assuming most metrics work on `float32` values, which is the common # assuming most metrics work on `float32` values, which is the common
# default floating type across popular libraries (PyTorch, Tensorflow, and JAX). # default floating type across popular libraries (PyTorch, Tensorflow, and JAX).
assert flatten(json.load(fh)) == pytest.approx( assert flatten(json.load(fh)) == pytest.approx(
flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8) flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8
)
else: else:
with open(fname, 'w') as fh: with open(fname, "w") as fh:
json.dump(ob, fh, sort_keys=True) json.dump(ob, fh, sort_keys=True)
...@@ -30,14 +31,23 @@ def assert_target_hashed(name, ob): ...@@ -30,14 +31,23 @@ def assert_target_hashed(name, ob):
fname = f"tests/testdata/{name}" fname = f"tests/testdata/{name}"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname) as fh: with open(fname) as fh:
assert fh.read() == hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest() assert (
fh.read()
== hashlib.sha256(
json.dumps(ob, sort_keys=True).encode("utf-8")
).hexdigest()
)
else: else:
with open(fname, 'w') as fh: with open(fname, "w") as fh:
fh.write(hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest()) fh.write(
hashlib.sha256(
json.dumps(ob, sort_keys=True).encode("utf-8")
).hexdigest()
)
# from https://stackoverflow.com/a/6027615 # from https://stackoverflow.com/a/6027615
def flatten(d, parent_key='', sep='.'): def flatten(d, parent_key="", sep="."):
items = [] items = []
for k, v in d.items(): for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k new_key = parent_key + sep + k if parent_key else k
...@@ -47,24 +57,26 @@ def flatten(d, parent_key='', sep='.'): ...@@ -47,24 +57,26 @@ def flatten(d, parent_key='', sep='.'):
items.append((new_key, v)) items.append((new_key, v))
return dict(items) return dict(items)
# make sure eval results for a task version are stable # make sure eval results for a task version are stable
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_versions_stable(taskname, task_class): def test_versions_stable(taskname, task_class):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')() lm = models.get_model("dummy")()
def ll_fn(reqs): def ll_fn(reqs):
for ctx, cont in reqs: for ctx, cont in reqs:
if len(ctx) == 0: if len(ctx) == 0:
continue continue
# space convention # space convention
assert ctx[-1] != ' ' assert ctx[-1] != " "
assert cont[0] == ' ' or ctx[-1] == '\n' assert cont[0] == " " or ctx[-1] == "\n"
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs) assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs)
res = [] res = []
random.seed(42) random.seed(42)
for _ in reqs: for _ in reqs:
res.append((-random.random(), False)) res.append((-random.random(), False))
...@@ -72,10 +84,12 @@ def test_versions_stable(taskname, task_class): ...@@ -72,10 +84,12 @@ def test_versions_stable(taskname, task_class):
return res return res
def ll_perp_fn(reqs): def ll_perp_fn(reqs):
for string, in reqs: for (string,) in reqs:
assert isinstance(string, str) assert isinstance(string, str)
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs) assert_target_hashed(
f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs
)
res = [] res = []
random.seed(42) random.seed(42)
...@@ -83,14 +97,14 @@ def test_versions_stable(taskname, task_class): ...@@ -83,14 +97,14 @@ def test_versions_stable(taskname, task_class):
res.append(-random.random()) res.append(-random.random())
return res return res
def greedy_until(reqs): def greedy_until(reqs):
res = [] res = []
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs) assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs)
for ctx, _ in reqs: for ctx, _ in reqs:
res.append("lol") res.append("lol")
assert ctx.strip() != '' assert ctx.strip() != ""
return res return res
...@@ -100,12 +114,12 @@ def test_versions_stable(taskname, task_class): ...@@ -100,12 +114,12 @@ def test_versions_stable(taskname, task_class):
limit = None limit = None
result = evaluator.evaluate( result = evaluator.evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
assert_target(f"{taskname}-v{task_class.VERSION}-res", result) assert_target(f"{taskname}-v{task_class.VERSION}-res", result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment