__init__.py 17.5 KB
Newer Older
1
import os
2
import abc
lintangsutawika's avatar
lintangsutawika committed
3
import yaml
4
import collections
5
from typing import List, Union, Dict
&'s avatar
& committed
6

7
from lm_eval import utils
8
from lm_eval import prompts
9
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
lintangsutawika's avatar
lintangsutawika committed
10
from lm_eval.api.registry import (
11
12
    register_task,
    register_group,
lintangsutawika's avatar
lintangsutawika committed
13
14
    TASK_REGISTRY,
    GROUP_REGISTRY,
lintangsutawika's avatar
lintangsutawika committed
15
)
lintangsutawika's avatar
lintangsutawika committed
16

lintangsutawika's avatar
lintangsutawika committed
17
import logging
lintangsutawika's avatar
lintangsutawika committed
18

lintangsutawika's avatar
lintangsutawika committed
19
# import python tasks
lintangsutawika's avatar
lintangsutawika committed
20
from .squadv2.task import SQuAD2
lintangsutawika's avatar
lintangsutawika committed
21
22
23
24
25
26
27
28
from .scrolls.task import (
    QuALITY,
    NarrativeQA,
    ContractNLI,
    GovReport,
    SummScreenFD,
    QMSum,
)
lintangsutawika's avatar
lintangsutawika committed
29

30
eval_logger = utils.eval_logger
31

lintangsutawika's avatar
lintangsutawika committed
32

33
34
35
36
37
def is_group(task):
    if list(task.keys()) == ["group", "task"]:
        return True
    return False

38
class TaskManager(abc.ABC):
39

40
41
42
43
44
    def __init__(
        self, 
        verbosity="INFO",
        include_path=None
        ) -> None:
45

lintangsutawika's avatar
lintangsutawika committed
46
47
48
49
50
        self.verbosity = verbosity
        self.include_path = include_path
        self.eval_logger.setLevel(getattr(logging, f"{verbosity}"))

        self.ALL_TASKS = self.initialize_tasks(
51
52
53
            include_path=include_path
            )

lintangsutawika's avatar
lintangsutawika committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def initialize_tasks(self, include_path=None):
        
        all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
        if include_path is not None:
            if isinstance(include_path, str):
                include_path = [include_path]
            all_paths.extend(include_path)

        ALL_TASKS = {}
        for task_dir in all_paths:
            tasks = get_task_and_group(task_dir)
            ALL_TASKS = {**tasks, **ALL_TASKS}

        return ALL_TASKS

69
70
71
72
73
74
75
76
77
78
79
80
81
    @property
    def all_tasks(self):
        return sorted(self.ALL_TASKS.keys())

    def _load_individual_task_or_group(self, task_name_or_config: Union[str, dict] = None) -> ConfigurableTask:

        print("Loading", task_name_or_config)
        if isinstance(task_name_or_config, str):
            task_info = self.ALL_TASKS[task_name_or_config]
            yaml_path = task_info["yaml_path"]
            task_type = task_info["type"]
            subtask_list = task_info["task"] if "task" in task_info else -1
            if task_type == "task":
82
                task_config = utils.load_yaml_config(yaml_path)
83
                return ConfigurableTask(config=task_config)
84
            else:
85
86
87
88
                if subtask_list == -1:
                    task_config = utils.load_yaml_config(yaml_path)
                    group_name = task_config["group"]
                    subtask_list = task_config["task"]
