evaluator.py 16 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
Leo Gao's avatar
Leo Gao committed
3
import random
4

Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
Stephen Hogg's avatar
Stephen Hogg committed
9
from lm_eval.utils import positional_deprecated, run_task_tests
10
11
12
13
from lm_eval.models.gpt2 import HFLM

import numpy as np
import transformers
14

Fabrizio Milo's avatar
Fabrizio Milo committed
15

16
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
17
18
19
20
21
22
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
    num_fewshot=0,
    batch_size=None,
23
    max_batch_size=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
24
25
26
27
28
29
30
    device=None,
    no_cache=False,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    check_integrity=False,
    decontamination_ngrams_path=None,
Alexander's avatar
Alexander committed
31
    tokenizer=None,
32
33
    write_out=False,
    output_base_path=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
34
):
35
    """Instantiate and evaluate a model on a list of tasks.
36

37
    :param model: Union[str, LM]
haileyschoelkopf's avatar
haileyschoelkopf committed
38
        Name of model, transformers.PreTrainedModel object, or LM object, see lm_eval.models.get_model
39
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
40
        String arguments for each model class, see LM.create_from_arg_string.
41
42
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
43
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
44
45
    :param num_fewshot: int
        Number of examples in few-shot context
46
    :param batch_size: int or str, optional
47
        Batch size for model
48
49
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
50
    :param device: str, optional
51
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
52
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
53
        Whether or not to cache
54
55
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
56
57
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
58
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
59
        Dictionary of custom task descriptions of the form: `task_name: description`
Stephen Hogg's avatar
Stephen Hogg committed
60
61
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
62
    :param write_out: bool
63
        If True, write details about prompts and logits to json for all tasks
64
    :param output_base_path: str, optional
65
        Directory to which detailed eval info will be written. Defaults to present working dir.
66
    :return
67
        Dictionary of results
68
    """
69
70
71
    random.seed(1234)
    np.random.seed(1234)

72
73
74
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
75
76
77
        if model_args is None:
            model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(
Alexander's avatar
Alexander committed
78
<<<<<<< HEAD
Alexander's avatar
Alexander committed
79
            model_args, {"batch_size": batch_size, "device": device, "tokenizer": tokenizer, "trust_remote_code": True}
Alexander's avatar
Alexander committed
80
=======
jonabur's avatar
jonabur committed
81
82
83
84
85
86
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
87
        )
88
    elif isinstance(model, transformers.PreTrainedModel):
haileyschoelkopf's avatar
haileyschoelkopf committed
89
        lm = lm_eval.models.get_model("hf-causal")(
jonabur's avatar
jonabur committed
90
91
92
            pretrained=model,
            batch_size=batch_size,
            max_batch_size=max_batch_size,
Alexander's avatar
Alexander committed
93
>>>>>>> origin/master
Fabrizio Milo's avatar
Fabrizio Milo committed
94
        )
95
        no_cache = True
96
97
98
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
99
100

    if not no_cache:
