main.py 3.03 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Leo Gao's avatar
Leo Gao committed
5
import logging
Leo Gao's avatar
Leo Gao committed
6

Leo Gao's avatar
Leo Gao committed
7
from lm_eval import models, tasks, evaluator, base
Jason Phang's avatar
lib  
Jason Phang committed
8

Leo Gao's avatar
Leo Gao committed
9
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
10

Jason Phang's avatar
Jason Phang committed
11
12
13
14
15
16
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Leo Gao's avatar
Leo Gao committed
17
    parser.add_argument('--num_fewshot', type=int, default=0)
Leo Gao's avatar
Leo Gao committed
18
    parser.add_argument('--batch_size', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
19
    parser.add_argument('--device', type=str, default=None)
Jason Phang's avatar
seed  
Jason Phang committed
20
    parser.add_argument('--seed', type=int, default=1234)
Jason Phang's avatar
Jason Phang committed
21
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
22
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
23
    parser.add_argument('--no_cache', action="store_true")
Jason Phang's avatar
Jason Phang committed
24
25
26
    return parser.parse_args()

def main():
Leo Gao's avatar
Leo Gao committed
27

Jason Phang's avatar
Jason Phang committed
28
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
29
30
31
    random.seed(args.seed)
    np.random.seed(args.seed)

32
33
34
    lm = models.get_model(args.model).create_from_arg_string(args.model_args, {
        'batch_size': args.batch_size, 'device': args.device
    })
Leo Gao's avatar
Leo Gao committed
35
36
37
    
    if args.limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
Leo Gao's avatar
Leo Gao committed
38

Leo Gao's avatar
Leo Gao committed
39
    if not args.no_cache:
Leo Gao's avatar
Leo Gao committed
40
        lm = base.CachingLM(lm, 'lm_cache/' + args.model + '_' + args.model_args.replace('=', '-').replace(',', '_').replace('/', '-') + '.db')
Jason Phang's avatar
Jason Phang committed
41
42
43
44
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Jason Phang's avatar
cleanup  
Jason Phang committed
45
    task_dict = tasks.get_task_dict(task_names)
Leo Gao's avatar
Leo Gao committed
46

Leo Gao's avatar
Leo Gao committed
47
    results = evaluator.evaluate(lm, task_dict, args.provide_description, args.num_fewshot, args.limit)
Leo Gao's avatar
Update  
Leo Gao committed
48

Jason Phang's avatar
Jason Phang committed
49
50
51
52
53
    dumped = json.dumps(results, indent=2)
    print(dumped)
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
54

Leo Gao's avatar
Leo Gao committed
55
    # MAKE TABLE
Leo Gao's avatar
Leo Gao committed
56
    from pytablewriter import MarkdownTableWriter, LatexTableWriter
Leo Gao's avatar
Leo Gao committed
57

Leo Gao's avatar
Leo Gao committed
58
59
60
61
    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
    latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
Leo Gao's avatar
Leo Gao committed
62
63
64

    values = []

Leo Gao's avatar
Leo Gao committed
65
66
    for k, dic in results["results"].items():
        version = results["versions"][k]
Leo Gao's avatar
Leo Gao committed
67
        for m, v in dic.items():
Leo Gao's avatar
Leo Gao committed
68
69
70
71
72
73
74
75
            if m.endswith("_stderr"): continue

            if m + "_stderr" in dic:
                se = dic[m + "_stderr"]

                values.append([k, version, m, '%.4f' % v, '±', '%.4f' % se])
            else:
                values.append([k, version, m, '%.4f' % v, '', ''])
Leo Gao's avatar
Leo Gao committed
76
            k = ""
Leo Gao's avatar
Leo Gao committed
77
78
79
80
81
82
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())
Leo Gao's avatar
Leo Gao committed
83

84
    print(f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}")
Leo Gao's avatar
Leo Gao committed
85
    print(md_writer.dumps())
Jason Phang's avatar
lib  
Jason Phang committed
86

Jason Phang's avatar
Jason Phang committed
87
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
88
    main()