evaluator.py 15.2 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
3
import numpy as np
Leo Gao's avatar
Leo Gao committed
4
import random
Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
Stephen Hogg's avatar
Stephen Hogg committed
9
from lm_eval.utils import positional_deprecated, run_task_tests
10

Fabrizio Milo's avatar
Fabrizio Milo committed
11

12
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
13
14
15
16
17
18
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
    num_fewshot=0,
    batch_size=None,
19
    max_batch_size=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
20
21
22
23
24
25
26
    device=None,
    no_cache=False,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    check_integrity=False,
    decontamination_ngrams_path=None,
27
28
    write_out=False,
    output_base_path=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
29
):
30
    """Instantiate and evaluate a model on a list of tasks.
31

32
33
34
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
35
        String arguments for each model class, see LM.create_from_arg_string.
36
37
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
38
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
39
40
    :param num_fewshot: int
        Number of examples in few-shot context
41
    :param batch_size: int or str, optional
42
        Batch size for model
43
44
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
45
    :param device: str, optional
46
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
47
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
48
        Whether or not to cache
49
50
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
51
52
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
53
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
54
        Dictionary of custom task descriptions of the form: `task_name: description`
Stephen Hogg's avatar
Stephen Hogg committed
55
56
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
57
    :param write_out: bool
58
        If True, write details about prompts and logits to json for all tasks
59
    :param output_base_path: str, optional
60
        Directory to which detailed eval info will be written. Defaults to present working dir.
61
    :return
62
        Dictionary of results
63
    """
64
65
66
    random.seed(1234)
    np.random.seed(1234)

67
68
69
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
70
71
72
        if model_args is None:
            model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(
73
            model_args, {"batch_size": batch_size, "max_batch_size": max_batch_size, "device": device}
Fabrizio Milo's avatar
Fabrizio Milo committed
74
        )
75
76
77
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
78
79

    if not no_cache:
80
        lm = lm_eval.base.CachingLM(
Fabrizio Milo's avatar
Fabrizio Milo committed
81
82
            lm,
            "lm_cache/"
83
            + (model if isinstance(model, str) else model.model.config._name_or_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
84
85
86
            + "_"
            + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
            + ".db",
87
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
88

89
    task_dict = lm_eval.tasks.get_task_dict(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
90

Stephen Hogg's avatar
Stephen Hogg committed
91
    if check_integrity:
92
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
93

94
95
96
97
98
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
99
        bootstrap_iters=bootstrap_iters,
100
        description_dict=description_dict,
Fabrizio Milo's avatar
Fabrizio Milo committed
101
        decontamination_ngrams_path=decontamination_ngrams_path,
102
103
        write_out=write_out,
        output_base_path=output_base_path,
104
    )
105
106
107

    # add info about the model and few shot config
    results["config"] = {
108
        "model": (model if isinstance(model, str) else model.model.config._name_or_path),
109
110
111
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
112
        "batch_sizes": list(lm.batch_sizes.values()),
113
114
115
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
116
        "bootstrap_iters": bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
117
        "description_dict": description_dict,
118
119
120
    }

    return results
Leo Gao's avatar
Leo Gao committed
121

Fabrizio Milo's avatar
Fabrizio Milo committed
122

123
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
124

Fabrizio Milo's avatar
Fabrizio Milo committed
125

126
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
127
128
129
130
131
132
133
134
135
def evaluate(
    lm,
    task_dict,
    provide_description=None,
    num_fewshot=0,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    decontamination_ngrams_path=None,
136
137
    write_out=False,
    output_base_path=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
138
):
139
140
141
142
143
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
144
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
145
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
146
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
147
148
149
150
151
152
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
153
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
154
        Dictionary of custom task descriptions of the form: `task_name: description`
155
    :param write_out: bool
156
        If True, write all prompts, logits and metrics to json for offline analysis
157
    :param output_base_path: str, optional
158
        Directory to which detailed eval info will be written. Defaults to present working dir
159
160
161
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
162
163
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

164
165
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
166
167
    if provide_description is not None:
        # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
168
169
170
        print(
            "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
        )
171

Leo Gao's avatar
Leo Gao committed
172
    decontaminate = decontamination_ngrams_path is not None
173

174
175
176
    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
Fabrizio Milo's avatar
Fabrizio Milo committed
177
        if (task.has_validation_docs() or task.has_test_docs())
178
    ]
Leo Gao's avatar
Leo Gao committed
179
180

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
181
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
182
183
184
185

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

Fabrizio Milo's avatar
Fabrizio Milo committed
186
    overlaps = collections.defaultdict(list)  # {task_name: contaminated_docs}
187

188
189
190
191
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
192
193
194

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}
Julen Etxaniz's avatar
Julen Etxaniz committed
195
    write_out_info = {}
