test_tasks.py 3.34 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
import lm_eval.tasks as tasks
import lm_eval.base as base
import pytest
4
from itertools import islice
Leo Gao's avatar
Leo Gao committed
5
6


7
8
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class):
bzantium's avatar
bzantium committed
9
    print("Evaluating task", taskname)
10
    task = task_class()
Leo Gao's avatar
Leo Gao committed
11
12
13
14
15
16
17
18
19

    assert task.has_training_docs() in [True, False]
    assert task.has_validation_docs() in [True, False]
    assert task.has_test_docs() in [True, False]

    assert isinstance(task.aggregation(), dict)
    assert isinstance(task.higher_is_better(), dict)
    assert task.aggregation().keys() == task.higher_is_better().keys()

20
21
    for v in task.higher_is_better().values():
        assert v in [True, False]
Leo Gao's avatar
Leo Gao committed
22

Leo Gao's avatar
Leo Gao committed
23
24
    assert isinstance(task.VERSION, int)

25
26
27
    # test deterministic docs
    # (don't test train because it's slow)

28
    task2 = task_class()
29

Leo Gao's avatar
Leo Gao committed
30
    limit = None
Leo Gao's avatar
Leo Gao committed
31

Leo Gao's avatar
Leo Gao committed
32
    if taskname in ["triviaqa"] or taskname.startswith("pile_"):
33
        limit = 10000
34
    if task.has_validation_docs():
Leo Gao's avatar
Leo Gao committed
35
36
        arr = list(islice(task.validation_docs(), limit))
        arr2 = list(islice(task2.validation_docs(), limit))
37
38
39
40
41

        assert arr == arr2

        reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
        reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
bzantium's avatar
bzantium committed
42

43
44
45
        assert reqs == reqs2

    if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
46
47
        arr = list(islice(task.test_docs(), limit))
        arr2 = list(islice(task2.test_docs(), limit))
48
49
50
51
52

        assert arr == arr2

        reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
        reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
bzantium's avatar
bzantium committed
53

54
55
        assert reqs == reqs2

Leo Gao's avatar
Leo Gao committed
56
    if task.has_training_docs():
57
58
59
60
61
62
63
        arr = list(islice(task.training_docs(), limit))
        arr2 = list(islice(task2.training_docs(), limit))

        assert arr == arr2

        reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
        reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
bzantium's avatar
bzantium committed
64

65
66
        assert reqs == reqs2

67

68
69
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, task_class):
bzantium's avatar
bzantium committed
70
    print("Evaluating task", taskname)
71
    task = task_class()
Leo Gao's avatar
Leo Gao committed
72
    fns = []
73
74
75
76
    if task.has_training_docs():
        fns.append(task.training_docs)
    if task.has_validation_docs():
        fns.append(task.validation_docs)
Leo Gao's avatar
Leo Gao committed
77
    # test doc might not have labels
78
    # if task.has_test_docs(): fns.append(task.test_docs)
Leo Gao's avatar
Leo Gao committed
79
80

    for fn in fns:
81
        # print(list(islice(fn(), 10)))
Leo Gao's avatar
Leo Gao committed
82
        for doc in islice(fn(), 10):
bzantium's avatar
bzantium committed
83

Leo Gao's avatar
Leo Gao committed
84
85
86
87
88
            txt = task.doc_to_text(doc)
            tgt = task.doc_to_target(doc)

            assert isinstance(txt, str)
            assert isinstance(tgt, str)
bzantium's avatar
bzantium committed
89

90
            # space convention
91
92
            # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
            if len(txt) != 0:
bzantium's avatar
bzantium committed
93
94
                assert txt[-1] != " "
                assert tgt[0] == " " or txt[-1] == "\n"
Leo Gao's avatar
Leo Gao committed
95
96

            reqs = task.construct_requests(doc, txt)
bzantium's avatar
bzantium committed
97

Leo Gao's avatar
Leo Gao committed
98
            # construct_requests can return just one request
99
100
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
101

Leo Gao's avatar
Leo Gao committed
102
            # todo: mock lm after refactoring evaluator.py to not be a mess
Leo Gao's avatar
Leo Gao committed
103
            for req in reqs:
Leo Gao's avatar
Leo Gao committed
104
                assert isinstance(req, base.Request)