evaluator.py 8.42 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
Leo Gao's avatar
Leo Gao committed
3
import random
Leo Gao's avatar
Leo Gao committed
4
import lm_eval.metrics
5
6
7
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
8
from scripts.clean_training_data.contamination import get_train_overlap
9
10
import numpy as np

11
12
13
def simple_evaluate(model, model_args, task_names, num_fewshot=0, batch_size=None, device=None, 
                    no_cache=False, limit=None, bootstrap_iters=100000, decontaminate=False, 
                    ngrams_path=None, ngrams_n_size=None):
14
15
16
    random.seed(1234)
    np.random.seed(1234)

17
    lm = lm_eval.models.MODEL_REGISTRY[model].create_from_arg_string(model_args, {
18
19
20
21
22
23
24
        'batch_size': batch_size, 'device': device
    })

    if not no_cache:
        lm = lm_eval.base.CachingLM(lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db')
    
    task_dict = lm_eval.tasks.get_task_dict(task_names)
25
26
    results = evaluate(lm, task_dict, False, num_fewshot, limit, bootstrap_iters=bootstrap_iters, 
                       decontaminate=decontaminate, ngrams_path=ngrams_path, ngrams_n_size=ngrams_n_size)
27
28
29
30
31
32
33
34
35
36
37
38
39
40

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
        "bootstrap_iters": bootstrap_iters
    }

    return results
Leo Gao's avatar
Leo Gao committed
41

42
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
43

44
45
def evaluate(lm, task_dict, provide_description, num_fewshot, limit, bootstrap_iters=100000,
             decontaminate=False, ngrams_path=None, ngrams_n_size=None):
46
47
    assert not provide_description # not implemented. todo: implement proper description-providing system

48
49
50
    if decontaminate:
        assert ngrams_path and ngrams_n_size

Leo Gao's avatar
Leo Gao committed
51
52
53
54
55
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

    task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
56
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
57
58
59
60

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

61
62
    overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}

Leo Gao's avatar
Leo Gao committed
63
64
65
66
67
68
69
70
    # if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
    # we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
    # (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable

    docs = {}

71
72
    docs_for_decontamination = collections.defaultdict(list)

Leo Gao's avatar
Leo Gao committed
73
74
    # get lists of each type of requeste
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
75
        versions[task_name] = task.VERSION
Leo Gao's avatar
Leo Gao committed
76
77
78
        #default to test doc, fall back to val doc if validation unavailable
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
79
            task_doc_func = task.test_docs
80
            task_set = "test" # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
81
        elif task.has_validation_docs():
82
            task_set = "val" # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
83
            task_doc_func = task.validation_docs
Leo Gao's avatar
Leo Gao committed
84

Leo Gao's avatar
Leo Gao committed
85
86
87
88
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
89
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
90
91

        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
92
93
94
95

            if decontaminate and task.should_decontaminate():
                docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc))

Leo Gao's avatar
Leo Gao committed
96
97
98
99
100
101
            docs[(task_name, doc_id)] = doc

            ctx = task.fewshot_context(
                doc=doc,
                provide_description=provide_description,
                num_fewshot=num_fewshot,
102
                rnd=rnd
Leo Gao's avatar
Leo Gao committed
103
104
105
            )

            reqs = task.construct_requests(doc, ctx)
106
            if not isinstance(reqs, (list, tuple)): reqs = [reqs] 
Leo Gao's avatar
Leo Gao committed
107
108
109
110
111
112
            for i, req in enumerate(reqs):
                requests[req.type].append(req)
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
                requests_origin[req.type].append((i, task_name, doc, doc_id))

113
114
115
116
117
    # Compare all tasks/sets at once to ensure a single training set scan
    if decontaminate:
        print("Finding train/test overlap, please wait...")
        overlaps = get_train_overlap(docs_for_decontamination, ngrams_path, ngrams_n_size, limit)

Leo Gao's avatar
Leo Gao committed
118
119
120
121
122
123
124
125
126
    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
        # TODO: right now, this code runs multiple seperate LM requests for multiple Requests differing
        # only in index. We could implement some kind of caching, but that would be more of a bandaid
        # solution. we could also implement some kind of autogrouping here; they should end up next to each other.

Leo Gao's avatar
Leo Gao committed
127
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        resps = getattr(lm, reqtype)([req.args for req in reqs])

        resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
148
149
150
151
152

            # Re-use the evaluation for the decontaminated set by just ignoring the overlaps
            if decontaminate and task_name in overlaps:
                if doc_id not in overlaps[task_name]:
                    vals[(task_name, metric + decontaminate_suffix)].append(value)
Leo Gao's avatar
Leo Gao committed
153
154
155
156
    
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
157
158
159
160
        real_metric = metric # key when looking up the metric with task.aggregation
        if metric.endswith(decontaminate_suffix):
            real_metric = metric.replace(decontaminate_suffix, "") # decontaminated still uses the same metric
        results[task_name][metric] = task.aggregation()[real_metric](items)
Leo Gao's avatar
Leo Gao committed
161

162
163
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
164
        stderr = lm_eval.metrics.stderr_for_metric(task.aggregation()[real_metric], bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters)
Leo Gao's avatar
Leo Gao committed
165
166
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Leo Gao's avatar
Leo Gao committed
167
    
Leo Gao's avatar
Leo Gao committed
168
    return {
169
170
        "results": dict(results),
        "versions": dict(versions)
Leo Gao's avatar
Leo Gao committed
171
    }
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203


def make_table(result_dict):
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
            if m.endswith("_stderr"): continue

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]

                values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
            else:
                values.append([k, version, m, '%.4f' % v, '', ''])
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

    return md_writer.dumps()