101
        lm = lm_eval.base.CachingLM(
Fabrizio Milo's avatar
Fabrizio Milo committed
102
103
            lm,
            "lm_cache/"
104
            + (model if isinstance(model, str) else model.model.config._name_or_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
105
106
107
            + "_"
            + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
            + ".db",
108
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
109

110
    task_dict = lm_eval.tasks.get_task_dict(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
111

Stephen Hogg's avatar
Stephen Hogg committed
112
    if check_integrity:
113
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
114

115
116
117
118
119
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
120
        bootstrap_iters=bootstrap_iters,
121
        description_dict=description_dict,
Fabrizio Milo's avatar
Fabrizio Milo committed
122
        decontamination_ngrams_path=decontamination_ngrams_path,
123
124
        write_out=write_out,
        output_base_path=output_base_path,
125
    )
126
127

    # add info about the model and few shot config
128
129
130
131
132
    model_name = None
    if isinstance(model, str):
        model_name = model
    elif isinstance(model, transformers.PreTrainedModel):
        model_name = "pretrained=" + model.config._name_or_path
133
    results["config"] = {
134
        "model": model_name,
135
136
137
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
jonabur's avatar
jonabur committed
138
139
140
        "batch_sizes": list(lm.batch_sizes.values())
        if hasattr(lm, "batch_sizes")
        else [],
141
142
143
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
144
        "bootstrap_iters": bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
145
        "description_dict": description_dict,
146
147
148
    }

    return results
Leo Gao's avatar
Leo Gao committed
149

Fabrizio Milo's avatar
Fabrizio Milo committed
150

151
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
152

Fabrizio Milo's avatar
Fabrizio Milo committed
153

154
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
155
156
157
158
159
160
161
162
163
def evaluate(
    lm,
    task_dict,
    provide_description=None,
    num_fewshot=0,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    decontamination_ngrams_path=None,
164
165
    write_out=False,
    output_base_path=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
166
):
167
168
169
170
171
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
172
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
173
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
174
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
175
176
177
178
179
180
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
181
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
182
        Dictionary of custom task descriptions of the form: `task_name: description`
183
    :param write_out: bool
184
        If True, write all prompts, logits and metrics to json for offline analysis
185
    :param output_base_path: str, optional
186
        Directory to which detailed eval info will be written. Defaults to present working dir
187
188
189
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
190
191
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

192
193
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
194
195
    if provide_description is not None:
        # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
196
197
198
        print(
            "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
        )
199

Leo Gao's avatar
Leo Gao committed
200
    decontaminate = decontamination_ngrams_path is not None
201

202
203
204
    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
Fabrizio Milo's avatar
Fabrizio Milo committed
205
        if (task.has_validation_docs() or task.has_test_docs())
206
    ]
Leo Gao's avatar
Leo Gao committed
207
208

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
209
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
210
211
212
213

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

Fabrizio Milo's avatar
Fabrizio Milo committed
214
    overlaps = collections.defaultdict(list)  # {task_name: contaminated_docs}
215

216
217
218
219
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
220
221
222

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}
Julen Etxaniz's avatar
Julen Etxaniz committed
223
    write_out_info = {}
Leo Gao's avatar
Leo Gao committed
224

225
226
    docs_for_decontamination = collections.defaultdict(list)

227
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
228
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
229
        versions[task_name] = task.VERSION
230
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
231
232
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
233
            task_doc_func = task.test_docs
Fabrizio Milo's avatar
Fabrizio Milo committed
234
            task_set = "test"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
235
        elif task.has_validation_docs():
Fabrizio Milo's avatar
Fabrizio Milo committed
236
            task_set = "val"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
237
            task_doc_func = task.validation_docs
238
239
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
240

Leo Gao's avatar
Leo Gao committed
241
242
243
244
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
245
        rnd.shuffle(task_docs)
246
247
        print(f"Task: {task_name}; number of docs: {len(task_docs)}")

248
        if write_out:
249
            prompt_details = []
Leo Gao's avatar
Leo Gao committed
250

Fabrizio Milo's avatar
Fabrizio Milo committed
251
252
253
254
255
        description = (
            description_dict[task_name]
            if description_dict and task_name in description_dict
            else ""
        )
256
257
        if limit is not None:
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
258

Leo Gao's avatar
Leo Gao committed
259
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
260
            if decontaminate and task.should_decontaminate():
Fabrizio Milo's avatar
Fabrizio Milo committed
261
262
263
                docs_for_decontamination[(task_name, task_set)].append(
                    task.doc_to_decontamination_query(doc)
                )
264

Leo Gao's avatar
Leo Gao committed
265
266
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
Fabrizio Milo's avatar
Fabrizio Milo committed
267
                doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
Leo Gao's avatar
Leo Gao committed
268
269
            )
            reqs = task.construct_requests(doc, ctx)
270

271
            if write_out:
272
273
274
275
276
277
278
279
280
                prompt_details.append({"doc_id": doc_id})

            # print the prompt for the first few documents
            if doc_id < 1:
                print(
                    f"Task: {task_name}; document {doc_id}; context prompt (starting on next line):\n{ctx}\n(end of prompt on previous line)"
                )
                print("Requests:", reqs)

281
282
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
283
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
284
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
285
286
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
287
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
288

289
                if write_out:
290
291
292
293
                    prompt_details[-1][f"prompt_{i}"] = "".join(
                        (map(lambda x: "".join(x), req.args))
                    )

294
        if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
295
            write_out_info[task_name] = prompt_details
296

297
298
    # Compare all tasks/sets at once to ensure a single training set scan
    if decontaminate:
299
        from lm_eval.decontamination.decontaminate import get_train_overlap
jon-tow's avatar
jon-tow committed
300

