evaluator.py 11.9 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
bzantium's avatar
bzantium committed
3
import numpy as np
Leo Gao's avatar
Leo Gao committed
4
import random
Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
Stephen Hogg's avatar
Stephen Hogg committed
9
from lm_eval.utils import positional_deprecated, run_task_tests
10

11

12
@positional_deprecated
bzantium's avatar
bzantium committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
    num_fewshot=0,
    batch_size=None,
    device=None,
    no_cache=False,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    check_integrity=False,
    decontamination_ngrams_path=None,
):

28
    """Instantiate and evaluate a model on a list of tasks.
29

30
31
32
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
bzantium's avatar
bzantium committed
33
        String arguments for each model class, see LM.create_from_arg_string.
34
35
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
36
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
37
38
39
40
41
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional
42
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
43
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
44
        Whether or not to cache
45
46
47
48
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
49
    :param description_dict: dict[str, str]
bzantium's avatar
bzantium committed
50
        Dictionary of custom task descriptions of the form: `task_name: description`
Stephen Hogg's avatar
Stephen Hogg committed
51
52
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
53
    :return
54
        Dictionary of results
55
    """
56
57
58
    random.seed(1234)
    np.random.seed(1234)

59
60
61
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
bzantium's avatar
bzantium committed
62
63
64
65
66
        if model_args is None:
            model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(
            model_args, {"batch_size": batch_size, "device": device}
        )
67
68
69
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
70
71

    if not no_cache:
72
        lm = lm_eval.base.CachingLM(
bzantium's avatar
bzantium committed
73
74
75
76
77
78
            lm,
            "lm_cache/"
            + model
            + "_"
            + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
            + ".db",
79
        )
bzantium's avatar
bzantium committed
80

81
    task_dict = lm_eval.tasks.get_task_dict(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
82

Stephen Hogg's avatar
Stephen Hogg committed
83
    if check_integrity:
84
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
85

86
87
88
89
90
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
bzantium's avatar
bzantium committed
91
92
93
        bootstrap_iters=bootstrap_iters,
        description_dict=description_dict,
        decontamination_ngrams_path=decontamination_ngrams_path,
94
    )
95
96
97
98
99
100
101
102
103
104

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
105
        "bootstrap_iters": bootstrap_iters,
bzantium's avatar
bzantium committed
106
        "description_dict": description_dict,
107
108
109
    }

    return results
Leo Gao's avatar
Leo Gao committed
110
111


bzantium's avatar
bzantium committed
112
113
114
decontaminate_suffix = "_decontaminate"


115
@positional_deprecated
bzantium's avatar
bzantium committed
116
117
118
119
120
121
122
123
124
125
def evaluate(
    lm,
    task_dict,
    provide_description=None,
    num_fewshot=0,
    limit=None,
    bootstrap_iters=100000,
    description_dict=None,
    decontamination_ngrams_path=None,
):
126
127
128
129
130
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
131
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
132
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
133
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
134
135
136
137
138
139
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
140
    :param description_dict: dict[str, str]
bzantium's avatar
bzantium committed
141
        Dictionary of custom task descriptions of the form: `task_name: description`
142
143
144
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
145
146
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

147
148
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
149
150
    if provide_description is not None:
        # nudge people to not specify it at all
bzantium's avatar
bzantium committed
151
152
153
154
155
        print(
            "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
        )

    decontaminate = decontamination_ngrams_path is not None
156
157
158
159

    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
bzantium's avatar
bzantium committed
160
        if (task.has_validation_docs() or task.has_test_docs())
161
    ]
Leo Gao's avatar
Leo Gao committed
162
163

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
164
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
165
166
167
168

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

bzantium's avatar
bzantium committed
169
170
    overlaps = collections.defaultdict(list)  # {task_name: contaminated_docs}

171
172
173
174
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
175
176
177
178

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}

bzantium's avatar
bzantium committed
179
180
    docs_for_decontamination = collections.defaultdict(list)

181
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
182
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
183
        versions[task_name] = task.VERSION
