__init__.py 16.1 KB
Newer Older
1
import os
2
import abc
lintangsutawika's avatar
lintangsutawika committed
3
import yaml
4
import collections
lintangsutawika's avatar
lintangsutawika committed
5
6

from functools import partial
7
from typing import List, Union, Dict
&'s avatar
& committed
8

9
from lm_eval import utils
10
from lm_eval import prompts
11
from lm_eval.api.task import TaskConfig, Task, ConfigurableTask
lintangsutawika's avatar
lintangsutawika committed
12
from lm_eval.api.registry import (
13
14
    register_task,
    register_group,
lintangsutawika's avatar
lintangsutawika committed
15
16
    TASK_REGISTRY,
    GROUP_REGISTRY,
lintangsutawika's avatar
lintangsutawika committed
17
)
lintangsutawika's avatar
lintangsutawika committed
18

lintangsutawika's avatar
lintangsutawika committed
19
import logging
lintangsutawika's avatar
lintangsutawika committed
20

lintangsutawika's avatar
lintangsutawika committed
21
# import python tasks
lintangsutawika's avatar
lintangsutawika committed
22
from .squadv2.task import SQuAD2
lintangsutawika's avatar
lintangsutawika committed
23
24
25
26
27
28
29
30
from .scrolls.task import (
    QuALITY,
    NarrativeQA,
    ContractNLI,
    GovReport,
    SummScreenFD,
    QMSum,
)
lintangsutawika's avatar
lintangsutawika committed
31

32
eval_logger = utils.eval_logger
33

lintangsutawika's avatar
lintangsutawika committed
34

35
class TaskManager(abc.ABC):
36

37
    def __init__(
lintangsutawika's avatar
lintangsutawika committed
38
        self,
39
40
41
        verbosity="INFO",
        include_path=None
        ) -> None:
42

lintangsutawika's avatar
lintangsutawika committed
43
44
        self.verbosity = verbosity
        self.include_path = include_path
lintangsutawika's avatar
lintangsutawika committed
45
        self.logger = eval_logger.setLevel(getattr(logging, f"{verbosity}"))
lintangsutawika's avatar
lintangsutawika committed
46
47

        self.ALL_TASKS = self.initialize_tasks(
48
49
50
            include_path=include_path
            )

lintangsutawika's avatar
lintangsutawika committed
51
    def initialize_tasks(self, include_path=None):
lintangsutawika's avatar
lintangsutawika committed
52

lintangsutawika's avatar
lintangsutawika committed
53
54
55
56
57
58
59
60
61
62
63
64
65
        all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
        if include_path is not None:
            if isinstance(include_path, str):
                include_path = [include_path]
            all_paths.extend(include_path)

        ALL_TASKS = {}
        for task_dir in all_paths:
            tasks = get_task_and_group(task_dir)
            ALL_TASKS = {**tasks, **ALL_TASKS}

        return ALL_TASKS

66
    def all_tasks(self):
lintangsutawika's avatar
lintangsutawika committed
67
        return sorted(list(self.ALL_TASKS.keys()))
68

lintangsutawika's avatar
lintangsutawika committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    def _name_is_registered(self, name):
        if name in self.ALL_TASKS:
            return True
        return False

    def _name_is_task(self, name):
        if self.ALL_TASKS[name]["type"] == "task":
            return True
        return False

    def _config_is_task(self, config):
        if list(config.keys()) == ["group", "task"]:
            return False
        return True

    def _get_config(self, name):
        assert name in self.ALL_TASKS
        yaml_path = self.ALL_TASKS[name]["yaml_path"]
        return utils.load_yaml_config(yaml_path)

    def _get_tasklist(self, name):
        assert self._name_is_task(name) == False
        return self.ALL_TASKS[name]["task"]

    def _load_individual_task_or_group(self, name_or_config: Union[str, dict] = None, parent_name: str = None) -> ConfigurableTask:

        print("Loading", name_or_config)
        if isinstance(name_or_config, str):
            if self._name_is_task(name_or_config):
                task_config = self._get_config(name_or_config)
                task_object = ConfigurableTask(config=task_config)
                if parent_name is not None:
                    task_object = (parent_name, task_object)
                return {name_or_config: task_object}
103
            else:
lintangsutawika's avatar
lintangsutawika committed
104
105
                group_name = name_or_config
                subtask_list = self._get_tasklist(name_or_config)
106
                if subtask_list == -1:
