test_evaluator.py 1.54 KB
Newer Older
1
import os
haileyschoelkopf's avatar
haileyschoelkopf committed
2
3
4

# import lm_eval.base as base
import lm_eval.api.registry as registry
Leo Gao's avatar
Leo Gao committed
5
import lm_eval.tasks as tasks
haileyschoelkopf's avatar
haileyschoelkopf committed
6
7

# import lm_eval.models as models
8
import lm_eval.api as api
Leo Gao's avatar
Leo Gao committed
9
import lm_eval.evaluator as evaluator
baberabb's avatar
baberabb committed
10
from typing import List
11
import random
Leo Gao's avatar
Leo Gao committed
12
13
import pytest

14
tasks.initialize_tasks()
Leo Gao's avatar
Leo Gao committed
15
16
17
18

# TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces

19
20

@pytest.mark.parametrize(
baberabb's avatar
baberabb committed
21
    "task_name,limit,model,model_args",
22
23
24
25
26
27
28
29
30
    [
        (
            ["arc_easy"],
            10,
            "hf",
            "pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
        )
    ],
)
baberabb's avatar
baberabb committed
31
def test_evaluator(task_name: List[str], limit: int, model: str, model_args: str):
32
33
    task_name = task_name
    limit = 10
34

baberabb's avatar
baberabb committed
35
    e1 = evaluator.simple_evaluate(
36
37
38
39
        model=model,
        tasks=task_name,
        limit=limit,
        model_args=model_args,
40
    )
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    assert e1 is not None

    lm = api.registry.get_model(model).create_from_arg_string(
        model_args,
        {
            "batch_size": None,
            "max_batch_size": None,
            "device": None,
        },
    )
    task_dict = tasks.get_task_dict(task_name, num_fewshot=0)

    e2 = evaluator.evaluate(
        lm=lm,
        task_dict=task_dict,
56
        limit=limit,
57
    )
58

59
    assert e2 is not None
60
    # check that caching is working
61
62
63
64
65
66
67
68

    def r(x):
        return x["results"]["arc_easy"]

    assert all(
        x == y
        for x, y in zip([y for _, y in r(e1).items()], [y for _, y in r(e2).items()])
    )