evaluator.py 15.1 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
Leo Gao's avatar
Leo Gao committed
3
import random
4

Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
Stephen Hogg's avatar
Stephen Hogg committed
9
from lm_eval.utils import positional_deprecated, run_task_tests
10
11
12
13
from lm_eval.models.gpt2 import HFLM

import numpy as np
import transformers
14

Fabrizio Milo's avatar
Fabrizio Milo committed
15

16
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
17
18
19
20
21
22
23
24
25
26
27
28
29
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
    num_fewshot=0,
    batch_size=None,
    device=None,
    no_cache=False,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    check_integrity=False,
    decontamination_ngrams_path=None,
30
31
    write_out=False,
    output_base_path=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
32
):
33
    """Instantiate and evaluate a model on a list of tasks.
34

35
36
37
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
38
        String arguments for each model class, see LM.create_from_arg_string.
39
40
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
41
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
42
43
44
45
46
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional
47
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
48
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
49
        Whether or not to cache
50
51
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
52
53
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
54
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
55
        Dictionary of custom task descriptions of the form: `task_name: description`
Stephen Hogg's avatar
Stephen Hogg committed
56
57
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
58
    :param write_out: bool
59
        If True, write details about prompts and logits to json for all tasks
60
    :param output_base_path: str, optional
61
        Directory to which detailed eval info will be written. Defaults to present working dir.
62
    :return
63
        Dictionary of results
64
    """
65
66
67
    random.seed(1234)
    np.random.seed(1234)

68
69
70
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
71
72
73
74
75
        if model_args is None:
            model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(
            model_args, {"batch_size": batch_size, "device": device}
        )
76
77
78
79
80
    elif isinstance(model, transformers.PreTrainedModel):
        lm = HFLM(
                pretrained=model,
                )
        no_cache = True
81
82
83
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
84
85

    if not no_cache:
86
        lm = lm_eval.base.CachingLM(
Fabrizio Milo's avatar
Fabrizio Milo committed
87
88
89
90
91
92
            lm,
            "lm_cache/"
            + model
            + "_"
            + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
            + ".db",
93
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
94

95
    task_dict = lm_eval.tasks.get_task_dict(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
96

Stephen Hogg's avatar
Stephen Hogg committed
97
    if check_integrity:
98
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
99

100
101
102
103
104
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
105
        bootstrap_iters=bootstrap_iters,
106
        description_dict=description_dict,
Fabrizio Milo's avatar
Fabrizio Milo committed
107
        decontamination_ngrams_path=decontamination_ngrams_path,
108
109
        write_out=write_out,
        output_base_path=output_base_path,
110
    )
111
112
113
114
115
116
117
118
119
120

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
121
        "bootstrap_iters": bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
122
        "description_dict": description_dict,
123
124
125
    }

    return results
Leo Gao's avatar
Leo Gao committed
126

Fabrizio Milo's avatar
Fabrizio Milo committed
127

128
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
129

Fabrizio Milo's avatar
Fabrizio Milo committed
130

131
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
132
133
134
135
136
137
138
139
140
def evaluate(
    lm,
    task_dict,
    provide_description=None,
    num_fewshot=0,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    decontamination_ngrams_path=None,
141
142
    write_out=False,
    output_base_path=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
143
):
144
145
146
147
148
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
149
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
150
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
151
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
152
153
154
155
156
157
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
158
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
159
        Dictionary of custom task descriptions of the form: `task_name: description`
160
    :param write_out: bool
161
        If True, write all prompts, logits and metrics to json for offline analysis
162
    :param output_base_path: str, optional
163
        Directory to which detailed eval info will be written. Defaults to present working dir
164
165
166
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
167
168
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

169
170
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
171
172
    if provide_description is not None:
        # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
173
174
175
        print(
            "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
        )
176

Leo Gao's avatar
Leo Gao committed
177
    decontaminate = decontamination_ngrams_path is not None
178

179
180
181
    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