Leo Gao's avatar
Leo Gao committed
196

197
198
    docs_for_decontamination = collections.defaultdict(list)

199
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
200
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
201
        versions[task_name] = task.VERSION
202
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
203
204
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
205
            task_doc_func = task.test_docs
Fabrizio Milo's avatar
Fabrizio Milo committed
206
            task_set = "test"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
207
        elif task.has_validation_docs():
Fabrizio Milo's avatar
Fabrizio Milo committed
208
            task_set = "val"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
209
            task_doc_func = task.validation_docs
210
211
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
212

Leo Gao's avatar
Leo Gao committed
213
214
215
216
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
217
        rnd.shuffle(task_docs)
218
219
        print(f"Task: {task_name}; number of docs: {len(task_docs)}")

220
        if write_out:
221
            prompt_details = []
Leo Gao's avatar
Leo Gao committed
222

Fabrizio Milo's avatar
Fabrizio Milo committed
223
224
225
226
227
        description = (
            description_dict[task_name]
            if description_dict and task_name in description_dict
            else ""
        )
228
229
        if limit is not None:
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
230

Leo Gao's avatar
Leo Gao committed
231
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
232
            if decontaminate and task.should_decontaminate():
Fabrizio Milo's avatar
Fabrizio Milo committed
233
234
235
                docs_for_decontamination[(task_name, task_set)].append(
                    task.doc_to_decontamination_query(doc)
                )
236

Leo Gao's avatar
Leo Gao committed
237
238
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
Fabrizio Milo's avatar
Fabrizio Milo committed
239
                doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
Leo Gao's avatar
Leo Gao committed
240
241
            )
            reqs = task.construct_requests(doc, ctx)
242

243
            if write_out:
244
245
246
247
248
249
250
251
252
                prompt_details.append({"doc_id": doc_id})

            # print the prompt for the first few documents
            if doc_id < 1:
                print(
                    f"Task: {task_name}; document {doc_id}; context prompt (starting on next line):\n{ctx}\n(end of prompt on previous line)"
                )
                print("Requests:", reqs)

253
254
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
255
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
256
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
257
258
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
259
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
260

261
                if write_out:
262
263
264
265
                    prompt_details[-1][f"prompt_{i}"] = "".join(
                        (map(lambda x: "".join(x), req.args))
                    )

266
        if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
267
            write_out_info[task_name] = prompt_details
268

269
270
    # Compare all tasks/sets at once to ensure a single training set scan
    if decontaminate:
271
        from lm_eval.decontamination.decontaminate import get_train_overlap
jon-tow's avatar
jon-tow committed
272

