"tests/xlm/test_tokenization_xlm.py" did not exist on "a3c5883f2c9a12360cee0734dfb262f92b912b24"
__init__.py 8.77 KB
Newer Older
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import yaml
3
from typing import List, Union, Dict
&'s avatar
& committed
4

5
from lm_eval import utils
6
from lm_eval import prompts
7
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
lintangsutawika's avatar
lintangsutawika committed
8
from lm_eval.api.registry import (
9
10
    register_task,
    register_group,
lintangsutawika's avatar
lintangsutawika committed
11
12
    TASK_REGISTRY,
    GROUP_REGISTRY,
haileyschoelkopf's avatar
haileyschoelkopf committed
13
    ALL_TASKS,
lintangsutawika's avatar
lintangsutawika committed
14
)
lintangsutawika's avatar
lintangsutawika committed
15

lintangsutawika's avatar
lintangsutawika committed
16
17
import logging

lintangsutawika's avatar
lintangsutawika committed
18
19
20
eval_logger = logging.getLogger("lm-eval")

# import python tasks
21
from .squadv2.squad import SQuAD2
lintangsutawika's avatar
lintangsutawika committed
22

23

24

25
def register_configurable_task(config: Dict[str, str]) -> int:
26
27
28
29
30
31
32
33
34
35
36
    SubClass = type(
        config["task"] + "ConfigurableTask",
        (ConfigurableTask,),
        {"CONFIG": TaskConfig(**config)},
    )

    if "task" in config:
        task_name = "{}".format(config["task"])
        register_task(task_name)(SubClass)

    if "group" in config:
baberabb's avatar
baberabb committed
37
38
39
        if config["group"] == config["task"]:
            raise ValueError("task and group name cannot be the same")
        elif type(config["group"]) == str:
40
41
42
43
44
45
46
47
48
            group_name = [config["group"]]
        else:
            group_name = config["group"]

        for group in group_name:
            register_group(group)(SubClass)

    return 0

lintangsutawika's avatar
format  
lintangsutawika committed
49

lintangsutawika's avatar
lintangsutawika committed
50
def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int:
51
52
    group = config["group"]
    all_task_list = config["task"]
lintangsutawika's avatar
format  
lintangsutawika committed
53
54
    config_list = [task for task in all_task_list if type(task) != str]
    task_list = [task for task in all_task_list if type(task) == str]
55
56

    for task_config in config_list:
lintangsutawika's avatar
lintangsutawika committed
57
        task_config = utils.load_yaml_config(yaml_path, task_config)
58
59
60
61
        var_configs = check_prompt_config(
            {
                **task_config,
                **{"group": group},
lintangsutawika's avatar
lintangsutawika committed
62
63
            },
            yaml_path=os.path.dirname(yaml_path),
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        )
        for config in var_configs:
            register_configurable_task(config)

    task_names = utils.pattern_match(task_list, ALL_TASKS)
    for task in task_names:
        if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
            if group in GROUP_REGISTRY:
                GROUP_REGISTRY[group].append(task)
            else:
                GROUP_REGISTRY[group] = [task]
                ALL_TASKS.add(group)

    return 0

79

lintangsutawika's avatar
lintangsutawika committed
80
81
82
def check_prompt_config(
    config: Dict[str, str], yaml_path: str = None
) -> List[Dict[str, str]]:
83
84
85
86
87
    all_configs = []
    if "use_prompt" in config:
        prompt_list = prompts.load_prompt_list(
            use_prompt=config["use_prompt"],
            dataset_name=config["dataset_path"],
lintangsutawika's avatar
lintangsutawika committed
88
            subset_name=config["dataset_name"] if "dataset_name" in config else None,
lintangsutawika's avatar
lintangsutawika committed
89
            yaml_path=yaml_path,
90
91
92
93
94
95
96
97
        )
        for idx, prompt_variation in enumerate(prompt_list):
            all_configs.append(
                {
                    **config,
                    **{"use_prompt": prompt_variation},
                    **{
                        "task": "_".join(
98
99
100
101
                            [
                                config["task"]
                                if "task" in config
                                else get_task_name_from_config(config),
lintangsutawika's avatar
lintangsutawika committed
102
103
104
                                prompt_variation.split("/")[-1]
                                if ".yaml" in prompt_variation
                                else prompt_variation,
105
106
107
                            ]
                        )
                    },
108
                    **{"output_type": "generate_until"},
109
110
111
112
113
114
115
                }
            )
    else:
        all_configs.append(config)
    return all_configs


116
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
Lintang Sutawika's avatar
Lintang Sutawika committed
117
    if "dataset_name" in task_config:
Lintang Sutawika's avatar
Lintang Sutawika committed
118
119
120
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
121
122


lintangsutawika's avatar
lintangsutawika committed
123
def include_task_folder(task_dir: str, register_task: bool = True) -> None:
124
125
126
127
    """
    Calling this function
    """
    for root, subdirs, file_list in os.walk(task_dir):
