evaluator.py 15 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
3
import numpy as np
Leo Gao's avatar
Leo Gao committed
4
import random
Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
Stephen Hogg's avatar
Stephen Hogg committed
9
from lm_eval.utils import positional_deprecated, run_task_tests
10

Fabrizio Milo's avatar
Fabrizio Milo committed
11

12
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
13
14
15
16
17
18
19
20
21
22
23
24
25
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
    num_fewshot=0,
    batch_size=None,
    device=None,
    no_cache=False,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    check_integrity=False,
    decontamination_ngrams_path=None,
Alexander's avatar
Alexander committed
26
    tokenizer=None,
27
28
    write_out=False,
    output_base_path=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
29
):
30
    """Instantiate and evaluate a model on a list of tasks.
31

32
33
34
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
35
        String arguments for each model class, see LM.create_from_arg_string.
36
37
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
38
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
39
40
41
42
43
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional
44
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
45
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
46
        Whether or not to cache
47
48
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
49
50
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
51
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
52
        Dictionary of custom task descriptions of the form: `task_name: description`
Stephen Hogg's avatar
Stephen Hogg committed
53
54
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
55
    :param write_out: bool
56
        If True, write details about prompts and logits to json for all tasks
57
    :param output_base_path: str, optional
58
        Directory to which detailed eval info will be written. Defaults to present working dir.
59
    :return
60
        Dictionary of results
61
    """
62
63
64
    random.seed(1234)
    np.random.seed(1234)

65
66
67
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
68
69
70
        if model_args is None:
            model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(
Alexander's avatar
Alexander committed
71
            model_args, {"batch_size": batch_size, "device": device, "tokenizer": tokenizer, "trust_remote_code": True}
Fabrizio Milo's avatar
Fabrizio Milo committed
72
        )
73
74
75
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
76
77

    if not no_cache:
78
        lm = lm_eval.base.CachingLM(
Fabrizio Milo's avatar
Fabrizio Milo committed
79
80
81
82
83
84
            lm,
            "lm_cache/"
            + model
            + "_"
            + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
            + ".db",
85
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
86

87
    task_dict = lm_eval.tasks.get_task_dict(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
88

Stephen Hogg's avatar
Stephen Hogg committed
89
    if check_integrity:
90
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
91

92
93
94
95
96
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
97
        bootstrap_iters=bootstrap_iters,
98
        description_dict=description_dict,
Fabrizio Milo's avatar
Fabrizio Milo committed
99
        decontamination_ngrams_path=decontamination_ngrams_path,
100
101
        write_out=write_out,
        output_base_path=output_base_path,
102
    )
103
104
105
106
107
108
109
110
111
112

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
113
        "bootstrap_iters": bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
114
        "description_dict": description_dict,
115
116
117
    }

    return results
Leo Gao's avatar
Leo Gao committed
118

Fabrizio Milo's avatar
Fabrizio Milo committed
119

120
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
121

Fabrizio Milo's avatar
Fabrizio Milo committed
122

123
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
124
125
126
127
128
129
130
131
132
def evaluate(
    lm,
    task_dict,
    provide_description=None,
    num_fewshot=0,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    decontamination_ngrams_path=None,
133
134
    write_out=False,
    output_base_path=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
135
):
136
137
138
139
140
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
141
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
142
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
143
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
144
145
146
147
148
149
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
150
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
151
        Dictionary of custom task descriptions of the form: `task_name: description`
152
    :param write_out: bool
153
        If True, write all prompts, logits and metrics to json for offline analysis
154
    :param output_base_path: str, optional
155
        Directory to which detailed eval info will be written. Defaults to present working dir
156
157
158
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
159
160
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

161
162
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
163
164
    if provide_description is not None:
        # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
165
166
167
        print(
            "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
        )
168

Leo Gao's avatar
Leo Gao committed
169
    decontaminate = decontamination_ngrams_path is not None
170

