"test/gtest-1.11.0/googletest/samples/sample6_unittest.cc" did not exist on "b2f89386d8f88655e47c4be0c719073dd6308a21"
main.py 3.32 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3
import logging
4
import fnmatch
5
import os
Leo Gao's avatar
Leo Gao committed
6

7
from lm_eval import tasks, evaluator
Jason Phang's avatar
lib  
Jason Phang committed
8

Leo Gao's avatar
Leo Gao committed
9
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
10

Fabrizio Milo's avatar
Fabrizio Milo committed
11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class MultiChoice:
    def __init__(self, choices):
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
    def __contains__(self, values):
        for value in values.split(","):
            if len(fnmatch.filter(self.choices, value)) == 0:
                return False

        return True

    def __iter__(self):
        for choice in self.choices:
            yield choice

Fabrizio Milo's avatar
Fabrizio Milo committed
28

Jason Phang's avatar
Jason Phang committed
29
30
def parse_args():
    parser = argparse.ArgumentParser()
Fabrizio Milo's avatar
Fabrizio Milo committed
31
32
33
34
35
36
37
38
39
40
41
42
43
    parser.add_argument("--model", required=True)
    parser.add_argument("--model_args", default="")
    parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
    parser.add_argument("--provide_description", action="store_true")
    parser.add_argument("--num_fewshot", type=int, default=0)
    parser.add_argument("--batch_size", type=int, default=None)
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--output_path", default=None)
    parser.add_argument("--limit", type=int, default=None)
    parser.add_argument("--no_cache", action="store_true")
    parser.add_argument("--decontamination_ngrams_path", default=None)
    parser.add_argument("--description_dict_path", default=None)
    parser.add_argument("--check_integrity", action="store_true")
44

Jason Phang's avatar
Jason Phang committed
45
46
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
47

48
49
50
51
52
53
54
55
56
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return list(task_names)

Fabrizio Milo's avatar
Fabrizio Milo committed
57

58
def main():
Jason Phang's avatar
Jason Phang committed
59
    args = parse_args()
Fabrizio Milo's avatar
Fabrizio Milo committed
60

61
    assert not args.provide_description  # not implemented
Fabrizio Milo's avatar
Fabrizio Milo committed
62

Leo Gao's avatar
Leo Gao committed
63
    if args.limit:
Fabrizio Milo's avatar
Fabrizio Milo committed
64
65
66
        print(
            "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
        )
Leo Gao's avatar
Leo Gao committed
67

68
    if args.tasks is None:
researcher2's avatar
researcher2 committed
69
        task_names = tasks.ALL_TASKS
Jason Phang's avatar
Jason Phang committed
70
    else:
71
        task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
Leo Gao's avatar
Leo Gao committed
72

73
74
    print(f"Selected Tasks: {task_names}")

75
76
    description_dict = {}
    if args.description_dict_path:
Fabrizio Milo's avatar
Fabrizio Milo committed
77
        with open(args.description_dict_path, "r") as f:
78
79
            description_dict = json.load(f)

80
    results = evaluator.simple_evaluate(
81
82
        model=args.model,
        model_args=args.model_args,
83
        tasks=task_names,
84
85
86
87
88
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
        device=args.device,
        no_cache=args.no_cache,
        limit=args.limit,
89
        description_dict=description_dict,
90
        decontamination_ngrams_path=args.decontamination_ngrams_path,
Fabrizio Milo's avatar
Fabrizio Milo committed
91
        check_integrity=args.check_integrity,
92
    )
93

Fabrizio Milo's avatar
Fabrizio Milo committed
94
    dumped = json.dumps(results, indent=2)
Jason Phang's avatar
Jason Phang committed
95
    print(dumped)
96

Jason Phang's avatar
Jason Phang committed
97
    if args.output_path:
98
        os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
Jason Phang's avatar
Jason Phang committed
99
100
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
101

102
103
104
105
    print(
        f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
        f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
    )
106
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
107

108

Jason Phang's avatar
Jason Phang committed
109
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
110
    main()