184
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
185
186
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
187
            task_doc_func = task.test_docs
bzantium's avatar
bzantium committed
188
            task_set = "test"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
189
        elif task.has_validation_docs():
bzantium's avatar
bzantium committed
190
            task_set = "val"  # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
191
            task_doc_func = task.validation_docs
192
193
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
194

Leo Gao's avatar
Leo Gao committed
195
196
197
198
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
199
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
200

bzantium's avatar
bzantium committed
201
202
203
204
205
        description = (
            description_dict[task_name]
            if description_dict and task_name in description_dict
            else ""
        )
206

Leo Gao's avatar
Leo Gao committed
207
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
bzantium's avatar
bzantium committed
208
209
210
211
212
213

            if decontaminate and task.should_decontaminate():
                docs_for_decontamination[(task_name, task_set)].append(
                    task.doc_to_decontamination_query(doc)
                )

Leo Gao's avatar
Leo Gao committed
214
215
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
bzantium's avatar
bzantium committed
216
                doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
Leo Gao's avatar
Leo Gao committed
217
218
            )
            reqs = task.construct_requests(doc, ctx)
219
220
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
221
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
222
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
223
224
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
225
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
226

bzantium's avatar
bzantium committed
227
228
229
230
231
232
233
234
235
    # Compare all tasks/sets at once to ensure a single training set scan
    if decontaminate:
        from lm_eval.decontamination.decontaminate import get_train_overlap

        print("Finding train/test overlap, please wait...")
        overlaps = get_train_overlap(
            docs_for_decontamination, decontamination_ngrams_path, limit
        )

Leo Gao's avatar
Leo Gao committed
236
237
238
239
240
    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
241
242
243
244
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
245

Leo Gao's avatar
Leo Gao committed
246
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
247
        resps = getattr(lm, reqtype)([req.args for req in reqs])
bzantium's avatar
bzantium committed
248
249
250
        resps = [
            x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
        ]
Leo Gao's avatar
Leo Gao committed
251
252
253

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
bzantium's avatar
bzantium committed
254

Leo Gao's avatar
Leo Gao committed
255
256
257
258
259
260
261
262
263
264
265
266
267
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
bzantium's avatar
bzantium committed
268
269
270
271
272
273

            # Re-use the evaluation for the decontaminated set by just ignoring the overlaps
            if decontaminate and task_name in overlaps:
                if doc_id not in overlaps[task_name]:
                    vals[(task_name, metric + decontaminate_suffix)].append(value)

Leo Gao's avatar
Leo Gao committed
274
275
276
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
bzantium's avatar
bzantium committed
277
278
279
280
281
282
        real_metric = metric  # key when looking up the metric with task.aggregation
        if metric.endswith(decontaminate_suffix):
            real_metric = metric.replace(
                decontaminate_suffix, ""
            )  # decontaminated still uses the same metric
        results[task_name][metric] = task.aggregation()[real_metric](items)
Leo Gao's avatar
Leo Gao committed
283

284
285
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
bzantium's avatar
bzantium committed
286

287
        stderr = lm_eval.metrics.stderr_for_metric(
bzantium's avatar
bzantium committed
288
289
290
291
            metric=task.aggregation()[real_metric],
            bootstrap_iters=min(bootstrap_iters, 1000)
            if metric in ["bleu", "chrf", "ter"]
            else bootstrap_iters,
292
        )
bzantium's avatar
bzantium committed
293

Leo Gao's avatar
Leo Gao committed
294
295
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
bzantium's avatar
bzantium committed
296
297

    return {"results": dict(results), "versions": dict(versions)}
298
299
300


def make_table(result_dict):
301
    """Generate table of results."""
302
303
304
305
306
307
308
309
310
311
312
313
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
314
315
            if m.endswith("_stderr"):
                continue
316
317
318

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
bzantium's avatar
bzantium committed
319
                values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
320
            else:
bzantium's avatar
bzantium committed
321
                values.append([k, version, m, "%.4f" % v, "", ""])
322
323
324
325
326
327
328
329
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

330
    return md_writer.dumps()