__init__.py 15.6 KB
Newer Older
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import yaml
3
import collections
4
from typing import List, Union, Dict
&'s avatar
& committed
5

6
from lm_eval import utils
7
from lm_eval import prompts
8
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
lintangsutawika's avatar
lintangsutawika committed
9
from lm_eval.api.registry import (
10
11
    register_task,
    register_group,
lintangsutawika's avatar
lintangsutawika committed
12
13
    TASK_REGISTRY,
    GROUP_REGISTRY,
haileyschoelkopf's avatar
haileyschoelkopf committed
14
    ALL_TASKS,
lintangsutawika's avatar
lintangsutawika committed
15
)
lintangsutawika's avatar
lintangsutawika committed
16

lintangsutawika's avatar
lintangsutawika committed
17
import logging
lintangsutawika's avatar
lintangsutawika committed
18

lintangsutawika's avatar
lintangsutawika committed
19
# import python tasks
lintangsutawika's avatar
lintangsutawika committed
20
from .squadv2.task import SQuAD2
lintangsutawika's avatar
lintangsutawika committed
21
22
23
24
25
26
27
28
from .scrolls.task import (
    QuALITY,
    NarrativeQA,
    ContractNLI,
    GovReport,
    SummScreenFD,
    QMSum,
)
lintangsutawika's avatar
lintangsutawika committed
29

30
eval_logger = utils.eval_logger
31

lintangsutawika's avatar
lintangsutawika committed
32

33
34
35
36
37
38
39
40
41
def is_group(task):
    if list(task.keys()) == ["group", "task"]:
        return True
    return False


def load_task_or_group(ALL_TASKS, task_name: str=None, task_config: dict=None) -> ConfigurableTask:

    if task_name is not None:
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        task_info = ALL_TASKS[task_name]
        yaml_path = task_info["yaml_path"]
        task_type = task_info["type"]
        subtask_list = task_info["task"] if "task" in task_info else -1
        if task_type == "task":
            task_config = utils.load_yaml_config(yaml_path)
            return ConfigurableTask(config=task_config)
        else:
            if subtask_list == -1:
                task_config = utils.load_yaml_config(yaml_path)
                group_name = task_config["group"]
                subtask_list = task_config["task"]
            else:
                group_name = task_name

            all_subtasks = {}
            for task_or_config in subtask_list:
                if isinstance(task_or_config, str):
                    task_object = load_task_or_group(ALL_TASKS, task_name=task_or_config)
                elif isinstance(task_or_config, dict):

                    if "group" in task_or_config:
                        all_subtasks[task_or_config["group"]] = (group_name, None)

                    task_object = load_task_or_group(ALL_TASKS, task_config=task_or_config)

                if isinstance(task_object, dict):
                    all_subtasks = {**task_object, **all_subtasks}
                else:
                    task_name = task_object._config["task"]
                    all_subtasks[task_name] = (group_name, task_object)
                    # if group_name is not None:
                    #     all_subtasks[task_name] = (group_name, task_object)
                    # else:
                    #     all_subtasks[task_name] = task_object
            return all_subtasks
78
79
80
    else:
        assert task_config is not None
        if is_group(task_config):
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
            group_name = task_config["group"]
            subtask_list = task_config["task"]
            all_subtasks = {}
            for task_or_config in subtask_list:
                if isinstance(task_or_config, str):
                    task_object = load_task_or_group(ALL_TASKS, task_name=task_or_config)
                    task_name = task_or_config
                elif isinstance(task_or_config, dict):
                    task_object = load_task_or_group(ALL_TASKS, task_config=task_or_config)

                if isinstance(task_object, dict):
                    all_subtasks = {**task_object, **all_subtasks}
                else:
                    task_name = task_object._config["task"]
                    all_subtasks[task_name] = (group_name, task_object)
            return all_subtasks
97
98
        else:
            task_type = "task"
99
100
101
102
            task_name = task_config["task"]
            base_task_info = ALL_TASKS[task_name]
            base_yaml_path = base_task_info["yaml_path"]
            base_task_config = utils.load_yaml_config(base_yaml_path)
