main.py 4.68 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3
import logging
4
import fnmatch
Leo Gao's avatar
Leo Gao committed
5

6
from lm_eval import tasks, evaluator
Jason Phang's avatar
lib  
Jason Phang committed
7

Leo Gao's avatar
Leo Gao committed
8
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
9

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class MultiChoice:
    def __init__(self, choices):
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
    def __contains__(self, values):
        for value in values.split(","):
            if len(fnmatch.filter(self.choices, value)) == 0:
                return False

        return True

    def __iter__(self):
        for choice in self.choices:
            yield choice

# Get task base classes for filtering
task_types = list(set([task.__bases__[0].__name__ for task in tasks.TASK_REGISTRY.values()]))

Jason Phang's avatar
Jason Phang committed
29
30
31
32
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', required=True)
    parser.add_argument('--model_args', default="")
33
34
35
    parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
    parser.add_argument('--task_type', default=None, choices=MultiChoice(task_types))    
    parser.add_argument('--exclude_tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
Jason Phang's avatar
Jason Phang committed
36
    parser.add_argument('--provide_description', action="store_true")
Leo Gao's avatar
Leo Gao committed
37
    parser.add_argument('--num_fewshot', type=int, default=0)
Leo Gao's avatar
Leo Gao committed
38
    parser.add_argument('--batch_size', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
39
    parser.add_argument('--device', type=str, default=None)
Jason Phang's avatar
Jason Phang committed
40
    parser.add_argument('--output_path', default=None)
Leo Gao's avatar
Leo Gao committed
41
    parser.add_argument('--limit', type=int, default=None)
Leo Gao's avatar
Leo Gao committed
42
    parser.add_argument('--no_cache', action="store_true")
43
44
45
    parser.add_argument('--decontaminate', action="store_true")
    parser.add_argument('--ngrams_path', default=None)
    parser.add_argument('--ngrams_n_size', type=int, default=None)
researcher2's avatar
researcher2 committed
46
    parser.add_argument('--description_dict_path', default=None)    
47

Jason Phang's avatar
Jason Phang committed
48
49
    return parser.parse_args()

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def ensure_correct_decontamination_params(args):
    valid = True
    if args.decontaminate:
        if not args.ngrams_n_size:
            print("Please specify n size of training set n-grams. (--ngrams_n_size)")
            valid = False
        if not args.ngrams_path:
            print("Please specify path containing training set n-grams. (--ngrams_path)")
            valid = False

    return valid

# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return list(task_names)

71
def main():
Jason Phang's avatar
Jason Phang committed
72
    args = parse_args()
73
74
    if not ensure_correct_decontamination_params(args):
        return
75
        
76
    assert not args.provide_description  # not implemented
Leo Gao's avatar
Leo Gao committed
77
78
79
    
    if args.limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
Leo Gao's avatar
Leo Gao committed
80

81
82
83
84
85
86
87
88
89
    if args.task_type:
        task_types = args.task_type.split(",")
        task_names = list(dict(filter(lambda x: x[1].__bases__[0].__name__ in task_types,
                                      tasks.TASK_REGISTRY.items())
                                      ).keys())

    if args.tasks is None:
        if args.task_type is None:
            task_names = tasks.ALL_TASKS
Jason Phang's avatar
Jason Phang committed
90
    else:
91
        task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
Leo Gao's avatar
Leo Gao committed
92

93
94
95
    if args.exclude_tasks:
        exclude_tasks = pattern_match(args.exclude_tasks.split(","), task_names)
        task_names = list(filter(lambda x: x not in exclude_tasks, task_names))
Leo Gao's avatar
Update  
Leo Gao committed
96

97
98
99
100
101
102
    if len(task_names) == 0:
        print("You must have excluded the tasks you specified, exiting.")
        return

    print(f"Selected Tasks: {task_names}")

103
104
105
106
107
    description_dict = {}
    if args.description_dict_path:
        with open(args.description_dict_path, 'r') as f:
            description_dict = json.load(f)

108
    results = evaluator.simple_evaluate(
109
110
        model=args.model,
        model_args=args.model_args,
111
        tasks=task_names,
112
113
114
115
116
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
        device=args.device,
        no_cache=args.no_cache,
        limit=args.limit,
117
118
119
120
        description_dict=description_dict,
        decontaminate=args.decontaminate,
        ngrams_path=args.ngrams_path,
        ngrams_n_size=args.ngrams_n_size
121
    )
122
123

    dumped = json.dumps(results, indent=2)    
Jason Phang's avatar
Jason Phang committed
124
    print(dumped)
125

Jason Phang's avatar
Jason Phang committed
126
127
128
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
129

130
131
132
133
    print(
        f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
        f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
    )
134
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
135

136

Jason Phang's avatar
Jason Phang committed
137
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
138
    main()