89
                else:
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                    group_name = task_name_or_config

                all_subtasks = {}
                for task_or_config in subtask_list:
                    if isinstance(task_or_config, str):
                        all_subtasks[task_or_config] = (group_name, None)
                        task_object = self._load_individual_task_or_group(self.ALL_TASKS, task_name_or_config=task_or_config)
                    elif isinstance(task_or_config, dict):
                        if "group" in task_or_config:
                            all_subtasks[task_or_config["group"]] = (group_name, None)
                        elif "task" in task_or_config:
                            all_subtasks[task_or_config["task"]] = (group_name, None)
                        task_object = self._load_individual_task_or_group(self.ALL_TASKS, task_name_or_config=task_or_config)

                    if isinstance(task_object, dict):
                        all_subtasks = {**task_object, **all_subtasks}
                    else:
                        task_name = task_object._config["task"]
                        all_subtasks[task_name] = (group_name, task_object)
                        # if group_name is not None:
                        #     all_subtasks[task_name] = (group_name, task_object)
                        # else:
                        #     all_subtasks[task_name] = task_object
                return all_subtasks
        elif isinstance(task_name_or_config, dict):
            if is_group(task_name_or_config):
                group_name = task_name_or_config["group"]
                subtask_list = task_name_or_config["task"]
                all_subtasks = {}
                for task_or_config in subtask_list:
                    if isinstance(task_or_config, str):
                        task_object = self._load_individual_task_or_group(self.ALL_TASKS, task_name_or_config=task_or_config)
                        task_name = task_or_config
                    elif isinstance(task_or_config, dict):
                        task_object = self._load_individual_task_or_group(self.ALL_TASKS, task_name_or_config=task_or_config)

                    if isinstance(task_object, dict):
                        all_subtasks = {**task_object, **all_subtasks}
                    else:
                        task_name = task_object._config["task"]
                        all_subtasks[task_name] = (group_name, task_object)
                return all_subtasks
            else:
                task_type = "task"
                task_name = task_name_or_config["task"]
                base_task_info = self.ALL_TASKS[task_name]
                base_yaml_path = base_task_info["yaml_path"]
                base_task_config = utils.load_yaml_config(base_yaml_path)

                return ConfigurableTask(
                    config={
                        **base_task_config,
                        **task_name_or_config,
                    }
                )

    def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:

        if isinstance(task_list, str):
            task_list = [task_list]

        all_loaded_tasks = {}
        for task in task_list:
            task_object = self._load_individual_task_or_group(
                task_name_or_config=task,
155
            )
156
157
158
159
160
161
162
163
164
165
166
            if isinstance(task, str):
                task_name = task
            elif isinstance(task, dict):
                task_name = task["task"]

            if isinstance(task_object, dict):
                all_loaded_tasks = {**task_object, **self.ALL_TASKS}
            else:
                all_loaded_tasks[task_name] = task_object
        
        return all_loaded_tasks
167
168


169
def register_configurable_task(config: Dict[str, str]) -> int:
170
171
172
173
174
175
176
177
178
179
180
    SubClass = type(
        config["task"] + "ConfigurableTask",
        (ConfigurableTask,),
        {"CONFIG": TaskConfig(**config)},
    )

    if "task" in config:
        task_name = "{}".format(config["task"])
        register_task(task_name)(SubClass)

    if "group" in config:
baberabb's avatar
baberabb committed
181
182
183
        if config["group"] == config["task"]:
            raise ValueError("task and group name cannot be the same")
        elif type(config["group"]) == str:
184
185
186
187
188
189
190
191
192
            group_name = [config["group"]]
        else:
            group_name = config["group"]

        for group in group_name:
            register_group(group)(SubClass)

    return 0

lintangsutawika's avatar
format  
lintangsutawika committed
193

lintangsutawika's avatar
lintangsutawika committed
194
def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int:
195
196
    group = config["group"]

lintangsutawika's avatar
lintangsutawika committed
197
198
199
200
201
202
    task_config_list = []
    group_config_list = []
    registered_task_or_group_list = []
    for task in config["task"]:
        if isinstance(task, str):
            registered_task_or_group_list.append(task)
203
        elif is_group(task):
lintangsutawika's avatar
lintangsutawika committed
204
205
206
            group_config_list.append(task)
        else:
            task_config_list.append(task)
207

lintangsutawika's avatar
lintangsutawika committed
208
    for task_config in task_config_list:
209
210
211
212
        base_config = {}
        task_name_config = {}
        if "task" in task_config:
            task_name = task_config["task"]
lintangsutawika's avatar
lintangsutawika committed
213
            if task_name in TASK_REGISTRY:
214
215
216
217
218
                task_obj = get_task_dict(task_name)[task_name]
                if type(task_obj) == tuple:
                    _, task_obj = task_obj

                if task_obj is not None:
Lintang Sutawika's avatar
Lintang Sutawika committed
219
                    base_config = task_obj._config.to_dict(keep_callable=True)
220
221
                    task_name_config["task"] = f"{group}_{task_name}"

lintangsutawika's avatar
lintangsutawika committed
222
        task_config = utils.load_yaml_config(yaml_path, task_config)
223
224
        var_configs = check_prompt_config(
            {
225
                **base_config,
226
227
                **task_config,
                **{"group": group},
228
                **task_name_config,
lintangsutawika's avatar
lintangsutawika committed
229
230
            },
            yaml_path=os.path.dirname(yaml_path),
231
232
233
234
        )
        for config in var_configs:
            register_configurable_task(config)

