__init__.py 17.5 KB
Newer Older
1
import collections
2
3
import logging
import os
4
from functools import partial
5
from typing import Dict, List, Mapping, Optional, Union
&'s avatar
& committed
6

7
from lm_eval import utils
8
from lm_eval.api.task import ConfigurableTask, Task
lintangsutawika's avatar
lintangsutawika committed
9

10

11
12
13
class TaskManager:
    """TaskManager indexes all tasks from the default `lm_eval/tasks/`
    and an optional directory if provided.
14

15
16
    """

17
18
19
20
21
22
    def __init__(
        self,
        verbosity="INFO",
        include_path: Optional[Union[str, List]] = None,
        include_defaults: bool = True,
    ) -> None:
23
24
25
26
27
        self.verbosity = verbosity
        self.include_path = include_path
        self.logger = utils.eval_logger
        self.logger.setLevel(getattr(logging, f"{verbosity}"))

28
29
30
        self._task_index = self.initialize_tasks(
            include_path=include_path, include_defaults=include_defaults
        )
31
        self._all_tasks = sorted(list(self._task_index.keys()))
32

33
        self.task_group_map = collections.defaultdict(list)
34

35
36
37
38
39
    def initialize_tasks(
        self,
        include_path: Optional[Union[str, List]] = None,
        include_defaults: bool = True,
    ):
40
        """Creates a dictionary of tasks index.
41

42
43
44
45
46
        :param include_path: Union[str, List] = None
            An additional path to be searched for tasks recursively.
            Can provide more than one such path as a list.
        :param include_defaults: bool = True
            If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
47
48
49
        :return
            Dictionary of task names as key and task metadata
        """
50
51
52
53
        if include_defaults:
            all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
        else:
            all_paths = []
54
55
56
57
        if include_path is not None:
            if isinstance(include_path, str):
                include_path = [include_path]
            all_paths.extend(include_path)
58

59
60
61
62
        task_index = {}
        for task_dir in all_paths:
            tasks = self._get_task_and_group(task_dir)
            task_index = {**tasks, **task_index}
lintangsutawika's avatar
format  
lintangsutawika committed
63

64
65
66
67
68
69
70
71
72
73
74
        return task_index

    @property
    def all_tasks(self):
        return self._all_tasks

    @property
    def task_index(self):
        return self._task_index

    def match_tasks(self, task_list):
75
        return utils.pattern_match(task_list, self.all_tasks)
76

77
    def _name_is_registered(self, name) -> bool:
78
79
80
81
        if name in self.all_tasks:
            return True
        return False

82
    def _name_is_task(self, name) -> bool:
83
84
85
86
        if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]):
            return True
        return False

87
    def _name_is_group(self, name) -> bool:
88
89
90
        if self._name_is_registered(name) and (
            self.task_index[name]["type"] == "group"
        ):
91
92
93
94
            return True
        return False

    def _name_is_python_task(self, name):
95
96
97
        if self._name_is_registered(name) and (
            self.task_index[name]["type"] == "python_task"
        ):
98
99
100
            return True
        return False

101
    def _config_is_task(self, config) -> bool:
102
103
104
105
        if ("task" in config) and isinstance(config["task"], str):
            return True
        return False

106
    def _config_is_group(self, config) -> bool:
107
108
109
110
        if ("task" in config) and isinstance(config["task"], list):
            return True
        return False

111
    def _config_is_python_task(self, config) -> bool:
112
113
114
115
116
        if "class" in config:
            return True
        return False

    def _get_yaml_path(self, name):
117
118
        if name not in self.task_index:
            raise ValueError
119
120
121
        return self.task_index[name]["yaml_path"]

    def _get_config(self, name):
122
123
        if name not in self.task_index:
            raise ValueError
124
125
126
127
128
129
130
        yaml_path = self._get_yaml_path(name)
        if yaml_path == -1:
            return {}
        else:
            return utils.load_yaml_config(yaml_path, mode="full")

    def _get_tasklist(self, name):
