test_tasks.py 7.24 KB
Newer Older
1
from itertools import islice
2
3
import pytest
from typing import List
4
import lm_eval.tasks as tasks
5
from lm_eval.api.task import ConfigurableTask
Leo Gao's avatar
Leo Gao committed
6
7


baberabb's avatar
baberabb committed
8
@pytest.fixture()
9
10
11
12
13
def task_class(task_name: List[str] = None) -> ConfigurableTask:
    if task_name is None:
        task_name = ["arc_easy"]
    x = [cls for name, cls in tasks.TASK_REGISTRY.items() if name in task_name]
    return x[0]
baberabb's avatar
baberabb committed
14
15
16


@pytest.fixture()
17
18
19
20
21
def limit(any_new_tasks: bool) -> int:
    return 100 if any_new_tasks else 10


# Tests
baberabb's avatar
baberabb committed
22

Leo Gao's avatar
Leo Gao committed
23

24
def test_download(task_class: ConfigurableTask):
baberabb's avatar
baberabb committed
25
    task_class().download()
baberabb's avatar
typo  
baberabb committed
26
    assert task_class().dataset is not None
Leo Gao's avatar
Leo Gao committed
27
28


29
def test_has_training_docs(task_class: ConfigurableTask):
baberabb's avatar
baberabb committed
30
    assert task_class().has_training_docs() in [True, False]
Leo Gao's avatar
Leo Gao committed
31

Leo Gao's avatar
Leo Gao committed
32

33
def test_check_training_docs(task_class: ConfigurableTask):
34
35
    task = task_class()
    assert task.has_training_docs() if task._config["training_split"] else True
36

37

baberabb's avatar
baberabb committed
38
39
def test_has_validation_docs(task_class):
    assert task_class().has_training_docs() in [True, False]
Leo Gao's avatar
Leo Gao committed
40

41

baberabb's avatar
baberabb committed
42
def test_check_validation_docs(task_class):
43
44
45
46
    task = task_class()
    assert (
        task_class().has_training_docs() if task._config["validation_split"] else True
    )
47

Fabrizio Milo's avatar
Fabrizio Milo committed
48

baberabb's avatar
baberabb committed
49
def test_has_test_docs(task_class):
baberabb's avatar
typo  
baberabb committed
50
    assert task_class().has_training_docs() in [True, False]
51
52


baberabb's avatar
baberabb committed
53
def test_check_test_docs(task_class):
54
55
    task = task_class()
    assert task_class().has_training_docs() if task._config["test_split"] else True
56

Fabrizio Milo's avatar
Fabrizio Milo committed
57

baberabb's avatar
baberabb committed
58
def test_should_decontaminate(task_class):
59
60
61
62
    task_class = task_class()
    assert task_class.should_decontaminate() in [True, False]
    if task_class.should_decontaminate():
        assert task_class._config["doc_to_decontamination_query"] is not None
63

64

baberabb's avatar
baberabb committed
65
def test_doc_to_text(task_class, limit):
66
67
68
69
70
71
    arr = (
        list(islice(task_class().test_docs(), limit))
        if limit
        else list(task_class().test_docs())
    )
    _array = [task_class().doc_to_text(doc) for doc in arr]
baberabb's avatar
baberabb committed
72
73
74
75
76
77
    # space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
    assert all(
        isinstance(x, str) and (x[-1] != " " if len(x) != 0 else True) for x in _array
    )


78
79
80
81
82
83
84
85
86
87
88
89
def test_create_choices(task_class, limit):
    arr = (
        list(islice(task_class().test_docs(), limit))
        if limit
        else list(task_class().test_docs())
    )
    _array = [task_class().doc_to_choice(doc) for doc in arr]
    # assert all(len(x) == 4 for x in _array)
    assert all(isinstance(x, list) for x in _array)
    assert all(isinstance(x[0], str) for x in _array)


baberabb's avatar
baberabb committed
90
def test_doc_to_target(task_class, limit):
91
92
93
94
95
96
97
98
    arr = (
        list(islice(task_class().test_docs(), limit))
        if limit
        else list(task_class().test_target())
    )
    _array_target = [task_class().doc_to_target(doc) for doc in arr]
    assert all(isinstance(label, int) for label in _array_target)
    assert len(_array_target) == limit if limit else True
