"docs/zh_cn/advanced_guides/datasets/sunrgbd_det.md" did not exist on "508d918c3fdb81b12a94f1b5b11db15997657550"
main.py 3.56 KB
Newer Older
Jason Phang's avatar
Jason Phang committed
1
2
import argparse
import json
Leo Gao's avatar
Leo Gao committed
3
import logging
4
import fnmatch
5
import os
Leo Gao's avatar
Leo Gao committed
6

7
from lm_eval import tasks, evaluator
Jason Phang's avatar
lib  
Jason Phang committed
8

Leo Gao's avatar
Leo Gao committed
9
logging.getLogger("openai").setLevel(logging.WARNING)
Leo Gao's avatar
Leo Gao committed
10

Fabrizio Milo's avatar
Fabrizio Milo committed
11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class MultiChoice:
    def __init__(self, choices):
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
    def __contains__(self, values):
        for value in values.split(","):
            if len(fnmatch.filter(self.choices, value)) == 0:
                return False

        return True

    def __iter__(self):
        for choice in self.choices:
            yield choice

Fabrizio Milo's avatar
Fabrizio Milo committed
28

Jason Phang's avatar
Jason Phang committed
29
30
def parse_args():
    parser = argparse.ArgumentParser()
Fabrizio Milo's avatar
Fabrizio Milo committed
31
32
33
34
35
    parser.add_argument("--model", required=True)
    parser.add_argument("--model_args", default="")
    parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
    parser.add_argument("--provide_description", action="store_true")
    parser.add_argument("--num_fewshot", type=int, default=0)
36
    parser.add_argument("--batch_size", type=str, default=None)
Fabrizio Milo's avatar
Fabrizio Milo committed
37
38
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--output_path", default=None)
39
40
41
42
    parser.add_argument("--limit", type=float, default=None,
                        help="Limit the number of examples per task. "
                             "If <1, limit is a percentage of the total number of examples.")
    parser.add_argument("--data_sampling", type=float, default=None)
Fabrizio Milo's avatar
Fabrizio Milo committed
43
44
45
46
    parser.add_argument("--no_cache", action="store_true")
    parser.add_argument("--decontamination_ngrams_path", default=None)
    parser.add_argument("--description_dict_path", default=None)
    parser.add_argument("--check_integrity", action="store_true")
47

Jason Phang's avatar
Jason Phang committed
48
49
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
50

51
52
53
54
55
56
57
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
58
    return sorted(list(task_names))
59

Fabrizio Milo's avatar
Fabrizio Milo committed
60

61
def main():
Jason Phang's avatar
Jason Phang committed
62
    args = parse_args()
Fabrizio Milo's avatar
Fabrizio Milo committed
63

64
    assert not args.provide_description  # not implemented
Fabrizio Milo's avatar
Fabrizio Milo committed
65

Leo Gao's avatar
Leo Gao committed
66
    if args.limit:
Fabrizio Milo's avatar
Fabrizio Milo committed
67
68
69
        print(
            "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
        )
Leo Gao's avatar
Leo Gao committed
70

71
    if args.tasks is None:
researcher2's avatar
researcher2 committed
72
        task_names = tasks.ALL_TASKS
Jason Phang's avatar
Jason Phang committed
73
    else:
74
        task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
Leo Gao's avatar
Leo Gao committed
75

76
77
    print(f"Selected Tasks: {task_names}")

78
79
    description_dict = {}
    if args.description_dict_path:
Fabrizio Milo's avatar
Fabrizio Milo committed
80
        with open(args.description_dict_path, "r") as f:
81
82
            description_dict = json.load(f)

83
    results = evaluator.simple_evaluate(
84
85
        model=args.model,
        model_args=args.model_args,
86
        tasks=task_names,
87
88
89
90
91
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
        device=args.device,
        no_cache=args.no_cache,
        limit=args.limit,
92
        description_dict=description_dict,
93
        decontamination_ngrams_path=args.decontamination_ngrams_path,
Fabrizio Milo's avatar
Fabrizio Milo committed
94
        check_integrity=args.check_integrity,
95
    )
96

Fabrizio Milo's avatar
Fabrizio Milo committed
97
    dumped = json.dumps(results, indent=2)
Jason Phang's avatar
Jason Phang committed
98
    print(dumped)
99

Jason Phang's avatar
Jason Phang committed
100
    if args.output_path:
101
        os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
Jason Phang's avatar
Jason Phang committed
102
103
        with open(args.output_path, "w") as f:
            f.write(dumped)
Jason Phang's avatar
Jason Phang committed
104

105
106
107
108
    print(
        f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
        f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
    )
109
    print(evaluator.make_table(results))
Jason Phang's avatar
lib  
Jason Phang committed
110

111

Jason Phang's avatar
Jason Phang committed
112
if __name__ == "__main__":
Jason Phang's avatar
lib  
Jason Phang committed
113
    main()