lintangsutawika's avatar
lintangsutawika committed
235
236
237
238
239
240
241
    for group_config in group_config_list:
        sub_group = group_config["group"]
        register_configurable_group(group_config, yaml_path)
        if group in GROUP_REGISTRY:
            GROUP_REGISTRY[group].append(sub_group)
        else:
            GROUP_REGISTRY[group] = [sub_group]
242
            self.ALL_TASKS.add(group)
lintangsutawika's avatar
lintangsutawika committed
243

244
    task_names = utils.pattern_match(registered_task_or_group_list, self.ALL_TASKS)
245
246
247
248
249
250
    for task in task_names:
        if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
            if group in GROUP_REGISTRY:
                GROUP_REGISTRY[group].append(task)
            else:
                GROUP_REGISTRY[group] = [task]
251
                self.ALL_TASKS.add(group)
252
253
254

    return 0

255

lintangsutawika's avatar
lintangsutawika committed
256
257
258
def check_prompt_config(
    config: Dict[str, str], yaml_path: str = None
) -> List[Dict[str, str]]:
259
260
261
262
263
    all_configs = []
    if "use_prompt" in config:
        prompt_list = prompts.load_prompt_list(
            use_prompt=config["use_prompt"],
            dataset_name=config["dataset_path"],
lintangsutawika's avatar
lintangsutawika committed
264
            subset_name=config["dataset_name"] if "dataset_name" in config else None,
lintangsutawika's avatar
lintangsutawika committed
265
            yaml_path=yaml_path,
266
267
268
269
270
271
272
273
        )
        for idx, prompt_variation in enumerate(prompt_list):
            all_configs.append(
                {
                    **config,
                    **{"use_prompt": prompt_variation},
                    **{
                        "task": "_".join(
274
275
276
277
                            [
                                config["task"]
                                if "task" in config
                                else get_task_name_from_config(config),
lintangsutawika's avatar
lintangsutawika committed
278
279
280
                                prompt_variation.split("/")[-1]
                                if ".yaml" in prompt_variation
                                else prompt_variation,
281
282
283
                            ]
                        )
                    },
284
                    **{"output_type": "generate_until"},
285
286
287
288
289
290
291
                }
            )
    else:
        all_configs.append(config)
    return all_configs


292
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
Lintang Sutawika's avatar
Lintang Sutawika committed
293
    if "dataset_name" in task_config:
Lintang Sutawika's avatar
Lintang Sutawika committed
294
295
296
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
297
298


lintangsutawika's avatar
lintangsutawika committed
299
def include_task_folder(task_dir: str, register_task: bool = True, task_name: str = None) -> None:
300
301
302
    """
    Calling this function
    """
303
304
305

    # Track whether any tasks failed during loading
    import_fail = False
306
    for root, subdirs, file_list in os.walk(task_dir):
307
308
309
310
311
312
313
        # if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                try:
                    config = utils.load_yaml_config(yaml_path)

lintangsutawika's avatar
lintangsutawika committed
314
315
                    if "task" not in config:
                        continue
lintangsutawika's avatar
lintangsutawika committed
316

lintangsutawika's avatar
lintangsutawika committed
317
318
319
320
321
322
323
324
325
326
                    all_configs = check_prompt_config(
                        config, yaml_path=os.path.dirname(yaml_path)
                    )
                    for config in all_configs:
                        if register_task:
                            if type(config["task"]) == str:
                                register_configurable_task(config)
                        else:
                            if type(config["task"]) == list:
                                register_configurable_group(config, yaml_path)
327

lintangsutawika's avatar
lintangsutawika committed
328
                # Log this silently and show it only when
329
                # the user defines the appropriate verbosity.
330
331
                except (ImportError, ModuleNotFoundError) as e:
                    import_fail = True
332
                    eval_logger.debug(
baberabb's avatar
baberabb committed
333
334
                        f"{yaml_path}: {e}. Config will not be added to registry."
                    )
335
                except Exception as error:
lintangsutawika's avatar
lintangsutawika committed
336
                    import traceback
lintangsutawika's avatar
lintangsutawika committed
337

338
339
                    eval_logger.warning(
                        "Unexpected error loading config in\n"
340
341
                        f"                                 {yaml_path}\n"
                        "                                 Config will not be added to registry\n"
lintangsutawika's avatar
lintangsutawika committed
342
343
                        f"                                 Error: {error}\n"
                        f"                                 Traceback: {traceback.format_exc()}"
344
                    )