lintangsutawika's avatar
lintangsutawika committed
107
108
109
110
111
112
113
114
115
116
117
                    subtask_list = self._get_config(name_or_config)["task"]

        elif isinstance(name_or_config, dict):
            if self._config_is_task(name_or_config):
                task_name = name_or_config["task"]
                if self._name_is_registered(task_name):
                    base_task_config = self._get_config(task_name)
                    task_config={
                            **base_task_config,
                            **name_or_config,
                        }
118
                else:
lintangsutawika's avatar
lintangsutawika committed
119
120
121
122
123
                    task_config = name_or_config
                task_object = ConfigurableTask(config=task_config)
                if parent_name is not None:
                    task_object = (parent_name, task_object)
                return {task_name: task_object}
124
            else:
lintangsutawika's avatar
lintangsutawika committed
125
126
127
128
129
130
131
                group_name = name_or_config["group"]
                subtask_list = name_or_config["task"]

        fn = partial(self._load_individual_task_or_group, parent_name=group_name)
        all_subtasks = dict(collections.ChainMap(*map(fn, subtask_list)))
        return all_subtasks

132
133
134
135
136
137
138
139
140

    def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:

        if isinstance(task_list, str):
            task_list = [task_list]

        all_loaded_tasks = {}
        for task in task_list:
            task_object = self._load_individual_task_or_group(
lintangsutawika's avatar
lintangsutawika committed
141
                name_or_config=task,
142
            )
143
144
145
146
147
148
            if isinstance(task, str):
                task_name = task
            elif isinstance(task, dict):
                task_name = task["task"]

            if isinstance(task_object, dict):
lintangsutawika's avatar
lintangsutawika committed
149
                all_loaded_tasks = {**task_object, **all_loaded_tasks}
150
151
            else:
                all_loaded_tasks[task_name] = task_object
lintangsutawika's avatar
lintangsutawika committed
152

153
        return all_loaded_tasks
154
155


156
def register_configurable_task(config: Dict[str, str]) -> int:
157
158
159
160
161
162
163
164
165
166
167
    SubClass = type(
        config["task"] + "ConfigurableTask",
        (ConfigurableTask,),
        {"CONFIG": TaskConfig(**config)},
    )

    if "task" in config:
        task_name = "{}".format(config["task"])
        register_task(task_name)(SubClass)

    if "group" in config:
baberabb's avatar
baberabb committed
168
169
170
        if config["group"] == config["task"]:
            raise ValueError("task and group name cannot be the same")
        elif type(config["group"]) == str:
171
172
173
174
175
176
177
178
179
            group_name = [config["group"]]
        else:
            group_name = config["group"]

        for group in group_name:
            register_group(group)(SubClass)

    return 0

lintangsutawika's avatar
format  
lintangsutawika committed
180

lintangsutawika's avatar
lintangsutawika committed
181
def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -> int:
182
183
    group = config["group"]

lintangsutawika's avatar
lintangsutawika committed
184
185
186
187
188
189
    task_config_list = []
    group_config_list = []
    registered_task_or_group_list = []
    for task in config["task"]:
        if isinstance(task, str):
            registered_task_or_group_list.append(task)
190
        elif is_group(task):
lintangsutawika's avatar
lintangsutawika committed
191
192
193
            group_config_list.append(task)
        else:
            task_config_list.append(task)
194

lintangsutawika's avatar
lintangsutawika committed
195
    for task_config in task_config_list:
196
197
198
199
        base_config = {}
        task_name_config = {}
        if "task" in task_config:
            task_name = task_config["task"]
lintangsutawika's avatar
lintangsutawika committed
200
            if task_name in TASK_REGISTRY:
201
202
203
204
205
                task_obj = get_task_dict(task_name)[task_name]
                if type(task_obj) == tuple:
                    _, task_obj = task_obj

                if task_obj is not None:
Lintang Sutawika's avatar
Lintang Sutawika committed
206
                    base_config = task_obj._config.to_dict(keep_callable=True)
207
208
                    task_name_config["task"] = f"{group}_{task_name}"

lintangsutawika's avatar
lintangsutawika committed
209
        task_config = utils.load_yaml_config(yaml_path, task_config)
210
211
        var_configs = check_prompt_config(
            {
212
                **base_config,
213
214
                **task_config,
                **{"group": group},
215
                **task_name_config,
lintangsutawika's avatar
lintangsutawika committed
216
217
            },
            yaml_path=os.path.dirname(yaml_path),
218
219
220
221
        )
        for config in var_configs:
            register_configurable_task(config)

