__init__.py 17.4 KB
Newer Older
1
import os
2
import abc
lintangsutawika's avatar
lintangsutawika committed
3
import yaml
4
import collections
5
from typing import List, Union, Dict
&'s avatar
& committed
6

7
from lm_eval import utils
8
from lm_eval import prompts
9
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
lintangsutawika's avatar
lintangsutawika committed
10
from lm_eval.api.registry import (
11
12
    register_task,
    register_group,
lintangsutawika's avatar
lintangsutawika committed
13
14
    TASK_REGISTRY,
    GROUP_REGISTRY,
lintangsutawika's avatar
lintangsutawika committed
15
)
lintangsutawika's avatar
lintangsutawika committed
16

lintangsutawika's avatar
lintangsutawika committed
17
import logging
lintangsutawika's avatar
lintangsutawika committed
18

lintangsutawika's avatar
lintangsutawika committed
19
# import python tasks
lintangsutawika's avatar
lintangsutawika committed
20
from .squadv2.task import SQuAD2
lintangsutawika's avatar
lintangsutawika committed
21
22
23
24
25
26
27
28
from .scrolls.task import (
    QuALITY,
    NarrativeQA,
    ContractNLI,
    GovReport,
    SummScreenFD,
    QMSum,
)
lintangsutawika's avatar
lintangsutawika committed
29

30
eval_logger = utils.eval_logger
31

lintangsutawika's avatar
lintangsutawika committed
32

33
34
35
36
37
def is_group(task):
    if list(task.keys()) == ["group", "task"]:
        return True
    return False

38
class TaskManager(abc.ABC):
39

40
41
42
43
44
    def __init__(
        self, 
        verbosity="INFO",
        include_path=None
        ) -> None:
45

lintangsutawika's avatar
lintangsutawika committed
46
47
        self.verbosity = verbosity
        self.include_path = include_path
lintangsutawika's avatar
lintangsutawika committed
48
        self.logger = eval_logger.setLevel(getattr(logging, f"{verbosity}"))
lintangsutawika's avatar
lintangsutawika committed
49
50

        self.ALL_TASKS = self.initialize_tasks(
51
52
53
            include_path=include_path
            )

lintangsutawika's avatar
lintangsutawika committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def initialize_tasks(self, include_path=None):
        
        all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
        if include_path is not None:
            if isinstance(include_path, str):
                include_path = [include_path]
            all_paths.extend(include_path)

        ALL_TASKS = {}
        for task_dir in all_paths:
            tasks = get_task_and_group(task_dir)
            ALL_TASKS = {**tasks, **ALL_TASKS}

        return ALL_TASKS

69
    def all_tasks(self):
lintangsutawika's avatar
lintangsutawika committed
70
        return sorted(list(self.ALL_TASKS.keys()))
71
72
73
74
75
76
77
78
79
80

    def _load_individual_task_or_group(self, task_name_or_config: Union[str, dict] = None) -> ConfigurableTask:

        print("Loading", task_name_or_config)
        if isinstance(task_name_or_config, str):
            task_info = self.ALL_TASKS[task_name_or_config]
            yaml_path = task_info["yaml_path"]
            task_type = task_info["type"]
            subtask_list = task_info["task"] if "task" in task_info else -1
            if task_type == "task":
81
                task_config = utils.load_yaml_config(yaml_path)
82
                return ConfigurableTask(config=task_config)
83
            else:
84
85
86
87
                if subtask_list == -1:
                    task_config = utils.load_yaml_config(yaml_path)
                    group_name = task_config["group"]
                    subtask_list = task_config["task"]
88
                else:
89
90
91
92
93
94
                    group_name = task_name_or_config

                all_subtasks = {}
                for task_or_config in subtask_list:
                    if isinstance(task_or_config, str):
                        all_subtasks[task_or_config] = (group_name, None)
lintangsutawika's avatar
lintangsutawika committed
95
                        task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
96
97
98
99
100
                    elif isinstance(task_or_config, dict):
                        if "group" in task_or_config:
                            all_subtasks[task_or_config["group"]] = (group_name, None)
                        elif "task" in task_or_config:
                            all_subtasks[task_or_config["task"]] = (group_name, None)
