main.py 3.78 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3
import logging
4
import fnmatch
5
import os
Leo Gao's avatar
Leo Gao committed
6

7
from lm_eval import tasks, evaluator
Jason Phang's avatar
lib  
Jason Phang committed
8

Leo Gao's avatar
Leo Gao committed
9
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
10

Fabrizio Milo's avatar
Fabrizio Milo committed
11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class MultiChoice:
    def __init__(self, choices):
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
    def __contains__(self, values):
        for value in values.split(","):
            if len(fnmatch.filter(self.choices, value)) == 0:
                return False

        return True

    def __iter__(self):
        for choice in self.choices:
            yield choice

Fabrizio Milo's avatar
Fabrizio Milo committed
28

Jason Phang's avatar
Jason Phang committed
29
30
def parse_args():
    parser = argparse.ArgumentParser()
Fabrizio Milo's avatar
Fabrizio Milo committed
31
32
33
34
35
    parser.add_argument("--model", required=True)
    parser.add_argument("--model_args", default="")
    parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
    parser.add_argument("--provide_description", action="store_true")
    parser.add_argument("--num_fewshot", type=int, default=0)
36
    parser.add_argument("--batch_size", type=str, default=None)
Fabrizio Milo's avatar
Fabrizio Milo committed
37
38
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--output_path", default=None)
39
40
41
42
    parser.add_argument("--limit", type=float, default=None,
                        help="Limit the number of examples per task. "
                             "If <1, limit is a percentage of the total number of examples.")
    parser.add_argument("--data_sampling", type=float, default=None)
Fabrizio Milo's avatar
Fabrizio Milo committed
43
44
45
46
    parser.add_argument("--no_cache", action="store_true")
    parser.add_argument("--decontamination_ngrams_path", default=None)
    parser.add_argument("--description_dict_path", default=None)
    parser.add_argument("--check_integrity", action="store_true")
47
48
    parser.add_argument("--write_out", action="store_true", default=False)
    parser.add_argument("--output_base_path", type=str, default=None)
49

Jason Phang's avatar
Jason Phang committed
50
51
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
52

53
54
55
56
57
58
59
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
60
    return sorted(list(task_names))
61

Fabrizio Milo's avatar
Fabrizio Milo committed
62

63
def main():
Jason Phang's avatar
Jason Phang committed
64
    args = parse_args()
Fabrizio Milo's avatar
Fabrizio Milo committed
65

66
    assert not args.provide_description  # not implemented
Fabrizio Milo's avatar
Fabrizio Milo committed
67

Leo Gao's avatar
Leo Gao committed
68
    if args.limit:
Fabrizio Milo's avatar
Fabrizio Milo committed
69
70
71
        print(
            "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
        )
Leo Gao's avatar
Leo Gao committed
72

73
    if args.tasks is None:
researcher2's avatar
researcher2 committed
74
        task_names = tasks.ALL_TASKS
Jason Phang's avatar
Jason Phang committed
75
    else:
76
        task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
Leo Gao's avatar
Leo Gao committed
77

78
79
    print(f"Selected Tasks: {task_names}")

80
81
    description_dict = {}
    if args.description_dict_path:
Fabrizio Milo's avatar
Fabrizio Milo committed
82
        with open(args.description_dict_path, "r") as f:
83
84
            description_dict = json.load(f)

85
    results = evaluator.simple_evaluate(
86
87
        model=args.model,
        model_args=args.model_args,
88
        tasks=task_names,
89
90
91
92
93
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
        device=args.device,
        no_cache=args.no_cache,
        limit=args.limit,
94
        description_dict=description_dict,
95
        decontamination_ngrams_path=args.decontamination_ngrams_path,
Fabrizio Milo's avatar
Fabrizio Milo committed
96
        check_integrity=args.check_integrity,
97
98
        write_out=args.write_out,
        output_base_path=args.output_base_path,
99
    )
100

Fabrizio Milo's avatar
Fabrizio Milo committed
101
    dumped = json.dumps(results, indent=2)
Jason Phang's avatar
Jason Phang committed
102
    print(dumped)
103

Jason Phang's avatar
Jason Phang committed
104
    if args.output_path:
105
        os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
Jason Phang's avatar
Jason Phang committed
106
107
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
108

109
110
111
112
    print(
        f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
        f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
    )
113
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
114

115

Jason Phang's avatar
Jason Phang committed
116
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
117
    main()