main.py 1.75 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Jason Phang's avatar
seed  
Jason Phang committed
3
4
import numpy as np
import random
Leo Gao's avatar
Leo Gao committed
5
import logging
Leo Gao's avatar
Leo Gao committed
6

Leo Gao's avatar
Leo Gao committed
7
from lm_eval import models, tasks, evaluator, base
Jason Phang's avatar
lib  
Jason Phang committed
8

Leo Gao's avatar
Leo Gao committed
9
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
10

Jason Phang's avatar
Jason Phang committed
11
12
13
14
15
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
    parser.add_argument('--tasks', default="all_tasks")
16
    parser.add_argument('--description_path', default=None)
Leo Gao's avatar
Leo Gao committed
17
    parser.add_argument('--num_fewshot', type=int, default=0)
Leo Gao's avatar
Leo Gao committed
18
    parser.add_argument('--batch_size', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
19
    parser.add_argument('--device', type=str, default=None)
Jason Phang's avatar
Jason Phang committed
20
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
21
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
22
    parser.add_argument('--no_cache', action="store_true")
Jason Phang's avatar
Jason Phang committed
23
24
25
    return parser.parse_args()

def main():
Leo Gao's avatar
Leo Gao committed
26

Jason Phang's avatar
Jason Phang committed
27
    args = parse_args()
Jason Phang's avatar
seed  
Jason Phang committed
28

Leo Gao's avatar
Leo Gao committed
29
30
    if args.limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
Leo Gao's avatar
Leo Gao committed
31

Jason Phang's avatar
Jason Phang committed
32
33
34
35
    if args.tasks == "all_tasks":
        task_names = tasks.ALL_TASKS
    else:
        task_names = args.tasks.split(",")
Leo Gao's avatar
Leo Gao committed
36

37
38
39
40
41
42
43
44
45
46
47
    results = evaluator.simple_evaluate(
        args.model,
        args.model_args,
        task_names,
        args.description_path,
        args.num_fewshot,
        args.batch_size,
        args.device,
        args.no_cache,
        args.limit
    )
Leo Gao's avatar
Update  
Leo Gao committed
48

Jason Phang's avatar
Jason Phang committed
49
    dumped = json.dumps(results, indent=2)
50
    
Jason Phang's avatar
Jason Phang committed
51
    print(dumped)
52

Jason Phang's avatar
Jason Phang committed
53
54
55
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
56

57
    print(f"{args.model} ({args.model_args}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}")
58
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
59

Jason Phang's avatar
Jason Phang committed
60
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
61
    main()