test_tasks.py 7.28 KB
Newer Older
1
from itertools import islice
2
3
import pytest
from typing import List
4
import lm_eval.tasks as tasks
5
from lm_eval.api.task import ConfigurableTask
Leo Gao's avatar
Leo Gao committed
6
7


baberabb's avatar
baberabb committed
8
@pytest.fixture()
9
10
11
12
13
def task_class(task_name: List[str] = None) -> ConfigurableTask:
    if task_name is None:
        task_name = ["arc_easy"]
    x = [cls for name, cls in tasks.TASK_REGISTRY.items() if name in task_name]
    return x[0]
baberabb's avatar
baberabb committed
14
15
16


@pytest.fixture()
baberabb's avatar
baberabb committed
17
18
def limit() -> int:
    return 10
19
20
21


# Tests
baberabb's avatar
baberabb committed
22

Leo Gao's avatar
Leo Gao committed
23

24
def test_download(task_class: ConfigurableTask):
baberabb's avatar
baberabb committed
25
    task_class().download()
baberabb's avatar
typo  
baberabb committed
26
    assert task_class().dataset is not None
Leo Gao's avatar
Leo Gao committed
27
28


29
def test_has_training_docs(task_class: ConfigurableTask):
baberabb's avatar
baberabb committed
30
    assert task_class().has_training_docs() in [True, False]
Leo Gao's avatar
Leo Gao committed
31

Leo Gao's avatar
Leo Gao committed
32

33
def test_check_training_docs(task_class: ConfigurableTask):
34
    task = task_class()
baberabb's avatar
baberabb committed
35
36
    if task.has_training_docs():
        assert task._config["training_split"] is not None
37

38

baberabb's avatar
baberabb committed
39
def test_has_validation_docs(task_class):
baberabb's avatar
baberabb committed
40
    assert task_class().has_validation_docs() in [True, False]
Leo Gao's avatar
Leo Gao committed
41

42

baberabb's avatar
baberabb committed
43
def test_check_validation_docs(task_class):
44
    task = task_class()
baberabb's avatar
baberabb committed
45
46
    if task.has_validation_docs():
        assert task._config["validation_split"] is not None
47

Fabrizio Milo's avatar
Fabrizio Milo committed
48

baberabb's avatar
baberabb committed
49
def test_has_test_docs(task_class):
baberabb's avatar
baberabb committed
50
    assert task_class().has_test_docs() in [True, False]
51
52


baberabb's avatar
baberabb committed
53
def test_check_test_docs(task_class):
54
    task = task_class()
baberabb's avatar
baberabb committed
55
56
    if task.has_test_docs():
        assert task._config["test_split"] is not None
57

Fabrizio Milo's avatar
Fabrizio Milo committed
58

baberabb's avatar
baberabb committed
59
def test_should_decontaminate(task_class):
baberabb's avatar
baberabb committed
60
61
62
63
    task = task_class()
    assert task.should_decontaminate() in [True, False]
    if task.should_decontaminate():
        assert task._config["doc_to_decontamination_query"] is not None
64

65

baberabb's avatar
baberabb committed
66
def test_doc_to_text(task_class, limit):
baberabb's avatar
baberabb committed
67
    task = task_class()
68
    arr = (
baberabb's avatar
baberabb committed
69
70
71
        list(islice(task.test_docs(), limit))
        if task.has_test_docs()
        else list(islice(task.validation_docs(), limit))
72
    )
baberabb's avatar
baberabb committed
73
    _array = [task.doc_to_text(doc) for doc in arr]
baberabb's avatar
baberabb committed
74
75
76
77
78
79
    # space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
    assert all(
        isinstance(x, str) and (x[-1] != " " if len(x) != 0 else True) for x in _array
    )


80
def test_create_choices(task_class, limit):
baberabb's avatar
baberabb committed
81
    task = task_class()
82
    arr = (
baberabb's avatar
baberabb committed
83
84
85
        list(islice(task.test_docs(), limit))
        if task.has_test_docs()
        else list(islice(task.validation_docs(), limit))
86
    )
baberabb's avatar
baberabb committed
87
    _array = [task.doc_to_choice(doc) for doc in arr]
88
89
90
91
92
    # assert all(len(x) == 4 for x in _array)
    assert all(isinstance(x, list) for x in _array)
    assert all(isinstance(x[0], str) for x in _array)


baberabb's avatar
baberabb committed
93
def test_doc_to_target(task_class, limit):
baberabb's avatar
baberabb committed
94
    task = task_class()