lintangsutawika's avatar
lintangsutawika committed
222
223
224
225
226
227
228
    for group_config in group_config_list:
        sub_group = group_config["group"]
        register_configurable_group(group_config, yaml_path)
        if group in GROUP_REGISTRY:
            GROUP_REGISTRY[group].append(sub_group)
        else:
            GROUP_REGISTRY[group] = [sub_group]
229
            self.ALL_TASKS.add(group)
lintangsutawika's avatar
lintangsutawika committed
230

231
    task_names = utils.pattern_match(registered_task_or_group_list, self.ALL_TASKS)
232
233
234
235
236
237
    for task in task_names:
        if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
            if group in GROUP_REGISTRY:
                GROUP_REGISTRY[group].append(task)
            else:
                GROUP_REGISTRY[group] = [task]
238
                self.ALL_TASKS.add(group)
239
240
241

    return 0

242

lintangsutawika's avatar
lintangsutawika committed
243
244
245
def check_prompt_config(
    config: Dict[str, str], yaml_path: str = None
) -> List[Dict[str, str]]:
246
247
248
249
250
    all_configs = []
    if "use_prompt" in config:
        prompt_list = prompts.load_prompt_list(
            use_prompt=config["use_prompt"],
            dataset_name=config["dataset_path"],
lintangsutawika's avatar
lintangsutawika committed
251
            subset_name=config["dataset_name"] if "dataset_name" in config else None,
lintangsutawika's avatar
lintangsutawika committed
252
            yaml_path=yaml_path,
253
254
255
256
257
258
259
260
        )
        for idx, prompt_variation in enumerate(prompt_list):
            all_configs.append(
                {
                    **config,
                    **{"use_prompt": prompt_variation},
                    **{
                        "task": "_".join(
261
262
263
264
                            [
                                config["task"]
                                if "task" in config
                                else get_task_name_from_config(config),
lintangsutawika's avatar
lintangsutawika committed
265
266
267
                                prompt_variation.split("/")[-1]
                                if ".yaml" in prompt_variation
                                else prompt_variation,
268
269
270
                            ]
                        )
                    },
271
                    **{"output_type": "generate_until"},
272
273
274
275
276
277
278
                }
            )
    else:
        all_configs.append(config)
    return all_configs


279
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
Lintang Sutawika's avatar
Lintang Sutawika committed
280
    if "dataset_name" in task_config:
Lintang Sutawika's avatar
Lintang Sutawika committed
281
282
283
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
284
285


lintangsutawika's avatar
lintangsutawika committed
286
def include_task_folder(task_dir: str, register_task: bool = True, task_name: str = None) -> None:
287
288
289
    """
    Calling this function
    """
290
291
292

    # Track whether any tasks failed during loading
    import_fail = False
293
    for root, subdirs, file_list in os.walk(task_dir):
294
295
296
297
298
299
300
        # if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                try:
                    config = utils.load_yaml_config(yaml_path)

lintangsutawika's avatar
lintangsutawika committed
301
302
                    if "task" not in config:
                        continue
lintangsutawika's avatar
lintangsutawika committed
303

lintangsutawika's avatar
lintangsutawika committed
304
305
306
307
308
309
310
311
312
313
                    all_configs = check_prompt_config(
                        config, yaml_path=os.path.dirname(yaml_path)
                    )
                    for config in all_configs:
                        if register_task:
                            if type(config["task"]) == str:
                                register_configurable_task(config)
                        else:
                            if type(config["task"]) == list:
                                register_configurable_group(config, yaml_path)
314

lintangsutawika's avatar
lintangsutawika committed
315
                # Log this silently and show it only when
316
                # the user defines the appropriate verbosity.
317
318
                except (ImportError, ModuleNotFoundError) as e:
                    import_fail = True
319
                    eval_logger.debug(
baberabb's avatar
baberabb committed
320
321
                        f"{yaml_path}: {e}. Config will not be added to registry."
                    )
322
                except Exception as error:
lintangsutawika's avatar
lintangsutawika committed
323
                    import traceback
lintangsutawika's avatar
lintangsutawika committed
324

325
326
                    eval_logger.warning(
                        "Unexpected error loading config in\n"
327
328
                        f"                                 {yaml_path}\n"
                        "                                 Config will not be added to registry\n"
lintangsutawika's avatar
lintangsutawika committed
329
330
                        f"                                 Error: {error}\n"
                        f"                                 Traceback: {traceback.format_exc()}"
331
                    )
332
333
334
335
336
337

    if import_fail:
        eval_logger.warning(
          "Some tasks could not be loaded due to missing dependencies."
          " Run with `--verbosity DEBUG` for full details."
          )
