evaluator.py 10.1 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
Stephen Hogg's avatar
Stephen Hogg committed
3
import pathlib
Leo Gao's avatar
Leo Gao committed
4
import random
Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
9
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import numpy as np
Stephen Hogg's avatar
Stephen Hogg committed
10
from lm_eval.utils import positional_deprecated, run_task_tests
11

12

13
@positional_deprecated
14
15
16
17
def simple_evaluate(model, model_args=None, tasks=[],
                    num_fewshot=0, batch_size=None, device=None,
                    no_cache=False, limit=None, bootstrap_iters=100000,
                    description_dict=None, check_integrity=False):
18
    """Instantiate and evaluate a model on a list of tasks.
19

20
21
22
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
23
        String arguments for each model class, see LM.create_from_arg_string. 
24
25
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
26
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
27
28
29
30
31
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional
32
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
33
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
34
        Whether or not to cache
35
36
37
38
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
39
    :param description_dict: dict[str, str]
40
        Dictionary of custom task descriptions of the form: `task_name: description` 
Stephen Hogg's avatar
Stephen Hogg committed
41
42
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
43
    :return
44
        Dictionary of results
45
    """
46
47
48
    random.seed(1234)
    np.random.seed(1234)

49
50
51
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
52
53
54
55
        if model_args is None: model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
            'batch_size': batch_size, 'device': device
        })
56
57
58
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
59
60

    if not no_cache:
61
        lm = lm_eval.base.CachingLM(
62
            lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
63
        )
64
65
    
    task_dict = lm_eval.tasks.get_task_dict(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
66

Stephen Hogg's avatar
Stephen Hogg committed
67
    if check_integrity:
68
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
69

70
71
72
73
74
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
75
        description_dict=description_dict
76
    )
77
78
79
80
81
82
83
84
85
86

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
87
        "bootstrap_iters": bootstrap_iters,
88
        "description_dict": description_dict
89
90
91
    }

    return results
Leo Gao's avatar
Leo Gao committed
92
93


94
@positional_deprecated
95
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None):
96
97
98
99
100
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
101
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
102
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
103
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
104
105
106
107
108
109
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
110
    :param description_dict: dict[str, str]
111
        Dictionary of custom task descriptions of the form: `task_name: description` 
112
113
114
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
115
116
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

117
118
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
119
120
    if provide_description is not None:
        # nudge people to not specify it at all
121
        print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
122
123
124
125

    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
126
        if(task.has_validation_docs() or task.has_test_docs())
127
    ]
Leo Gao's avatar
Leo Gao committed
128
129

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
130
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
131
132
133
134

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

135
136
137
138
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
139
140
141
142

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}

143
    # get lists of each type of request
144
145
    for task_name, task in task_dict_items:
        versions[task_name] = task.VERSION
146
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
147
148
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
149
            task_doc_func = task.test_docs
Leo Gao's avatar
Leo Gao committed
150
151
        elif task.has_validation_docs():
            task_doc_func = task.validation_docs
152
153
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
154

Leo Gao's avatar
Leo Gao committed
155
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
156
        task_docs = list(task_doc_func())
Leo Gao's avatar
Leo Gao committed
157
158
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
159
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
160

161
        description = description_dict[task_name] if description_dict and task_name in description_dict else ""
cjlovering's avatar
cjlovering committed
162

163
164
165
166
167
168
169
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
                doc=doc,
                num_fewshot=num_fewshot,
                rnd=rnd,
                description=description
Leo Gao's avatar
Leo Gao committed
170
            )
171
            reqs = task.construct_requests(doc, ctx)
172
173
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
174
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
175
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
176
177
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
178
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
179
180
181
182
183
184

    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
185
186
187
188
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
189

Leo Gao's avatar
Leo Gao committed
190
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
191
        resps = getattr(lm, reqtype)([req.args for req in reqs])
192
        resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]
cjlovering's avatar
cjlovering committed
193

194
195
196
        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
Leo Gao's avatar
Leo Gao committed
197
198
199
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
200
201
202
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]
203

204
205
        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]
jon-tow's avatar
jon-tow committed
206

207
208
209
210
        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
    
Leo Gao's avatar
Leo Gao committed
211
    # aggregate results
212
213
214
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
        results[task_name][metric] = task.aggregation()[metric](items)
cjlovering's avatar
cjlovering committed
215

216
217
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
218
219
        stderr = lm_eval.metrics.stderr_for_metric(
            metric=task.aggregation()[metric],
220
            bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
221
        )
Leo Gao's avatar
Leo Gao committed
222
        if stderr is not None:
223
224
            results[task_name][metric + "_stderr"] = stderr(items)
    
cjlovering's avatar
cjlovering committed
225
    return {
226
227
        "results": dict(results),
        "versions": dict(versions)
cjlovering's avatar
cjlovering committed
228
    }
229
230
231


def make_table(result_dict):
232
    """Generate table of results."""
233
234
235
236
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
237
238
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
239
240

    values = []
241
242

    for k, dic in result_dict["results"].items():
243
244
        version = result_dict["versions"][k]
        for m, v in dic.items():
245
246
            if m.endswith("_stderr"):
                continue
247

248
249
            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
250
                values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
251
            else:
252
                values.append([k, version, m, '%.4f' % v, '', ''])
253
254
255
256
257
258
259
260
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

261
    return md_writer.dumps()