manager.py 2.51 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations

from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Any

from lm_eval.tasks.factory import TaskFactory
from lm_eval.tasks.index import Entry, Kind, TaskIndex
from lm_eval.utils import setup_logging


class TaskManager:
    def __init__(
        self,
        verbosity: str | None = None,
        include_path: str | Path | list[str | Path] | None = None,
        include_defaults: bool = True,
        metadata: dict[str, dict[str, Any]] | None = None,
    ) -> None:
        if verbosity:
            setup_logging(verbosity)

        index = TaskIndex()
        self._factory = TaskFactory(meta=metadata)

        all_paths: list[Path] = []
        if include_defaults:
            all_paths.append(Path(__file__).parent)
        if include_path:
            all_paths += [
                Path(p)
                for p in (
                    include_path
                    if isinstance(include_path, (list, tuple))
                    else [include_path]
                )
            ]

        self._index = index.build(all_paths)

        buckets = defaultdict(list)
        for k, e in self._index.items():
            buckets[e.kind].append(k)

        self._all_tasks = sorted(
            chain.from_iterable(buckets[k] for k in {Kind.TASK, Kind.PY_TASK})
        )
        self._all_groups = sorted(buckets[Kind.GROUP])
        self._all_tags = sorted(buckets[Kind.TAG])

    def _entry(self, name: str) -> Entry:
        if name not in self._index:
            raise KeyError(f"Unknown task/group/tag: {name}")
        return self._index[name]

    def load_spec(self, spec: str | dict[str, Any]):
        """Spec can be:
        • str  task / group / tag name (registered)
        • dict inline overrides   {'task': 'hellaswag', 'num_fewshot': 5}
        """
        if isinstance(spec, str):
            entry = self._entry(spec)
            return self._factory.build(entry, overrides=None, registry=self._index)

        if isinstance(spec, dict):
            # inline dict => find base entry, then pass overrides
            name = spec["task"]
            entry = self._entry(name)
            return self._factory.build(entry, overrides=spec, registry=self._index)

        raise TypeError("spec must be str or dict")

    def load_task_or_group(self, task_list: str | list[str]):
        return (
            [self.load_spec(s) for s in task_list]
            if isinstance(task_list, list)
            else [self.load_spec(task_list)]
        )