main.py 1.95 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Update  
Leo Gao committed
6
import collections
Leo Gao's avatar
Leo Gao committed
7
import logging
Leo Gao's avatar
Leo Gao committed
8

Leo Gao's avatar
Leo Gao committed
9
from lm_eval import models, tasks, evaluator, base
Jason Phang's avatar
lib  
Jason Phang committed
10

Leo Gao's avatar
Leo Gao committed
11
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
12

Jason Phang's avatar
Jason Phang committed
13
14
15
16
17
18
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Leo Gao's avatar
Leo Gao committed
19
    parser.add_argument('--num_fewshot', type=int, default=0)
Jason Phang's avatar
seed  
Jason Phang committed
20
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
21
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
22
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
23
    parser.add_argument('--no_cache', action="store_true")
Jason Phang's avatar
Jason Phang committed
24
25
26
    return parser.parse_args()

def main():
Leo Gao's avatar
Leo Gao committed
27

Jason Phang's avatar
Jason Phang committed
28
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
29
30
31
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
32
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Leo Gao's avatar
Leo Gao committed
33

Leo Gao's avatar
Leo Gao committed
34
    if not args.no_cache:
Leo Gao's avatar
Leo Gao committed
35
        lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db')
Jason Phang's avatar
Jason Phang committed
36
37
38
39
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
40
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
41

Leo Gao's avatar
Leo Gao committed
42
    results = evaluator.evaluate(lm, task_dict, args.provide_description, args.num_fewshot, args.limit)
Leo Gao's avatar
Update  
Leo Gao committed
43

Jason Phang's avatar
Jason Phang committed
44
45
46
47
48
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
49

Leo Gao's avatar
Leo Gao committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    # MAKE TABLE
    from pytablewriter import MarkdownTableWriter

    writer = MarkdownTableWriter()
    writer.headers = ["Task", "Metric", "Value"]

    values = []

    for k, dic in results.items():
        for m, v in dic.items():
            values.append([k, m, '%.4f' % v])
            k = ""
    writer.value_matrix = values

    print(writer.dumps())
Jason Phang's avatar
lib  
Jason Phang committed
65

Jason Phang's avatar
Jason Phang committed
66
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
67
    main()