131
132
        if self._name_is_task(name):
            raise ValueError
133
134
135
136
137
138
139
140
141
142
143
144
        return self.task_index[name]["task"]

    def _process_alias(self, config, group=None):
        # If the group is not the same as the original
        # group which the group alias was intended for,
        # Set the group_alias to None instead.
        if ("group_alias" in config) and ("group" in config) and group is not None:
            if config["group"] != group:
                config["group_alias"] = None
        return config

    def _load_individual_task_or_group(
145
        self,
146
147
148
149
150
        name_or_config: Optional[Union[str, dict]] = None,
        parent_name: Optional[str] = None,
        update_config: Optional[dict] = None,
        yaml_path: Optional[str] = None,
    ) -> Mapping:
151
152
        def load_task(config, task, group=None, yaml_path=None):
            if "include" in config:
153
154
                if yaml_path is None:
                    raise ValueError
155
156
                config = {
                    **utils.load_yaml_config(
157
158
159
                        yaml_path,
                        yaml_config={"include": config.pop("include")},
                        mode="full",
160
161
162
                    ),
                    **config,
                }
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
            if self._config_is_python_task(config):
                task_object = config["class"]()
            else:
                config = self._process_alias(config, group=group)
                task_object = ConfigurableTask(config=config)
            if group is not None:
                task_object = (group, task_object)
            return {task: task_object}

        if isinstance(name_or_config, str):
            if update_config is not None:
                # Process name_or_config as a dict instead
                name_or_config = {"task": name_or_config, **update_config}
            elif self._name_is_task(name_or_config):
                task_config = self._get_config(name_or_config)
                return load_task(task_config, task=name_or_config, group=parent_name)
179
            else:
180
181
182
183
184
185
186
187
188
                group_name = name_or_config
                subtask_list = self._get_tasklist(name_or_config)
                if subtask_list == -1:
                    group_config = self._get_config(name_or_config)
                    subtask_list = group_config["task"]

                # This checks if we're at the root.
                if parent_name is None:
                    group_config = self._get_config(name_or_config)
189
                    if set(group_config.keys()) > {"task", "group"}:
190
                        update_config = {
191
192
193
                            k: v
                            for k, v in group_config.items()
                            if k not in ["task", "group"]
194
195
                        }
                    yaml_path = self._get_yaml_path(group_name)
196

197
198
199
                    if (update_config is not None) and ("group_alias" in update_config):
                        group_name = update_config["group_alias"]
                        update_config.pop("group_alias")
200

201
202
        if isinstance(name_or_config, dict):
            if update_config is not None:
203
                name_or_config = {
204
205
                    **name_or_config,
                    **update_config,
206
207
                }

208
209
210
211
212
213
            if self._config_is_task(name_or_config):
                name = name_or_config["task"]
                # If the name is registered as a group
                # if self._name_is_task(name) is False:
                if self._name_is_group(name):
                    group_name = name
214
215
216
                    update_config = {
                        k: v for k, v in name_or_config.items() if k != "task"
                    }
217
218
219
220
221
222
223
224
225
226
                    subtask_list = self._get_tasklist(name)
                    if subtask_list == -1:
                        subtask_list = self._get_config(name)["task"]
                else:
                    if self._name_is_registered(name):
                        base_task_config = self._get_config(name)

                        # Check if this is a duplicate.
                        if parent_name is not None:
                            name_or_config["group"] = parent_name
227
228
229
230
231
232
233
234
                            num_duplicate = len(
                                list(
                                    filter(
                                        lambda x: x.startswith(name),
                                        self.task_group_map[parent_name],
                                    )
                                )
                            )
235
236
237
238
                            if num_duplicate > 0:
                                name = f"{name}-{num_duplicate}"
                            self.task_group_map[parent_name].append(name)

239
240
241
242
                        task_config = {
                            **base_task_config,
                            **name_or_config,
                        }
