main.py 3.79 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Update  
Leo Gao committed
6
import collections
Leo Gao's avatar
Leo Gao committed
7

Jason Phang's avatar
lib  
Jason Phang committed
8
9
from lm_eval import models, tasks

Leo Gao's avatar
Leo Gao committed
10

Jason Phang's avatar
Jason Phang committed
11
12
13
14
15
16
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Jason Phang's avatar
lib  
Jason Phang committed
17
    parser.add_argument('--num_fewshot', type=int, default=1)
Jason Phang's avatar
seed  
Jason Phang committed
18
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
19
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
20
    parser.add_argument('--limit', type=int, default=None)
Jason Phang's avatar
Jason Phang committed
21
22
23
24
    return parser.parse_args()

def main():
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
25
26
27
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
28
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Jason Phang's avatar
Jason Phang committed
29
30
31
32
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
33
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
34
35
36
37

    # TODO: fall back to test docs
    task_dict_items = [(name, task) for name, task in task_dict.items() if task.has_validation_docs()]

Leo Gao's avatar
Leo Gao committed
38
    results = collections.defaultdict(dict)
Leo Gao's avatar
Update  
Leo Gao committed
39
40

    requests = collections.defaultdict(list)
Leo Gao's avatar
Leo Gao committed
41
42
43
44
45
46
47
48
49
    requests_origin = collections.defaultdict(list)

    # if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
    # we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
    # (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable

    docs = {}
Leo Gao's avatar
Update  
Leo Gao committed
50
51

    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
52
53
        for doc_id, doc in enumerate(itertools.islice(task.validation_docs(), 0, args.limit)):
            docs[(task_name, doc_id)] = doc
Leo Gao's avatar
Update  
Leo Gao committed
54
55
56
57
58
59
60
61
62

            ctx = task.fewshot_context(
                doc=doc,
                provide_description=args.provide_description,
                num_fewshot=args.num_fewshot,
            )

            reqs = task.construct_requests(ctx)

Leo Gao's avatar
Leo Gao committed
63
            for i, req in enumerate(reqs):
Leo Gao's avatar
Update  
Leo Gao committed
64
                requests[req.type].append(req)
Leo Gao's avatar
Leo Gao committed
65
66
67
68
69
70
71
72
73
74
75
76
                requests_origin[req.type].append((i, task_name, doc, doc_id))

    process_res_queue = collections.defaultdict(list)

    for reqtype, reqs in requests.items():
        resps = getattr(lm, reqtype)([req.args for req in reqs])

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

Leo Gao's avatar
Leo Gao committed
77
78
79
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]
Leo Gao's avatar
Leo Gao committed
80
81
82
83

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

Leo Gao's avatar
Leo Gao committed
84
        metrics = task.process_results(doc, requests)
Leo Gao's avatar
Leo Gao committed
85
        for metric in metrics:
Leo Gao's avatar
Leo Gao committed
86
            results[task_name][metric['submetric']] = {
Leo Gao's avatar
Leo Gao committed
87
88
89
90
91
                "higher_is_better": metric["higher_is_better"],
                "aggregation": metric["aggregation"]
            }
            vals[(task_name, metric['submetric'])].append(metric['value'])
    
Leo Gao's avatar
Leo Gao committed
92
93
94
    for task_name, submetrics in results.items():
        for k in submetrics.keys():
            submetrics[k]['value'] = submetrics[k]['aggregation'](vals[(task_name, k)])
Leo Gao's avatar
Leo Gao committed
95

Leo Gao's avatar
Leo Gao committed
96
97
            # can't serialize a function
            del submetrics[k]['aggregation']
Leo Gao's avatar
Update  
Leo Gao committed
98

Leo Gao's avatar
Leo Gao committed
99
    print(results)
Jason Phang's avatar
Jason Phang committed
100

Jason Phang's avatar
Jason Phang committed
101
102
103
104
105
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
106

Jason Phang's avatar
lib  
Jason Phang committed
107

Jason Phang's avatar
Jason Phang committed
108
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
109
    main()