__init__.py 12.7 KB
Newer Older
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import yaml
3
import collections
4
from typing import List, Union, Dict
&'s avatar
& committed
5

6
from lm_eval import utils
7
from lm_eval import prompts
8
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
lintangsutawika's avatar
lintangsutawika committed
9
from lm_eval.api.registry import (
10
11
    register_task,
    register_group,
lintangsutawika's avatar
lintangsutawika committed
12
13
    TASK_REGISTRY,
    GROUP_REGISTRY,
haileyschoelkopf's avatar
haileyschoelkopf committed
14
    ALL_TASKS,
lintangsutawika's avatar
lintangsutawika committed
15
)
lintangsutawika's avatar
lintangsutawika committed
16

lintangsutawika's avatar
lintangsutawika committed
17
import logging
lintangsutawika's avatar
lintangsutawika committed
18

lintangsutawika's avatar
lintangsutawika committed
19
# import python tasks
lintangsutawika's avatar
lintangsutawika committed
20
from .squadv2.task import SQuAD2
lintangsutawika's avatar
lintangsutawika committed
21
22
23
24
25
26
27
28
from .scrolls.task import (
    QuALITY,
    NarrativeQA,
    ContractNLI,
    GovReport,
    SummScreenFD,
    QMSum,
)
lintangsutawika's avatar
lintangsutawika committed
29

30
eval_logger = utils.eval_logger
31

lintangsutawika's avatar
lintangsutawika committed
32

33
34
35
36
37
38
def load_task_or_group(yaml_path: str) -> ConfigurableTask:

    config = utils.load_yaml_config(yaml_path)
    return ConfigurableTask(config=config)


39
def register_configurable_task(config: Dict[str, str]) -> int:
40
41
42
43
44
45
46
47
48
49
50
    SubClass = type(
        config["task"] + "ConfigurableTask",
        (ConfigurableTask,),
        {"CONFIG": TaskConfig(**config)},
    )

    if "task" in config:
        task_name = "{}".format(config["task"])
        register_task(task_name)(SubClass)

    if "group" in config:
baberabb's avatar
baberabb committed
51
52
53
        if config["group"] == config["task"]:
            raise ValueError("task and group name cannot be the same")
        elif type(config["group"]) == str:
54
55
56
57
58
59
60
61
62
            group_name = [config["group"]]
        else:
            group_name = config["group"]

        for group in group_name:
            register_group(group)(SubClass)

    return 0

lintangsutawika's avatar
format  
lintangsutawika committed
63

lintangsutawika's avatar
lintangsutawika committed
64
def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int:
65
66
    group = config["group"]

lintangsutawika's avatar
lintangsutawika committed
67
68
69
70
71
72
73
74
75
76
77
78
79
    if group not in ["grouptest", "arc_stuff"]:
        return 0

    task_config_list = []
    group_config_list = []
    registered_task_or_group_list = []
    for task in config["task"]:
        if isinstance(task, str):
            registered_task_or_group_list.append(task)
        elif list(task.keys()) == ["group", "task"]:
            group_config_list.append(task)
        else:
            task_config_list.append(task)
80

lintangsutawika's avatar
lintangsutawika committed
81
    for task_config in task_config_list:
82
83
84
85
        base_config = {}
        task_name_config = {}
        if "task" in task_config:
            task_name = task_config["task"]
lintangsutawika's avatar
lintangsutawika committed
86
            if task_name in TASK_REGISTRY:
87
88
89
90
91
                task_obj = get_task_dict(task_name)[task_name]
                if type(task_obj) == tuple:
                    _, task_obj = task_obj

                if task_obj is not None:
Lintang Sutawika's avatar
Lintang Sutawika committed
92
                    base_config = task_obj._config.to_dict(keep_callable=True)
93
94
                    task_name_config["task"] = f"{group}_{task_name}"

lintangsutawika's avatar
lintangsutawika committed
95
        task_config = utils.load_yaml_config(yaml_path, task_config)
96
97
        var_configs = check_prompt_config(
            {
98
                **base_config,
99
100
                **task_config,
                **{"group": group},
101
                **task_name_config,
lintangsutawika's avatar
lintangsutawika committed
102
103
            },
            yaml_path=os.path.dirname(yaml_path),
104
105
106
107
        )
        for config in var_configs:
            register_configurable_task(config)