273
        print("Finding train/test overlap, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
274
275
276
        overlaps = get_train_overlap(
            docs_for_decontamination, decontamination_ngrams_path, limit
        )
277

Leo Gao's avatar
Leo Gao committed
278
279
280
281
282
    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
283
284
285
286
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
287

Leo Gao's avatar
Leo Gao committed
288
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
289
        resps = getattr(lm, reqtype)([req.args for req in reqs])
Fabrizio Milo's avatar
Fabrizio Milo committed
290
291
292
        resps = [
            x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
        ]
Leo Gao's avatar
Leo Gao committed
293
294
295

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
Fabrizio Milo's avatar
Fabrizio Milo committed
296

297
            if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
298
                write_out_info[task_name][doc_id][f"logit_{i}"] = resp
299
300
                task = task_dict[task_name]
                if isinstance(task, lm_eval.base.MultipleChoiceTask):
Julen Etxaniz's avatar
Julen Etxaniz committed
301
                    write_out_info[task_name][doc_id]["truth"] = doc["gold"]
302
                elif isinstance(task, lm_eval.tasks.winogrande.Winogrande):
Julen Etxaniz's avatar
Julen Etxaniz committed
303
                    write_out_info[task_name][doc_id]["truth"] = task.answer_to_num[
304
305
306
                        doc["answer"]
                    ]
                else:
Julen Etxaniz's avatar
Julen Etxaniz committed
307
                    write_out_info[task_name][doc_id]["truth"] = task.doc_to_target(doc)
308

Leo Gao's avatar
Leo Gao committed
309
310
311
312
313
314
315
316
317
318
319
320
321
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
322

323
            if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
324
                write_out_info[task_name][doc_id][metric] = str(value)
325

326
327
328
329
            # Re-use the evaluation for the decontaminated set by just ignoring the overlaps
            if decontaminate and task_name in overlaps:
                if doc_id not in overlaps[task_name]:
                    vals[(task_name, metric + decontaminate_suffix)].append(value)
Fabrizio Milo's avatar
Fabrizio Milo committed
330

Leo Gao's avatar
Leo Gao committed
331
332
333
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
Fabrizio Milo's avatar
Fabrizio Milo committed
334
        real_metric = metric  # key when looking up the metric with task.aggregation
335
        if metric.endswith(decontaminate_suffix):
Fabrizio Milo's avatar
Fabrizio Milo committed
336
337
338
            real_metric = metric.replace(
                decontaminate_suffix, ""
            )  # decontaminated still uses the same metric
339
        results[task_name][metric] = task.aggregation()[real_metric](items)
Leo Gao's avatar
Leo Gao committed
340

341
342
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
343

344
        stderr = lm_eval.metrics.stderr_for_metric(
345
            metric=task.aggregation()[real_metric],
Fabrizio Milo's avatar
Fabrizio Milo committed
346
347
348
            bootstrap_iters=min(bootstrap_iters, 1000)
            if metric in ["bleu", "chrf", "ter"]
            else bootstrap_iters,
349
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
350

Leo Gao's avatar
Leo Gao committed
351
352
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Fabrizio Milo's avatar
Fabrizio Milo committed
353

354
    if write_out:
355
356
357
        import json
        import pathlib

358
359
360
        output_base_path = (
            pathlib.Path(output_base_path)
            if output_base_path is not None
361
362
363
            else pathlib.Path(".")
        )
        try:
364
            output_base_path.mkdir(parents=True, exist_ok=False)
365
366
367
368
369
        except FileExistsError:
            pass

        for task_name, _ in task_dict_items:
            with open(
Julen Etxaniz's avatar
Julen Etxaniz committed
370
                output_base_path.joinpath(f"{task_name}_write_out_info.json"),
371
372
373
                "w",
                encoding="utf8",
            ) as fp:
Julen Etxaniz's avatar
Julen Etxaniz committed
374
                json.dump(write_out_info[task_name], fp, indent=4, ensure_ascii=False)
375

Fabrizio Milo's avatar
Fabrizio Milo committed
376
    return {"results": dict(results), "versions": dict(versions)}
377
378
379


def make_table(result_dict):
380
    """Generate table of results."""
381
382
383
384
385
386
387
388
389
390
391
392
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
393
394
            if m.endswith("_stderr"):
                continue
395
396
397

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
Fabrizio Milo's avatar
Fabrizio Milo committed
398
                values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
399
            else:
Fabrizio Milo's avatar
Fabrizio Milo committed
400
                values.append([k, version, m, "%.4f" % v, "", ""])
401
402
403
404
405
406
407
408
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

409
    return md_writer.dumps()