test_version_stable.py 2.44 KB
Newer Older
1
2
3
4
5
6
7
import lm_eval.tasks as tasks
import lm_eval.models as models
import lm_eval.evaluator as evaluator
import random
import pytest
import os
import json
8
import hashlib
9
10
11
12
13
14
15
16
17


os.makedirs("tests/testdata", exist_ok=True)


def assert_target(name, ob):
    fname = f"tests/testdata/{name}.json"
    if os.path.exists(fname):
        with open(fname) as fh:
18
            assert json.load(fh) == json.loads(json.dumps(ob, sort_keys=True))
19
20
    else:
        with open(fname, 'w') as fh:
21
22
            json.dump(ob, fh, sort_keys=True)

23

24
25
26
27
28
29
30
31
def assert_target_hashed(name, ob):
    fname = f"tests/testdata/{name}"
    if os.path.exists(fname):
        with open(fname) as fh:
            assert fh.read() == hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest()
    else:
        with open(fname, 'w') as fh:
            fh.write(hashlib.sha256(json.dumps(ob, sort_keys=True).encode('utf-8')).hexdigest())
32
33
34
35


# make sure eval results for a task version are stable

36
37
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_versions_stable(taskname, task_class):
38
39
40
41
42
    task_dict = tasks.get_task_dict([taskname])
    lm = models.get_model('dummy')()

    def ll_fn(reqs):
        for ctx, cont in reqs:
43
44
            if len(ctx) == 0:
                continue
45
46
47
48
            # space convention
            assert ctx[-1] != ' '
            assert cont[0] == ' ' or ctx[-1] == '\n'
        
49
        assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs)
50
51
52
53
54
55
56
57
58
59
60
61
        res = []
        
        random.seed(42)
        for _ in reqs:
            res.append((-random.random(), False))

        return res

    def ll_perp_fn(reqs):
        for string, in reqs:
            assert isinstance(string, str)

62
        assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs)
63
64
65
66
67
68
69
70
71
72
        res = []

        random.seed(42)
        for _ in reqs:
            res.append(-random.random())

        return res
    
    def greedy_until(reqs):
        res = []
73
        assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs)
74
75
76
77
78
79
80
81
82
83
        
        for ctx, _ in reqs:
            res.append("lol")
            assert ctx.strip() != ''

        return res

    lm.loglikelihood = ll_fn
    lm.loglikelihood_rolling = ll_perp_fn
    lm.greedy_until = greedy_until
84
85

    limit = None
86
87
    result = evaluator.evaluate(lm, task_dict, False, 0, limit, bootstrap_iters=10)
    assert_target(f"{taskname}-v{task_class.VERSION}-res", result)