main.py 3.77 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Update  
Leo Gao committed
6
import collections
Leo Gao's avatar
Leo Gao committed
7

Jason Phang's avatar
lib  
Jason Phang committed
8
9
from lm_eval import models, tasks

Leo Gao's avatar
Leo Gao committed
10

Jason Phang's avatar
Jason Phang committed
11
12
13
14
15
16
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Jason Phang's avatar
lib  
Jason Phang committed
17
    parser.add_argument('--num_fewshot', type=int, default=1)
Jason Phang's avatar
seed  
Jason Phang committed
18
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
19
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
20
    parser.add_argument('--limit', type=int, default=None)
Jason Phang's avatar
Jason Phang committed
21
22
23
24
    return parser.parse_args()

def main():
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
25
26
27
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
28
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Jason Phang's avatar
Jason Phang committed
29
30
31
32
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
33
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
34
35
36
37

    # TODO: fall back to test docs
    task_dict_items = [(name, task) for name, task in task_dict.items() if task.has_validation_docs()]

Leo Gao's avatar
Leo Gao committed
38
    results = collections.defaultdict(dict)
Leo Gao's avatar
Update  
Leo Gao committed
39
40

    requests = collections.defaultdict(list)
Leo Gao's avatar
Leo Gao committed
41
42
43
44
45
46
47
48
49
    requests_origin = collections.defaultdict(list)

    # if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
    # we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
    # (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable

    docs = {}
Leo Gao's avatar
Update  
Leo Gao committed
50
51

    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
52
53
        for doc_id, doc in enumerate(itertools.islice(task.validation_docs(), 0, args.limit)):
            docs[(task_name, doc_id)] = doc
Leo Gao's avatar
Update  
Leo Gao committed
54
55
56
57
58
59
60
61
62

            ctx = task.fewshot_context(
                doc=doc,
                provide_description=args.provide_description,
                num_fewshot=args.num_fewshot,
            )

            reqs = task.construct_requests(ctx)

Leo Gao's avatar
Leo Gao committed
63
            for i, req in enumerate(reqs):
Leo Gao's avatar
Update  
Leo Gao committed
64
                requests[req.type].append(req)
Leo Gao's avatar
Leo Gao committed
65
66
67
68
69
70
71
72
73
74
75
76
                requests_origin[req.type].append((i, task_name, doc, doc_id))

    process_res_queue = collections.defaultdict(list)

    for reqtype, reqs in requests.items():
        resps = getattr(lm, reqtype)([req.args for req in reqs])

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

Leo Gao's avatar
Leo Gao committed
77
78
79
    for (task_name, doc_id), requests in process_res_queue.items():
        requests.sort(key=lambda x: x[0])
        requests = [x[1] for x in requests]
Leo Gao's avatar
Leo Gao committed
80
81
82
83

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

Leo Gao's avatar
Leo Gao committed
84
        metrics = task.process_results(doc, requests)
Leo Gao's avatar
Leo Gao committed
85
        for metric in metrics:
Leo Gao's avatar
Leo Gao committed
86
            results[task_name][metric['submetric']] = {
Leo Gao's avatar
Leo Gao committed
87
88
89
90
91
                "higher_is_better": metric["higher_is_better"],
                "aggregation": metric["aggregation"]
            }
            vals[(task_name, metric['submetric'])].append(metric['value'])
    
Leo Gao's avatar
Leo Gao committed
92
93
94
    for task_name, submetrics in results.items():
        for k in submetrics.keys():
            submetrics[k]['value'] = submetrics[k]['aggregation'](vals[(task_name, k)])
Leo Gao's avatar
Leo Gao committed
95

Leo Gao's avatar
Leo Gao committed
96
97
            # can't serialize a function
            del submetrics[k]['aggregation']
Leo Gao's avatar
Update  
Leo Gao committed
98

Jason Phang's avatar
Jason Phang committed
99
100
101
102
103
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
104

Jason Phang's avatar
lib  
Jason Phang committed
105

Jason Phang's avatar
Jason Phang committed
106
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
107
    main()