main.py 1.89 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3
import logging
Leo Gao's avatar
Leo Gao committed
4

5
from lm_eval import tasks, evaluator
Jason Phang's avatar
lib  
Jason Phang committed
6

Leo Gao's avatar
Leo Gao committed
7
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
8

9

Jason Phang's avatar
Jason Phang committed
10
11
12
13
14
15
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
    parser.add_argument('--provide_description', action="store_true")
Leo Gao's avatar
Leo Gao committed
16
    parser.add_argument('--num_fewshot', type=int, default=0)
Leo Gao's avatar
Leo Gao committed
17
    parser.add_argument('--batch_size', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
18
    parser.add_argument('--device', type=str, default=None)
Jason Phang's avatar
Jason Phang committed
19
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
20
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
21
    parser.add_argument('--no_cache', action="store_true")
Jason Phang's avatar
Jason Phang committed
22
23
    return parser.parse_args()

Leo Gao's avatar
Leo Gao committed
24

25
def main():
Jason Phang's avatar
Jason Phang committed
26
    args = parse_args()
27
    assert not args.provide_description  # not implemented
Leo Gao's avatar
Leo Gao committed
28
29
30
    
    if args.limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
Leo Gao's avatar
Leo Gao committed
31

Jason Phang's avatar
Jason Phang committed
32
33
34
35
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Leo Gao's avatar
Leo Gao committed
36

37
38
39
40
41
42
43
44
45
46
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        task_names=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
        device=args.device,
        no_cache=args.no_cache,
        limit=args.limit,
    )
Leo Gao's avatar
Update  
Leo Gao committed
47

Jason Phang's avatar
Jason Phang committed
48
    dumped = json.dumps(results, indent=2)
49
    
Jason Phang's avatar
Jason Phang committed
50
    print(dumped)
51

Jason Phang's avatar
Jason Phang committed
52
53
54
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
55

56
57
58
59
    print(
        f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
        f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
    )
60
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
61

62

Jason Phang's avatar
Jason Phang committed
63
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
64
    main()