evaluator.py 11.7 KB
Newer Older
Leo Gao's avatar
Leo Gao committed
1
2
import collections
import itertools
Stephen Hogg's avatar
Stephen Hogg committed
3
import pathlib
Leo Gao's avatar
Leo Gao committed
4
import random
Leo Gao's avatar
Leo Gao committed
5
import lm_eval.metrics
6
7
8
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
9
import lm_eval.decontamination
10
import numpy as np
Stephen Hogg's avatar
Stephen Hogg committed
11
from lm_eval.utils import positional_deprecated, run_task_tests
researcher2's avatar
researcher2 committed
12
from lm_eval.decontamination.decontaminate import get_train_overlap
13

14
@positional_deprecated
15
def simple_evaluate(model, model_args=None, tasks=[],
16
                    num_fewshot=0, batch_size=None, device=None,
17
                    no_cache=False, limit=None, bootstrap_iters=100000,
18
19
20
                    description_dict=None, check_integrity=False, 
                    decontamination_ngrams_path=None):

21
    """Instantiate and evaluate a model on a list of tasks.
22

23
24
25
26
27
28
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
        String arguments for each model class, see LM.create_from_arg_string. 
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
29
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
30
31
32
33
34
    :param num_fewshot: int
        Number of examples in few-shot context
    :param batch_size: int, optional
        Batch size for model
    :param device: str, optional
35
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
36
    :param no_cache: bool
Leo Gao's avatar
Leo Gao committed
37
        Whether or not to cache
38
39
40
41
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
42
    :param description_dict: dict[str, str]
43
        Dictionary of custom task descriptions of the form: `task_name: description` 
Stephen Hogg's avatar
Stephen Hogg committed
44
45
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
46
    :return
47
        Dictionary of results
48
    """
49
50
51
    random.seed(1234)
    np.random.seed(1234)

52
53
54
55
56
57
58
59
60
61
    assert tasks != [], "No tasks specified"

    if isinstance(model, str):
        if model_args is None: model_args = ""
        lm = lm_eval.models.get_model(model).create_from_arg_string(model_args, {
            'batch_size': batch_size, 'device': device
        })
    else:
        assert isinstance(model, lm_eval.base.LM)
        lm = model
62
63

    if not no_cache:
64
65
66
        lm = lm_eval.base.CachingLM(
            lm, 'lm_cache/' + model + '_' + model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db'
        )
67
    
68
    task_dict = lm_eval.tasks.get_task_dict(tasks)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
69

Stephen Hogg's avatar
Stephen Hogg committed
70
    if check_integrity:
71
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
72

73
74
75
76
77
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        num_fewshot=num_fewshot,
        limit=limit,
78
        description_dict=description_dict,
Leo Gao's avatar
Leo Gao committed
79
        decontamination_ngrams_path=decontamination_ngrams_path, 
80
    )
81
82
83
84
85
86
87
88
89
90

    # add info about the model and few shot config
    results["config"] = {
        "model": model,
        "model_args": model_args,
        "num_fewshot": num_fewshot,
        "batch_size": batch_size,
        "device": device,
        "no_cache": no_cache,
        "limit": limit,
91
92
        "bootstrap_iters": bootstrap_iters,
        "description_dict": description_dict
93
94
95
    }

    return results
Leo Gao's avatar
Leo Gao committed
96

97
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
98

99
@positional_deprecated
100
def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, bootstrap_iters=100000, description_dict=None,
Leo Gao's avatar
Leo Gao committed
101
             decontamination_ngrams_path=None):
102
103
104
105
106
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
Leo Gao's avatar
Leo Gao committed
107
        Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
108
    :param provide_description: bool
Leo Gao's avatar
Leo Gao committed
109
        Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
110
111
112
113
114
115
    :param num_fewshot: int
        Number of examples in few-shot context
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Jonathan Tow's avatar
Jonathan Tow committed
116
    :param description_dict: dict[str, str]
117
        Dictionary of custom task descriptions of the form: `task_name: description` 