345
346
347
348
349
350

    if import_fail:
        eval_logger.warning(
          "Some tasks could not be loaded due to missing dependencies."
          " Run with `--verbosity DEBUG` for full details."
          )
lintangsutawika's avatar
lintangsutawika committed
351
    return 0
352
353


354
def get_task_and_group(task_dir: str):
355
    tasks_and_groups = collections.defaultdict()
356
357
358
359
360
    for root, _, file_list in os.walk(task_dir):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                config = utils.simple_load_yaml_config(yaml_path)
361
362
363
364
                if list(config.keys()) == ["group", "task"]:
                    # This is a group config
                    tasks_and_groups[config["group"]] = {
                        "type": "group",
lintangsutawika's avatar
lintangsutawika committed
365
366
367
368
                        "task": -1, # This signals that
                                    # we don't need to know
                                    # the task list for indexing
                                    # as it can be loaded
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
                                    # when called.
                        "yaml_path": yaml_path,
                    }
                else:
                    # This is a task config
                    task = config["task"]
                    tasks_and_groups[task] = {
                        "type": "task",
                        "yaml_path": yaml_path,
                        }

                    if "group" in config:
                        groups = config["group"]
                        if isinstance(config["group"], str):
                            groups = [groups]

                        for group in groups:
                            if group not in tasks_and_groups:
                                tasks_and_groups[group] = {
                                    "type": "group",
                                    "task": [task],
                                    "yaml_path": -1,
                                }
                            else:
                                tasks_and_groups[group]["task"].append(task)

    return tasks_and_groups
396
397


lintangsutawika's avatar
lintangsutawika committed
398

399
400
def get_task(task_name, config):
    try:
401
        return TASK_REGISTRY[task_name](config=config)
402
    except KeyError:
lintangsutawika's avatar
lintangsutawika committed
403
        eval_logger.info("Available tasks:")
404
        eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
lintangsutawika's avatar
lintangsutawika committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        raise KeyError(f"Missing task {task_name}")


def get_task_name_from_object(task_object):
    for name, class_ in TASK_REGISTRY.items():
        if class_ is task_object:
            return name

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )


# TODO: pass num_fewshot and other cmdline overrides in a better way
423
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
lintangsutawika's avatar
lintangsutawika committed
424
    config = {**kwargs}
425
426
427
428
429

    task_name_from_registry_dict = {}
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

lintangsutawika's avatar
lintangsutawika committed
430
431
432
    if type(task_name_list) != list:
        task_name_list = [task_name_list]

433
434
435
    for task_element in task_name_list:
        if isinstance(task_element, str):
            if task_element in GROUP_REGISTRY:
436
                group_name = task_element
437
438
                for task_name in GROUP_REGISTRY[task_element]:
                    if task_name not in task_name_from_registry_dict:
lintangsutawika's avatar
lintangsutawika committed
439
440
441
442
443
444
445
446
447
448
449
                        task_obj = get_task_dict(task_name)
                        if task_name in task_obj.keys():
                            task_dict = {
                                task_name: (group_name, task_obj[task_name]),
                            }
                        else:
                            task_dict = {
                                task_name: (group_name, None),
                                **task_obj,
                            }

450
451
                        task_name_from_registry_dict = {
                            **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
452
                            **task_dict,
lintangsutawika's avatar
lintangsutawika committed
453
                        }
454
            else:
455
                task_name = task_element
456
457
458
                if task_name not in task_name_from_registry_dict:
                    task_name_from_registry_dict = {
                        **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
459
460
                        task_name: get_task(task_name=task_element, config=config),
                    }
461
462

        elif isinstance(task_element, dict):
463
            task_element.update(config)
464
465
466
            task_name_from_config_dict = {
                **task_name_from_config_dict,
                get_task_name_from_config(task_element): ConfigurableTask(
467
                    config=task_element
lintangsutawika's avatar
lintangsutawika committed
468
                ),
469
470
471
472
473
            }

        elif isinstance(task_element, Task):
            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
474
                get_task_name_from_object(task_element): task_element,
475
            }
lintangsutawika's avatar
lintangsutawika committed
476
477
478
479

    assert set(task_name_from_registry_dict.keys()).isdisjoint(
        set(task_name_from_object_dict.keys())
    )
lintangsutawika's avatar
lintangsutawika committed
480
481
482
    return {
        **task_name_from_registry_dict,
        **task_name_from_config_dict,
483
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
484
    }