128
129
130
131
132
133
134
        # if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                try:
                    config = utils.load_yaml_config(yaml_path)

lintangsutawika's avatar
lintangsutawika committed
135
136
                    if "task" not in config:
                        continue
lintangsutawika's avatar
lintangsutawika committed
137

lintangsutawika's avatar
lintangsutawika committed
138
139
140
141
142
143
144
145
146
147
                    all_configs = check_prompt_config(
                        config, yaml_path=os.path.dirname(yaml_path)
                    )
                    for config in all_configs:
                        if register_task:
                            if type(config["task"]) == str:
                                register_configurable_task(config)
                        else:
                            if type(config["task"]) == list:
                                register_configurable_group(config, yaml_path)
baberabb's avatar
baberabb committed
148
149
150
151
                except ModuleNotFoundError as e:
                    eval_logger.warning(
                        f"{yaml_path}: {e}. Config will not be added to registry."
                    )
152
                except Exception as error:
lintangsutawika's avatar
lintangsutawika committed
153
                    import traceback
lintangsutawika's avatar
lintangsutawika committed
154

lintangsutawika's avatar
lintangsutawika committed
155
                    eval_logger.debug(
156
157
158
                        "Failed to load config in\n"
                        f"                                 {yaml_path}\n"
                        "                                 Config will not be added to registry\n"
lintangsutawika's avatar
lintangsutawika committed
159
160
                        f"                                 Error: {error}\n"
                        f"                                 Traceback: {traceback.format_exc()}"
161
                    )
lintangsutawika's avatar
lintangsutawika committed
162
    return 0
163
164
165
166
167
168
169


def include_path(task_dir):
    include_task_folder(task_dir)
    # Register Benchmarks after all tasks have been added
    include_task_folder(task_dir, register_task=False)
    return 0
170

lintangsutawika's avatar
lintangsutawika committed
171

172
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
173
include_path(task_dir)
lintangsutawika's avatar
lintangsutawika committed
174

lintangsutawika's avatar
lintangsutawika committed
175

176
177
def get_task(task_name, config):
    try:
178
        return TASK_REGISTRY[task_name](config=config)
179
    except KeyError:
lintangsutawika's avatar
lintangsutawika committed
180
        eval_logger.info("Available tasks:")
181
        eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
lintangsutawika's avatar
lintangsutawika committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        raise KeyError(f"Missing task {task_name}")


def get_task_name_from_object(task_object):
    for name, class_ in TASK_REGISTRY.items():
        if class_ is task_object:
            return name

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )


# TODO: pass num_fewshot and other cmdline overrides in a better way
200
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
lintangsutawika's avatar
lintangsutawika committed
201
    config = {**kwargs}
202
203
204
205
206

    task_name_from_registry_dict = {}
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

lintangsutawika's avatar
lintangsutawika committed
207
208
209
    if type(task_name_list) != list:
        task_name_list = [task_name_list]

210
211
212
    for task_element in task_name_list:
        if isinstance(task_element, str):
            if task_element in GROUP_REGISTRY:
213
                group_name = task_element
214
215
                for task_name in GROUP_REGISTRY[task_element]:
                    if task_name not in task_name_from_registry_dict:
lintangsutawika's avatar
lintangsutawika committed
216
217
218
219
220
221
222
223
224
225
226
                        task_obj = get_task_dict(task_name)
                        if task_name in task_obj.keys():
                            task_dict = {
                                task_name: (group_name, task_obj[task_name]),
                            }
                        else:
                            task_dict = {
                                task_name: (group_name, None),
                                **task_obj,
                            }

227
228
                        task_name_from_registry_dict = {
                            **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
229
                            **task_dict,
lintangsutawika's avatar
lintangsutawika committed
230
                        }
231
            else:
232
                task_name = task_element
233
234
235
                if task_name not in task_name_from_registry_dict:
                    task_name_from_registry_dict = {
                        **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
236
237
                        task_name: get_task(task_name=task_element, config=config),
                    }
238
239

        elif isinstance(task_element, dict):
240
            task_element.update(config)
241
242
243
            task_name_from_config_dict = {
                **task_name_from_config_dict,
                get_task_name_from_config(task_element): ConfigurableTask(
244
                    config=task_element
lintangsutawika's avatar
lintangsutawika committed
245
                ),
246
247
248
249
250
            }

        elif isinstance(task_element, Task):
            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
251
                get_task_name_from_object(task_element): task_element,
252
            }
lintangsutawika's avatar
lintangsutawika committed
253
254
255
256

    assert set(task_name_from_registry_dict.keys()).isdisjoint(
        set(task_name_from_object_dict.keys())
    )
lintangsutawika's avatar
lintangsutawika committed
257
258
259
    return {
        **task_name_from_registry_dict,
        **task_name_from_config_dict,
260
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
261
    }