"vscode:/vscode.git/clone" did not exist on "558ca0feccdf4a6a87d1ac5d288964c2e3b4c022"
test_version_stable.py 3.49 KB
Newer Older
1
2
3
4
import collections
import hashlib
import json
import os
5
import random
6

7
import pytest
8
9
10
11

import lm_eval.evaluator as evaluator
import lm_eval.models as models
import lm_eval.tasks as tasks
12
13
14
15
16
17
18
19
20


os.makedirs("tests/testdata", exist_ok=True)


def assert_target(name, ob):
    fname = f"tests/testdata/{name}.json"
    if os.path.exists(fname):
        with open(fname) as fh:
Fabrizio Milo's avatar
Fabrizio Milo committed
21
22
            # Use relative tolerance of 1e-5 and absolute tolerance of 1e-8
            # assuming most metrics work on `float32` values, which is the common
23
24
            # default floating type across popular libraries (PyTorch, Tensorflow, and JAX).
            assert flatten(json.load(fh)) == pytest.approx(
Fabrizio Milo's avatar
Fabrizio Milo committed
25
26
                flatten(json.loads(json.dumps(ob, sort_keys=True))), rel=1e-5, abs=1e-8
            )
27
    else:
Fabrizio Milo's avatar
Fabrizio Milo committed
28
        with open(fname, "w") as fh:
29
30
            json.dump(ob, fh, sort_keys=True)

31

32
33
34
35
def assert_target_hashed(name, ob):
    fname = f"tests/testdata/{name}"
    if os.path.exists(fname):
        with open(fname) as fh:
Fabrizio Milo's avatar
Fabrizio Milo committed
36
37
38
            assert (
                fh.read()
                == hashlib.sha256(
39
                    json.dumps([o.__dict__ for o in ob], sort_keys=True).encode("utf-8")
Fabrizio Milo's avatar
Fabrizio Milo committed
40
41
                ).hexdigest()
            )
42
    else:
Fabrizio Milo's avatar
Fabrizio Milo committed
43
44
45
        with open(fname, "w") as fh:
            fh.write(
                hashlib.sha256(
46
                    json.dumps([o.__dict__ for o in ob], sort_keys=True).encode("utf-8")
Fabrizio Milo's avatar
Fabrizio Milo committed
47
48
49
                ).hexdigest()
            )

50

51
# from https://stackoverflow.com/a/6027615
Fabrizio Milo's avatar
Fabrizio Milo committed
52
def flatten(d, parent_key="", sep="."):
53
54
55
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
jon-tow's avatar
jon-tow committed
56
        if isinstance(v, collections.abc.MutableMapping):
57
58
59
60
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)
61

Fabrizio Milo's avatar
Fabrizio Milo committed
62

63
64
# make sure eval results for a task version are stable

Fabrizio Milo's avatar
Fabrizio Milo committed
65

66
67
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_versions_stable(taskname, task_class):
68
    task_dict = tasks.get_task_dict([taskname])
Fabrizio Milo's avatar
Fabrizio Milo committed
69
    lm = models.get_model("dummy")()
70
71

    def ll_fn(reqs):
72
        for ctx, cont in [req.args for req in reqs]:
73
74
            if len(ctx) == 0:
                continue
75
            # space convention
Fabrizio Milo's avatar
Fabrizio Milo committed
76
77
78
            assert ctx[-1] != " "
            assert cont[0] == " " or ctx[-1] == "\n"

79
        assert_target_hashed(f"{taskname}-v{task_class.VERSION}-loglikelihood", reqs)
80
        res = []
Fabrizio Milo's avatar
Fabrizio Milo committed
81

82
83
84
85
86
87
88
        random.seed(42)
        for _ in reqs:
            res.append((-random.random(), False))

        return res

    def ll_perp_fn(reqs):
89
        for (string,) in [req.args for req in reqs]:
90
91
            assert isinstance(string, str)

Fabrizio Milo's avatar
Fabrizio Milo committed
92
93
94
        assert_target_hashed(
            f"{taskname}-v{task_class.VERSION}-loglikelihood_rolling", reqs
        )
95
96
97
98
99
100
101
        res = []

        random.seed(42)
        for _ in reqs:
            res.append(-random.random())

        return res
Fabrizio Milo's avatar
Fabrizio Milo committed
102

103
    def generate_until(reqs):
104
        res = []
105
        assert_target_hashed(f"{taskname}-v{task_class.VERSION}-generate_until", reqs)
Fabrizio Milo's avatar
Fabrizio Milo committed
106

107
        for ctx, _ in [req.args for req in reqs]:
108
            res.append("lol")
Fabrizio Milo's avatar
Fabrizio Milo committed
109
            assert ctx.strip() != ""
110
111
112
113
114

        return res

    lm.loglikelihood = ll_fn
    lm.loglikelihood_rolling = ll_perp_fn
115
    lm.generate_until = generate_until
116
117

    limit = None
118
    result = evaluator.evaluate(
Fabrizio Milo's avatar
Fabrizio Milo committed
119
120
121
122
123
        lm=lm,
        task_dict=task_dict,
        num_fewshot=0,
        limit=limit,
        bootstrap_iters=10,
124
125
    )

126
    assert_target(f"{taskname}-v{task_class.VERSION}-res", result)