evaluator.py 10 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
3
import os
Leo Gao's avatar
Leo Gao committed
4
import random
Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
9
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import numpy as np
10
from lm_eval.utils import positional_deprecated
11

12

13
@positional_deprecated
14
def simple_evaluate(model, model_args=None, tasks=[],
15
                    num_fewshot=0, batch_size=None, device=None,
16
                    no_cache=False, limit=None, bootstrap_iters=100000,
17
                    description_dict=None):
18
    """Instantiate and evaluate a model on a list of tasks.
19

20
21
22
23
24
25
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
        String arguments for each model class, see LM.create_from_arg_string. 
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
26
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
27
28
29
30
31
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional
32
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
33
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
34
        Whether or not to cache
35
36
37
38
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
39
    :param description_dict: dict[str, str]
40
        Dictionary of custom task descriptions of the form: `task_name: description` 
41
    :return
42
        Dictionary of results
43
    """
44
45
46
    random.seed(1234)
    np.random.seed(1234)

47
48
49
50
51
52
53
54
55
56
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
        if model_args is None: model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
            'batch_size': batch_size, 'device': device
        })
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
57
58

    if not no_cache:
59
60
61
        lm = lm_eval.base.CachingLM(
            lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
        )
62
    
63
    task_dict = lm_eval.tasks.get_task_dict(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
64

65
66
67
68
69
70
71
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
        description_dict=description_dict
    )
72
73
74
75
76
77
78
79
80
81

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
82
83
        "bootstrap_iters": bootstrap_iters,
        "description_dict": description_dict
84
85
86
    }

    return results
Leo Gao's avatar
Leo Gao committed
87
88


89
@positional_deprecated
Leo Gao's avatar
Leo Gao committed
90
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None):
91
92
93
94
95
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
96
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
97
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
98
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
99
100
101
102
103
104
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
105
    :param description_dict: dict[str, str]
106
        Dictionary of custom task descriptions of the form: `task_name: description` 
107
108
109
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
110
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
111
112
113
    print(f"{'='*20}")
    print(f"Task Module: {lm_eval.base.MultipleChoiceTask.__name__}")
    print(f"{'='*20}")
Leo Gao's avatar
Leo Gao committed
114

115
116
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
117
118
119
    if provide_description is not None:
        # nudge people to not specify it at all
        print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
120
121
122
123
124
125

    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
        if(task.has_validation_docs() or task.has_test_docs())
    ]
Leo Gao's avatar
Leo Gao committed
126
127

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
128
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
129
130
131
132

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

133
134
135
136
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
137
138
139
140

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}

141
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
142
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
143
        versions[task_name] = task.VERSION
144
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
145
146
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
147
            task_doc_func = task.test_docs
Leo Gao's avatar
Leo Gao committed
148
149
        elif task.has_validation_docs():
            task_doc_func = task.validation_docs
150
151
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
152

Leo Gao's avatar
Leo Gao committed
153
154
155
156
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
157
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
158

159
160
        description = description_dict[task_name] if description_dict and task_name in description_dict else ""

Leo Gao's avatar
Leo Gao committed
161
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
Leo Gao's avatar
Leo Gao committed
162
163
164
165
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
                doc=doc,
                num_fewshot=num_fewshot,
166
167
                rnd=rnd,
                description=description
Leo Gao's avatar
Leo Gao committed
168
169
            )
            reqs = task.construct_requests(doc, ctx)
170
171
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
172
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
173
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
174
175
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
176
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
177
178
179
180
181
182

    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
183
184
185
186
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
187

Leo Gao's avatar
Leo Gao committed
188
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        resps = getattr(lm, reqtype)([req.args for req in reqs])
        resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
    
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
        results[task_name][metric] = task.aggregation()[metric](items)
Leo Gao's avatar
Leo Gao committed
213

214
215
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
216
217
218
219
        stderr = lm_eval.metrics.stderr_for_metric(
            metric=task.aggregation()[metric],
            bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
        )
Leo Gao's avatar
Leo Gao committed
220
221
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Leo Gao's avatar
Leo Gao committed
222
    
Leo Gao's avatar
Leo Gao committed
223
    return {
224
225
        "results": dict(results),
        "versions": dict(versions)
Leo Gao's avatar
Leo Gao committed
226
    }
227
228
229


def make_table(result_dict):
230
    """Generate table of results."""
231
232
233
234
235
236
237
238
239
240
241
242
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
243
244
            if m.endswith("_stderr"):
                continue
245
246
247
248
249
250
251
252
253
254
255
256
257
258

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
                values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
            else:
                values.append([k, version, m, '%.4f' % v, '', ''])
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

259
    return md_writer.dumps()