Unverified Commit a2cada5d authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #317 from EleutherAI/Mistobaan/add-pre-commit

Add pre-commit
parents 7a038118 83507c4b
......@@ -6,7 +6,7 @@ from itertools import islice
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class):
print('Evaluating task', taskname)
print("Evaluating task", taskname)
# dl = task_class.download
# task_class.download = MagicMock()
task = task_class()
......@@ -42,7 +42,7 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
if task.has_test_docs():
......@@ -53,7 +53,7 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
if task.has_training_docs():
......@@ -64,13 +64,13 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, task_class):
print('Evaluating task', taskname)
print("Evaluating task", taskname)
task = task_class()
fns = []
if task.has_training_docs():
......@@ -83,21 +83,21 @@ def test_documents_and_requests(taskname, task_class):
for fn in fns:
# print(list(islice(fn(), 10)))
for doc in islice(fn(), 10):
txt = task.doc_to_text(doc)
tgt = task.doc_to_target(doc)
assert isinstance(txt, str)
assert isinstance(tgt, str)
# space convention
# allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
if len(txt) != 0:
assert txt[-1] != ' '
assert tgt[0] == ' ' or txt[-1] == '\n'
assert txt[-1] != " "
assert tgt[0] == " " or txt[-1] == "\n"
reqs = task.construct_requests(doc, txt)
# construct_requests can return just one request
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
......
......@@ -5,8 +5,14 @@ from lm_eval.utils import get_rolling_token_windows, make_disjoint_window
def test_get_rolling_token_windows_v1():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]),
(
[9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
),
(
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]),
]
x = list(range(34))
......@@ -123,7 +129,6 @@ def test_get_rolling_token_windows_v4():
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
]
x = list(range(30))
generator = get_rolling_token_windows(
......@@ -145,8 +150,14 @@ def test_get_rolling_token_windows_v4():
def test_get_rolling_token_windows_v5():
gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]),
(
[9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
),
(
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
),
]
x = list(range(30))
generator = get_rolling_token_windows(
......@@ -203,5 +214,8 @@ def test_get_rolling_token_windows_empty():
def test_make_disjoint_window():
assert make_disjoint_window(([1,2,3,4,5], [2,3,4,5,6])) == ([1], [2,3,4,5,6])
assert make_disjoint_window(([1,2,3,4,5], [4,5,6])) == ([1,2,3], [4,5,6])
\ No newline at end of file
assert make_disjoint_window(([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])) == (
[1],
[2, 3, 4, 5, 6],
)
assert make_disjoint_window(([1, 2, 3, 4, 5], [4, 5, 6])) == ([1, 2, 3], [4, 5, 6])
......@@ -16,13 +16,14 @@ def assert_target(name, ob):
fname = f"tests/testdata/{name}.json"
if os.path.exists(fname):
with open(fname) as fh:
# Use relative tolerance of 1e-5 and absolute tolerance of 1e-8
# assuming most metrics work on `float32` values, which is the common
# Use relative tolerance of 1e-5 and absolute tolerance of 1e-8
# assuming most metrics work on `float32` values, which is the common
# default floating type across popular libraries (PyTorch, Tensorflow, and JAX).
assert flatten(json.load(fh)) == pytest.approx(
flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8)
flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8
)
else:
with open(fname, 'w') as fh:
with open(fname, "w") as fh:
json.dump(ob, fh, sort_keys=True)
......@@ -30,14 +31,23 @@ def assert_target_hashed(name, ob):
fname = f"tests/testdata/{name}"
if os.path.exists(fname):
with open(fname) as fh:
assert fh.read() == hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest()
assert (
fh.read()
== hashlib.sha256(
json.dumps(ob, sort_keys=True).encode("utf-8")
).hexdigest()
)
else:
with open(fname, 'w') as fh:
fh.write(hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest())
with open(fname, "w") as fh:
fh.write(
hashlib.sha256(
json.dumps(ob, sort_keys=True).encode("utf-8")
).hexdigest()
)
# from https://stackoverflow.com/a/6027615
def flatten(d, parent_key='', sep='.'):
def flatten(d, parent_key="", sep="."):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
......@@ -47,24 +57,26 @@ def flatten(d, parent_key='', sep='.'):
items.append((new_key, v))
return dict(items)
# make sure eval results for a task version are stable
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_versions_stable(taskname, task_class):
task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')()
lm = models.get_model("dummy")()
def ll_fn(reqs):
for ctx, cont in reqs:
if len(ctx) == 0:
continue
# space convention
assert ctx[-1] != ' '
assert cont[0] == ' ' or ctx[-1] == '\n'
assert ctx[-1] != " "
assert cont[0] == " " or ctx[-1] == "\n"
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs)
res = []
random.seed(42)
for _ in reqs:
res.append((-random.random(), False))
......@@ -72,10 +84,12 @@ def test_versions_stable(taskname, task_class):
return res
def ll_perp_fn(reqs):
for string, in reqs:
for (string,) in reqs:
assert isinstance(string, str)
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs)
assert_target_hashed(
f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs
)
res = []
random.seed(42)
......@@ -83,14 +97,14 @@ def test_versions_stable(taskname, task_class):
res.append(-random.random())
return res
def greedy_until(reqs):
res = []
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs)
for ctx, _ in reqs:
res.append("lol")
assert ctx.strip() != ''
assert ctx.strip() != ""
return res
......@@ -100,12 +114,12 @@ def test_versions_stable(taskname, task_class):
limit = None
result = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None
lm=lm,
task_dict=task_dict,
num_fewshot=0,
limit=limit,
bootstrap_iters=10,
description_dict=None,
)
assert_target(f"{taskname}-v{task_class.VERSION}-res", result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment