test_tasks.py 7.33 KB
Newer Older
1
from itertools import islice
2
3
import pytest
from typing import List
4
import lm_eval.tasks as tasks
5
from lm_eval.api.task import ConfigurableTask
Leo Gao's avatar
Leo Gao committed
6

baberabb's avatar
baberabb committed
7
# Using fixtures to get the task class and limit
baberabb's avatar
baberabb committed
8
@pytest.fixture()
baberabb's avatar
baberabb committed
9
def task_class() -> ConfigurableTask:
baberabb's avatar
baberabb committed
10
    task_name = ["arc_easy"]
11
12
    x = [cls for name, cls in tasks.TASK_REGISTRY.items() if name in task_name]
    return x[0]
baberabb's avatar
baberabb committed
13
14
15


@pytest.fixture()
baberabb's avatar
baberabb committed
16
17
def limit() -> int:
    return 10
18
19
20


# Tests
baberabb's avatar
baberabb committed
21

Leo Gao's avatar
Leo Gao committed
22

23
def test_download(task_class: ConfigurableTask):
baberabb's avatar
baberabb committed
24
    task_class().download()
baberabb's avatar
typo  
baberabb committed
25
    assert task_class().dataset is not None
Leo Gao's avatar
Leo Gao committed
26
27


28
def test_has_training_docs(task_class: ConfigurableTask):
baberabb's avatar
baberabb committed
29
    assert task_class().has_training_docs() in [True, False]
Leo Gao's avatar
Leo Gao committed
30

Leo Gao's avatar
Leo Gao committed
31

32
def test_check_training_docs(task_class: ConfigurableTask):
33
    task = task_class()
baberabb's avatar
baberabb committed
34
35
    if task.has_training_docs():
        assert task._config["training_split"] is not None
36

37

baberabb's avatar
baberabb committed
38
def test_has_validation_docs(task_class):
baberabb's avatar
baberabb committed
39
    assert task_class().has_validation_docs() in [True, False]
Leo Gao's avatar
Leo Gao committed
40

41

baberabb's avatar
baberabb committed
42
def test_check_validation_docs(task_class):
43
    task = task_class()
baberabb's avatar
baberabb committed
44
45
    if task.has_validation_docs():
        assert task._config["validation_split"] is not None
46

Fabrizio Milo's avatar
Fabrizio Milo committed
47

baberabb's avatar
baberabb committed
48
def test_has_test_docs(task_class):
baberabb's avatar
baberabb committed
49
    assert task_class().has_test_docs() in [True, False]
50
51


baberabb's avatar
baberabb committed
52
def test_check_test_docs(task_class):
53
    task = task_class()
baberabb's avatar
baberabb committed
54
55
    if task.has_test_docs():
        assert task._config["test_split"] is not None
56

Fabrizio Milo's avatar
Fabrizio Milo committed
57

baberabb's avatar
baberabb committed
58
def test_should_decontaminate(task_class):
baberabb's avatar
baberabb committed
59
60
61
62
    task = task_class()
    assert task.should_decontaminate() in [True, False]
    if task.should_decontaminate():
        assert task._config["doc_to_decontamination_query"] is not None
63

64

baberabb's avatar
baberabb committed
65
def test_doc_to_text(task_class, limit):
baberabb's avatar
baberabb committed
66
    task = task_class()
67
    arr = (
baberabb's avatar
baberabb committed
68
69
70
        list(islice(task.test_docs(), limit))
        if task.has_test_docs()
        else list(islice(task.validation_docs(), limit))
71
    )
baberabb's avatar
baberabb committed
72
    _array = [task.doc_to_text(doc) for doc in arr]
baberabb's avatar
baberabb committed
73
74
75
76
77
78
    # space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
    assert all(
        isinstance(x, str) and (x[-1] != " " if len(x) != 0 else True) for x in _array
    )


79
def test_create_choices(task_class, limit):
baberabb's avatar
baberabb committed
80
    task = task_class()
81
    arr = (
baberabb's avatar
baberabb committed
82
83
84
        list(islice(task.test_docs(), limit))
        if task.has_test_docs()
        else list(islice(task.validation_docs(), limit))
85
    )
86
87
88
89
90
    if "multiple_choice" in task._config.group:
        _array = [task.doc_to_choice(doc) for doc in arr]
        # assert all(len(x) == 4 for x in _array)
        assert all(isinstance(x, list) for x in _array)
        assert all(isinstance(x[0], str) for x in _array)
91
92


baberabb's avatar
baberabb committed
93
def test_doc_to_target(task_class, limit):
baberabb's avatar
baberabb committed
94
    task = task_class()