lintangsutawika's avatar
lintangsutawika committed
101
                        task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

                    if isinstance(task_object, dict):
                        all_subtasks = {**task_object, **all_subtasks}
                    else:
                        task_name = task_object._config["task"]
                        all_subtasks[task_name] = (group_name, task_object)
                        # if group_name is not None:
                        #     all_subtasks[task_name] = (group_name, task_object)
                        # else:
                        #     all_subtasks[task_name] = task_object
                return all_subtasks
        elif isinstance(task_name_or_config, dict):
            if is_group(task_name_or_config):
                group_name = task_name_or_config["group"]
                subtask_list = task_name_or_config["task"]
                all_subtasks = {}
                for task_or_config in subtask_list:
                    if isinstance(task_or_config, str):
lintangsutawika's avatar
lintangsutawika committed
120
                        task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
121
122
                        task_name = task_or_config
                    elif isinstance(task_or_config, dict):
lintangsutawika's avatar
lintangsutawika committed
123
                        task_object = self._load_individual_task_or_group(task_name_or_config=task_or_config)
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

                    if isinstance(task_object, dict):
                        all_subtasks = {**task_object, **all_subtasks}
                    else:
                        task_name = task_object._config["task"]
                        all_subtasks[task_name] = (group_name, task_object)
                return all_subtasks
            else:
                task_type = "task"
                task_name = task_name_or_config["task"]
                base_task_info = self.ALL_TASKS[task_name]
                base_yaml_path = base_task_info["yaml_path"]
                base_task_config = utils.load_yaml_config(base_yaml_path)

                return ConfigurableTask(
                    config={
                        **base_task_config,
                        **task_name_or_config,
                    }
                )

    def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:

        if isinstance(task_list, str):
            task_list = [task_list]

        all_loaded_tasks = {}
        for task in task_list:
            task_object = self._load_individual_task_or_group(
                task_name_or_config=task,
154
            )
155
156
157
158
159
160
161
162
163
164
165
            if isinstance(task, str):
                task_name = task
            elif isinstance(task, dict):
                task_name = task["task"]

            if isinstance(task_object, dict):
                all_loaded_tasks = {**task_object, **self.ALL_TASKS}
            else:
                all_loaded_tasks[task_name] = task_object
        
        return all_loaded_tasks
166
167


168
def register_configurable_task(config: Dict[str, str]) -> int:
169
170
171
172
173
174
175
176
177
178
179
    SubClass = type(
        config["task"] + "ConfigurableTask",
        (ConfigurableTask,),
        {"CONFIG": TaskConfig(**config)},
    )

    if "task" in config:
        task_name = "{}".format(config["task"])
        register_task(task_name)(SubClass)

    if "group" in config:
baberabb's avatar
baberabb committed
180
181
182
        if config["group"] == config["task"]:
            raise ValueError("task and group name cannot be the same")
        elif type(config["group"]) == str:
183
184
185
186
187
188
189
190
191
            group_name = [config["group"]]
        else:
            group_name = config["group"]

        for group in group_name:
            register_group(group)(SubClass)

    return 0

lintangsutawika's avatar
format  
lintangsutawika committed
192

lintangsutawika's avatar
lintangsutawika committed
193
def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int:
194
195
    group = config["group"]

lintangsutawika's avatar
lintangsutawika committed
196
197
198
199
200
201
    task_config_list = []
    group_config_list = []
    registered_task_or_group_list = []
    for task in config["task"]:
        if isinstance(task, str):
            registered_task_or_group_list.append(task)
202
        elif is_group(task):
lintangsutawika's avatar
lintangsutawika committed
203
204
205
            group_config_list.append(task)
        else:
            task_config_list.append(task)
206

lintangsutawika's avatar
lintangsutawika committed
207
    for task_config in task_config_list:
208
209
210
211
        base_config = {}
        task_name_config = {}
        if "task" in task_config:
            task_name = task_config["task"]
lintangsutawika's avatar
lintangsutawika committed
212
            if task_name in TASK_REGISTRY:
213
214
215
216
217
                task_obj = get_task_dict(task_name)[task_name]
                if type(task_obj) == tuple:
                    _, task_obj = task_obj

                if task_obj is not None:
Lintang Sutawika's avatar
Lintang Sutawika committed
218
                    base_config = task_obj._config.to_dict(keep_callable=True)
219
220
                    task_name_config["task"] = f"{group}_{task_name}"

lintangsutawika's avatar
lintangsutawika committed
221
        task_config = utils.load_yaml_config(yaml_path, task_config)
222
223
        var_configs = check_prompt_config(
            {
224
                **base_config,
225
226
                **task_config,
                **{"group": group},
227
                **task_name_config,
lintangsutawika's avatar
lintangsutawika committed
228
229
            },
            yaml_path=os.path.dirname(yaml_path),
230
231
232
233
        )
        for config in var_configs:
            register_configurable_task(config)