118
119
120
    :return
        Dictionary of results
    """
Leo Gao's avatar
Leo Gao committed
121
122
    # TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces

123
124
    # TODO: todo: implement proper description-providing system
    assert not provide_description  # not implemented.
Leo Gao's avatar
Leo Gao committed
125
126
127
    if provide_description is not None:
        # nudge people to not specify it at all
        print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
128

Leo Gao's avatar
Leo Gao committed
129
    decontaminate = decontamination_ngrams_path is not None
130

131
132
133
134
135
    task_dict_items = [
        (name, task)
        for name, task in task_dict.items()
        if(task.has_validation_docs() or task.has_test_docs())
    ]
Leo Gao's avatar
Leo Gao committed
136
137

    results = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
138
    versions = collections.defaultdict(dict)
Leo Gao's avatar
Leo Gao committed
139
140
141
142

    requests = collections.defaultdict(list)
    requests_origin = collections.defaultdict(list)

143
144
    overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}

145
146
147
148
    # If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
    # memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
    # over-engineering is bad (or we could make it write the requests to disk and then read them back out again
    #  - probably using an sqlite db because of all the moving parts we have
Leo Gao's avatar
Leo Gao committed
149
150
151
152

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
    docs = {}

153
154
    docs_for_decontamination = collections.defaultdict(list)

155
    # get lists of each type of request
Leo Gao's avatar
Leo Gao committed
156
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
157
        versions[task_name] = task.VERSION
158
        # default to test doc, fall back to val doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
159
160
        # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
        if task.has_test_docs():
Leo Gao's avatar
Leo Gao committed
161
            task_doc_func = task.test_docs
162
            task_set = "test" # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
163
        elif task.has_validation_docs():
164
            task_set = "val" # Required for caching in the decontamination
Leo Gao's avatar
Leo Gao committed
165
            task_doc_func = task.validation_docs
166
167
        else:
            raise RuntimeError("Task has neither test_docs nor validation_docs")
Leo Gao's avatar
Leo Gao committed
168

Leo Gao's avatar
Leo Gao committed
169
170
171
172
        # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
        task_docs = list(task_doc_func())
        rnd = random.Random()
        rnd.seed(42)
Jason Phang's avatar
Jason Phang committed
173
        rnd.shuffle(task_docs)
Leo Gao's avatar
Leo Gao committed
174

175
176
        description = description_dict[task_name] if description_dict and task_name in description_dict else ""

Leo Gao's avatar
Leo Gao committed
177
        for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
178
179
180
181

            if decontaminate and task.should_decontaminate():
                docs_for_decontamination[(task_name, task_set)].append(task.doc_to_decontamination_query(doc))

Leo Gao's avatar
Leo Gao committed
182
183
184
185
            docs[(task_name, doc_id)] = doc
            ctx = task.fewshot_context(
                doc=doc,
                num_fewshot=num_fewshot,
186
187
                rnd=rnd,
                description=description
Leo Gao's avatar
Leo Gao committed
188
189
            )
            reqs = task.construct_requests(doc, ctx)
190
191
            if not isinstance(reqs, (list, tuple)):
                reqs = [reqs]
Leo Gao's avatar
Leo Gao committed
192
            for i, req in enumerate(reqs):
Leo Gao's avatar
Leo Gao committed
193
                requests[req.request_type].append(req)
Leo Gao's avatar
Leo Gao committed
194
195
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
196
                requests_origin[req.request_type].append((i, task_name, doc, doc_id))
Leo Gao's avatar
Leo Gao committed
197

198
199
200
    # Compare all tasks/sets at once to ensure a single training set scan
    if decontaminate:
        print("Finding train/test overlap, please wait...")
researcher2's avatar
researcher2 committed
201
        overlaps = get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit)
202

Leo Gao's avatar
Leo Gao committed
203
204
205
206
207
    # all responses for each (task, doc)
    process_res_queue = collections.defaultdict(list)

    # execute each type of request
    for reqtype, reqs in requests.items():
208
209
210
211
        # TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
        #       only in index. We could implement some kind of caching, but that would be more of a band-aid
        #       solution. we could also implement some kind of auto-grouping here;
        #       they should end up next to each other.
Leo Gao's avatar
Leo Gao committed
212

Leo Gao's avatar
Leo Gao committed
213
        print("Running", reqtype, "requests")
Leo Gao's avatar
Leo Gao committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        resps = getattr(lm, reqtype)([req.args for req in reqs])
        resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, requests)
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
233
234
235
236
237

            # Re-use the evaluation for the decontaminated set by just ignoring the overlaps
            if decontaminate and task_name in overlaps:
                if doc_id not in overlaps[task_name]:
                    vals[(task_name, metric + decontaminate_suffix)].append(value)
Leo Gao's avatar
Leo Gao committed
238
239
240
241
    
    # aggregate results
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
242
243
244
245
        real_metric = metric # key when looking up the metric with task.aggregation
        if metric.endswith(decontaminate_suffix):
            real_metric = metric.replace(decontaminate_suffix, "") # decontaminated still uses the same metric
        results[task_name][metric] = task.aggregation()[real_metric](items)
Leo Gao's avatar
Leo Gao committed
246

247
248
        # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
        # so we run them less iterations. still looking for a cleaner way to do this
249

250
        stderr = lm_eval.metrics.stderr_for_metric(
251
            metric=task.aggregation()[real_metric],
252
253
            bootstrap_iters=min(bootstrap_iters, 1000) if metric in ["bleu", "chrf", "ter"] else bootstrap_iters,
        )
254
        
Leo Gao's avatar
Leo Gao committed
255
256
        if stderr is not None:
            results[task_name][metric + "_stderr"] = stderr(items)
Leo Gao's avatar
Leo Gao committed
257
    
Leo Gao's avatar
Leo Gao committed
258
    return {
259
260
        "results": dict(results),
        "versions": dict(versions)
Leo Gao's avatar
Leo Gao committed
261
    }
262
263
264


def make_table(result_dict):
265
    """Generate table of results."""
266
267
268
269
270
271
272
273
274
275
276
277
    from pytablewriter import MarkdownTableWriter, LatexTableWriter

    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]

    values = []

    for k, dic in result_dict["results"].items():
        version = result_dict["versions"][k]
        for m, v in dic.items():
278
279
            if m.endswith("_stderr"):
                continue
280
281
282
283
284
285
286
287
288
289
290
291
292
293

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]
                values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
            else:
                values.append([k, version, m, '%.4f' % v, '', ''])
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

294
    return md_writer.dumps()