main.py 2.32 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Update  
Leo Gao committed
6
import collections
Leo Gao's avatar
Leo Gao committed
7

Jason Phang's avatar
lib  
Jason Phang committed
8
9
from lm_eval import models, tasks

Leo Gao's avatar
Leo Gao committed
10

Jason Phang's avatar
Jason Phang committed
11
12
13
14
15
16
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Jason Phang's avatar
lib  
Jason Phang committed
17
    parser.add_argument('--num_fewshot', type=int, default=1)
Jason Phang's avatar
seed  
Jason Phang committed
18
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
19
    parser.add_argument('--output_path', default=None)
20
    parser.add_argument('--limit', default=None)
Jason Phang's avatar
Jason Phang committed
21
22
23
24
    return parser.parse_args()

def main():
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
25
26
27
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
28
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Jason Phang's avatar
Jason Phang committed
29
30
31
32
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
33
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Update  
Leo Gao committed
34
    task_dict_items = list(task_dict.items())
Jason Phang's avatar
Jason Phang committed
35
    results = {}
Leo Gao's avatar
Update  
Leo Gao committed
36
37
38
39
40
41

    requests = collections.defaultdict(list)
    requests_lengths = collections.defaultdict(list)

    for task_name, task in task_dict_items:
        # TODO: fall back to test docs
Jason Phang's avatar
Jason Phang committed
42
43
        if not task.has_validation_docs():
            continue
Leo Gao's avatar
Update  
Leo Gao committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

        for doc in itertools.islice(task.validation_docs(), 0, args.limit):
            ctx = task.fewshot_context(
                doc=doc,
                provide_description=args.provide_description,
                num_fewshot=args.num_fewshot,
            )

            reqs = task.construct_requests(ctx)

            lengths = collections.defaultdict(int)

            for req in reqs:
                requests[req.type].append(req)
                lengths[req.type] += 1
            
            for type, ct in lengths.items():
                requests_lengths[type].append(ct)

    # TODO: finish implementation
    for reqname, reqs in requests.items():
        lm_res = getattr(lm, reqname)([req.args for req in reqs])

    for task_name, task in task_dict_items:
        if not task.has_validation_docs():
            continue

Jason Phang's avatar
Jason Phang committed
71

Jason Phang's avatar
Jason Phang committed
72
73
74
75
76
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
77

Jason Phang's avatar
lib  
Jason Phang committed
78

Jason Phang's avatar
Jason Phang committed
79
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
80
    main()