lintangsutawika's avatar
lintangsutawika committed
338
    return 0
339
340


341
def get_task_and_group(task_dir: str):
342
    tasks_and_groups = collections.defaultdict()
343
344
345
346
347
    for root, _, file_list in os.walk(task_dir):
        for f in file_list:
            if f.endswith(".yaml"):
                yaml_path = os.path.join(root, f)
                config = utils.simple_load_yaml_config(yaml_path)
348
349
350
351
                if list(config.keys()) == ["group", "task"]:
                    # This is a group config
                    tasks_and_groups[config["group"]] = {
                        "type": "group",
lintangsutawika's avatar
lintangsutawika committed
352
353
354
355
                        "task": -1, # This signals that
                                    # we don't need to know
                                    # the task list for indexing
                                    # as it can be loaded
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
                                    # when called.
                        "yaml_path": yaml_path,
                    }
                else:
                    # This is a task config
                    task = config["task"]
                    tasks_and_groups[task] = {
                        "type": "task",
                        "yaml_path": yaml_path,
                        }

                    if "group" in config:
                        groups = config["group"]
                        if isinstance(config["group"], str):
                            groups = [groups]

                        for group in groups:
                            if group not in tasks_and_groups:
                                tasks_and_groups[group] = {
                                    "type": "group",
                                    "task": [task],
                                    "yaml_path": -1,
                                }
                            else:
                                tasks_and_groups[group]["task"].append(task)

    return tasks_and_groups
383
384


lintangsutawika's avatar
lintangsutawika committed
385

386
387
def get_task(task_name, config):
    try:
388
        return TASK_REGISTRY[task_name](config=config)
389
    except KeyError:
lintangsutawika's avatar
lintangsutawika committed
390
        eval_logger.info("Available tasks:")
391
        eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
lintangsutawika's avatar
lintangsutawika committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        raise KeyError(f"Missing task {task_name}")


def get_task_name_from_object(task_object):
    for name, class_ in TASK_REGISTRY.items():
        if class_ is task_object:
            return name

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )


# TODO: pass num_fewshot and other cmdline overrides in a better way
410
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], **kwargs):
lintangsutawika's avatar
lintangsutawika committed
411
    config = {**kwargs}
412
413
414
415
416

    task_name_from_registry_dict = {}
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

lintangsutawika's avatar
lintangsutawika committed
417
418
419
    if type(task_name_list) != list:
        task_name_list = [task_name_list]

420
421
422
    for task_element in task_name_list:
        if isinstance(task_element, str):
            if task_element in GROUP_REGISTRY:
423
                group_name = task_element
424
425
                for task_name in GROUP_REGISTRY[task_element]:
                    if task_name not in task_name_from_registry_dict:
lintangsutawika's avatar
lintangsutawika committed
426
427
428
429
430
431
432
433
434
435
436
                        task_obj = get_task_dict(task_name)
                        if task_name in task_obj.keys():
                            task_dict = {
                                task_name: (group_name, task_obj[task_name]),
                            }
                        else:
                            task_dict = {
                                task_name: (group_name, None),
                                **task_obj,
                            }

437
438
                        task_name_from_registry_dict = {
                            **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
439
                            **task_dict,
lintangsutawika's avatar
lintangsutawika committed
440
                        }
441
            else:
442
                task_name = task_element
443
444
445
                if task_name not in task_name_from_registry_dict:
                    task_name_from_registry_dict = {
                        **task_name_from_registry_dict,
lintangsutawika's avatar
lintangsutawika committed
446
447
                        task_name: get_task(task_name=task_element, config=config),
                    }
448
449

        elif isinstance(task_element, dict):
450
            task_element.update(config)
451
452
453
            task_name_from_config_dict = {
                **task_name_from_config_dict,
                get_task_name_from_config(task_element): ConfigurableTask(
454
                    config=task_element
lintangsutawika's avatar
lintangsutawika committed
455
                ),
456
457
458
459
460
            }

        elif isinstance(task_element, Task):
            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
461
                get_task_name_from_object(task_element): task_element,
462
            }
lintangsutawika's avatar
lintangsutawika committed
463
464
465
466

    assert set(task_name_from_registry_dict.keys()).isdisjoint(
        set(task_name_from_object_dict.keys())
    )
lintangsutawika's avatar
lintangsutawika committed
467
468
469
    return {
        **task_name_from_registry_dict,
        **task_name_from_config_dict,
470
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
471
    }