evaluator.py 12.5 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
3
import numpy as np
Leo Gao's avatar
Leo Gao committed
4
import random
5
import inspect
Leo Gao's avatar
Leo Gao committed
6
import lm_eval.metrics
7
8
9
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
Stephen Hogg's avatar
Stephen Hogg committed
10
from lm_eval.utils import positional_deprecated, run_task_tests
11

Fabrizio Milo's avatar
Fabrizio Milo committed
12

13
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
    num_fewshot=0,
    batch_size=None,
    device=None,
    no_cache=False,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    check_integrity=False,
    decontamination_ngrams_path=None,
):
28

29
    """Instantiate and evaluate a model on a list of tasks.
30

31
32
33
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
34
        String arguments for each model class, see LM.create_from_arg_string.
35
36
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
37
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
38
39
40
41
42
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional
43
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
44
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
45
        Whether or not to cache
46
47
48
49
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
50
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
51
        Dictionary of custom task descriptions of the form: `task_name: description`
Stephen Hogg's avatar
Stephen Hogg committed
52
53
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
54
    :return
55
        Dictionary of results
56
    """
57
58
59
    random.seed(1234)
    np.random.seed(1234)

60
61
62
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
63
64
65
66
67
        if model_args is None:
            model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(
            model_args, {"batch_size": batch_size, "device": device}
        )
68
69
70
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
71
72

    if not no_cache:
73
        lm = lm_eval.base.CachingLM(
Fabrizio Milo's avatar
Fabrizio Milo committed
74
75
76
77
78
79
            lm,
            "lm_cache/"
            + model
            + "_"
            + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
            + ".db",
80
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
81

82
    task_dict = lm_eval.tasks.get_task_dict(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
83

Stephen Hogg's avatar
Stephen Hogg committed
84
    if check_integrity:
85
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
86

87
88
89
90
91
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
92
        bootstrap_iters=bootstrap_iters,
93
        description_dict=description_dict,
Fabrizio Milo's avatar
Fabrizio Milo committed
94
        decontamination_ngrams_path=decontamination_ngrams_path,
95
    )
96
97
98
99
100
101
102
103
104
105

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
106
        "bootstrap_iters": bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
107
        "description_dict": description_dict,
108
109
110
    }

    return results
Leo Gao's avatar
Leo Gao committed
111

Fabrizio Milo's avatar
Fabrizio Milo committed
112

113
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
114

Fabrizio Milo's avatar
Fabrizio Milo committed
115

116
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
117
118
119
120
121
122
123
124
125
126
def evaluate(
    lm,
    task_dict,
    provide_description=None,
    num_fewshot=0,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    decontamination_ngrams_path=None,
):
127
128
129
130
131
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
132
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
133
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
134
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
135
136
137
138
139
140
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
141
    :param description_dict: dict[str, str]
Fabrizio Milo's avatar
Fabrizio Milo committed
142
        Dictionary of custom task descriptions of the form: `task_name: description`
143
144
145
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
146
147
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

148
149
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
150
151
    if provide_description is not None:
        # nudge people to not specify it at all
Fabrizio Milo's avatar
Fabrizio Milo committed
152
153
154
        print(
            "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
        )
155

Leo Gao's avatar
Leo Gao committed
156
    decontaminate = decontamination_ngrams_path is not None
157

158
159
160
    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
Fabrizio Milo's avatar
Fabrizio Milo committed
161
        if (task.has_validation_docs() or task.has_test_docs())
162
    ]
Leo Gao's avatar
Leo Gao committed
163
164

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
165
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
166
167
168
169

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

Fabrizio Milo's avatar
Fabrizio Milo committed
170
    overlaps = collections.defaultdict(list)  # {task_name: contaminated_docs}
171

172
173
174
175
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
176
177
178
179

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}

180
    docs_for_decontamination = collections.defaultdict(list)
181
    task_to_description = {}
182

183
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
184
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
185
        versions[task_name] = task.VERSION
186
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
187
188
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
189
            task_doc_func = task.test_docs
Fabrizio Milo's avatar
Fabrizio Milo committed
190
            task_set = "test"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
191
        elif task.has_validation_docs():
Fabrizio Milo's avatar
Fabrizio Milo committed
192
            task_set = "val"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
193
            task_doc_func = task.validation_docs
194
195
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
196

Leo Gao's avatar
Leo Gao committed
197
198
199
200
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
201
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
202

