main.py 4.42 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3
import logging
4
import fnmatch
Leo Gao's avatar
Leo Gao committed
5

6
from lm_eval import tasks, evaluator
Jason Phang's avatar
lib  
Jason Phang committed
7

Leo Gao's avatar
Leo Gao committed
8
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
9

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class MultiChoice:
    def __init__(self, choices):
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
    def __contains__(self, values):
        for value in values.split(","):
            if len(fnmatch.filter(self.choices, value)) == 0:
                return False

        return True

    def __iter__(self):
        for choice in self.choices:
            yield choice

# Get task base classes for filtering
task_types = list(set([task.__bases__[0].__name__ for task in tasks.TASK_REGISTRY.values()]))

Jason Phang's avatar
Jason Phang committed
29
30
31
32
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
33
34
35
    parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
    parser.add_argument('--task_type', default=None, choices=MultiChoice(task_types))    
    parser.add_argument('--exclude_tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
Jason Phang's avatar
Jason Phang committed
36
    parser.add_argument('--provide_description', action="store_true")
Leo Gao's avatar
Leo Gao committed
37
    parser.add_argument('--num_fewshot', type=int, default=0)
Leo Gao's avatar
Leo Gao committed
38
    parser.add_argument('--batch_size', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
39
    parser.add_argument('--device', type=str, default=None)
Jason Phang's avatar
Jason Phang committed
40
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
41
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
42
    parser.add_argument('--no_cache', action="store_true")
43
44
45
46
    parser.add_argument('--decontaminate', action="store_true")
    parser.add_argument('--ngrams_path', default=None)
    parser.add_argument('--ngrams_n_size', type=int, default=None)

Jason Phang's avatar
Jason Phang committed
47
48
    return parser.parse_args()

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def ensure_correct_decontamination_params(args):
    valid = True
    if args.decontaminate:
        if not args.ngrams_n_size:
            print("Please specify n size of training set n-grams. (--ngrams_n_size)")
            valid = False
        if not args.ngrams_path:
            print("Please specify path containing training set n-grams. (--ngrams_path)")
            valid = False

    return valid

# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return list(task_names)

Jason Phang's avatar
Jason Phang committed
70
def main():
Leo Gao's avatar
Leo Gao committed
71

Jason Phang's avatar
Jason Phang committed
72
    args = parse_args()
73
74
    if not ensure_correct_decontamination_params(args):
        return
Jason Phang's avatar
seed  
Jason Phang committed
75

76
    # assert not args.provide_description # not implemented
Leo Gao's avatar
Leo Gao committed
77
78
79
    
    if args.limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
Leo Gao's avatar
Leo Gao committed
80

81
82
83
84
85
86
87
88
89
    if args.task_type:
        task_types = args.task_type.split(",")
        task_names = list(dict(filter(lambda x: x[1].__bases__[0].__name__ in task_types,
                                      tasks.TASK_REGISTRY.items())
                                      ).keys())

    if args.tasks is None:
        if args.task_type is None:
            task_names = tasks.ALL_TASKS
Jason Phang's avatar
Jason Phang committed
90
    else:
91
        task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
Leo Gao's avatar
Leo Gao committed
92

93
94
95
    if args.exclude_tasks:
        exclude_tasks = pattern_match(args.exclude_tasks.split(","), task_names)
        task_names = list(filter(lambda x: x not in exclude_tasks, task_names))
Leo Gao's avatar
Update  
Leo Gao committed
96

97
98
99
100
101
102
103
104
105
106
107
108
109
    if len(task_names) == 0:
        print("You must have excluded the tasks you specified, exiting.")
        return

    print(f"Selected Tasks: {task_names}")

    results = evaluator.simple_evaluate(args.model, args.model_args, task_names, 
                                        num_fewshot=args.num_fewshot, batch_size=args.batch_size, 
                                        device=args.device, no_cache=args.no_cache, limit=args.limit,
                                        decontaminate=args.decontaminate, ngrams_path=args.ngrams_path, 
                                        ngrams_n_size=args.ngrams_n_size)

    dumped = json.dumps(results, indent=2)    
Jason Phang's avatar
Jason Phang committed
110
    print(dumped)
111

Jason Phang's avatar
Jason Phang committed
112
113
114
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
115

116
    print(f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}")
117
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
118

Jason Phang's avatar
Jason Phang committed
119
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
120
    main()