main.py 2.07 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Leo Gao's avatar
Leo Gao committed
5
import logging
Leo Gao's avatar
Leo Gao committed
6

Leo Gao's avatar
Leo Gao committed
7
from lm_eval import models, tasks, evaluator, base
Jason Phang's avatar
lib  
Jason Phang committed
8

Leo Gao's avatar
Leo Gao committed
9
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
10

Jason Phang's avatar
Jason Phang committed
11
12
13
14
15
16
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Leo Gao's avatar
Leo Gao committed
17
    parser.add_argument('--num_fewshot', type=int, default=0)
Jason Phang's avatar
seed  
Jason Phang committed
18
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
19
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
20
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
21
    parser.add_argument('--no_cache', action="store_true")
Jason Phang's avatar
Jason Phang committed
22
23
24
    return parser.parse_args()

def main():
Leo Gao's avatar
Leo Gao committed
25

Jason Phang's avatar
Jason Phang committed
26
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
27
28
29
    random.seed(args.seed)
    np.random.seed(args.seed)

Jason Phang's avatar
lib  
Jason Phang committed
30
    lm = models.get_model(args.model).create_from_arg_string(args.model_args)
Leo Gao's avatar
Leo Gao committed
31
32
33
    
    if args.limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
Leo Gao's avatar
Leo Gao committed
34

Leo Gao's avatar
Leo Gao committed
35
    if not args.no_cache:
Leo Gao's avatar
Leo Gao committed
36
        lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db')
Jason Phang's avatar
Jason Phang committed
37
38
39
40
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
41
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
42

Leo Gao's avatar
Leo Gao committed
43
    results = evaluator.evaluate(lm, task_dict, args.provide_description, args.num_fewshot, args.limit)
Leo Gao's avatar
Update  
Leo Gao committed
44

Jason Phang's avatar
Jason Phang committed
45
46
47
48
49
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
50

Leo Gao's avatar
Leo Gao committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    # MAKE TABLE
    from pytablewriter import MarkdownTableWriter

    writer = MarkdownTableWriter()
    writer.headers = ["Task", "Metric", "Value"]

    values = []

    for k, dic in results.items():
        for m, v in dic.items():
            values.append([k, m, '%.4f' % v])
            k = ""
    writer.value_matrix = values

    print(writer.dumps())
Jason Phang's avatar
lib  
Jason Phang committed
66

Jason Phang's avatar
Jason Phang committed
67
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
68
    main()