__init__.py 16.9 KB
Newer Older
1
import collections
2
3
import logging
import os
4
from functools import partial
5
from typing import Dict, List, Mapping, Optional, Union
&'s avatar
& committed
6

7
from lm_eval import utils
8
from lm_eval.api.task import ConfigurableTask, Task
lintangsutawika's avatar
lintangsutawika committed
9

10

11
12
13
class TaskManager:
    """TaskManager indexes all tasks from the default `lm_eval/tasks/`
    and an optional directory if provided.
14

15
16
    """

Nathan Habib's avatar
Nathan Habib committed
17
    def __init__(self, verbosity="INFO", include_path: Optional[str] = None) -> None:
18
19
20
21
22
        self.verbosity = verbosity
        self.include_path = include_path
        self.logger = utils.eval_logger
        self.logger.setLevel(getattr(logging, f"{verbosity}"))

Nathan Habib's avatar
Nathan Habib committed
23
        self._task_index = self.initialize_tasks(include_path=include_path)
24
        self._all_tasks = sorted(list(self._task_index.keys()))
25

26
        self.task_group_map = collections.defaultdict(list)
27

Nathan Habib's avatar
Nathan Habib committed
28
    def initialize_tasks(self, include_path: Optional[str] = None):
29
        """Creates a dictionary of tasks index.
30

Nathan Habib's avatar
Nathan Habib committed
31
32
33
        :param include_path: str = None
            An additional path to be searched for tasks

34
35
36
        :return
            Dictionary of task names as key and task metadata
        """
Nathan Habib's avatar
Nathan Habib committed
37
        all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
38
39
40
41
        if include_path is not None:
            if isinstance(include_path, str):
                include_path = [include_path]
            all_paths.extend(include_path)
42

43
44
45
46
        task_index = {}
        for task_dir in all_paths:
            tasks = self._get_task_and_group(task_dir)
            task_index = {**tasks, **task_index}
lintangsutawika's avatar
format  
lintangsutawika committed
47

48
49
50
51
52
53
54
55
56
57
58
        return task_index

    @property
    def all_tasks(self):
        return self._all_tasks

    @property
    def task_index(self):
        return self._task_index

    def match_tasks(self, task_list):
59
        return utils.pattern_match(task_list, self.all_tasks)
60

61
    def _name_is_registered(self, name) -> bool:
62
63
64
65
        if name in self.all_tasks:
            return True
        return False

66
    def _name_is_task(self, name) -> bool:
67
68
69
70
        if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]):
            return True
        return False

71
    def _name_is_group(self, name) -> bool:
72
73
74
        if self._name_is_registered(name) and (
            self.task_index[name]["type"] == "group"
        ):
75
76
77
78
            return True
        return False

    def _name_is_python_task(self, name):
79
80
81
        if self._name_is_registered(name) and (
            self.task_index[name]["type"] == "python_task"
        ):
82
83
84
            return True
        return False

85
    def _config_is_task(self, config) -> bool:
86
87
88
89
        if ("task" in config) and isinstance(config["task"], str):
            return True
        return False

90
    def _config_is_group(self, config) -> bool:
91
92
93
94
        if ("task" in config) and isinstance(config["task"], list):
            return True
        return False

95
    def _config_is_python_task(self, config) -> bool:
96
97
98
99
100
        if "class" in config:
            return True
        return False

    def _get_yaml_path(self, name):
101
102
        if name not in self.task_index:
            raise ValueError
103
104
105
        return self.task_index[name]["yaml_path"]

    def _get_config(self, name):
106
107
        if name not in self.task_index:
            raise ValueError
108
109
110
111
112
113
114
        yaml_path = self._get_yaml_path(name)
        if yaml_path == -1:
            return {}
        else:
            return utils.load_yaml_config(yaml_path, mode="full")

    def _get_tasklist(self, name):
115
116
        if self._name_is_task(name):
            raise ValueError
117
118
119
120
121
122
123
124
125
126
127
128
        return self.task_index[name]["task"]

    def _process_alias(self, config, group=None):
        # If the group is not the same as the original
        # group which the group alias was intended for,
        # Set the group_alias to None instead.
        if ("group_alias" in config) and ("group" in config) and group is not None:
            if config["group"] != group:
                config["group_alias"] = None
        return config

    def _load_individual_task_or_group(
129
        self,
130
131
132
133
134
        name_or_config: Optional[Union[str, dict]] = None,
        parent_name: Optional[str] = None,
        update_config: Optional[dict] = None,
        yaml_path: Optional[str] = None,
    ) -> Mapping:
135
136
        def load_task(config, task, group=None, yaml_path=None):
            if "include" in config:
137
138
                if yaml_path is None:
                    raise ValueError
