evaluator.py 15.8 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
Leo Gao's avatar
Leo Gao committed
3
import random
4

Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
Stephen Hogg's avatar
Stephen Hogg committed
9
from lm_eval.utils import positional_deprecated, run_task_tests
10
11
12
13
from lm_eval.models.gpt2 import HFLM

import numpy as np
import transformers
14

Fabrizio Milo's avatar
Fabrizio Milo committed
15

16
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
17
18
19
20
21
22
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
    num_fewshot=0,
    batch_size=None,
23
    max_batch_size=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
24
25
26
27
28
29
30
    device=None,
    no_cache=False,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    check_integrity=False,
    decontamination_ngrams_path=None,
31
32
    write_out=False,
    output_base_path=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
33
):
34
    """Instantiate and evaluate a model on a list of tasks.
35

36
    :param model: Union[str, LM]
haileyschoelkopf's avatar
haileyschoelkopf committed
37
        Name of model, transformers.PreTrainedModel object, or LM object, see lm_eval.models.get_model
38
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
39
        String arguments for each model class, see LM.create_from_arg_string.
40
41
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
42
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
43
44
    :param num_fewshot: int
        Number of examples in few-shot context
45
    :param batch_size: int or str, optional
46
        Batch size for model
47
48
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
49
    :param device: str, optional
50
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
51
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
52
        Whether or not to cache
53
54
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
55
56
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
57
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
58
        Dictionary of custom task descriptions of the form: `task_name: description`
Stephen Hogg's avatar
Stephen Hogg committed
59
60
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
61
    :param write_out: bool
62
        If True, write details about prompts and logits to json for all tasks
63
    :param output_base_path: str, optional
64
        Directory to which detailed eval info will be written. Defaults to present working dir.
65
    :return
66
        Dictionary of results
67
    """
68
69
70
    random.seed(1234)
    np.random.seed(1234)

71
72
73
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
74
75
76
        if model_args is None:
            model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(
jonabur's avatar
jonabur committed
77
78
79
80
81
82
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
83
        )
84
    elif isinstance(model, transformers.PreTrainedModel):
haileyschoelkopf's avatar
haileyschoelkopf committed
85
        lm = lm_eval.models.get_model("hf-causal")(
jonabur's avatar
jonabur committed
86
87
88
89
            pretrained=model,
            batch_size=batch_size,
            max_batch_size=max_batch_size,
        )
90
        no_cache = True
91
92
93
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
94
95

    if not no_cache:
96
        lm = lm_eval.base.CachingLM(
Fabrizio Milo's avatar
Fabrizio Milo committed
97
98
            lm,
            "lm_cache/"
99
            + (model if isinstance(model, str) else model.model.config._name_or_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
100
101
102
            + "_"
            + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
            + ".db",
103
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
104

105
    task_dict = lm_eval.tasks.get_task_dict(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
106

Stephen Hogg's avatar
Stephen Hogg committed
107
    if check_integrity:
108
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
109

110
111
112
113
114
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
115
        bootstrap_iters=bootstrap_iters,
116
        description_dict=description_dict,
Fabrizio Milo's avatar
Fabrizio Milo committed
117
        decontamination_ngrams_path=decontamination_ngrams_path,
118
119
        write_out=write_out,
        output_base_path=output_base_path,
120
    )
121
122

    # add info about the model and few shot config
123
124
125
126
127
    model_name = None
    if isinstance(model, str):
        model_name = model
    elif isinstance(model, transformers.PreTrainedModel):
        model_name = "pretrained=" + model.config._name_or_path
128
    results["config"] = {
129
        "model": model_name,
130
131
132
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
jonabur's avatar
jonabur committed
133
134
135
        "batch_sizes": list(lm.batch_sizes.values())
        if hasattr(lm, "batch_sizes")
        else [],
136
137
138
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
139
        "bootstrap_iters": bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
140
        "description_dict": description_dict,
141
142
143
    }

    return results
Leo Gao's avatar
Leo Gao committed
144

Fabrizio Milo's avatar
Fabrizio Milo committed
145

146
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
147

Fabrizio Milo's avatar
Fabrizio Milo committed
148

149
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
150
151
152
153
154
155
156
157
158
def evaluate(
    lm,
    task_dict,
    provide_description=None,
    num_fewshot=0,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    decontamination_ngrams_path=None,
159
160
    write_out=False,
    output_base_path=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
161
):
162
163
164
165
166
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
167
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
168
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
169
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
170
171
172
173
174
175
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
176
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
177
        Dictionary of custom task descriptions of the form: `task_name: description`
178
    :param write_out: bool
179
        If True, write all prompts, logits and metrics to json for offline analysis
180
    :param output_base_path: str, optional
181
        Directory to which detailed eval info will be written. Defaults to present working dir
182
183
184
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
185
186
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

187
188
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
189
190
    if provide_description is not None:
        # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
191
192
193
        print(
            "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
        )
194

Leo Gao's avatar
Leo Gao committed
195
    decontaminate = decontamination_ngrams_path is not None
196

197
198
199
    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
Fabrizio Milo's avatar
Fabrizio Milo committed
200
        if (task.has_validation_docs() or task.has_test_docs())
201
    ]
Leo Gao's avatar
Leo Gao committed
202
203

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
204
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
205
206
207
208

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

Fabrizio Milo's avatar
Fabrizio Milo committed
209
    overlaps = collections.defaultdict(list)  # {task_name: contaminated_docs}
210

211
212
213
214
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
215
216
217

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}
Julen Etxaniz's avatar
Julen Etxaniz committed
218
    write_out_info = {}
