main.py 3.63 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Update  
Leo Gao committed
6
import collections
Leo Gao's avatar
Leo Gao committed
7

Jason Phang's avatar
lib  
Jason Phang committed
8
9
from lm_eval import models, tasks

Leo Gao's avatar
Leo Gao committed
10

Jason Phang's avatar
Jason Phang committed
11
12
13
14
15
16
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Jason Phang's avatar
lib  
Jason Phang committed
17
    parser.add_argument('--num_fewshot', type=int, default=1)
Jason Phang's avatar
seed  
Jason Phang committed
18
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
19
    parser.add_argument('--output_path', default=None)
20
    parser.add_argument('--limit', default=None)
Jason Phang's avatar
Jason Phang committed
21
22
23
24
    return parser.parse_args()

def main():
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
25
26
27
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
28
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Jason Phang's avatar
Jason Phang committed
29
30
31
32
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
33
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
34
35
36
37

    # TODO: fall back to test docs
    task_dict_items = [(name, task) for name, task in task_dict.items() if task.has_validation_docs()]

Jason Phang's avatar
Jason Phang committed
38
    results = {}
Leo Gao's avatar
Update  
Leo Gao committed
39
40

    requests = collections.defaultdict(list)
Leo Gao's avatar
Leo Gao committed
41
42
43
44
45
46
47
48
49
    requests_origin = collections.defaultdict(list)

    # if we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger memory,
    # we can always modify this plumbing to support that, but i didn't want to include it just yet because overengineering is bad
    # (or we could make it write the requests to disk and then read them back out again - probably using an sqlite db because of all the moving parts we have

    # TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable

    docs = {}
Leo Gao's avatar
Update  
Leo Gao committed
50
51

    for task_name, task in task_dict_items:
Leo Gao's avatar
Leo Gao committed
52
53
        for doc_id, doc in enumerate(itertools.islice(task.validation_docs(), 0, args.limit)):
            docs[(task_name, doc_id)] = doc
Leo Gao's avatar
Update  
Leo Gao committed
54
55
56
57
58
59
60
61
62

            ctx = task.fewshot_context(
                doc=doc,
                provide_description=args.provide_description,
                num_fewshot=args.num_fewshot,
            )

            reqs = task.construct_requests(ctx)

Leo Gao's avatar
Leo Gao committed
63
            for i, req in enumerate(reqs):
Leo Gao's avatar
Update  
Leo Gao committed
64
                requests[req.type].append(req)
Leo Gao's avatar
Leo Gao committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                requests_origin[req.type].append((i, task_name, doc, doc_id))

    process_res_queue = collections.defaultdict(list)

    for reqtype, reqs in requests.items():
        resps = getattr(lm, reqtype)([req.args for req in reqs])

        for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
            process_res_queue[(task_name, doc_id)].append((i, resp))
    
    vals = collections.defaultdict(list)

    for (task_name, doc_id), args in process_res_queue.items():
        args.sort(lambda x: x[0])
        args = [x[1] for x in args]

        task = task_dict[task_name]
        doc = docs[(task_name, doc_id)]

        metrics = task.process_results(doc, args)
        for metric in metrics:
            results[(task_name, metric['submetric'])] = {
                "higher_is_better": metric["higher_is_better"],
                "aggregation": metric["aggregation"]
            }
            vals[(task_name, metric['submetric'])].append(metric['value'])
    
    for k in results.keys():
        results[k]['value'] = results[k]['aggregation'](vals[k])

        # can't serialize a function
        del results[k]['aggregation']
Leo Gao's avatar
Update  
Leo Gao committed
97

Jason Phang's avatar
Jason Phang committed
98

Jason Phang's avatar
Jason Phang committed
99
100
101
102
103
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
104

Jason Phang's avatar
lib  
Jason Phang committed
105

Jason Phang's avatar
Jason Phang committed
106
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
107
    main()