main.py 1.59 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Update  
Leo Gao committed
6
import collections
Leo Gao's avatar
Leo Gao committed
7
import logging
Leo Gao's avatar
Leo Gao committed
8

Leo Gao's avatar
Leo Gao committed
9
from lm_eval import models, tasks, evaluator, base
Jason Phang's avatar
lib  
Jason Phang committed
10

Leo Gao's avatar
Leo Gao committed
11
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
12

Jason Phang's avatar
Jason Phang committed
13
14
15
16
17
18
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Jason Phang's avatar
lib  
Jason Phang committed
19
    parser.add_argument('--num_fewshot', type=int, default=1)
Jason Phang's avatar
seed  
Jason Phang committed
20
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
21
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
22
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
23
    parser.add_argument('--cache', action="store_true")
Jason Phang's avatar
Jason Phang committed
24
25
26
    return parser.parse_args()

def main():
Leo Gao's avatar
Leo Gao committed
27

Jason Phang's avatar
Jason Phang committed
28
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
29
30
31
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
32
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Leo Gao's avatar
Leo Gao committed
33
34
35

    if args.cache:
        lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db')
Jason Phang's avatar
Jason Phang committed
36
37
38
39
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
40
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
41

Leo Gao's avatar
Leo Gao committed
42
    results = evaluator.evaluate(lm, task_dict, args.provide_description, args.num_fewshot, args.limit)
Leo Gao's avatar
Update  
Leo Gao committed
43

Jason Phang's avatar
Jason Phang committed
44
45
46
47
48
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
49

Jason Phang's avatar
lib  
Jason Phang committed
50

Jason Phang's avatar
Jason Phang committed
51
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
52
    main()