Leo Gao's avatar
Leo Gao committed
219

220
221
    docs_for_decontamination = collections.defaultdict(list)

222
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
223
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
224
        versions[task_name] = task.VERSION
225
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
226
227
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
228
            task_doc_func = task.test_docs
Fabrizio Milo's avatar
Fabrizio Milo committed
229
            task_set = "test"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
230
        elif task.has_validation_docs():
Fabrizio Milo's avatar
Fabrizio Milo committed
231
            task_set = "val"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
232
            task_doc_func = task.validation_docs
233
234
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
235

Leo Gao's avatar
Leo Gao committed
236
237
238
239
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
240
        rnd.shuffle(task_docs)
241
242
        print(f"Task: {task_name}; number of docs: {len(task_docs)}")

243
        if write_out:
244
            prompt_details = []
Leo Gao's avatar
Leo Gao committed
245

Fabrizio Milo's avatar
Fabrizio Milo committed
246
247
248
249
250
        description = (
            description_dict[task_name]
            if description_dict and task_name in description_dict
            else ""
        )
251
252
        if limit is not None:
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
253

Leo Gao's avatar
Leo Gao committed
254
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
255
            if decontaminate and task.should_decontaminate():
Fabrizio Milo's avatar
Fabrizio Milo committed
256
257
258
                docs_for_decontamination[(task_name, task_set)].append(
                    task.doc_to_decontamination_query(doc)
                )
259

Leo Gao's avatar
Leo Gao committed
260
261
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
Fabrizio Milo's avatar
Fabrizio Milo committed
262
                doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
Leo Gao's avatar
Leo Gao committed
263
264
            )
            reqs = task.construct_requests(doc, ctx)
265

266
            if write_out:
267
268
269
270
271
272
273
274
275
                prompt_details.append({"doc_id": doc_id})

            # print the prompt for the first few documents
            if doc_id < 1:
                print(
                    f"Task: {task_name}; document {doc_id}; context prompt (starting on next line):\n{ctx}\n(end of prompt on previous line)"
                )
                print("Requests:", reqs)

276
277
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
278
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
279
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
280
281
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
282
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
283

284
                if write_out:
285
286
287
288
                    prompt_details[-1][f"prompt_{i}"] = "".join(
                        (map(lambda x: "".join(x), req.args))
                    )

289
        if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
290
            write_out_info[task_name] = prompt_details
291

292
293
    # Compare all tasks/sets at once to ensure a single training set scan
    if decontaminate:
294
        from lm_eval.decontamination.decontaminate import get_train_overlap
jon-tow's avatar
jon-tow committed
295