lintangsutawika's avatar
lintangsutawika committed
234
235
236
237
238
239
240
    for group_config in group_config_list:
        sub_group = group_config["group"]
        register_configurable_group(group_config, yaml_path)
        if group in GROUP_REGISTRY:
            GROUP_REGISTRY[group].append(sub_group)
        else:
            GROUP_REGISTRY[group] = [sub_group]
241
            self.ALL_TASKS.add(group)
lintangsutawika's avatar
lintangsutawika committed
242

243
    task_names = utils.pattern_match(registered_task_or_group_list, self.ALL_TASKS)
244
245
246
247
248
249
    for task in task_names:
        if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
            if group in GROUP_REGISTRY:
                GROUP_REGISTRY[group].append(task)
            else:
                GROUP_REGISTRY[group] = [task]
250
                self.ALL_TASKS.add(group)
251
252
253

    return 0

254

lintangsutawika's avatar
lintangsutawika committed
255
256
257
def check_prompt_config(
    config: Dict[str, str], yaml_path: str = None
) -> List[Dict[str, str]]:
258
259
260
261
262
    all_configs = []
    if "use_prompt" in config:
        prompt_list = prompts.load_prompt_list(
            use_prompt=config["use_prompt"],
            dataset_name=config["dataset_path"],
lintangsutawika's avatar
lintangsutawika committed
263
            subset_name=config["dataset_name"] if "dataset_name" in config else None,
lintangsutawika's avatar
lintangsutawika committed
264
            yaml_path=yaml_path,
265
266
267
268
269
270
271
272
        )
        for idx, prompt_variation in enumerate(prompt_list):
            all_configs.append(
                {
                    **config,
                    **{"use_prompt": prompt_variation},
                    **{
                        "task": "_".join(
273
274
275
276
                            [
                                config["task"]
                                if "task" in config
                                else get_task_name_from_config(config),
lintangsutawika's avatar
lintangsutawika committed
277
278
279
                                prompt_variation.split("/")[-1]
                                if ".yaml" in prompt_variation
                                else prompt_variation,
280
281
282
                            ]
                        )
                    },
283
                    **{"output_type": "generate_until"},
284
285
286
287
288
289
290
                }
            )
    else:
        all_configs.append(config)
    return all_configs


291
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
Lintang Sutawika's avatar
Lintang Sutawika committed
292
    if "dataset_name" in task_config:
Lintang Sutawika's avatar
Lintang Sutawika committed
293
294
295
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
296
297


lintangsutawika's avatar
lintangsutawika committed
298
def include_task_folder(task_dir: str, register_task: bool = True, task_name: str = None) -> None:
299
300
301
    """
    Calling this function
    """
302
303
304

    # Track whether any tasks failed during loading
    import_fail = False
305
    for root, subdirs, file_list in os.walk(task_dir):
306
307
308
309
310
311
312
        # if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                try:
                    config = utils.load_yaml_config(yaml_path)

lintangsutawika's avatar
lintangsutawika committed
313
314
                    if "task" not in config:
                        continue
lintangsutawika's avatar
lintangsutawika committed
315

lintangsutawika's avatar
lintangsutawika committed
316
317
318
319
320
321
322
323
324
325
                    all_configs = check_prompt_config(
                        config, yaml_path=os.path.dirname(yaml_path)
                    )
                    for config in all_configs:
                        if register_task:
                            if type(config["task"]) == str:
                                register_configurable_task(config)
                        else:
                            if type(config["task"]) == list:
                                register_configurable_group(config, yaml_path)
326

lintangsutawika's avatar
lintangsutawika committed
327
                # Log this silently and show it only when
328
                # the user defines the appropriate verbosity.
329
330
                except (ImportError, ModuleNotFoundError) as e:
                    import_fail = True
331
                    eval_logger.debug(
baberabb's avatar
baberabb committed
332
333
                        f"{yaml_path}: {e}. Config will not be added to registry."
                    )
334
                except Exception as error:
lintangsutawika's avatar
lintangsutawika committed
335
                    import traceback
lintangsutawika's avatar
lintangsutawika committed
336

337
338
                    eval_logger.warning(
                        "Unexpected error loading config in\n"
339
340
                        f"                                 {yaml_path}\n"
                        "                                 Config will not be added to registry\n"
lintangsutawika's avatar
lintangsutawika committed
341
342
                        f"                                 Error: {error}\n"
                        f"                                 Traceback: {traceback.format_exc()}"
343
                    )