Fabrizio Milo's avatar
Fabrizio Milo committed
203
204
205
206
207
        description = (
            description_dict[task_name]
            if description_dict and task_name in description_dict
            else ""
        )
208
        task_to_description[task_name] = description
209

Leo Gao's avatar
Leo Gao committed
210
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
211
212

            if decontaminate and task.should_decontaminate():
Fabrizio Milo's avatar
Fabrizio Milo committed
213
214
215
                docs_for_decontamination[(task_name, task_set)].append(
                    task.doc_to_decontamination_query(doc)
                )
216

Leo Gao's avatar
Leo Gao committed
217
218
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
Fabrizio Milo's avatar
Fabrizio Milo committed
219
                doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
Leo Gao's avatar
Leo Gao committed
220
            )
221
222
223
224
            if "description" in inspect.getfullargspec(task.construct_requests).args:
                reqs = task.construct_requests(doc, ctx, description=description)
            else:
                reqs = task.construct_requests(doc, ctx)
225
226
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
227
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
228
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
229
230
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
231
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
232

233
234
    # Compare all tasks/sets at once to ensure a single training set scan
    if decontaminate:
235
        from lm_eval.decontamination.decontaminate import get_train_overlap
jon-tow's avatar
jon-tow committed
236

237
        print("Finding train/test overlap, please wait...")
Fabrizio Milo's avatar
Fabrizio Milo committed
238
239
240
        overlaps = get_train_overlap(
            docs_for_decontamination, decontamination_ngrams_path, limit
        )
241

Leo Gao's avatar
Leo Gao committed
242
243
244
245
246
    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
247
248
249
250
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
251

Leo Gao's avatar
Leo Gao committed
252
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
253
        resps = getattr(lm, reqtype)([req.args for req in reqs])
Fabrizio Milo's avatar
Fabrizio Milo committed
254
255
256
        resps = [
            x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
        ]
Leo Gao's avatar
Leo Gao committed
257
258
259

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
Fabrizio Milo's avatar
Fabrizio Milo committed
260

Leo Gao's avatar
Leo Gao committed
261
262
263
264
265
266
267
268
269
270
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

271
272
273
274
275
        # be backward compatible with tasks that do not allow description_dict in process_results
        if "description" in inspect.getfullargspec(task.process_results).args:
            metrics = task.process_results(doc, requests, task_to_description[task_name])
        else:
            metrics = task.process_results(doc, requests)
Leo Gao's avatar
Leo Gao committed
276
277
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
278
279
280
281
282

            # Re-use the evaluation for the decontaminated set by just ignoring the overlaps
            if decontaminate and task_name in overlaps:
                if doc_id not in overlaps[task_name]:
                    vals[(task_name, metric + decontaminate_suffix)].append(value)
Fabrizio Milo's avatar
Fabrizio Milo committed
283

Leo Gao's avatar
Leo Gao committed
284
285
286
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
Fabrizio Milo's avatar
Fabrizio Milo committed
287
        real_metric = metric  # key when looking up the metric with task.aggregation
288
        if metric.endswith(decontaminate_suffix):
Fabrizio Milo's avatar
Fabrizio Milo committed
289
290
291
            real_metric = metric.replace(
                decontaminate_suffix, ""
            )  # decontaminated still uses the same metric
292
        results[task_name][metric] = task.aggregation()[real_metric](items)
Leo Gao's avatar
Leo Gao committed
293

294
295
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
296

297
        stderr = lm_eval.metrics.stderr_for_metric(
298
            metric=task.aggregation()[real_metric],
Fabrizio Milo's avatar
Fabrizio Milo committed
299
300
301
            bootstrap_iters=min(bootstrap_iters, 1000)
            if metric in ["bleu", "chrf", "ter"]
            else bootstrap_iters,
302
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
303

Leo Gao's avatar
Leo Gao committed
304
305
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Fabrizio Milo's avatar
Fabrizio Milo committed
306
307

    return {"results": dict(results), "versions": dict(versions)}
308
309
310


def make_table(result_dict):
311
    """Generate table of results."""
312
313
314
315
316
317
318
319
320
321
322
323
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
324
325
            if m.endswith("_stderr"):
                continue
326
327
328

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
Fabrizio Milo's avatar
Fabrizio Milo committed
329
                values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
330
            else:
Fabrizio Milo's avatar
Fabrizio Milo committed
331
                values.append([k, version, m, "%.4f" % v, "", ""])
332
333
334
335
336
337
338
339
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

340
    return md_writer.dumps()