Commit 377a1f45 authored by thomasw21's avatar thomasw21
Browse files

Merge remote-tracking branch 'origin/master' into thomas/fix_head_qa

parents 22c4124f f16e8b5c
......@@ -6,6 +6,7 @@ import pytest
import os
import json
import hashlib
import collections
os.makedirs("tests/testdata", exist_ok=True)
......@@ -15,11 +16,16 @@ def assert_target(name, ob):
fname = f"tests/testdata/{name}.json"
if os.path.exists(fname):
with open(fname) as fh:
assert json.load(fh) == json.loads(json.dumps(ob, sort_keys=True))
# Use relative tolerance of 1e-5 and absolute tolerance of 1e-8
# assuming most metrics work on `float32` values, which is the common
# default floating type across popular libraries (PyTorch, Tensorflow, and JAX).
assert flatten(json.load(fh)) == pytest.approx(
flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8)
else:
with open(fname, 'w') as fh:
json.dump(ob, fh, sort_keys=True)
def assert_target_hashed(name, ob):
fname = f"tests/testdata/{name}"
if os.path.exists(fname):
......@@ -29,22 +35,34 @@ def assert_target_hashed(name, ob):
with open(fname, 'w') as fh:
fh.write(hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest())
# from https://stackoverflow.com/a/6027615
def flatten(d, parent_key='', sep='.'):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
# make sure eval results for a task version are stable
@pytest.mark.parametrize("taskname,Task", tasks.TASK_REGISTRY.items())
def test_versions_stable(taskname, Task):
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_versions_stable(taskname, task_class):
task_dict = tasks.get_task_dict([taskname])
lm = models.get_model('dummy')()
def ll_fn(reqs):
for ctx, cont in reqs:
if len(ctx) == 0: continue
if len(ctx) == 0:
continue
# space convention
assert ctx[-1] != ' '
assert cont[0] == ' ' or ctx[-1] == '\n'
assert_target_hashed(f"{taskname}-v{Task.VERSION}-loglikelihood", reqs)
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs)
res = []
random.seed(42)
......@@ -57,7 +75,7 @@ def test_versions_stable(taskname, Task):
for string, in reqs:
assert isinstance(string, str)
assert_target_hashed(f"{taskname}-v{Task.VERSION}-loglikelihood_rolling", reqs)
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs)
res = []
random.seed(42)
......@@ -68,7 +86,7 @@ def test_versions_stable(taskname, Task):
def greedy_until(reqs):
res = []
assert_target_hashed(f"{taskname}-v{Task.VERSION}-greedy_until", reqs)
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs)
for ctx, _ in reqs:
res.append("lol")
......@@ -81,5 +99,5 @@ def test_versions_stable(taskname, Task):
lm.greedy_until = greedy_until
limit = None
res = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10)
assert_target(f"{taskname}-v{Task.VERSION}-res", res)
result = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10)
assert_target(f"{taskname}-v{task_class.VERSION}-res", result)
1a280973bbac2b7ac29dd64dddac474fb4749585f7de893483b4034814466c67
\ No newline at end of file
{"results": {"truthfulqa_gen": {"bleu_acc": 0.0, "bleu_acc_stderr": 0.0, "bleu_diff": 0.0, "bleu_diff_stderr": 0.0, "bleu_max": 0.0, "bleu_max_stderr": 0.0, "bleurt_acc": 0.835985312117503, "bleurt_acc_stderr": 0.012962704327492454, "bleurt_diff": 0.14077322143090107, "bleurt_diff_stderr": 0.005459888909582694, "bleurt_max": -1.4399358725752065, "bleurt_max_stderr": 0.0022126992369197133, "rouge1_acc": 0.0, "rouge1_acc_stderr": 0.0, "rouge1_diff": 0.0, "rouge1_diff_stderr": 0.0, "rouge1_max": 0.0, "rouge1_max_stderr": 0.0, "rouge2_acc": 0.0, "rouge2_acc_stderr": 0.0, "rouge2_diff": 0.0, "rouge2_diff_stderr": 0.0, "rouge2_max": 0.0, "rouge2_max_stderr": 0.0, "rougeL_acc": 0.0, "rougeL_acc_stderr": 0.0, "rougeL_diff": 0.0, "rougeL_diff_stderr": 0.0, "rougeL_max": 0.0, "rougeL_max_stderr": 0.0}}, "versions": {"truthfulqa_gen": 1}}
\ No newline at end of file
1e07020e9cf41d46ed65312eb39d2b8e6599673d4f0d6b67c0d0eba0efb493bb
\ No newline at end of file
{"results": {"truthfulqa_mc": {"mc1": 0.23255813953488372, "mc1_stderr": 0.01478915753108052, "mc2": 0.4462325560722362, "mc2_stderr": 0.004986523944692003}}, "versions": {"truthfulqa_mc": 1}}
\ No newline at end of file
8a0f81661d2ab2334bbc8031fac31c0c8882f1d9271dd51599d21dfdbb726dea
\ No newline at end of file
{"results": {"wnli": {"acc": 0.5633802816901409, "acc_stderr": 0.0592793555841297}}, "versions": {"wnli": 1}}
\ No newline at end of file
976a5cac4bdb724632eebd4cb9e522203ce3da8d5525288a597c86e80469f3f2
\ No newline at end of file
{"results": {"blimp_adjunct_island": {"acc": 0.485, "acc_stderr": 0.0158121796418149}}, "versions": {"blimp_adjunct_island": 0}}
\ No newline at end of file
2d8964e56a17661502ecf3f09c0befba63915360ddf2145b0bd845816950515d
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment