main.py 4.72 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3
import logging
4
import fnmatch
Leo Gao's avatar
Leo Gao committed
5

6
from lm_eval import tasks, evaluator
Jason Phang's avatar
lib  
Jason Phang committed
7

Leo Gao's avatar
Leo Gao committed
8
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
9

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class MultiChoice:
    def __init__(self, choices):
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
    def __contains__(self, values):
        for value in values.split(","):
            if len(fnmatch.filter(self.choices, value)) == 0:
                return False

        return True

    def __iter__(self):
        for choice in self.choices:
            yield choice

# Get task base classes for filtering
task_types = list(set([task.__bases__[0].__name__ for task in tasks.TASK_REGISTRY.values()]))

Jason Phang's avatar
Jason Phang committed
29
30
31
32
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
33
34
35
    parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
    parser.add_argument('--task_type', default=None, choices=MultiChoice(task_types))    
    parser.add_argument('--exclude_tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
Jason Phang's avatar
Jason Phang committed
36
    parser.add_argument('--provide_description', action="store_true")
Leo Gao's avatar
Leo Gao committed
37
    parser.add_argument('--num_fewshot', type=int, default=0)
Leo Gao's avatar
Leo Gao committed
38
    parser.add_argument('--batch_size', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
39
    parser.add_argument('--device', type=str, default=None)
Jason Phang's avatar
Jason Phang committed
40
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
41
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
42
    parser.add_argument('--no_cache', action="store_true")
43
44
45
46
    parser.add_argument('--decontaminate', action="store_true")
    parser.add_argument('--ngrams_path', default=None)
    parser.add_argument('--ngrams_n_size', type=int, default=None)

Jason Phang's avatar
Jason Phang committed
47
48
    return parser.parse_args()

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def ensure_correct_decontamination_params(args):
    valid = True
    if args.decontaminate:
        if not args.ngrams_n_size:
            print("Please specify n size of training set n-grams. (--ngrams_n_size)")
            valid = False
        if not args.ngrams_path:
            print("Please specify path containing training set n-grams. (--ngrams_path)")
            valid = False

    return valid

# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return list(task_names)

Jason Phang's avatar
Jason Phang committed
70
def main():
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
71
    parser.add_argument('--description_dict_path', default=None)
Jason Phang's avatar
Jason Phang committed
72
73
    return parser.parse_args()

Leo Gao's avatar
Leo Gao committed
74

75
def main():
Jason Phang's avatar
Jason Phang committed
76
    args = parse_args()
77
78
    if not ensure_correct_decontamination_params(args):
        return
79
        
80
    assert not args.provide_description  # not implemented
Leo Gao's avatar
Leo Gao committed
81
82
83
    
    if args.limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
Leo Gao's avatar
Leo Gao committed
84

85
86
87
88
89
90
91
92
93
    if args.task_type:
        task_types = args.task_type.split(",")
        task_names = list(dict(filter(lambda x: x[1].__bases__[0].__name__ in task_types,
                                      tasks.TASK_REGISTRY.items())
                                      ).keys())

    if args.tasks is None:
        if args.task_type is None:
            task_names = tasks.ALL_TASKS
Jason Phang's avatar
Jason Phang committed
94
    else:
95
        task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
Leo Gao's avatar
Leo Gao committed
96

97
98
99
    if args.exclude_tasks:
        exclude_tasks = pattern_match(args.exclude_tasks.split(","), task_names)
        task_names = list(filter(lambda x: x not in exclude_tasks, task_names))
Leo Gao's avatar
Update  
Leo Gao committed
100

101
102
103
104
105
106
    if len(task_names) == 0:
        print("You must have excluded the tasks you specified, exiting.")
        return

    print(f"Selected Tasks: {task_names}")

107
108
109
110
111
    description_dict = {}
    if args.description_dict_path:
        with open(args.description_dict_path, 'r') as f:
            description_dict = json.load(f)

112
    results = evaluator.simple_evaluate(
113
114
        model=args.model,
        model_args=args.model_args,
115
        tasks=task_names,
116
117
118
119
120
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
        device=args.device,
        no_cache=args.no_cache,
        limit=args.limit,
121
122
123
124
        description_dict=description_dict,
        decontaminate=args.decontaminate,
        ngrams_path=args.ngrams_path,
        ngrams_n_size=args.ngrams_n_size
125
    )
126
127

    dumped = json.dumps(results, indent=2)    
Jason Phang's avatar
Jason Phang committed
128
    print(dumped)
129

Jason Phang's avatar
Jason Phang committed
130
131
132
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
133

134
135
136
137
    print(
        f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
        f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
    )
138
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
139

140

Jason Phang's avatar
Jason Phang committed
141
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
142
    main()