test_evaluator.py 1.88 KB
Newer Older
haileyschoelkopf's avatar
haileyschoelkopf committed
1
# import lm_eval.base as base
2
3
4
from typing import List

import pytest
haileyschoelkopf's avatar
haileyschoelkopf committed
5
6

# import lm_eval.models as models
7
import lm_eval.api as api
Leo Gao's avatar
Leo Gao committed
8
import lm_eval.evaluator as evaluator
9
from lm_eval import tasks
10

Leo Gao's avatar
Leo Gao committed
11
12
13
14

# TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces

15
16

@pytest.mark.parametrize(
17
    "task_name,limit,model,model_args,bootstrap_iters",
18
19
20
21
22
23
    [
        (
            ["arc_easy"],
            10,
            "hf",
            "pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
24
            0,
25
26
27
28
29
30
        ),
        (
            ["mmlu_abstract_algebra"],
            None,
            "hf",
            "pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
31
            10000,
32
        ),
33
34
    ],
)
35
36
37
def test_evaluator(
    task_name: List[str], limit: int, model: str, model_args: str, bootstrap_iters: int
):
baberabb's avatar
baberabb committed
38
    e1 = evaluator.simple_evaluate(
39
40
41
42
        model=model,
        tasks=task_name,
        limit=limit,
        model_args=model_args,
43
        bootstrap_iters=bootstrap_iters,
44
    )
45
46
47
48
49
50
51
52
53
54
    assert e1 is not None

    lm = api.registry.get_model(model).create_from_arg_string(
        model_args,
        {
            "batch_size": None,
            "max_batch_size": None,
            "device": None,
        },
    )
55
56
    task_manager = tasks.TaskManager()
    task_dict = tasks.get_task_dict(task_name, task_manager)
57
58
59
60

    e2 = evaluator.evaluate(
        lm=lm,
        task_dict=task_dict,
61
        limit=limit,
62
        bootstrap_iters=bootstrap_iters,
63
    )
64

65
    assert e2 is not None
66
    # check that caching is working
67
68

    def r(x):
69
70
71
72
        if "arc_easy" in x["results"]:
            return x["results"]["arc_easy"]
        else:
            return x["results"]["mmlu_abstract_algebra"]
73
74
75
76
77

    assert all(
        x == y
        for x, y in zip([y for _, y in r(e1).items()], [y for _, y in r(e2).items()])
    )