main.py 4.56 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Update  
Leo Gao committed
6
import collections
Leo Gao's avatar
Leo Gao committed
7

Jason Phang's avatar
lib  
Jason Phang committed
8
9
from lm_eval import models, tasks

Leo Gao's avatar
Leo Gao committed
10

Jason Phang's avatar
Jason Phang committed
11
12
13
14
15
16
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Jason Phang's avatar
lib  
Jason Phang committed
17
    parser.add_argument('--num_fewshot', type=int, default=1)
Jason Phang's avatar
seed  
Jason Phang committed
18
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
19
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
20
    parser.add_argument('--limit', type=int, default=None)
Jason Phang's avatar
Jason Phang committed
21
22
23
24
    return parser.parse_args()

def main():
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
25
26
27
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
28
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Jason Phang's avatar
Jason Phang committed
29
30
31
32
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
33
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
34

35
    task_dict_items = [(name, task) for name, task in task_dict.items() if(task.has_validation_docs() or task.has_test_docs())]
Leo Gao's avatar
Leo Gao committed
36

Leo Gao's avatar
Leo Gao committed
37
    results = collections.defaultdict(dict)
Leo Gao's avatar
Update  
Leo Gao committed
38
39

    requests = collections.defaultdict(list)
Leo Gao's avatar
Leo Gao committed
40
41
42
43
44
45
46
47
48
    requests_origin = collections.defaultdict(list)

    # if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
    # we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
    # (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable

    docs = {}
Leo Gao's avatar
Update  
Leo Gao committed
49

50
    # get lists of each type of requeste
Leo Gao's avatar
Update  
Leo Gao committed
51
    for task_name, task in task_dict_items:
52
        #default to validation doc, fall back to test doc if validation unavailable
Leo Gao's avatar
Leo Gao committed
53
        # TODO: the val-fallback-to-test system isn't final, we should revisit it at some point
54
55
56
57
58
59
        if task.has_validation_docs():
            task_doc_func = task.validation_docs
        elif task.has_test_docs():
            task_doc_func = task.test_docs

        for doc_id, doc in enumerate(itertools.islice(task_doc_func(), 0, args.limit)):
Leo Gao's avatar
Leo Gao committed
60
            docs[(task_name, doc_id)] = doc
Leo Gao's avatar
Update  
Leo Gao committed
61
62
63
64
65
66
67

            ctx = task.fewshot_context(
                doc=doc,
                provide_description=args.provide_description,
                num_fewshot=args.num_fewshot,
            )

68
            reqs = task.construct_requests(doc, ctx)
Leo Gao's avatar
Update  
Leo Gao committed
69

Leo Gao's avatar
Leo Gao committed
70
            for i, req in enumerate(reqs):
Leo Gao's avatar
Update  
Leo Gao committed
71
                requests[req.type].append(req)
72
73
                # i: index in requests for a single task instance
                # doc_id: unique id that we can get back to a doc using `docs`
Leo Gao's avatar
Leo Gao committed
74
75
                requests_origin[req.type].append((i, task_name, doc, doc_id))

76
    # all responses for each (task, doc)
Leo Gao's avatar
Leo Gao committed
77
78
    process_res_queue = collections.defaultdict(list)

79
    # execute each type of request
Leo Gao's avatar
Leo Gao committed
80
    for reqtype, reqs in requests.items():
81
82
83
84
        # TODO: right now, this code runs multiple seperate LM requests for multiple Requests differing
        # only in index. We could implement some kind of caching, but that would be more of a bandaid
        # solution. we could also implement some kind of autogrouping here; they should end up next to each other.

Leo Gao's avatar
Leo Gao committed
85
86
        resps = getattr(lm, reqtype)([req.args for req in reqs])

87
88
        resps = [x if req.index is None else x[req.index] for x, req in zip(resps, reqs)]

Leo Gao's avatar
Leo Gao committed
89
90
91
92
93
        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

94
    # unpack results and sort back in order and return control to Task
Leo Gao's avatar
Leo Gao committed
95
96
97
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]
Leo Gao's avatar
Leo Gao committed
98
99
100
101

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

Leo Gao's avatar
Leo Gao committed
102
        metrics = task.process_results(doc, requests)
103
104
        for metric, value in metrics.items():
            vals[(task_name, metric)].append(value)
Leo Gao's avatar
Leo Gao committed
105
    
106
    # aggregate results
107
108
109
    for (task_name, metric), items in vals.items():
        task = task_dict[task_name]
        results[task_name][metric] = task.aggregation()[metric](items)
Leo Gao's avatar
Update  
Leo Gao committed
110

Jason Phang's avatar
Jason Phang committed
111
112
113
114
115
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
116

Jason Phang's avatar
lib  
Jason Phang committed
117

Jason Phang's avatar
Jason Phang committed
118
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
119
    main()