evaluator.py 13 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
Stephen Hogg's avatar
Stephen Hogg committed
3
import pathlib
Leo Gao's avatar
Leo Gao committed
4
import random
cjlovering's avatar
cjlovering committed
5

Leo Gao's avatar
Leo Gao committed
6
import lm_eval.metrics
7
8
9
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
cjlovering's avatar
cjlovering committed
10
import promptsource
11
import numpy as np
cjlovering's avatar
cjlovering committed
12
13

from promptsource.templates import DatasetTemplates
Stephen Hogg's avatar
Stephen Hogg committed
14
from lm_eval.utils import positional_deprecated, run_task_tests
15

16

17
@positional_deprecated
cjlovering's avatar
cjlovering committed
18
19
20
21
22
23
24
25
26
27
28
29
30
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
    num_fewshot=0,
    batch_size=None,
    device=None,
    no_cache=False,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    check_integrity=False,
):
31
    """Instantiate and evaluate a model on a list of tasks.
32

33
34
35
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
cjlovering's avatar
cjlovering committed
36
        String arguments for each model class, see LM.create_from_arg_string.
37
38
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
39
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
40
41
42
43
44
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional
45
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
46
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
47
        Whether or not to cache
48
49
50
51
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
52
    :param description_dict: dict[str, str]
cjlovering's avatar
cjlovering committed
53
        Dictionary of custom task descriptions of the form: `task_name: description`
Stephen Hogg's avatar
Stephen Hogg committed
54
55
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
56
    :return
57
        Dictionary of results
58
    """
59
60
61
    random.seed(1234)
    np.random.seed(1234)

62
63
64
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
cjlovering's avatar
cjlovering committed
65
66
67
68
69
        if model_args is None:
            model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(
            model_args, {"batch_size": batch_size, "device": device}
        )
70
71
72
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
73

jon-tow's avatar
jon-tow committed
74
75
    # TODO: Hard-code turning off cache while testing. Remove once testing is completed.
    no_cache = True
76
    if not no_cache:
77
        lm = lm_eval.base.CachingLM(
cjlovering's avatar
cjlovering committed
78
79
80
81
82
83
            lm,
            "lm_cache/"
            + model
            + "_"
            + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
            + ".db",
84
        )
cjlovering's avatar
cjlovering committed
85
86

    task_dict = lm_eval.tasks.get_task_dict_promptsource(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
87

Stephen Hogg's avatar
Stephen Hogg committed
88
    if check_integrity:
89
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
90

91
92
93
94
95
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
cjlovering's avatar
cjlovering committed
96
        description_dict=description_dict,
97
    )
98
99
100
101
102
103
104
105
106
107

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
108
        "bootstrap_iters": bootstrap_iters,
cjlovering's avatar
cjlovering committed
109
        "description_dict": description_dict,
110
111
112
    }

    return results
Leo Gao's avatar
Leo Gao committed
113
114


115
@positional_deprecated
cjlovering's avatar
cjlovering committed
116
117
118
119
120
121
122
123
124
def evaluate(
    lm,
    task_dict,
    provide_description=None,
    num_fewshot=0,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
):
125
126
127
128
129
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
130
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
131
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
132
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
133
134
135
136
137
138
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
139
    :param description_dict: dict[str, str]
cjlovering's avatar
cjlovering committed
140
        Dictionary of custom task descriptions of the form: `task_name: description`
141
142
143
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
144
145
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

146
147
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
148
149
    if provide_description is not None:
        # nudge people to not specify it at all
cjlovering's avatar
cjlovering committed
150
151
152
        print(
            "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
        )
153
154
155
156

    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
cjlovering's avatar
cjlovering committed
157
        if (task.has_validation_docs() or task.has_test_docs())
158
    ]
Leo Gao's avatar
Leo Gao committed
159
160

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
161
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
162
163
164
165

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

166
167
168
169
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
170
171
172
173

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}

174
    # get lists of each type of request
jon-tow's avatar
jon-tow committed
175
176
    for task_prompt_name, task in task_dict_items:
        versions[task_prompt_name] = task.VERSION
177
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
178
179
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
180
            task_doc_func = task.test_docs
Leo Gao's avatar
Leo Gao committed
181
182
        elif task.has_validation_docs():
            task_doc_func = task.validation_docs
183
184
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
185

Leo Gao's avatar
Leo Gao committed
186
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
cjlovering's avatar
cjlovering committed
187
        task_docs = list(enumerate(list(task_doc_func())))
Leo Gao's avatar
Leo Gao committed
188
189
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
190
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
191

cjlovering's avatar
cjlovering committed
192
        description = (
jon-tow's avatar
jon-tow committed
193
194
            description_dict[task_prompt_name]
            if description_dict and task_prompt_name in description_dict
cjlovering's avatar
cjlovering committed
195
196
            else ""
        )
197

cjlovering's avatar
cjlovering committed
198
199
200
        for doc_id, (original_doc_id, doc) in enumerate(
            itertools.islice(task_docs, 0, limit)
        ):
cjlovering's avatar
cjlovering committed
201
202
203
            if task.invalid_doc_for_prompt(doc):
                continue

jon-tow's avatar
jon-tow committed
204
            docs[(task_prompt_name, doc_id)] = doc
cjlovering's avatar
cjlovering committed
205
            ctx, fewshotex_logging_info = task.fewshot_context(
cjlovering's avatar
cjlovering committed
206
                doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
Leo Gao's avatar
Leo Gao committed
207
            )
cjlovering's avatar
cjlovering committed
208
            fewshotex_logging_info["doc_id"] = original_doc_id
209
210
            args = {"num_fewshot": num_fewshot}
            reqs = task.construct_requests(doc, ctx, args)
