manager.py 3.02 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
from __future__ import annotations

from collections import defaultdict
from itertools import chain
from pathlib import Path
Baber's avatar
Baber committed
6
from typing import TYPE_CHECKING, Any
Baber's avatar
Baber committed
7

Baber's avatar
Baber committed
8
from lm_eval.api.task import Task
Baber's avatar
Baber committed
9
10
11
12
13
from lm_eval.tasks.factory import TaskFactory
from lm_eval.tasks.index import Entry, Kind, TaskIndex
from lm_eval.utils import setup_logging


Baber's avatar
Baber committed
14
15
16
17
if TYPE_CHECKING:
    from lm_eval.api.task import Task


Baber's avatar
Baber committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class TaskManager:
    def __init__(
        self,
        verbosity: str | None = None,
        include_path: str | Path | list[str | Path] | None = None,
        include_defaults: bool = True,
        metadata: dict[str, dict[str, Any]] | None = None,
    ) -> None:
        if verbosity:
            setup_logging(verbosity)

        index = TaskIndex()
        self._factory = TaskFactory(meta=metadata)

        all_paths: list[Path] = []
        if include_defaults:
            all_paths.append(Path(__file__).parent)
        if include_path:
            all_paths += [
                Path(p)
                for p in (
                    include_path
                    if isinstance(include_path, (list, tuple))
                    else [include_path]
                )
            ]

        self._index = index.build(all_paths)

        buckets = defaultdict(list)
        for k, e in self._index.items():
            buckets[e.kind].append(k)

        self._all_tasks = sorted(
            chain.from_iterable(buckets[k] for k in {Kind.TASK, Kind.PY_TASK})
        )
        self._all_groups = sorted(buckets[Kind.GROUP])
        self._all_tags = sorted(buckets[Kind.TAG])

    def _entry(self, name: str) -> Entry:
        if name not in self._index:
            raise KeyError(f"Unknown task/group/tag: {name}")
        return self._index[name]

    def load_spec(self, spec: str | dict[str, Any]):
        """Spec can be:
        • str  task / group / tag name (registered)
        • dict inline overrides   {'task': 'hellaswag', 'num_fewshot': 5}
        """
        if isinstance(spec, str):
            entry = self._entry(spec)
            return self._factory.build(entry, overrides=None, registry=self._index)

        if isinstance(spec, dict):
            # inline dict => find base entry, then pass overrides
            name = spec["task"]
            entry = self._entry(name)
            return self._factory.build(entry, overrides=spec, registry=self._index)

        raise TypeError("spec must be str or dict")

    def load_task_or_group(self, task_list: str | list[str]):
        return (
            [self.load_spec(s) for s in task_list]
            if isinstance(task_list, list)
            else [self.load_spec(task_list)]
        )
Baber's avatar
Baber committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101


def get_task_dict(
    task_name_list: str | list[str | dict | Task],
    task_manager: TaskManager | None = None,
):
    if not task_manager:
        task_manager = TaskManager()
    else:
        assert isinstance(task_manager, TaskManager)

    return {
        task_name: task_manager.load_spec(task_name)
        if isinstance(task_name, str)
        else task_name
        for task_name in task_name_list
    }