296
        print("Finding train/test overlap, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
297
298
299
        overlaps = get_train_overlap(
            docs_for_decontamination, decontamination_ngrams_path, limit
        )
300

Leo Gao's avatar
Leo Gao committed
301
302
303
304
305
    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
306
307
308
309
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
310

Leo Gao's avatar
Leo Gao committed
311
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
312
        resps = getattr(lm, reqtype)([req.args for req in reqs])
Fabrizio Milo's avatar
Fabrizio Milo committed
313
314
315
        resps = [
            x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
        ]
Leo Gao's avatar
Leo Gao committed
316
317
318

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
Fabrizio Milo's avatar
Fabrizio Milo committed
319

320
            if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
321
                write_out_info[task_name][doc_id][f"logit_{i}"] = resp
322
323
                task = task_dict[task_name]
                if isinstance(task, lm_eval.base.MultipleChoiceTask):
Julen Etxaniz's avatar
Julen Etxaniz committed
324
                    write_out_info[task_name][doc_id]["truth"] = doc["gold"]
325
                elif isinstance(task, lm_eval.tasks.winogrande.Winogrande):
Julen Etxaniz's avatar
Julen Etxaniz committed
326
                    write_out_info[task_name][doc_id]["truth"] = task.answer_to_num[
327
328
329
                        doc["answer"]
                    ]
                else:
Julen Etxaniz's avatar
Julen Etxaniz committed
330
                    write_out_info[task_name][doc_id]["truth"] = task.doc_to_target(doc)
331

Leo Gao's avatar
Leo Gao committed
332
333
334
335
336
337
338
339
340
341
342
343
344
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
345

346
            if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
347
                write_out_info[task_name][doc_id][metric] = str(value)
348

349
350
351
352
            # Re-use the evaluation for the decontaminated set by just ignoring the overlaps
            if decontaminate and task_name in overlaps:
                if doc_id not in overlaps[task_name]:
                    vals[(task_name, metric + decontaminate_suffix)].append(value)
Fabrizio Milo's avatar
Fabrizio Milo committed
353

Leo Gao's avatar
Leo Gao committed
354
355
356
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
Fabrizio Milo's avatar
Fabrizio Milo committed
357
        real_metric = metric  # key when looking up the metric with task.aggregation
358
        if metric.endswith(decontaminate_suffix):
Fabrizio Milo's avatar
Fabrizio Milo committed
359
360
361
            real_metric = metric.replace(
                decontaminate_suffix, ""
            )  # decontaminated still uses the same metric
362
        results[task_name][metric] = task.aggregation()[real_metric](items)
Leo Gao's avatar
Leo Gao committed
363

364
365
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
366

367
        stderr = lm_eval.metrics.stderr_for_metric(
368
            metric=task.aggregation()[real_metric],
Fabrizio Milo's avatar
Fabrizio Milo committed
369
370
371
            bootstrap_iters=min(bootstrap_iters, 1000)
            if metric in ["bleu", "chrf", "ter"]
            else bootstrap_iters,
372
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
373

Leo Gao's avatar
Leo Gao committed
374
375
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Fabrizio Milo's avatar
Fabrizio Milo committed
376

377
    if write_out:
378
379
380
        import json
        import pathlib

381
382
383
        output_base_path = (
            pathlib.Path(output_base_path)
            if output_base_path is not None
384
385
386
            else pathlib.Path(".")
        )
        try:
387
            output_base_path.mkdir(parents=True, exist_ok=False)
388
389
390
391
392
        except FileExistsError:
            pass

        for task_name, _ in task_dict_items:
            with open(
Julen Etxaniz's avatar
Julen Etxaniz committed
393
                output_base_path.joinpath(f"{task_name}_write_out_info.json"),
394
395
396
                "w",
                encoding="utf8",
            ) as fp:
Julen Etxaniz's avatar
Julen Etxaniz committed
397
                json.dump(write_out_info[task_name], fp, indent=4, ensure_ascii=False)
398

Fabrizio Milo's avatar
Fabrizio Milo committed
399
    return {"results": dict(results), "versions": dict(versions)}
400
401
402


def make_table(result_dict):
403
    """Generate table of results."""
404
405
406
407
408
409
410
411
412
413
414
415
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
416
417
            if m.endswith("_stderr"):
                continue
418
419
420

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
Fabrizio Milo's avatar
Fabrizio Milo committed
421
                values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
422
            else:
Fabrizio Milo's avatar
Fabrizio Milo committed
423
                values.append([k, version, m, "%.4f" % v, "", ""])
424
425
426
427
428
429
430
431
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

432
    return md_writer.dumps()