211
212
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
213
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
214
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
215
216
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
217
                requests_origin[req.request_type].append(
cjlovering's avatar
cjlovering committed
218
                    (i, task_prompt_name, doc, doc_id, fewshotex_logging_info)
219
                )
Leo Gao's avatar
Leo Gao committed
220
221
222
223
224
225

    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
226
227
228
229
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
230

Leo Gao's avatar
Leo Gao committed
231
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
232
        resps = getattr(lm, reqtype)([req.args for req in reqs])
cjlovering's avatar
cjlovering committed
233
234
235
        resps = [
            x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
        ]
Leo Gao's avatar
Leo Gao committed
236

cjlovering's avatar
cjlovering committed
237
        for resp, (i, task_prompt_name, doc, doc_id, fewshotex_logging_info) in zip(
238
239
            resps, requests_origin[reqtype]
        ):
cjlovering's avatar
cjlovering committed
240
241
242
            process_res_queue[(task_prompt_name, doc_id)].append(
                (i, resp, fewshotex_logging_info)
            )
cjlovering's avatar
cjlovering committed
243

Leo Gao's avatar
Leo Gao committed
244
245
246
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
cjlovering's avatar
cjlovering committed
247
248
249
250
    examples = []
    for (task_prompt_name, doc_id), per_doc_requests in process_res_queue.items():
        per_doc_requests.sort(key=lambda x: x[0])
        per_doc_results = [x[1] for x in per_doc_requests]
251
        fewshot_logging_info = [x[2] for x in per_doc_requests][0]
Leo Gao's avatar
Leo Gao committed
252

jon-tow's avatar
jon-tow committed
253
254
        task = task_dict[task_prompt_name]
        doc = docs[(task_prompt_name, doc_id)]
Leo Gao's avatar
Leo Gao committed
255

cjlovering's avatar
cjlovering committed
256
257
258
        output = task.process_results(doc, per_doc_results)
        if task.save_examples:
            metrics, example = output
259
260
            example.update(fewshot_logging_info)
            example.update(task.get_logging_info())
cjlovering's avatar
cjlovering committed
261
262
263
            examples.append(example)
        else:
            metrics = output
264
265
266
267
            example = fewshot_logging_info
            example.update(task.get_logging_info())
            examples.append(example)

Leo Gao's avatar
Leo Gao committed
268
        for metric, value in metrics.items():
jon-tow's avatar
jon-tow committed
269
270
            vals[(task_prompt_name, metric)].append(value)

Leo Gao's avatar
Leo Gao committed
271
    # aggregate results
cjlovering's avatar
cjlovering committed
272
    metric_results = []
jon-tow's avatar
jon-tow committed
273
274
    for (task_prompt_name, metric), items in vals.items():
        task_name, prompt_name = task_prompt_name.split("+")
cjlovering's avatar
cjlovering committed
275

jon-tow's avatar
jon-tow committed
276
277
        results[task_prompt_name]["task_name"] = task_name
        results[task_prompt_name]["prompt_name"] = prompt_name
278
        task = task_dict[task_prompt_name]
jon-tow's avatar
jon-tow committed
279
        results[task_prompt_name][metric] = task.aggregation()[metric](items)
Leo Gao's avatar
Leo Gao committed
280

cjlovering's avatar
cjlovering committed
281
282
283
284
285
286
287
        _metric_results = {
            "task_name": task_name,
            "prompt_name": prompt_name,
            metric: task.aggregation()[metric](items),
            **task.get_logging_info(),
        }

288
289
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
290
291
        stderr = lm_eval.metrics.stderr_for_metric(
            metric=task.aggregation()[metric],
cjlovering's avatar
cjlovering committed
292
293
294
            bootstrap_iters=min(bootstrap_iters, 1000)
            if metric in ["bleu", "chrf", "ter"]
            else bootstrap_iters,
295
        )
Leo Gao's avatar
Leo Gao committed
296
        if stderr is not None:
jon-tow's avatar
jon-tow committed
297
            results[task_prompt_name][metric + "_stderr"] = stderr(items)
cjlovering's avatar
cjlovering committed
298
299
            _metric_results[metric + "_stderr"] = stderr(items)
        metric_results.append(_metric_results)
300

cjlovering's avatar
cjlovering committed
301
302
303
304
305
306
307
308
309
    return {
        # List of results that tracks the averages per model and prompt.
        "results": metric_results,
        "versions": dict(versions),
        # List of all prompt x doc examples with additional information in it.
        "examples": examples,
        # Original results used for generating the table when running this file.
        "table_results": dict(results),
    }
310
311
312


def make_table(result_dict):
313
    """Generate table of results."""
314
315
316
317
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
318
319
320
321
322
323
324
325
326
327
    md_writer.headers = ["Task", "Prompt", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = [
        "Task",
        "Prompt",
        "Version",
        "Metric",
        "Value",
        "",
        "Stderr",
    ]
328
329

    values = []
cjlovering's avatar
cjlovering committed
330
    for k, dic in result_dict["table_results"].items():
331
332
        version = result_dict["versions"][k]
        for m, v in dic.items():
333
334
            if m.endswith("_stderr"):
                continue
335
336
            if "_name" in m:
                continue
337
338
            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
339
340
341
342
343
344
345
346
347
348
349
                values.append(
                    [
                        dic["task_name"],
                        dic["prompt_name"],
                        version,
                        m,
                        "%.4f" % v,
                        "±",
                        "%.4f" % se,
                    ]
                )
350
            else:
351
352
353
354
355
356
357
358
359
360
361
                values.append(
                    [
                        dic["task_name"],
                        dic["prompt_name"],
                        version,
                        m,
                        "%.4f" % v,
                        "",
                        "",
                    ]
                )
362
363
364
365
366
367
368
369
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

370
    return md_writer.dumps()