301
        print("Finding train/test overlap, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
302
303
304
        overlaps = get_train_overlap(
            docs_for_decontamination, decontamination_ngrams_path, limit
        )
305

Leo Gao's avatar
Leo Gao committed
306
307
308
309
310
    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
311
312
313
314
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
315

Leo Gao's avatar
Leo Gao committed
316
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
317
        resps = getattr(lm, reqtype)([req.args for req in reqs])
Fabrizio Milo's avatar
Fabrizio Milo committed
318
319
320
        resps = [
            x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
        ]
Leo Gao's avatar
Leo Gao committed
321
322
323

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
Fabrizio Milo's avatar
Fabrizio Milo committed
324

325
            if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
326
                write_out_info[task_name][doc_id][f"logit_{i}"] = resp
327
328
                task = task_dict[task_name]
                if isinstance(task, lm_eval.base.MultipleChoiceTask):
Julen Etxaniz's avatar
Julen Etxaniz committed
329
                    write_out_info[task_name][doc_id]["truth"] = doc["gold"]
330
                elif isinstance(task, lm_eval.tasks.winogrande.Winogrande):
Julen Etxaniz's avatar
Julen Etxaniz committed
331
                    write_out_info[task_name][doc_id]["truth"] = task.answer_to_num[
332
333
334
                        doc["answer"]
                    ]
                else:
Julen Etxaniz's avatar
Julen Etxaniz committed
335
                    write_out_info[task_name][doc_id]["truth"] = task.doc_to_target(doc)
336

Leo Gao's avatar
Leo Gao committed
337
338
339
340
341
342
343
344
345
346
347
348
349
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
350

351
            if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
352
                write_out_info[task_name][doc_id][metric] = str(value)
353

354
355
356
357
            # Re-use the evaluation for the decontaminated set by just ignoring the overlaps
            if decontaminate and task_name in overlaps:
                if doc_id not in overlaps[task_name]:
                    vals[(task_name, metric + decontaminate_suffix)].append(value)
Fabrizio Milo's avatar
Fabrizio Milo committed
358

Leo Gao's avatar
Leo Gao committed
359
360
361
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
Fabrizio Milo's avatar
Fabrizio Milo committed
362
        real_metric = metric  # key when looking up the metric with task.aggregation
363
        if metric.endswith(decontaminate_suffix):
Fabrizio Milo's avatar
Fabrizio Milo committed
364
365
366
            real_metric = metric.replace(
                decontaminate_suffix, ""
            )  # decontaminated still uses the same metric
367
        results[task_name][metric] = task.aggregation()[real_metric](items)
Leo Gao's avatar
Leo Gao committed
368

369
370
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
371

372
        stderr = lm_eval.metrics.stderr_for_metric(
373
            metric=task.aggregation()[real_metric],
Fabrizio Milo's avatar
Fabrizio Milo committed
374
375
376
            bootstrap_iters=min(bootstrap_iters, 1000)
            if metric in ["bleu", "chrf", "ter"]
            else bootstrap_iters,
377
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
378

Leo Gao's avatar
Leo Gao committed
379
380
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Fabrizio Milo's avatar
Fabrizio Milo committed
381

382
    if write_out:
383
384
385
        import json
        import pathlib

386
387
388
        output_base_path = (
            pathlib.Path(output_base_path)
            if output_base_path is not None
389
390
391
            else pathlib.Path(".")
        )
        try:
392
            output_base_path.mkdir(parents=True, exist_ok=False)
393
394
395
396
397
        except FileExistsError:
            pass

        for task_name, _ in task_dict_items:
            with open(
Julen Etxaniz's avatar
Julen Etxaniz committed
398
                output_base_path.joinpath(f"{task_name}_write_out_info.json"),
399
400
401
                "w",
                encoding="utf8",
            ) as fp:
Julen Etxaniz's avatar
Julen Etxaniz committed
402
                json.dump(write_out_info[task_name], fp, indent=4, ensure_ascii=False)
403

Fabrizio Milo's avatar
Fabrizio Milo committed
404
    return {"results": dict(results), "versions": dict(versions)}
405
406
407


def make_table(result_dict):
408
    """Generate table of results."""
409
410
411
412
413
414
415
416
417
418
419
420
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
421
422
            if m.endswith("_stderr"):
                continue
423
424
425

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
Fabrizio Milo's avatar
Fabrizio Milo committed
426
                values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
427
            else:
Fabrizio Milo's avatar
Fabrizio Milo committed
428
                values.append([k, version, m, "%.4f" % v, "", ""])
429
430
431
432
433
434
435
436
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

437
    return md_writer.dumps()