main.py 3.53 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3
import logging
4
import fnmatch
5
import yaml
Leo Gao's avatar
Leo Gao committed
6

7
from lm_eval import tasks, evaluator
8
from lm_eval.api.task import ConfigurableTask
Jason Phang's avatar
lib  
Jason Phang committed
9

Leo Gao's avatar
Leo Gao committed
10
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
11

Fabrizio Milo's avatar
Fabrizio Milo committed
12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class MultiChoice:
    def __init__(self, choices):
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
    def __contains__(self, values):
        for value in values.split(","):
            if len(fnmatch.filter(self.choices, value)) == 0:
                return False

        return True

    def __iter__(self):
        for choice in self.choices:
            yield choice

Fabrizio Milo's avatar
Fabrizio Milo committed
29

Jason Phang's avatar
Jason Phang committed
30
31
def parse_args():
    parser = argparse.ArgumentParser()
Fabrizio Milo's avatar
Fabrizio Milo committed
32
33
34
    parser.add_argument("--model", required=True)
    parser.add_argument("--model_args", default="")
    parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
35
    parser.add_argument("--config", default=None)
Fabrizio Milo's avatar
Fabrizio Milo committed
36
37
    parser.add_argument("--provide_description", action="store_true")
    parser.add_argument("--num_fewshot", type=int, default=0)
38
    parser.add_argument("--batch_size", type=int, default=1)
Fabrizio Milo's avatar
Fabrizio Milo committed
39
40
41
42
43
44
45
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--output_path", default=None)
    parser.add_argument("--limit", type=int, default=None)
    parser.add_argument("--no_cache", action="store_true")
    parser.add_argument("--decontamination_ngrams_path", default=None)
    parser.add_argument("--description_dict_path", default=None)
    parser.add_argument("--check_integrity", action="store_true")
46

Jason Phang's avatar
Jason Phang committed
47
48
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
49

50
51
52
53
54
55
56
57
58
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return list(task_names)

Fabrizio Milo's avatar
Fabrizio Milo committed
59

60
def main():
Jason Phang's avatar
Jason Phang committed
61
    args = parse_args()
Fabrizio Milo's avatar
Fabrizio Milo committed
62

Leo Gao's avatar
Leo Gao committed
63
    if args.limit:
Fabrizio Milo's avatar
Fabrizio Milo committed
64
65
66
        print(
            "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
        )
Leo Gao's avatar
Leo Gao committed
67

68
    if args.tasks is None:
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        if args.config:
            task_names = []
            for config_files in args.config.split(","):
                with open(config_files, "r") as f:
                    config = yaml.load(f, yaml.Loader)

                if args.num_fewshot != 0:
                    config["num_fewshot"] = args.num_fewshot

                if args.batch_size != None:
                    config["batch_size"] = args.batch_size

                task_names.append(config)
        else:
            task_names = tasks.ALL_TASKS
Jason Phang's avatar
Jason Phang committed
84
    else:
85
        task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
Leo Gao's avatar
Leo Gao committed
86

87
88
    print(f"Selected Tasks: {task_names}")

89
    results = evaluator.simple_evaluate(
90
91
        model=args.model,
        model_args=args.model_args,
92
        tasks=task_names,
93
94
95
96
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
        device=args.device,
        limit=args.limit,
97
        decontamination_ngrams_path=args.decontamination_ngrams_path,
Fabrizio Milo's avatar
Fabrizio Milo committed
98
        check_integrity=args.check_integrity,
99
    )
100

Fabrizio Milo's avatar
Fabrizio Milo committed
101
    dumped = json.dumps(results, indent=2)
Jason Phang's avatar
Jason Phang committed
102
    print(dumped)
103

Jason Phang's avatar
Jason Phang committed
104
105
106
    if args.output_path:
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
107

108
109
110
111
    print(
        f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
        f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
    )
112
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
113

114

Jason Phang's avatar
Jason Phang committed
115
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
116
    main()