test_tasks.py 2.72 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
3
4
5
6
7
8
9
import lm_eval.tasks as tasks
import lm_eval.base as base
from itertools import islice
import pytest


@pytest.mark.parametrize("taskname,Task", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, Task):
    print('Evaluating task', taskname)
Leo Gao's avatar
Leo Gao committed
10
11
    #dl = Task.download
    #Task.download = MagicMock()
Leo Gao's avatar
Leo Gao committed
12
    task = Task()
Leo Gao's avatar
Leo Gao committed
13
    #Task.download = dl
Leo Gao's avatar
Leo Gao committed
14
15
16
17
18
19
20
21
22
23
24

    assert task.has_training_docs() in [True, False]
    assert task.has_validation_docs() in [True, False]
    assert task.has_test_docs() in [True, False]

    assert isinstance(task.aggregation(), dict)
    assert isinstance(task.higher_is_better(), dict)
    assert task.aggregation().keys() == task.higher_is_better().keys()

    for v in task.higher_is_better().values(): assert v in [True, False]

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    # test deterministic docs
    # (don't test train because it's slow)

    task2 = Task()
    if task.has_validation_docs():
        arr = list(islice(task.validation_docs(), 100))
        arr2 = list(islice(task2.validation_docs(), 100))

        assert arr == arr2

        reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
        reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
        
        assert reqs == reqs2

    if task.has_test_docs():
        arr = list(islice(task.test_docs(), 100))
        arr2 = list(islice(task2.test_docs(), 100))

        assert arr == arr2

        reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
        reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
        
        assert reqs == reqs2


Leo Gao's avatar
Leo Gao committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

@pytest.mark.parametrize("taskname,Task", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, Task):
    print('Evaluating task', taskname)
    task = Task()
    fns = []
    if task.has_training_docs(): fns.append(task.training_docs)
    if task.has_validation_docs(): fns.append(task.validation_docs)
    # test doce might not have labels
    #if task.has_test_docs(): fns.append(task.test_docs)

    for fn in fns:
        #print(list(islice(fn(), 10)))
        for doc in islice(fn(), 10):
            
            txt = task.doc_to_text(doc)
            tgt = task.doc_to_target(doc)

            assert isinstance(txt, str)
            assert isinstance(tgt, str)
72
73
74
75
            
            # space convention
            assert txt[-1] != ' '
            assert tgt[0] == ' ' or txt[-1] == '\n'
Leo Gao's avatar
Leo Gao committed
76
77

            reqs = task.construct_requests(doc, txt)
Leo Gao's avatar
Leo Gao committed
78
79
80
            
            # construct_requests can return just one request
            if not isinstance(reqs, (list, tuple)): reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
81

Leo Gao's avatar
Leo Gao committed
82
            # todo: mock lm after refactoring evaluator.py to not be a mess
Leo Gao's avatar
Leo Gao committed
83
            for req in reqs:
Leo Gao's avatar
Leo Gao committed
84
                assert isinstance(req, base.Request)