main.py 2.09 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Stella Biderman's avatar
Stella Biderman committed
5
import itertools
Leo Gao's avatar
Update  
Leo Gao committed
6
import collections
Leo Gao's avatar
Leo Gao committed
7
import logging
Leo Gao's avatar
Leo Gao committed
8

Leo Gao's avatar
Leo Gao committed
9
from lm_eval import models, tasks, evaluator, base
Jason Phang's avatar
lib  
Jason Phang committed
10

Leo Gao's avatar
Leo Gao committed
11
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
12

Jason Phang's avatar
Jason Phang committed
13
14
15
16
17
18
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Leo Gao's avatar
Leo Gao committed
19
    parser.add_argument('--num_fewshot', type=int, default=0)
Jason Phang's avatar
seed  
Jason Phang committed
20
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
21
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
22
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
23
    parser.add_argument('--no_cache', action="store_true")
Jason Phang's avatar
Jason Phang committed
24
25
26
    return parser.parse_args()

def main():
Leo Gao's avatar
Leo Gao committed
27

Jason Phang's avatar
Jason Phang committed
28
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
29
30
31
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
32
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Leo Gao's avatar
Leo Gao committed
33
34
35
    
    if args.limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
Leo Gao's avatar
Leo Gao committed
36

Leo Gao's avatar
Leo Gao committed
37
    if not args.no_cache:
Leo Gao's avatar
Leo Gao committed
38
        lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_') + '.db')
Jason Phang's avatar
Jason Phang committed
39
40
41
42
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
43
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
44

Leo Gao's avatar
Leo Gao committed
45
    results = evaluator.evaluate(lm, task_dict, args.provide_description, args.num_fewshot, args.limit)
Leo Gao's avatar
Update  
Leo Gao committed
46

Jason Phang's avatar
Jason Phang committed
47
48
49
50
51
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
52

Leo Gao's avatar
Leo Gao committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    # MAKE TABLE
    from pytablewriter import MarkdownTableWriter

    writer = MarkdownTableWriter()
    writer.headers = ["Task", "Metric", "Value"]

    values = []

    for k, dic in results.items():
        for m, v in dic.items():
            values.append([k, m, '%.4f' % v])
            k = ""
    writer.value_matrix = values

    print(writer.dumps())
Jason Phang's avatar
lib  
Jason Phang committed
68

Jason Phang's avatar
Jason Phang committed
69
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
70
    main()