171
172
173
    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
Fabrizio Milo's avatar
Fabrizio Milo committed
174
        if (task.has_validation_docs() or task.has_test_docs())
175
    ]
Leo Gao's avatar
Leo Gao committed
176
177

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
178
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
179
180
181
182

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

Fabrizio Milo's avatar
Fabrizio Milo committed
183
    overlaps = collections.defaultdict(list)  # {task_name: contaminated_docs}
184

185
186
187
188
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
189
190
191

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}
Julen Etxaniz's avatar
Julen Etxaniz committed
192
    write_out_info = {}
Leo Gao's avatar
Leo Gao committed
193

194
195
    docs_for_decontamination = collections.defaultdict(list)

196
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
197
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
198
        versions[task_name] = task.VERSION
199
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
200
201
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
202
            task_doc_func = task.test_docs
Fabrizio Milo's avatar
Fabrizio Milo committed
203
            task_set = "test"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
204
        elif task.has_validation_docs():
Fabrizio Milo's avatar
Fabrizio Milo committed
205
            task_set = "val"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
206
            task_doc_func = task.validation_docs
207
208
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
209

Leo Gao's avatar
Leo Gao committed
210
211
212
213
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
214
        rnd.shuffle(task_docs)
215
216
        print(f"Task: {task_name}; number of docs: {len(task_docs)}")

217
        if write_out:
218
            prompt_details = []
Leo Gao's avatar
Leo Gao committed
219

Fabrizio Milo's avatar
Fabrizio Milo committed
220
221
222
223
224
        description = (
            description_dict[task_name]
            if description_dict and task_name in description_dict
            else ""
        )
225
226
        if limit is not None:
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
227

Leo Gao's avatar
Leo Gao committed
228
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
229
            if decontaminate and task.should_decontaminate():
Fabrizio Milo's avatar
Fabrizio Milo committed
230
231
232
                docs_for_decontamination[(task_name, task_set)].append(
                    task.doc_to_decontamination_query(doc)
                )
233

Leo Gao's avatar
Leo Gao committed
234
235
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
Fabrizio Milo's avatar
Fabrizio Milo committed
236
                doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
Leo Gao's avatar
Leo Gao committed
237
238
            )
            reqs = task.construct_requests(doc, ctx)
239

240
            if write_out:
241
242
243
244
245
246
247
248
249
                prompt_details.append({"doc_id": doc_id})

            # print the prompt for the first few documents
            if doc_id < 1:
                print(
                    f"Task: {task_name}; document {doc_id}; context prompt (starting on next line):\n{ctx}\n(end of prompt on previous line)"
                )
                print("Requests:", reqs)

250
251
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
252
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
253
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
254
255
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
256
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
257

258
                if write_out:
259
260
261
262
                    prompt_details[-1][f"prompt_{i}"] = "".join(
                        (map(lambda x: "".join(x), req.args))
                    )

263
        if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
264
            write_out_info[task_name] = prompt_details
265

266
267
    # Compare all tasks/sets at once to ensure a single training set scan
    if decontaminate:
268
        from lm_eval.decontamination.decontaminate import get_train_overlap
jon-tow's avatar
jon-tow committed
269