344
345
346
347
348
349

    if import_fail:
        eval_logger.warning(
          "Some tasks could not be loaded due to missing dependencies."
          " Run with `--verbosity DEBUG` for full details."
          )
lintangsutawika's avatar
lintangsutawika committed
350
    return 0
351
352


353
def get_task_and_group(task_dir: str):
354
    tasks_and_groups = collections.defaultdict()
355
356
357
358
359
    for root, _, file_list in os.walk(task_dir):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                config = utils.simple_load_yaml_config(yaml_path)
360
361
362
363
                if list(config.keys()) == ["group", "task"]:
                    # This is a group config
                    tasks_and_groups[config["group"]] = {
                        "type": "group",
lintangsutawika's avatar
lintangsutawika committed
364
365
366
367
                        "task": -1, # This signals that
                                    # we don't need to know
                                    # the task list for indexing
                                    # as it can be loaded
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
                                    # when called.
                        "yaml_path": yaml_path,
                    }
                else:
                    # This is a task config
                    task = config["task"]
                    tasks_and_groups[task] = {
                        "type": "task",
                        "yaml_path": yaml_path,
                        }

                    if "group" in config:
                        groups = config["group"]
                        if isinstance(config["group"], str):
                            groups = [groups]

                        for group in groups:
                            if group not in tasks_and_groups:
                                tasks_and_groups[group] = {
                                    "type": "group",
                                    "task": [task],
                                    "yaml_path": -1,
                                }
                            else:
                                tasks_and_groups[group]["task"].append(task)

    return tasks_and_groups
395
396


lintangsutawika's avatar
lintangsutawika committed
397

398
399
def get_task(task_name, config):
    try:
400
        return TASK_REGISTRY[task_name](config=config)
401
    except KeyError:
lintangsutawika's avatar
lintangsutawika committed
402
        eval_logger.info("Available tasks:")
403
        eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
lintangsutawika's avatar
lintangsutawika committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
        raise KeyError(f"Missing task {task_name}")


def get_task_name_from_object(task_object):
    for name, class_ in TASK_REGISTRY.items():
        if class_ is task_object:
            return name

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )


# TODO: pass num_fewshot and other cmdline overrides in a better way
422
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
lintangsutawika's avatar
lintangsutawika committed
423
    config = {**kwargs}
424
425
426
427
428

    task_name_from_registry_dict = {}
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

lintangsutawika's avatar
lintangsutawika committed
429
430
431
    if type(task_name_list) != list:
        task_name_list = [task_name_list]

432
433
434
    for task_element in task_name_list:
        if isinstance(task_element, str):
            if task_element in GROUP_REGISTRY:
435
                group_name = task_element
436
437
                for task_name in GROUP_REGISTRY[task_element]:
                    if task_name not in task_name_from_registry_dict:
lintangsutawika's avatar
lintangsutawika committed
438
439
440
441
442
443
444
445
446
447
448
                        task_obj = get_task_dict(task_name)
                        if task_name in task_obj.keys():
                            task_dict = {
                                task_name: (group_name, task_obj[task_name]),
                            }
                        else:
                            task_dict = {
                                task_name: (group_name, None),
                                **task_obj,
                            }

449
450
                        task_name_from_registry_dict = {
                            **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
451
                            **task_dict,
lintangsutawika's avatar
lintangsutawika committed
452
                        }
453
            else:
454
                task_name = task_element
455
456
457
                if task_name not in task_name_from_registry_dict:
                    task_name_from_registry_dict = {
                        **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
458
459
                        task_name: get_task(task_name=task_element, config=config),
                    }
460
461

        elif isinstance(task_element, dict):
462
            task_element.update(config)
463
464
465
            task_name_from_config_dict = {
                **task_name_from_config_dict,
                get_task_name_from_config(task_element): ConfigurableTask(
466
                    config=task_element
lintangsutawika's avatar
lintangsutawika committed
467
                ),
468
469
470
471
472
            }

        elif isinstance(task_element, Task):
            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
473
                get_task_name_from_object(task_element): task_element,
474
            }
lintangsutawika's avatar
lintangsutawika committed
475
476
477
478

    assert set(task_name_from_registry_dict.keys()).isdisjoint(
        set(task_name_from_object_dict.keys())
    )
lintangsutawika's avatar
lintangsutawika committed
479
480
481
    return {
        **task_name_from_registry_dict,
        **task_name_from_config_dict,
482
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
483
    }