main.py 1.52 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Update  
Leo Gao committed
6
import collections
Leo Gao's avatar
Leo Gao committed
7

Leo Gao's avatar
Leo Gao committed
8
from lm_eval import models, tasks, evaluator, base
Jason Phang's avatar
lib  
Jason Phang committed
9

Leo Gao's avatar
Leo Gao committed
10

Jason Phang's avatar
Jason Phang committed
11
12
13
14
15
16
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Jason Phang's avatar
lib  
Jason Phang committed
17
    parser.add_argument('--num_fewshot', type=int, default=1)
Jason Phang's avatar
seed  
Jason Phang committed
18
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
19
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
20
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
21
    parser.add_argument('--cache', action="store_true")
Jason Phang's avatar
Jason Phang committed
22
23
24
    return parser.parse_args()

def main():
Leo Gao's avatar
Leo Gao committed
25

Jason Phang's avatar
Jason Phang committed
26
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
27
28
29
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
30
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Leo Gao's avatar
Leo Gao committed
31
32
33

    if args.cache:
        lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db')
Jason Phang's avatar
Jason Phang committed
34
35
36
37
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
38
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
39

Leo Gao's avatar
Leo Gao committed
40
    results = evaluator.evaluate(lm, task_dict, args.provide_description, args.num_fewshot, args.limit)
Leo Gao's avatar
Update  
Leo Gao committed
41

Jason Phang's avatar
Jason Phang committed
42
43
44
45
46
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
47

Jason Phang's avatar
lib  
Jason Phang committed
48

Jason Phang's avatar
Jason Phang committed
49
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
50
    main()