270
        print("Finding train/test overlap, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
271
272
273
        overlaps = get_train_overlap(
            docs_for_decontamination, decontamination_ngrams_path, limit
        )
274

Leo Gao's avatar
Leo Gao committed
275
276
277
278
279
    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
280
281
282
283
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
284

Leo Gao's avatar
Leo Gao committed
285
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
286
        resps = getattr(lm, reqtype)([req.args for req in reqs])
Fabrizio Milo's avatar
Fabrizio Milo committed
287
288
289
        resps = [
            x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
        ]
Leo Gao's avatar
Leo Gao committed
290
291
292

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
Fabrizio Milo's avatar
Fabrizio Milo committed
293

294
            if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
295
                write_out_info[task_name][doc_id][f"logit_{i}"] = resp
296
297
                task = task_dict[task_name]
                if isinstance(task, lm_eval.base.MultipleChoiceTask):
Julen Etxaniz's avatar
Julen Etxaniz committed
298
                    write_out_info[task_name][doc_id]["truth"] = doc["gold"]
299
                elif isinstance(task, lm_eval.tasks.winogrande.Winogrande):
Julen Etxaniz's avatar
Julen Etxaniz committed
300
                    write_out_info[task_name][doc_id]["truth"] = task.answer_to_num[
301
302
303
                        doc["answer"]
                    ]
                else:
Julen Etxaniz's avatar
Julen Etxaniz committed
304
                    write_out_info[task_name][doc_id]["truth"] = task.doc_to_target(doc)
305

Leo Gao's avatar
Leo Gao committed
306
307
308
309
310
311
312
313
314
315
316
317
318
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
319

320
            if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
321
                write_out_info[task_name][doc_id][metric] = str(value)
322

323
324
325
326
            # Re-use the evaluation for the decontaminated set by just ignoring the overlaps
            if decontaminate and task_name in overlaps:
                if doc_id not in overlaps[task_name]:
                    vals[(task_name, metric + decontaminate_suffix)].append(value)
Fabrizio Milo's avatar
Fabrizio Milo committed
327

Leo Gao's avatar
Leo Gao committed
328
329
330
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
Fabrizio Milo's avatar
Fabrizio Milo committed
331
        real_metric = metric  # key when looking up the metric with task.aggregation
332
        if metric.endswith(decontaminate_suffix):
Fabrizio Milo's avatar
Fabrizio Milo committed
333
334
335
            real_metric = metric.replace(
                decontaminate_suffix, ""
            )  # decontaminated still uses the same metric
336
        results[task_name][metric] = task.aggregation()[real_metric](items)
Leo Gao's avatar
Leo Gao committed
337

338
339
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
340

341
        stderr = lm_eval.metrics.stderr_for_metric(
342
            metric=task.aggregation()[real_metric],
Fabrizio Milo's avatar
Fabrizio Milo committed
343
344
345
            bootstrap_iters=min(bootstrap_iters, 1000)
            if metric in ["bleu", "chrf", "ter"]
            else bootstrap_iters,
346
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
347

Leo Gao's avatar
Leo Gao committed
348
349
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Fabrizio Milo's avatar
Fabrizio Milo committed
350

351
    if write_out:
352
353
354
        import json
        import pathlib

355
356
357
        output_base_path = (
            pathlib.Path(output_base_path)
            if output_base_path is not None
358
359
360
            else pathlib.Path(".")
        )
        try:
361
            output_base_path.mkdir(parents=True, exist_ok=False)
362
363
364
365
366
        except FileExistsError:
            pass

        for task_name, _ in task_dict_items:
            with open(
Julen Etxaniz's avatar
Julen Etxaniz committed
367
                output_base_path.joinpath(f"{task_name}_write_out_info.json"),
368
369
370
                "w",
                encoding="utf8",
            ) as fp:
Julen Etxaniz's avatar
Julen Etxaniz committed
371
                json.dump(write_out_info[task_name], fp, indent=4, ensure_ascii=False)
372

Fabrizio Milo's avatar
Fabrizio Milo committed
373
    return {"results": dict(results), "versions": dict(versions)}
374
375
376


def make_table(result_dict):
377
    """Generate table of results."""
378
379
380
381
382
383
384
385
386
387
388
389
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
390
391
            if m.endswith("_stderr"):
                continue
392
393
394

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
Fabrizio Milo's avatar
Fabrizio Milo committed
395
                values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
396
            else:
Fabrizio Milo's avatar
Fabrizio Milo committed
397
                values.append([k, version, m, "%.4f" % v, "", ""])
398
399
400
401
402
403
404
405
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

406
    return md_writer.dumps()