evaluator.py 7.71 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
Leo Gao's avatar
Leo Gao committed
3
import random
Leo Gao's avatar
Leo Gao committed
4
import lm_eval.metrics
5
6
7
8
9
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import numpy as np

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

def simple_evaluate(model, model_args, task_names,
                    num_fewshot=0, batch_size=None, device=None,
                    no_cache=False, limit=None, bootstrap_iters=100000):
    """

    :param model: str
        Name of model, see lm_eval.models.get_model
    :param model_args: str
        String arguments for each model class, see LM.create_from_arg_string
    :param task_names: list[str]
        List of task names
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional

    :param no_cache: bool
        Whether or not
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
    :return
    """
36
37
38
39
40
41
42
43
    random.seed(1234)
    np.random.seed(1234)

    lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
        'batch_size': batch_size, 'device': device
    })

    if not no_cache:
44
45
46
        lm = lm_eval.base.CachingLM(
            lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
        )
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    
    task_dict = lm_eval.tasks.get_task_dict(task_names)
    results = evaluate(lm, task_dict, False, num_fewshot, limit)

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
        "bootstrap_iters": bootstrap_iters
    }

    return results
Leo Gao's avatar
Leo Gao committed
64
65


66
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000):
Leo Gao's avatar
Leo Gao committed
67
68
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

69
70
71
72
73
74
75
76
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.

    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
        if(task.has_validation_docs() or task.has_test_docs())
    ]
Leo Gao's avatar
Leo Gao committed
77
78

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
79
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
80
81
82
83

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

84
85
86
87
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
88
89
90
91

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}

92
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
93
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
94
        versions[task_name] = task.VERSION
95
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
96
97
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
98
            task_doc_func = task.test_docs
Leo Gao's avatar
Leo Gao committed
99
100
        elif task.has_validation_docs():
            task_doc_func = task.validation_docs
101
102
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
103

Leo Gao's avatar
Leo Gao committed
104
105
106
107
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
108
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
109
110

        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
Leo Gao's avatar
Leo Gao committed
111
112
113
114
115
116
            docs[(task_name, doc_id)] = doc

            ctx = task.fewshot_context(
                doc=doc,
                provide_description=provide_description,
                num_fewshot=num_fewshot,
117
                rnd=rnd
Leo Gao's avatar
Leo Gao committed
118
119
120
            )

            reqs = task.construct_requests(doc, ctx)
121
122
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
123
            for i, req in enumerate(reqs):
124
                requests[req.req_type].append(req)
Leo Gao's avatar
Leo Gao committed
125
126
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
127
                requests_origin[req.req_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
128
129
130
131
132
133

    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
134
135
136
137
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
138

Leo Gao's avatar
Leo Gao committed
139
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        resps = getattr(lm, reqtype)([req.args for req in reqs])
        resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
    
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
        results[task_name][metric] = task.aggregation()[metric](items)
Leo Gao's avatar
Leo Gao committed
164

165
166
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
167
168
169
170
        stderr = lm_eval.metrics.stderr_for_metric(
            metric=task.aggregation()[metric],
            bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
        )
Leo Gao's avatar
Leo Gao committed
171
172
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Leo Gao's avatar
Leo Gao committed
173
    
Leo Gao's avatar
Leo Gao committed
174
    return {
175
176
        "results": dict(results),
        "versions": dict(versions)
Leo Gao's avatar
Leo Gao committed
177
    }
178
179
180


def make_table(result_dict):
181
    """Generate table of results."""
182
183
184
185
186
187
188
189
190
191
192
193
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
194
195
            if m.endswith("_stderr"):
                continue
196
197
198
199
200
201
202
203
204
205
206
207
208
209

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
                values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
            else:
                values.append([k, version, m, '%.4f' % v, '', ''])
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

210
    return md_writer.dumps()