95
    arr = (
baberabb's avatar
baberabb committed
96
97
98
        list(islice(task.test_docs(), limit))
        if task.has_test_docs()
        else list(islice(task.validation_docs(), limit))
99
    )
baberabb's avatar
baberabb committed
100
    _array_target = [task.doc_to_target(doc) for doc in arr]
101
102
    assert all(isinstance(label, int) for label in _array_target)
    assert len(_array_target) == limit if limit else True
baberabb's avatar
baberabb committed
103
104
105
    # _array_text = [task.doc_to_text(doc) for doc in arr]
    # Not working
    # assert all(tgt[0] == " " or txt[-1] == "\n" if  len(txt) != 0 else True for txt, tgt in zip(_array_text, _array_target))
106

Fabrizio Milo's avatar
Fabrizio Milo committed
107

baberabb's avatar
baberabb committed
108
109
110
def test_build_all_requests(task_class, limit):
    task_class().build_all_requests(rank=1, limit=limit, world_size=1)
    assert task_class.instances is not None
111

112

baberabb's avatar
baberabb committed
113
def test_construct_requests(task_class, limit):
baberabb's avatar
baberabb committed
114
    task = task_class()
115
    arr = (
baberabb's avatar
baberabb committed
116
117
118
        list(islice(task.test_docs(), limit))
        if task.has_test_docs()
        else list(islice(task.validation_docs(), limit))
119
    )
baberabb's avatar
baberabb committed
120
    requests = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
baberabb's avatar
baberabb committed
121
122
123
124
    assert all(isinstance(doc, list) for doc in requests)
    assert len(requests) == limit if limit else True


125
126
127
128
129
# def test_create_choices(task_class):
#     arr = list(islice(task_class().test_docs(), 1))
#     choices = task_class().create_choices(arr[0])
#     assert choices is not None
# checking if number of choices is correct
baberabb's avatar
baberabb committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229


# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_basic_interface(taskname, task_class):
#     print("Evaluating task", taskname)
#     task = task_class()
#
#     assert task.has_training_docs() in [True, False]
#     assert task.has_validation_docs() in [True, False]
#     assert task.has_test_docs() in [True, False]
#
#     assert isinstance(task.aggregation(), dict)
#     assert isinstance(task.higher_is_better(), dict)
#     assert task.aggregation().keys() == task.higher_is_better().keys()
#
#     for v in task.higher_is_better().values():
#         assert v in [True, False]
#
#     assert isinstance(task.VERSION, int)
#
#     # test deterministic docs
#     # (don't test train because it's slow)
#
#     task2 = task_class()
#
#     limit = None
#
#     if taskname in ["triviaqa"] or taskname.startswith("pile_"):
#         limit = 10000
#     if task.has_validation_docs():
#         arr = list(islice(task.validation_docs(), limit))
#         arr2 = list(islice(task2.validation_docs(), limit))
#
#         assert arr == arr2
#
#         reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
#         reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
#         assert reqs == reqs2
#
#     if task.has_test_docs():
#         arr = list(islice(task.test_docs(), limit))
#         arr2 = list(islice(task2.test_docs(), limit))
#
#         assert arr == arr2
#
#         reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
#         reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
#         assert reqs == reqs2
#
#     if task.has_training_docs():
#         arr = list(islice(task.training_docs(), limit))
#         arr2 = list(islice(task2.training_docs(), limit))
#
#         assert arr == arr2
#
#         reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
#         reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
#         assert reqs == reqs2
#
#
# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_documents_and_requests(taskname, task_class):
#     print("Evaluating task", taskname)
#     task = task_class()
#     fns = []
#     if task.has_training_docs():
#         fns.append(task.training_docs)
#     if task.has_validation_docs():
#         fns.append(task.validation_docs)
#     # test doc might not have labels
#     # if task.has_test_docs(): fns.append(task.test_docs)
#
#     for fn in fns:
#         # print(list(islice(fn(), 10)))
#         for doc in islice(fn(), 10):
#
#             txt = task.doc_to_text(doc)
#             tgt = task.doc_to_target(doc)
#
#             assert isinstance(txt, str)
#             assert isinstance(tgt, str)
#
#             # space convention
#             # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
#             if len(txt) != 0:
#                 assert txt[-1] != " "
#                 assert tgt[0] == " " or txt[-1] == "\n"
#
#             reqs = task.construct_requests(doc, txt)
#
#             # construct_requests can return just one request
#             if not isinstance(reqs, (list, tuple)):
#                 reqs = [reqs]
#
#             # todo: mock lm after refactoring evaluator.py to not be a mess
#             # for req in reqs:
#             #     assert isinstance(req, base.Request)