Fabrizio Milo's avatar
Fabrizio Milo committed
182
        if (task.has_validation_docs() or task.has_test_docs())
183
    ]
Leo Gao's avatar
Leo Gao committed
184
185

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
186
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
187
188
189
190

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

Fabrizio Milo's avatar
Fabrizio Milo committed
191
    overlaps = collections.defaultdict(list)  # {task_name: contaminated_docs}
192

193
194
195
196
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
197
198
199

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}
Julen Etxaniz's avatar
Julen Etxaniz committed
200
    write_out_info = {}
Leo Gao's avatar
Leo Gao committed
201

202
203
    docs_for_decontamination = collections.defaultdict(list)

204
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
205
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
206
        versions[task_name] = task.VERSION
207
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
208
209
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
210
            task_doc_func = task.test_docs
Fabrizio Milo's avatar
Fabrizio Milo committed
211
            task_set = "test"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
212
        elif task.has_validation_docs():
Fabrizio Milo's avatar
Fabrizio Milo committed
213
            task_set = "val"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
214
            task_doc_func = task.validation_docs
215
216
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
217

Leo Gao's avatar
Leo Gao committed
218
219
220
221
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
222
        rnd.shuffle(task_docs)
223
224
        print(f"Task: {task_name}; number of docs: {len(task_docs)}")

225
        if write_out:
226
            prompt_details = []
Leo Gao's avatar
Leo Gao committed
227

Fabrizio Milo's avatar
Fabrizio Milo committed
228
229
230
231
232
        description = (
            description_dict[task_name]
            if description_dict and task_name in description_dict
            else ""
        )
233
234
        if limit is not None:
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
235

Leo Gao's avatar
Leo Gao committed
236
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
237
            if decontaminate and task.should_decontaminate():
Fabrizio Milo's avatar
Fabrizio Milo committed
238
239
240
                docs_for_decontamination[(task_name, task_set)].append(
                    task.doc_to_decontamination_query(doc)
                )
241

Leo Gao's avatar
Leo Gao committed
242
243
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
Fabrizio Milo's avatar
Fabrizio Milo committed
244
                doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
Leo Gao's avatar
Leo Gao committed
245
246
            )
            reqs = task.construct_requests(doc, ctx)
247

248
            if write_out:
249
250
251
252
253
254
255
256
257
                prompt_details.append({"doc_id": doc_id})

            # print the prompt for the first few documents
            if doc_id < 1:
                print(
                    f"Task: {task_name}; document {doc_id}; context prompt (starting on next line):\n{ctx}\n(end of prompt on previous line)"
                )
                print("Requests:", reqs)

258
259
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
260
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
261
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
262
263
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
264
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
265

266
                if write_out:
267
268
269
270
                    prompt_details[-1][f"prompt_{i}"] = "".join(
                        (map(lambda x: "".join(x), req.args))
                    )

271
        if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
272
            write_out_info[task_name] = prompt_details
273

274
275
    # Compare all tasks/sets at once to ensure a single training set scan
    if decontaminate:
276
        from lm_eval.decontamination.decontaminate import get_train_overlap
jon-tow's avatar
jon-tow committed
277

