test_evaluator.py 2.04 KB
Newer Older
1
import os
haileyschoelkopf's avatar
haileyschoelkopf committed
2
3
4

# import lm_eval.base as base
import lm_eval.api.registry as registry
Leo Gao's avatar
Leo Gao committed
5
import lm_eval.tasks as tasks
haileyschoelkopf's avatar
haileyschoelkopf committed
6
7
8

# import lm_eval.models as models

Leo Gao's avatar
Leo Gao committed
9
import lm_eval.evaluator as evaluator
10
import random
Leo Gao's avatar
Leo Gao committed
11
12
13
14
15
16
import pytest


# TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces

17
18

@pytest.mark.parametrize(
baberabb's avatar
baberabb committed
19
    "task_name,limit,model,model_args",
20
21
22
23
24
25
26
27
28
29
30
31
32
    [
        (
            ["arc_easy"],
            10,
            "hf",
            "pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
        )
    ],
)
def test_evaluator(task_name: list[str], limit: int, model: str, model_args: str):
    task_name = task_name
    limit = 10
    model, model_args = model, model_args
baberabb's avatar
baberabb committed
33
    # task_dict = tasks.get_task_dict(task)
34

haileyschoelkopf's avatar
haileyschoelkopf committed
35
36
37
    # TODO: re-add cachingLM
    # os.system("rm test_cache.db")
    # lm = base.CachingLM(models.get_model("dummy")(), "test_cache.db")
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    # lm = registry.get_model("dummy")()

    # def ll_fn(reqs):
    #     for ctx, cont in [req.args for req in reqs]:
    #         if len(ctx) == 0:
    #             continue
    #         # space convention
    #         assert ctx[-1] != " "
    #         assert cont[0] == " " or ctx[-1] == "\n"
    #
    #     res = []
    #
    #     random.seed(42)
    #     for _ in reqs:
    #         res.append((-random.random(), False))
    #
    #     return res
    #
    # def ll_perp_fn(reqs):
    #     for (string,) in reqs:
    #         assert isinstance(string, str)
    #
    #     res = []
    #     random.seed(42)
    #     for _ in reqs:
    #         res.append(-random.random())
    #
    #     return res
    #
    # lm.loglikelihood = ll_fn
    # lm.loglikelihood_rolling = ll_perp_fn
69

baberabb's avatar
baberabb committed
70
    e1 = evaluator.simple_evaluate(
71
72
73
74
        model=model,
        tasks=task_name,
        limit=limit,
        model_args=model_args,
75
    )
baberabb's avatar
baberabb committed
76
    e2 = evaluator.simple_evaluate(
77
78
79
80
        model=model,
        tasks=task_name,
        limit=limit,
        model_args=model_args,
81
    )
82

83
    # check that caching is working
84
    assert e1 == e2