139
140
                config = {
                    **utils.load_yaml_config(
141
142
143
                        yaml_path,
                        yaml_config={"include": config.pop("include")},
                        mode="full",
144
145
146
                    ),
                    **config,
                }
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
            if self._config_is_python_task(config):
                task_object = config["class"]()
            else:
                config = self._process_alias(config, group=group)
                task_object = ConfigurableTask(config=config)
            if group is not None:
                task_object = (group, task_object)
            return {task: task_object}

        if isinstance(name_or_config, str):
            if update_config is not None:
                # Process name_or_config as a dict instead
                name_or_config = {"task": name_or_config, **update_config}
            elif self._name_is_task(name_or_config):
                task_config = self._get_config(name_or_config)
                return load_task(task_config, task=name_or_config, group=parent_name)
163
            else:
164
165
166
167
168
169
170
171
172
                group_name = name_or_config
                subtask_list = self._get_tasklist(name_or_config)
                if subtask_list == -1:
                    group_config = self._get_config(name_or_config)
                    subtask_list = group_config["task"]

                # This checks if we're at the root.
                if parent_name is None:
                    group_config = self._get_config(name_or_config)
173
                    if set(group_config.keys()) > {"task", "group"}:
174
                        update_config = {
175
176
177
                            k: v
                            for k, v in group_config.items()
                            if k not in ["task", "group"]
178
179
                        }
                    yaml_path = self._get_yaml_path(group_name)
180

181
182
183
                    if (update_config is not None) and ("group_alias" in update_config):
                        group_name = update_config["group_alias"]
                        update_config.pop("group_alias")
184

185
186
        if isinstance(name_or_config, dict):
            if update_config is not None:
187
                name_or_config = {
188
189
                    **name_or_config,
                    **update_config,
190
191
                }

192
193
194
195
196
197
            if self._config_is_task(name_or_config):
                name = name_or_config["task"]
                # If the name is registered as a group
                # if self._name_is_task(name) is False:
                if self._name_is_group(name):
                    group_name = name
198
199
200
                    update_config = {
                        k: v for k, v in name_or_config.items() if k != "task"
                    }
201
202
203
204
205
206
207
208
209
210
                    subtask_list = self._get_tasklist(name)
                    if subtask_list == -1:
                        subtask_list = self._get_config(name)["task"]
                else:
                    if self._name_is_registered(name):
                        base_task_config = self._get_config(name)

                        # Check if this is a duplicate.
                        if parent_name is not None:
                            name_or_config["group"] = parent_name
211
212
213
214
215
216
217
218
                            num_duplicate = len(
                                list(
                                    filter(
                                        lambda x: x.startswith(name),
                                        self.task_group_map[parent_name],
                                    )
                                )
                            )
219
220
221
222
                            if num_duplicate > 0:
                                name = f"{name}-{num_duplicate}"
                            self.task_group_map[parent_name].append(name)

223
224
225
226
                        task_config = {
                            **base_task_config,
                            **name_or_config,
                        }
227
228
                    else:
                        task_config = name_or_config
229
230
231
                    return load_task(
                        task_config, task=name, group=parent_name, yaml_path=yaml_path
                    )
232
233
234
            else:
                group_name = name_or_config["group"]
                subtask_list = name_or_config["task"]
235
                if set(name_or_config.keys()) > {"task", "group"}:
236
                    update_config = {
237
238
239
                        k: v
                        for k, v in name_or_config.items()
                        if k not in ["task", "group"]
240
                    }
241

242
        all_subtasks = {}
243
        if parent_name is not None:
244
            all_subtasks = {group_name: (parent_name, None)}
245

246
247
248
249
250
251
252
253
254
255
        fn = partial(
            self._load_individual_task_or_group,
            parent_name=group_name,
            update_config=update_config,
            yaml_path=yaml_path,
        )
        all_subtasks = {
            **all_subtasks,
            **dict(collections.ChainMap(*map(fn, subtask_list))),
        }
256
        return all_subtasks
257

258
    def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
259
        """Loads a dictionary of task objects from a list
260

261
262
        :param task_list: Union[str, list] = None
            Single string or list of string of task names to be loaded
263

264
265
266
267
268
        :return
            Dictionary of task objects
        """
        if isinstance(task_list, str):
            task_list = [task_list]
269

270
        all_loaded_tasks = dict(
271
            collections.ChainMap(*map(self._load_individual_task_or_group, task_list))
272
273
274
275
276
277
278
        )
        return all_loaded_tasks

    def load_config(self, config: Dict):
        return self._load_individual_task_or_group(config)

    def _get_task_and_group(self, task_dir: str):
279
        """Creates a dictionary of tasks index with the following metadata,
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        - `type`, that can be either `task`, `python_task`, or `group`.
            `task` refer to regular task configs, `python_task` are special
            yaml files that only consists of `task` and `class` parameters.
            `group` are group configs.
        - `yaml_path`, path to the yaml file. If the entry is a `group` that
            was configured through a task config, the yaml_path will be -1
            and all subtasks will be listed in `task` (see below)
        - `task`, reserved for entries with `type` as `group`. This will list
            all subtasks. When a group config is created (as opposed to task
            config having `group` parameter set), this will be set to -1 to
            avoid recursive indexing. The whole list of subtasks will be loaded
            at evaluation.

        :param task_dir: str
            A directory to check for tasks

        :return
            Dictionary of task names as key and task metadata
        """
        tasks_and_groups = collections.defaultdict()
