main.py 2.15 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3
import logging
Leo Gao's avatar
Leo Gao committed
4

5
from lm_eval import tasks, evaluator
Jason Phang's avatar
lib  
Jason Phang committed
6

Leo Gao's avatar
Leo Gao committed
7
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
8

9

Jason Phang's avatar
Jason Phang committed
10
11
12
13
14
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
15
    parser.add_argument('--provide_description', action="store_true")
Leo Gao's avatar
Leo Gao committed
16
    parser.add_argument('--num_fewshot', type=int, default=0)
Leo Gao's avatar
Leo Gao committed
17
    parser.add_argument('--batch_size', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
18
    parser.add_argument('--device', type=str, default=None)
Jason Phang's avatar
Jason Phang committed
19
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
20
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
21
    parser.add_argument('--no_cache', action="store_true")
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
22
    parser.add_argument('--description_dict_path', default=None)
Jason Phang's avatar
Jason Phang committed
23
24
    return parser.parse_args()

Leo Gao's avatar
Leo Gao committed
25

26
def main():
Jason Phang's avatar
Jason Phang committed
27
    args = parse_args()
28
    assert not args.provide_description  # not implemented
Leo Gao's avatar
Leo Gao committed
29
30
31
    
    if args.limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
Leo Gao's avatar
Leo Gao committed
32

Jason Phang's avatar
Jason Phang committed
33
34
35
36
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Leo Gao's avatar
Leo Gao committed
37

38
39
40
41
42
    description_dict = {}
    if args.description_dict_path:
        with open(args.description_dict_path, 'r') as f:
            description_dict = json.load(f)

43
    results = evaluator.simple_evaluate(
44
45
46
47
48
49
50
51
        model=args.model,
        model_args=args.model_args,
        task_names=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
        device=args.device,
        no_cache=args.no_cache,
        limit=args.limit,
52
        description_dict=description_dict
53
    )
Leo Gao's avatar
Update  
Leo Gao committed
54

Jason Phang's avatar
Jason Phang committed
55
    dumped = json.dumps(results, indent=2)
56
    
Jason Phang's avatar
Jason Phang committed
57
    print(dumped)
58

Jason Phang's avatar
Jason Phang committed
59
60
61
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
62

63
64
65
66
    print(
        f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
        f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
    )
67
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
68

69

Jason Phang's avatar
Jason Phang committed
70
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
71
    main()