main.py 2.26 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3
import logging
Leo Gao's avatar
Leo Gao committed
4

5
from lm_eval import tasks, evaluator
Jason Phang's avatar
lib  
Jason Phang committed
6

Leo Gao's avatar
Leo Gao committed
7
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
8

9

Jason Phang's avatar
Jason Phang committed
10
11
def parse_args():
    parser = argparse.ArgumentParser()
12
13
14
15
16
17
18
19
20
21
22
23
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
    parser.add_argument('--num_fewshot', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=None)
    parser.add_argument('--device', type=str, default=None)
    parser.add_argument('--output_path', default=None)
    parser.add_argument('--limit', type=int, default=None)
    parser.add_argument('--no_cache', action="store_true")
    parser.add_argument('--description_dict_path', default=None)
    parser.add_argument('--check_integrity', action="store_true")
Jason Phang's avatar
Jason Phang committed
24
25
    return parser.parse_args()

Leo Gao's avatar
Leo Gao committed
26

27
def main():
Jason Phang's avatar
Jason Phang committed
28
    args = parse_args()
29
    assert not args.provide_description  # not implemented
30
    
Leo Gao's avatar
Leo Gao committed
31
    if args.limit:
32
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
Leo Gao's avatar
Leo Gao committed
33

Jason Phang's avatar
Jason Phang committed
34
35
36
37
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Leo Gao's avatar
Leo Gao committed
38

39
40
    description_dict = {}
    if args.description_dict_path:
41
        with open(args.description_dict_path, 'r') as f:
42
43
            description_dict = json.load(f)

44
    results = evaluator.simple_evaluate(
45
46
        model=args.model,
        model_args=args.model_args,
47
        tasks=task_names,
48
49
50
51
52
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
        device=args.device,
        no_cache=args.no_cache,
        limit=args.limit,
Leo Gao's avatar
Leo Gao committed
53
        description_dict=description_dict,
54
        check_integrity=args.check_integrity
55
    )
Leo Gao's avatar
Update  
Leo Gao committed
56

Jason Phang's avatar
Jason Phang committed
57
    dumped = json.dumps(results, indent=2)
58
    
Jason Phang's avatar
Jason Phang committed
59
    print(dumped)
60

Jason Phang's avatar
Jason Phang committed
61
62
63
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
64

65
66
67
68
    print(
        f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
        f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
    )
69
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
70

71

Jason Phang's avatar
Jason Phang committed
72
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
73
    main()