Nathan Habib's avatar
Nathan Habib committed
300
        for root, _, file_list in os.walk(task_dir):
301
302
303
304
305
306
307
308
309
310
311
312
313
314
            for f in file_list:
                if f.endswith(".yaml"):
                    yaml_path = os.path.join(root, f)
                    config = utils.load_yaml_config(yaml_path, mode="simple")
                    if self._config_is_python_task(config):
                        # This is a python class config
                        tasks_and_groups[config["task"]] = {
                            "type": "python_task",
                            "yaml_path": yaml_path,
                        }
                    elif self._config_is_group(config):
                        # This is a group config
                        tasks_and_groups[config["group"]] = {
                            "type": "group",
315
316
317
318
319
                            "task": -1,  # This signals that
                            # we don't need to know
                            # the task list for indexing
                            # as it can be loaded
                            # when called.
320
321
                            "yaml_path": yaml_path,
                        }
lintangsutawika's avatar
lintangsutawika committed
322

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
                        # # Registered the level 1 tasks from a group config
                        # for config in config["task"]:
                        #     if isinstance(config, dict) and self._config_is_task(config):
                        #         task = config["task"]
                        #         tasks_and_groups[task] = {
                        #             "type": "task",
                        #             "yaml_path": yaml_path,
                        #             }

                    elif self._config_is_task(config):
                        # This is a task config
                        task = config["task"]
                        tasks_and_groups[task] = {
                            "type": "task",
                            "yaml_path": yaml_path,
338
                        }
339

340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                        if "group" in config:
                            groups = config["group"]
                            if isinstance(config["group"], str):
                                groups = [groups]

                            for group in groups:
                                if group not in tasks_and_groups:
                                    tasks_and_groups[group] = {
                                        "type": "group",
                                        "task": [task],
                                        "yaml_path": -1,
                                    }
                                else:
                                    tasks_and_groups[group]["task"].append(task)
                    else:
                        self.logger.debug(f"File {f} in {root} could not be loaded")

        return tasks_and_groups
lintangsutawika's avatar
lintangsutawika committed
358

359

360
361
362
363
364
365
366
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
    if "task" in task_config:
        return task_config["task"]
    if "dataset_name" in task_config:
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
lintangsutawika's avatar
lintangsutawika committed
367

368

lintangsutawika's avatar
lintangsutawika committed
369
def get_task_name_from_object(task_object):
370
371
    if hasattr(task_object, "config"):
        return task_object._config["task"]
lintangsutawika's avatar
lintangsutawika committed
372
373
374
375
376
377
378
379
380

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )

381
382

def get_task_dict(
383
    task_name_list: Union[str, List[Union[str, Dict, Task]]],
384
    task_manager: Optional[TaskManager] = None,
385
):
386
    """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
lintangsutawika's avatar
lintangsutawika committed
387

388
389
390
391
392
393
394
    :param task_name_list: List[Union[str, Dict, Task]]
        Name of model or LM object, see lm_eval.models.get_model
    :param task_manager: TaskManager = None
        A TaskManager object that stores indexed tasks. If not set,
        task_manager will load one. This should be set by the user
        if there are additional paths that want to be included
        via `include_path`
395

396
397
398
399
    :return
        Dictionary of task objects
    """
    task_name_from_string_dict = {}
400
401
402
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

403
    if isinstance(task_name_list, str):
lintangsutawika's avatar
lintangsutawika committed
404
        task_name_list = [task_name_list]
405
406
407
408
409
410
411
412
413
    elif isinstance(task_name_list, list):
        if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
            raise TypeError(
                "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
            )
    else:
        raise TypeError(
            f"Expected a 'str' or 'list' but received {type(task_name_list)}."
        )
lintangsutawika's avatar
lintangsutawika committed
414

415
    string_task_name_list = [task for task in task_name_list if isinstance(task, str)]
416
417
418
    others_task_name_list = [
        task for task in task_name_list if not isinstance(task, str)
    ]
419
420
421
    if len(string_task_name_list) > 0:
        if task_manager is None:
            task_manager = TaskManager()
lintangsutawika's avatar
lintangsutawika committed
422

423
424
425
        task_name_from_string_dict = task_manager.load_task_or_group(
            string_task_name_list
        )
426

427
428
    for task_element in others_task_name_list:
        if isinstance(task_element, dict):
429
430
            task_name_from_config_dict = {
                **task_name_from_config_dict,
431
                **task_manager.load_config(config=task_element),
432
433
434
435
436
            }

        elif isinstance(task_element, Task):
            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
437
                get_task_name_from_object(task_element): task_element,
438
            }
lintangsutawika's avatar
lintangsutawika committed
439

440
    if not set(task_name_from_string_dict.keys()).isdisjoint(
lintangsutawika's avatar
lintangsutawika committed
441
        set(task_name_from_object_dict.keys())
442
443
    ):
        raise ValueError
444

lintangsutawika's avatar
lintangsutawika committed
445
    return {
446
        **task_name_from_string_dict,
lintangsutawika's avatar
lintangsutawika committed
447
        **task_name_from_config_dict,
448
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
449
    }