evaluator.py 11.6 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
Stephen Hogg's avatar
Stephen Hogg committed
3
import pathlib
Leo Gao's avatar
Leo Gao committed
4
import random
Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
cjlovering's avatar
cjlovering committed
9
import promptsource
10
import numpy as np
cjlovering's avatar
cjlovering committed
11
12

from promptsource.templates import DatasetTemplates
Stephen Hogg's avatar
Stephen Hogg committed
13
from lm_eval.utils import positional_deprecated, run_task_tests
14

15

16
@positional_deprecated
cjlovering's avatar
cjlovering committed
17
18
19
20
21
22
23
24
25
26
27
28
29
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
    num_fewshot=0,
    batch_size=None,
    device=None,
    no_cache=False,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    check_integrity=False,
):
30
    """Instantiate and evaluate a model on a list of tasks.
31

32
33
34
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
cjlovering's avatar
cjlovering committed
35
        String arguments for each model class, see LM.create_from_arg_string.
36
37
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
38
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
39
40
41
42
43
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional
44
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
45
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
46
        Whether or not to cache
47
48
49
50
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
51
    :param description_dict: dict[str, str]
cjlovering's avatar
cjlovering committed
52
        Dictionary of custom task descriptions of the form: `task_name: description`
Stephen Hogg's avatar
Stephen Hogg committed
53
54
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
55
    :return
56
        Dictionary of results
57
    """
58
59
60
    random.seed(1234)
    np.random.seed(1234)

61
62
63
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
cjlovering's avatar
cjlovering committed
64
65
66
67
68
        if model_args is None:
            model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(
            model_args, {"batch_size": batch_size, "device": device}
        )
69
70
71
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
72

jon-tow's avatar
jon-tow committed
73
74
    # TODO: Hard-code turning off cache while testing. Remove once testing is completed.
    no_cache = True
75
    if not no_cache:
76
        lm = lm_eval.base.CachingLM(
cjlovering's avatar
cjlovering committed
77
78
79
80
81
82
            lm,
            "lm_cache/"
            + model
            + "_"
            + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
            + ".db",
83
        )
cjlovering's avatar
cjlovering committed
84
85

    task_dict = lm_eval.tasks.get_task_dict_promptsource(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
86

Stephen Hogg's avatar
Stephen Hogg committed
87
    if check_integrity:
88
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
89

90
91
92
93
94
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
cjlovering's avatar
cjlovering committed
95
        description_dict=description_dict,
96
    )
97
98
99
100
101
102
103
104
105
106

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
107
        "bootstrap_iters": bootstrap_iters,
cjlovering's avatar
cjlovering committed
108
        "description_dict": description_dict,
109
110
111
    }

    return results
Leo Gao's avatar
Leo Gao committed
112
113


114
@positional_deprecated
cjlovering's avatar
cjlovering committed
115
116
117
118
119
120
121
122
123
def evaluate(
    lm,
    task_dict,
    provide_description=None,
    num_fewshot=0,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
):
124
125
126
127
128
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
129
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
130
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
131
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
132
133
134
135
136
137
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
138
    :param description_dict: dict[str, str]
cjlovering's avatar
cjlovering committed
139
        Dictionary of custom task descriptions of the form: `task_name: description`
140
141
142
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
143
144
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

145
146
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
147
148
    if provide_description is not None:
        # nudge people to not specify it at all
cjlovering's avatar
cjlovering committed
149
150
151
        print(
            "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
        )
152
153
154
155

    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
cjlovering's avatar
cjlovering committed
156
        if (task.has_validation_docs() or task.has_test_docs())
157
    ]
Leo Gao's avatar
Leo Gao committed
158
159

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
160
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
161
162
163
164

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

165
166
167
168
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
169
170
171
172

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}

173
    # get lists of each type of request
jon-tow's avatar
jon-tow committed
174
    for task_prompt_name, task in task_dict_items:
cjlovering's avatar
cjlovering committed
175
176
177
        # if task.is_generation_task():
        #     print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
        #     continue
178

jon-tow's avatar
jon-tow committed
179
        versions[task_prompt_name] = task.VERSION
180
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
181
182
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
183
            task_doc_func = task.test_docs
Leo Gao's avatar
Leo Gao committed
184
185
        elif task.has_validation_docs():
            task_doc_func = task.validation_docs
186
187
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
188

Leo Gao's avatar
Leo Gao committed
189
190
191
192
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
193
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
194

