"lm_eval/_cli/evaluate.py" did not exist on "82517de7b6aa5b465747502f40236da7a2944381"
main.py 4.24 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Update  
Leo Gao committed
6
import collections
Leo Gao's avatar
Leo Gao committed
7

Jason Phang's avatar
lib  
Jason Phang committed
8
9
from lm_eval import models, tasks

Leo Gao's avatar
Leo Gao committed
10

Jason Phang's avatar
Jason Phang committed
11
12
13
14
15
16
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Jason Phang's avatar
lib  
Jason Phang committed
17
    parser.add_argument('--num_fewshot', type=int, default=1)
Jason Phang's avatar
seed  
Jason Phang committed
18
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
19
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
20
    parser.add_argument('--limit', type=int, default=None)
Jason Phang's avatar
Jason Phang committed
21
22
23
24
    return parser.parse_args()

def main():
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
25
26
27
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
28
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Jason Phang's avatar
Jason Phang committed
29
30
31
32
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
33
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
34
35
36
37

    # TODO: fall back to test docs
    task_dict_items = [(name, task) for name, task in task_dict.items() if task.has_validation_docs()]

Leo Gao's avatar
Leo Gao committed
38
    results = collections.defaultdict(dict)
Leo Gao's avatar
Update  
Leo Gao committed
39
40

    requests = collections.defaultdict(list)
Leo Gao's avatar
Leo Gao committed
41
42
43
44
45
46
47
48
49
    requests_origin = collections.defaultdict(list)

    # if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
    # we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
    # (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable

    docs = {}
Leo Gao's avatar
Update  
Leo Gao committed
50

51
    # get lists of each type of requeste
Leo Gao's avatar
Update  
Leo Gao committed
52
    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
53
54
        for doc_id, doc in enumerate(itertools.islice(task.validation_docs(), 0, args.limit)):
            docs[(task_name, doc_id)] = doc
Leo Gao's avatar
Update  
Leo Gao committed
55
56
57
58
59
60
61

            ctx = task.fewshot_context(
                doc=doc,
                provide_description=args.provide_description,
                num_fewshot=args.num_fewshot,
            )

62
            reqs = task.construct_requests(doc, ctx)
Leo Gao's avatar
Update  
Leo Gao committed
63

Leo Gao's avatar
Leo Gao committed
64
            for i, req in enumerate(reqs):
Leo Gao's avatar
Update  
Leo Gao committed
65
                requests[req.type].append(req)
66
67
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
68
69
                requests_origin[req.type].append((i, task_name, doc, doc_id))

70
    # all responses for each (task, doc)
Leo Gao's avatar
Leo Gao committed
71
72
    process_res_queue = collections.defaultdict(list)

73
    # execute each type of request
Leo Gao's avatar
Leo Gao committed
74
    for reqtype, reqs in requests.items():
75
76
77
78
        # TODO: right now, this code runs multiple seperate LM requests for multiple Requests differing
        # only in index. We could implement some kind of caching, but that would be more of a bandaid
        # solution. we could also implement some kind of autogrouping here; they should end up next to each other.

Leo Gao's avatar
Leo Gao committed
79
80
        resps = getattr(lm, reqtype)([req.args for req in reqs])

81
82
        resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]

Leo Gao's avatar
Leo Gao committed
83
84
85
86
87
        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

88
    # unpack results and sort back in order and return control to Task
Leo Gao's avatar
Leo Gao committed
89
90
91
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]
Leo Gao's avatar
Leo Gao committed
92
93
94
95

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

Leo Gao's avatar
Leo Gao committed
96
        metrics = task.process_results(doc, requests)
97
98
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
Leo Gao's avatar
Leo Gao committed
99
    
100
    # aggregate results
101
102
103
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
        results[task_name][metric] = task.aggregation()[metric](items)
Leo Gao's avatar
Update  
Leo Gao committed
104

Jason Phang's avatar
Jason Phang committed
105
106
107
108
109
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
110

Jason Phang's avatar
lib  
Jason Phang committed
111

Jason Phang's avatar
Jason Phang committed
112
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
113
    main()