baberabb's avatar
baberabb committed
99
100
101
    # _array_text = [task.doc_to_text(doc) for doc in arr]
    # Not working
    # assert all(tgt[0] == " " or txt[-1] == "\n" if  len(txt) != 0 else True for txt, tgt in zip(_array_text, _array_target))
102

Fabrizio Milo's avatar
Fabrizio Milo committed
103

baberabb's avatar
baberabb committed
104
105
106
def test_build_all_requests(task_class, limit):
    task_class().build_all_requests(rank=1, limit=limit, world_size=1)
    assert task_class.instances is not None
107

108

baberabb's avatar
baberabb committed
109
def test_construct_requests(task_class, limit):
110
111
112
113
114
115
116
117
118
    arr = (
        list(islice(task_class().test_docs(), limit))
        if limit
        else list(task_class().test_docs())
    )
    requests = [
        task_class().construct_requests(doc, task_class().doc_to_text(doc))
        for doc in arr
    ]
baberabb's avatar
baberabb committed
119
120
121
122
    assert all(isinstance(doc, list) for doc in requests)
    assert len(requests) == limit if limit else True


123
124
125
126
127
# def test_create_choices(task_class):
#     arr = list(islice(task_class().test_docs(), 1))
#     choices = task_class().create_choices(arr[0])
#     assert choices is not None
# checking if number of choices is correct
baberabb's avatar
baberabb committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227


# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_basic_interface(taskname, task_class):
#     print("Evaluating task", taskname)
#     task = task_class()
#
#     assert task.has_training_docs() in [True, False]
#     assert task.has_validation_docs() in [True, False]
#     assert task.has_test_docs() in [True, False]
#
#     assert isinstance(task.aggregation(), dict)
#     assert isinstance(task.higher_is_better(), dict)
#     assert task.aggregation().keys() == task.higher_is_better().keys()
#
#     for v in task.higher_is_better().values():
#         assert v in [True, False]
#
#     assert isinstance(task.VERSION, int)
#
#     # test deterministic docs
#     # (don't test train because it's slow)
#
#     task2 = task_class()
#
#     limit = None
#
#     if taskname in ["triviaqa"] or taskname.startswith("pile_"):
#         limit = 10000
#     if task.has_validation_docs():
#         arr = list(islice(task.validation_docs(), limit))
#         arr2 = list(islice(task2.validation_docs(), limit))
#
#         assert arr == arr2
#
#         reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
#         reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
#         assert reqs == reqs2
#
#     if task.has_test_docs():
#         arr = list(islice(task.test_docs(), limit))
#         arr2 = list(islice(task2.test_docs(), limit))
#
#         assert arr == arr2
#
#         reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
#         reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
#         assert reqs == reqs2
#
#     if task.has_training_docs():
#         arr = list(islice(task.training_docs(), limit))
#         arr2 = list(islice(task2.training_docs(), limit))
#
#         assert arr == arr2
#
#         reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
#         reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
#         assert reqs == reqs2
#
#
# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_documents_and_requests(taskname, task_class):
#     print("Evaluating task", taskname)
#     task = task_class()
#     fns = []
#     if task.has_training_docs():
#         fns.append(task.training_docs)
#     if task.has_validation_docs():
#         fns.append(task.validation_docs)
#     # test doc might not have labels
#     # if task.has_test_docs(): fns.append(task.test_docs)
#
#     for fn in fns:
#         # print(list(islice(fn(), 10)))
#         for doc in islice(fn(), 10):
#
#             txt = task.doc_to_text(doc)
#             tgt = task.doc_to_target(doc)
#
#             assert isinstance(txt, str)
#             assert isinstance(tgt, str)
#
#             # space convention
#             # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
#             if len(txt) != 0:
#                 assert txt[-1] != " "
#                 assert tgt[0] == " " or txt[-1] == "\n"
#
#             reqs = task.construct_requests(doc, txt)
#
#             # construct_requests can return just one request
#             if not isinstance(reqs, (list, tuple)):
#                 reqs = [reqs]
#
#             # todo: mock lm after refactoring evaluator.py to not be a mess
#             # for req in reqs:
#             #     assert isinstance(req, base.Request)