lintangsutawika's avatar
lintangsutawika committed
108
109
110
111
112
113
114
115
116
117
    for group_config in group_config_list:
        sub_group = group_config["group"]
        register_configurable_group(group_config, yaml_path)
        if group in GROUP_REGISTRY:
            GROUP_REGISTRY[group].append(sub_group)
        else:
            GROUP_REGISTRY[group] = [sub_group]
            ALL_TASKS.add(group)

    task_names = utils.pattern_match(registered_task_or_group_list, ALL_TASKS)
118
119
120
121
122
123
124
125
126
127
    for task in task_names:
        if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
            if group in GROUP_REGISTRY:
                GROUP_REGISTRY[group].append(task)
            else:
                GROUP_REGISTRY[group] = [task]
                ALL_TASKS.add(group)

    return 0

128

lintangsutawika's avatar
lintangsutawika committed
129
130
131
def check_prompt_config(
    config: Dict[str, str], yaml_path: str = None
) -> List[Dict[str, str]]:
132
133
134
135
136
    all_configs = []
    if "use_prompt" in config:
        prompt_list = prompts.load_prompt_list(
            use_prompt=config["use_prompt"],
            dataset_name=config["dataset_path"],
lintangsutawika's avatar
lintangsutawika committed
137
            subset_name=config["dataset_name"] if "dataset_name" in config else None,
lintangsutawika's avatar
lintangsutawika committed
138
            yaml_path=yaml_path,
139
140
141
142
143
144
145
146
        )
        for idx, prompt_variation in enumerate(prompt_list):
            all_configs.append(
                {
                    **config,
                    **{"use_prompt": prompt_variation},
                    **{
                        "task": "_".join(
147
148
149
150
                            [
                                config["task"]
                                if "task" in config
                                else get_task_name_from_config(config),
lintangsutawika's avatar
lintangsutawika committed
151
152
153
                                prompt_variation.split("/")[-1]
                                if ".yaml" in prompt_variation
                                else prompt_variation,
154
155
156
                            ]
                        )
                    },
157
                    **{"output_type": "generate_until"},
158
159
160
161
162
163
164
                }
            )
    else:
        all_configs.append(config)
    return all_configs


165
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
Lintang Sutawika's avatar
Lintang Sutawika committed
166
    if "dataset_name" in task_config:
Lintang Sutawika's avatar
Lintang Sutawika committed
167
168
169
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
170
171


lintangsutawika's avatar
lintangsutawika committed
172
def include_task_folder(task_dir: str, register_task: bool = True, task_name: str = None) -> None:
173
174
175
    """
    Calling this function
    """
176
177
178

    # Track whether any tasks failed during loading
    import_fail = False
179
    for root, subdirs, file_list in os.walk(task_dir):
180
181
182
183
184
185
186
        # if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                try:
                    config = utils.load_yaml_config(yaml_path)

lintangsutawika's avatar
lintangsutawika committed
187
188
                    if "task" not in config:
                        continue
lintangsutawika's avatar
lintangsutawika committed
189

lintangsutawika's avatar
lintangsutawika committed
190
191
192
193
194
195
196
197
198
199
                    all_configs = check_prompt_config(
                        config, yaml_path=os.path.dirname(yaml_path)
                    )
                    for config in all_configs:
                        if register_task:
                            if type(config["task"]) == str:
                                register_configurable_task(config)
                        else:
                            if type(config["task"]) == list:
                                register_configurable_group(config, yaml_path)
200

lintangsutawika's avatar
lintangsutawika committed
201
                # Log this silently and show it only when
202
                # the user defines the appropriate verbosity.
203
204
                except (ImportError, ModuleNotFoundError) as e:
                    import_fail = True
205
                    eval_logger.debug(
baberabb's avatar
baberabb committed
206
207
                        f"{yaml_path}: {e}. Config will not be added to registry."
                    )
208
                except Exception as error:
lintangsutawika's avatar
lintangsutawika committed
209
                    import traceback
lintangsutawika's avatar
lintangsutawika committed
210

211
212
                    eval_logger.warning(
                        "Unexpected error loading config in\n"
213
214
                        f"                                 {yaml_path}\n"
                        "                                 Config will not be added to registry\n"
lintangsutawika's avatar
lintangsutawika committed
215
216
                        f"                                 Error: {error}\n"
                        f"                                 Traceback: {traceback.format_exc()}"
217
                    )
218
219
220
221
222
223

    if import_fail:
        eval_logger.warning(
          "Some tasks could not be loaded due to missing dependencies."
          " Run with `--verbosity DEBUG` for full details."
          )
lintangsutawika's avatar
lintangsutawika committed
224
    return 0
225
226