243
244
                    else:
                        task_config = name_or_config
245
246
247
                    return load_task(
                        task_config, task=name, group=parent_name, yaml_path=yaml_path
                    )
248
249
250
            else:
                group_name = name_or_config["group"]
                subtask_list = name_or_config["task"]
251
                if set(name_or_config.keys()) > {"task", "group"}:
252
                    update_config = {
253
254
255
                        k: v
                        for k, v in name_or_config.items()
                        if k not in ["task", "group"]
256
                    }
257

258
        all_subtasks = {}
259
        if parent_name is not None:
260
            all_subtasks = {group_name: (parent_name, None)}
261

262
263
264
265
266
267
268
269
270
271
        fn = partial(
            self._load_individual_task_or_group,
            parent_name=group_name,
            update_config=update_config,
            yaml_path=yaml_path,
        )
        all_subtasks = {
            **all_subtasks,
            **dict(collections.ChainMap(*map(fn, subtask_list))),
        }
272
        return all_subtasks
273

274
    def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
275
        """Loads a dictionary of task objects from a list
276

277
278
        :param task_list: Union[str, list] = None
            Single string or list of string of task names to be loaded
279

280
281
282
283
284
        :return
            Dictionary of task objects
        """
        if isinstance(task_list, str):
            task_list = [task_list]
285

286
        all_loaded_tasks = dict(
287
            collections.ChainMap(*map(self._load_individual_task_or_group, task_list))
288
289
290
291
292
293
294
        )
        return all_loaded_tasks

    def load_config(self, config: Dict):
        return self._load_individual_task_or_group(config)

    def _get_task_and_group(self, task_dir: str):
295
        """Creates a dictionary of tasks index with the following metadata,
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        - `type`, that can be either `task`, `python_task`, or `group`.
            `task` refer to regular task configs, `python_task` are special
            yaml files that only consists of `task` and `class` parameters.
            `group` are group configs.
        - `yaml_path`, path to the yaml file. If the entry is a `group` that
            was configured through a task config, the yaml_path will be -1
            and all subtasks will be listed in `task` (see below)
        - `task`, reserved for entries with `type` as `group`. This will list
            all subtasks. When a group config is created (as opposed to task
            config having `group` parameter set), this will be set to -1 to
            avoid recursive indexing. The whole list of subtasks will be loaded
            at evaluation.

        :param task_dir: str
            A directory to check for tasks

        :return
            Dictionary of task names as key and task metadata
        """
315
316
317
318
        ignore_dirs = [
            "__pycache__",
            ".ipynb_checkpoints",
        ]
319
        tasks_and_groups = collections.defaultdict()
320
321
        for root, dirs, file_list in os.walk(task_dir):
            dirs[:] = [d for d in dirs if d not in ignore_dirs]
322
323
324
325
326
327
328
329
330
331
332
333
334
335
            for f in file_list:
                if f.endswith(".yaml"):
                    yaml_path = os.path.join(root, f)
                    config = utils.load_yaml_config(yaml_path, mode="simple")
                    if self._config_is_python_task(config):
                        # This is a python class config
                        tasks_and_groups[config["task"]] = {
                            "type": "python_task",
                            "yaml_path": yaml_path,
                        }
                    elif self._config_is_group(config):
                        # This is a group config
                        tasks_and_groups[config["group"]] = {
                            "type": "group",
336
337
338
339
340
                            "task": -1,  # This signals that
                            # we don't need to know
                            # the task list for indexing
                            # as it can be loaded
                            # when called.
341
342
                            "yaml_path": yaml_path,
                        }
