evaluator.py 9.17 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
3
import json
Leo Gao's avatar
Leo Gao committed
4
import random
Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
9
10
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import numpy as np

11
12
13

def simple_evaluate(model, model_args, task_names,
                    num_fewshot=0, batch_size=None, device=None,
14
15
                    no_cache=False, limit=None, bootstrap_iters=100000,
                    description_dict_path=None):
16
    """Instantiate and evaluate a model on a list of tasks.
17
18
19
20
21
22
23
24
25
26
27
28

    :param model: str
        Name of model, see lm_eval.models.get_model
    :param model_args: str
        String arguments for each model class, see LM.create_from_arg_string
    :param task_names: list[str]
        List of task names
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional
29
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
30
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
31
        Whether or not to cache
32
33
34
35
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
36
37
    :param description_dict_path:
        Path to a JSON file containing `task_name: description` key-values for custom prompts
38
    :return
39
        Dictionary of results
40
    """
41
42
43
44
45
46
47
48
    random.seed(1234)
    np.random.seed(1234)

    lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
        'batch_size': batch_size, 'device': device
    })

    if not no_cache:
49
50
51
        lm = lm_eval.base.CachingLM(
            lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
        )
52
53
    
    task_dict = lm_eval.tasks.get_task_dict(task_names)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
54

55
    description_dict = {}
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
56
57
    if description_dict_path:
        with open(description_dict_path, 'r') as f:
58
59
            description_dict = json.load(f)

Jonathan Tow's avatar
Merge  
Jonathan Tow committed
60
    results = evaluate(lm, task_dict, False, num_fewshot, limit, description_dict=description_dict)
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
        "bootstrap_iters": bootstrap_iters
    }

    return results
Leo Gao's avatar
Leo Gao committed
75
76


77
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000, description_dict=None):
78
79
80
81
82
83
84
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
        Dictionary of tasks
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
85
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
86
87
88
89
90
91
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
92
93
    :param description_dict:
        Dictionary of task descriptions of the form: `task_name: description` 
94
95
96
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
97
98
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

99
100
101
102
103
104
105
106
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.

    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
        if(task.has_validation_docs() or task.has_test_docs())
    ]
Leo Gao's avatar
Leo Gao committed
107
108

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
109
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
110
111
112
113

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

114
115
116
117
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
118
119
120
121

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}

122
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
123
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
124
        versions[task_name] = task.VERSION
125
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
126
127
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
128
            task_doc_func = task.test_docs
Leo Gao's avatar
Leo Gao committed
129
130
        elif task.has_validation_docs():
            task_doc_func = task.validation_docs
131
132
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
133

Leo Gao's avatar
Leo Gao committed
134
135
136
137
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
138
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
139

140
141
        description = description_dict[task_name] if description_dict and task_name in description_dict else ""

Leo Gao's avatar
Leo Gao committed
142
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
Leo Gao's avatar
Leo Gao committed
143
144
145
146
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
                doc=doc,
                num_fewshot=num_fewshot,
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
147
                provide_description=provide_description,
148
149
                rnd=rnd,
                description=description
Leo Gao's avatar
Leo Gao committed
150
151
            )
            reqs = task.construct_requests(doc, ctx)
152
153
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
154
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
155
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
156
157
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
158
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
159
160
161
162
163
164

    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
165
166
167
168
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
169

Leo Gao's avatar
Leo Gao committed
170
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        resps = getattr(lm, reqtype)([req.args for req in reqs])
        resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
    
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
        results[task_name][metric] = task.aggregation()[metric](items)
Leo Gao's avatar
Leo Gao committed
195

196
197
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
198
199
200
201
        stderr = lm_eval.metrics.stderr_for_metric(
            metric=task.aggregation()[metric],
            bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
        )
Leo Gao's avatar
Leo Gao committed
202
203
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Leo Gao's avatar
Leo Gao committed
204
    
Leo Gao's avatar
Leo Gao committed
205
    return {
206
207
        "results": dict(results),
        "versions": dict(versions)
Leo Gao's avatar
Leo Gao committed
208
    }
209
210
211


def make_table(result_dict):
212
    """Generate table of results."""
213
214
215
216
217
218
219
220
221
222
223
224
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
225
226
            if m.endswith("_stderr"):
                continue
227
228
229
230
231
232
233
234
235
236
237
238
239
240

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
                values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
            else:
                values.append([k, version, m, '%.4f' % v, '', ''])
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

241
    return md_writer.dumps()