95
    arr = (
baberabb's avatar
baberabb committed
96
97
98
        list(islice(task.test_docs(), limit))
        if task.has_test_docs()
        else list(islice(task.validation_docs(), limit))
99
    )
baberabb's avatar
baberabb committed
100
    _array_target = [task.doc_to_target(doc) for doc in arr]
101
102
    assert all(isinstance(label, int) for label in _array_target)
    assert len(_array_target) == limit if limit else True
baberabb's avatar
baberabb committed
103
104
105
    # _array_text = [task.doc_to_text(doc) for doc in arr]
    # Not working
    # assert all(tgt[0] == " " or txt[-1] == "\n" if  len(txt) != 0 else True for txt, tgt in zip(_array_text, _array_target))
106

Fabrizio Milo's avatar
Fabrizio Milo committed
107

baberabb's avatar
baberabb committed
108
109
110
def test_build_all_requests(task_class, limit):
    task_class().build_all_requests(rank=1, limit=limit, world_size=1)
    assert task_class.instances is not None
111

112

baberabb's avatar
baberabb committed
113
def test_construct_requests(task_class, limit):
baberabb's avatar
baberabb committed
114
    task = task_class()
115
    arr = (
baberabb's avatar
baberabb committed
116
117
118
        list(islice(task.test_docs(), limit))
        if task.has_test_docs()
        else list(islice(task.validation_docs(), limit))
119
    )
baberabb's avatar
baberabb committed
120
    requests = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
baberabb's avatar
baberabb committed
121
122
123
124
    assert all(isinstance(doc, list) for doc in requests)
    assert len(requests) == limit if limit else True


125
126
127
128
129
# def test_create_choices(task_class):
#     arr = list(islice(task_class().test_docs(), 1))
#     choices = task_class().create_choices(arr[0])
#     assert choices is not None
# checking if number of choices is correct
baberabb's avatar
baberabb committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229


# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_basic_interface(taskname, task_class):
#     print("Evaluating task", taskname)
#     task = task_class()
#
#     assert task.has_training_docs() in [True, False]
#     assert task.has_validation_docs() in [True, False]
#     assert task.has_test_docs() in [True, False]
#
#     assert isinstance(task.aggregation(), dict)
#     assert isinstance(task.higher_is_better(), dict)
#     assert task.aggregation().keys() == task.higher_is_better().keys()
#
#     for v in task.higher_is_better().values():
#         assert v in [True, False]
#
#     assert isinstance(task.VERSION, int)
#
#     # test deterministic docs
#     # (don't test train because it's slow)
#
#     task2 = task_class()
#
#     limit = None
#
#     if taskname in ["triviaqa"] or taskname.startswith("pile_"):
#         limit = 10000
#     if task.has_validation_docs():
#         arr = list(islice(task.validation_docs(), limit))
#         arr2 = list(islice(task2.validation_docs(), limit))
#
#         assert arr == arr2
#
#         reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
#         reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
#         assert reqs == reqs2
#
#     if task.has_test_docs():
#         arr = list(islice(task.test_docs(), limit))
#         arr2 = list(islice(task2.test_docs(), limit))
#
#         assert arr == arr2
#
#         reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
#         reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
#         assert reqs == reqs2
#
#     if task.has_training_docs():
#         arr = list(islice(task.training_docs(), limit))
#         arr2 = list(islice(task2.training_docs(), limit))
#
#         assert arr == arr2
#
#         reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
#         reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
#         assert reqs == reqs2
#
#
# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_documents_and_requests(taskname, task_class):
#     print("Evaluating task", taskname)
#     task = task_class()
#     fns = []
#     if task.has_training_docs():
#         fns.append(task.training_docs)
#     if task.has_validation_docs():
#         fns.append(task.validation_docs)
#     # test doc might not have labels
#     # if task.has_test_docs(): fns.append(task.test_docs)
#
#     for fn in fns:
#         # print(list(islice(fn(), 10)))
#         for doc in islice(fn(), 10):
#
#             txt = task.doc_to_text(doc)
#             tgt = task.doc_to_target(doc)
#
#             assert isinstance(txt, str)
#             assert isinstance(tgt, str)
#
#             # space convention
#             # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
#             if len(txt) != 0:
#                 assert txt[-1] != " "
#                 assert tgt[0] == " " or txt[-1] == "\n"
#
#             reqs = task.construct_requests(doc, txt)
#
#             # construct_requests can return just one request
#             if not isinstance(reqs, (list, tuple)):
#                 reqs = [reqs]
#
#             # todo: mock lm after refactoring evaluator.py to not be a mess
#             # for req in reqs:
#             #     assert isinstance(req, base.Request)