103

104
105
106
107
108
109
            return ConfigurableTask(
                config={
                    **base_task_config,
                    **task_config,
                }
            )
110
111


112
def register_configurable_task(config: Dict[str, str]) -> int:
113
114
115
116
117
118
119
120
121
122
123
    SubClass = type(
        config["task"] + "ConfigurableTask",
        (ConfigurableTask,),
        {"CONFIG": TaskConfig(**config)},
    )

    if "task" in config:
        task_name = "{}".format(config["task"])
        register_task(task_name)(SubClass)

    if "group" in config:
baberabb's avatar
baberabb committed
124
125
126
        if config["group"] == config["task"]:
            raise ValueError("task and group name cannot be the same")
        elif type(config["group"]) == str:
127
128
129
130
131
132
133
134
135
            group_name = [config["group"]]
        else:
            group_name = config["group"]

        for group in group_name:
            register_group(group)(SubClass)

    return 0

lintangsutawika's avatar
format  
lintangsutawika committed
136

lintangsutawika's avatar
lintangsutawika committed
137
def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int:
138
139
    group = config["group"]

lintangsutawika's avatar
lintangsutawika committed
140
141
142
143
144
145
    task_config_list = []
    group_config_list = []
    registered_task_or_group_list = []
    for task in config["task"]:
        if isinstance(task, str):
            registered_task_or_group_list.append(task)
146
        elif is_group(task):
lintangsutawika's avatar
lintangsutawika committed
147
148
149
            group_config_list.append(task)
        else:
            task_config_list.append(task)
150

lintangsutawika's avatar
lintangsutawika committed
151
    for task_config in task_config_list:
152
153
154
155
        base_config = {}
        task_name_config = {}
        if "task" in task_config:
            task_name = task_config["task"]
lintangsutawika's avatar
lintangsutawika committed
156
            if task_name in TASK_REGISTRY:
157
158
159
160
161
                task_obj = get_task_dict(task_name)[task_name]
                if type(task_obj) == tuple:
                    _, task_obj = task_obj

                if task_obj is not None:
Lintang Sutawika's avatar
Lintang Sutawika committed
162
                    base_config = task_obj._config.to_dict(keep_callable=True)
163
164
                    task_name_config["task"] = f"{group}_{task_name}"

lintangsutawika's avatar
lintangsutawika committed
165
        task_config = utils.load_yaml_config(yaml_path, task_config)
166
167
        var_configs = check_prompt_config(
            {
168
                **base_config,
169
170
                **task_config,
                **{"group": group},
171
                **task_name_config,
lintangsutawika's avatar
lintangsutawika committed
172
173
            },
            yaml_path=os.path.dirname(yaml_path),
174
175
176
177
        )
        for config in var_configs:
            register_configurable_task(config)

lintangsutawika's avatar
lintangsutawika committed
178
179
180
181
182
183
184
185
186
187
    for group_config in group_config_list:
        sub_group = group_config["group"]
        register_configurable_group(group_config, yaml_path)
        if group in GROUP_REGISTRY:
            GROUP_REGISTRY[group].append(sub_group)
        else:
            GROUP_REGISTRY[group] = [sub_group]
            ALL_TASKS.add(group)

    task_names = utils.pattern_match(registered_task_or_group_list, ALL_TASKS)
188
189
190
191
192
193
194
195
196
197
    for task in task_names:
        if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
            if group in GROUP_REGISTRY:
                GROUP_REGISTRY[group].append(task)
            else:
                GROUP_REGISTRY[group] = [task]
                ALL_TASKS.add(group)

    return 0

198

