main.py 1.56 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Leo Gao committed
6

Jason Phang's avatar
lib  
Jason Phang committed
7
8
from lm_eval import models, tasks

Leo Gao's avatar
Leo Gao committed
9

Jason Phang's avatar
Jason Phang committed
10
11
12
13
14
15
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Jason Phang's avatar
lib  
Jason Phang committed
16
    parser.add_argument('--num_fewshot', type=int, default=1)
Jason Phang's avatar
seed  
Jason Phang committed
17
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
18
    parser.add_argument('--output_path', default=None)
19
    parser.add_argument('--limit', default=None)
Jason Phang's avatar
Jason Phang committed
20
21
22
23
    return parser.parse_args()

def main():
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
24
25
26
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
27
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Jason Phang's avatar
Jason Phang committed
28
29
30
31
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
32
    task_dict = tasks.get_task_dict(task_names)
Jason Phang's avatar
Jason Phang committed
33
    results = {}
Jason Phang's avatar
lib  
Jason Phang committed
34
    for task_name, task in task_dict.items():
Jason Phang's avatar
Jason Phang committed
35
36
37
        if not task.has_validation_docs():
            continue
        result = task.evaluate(
38
            docs=itertools.isslice(task.validation_docs(), 0, args.limit),
Jason Phang's avatar
lib  
Jason Phang committed
39
            lm=lm,
Jason Phang's avatar
Jason Phang committed
40
            provide_description=args.provide_description,
Jason Phang's avatar
lib  
Jason Phang committed
41
            num_fewshot=args.num_fewshot,
Jason Phang's avatar
Jason Phang committed
42
43
        )
        results[task_name] = result
Jason Phang's avatar
Jason Phang committed
44

Jason Phang's avatar
Jason Phang committed
45
46
47
48
49
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
50

Jason Phang's avatar
lib  
Jason Phang committed
51

Jason Phang's avatar
Jason Phang committed
52
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
53
    main()