test_tasks.py 3.5 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
import lm_eval.tasks as tasks
import lm_eval.base as base
import pytest
4
from itertools import islice
Leo Gao's avatar
Leo Gao committed
5
6


7
8
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class):
Leo Gao's avatar
Leo Gao committed
9
    print('Evaluating task', taskname)
10
11
12
13
    # dl = task_class.download
    # task_class.download = MagicMock()
    task = task_class()
    # task_class.download = dl
Leo Gao's avatar
Leo Gao committed
14
15
16
17
18
19
20
21
22

    assert task.has_training_docs() in [True, False]
    assert task.has_validation_docs() in [True, False]
    assert task.has_test_docs() in [True, False]

    assert isinstance(task.aggregation(), dict)
    assert isinstance(task.higher_is_better(), dict)
    assert task.aggregation().keys() == task.higher_is_better().keys()

23
24
    for v in task.higher_is_better().values():
        assert v in [True, False]
Leo Gao's avatar
Leo Gao committed
25

Leo Gao's avatar
Leo Gao committed
26
27
    assert isinstance(task.VERSION, int)

28
29
30
    # test deterministic docs
    # (don't test train because it's slow)

31
    task2 = task_class()
32

Leo Gao's avatar
Leo Gao committed
33
    limit = None
Leo Gao's avatar
Leo Gao committed
34

Leo Gao's avatar
Leo Gao committed
35
    if taskname in ["triviaqa"] or taskname.startswith("pile_"):
36
        limit = 10000
37
    if task.has_validation_docs():
Leo Gao's avatar
Leo Gao committed
38
39
        arr = list(islice(task.validation_docs(), limit))
        arr2 = list(islice(task2.validation_docs(), limit))
40
41
42
43
44
45
46
47
48

        assert arr == arr2

        reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
        reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
        
        assert reqs == reqs2

    if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
49
50
        arr = list(islice(task.test_docs(), limit))
        arr2 = list(islice(task2.test_docs(), limit))
51
52
53
54
55
56
57
58

        assert arr == arr2

        reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
        reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
        
        assert reqs == reqs2

Leo Gao's avatar
Leo Gao committed
59
    if task.has_training_docs():
60
61
62
63
64
65
66
67
68
69
        arr = list(islice(task.training_docs(), limit))
        arr2 = list(islice(task2.training_docs(), limit))

        assert arr == arr2

        reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
        reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
        
        assert reqs == reqs2

70

71
72
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, task_class):
Leo Gao's avatar
Leo Gao committed
73
    print('Evaluating task', taskname)
74
    task = task_class()
Leo Gao's avatar
Leo Gao committed
75
    fns = []
76
77
78
79
    if task.has_training_docs():
        fns.append(task.training_docs)
    if task.has_validation_docs():
        fns.append(task.validation_docs)
Leo Gao's avatar
Leo Gao committed
80
    # test doc might not have labels
81
    # if task.has_test_docs(): fns.append(task.test_docs)
Leo Gao's avatar
Leo Gao committed
82
83

    for fn in fns:
84
        # print(list(islice(fn(), 10)))
Leo Gao's avatar
Leo Gao committed
85
86
87
88
89
90
91
        for doc in islice(fn(), 10):
            
            txt = task.doc_to_text(doc)
            tgt = task.doc_to_target(doc)

            assert isinstance(txt, str)
            assert isinstance(tgt, str)
92
93
            
            # space convention
94
95
96
97
            # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
            if len(txt) != 0:
                assert txt[-1] != ' '
                assert tgt[0] == ' ' or txt[-1] == '\n'
Leo Gao's avatar
Leo Gao committed
98
99

            reqs = task.construct_requests(doc, txt)
Leo Gao's avatar
Leo Gao committed
100
101
            
            # construct_requests can return just one request
102
103
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
104

Leo Gao's avatar
Leo Gao committed
105
            # todo: mock lm after refactoring evaluator.py to not be a mess
Leo Gao's avatar
Leo Gao committed
106
            for req in reqs:
Leo Gao's avatar
Leo Gao committed
107
                assert isinstance(req, base.Request)