lintangsutawika's avatar
lintangsutawika committed
199
200
201
def check_prompt_config(
    config: Dict[str, str], yaml_path: str = None
) -> List[Dict[str, str]]:
202
203
204
205
206
    all_configs = []
    if "use_prompt" in config:
        prompt_list = prompts.load_prompt_list(
            use_prompt=config["use_prompt"],
            dataset_name=config["dataset_path"],
lintangsutawika's avatar
lintangsutawika committed
207
            subset_name=config["dataset_name"] if "dataset_name" in config else None,
lintangsutawika's avatar
lintangsutawika committed
208
            yaml_path=yaml_path,
209
210
211
212
213
214
215
216
        )
        for idx, prompt_variation in enumerate(prompt_list):
            all_configs.append(
                {
                    **config,
                    **{"use_prompt": prompt_variation},
                    **{
                        "task": "_".join(
217
218
219
220
                            [
                                config["task"]
                                if "task" in config
                                else get_task_name_from_config(config),
lintangsutawika's avatar
lintangsutawika committed
221
222
223
                                prompt_variation.split("/")[-1]
                                if ".yaml" in prompt_variation
                                else prompt_variation,
224
225
226
                            ]
                        )
                    },
227
                    **{"output_type": "generate_until"},
228
229
230
231
232
233
234
                }
            )
    else:
        all_configs.append(config)
    return all_configs


235
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
Lintang Sutawika's avatar
Lintang Sutawika committed
236
    if "dataset_name" in task_config:
Lintang Sutawika's avatar
Lintang Sutawika committed
237
238
239
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
240
241


lintangsutawika's avatar
lintangsutawika committed
242
def include_task_folder(task_dir: str, register_task: bool = True, task_name: str = None) -> None:
243
244
245
    """
    Calling this function
    """
246
247
248

    # Track whether any tasks failed during loading
    import_fail = False
249
    for root, subdirs, file_list in os.walk(task_dir):
250
251
252
253
254
255
256
        # if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                try:
                    config = utils.load_yaml_config(yaml_path)

lintangsutawika's avatar
lintangsutawika committed
257
258
                    if "task" not in config:
                        continue
lintangsutawika's avatar
lintangsutawika committed
259

lintangsutawika's avatar
lintangsutawika committed
260
261
262
263
264
265
266
267
268
269
                    all_configs = check_prompt_config(
                        config, yaml_path=os.path.dirname(yaml_path)
                    )
                    for config in all_configs:
                        if register_task:
                            if type(config["task"]) == str:
                                register_configurable_task(config)
                        else:
                            if type(config["task"]) == list:
                                register_configurable_group(config, yaml_path)
270

lintangsutawika's avatar
lintangsutawika committed
271
                # Log this silently and show it only when
272
                # the user defines the appropriate verbosity.
273
274
                except (ImportError, ModuleNotFoundError) as e:
                    import_fail = True
275
                    eval_logger.debug(
baberabb's avatar
baberabb committed
276
277
                        f"{yaml_path}: {e}. Config will not be added to registry."
                    )
278
                except Exception as error:
lintangsutawika's avatar
lintangsutawika committed
279
                    import traceback
lintangsutawika's avatar
lintangsutawika committed
280

281
282
                    eval_logger.warning(
                        "Unexpected error loading config in\n"
283
284
                        f"                                 {yaml_path}\n"
                        "                                 Config will not be added to registry\n"
lintangsutawika's avatar
lintangsutawika committed
285
286
                        f"                                 Error: {error}\n"
                        f"                                 Traceback: {traceback.format_exc()}"
287
                    )
288
289
290
291
292
293

    if import_fail:
        eval_logger.warning(
          "Some tasks could not be loaded due to missing dependencies."
          " Run with `--verbosity DEBUG` for full details."
          )
lintangsutawika's avatar
lintangsutawika committed
294
    return 0
295
296


297
def get_task_and_group(task_dir: str):
298
    tasks_and_groups = collections.defaultdict()
299
300
301
302
303
    for root, _, file_list in os.walk(task_dir):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                config = utils.simple_load_yaml_config(yaml_path)