278
        print("Finding train/test overlap, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
279
280
281
        overlaps = get_train_overlap(
            docs_for_decontamination, decontamination_ngrams_path, limit
        )
282

Leo Gao's avatar
Leo Gao committed
283
284
285
286
287
    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
288
289
290
291
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
292

Leo Gao's avatar
Leo Gao committed
293
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
294
        resps = getattr(lm, reqtype)([req.args for req in reqs])
Fabrizio Milo's avatar
Fabrizio Milo committed
295
296
297
        resps = [
            x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
        ]
Leo Gao's avatar
Leo Gao committed
298
299
300

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
Fabrizio Milo's avatar
Fabrizio Milo committed
301

302
            if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
303
                write_out_info[task_name][doc_id][f"logit_{i}"] = resp
304
305
                task = task_dict[task_name]
                if isinstance(task, lm_eval.base.MultipleChoiceTask):
Julen Etxaniz's avatar
Julen Etxaniz committed
306
                    write_out_info[task_name][doc_id]["truth"] = doc["gold"]
307
                elif isinstance(task, lm_eval.tasks.winogrande.Winogrande):
Julen Etxaniz's avatar
Julen Etxaniz committed
308
                    write_out_info[task_name][doc_id]["truth"] = task.answer_to_num[
309
310
311
                        doc["answer"]
                    ]
                else:
Julen Etxaniz's avatar
Julen Etxaniz committed
312
                    write_out_info[task_name][doc_id]["truth"] = task.doc_to_target(doc)
313

Leo Gao's avatar
Leo Gao committed
314
315
316
317
318
319
320
321
322
323
324
325
326
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
327

328
            if write_out:
Julen Etxaniz's avatar
Julen Etxaniz committed
329
                write_out_info[task_name][doc_id][metric] = str(value)
330

331
332
333
334
            # Re-use the evaluation for the decontaminated set by just ignoring the overlaps
            if decontaminate and task_name in overlaps:
                if doc_id not in overlaps[task_name]:
                    vals[(task_name, metric + decontaminate_suffix)].append(value)
Fabrizio Milo's avatar
Fabrizio Milo committed
335

Leo Gao's avatar
Leo Gao committed
336
337
338
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
Fabrizio Milo's avatar
Fabrizio Milo committed
339
        real_metric = metric  # key when looking up the metric with task.aggregation
340
        if metric.endswith(decontaminate_suffix):
Fabrizio Milo's avatar
Fabrizio Milo committed
341
342
343
            real_metric = metric.replace(
                decontaminate_suffix, ""
            )  # decontaminated still uses the same metric
344
        results[task_name][metric] = task.aggregation()[real_metric](items)
Leo Gao's avatar
Leo Gao committed
345

346
347
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
348

349
        stderr = lm_eval.metrics.stderr_for_metric(
350
            metric=task.aggregation()[real_metric],
Fabrizio Milo's avatar
Fabrizio Milo committed
351
352
353
            bootstrap_iters=min(bootstrap_iters, 1000)
            if metric in ["bleu", "chrf", "ter"]
            else bootstrap_iters,
354
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
355

Leo Gao's avatar
Leo Gao committed
356
357
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Fabrizio Milo's avatar
Fabrizio Milo committed
358

359
    if write_out:
360
361
362
        import json
        import pathlib

363
364
365
        output_base_path = (
            pathlib.Path(output_base_path)
            if output_base_path is not None
366
367
368
            else pathlib.Path(".")
        )
        try:
369
            output_base_path.mkdir(parents=True, exist_ok=False)
370
371
372
373
374
        except FileExistsError:
            pass

        for task_name, _ in task_dict_items:
            with open(
Julen Etxaniz's avatar
Julen Etxaniz committed
375
                output_base_path.joinpath(f"{task_name}_write_out_info.json"),
376
377
378
                "w",
                encoding="utf8",
            ) as fp:
Julen Etxaniz's avatar
Julen Etxaniz committed
379
                json.dump(write_out_info[task_name], fp, indent=4, ensure_ascii=False)
380

Fabrizio Milo's avatar
Fabrizio Milo committed
381
    return {"results": dict(results), "versions": dict(versions)}
382
383
384


def make_table(result_dict):
385
    """Generate table of results."""
386
387
388
389
390
391
392
393
394
395
396
397
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
398
399
            if m.endswith("_stderr"):
                continue
400
401
402

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
Fabrizio Milo's avatar
Fabrizio Milo committed
403
                values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
404
            else:
Fabrizio Milo's avatar
Fabrizio Milo committed
405
                values.append([k, version, m, "%.4f" % v, "", ""])
406
407
408
409
410
411
412
413
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

414
    return md_writer.dumps()