227
def get_task_and_group(task_dir: str):
228
    tasks_and_groups = collections.defaultdict()
229
230
231
232
233
    for root, _, file_list in os.walk(task_dir):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                config = utils.simple_load_yaml_config(yaml_path)
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
                if list(config.keys()) == ["group", "task"]:
                    # This is a group config
                    tasks_and_groups[config["group"]] = {
                        "type": "group",
                        "task": -1, # This signals that 
                                    # we don't need to know 
                                    # the task list for indexing 
                                    # as it can be loaded 
                                    # when called.
                        "yaml_path": yaml_path,
                    }
                else:
                    # This is a task config
                    task = config["task"]
                    tasks_and_groups[task] = {
                        "type": "task",
                        "yaml_path": yaml_path,
                        }

                    if "group" in config:
                        groups = config["group"]
                        if isinstance(config["group"], str):
                            groups = [groups]

                        for group in groups:
                            if group not in tasks_and_groups:
                                tasks_and_groups[group] = {
                                    "type": "group",
                                    "task": [task],
                                    "yaml_path": -1,
                                }
                            else:
                                tasks_and_groups[group]["task"].append(task)

    return tasks_and_groups
269
270

def initialize_tasks(verbosity="INFO", include_path=None):
271
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
272
273
274
275
276
277
    all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
    if include_path is not None:
        if isinstance(include_path, str):
            include_path = [include_path]
        all_paths.extend(include_path)

278
    ALL_TASKS = {}
279
    for task_dir in all_paths:
280
281
        tasks = get_task_and_group(task_dir)
        ALL_TASKS = {**tasks, **ALL_TASKS}
282

283
    return ALL_TASKS
lintangsutawika's avatar
lintangsutawika committed
284

285
286
def get_task(task_name, config):
    try:
287
        return TASK_REGISTRY[task_name](config=config)
288
    except KeyError:
lintangsutawika's avatar
lintangsutawika committed
289
        eval_logger.info("Available tasks:")
290
        eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
lintangsutawika's avatar
lintangsutawika committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        raise KeyError(f"Missing task {task_name}")


def get_task_name_from_object(task_object):
    for name, class_ in TASK_REGISTRY.items():
        if class_ is task_object:
            return name

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )


# TODO: pass num_fewshot and other cmdline overrides in a better way
309
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
lintangsutawika's avatar
lintangsutawika committed
310
    config = {**kwargs}
311
312
313
314
315

    task_name_from_registry_dict = {}
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

lintangsutawika's avatar
lintangsutawika committed
316
317
318
    if type(task_name_list) != list:
        task_name_list = [task_name_list]

319
320
321
    for task_element in task_name_list:
        if isinstance(task_element, str):
            if task_element in GROUP_REGISTRY:
322
                group_name = task_element
323
324
                for task_name in GROUP_REGISTRY[task_element]:
                    if task_name not in task_name_from_registry_dict:
lintangsutawika's avatar
lintangsutawika committed
325
326
327
328
329
330
331
332
333
334
335
                        task_obj = get_task_dict(task_name)
                        if task_name in task_obj.keys():
                            task_dict = {
                                task_name: (group_name, task_obj[task_name]),
                            }
                        else:
                            task_dict = {
                                task_name: (group_name, None),
                                **task_obj,
                            }

336
337
                        task_name_from_registry_dict = {
                            **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
338
                            **task_dict,
lintangsutawika's avatar
lintangsutawika committed
339
                        }
340
            else:
341
                task_name = task_element
342
343
344
                if task_name not in task_name_from_registry_dict:
                    task_name_from_registry_dict = {
                        **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
345
346
                        task_name: get_task(task_name=task_element, config=config),
                    }
347
348

        elif isinstance(task_element, dict):
349
            task_element.update(config)
350
351
352
            task_name_from_config_dict = {
                **task_name_from_config_dict,
                get_task_name_from_config(task_element): ConfigurableTask(
353
                    config=task_element
lintangsutawika's avatar
lintangsutawika committed
354
                ),
355
356
357
358
359
            }

        elif isinstance(task_element, Task):
            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
360
                get_task_name_from_object(task_element): task_element,
361
            }
lintangsutawika's avatar
lintangsutawika committed
362
363
364
365

    assert set(task_name_from_registry_dict.keys()).isdisjoint(
        set(task_name_from_object_dict.keys())
    )
lintangsutawika's avatar
lintangsutawika committed
366
367
368
    return {
        **task_name_from_registry_dict,
        **task_name_from_config_dict,
369
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
370
    }