cjlovering's avatar
cjlovering committed
195
        description = (
jon-tow's avatar
jon-tow committed
196
197
            description_dict[task_prompt_name]
            if description_dict and task_prompt_name in description_dict
cjlovering's avatar
cjlovering committed
198
199
            else ""
        )
200

Leo Gao's avatar
Leo Gao committed
201
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
jon-tow's avatar
jon-tow committed
202
            docs[(task_prompt_name, doc_id)] = doc
Leo Gao's avatar
Leo Gao committed
203
            ctx = task.fewshot_context(
cjlovering's avatar
cjlovering committed
204
                doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
Leo Gao's avatar
Leo Gao committed
205
206
            )
            reqs = task.construct_requests(doc, ctx)
207
208
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
209
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
210
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
211
212
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
213
214
215
                requests_origin[req.request_type].append(
                    (i, task_prompt_name, doc, doc_id)
                )
Leo Gao's avatar
Leo Gao committed
216
217
218
219
220
221

    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
222
223
224
225
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
226

Leo Gao's avatar
Leo Gao committed
227
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
228
        resps = getattr(lm, reqtype)([req.args for req in reqs])
cjlovering's avatar
cjlovering committed
229
230
231
        resps = [
            x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
        ]
Leo Gao's avatar
Leo Gao committed
232

233
234
235
        for resp, (i, task_prompt_name, doc, doc_id) in zip(
            resps, requests_origin[reqtype]
        ):
jon-tow's avatar
jon-tow committed
236
            process_res_queue[(task_prompt_name, doc_id)].append((i, resp))
cjlovering's avatar
cjlovering committed
237

Leo Gao's avatar
Leo Gao committed
238
239
240
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
jon-tow's avatar
jon-tow committed
241
    for (task_prompt_name, doc_id), requests in process_res_queue.items():
Leo Gao's avatar
Leo Gao committed
242
243
244
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

jon-tow's avatar
jon-tow committed
245
246
        task = task_dict[task_prompt_name]
        doc = docs[(task_prompt_name, doc_id)]
Leo Gao's avatar
Leo Gao committed
247
248
249

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
jon-tow's avatar
jon-tow committed
250
251
            vals[(task_prompt_name, metric)].append(value)

Leo Gao's avatar
Leo Gao committed
252
    # aggregate results
jon-tow's avatar
jon-tow committed
253
254
255
256
    for (task_prompt_name, metric), items in vals.items():
        task_name, prompt_name = task_prompt_name.split("+")
        results[task_prompt_name]["task_name"] = task_name
        results[task_prompt_name]["prompt_name"] = prompt_name
257
        task = task_dict[task_prompt_name]
jon-tow's avatar
jon-tow committed
258
        results[task_prompt_name][metric] = task.aggregation()[metric](items)
Leo Gao's avatar
Leo Gao committed
259

260
261
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
262
263
        stderr = lm_eval.metrics.stderr_for_metric(
            metric=task.aggregation()[metric],
cjlovering's avatar
cjlovering committed
264
265
266
            bootstrap_iters=min(bootstrap_iters, 1000)
            if metric in ["bleu", "chrf", "ter"]
            else bootstrap_iters,
267
        )
Leo Gao's avatar
Leo Gao committed
268
        if stderr is not None:
jon-tow's avatar
jon-tow committed
269
            results[task_prompt_name][metric + "_stderr"] = stderr(items)
cjlovering's avatar
cjlovering committed
270
271

    return {"results": dict(results), "versions": dict(versions)}
272
273
274


def make_table(result_dict):
275
    """Generate table of results."""
276
277
278
279
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
280
281
282
283
284
285
286
287
288
289
    md_writer.headers = ["Task", "Prompt", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = [
        "Task",
        "Prompt",
        "Version",
        "Metric",
        "Value",
        "",
        "Stderr",
    ]
290
291
292
293
294

    values = []
    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
295
296
            if m.endswith("_stderr"):
                continue
297
298
            if "_name" in m:
                continue
299
300
            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
301
302
303
304
305
306
307
308
309
310
311
                values.append(
                    [
                        dic["task_name"],
                        dic["prompt_name"],
                        version,
                        m,
                        "%.4f" % v,
                        "±",
                        "%.4f" % se,
                    ]
                )
312
            else:
313
314
315
316
317
318
319
320
321
322
323
                values.append(
                    [
                        dic["task_name"],
                        dic["prompt_name"],
                        version,
                        m,
                        "%.4f" % v,
                        "",
                        "",
                    ]
                )
324
325
326
327
328
329
330
331
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

332
    return md_writer.dumps()