__init__.py 6.19 KB
Newer Older
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import yaml
lintangsutawika's avatar
lintangsutawika committed
3
from typing import List, Union
&'s avatar
& committed
4

5
from lm_eval import utils
lintangsutawika's avatar
lintangsutawika committed
6
from lm_eval.logger import eval_logger
7
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
lintangsutawika's avatar
lintangsutawika committed
8
from lm_eval.api.registry import (
9
10
    register_task,
    register_group,
lintangsutawika's avatar
lintangsutawika committed
11
12
    TASK_REGISTRY,
    GROUP_REGISTRY,
haileyschoelkopf's avatar
haileyschoelkopf committed
13
    ALL_TASKS,
lintangsutawika's avatar
lintangsutawika committed
14
)
lintangsutawika's avatar
lintangsutawika committed
15

16

17
def get_task_name_from_config(task_config):
Lintang Sutawika's avatar
Lintang Sutawika committed
18
    if "dataset_name" in task_config:
Lintang Sutawika's avatar
Lintang Sutawika committed
19
20
21
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
22
23


24
25
26
27
28
def include_task_folder(task_dir):
    """
    Calling this function
    """
    for root, subdirs, file_list in os.walk(task_dir):
29
        if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
30
31
32
33
34
35
36
37
38
39
40
            for f in file_list:
                if f.endswith(".yaml"):
                    yaml_path = os.path.join(root, f)
                    try:
                        config = utils.load_yaml_config(yaml_path)

                        SubClass = type(
                            config["task"] + "ConfigurableTask",
                            (ConfigurableTask,),
                            {"CONFIG": TaskConfig(**config)},
                        )
41

42
43
44
                        if "task" in config:
                            task_name = "{}".format(config["task"])
                            register_task(task_name)(SubClass)
45

46
47
48
49
50
51
52
53
54
55
                        if "group" in config:
                            for group in config["group"]:
                                register_group(group)(SubClass)
                    except Exception as error:
                        eval_logger.warning(
                            "Failed to load config in\n"
                            f"                                 {yaml_path}\n"
                            "                                 Config will not be added to registry\n"
                            f"                                 Error: {error}"
                        )
56

57

lintangsutawika's avatar
lintangsutawika committed
58
59
60
61
62
63
def include_benchmarks(task_dir, benchmark_dir="benchmarks"):

    for root, subdirs, file_list in os.walk(os.path.join(task_dir, benchmark_dir)):
        if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
            for f in file_list:
                if f.endswith(".yaml"):
lintangsutawika's avatar
changes  
lintangsutawika committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                    try:
                        benchmark_path = os.path.join(root, f)

                        with open(benchmark_path, "rb") as file:
                            yaml_config = yaml.full_load(file)

                        assert "group" in yaml_config
                        group = yaml_config["group"]
                        task_list = yaml_config["task"]
                        task_names = utils.pattern_match(task_list, ALL_TASKS)
                        for task in task_names:
                            if task in TASK_REGISTRY:
                                if group in GROUP_REGISTRY:
                                    GROUP_REGISTRY[group].append(task)
                                else:
                                    GROUP_REGISTRY[group] = [task]
                                    ALL_TASKS.add(group)
                    except Exception as error:
                        eval_logger.warning(
                            "Failed to load benchmark in\n"
                            f"                                 {benchmark_path}\n"
                            "                                 Benchmark will not be added to registry\n"
                            f"                                 Error: {error}"
                        )
lintangsutawika's avatar
lintangsutawika committed
88
89


90
91
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_task_folder(task_dir)
lintangsutawika's avatar
lintangsutawika committed
92
include_benchmarks(task_dir)
lintangsutawika's avatar
lintangsutawika committed
93

lintangsutawika's avatar
lintangsutawika committed
94

95
96
def get_task(task_name, config):
    try:
97
        return TASK_REGISTRY[task_name](config=config)
98
    except KeyError:
lintangsutawika's avatar
lintangsutawika committed
99
        eval_logger.info("Available tasks:")
100
        eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
lintangsutawika's avatar
lintangsutawika committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        raise KeyError(f"Missing task {task_name}")


def get_task_name_from_object(task_object):
    for name, class_ in TASK_REGISTRY.items():
        if class_ is task_object:
            return name

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )


# TODO: pass num_fewshot and other cmdline overrides in a better way
lintangsutawika's avatar
lintangsutawika committed
119
120
121
def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):

    config = {**kwargs}
122
123
124
125
126
127
128
129
130
131
132
133
134

    task_name_from_registry_dict = {}
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

    for task_element in task_name_list:
        if isinstance(task_element, str):

            if task_element in GROUP_REGISTRY:
                for task_name in GROUP_REGISTRY[task_element]:
                    if task_name not in task_name_from_registry_dict:
                        task_name_from_registry_dict = {
                            **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
135
136
                            task_name: get_task(task_name=task_name, config=config),
                        }
137
            else:
138
                task_name = task_element
139
140
141
                if task_name not in task_name_from_registry_dict:
                    task_name_from_registry_dict = {
                        **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
142
143
                        task_name: get_task(task_name=task_element, config=config),
                    }
144
145

        elif isinstance(task_element, dict):
146
            task_element.update(config)
147
148
149
            task_name_from_config_dict = {
                **task_name_from_config_dict,
                get_task_name_from_config(task_element): ConfigurableTask(
150
                    config=task_element
lintangsutawika's avatar
lintangsutawika committed
151
                ),
152
153
154
155
156
157
            }

        elif isinstance(task_element, Task):

            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
158
                get_task_name_from_object(task_element): task_element,
159
            }
lintangsutawika's avatar
lintangsutawika committed
160
161
162
163

    assert set(task_name_from_registry_dict.keys()).isdisjoint(
        set(task_name_from_object_dict.keys())
    )
lintangsutawika's avatar
lintangsutawika committed
164
165
166
    return {
        **task_name_from_registry_dict,
        **task_name_from_config_dict,
167
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
168
    }