lintangsutawika's avatar
lintangsutawika committed
343

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
                        # # Registered the level 1 tasks from a group config
                        # for config in config["task"]:
                        #     if isinstance(config, dict) and self._config_is_task(config):
                        #         task = config["task"]
                        #         tasks_and_groups[task] = {
                        #             "type": "task",
                        #             "yaml_path": yaml_path,
                        #             }

                    elif self._config_is_task(config):
                        # This is a task config
                        task = config["task"]
                        tasks_and_groups[task] = {
                            "type": "task",
                            "yaml_path": yaml_path,
359
                        }
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
                        if "group" in config:
                            groups = config["group"]
                            if isinstance(config["group"], str):
                                groups = [groups]

                            for group in groups:
                                if group not in tasks_and_groups:
                                    tasks_and_groups[group] = {
                                        "type": "group",
                                        "task": [task],
                                        "yaml_path": -1,
                                    }
                                else:
                                    tasks_and_groups[group]["task"].append(task)
                    else:
                        self.logger.debug(f"File {f} in {root} could not be loaded")

        return tasks_and_groups
lintangsutawika's avatar
lintangsutawika committed
379

380

381
382
383
384
385
386
387
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
    if "task" in task_config:
        return task_config["task"]
    if "dataset_name" in task_config:
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
lintangsutawika's avatar
lintangsutawika committed
388

389

lintangsutawika's avatar
lintangsutawika committed
390
def get_task_name_from_object(task_object):
391
392
    if hasattr(task_object, "config"):
        return task_object._config["task"]
lintangsutawika's avatar
lintangsutawika committed
393
394
395
396
397
398
399
400
401

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )

402
403

def get_task_dict(
404
    task_name_list: Union[str, List[Union[str, Dict, Task]]],
405
    task_manager: Optional[TaskManager] = None,
406
):
407
    """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
lintangsutawika's avatar
lintangsutawika committed
408

409
410
411
412
413
414
415
    :param task_name_list: List[Union[str, Dict, Task]]
        Name of model or LM object, see lm_eval.models.get_model
    :param task_manager: TaskManager = None
        A TaskManager object that stores indexed tasks. If not set,
        task_manager will load one. This should be set by the user
        if there are additional paths that want to be included
        via `include_path`
416

417
418
419
420
    :return
        Dictionary of task objects
    """
    task_name_from_string_dict = {}
421
422
423
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

424
    if isinstance(task_name_list, str):
lintangsutawika's avatar
lintangsutawika committed
425
        task_name_list = [task_name_list]
426
427
428
429
430
431
432
433
434
    elif isinstance(task_name_list, list):
        if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
            raise TypeError(
                "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
            )
    else:
        raise TypeError(
            f"Expected a 'str' or 'list' but received {type(task_name_list)}."
        )
lintangsutawika's avatar
lintangsutawika committed
435

436
    string_task_name_list = [task for task in task_name_list if isinstance(task, str)]
437
438
439
    others_task_name_list = [
        task for task in task_name_list if not isinstance(task, str)
    ]
440
441
442
    if len(string_task_name_list) > 0:
        if task_manager is None:
            task_manager = TaskManager()
lintangsutawika's avatar
lintangsutawika committed
443

444
445
446
        task_name_from_string_dict = task_manager.load_task_or_group(
            string_task_name_list
        )
447

448
449
    for task_element in others_task_name_list:
        if isinstance(task_element, dict):
450
451
            task_name_from_config_dict = {
                **task_name_from_config_dict,
452
                **task_manager.load_config(config=task_element),
453
454
455
456
457
            }

        elif isinstance(task_element, Task):
            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
458
                get_task_name_from_object(task_element): task_element,
459
            }
lintangsutawika's avatar
lintangsutawika committed
460

461
    if not set(task_name_from_string_dict.keys()).isdisjoint(
lintangsutawika's avatar
lintangsutawika committed
462
        set(task_name_from_object_dict.keys())
463
464
    ):
        raise ValueError
465

lintangsutawika's avatar
lintangsutawika committed
466
    return {
467
        **task_name_from_string_dict,
lintangsutawika's avatar
lintangsutawika committed
468
        **task_name_from_config_dict,
469
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
470
    }