304
305
306
307
                if list(config.keys()) == ["group", "task"]:
                    # This is a group config
                    tasks_and_groups[config["group"]] = {
                        "type": "group",
lintangsutawika's avatar
lintangsutawika committed
308
309
310
311
                        "task": -1, # This signals that
                                    # we don't need to know
                                    # the task list for indexing
                                    # as it can be loaded
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
                                    # when called.
                        "yaml_path": yaml_path,
                    }
                else:
                    # This is a task config
                    task = config["task"]
                    tasks_and_groups[task] = {
                        "type": "task",
                        "yaml_path": yaml_path,
                        }

                    if "group" in config:
                        groups = config["group"]
                        if isinstance(config["group"], str):
                            groups = [groups]

                        for group in groups:
                            if group not in tasks_and_groups:
                                tasks_and_groups[group] = {
                                    "type": "group",
                                    "task": [task],
                                    "yaml_path": -1,
                                }
                            else:
                                tasks_and_groups[group]["task"].append(task)

    return tasks_and_groups
339
340

def initialize_tasks(verbosity="INFO", include_path=None):
341
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
342
343
344
345
346
347
    all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
    if include_path is not None:
        if isinstance(include_path, str):
            include_path = [include_path]
        all_paths.extend(include_path)

348
    ALL_TASKS = {}
349
    for task_dir in all_paths:
350
351
        tasks = get_task_and_group(task_dir)
        ALL_TASKS = {**tasks, **ALL_TASKS}
352

353
    return ALL_TASKS
lintangsutawika's avatar
lintangsutawika committed
354

355
356
def get_task(task_name, config):
    try:
357
        return TASK_REGISTRY[task_name](config=config)
358
    except KeyError:
lintangsutawika's avatar
lintangsutawika committed
359
        eval_logger.info("Available tasks:")
360
        eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
lintangsutawika's avatar
lintangsutawika committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        raise KeyError(f"Missing task {task_name}")


def get_task_name_from_object(task_object):
    for name, class_ in TASK_REGISTRY.items():
        if class_ is task_object:
            return name

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )


# TODO: pass num_fewshot and other cmdline overrides in a better way
379
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
lintangsutawika's avatar
lintangsutawika committed
380
    config = {**kwargs}
381
382
383
384
385

    task_name_from_registry_dict = {}
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

lintangsutawika's avatar
lintangsutawika committed
386
387
388
    if type(task_name_list) != list:
        task_name_list = [task_name_list]

389
390
391
    for task_element in task_name_list:
        if isinstance(task_element, str):
            if task_element in GROUP_REGISTRY:
392
                group_name = task_element
393
394
                for task_name in GROUP_REGISTRY[task_element]:
                    if task_name not in task_name_from_registry_dict:
lintangsutawika's avatar
lintangsutawika committed
395
396
397
398
399
400
401
402
403
404
405
                        task_obj = get_task_dict(task_name)
                        if task_name in task_obj.keys():
                            task_dict = {
                                task_name: (group_name, task_obj[task_name]),
                            }
                        else:
                            task_dict = {
                                task_name: (group_name, None),
                                **task_obj,
                            }

406
407
                        task_name_from_registry_dict = {
                            **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
408
                            **task_dict,
lintangsutawika's avatar
lintangsutawika committed
409
                        }
410
            else:
411
                task_name = task_element
412
413
414
                if task_name not in task_name_from_registry_dict:
                    task_name_from_registry_dict = {
                        **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
415
416
                        task_name: get_task(task_name=task_element, config=config),
                    }
417
418

        elif isinstance(task_element, dict):
419
            task_element.update(config)
420
421
422
            task_name_from_config_dict = {
                **task_name_from_config_dict,
                get_task_name_from_config(task_element): ConfigurableTask(
423
                    config=task_element
lintangsutawika's avatar
lintangsutawika committed
424
                ),
425
426
427
428
429
            }

        elif isinstance(task_element, Task):
            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
430
                get_task_name_from_object(task_element): task_element,
431
            }
lintangsutawika's avatar
lintangsutawika committed
432
433
434
435

    assert set(task_name_from_registry_dict.keys()).isdisjoint(
        set(task_name_from_object_dict.keys())
    )
lintangsutawika's avatar
lintangsutawika committed
436
437
438
    return {
        **task_name_from_registry_dict,
        **task_name_from_config_dict,
439
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
440
    }