__init__.py 42.5 KB
Newer Older
1
2
3
4
# ruff: noqa E402
from __future__ import annotations


Baber's avatar
Baber committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""
Task Management Module for LM Evaluation Harness.

This module provides comprehensive task discovery, loading, and management functionality
for the LM Evaluation Harness. It handles YAML configuration parsing with include support,
dynamic function importing, and task indexing across multiple directories.

Key Components:
- TaskManager: Main class for task discovery and management
- YAML configuration loading with !function tag support
- Task, group, and tag indexing
- Include resolution with cycle detection
- Caching for performance optimization

Example:
    Basic usage::

        task_manager = TaskManager()
        all_tasks = task_manager.all_tasks
        task_config = task_manager._get_config("hellaswag")

    Custom task paths::

        task_manager = TaskManager(
            include_path="/path/to/custom/tasks",
            include_defaults=True
        )
"""
33
import collections
34
35
import functools
import importlib.util
36
import inspect
37
import logging
38
import sys
39
from functools import partial
Baber's avatar
Baber committed
40
from pathlib import Path
Baber's avatar
nit  
Baber committed
41
42
43
44
45
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
)
46
47
48

import yaml
from yaml import YAMLError
&'s avatar
& committed
49

50
from lm_eval.api.group import GroupConfig
Lintang Sutawika's avatar
Lintang Sutawika committed
51
from lm_eval.evaluator_utils import get_subtask_list
Baber's avatar
nit  
Baber committed
52
53
54
55
from lm_eval.utils import pattern_match, setup_logging


if TYPE_CHECKING:
56
57
    from collections.abc import Generator, Mapping

Baber's avatar
nit  
Baber committed
58
    from lm_eval.api.task import ConfigurableTask, Task
Lintang Sutawika's avatar
Lintang Sutawika committed
59

Baber's avatar
nit  
Baber committed
60
eval_logger = logging.getLogger(__name__)
Lintang Sutawika's avatar
Lintang Sutawika committed
61

Baber's avatar
Baber committed
62
#: List of configuration keys that are specific to groups only
Lintang Sutawika's avatar
Lintang Sutawika committed
63
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
Baber's avatar
Baber committed
64
65

#: Base YAML loader class - uses C loader if available for performance
Baber's avatar
nit  
Baber committed
66
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
Baber's avatar
Baber committed
67
68

#: Directory names to ignore during task discovery
Baber's avatar
nit  
Baber committed
69
70
71
72
_IGNORE_DIRS = (
    "__pycache__",
    ".ipynb_checkpoints",
)
73

lintangsutawika's avatar
lintangsutawika committed
74

75
76
def _mk_function_ctor(base_dir: Path, resolve: bool):
    """Return a constructor that resolves !function relative to *base_dir*."""
77

78
79
80
81
82
    def ctor(loader: yaml.Loader, node: yaml.Node):
        spec = loader.construct_scalar(node)
        if not resolve:  # “simple” mode → stub
            return lambda *a, **kw: None
        return _import_function(spec, base_dir)
83

84
85
86
87
88
89
90
    return ctor


@functools.lru_cache(maxsize=1024)
def make_yaml_loader(base_dir: Path, *, simple: bool) -> type[yaml.Loader]:
    """Factory that returns a *cached* PyYAML Loader subclass bound to *base_dir*.
    simple=True  →  !function returns a stub (used when only metadata is needed).
91
92
93
    """

    class Loader(_Base):
94
        pass  # dynamic subclass just to carry custom constructors
95

96
97
98
99
100
    yaml.add_constructor(
        "!function",
        _mk_function_ctor(base_dir, resolve=not simple),
        Loader=Loader,
    )
101
102
103
    return Loader


104
105
106
107
108
@functools.lru_cache(maxsize=4096)
def _read_yaml(path: Path, *, resolve_functions: bool) -> dict:
    loader_cls = make_yaml_loader(path.parent, simple=not resolve_functions)
    with path.open("rb") as fh:
        return yaml.load(fh, Loader=loader_cls)
Baber's avatar
Baber committed
109
110


111
112
113
114
115
116
117
@functools.cache
def _import_function(qual: str, base_dir: Path):
    """Import `qual` where qual looks like  "my_utils.some_fn".
    Search order:
      1. <base_dir>/my_utils.py            (relative file)
      2. python importlib (package/module already importable)
    Uses file *mtime* so edits are reloaded without killing the process.
Baber's avatar
Baber committed
118
    """
119

120
121
122
    if "." not in qual:
        msg = f"!function value '{qual}' must contain a '.'"
        raise ValueError(msg)
Baber's avatar
Baber committed
123

124
125
    mod_part, _, fn_name = qual.rpartition(".")
    relative_path = (base_dir / f"{mod_part.replace('.', '/')}.py").resolve()
126

127
128
129
130
131
132
133
134
135
136
137
    if relative_path.exists():
        mtime = relative_path.stat().st_mtime_ns  # for cache busting
        module_key = f"{relative_path}:{mtime}"
        if module_key in sys.modules:
            mod = sys.modules[module_key]
        else:
            spec = importlib.util.spec_from_file_location(module_key, relative_path)
            mod = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(mod)  # type: ignore[arg-type]
            sys.modules[module_key] = mod
        return getattr(mod, fn_name)
138

139
140
    module = importlib.import_module(mod_part)
    return getattr(module, fn_name)
Baber's avatar
nit  
Baber committed
141
142


143
def load_yaml_config(
144
    yaml_path: Path | str,
145
    *,
146
    resolve_functions: bool = True,
147
    resolve_includes: bool = True,
148
    _seen: set[tuple[Path, bool]] | None = None,
149
) -> dict:
150
151
    """Read YAML once, optionally walk `include:` chains, with cycle detection."""
    path = Path(yaml_path).expanduser().resolve()
152
153
    if _seen is None:
        _seen = set()
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    key = (path, resolve_functions)
    if key in _seen:
        msg = f"Include cycle at {path}"
        raise ValueError(msg)
    _seen.add(key)

    cfg = _read_yaml(path, resolve_functions=resolve_functions)

    if not resolve_includes or "include" not in cfg:
        return cfg

    base_dir = path.parent
    merged: dict = {}
    for inc in cfg.pop("include"):
        inc_path = (
            (base_dir / inc).resolve() if not Path(inc).is_absolute() else Path(inc)
170
        )
171
172
173
174
175
176
177
178
179
180
181
        merged.update(
            load_yaml_config(
                inc_path,
                resolve_functions=resolve_functions,
                _seen=_seen,
            ),
        )
    merged.update(cfg)  # local keys win
    return merged


182
def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
183
    """Recursively iterate over all YAML files in a directory tree.
Baber's avatar
Baber committed
184
185
186
187
188
189
190
191
192
193
194
195

    Excludes files in ignored directories like __pycache__ and .ipynb_checkpoints.

    Args:
        root: Root directory to search for YAML files

    Yields:
        Path objects for each discovered YAML file

    Example:
        >>> for yaml_file in iter_yaml_files(Path("tasks")):
        ...     print(f"Found task config: {yaml_file}")
196

Baber's avatar
Baber committed
197
    """
198
199
    # for p in iglob(str(root / "**/*.yaml"), recursive=True):
    for p in root.glob("**/*.yaml"):
Baber's avatar
nit  
Baber committed
200
        # ignore check
Baber's avatar
Baber committed
201
202
203
        path = Path(p)
        # Check if any parent directory is in the ignore list
        if any(part in ignore for part in path.parts):
204
            continue
Baber's avatar
Baber committed
205
        yield path
Lintang Sutawika's avatar
Lintang Sutawika committed
206

207

208
class TaskManager:
209
    """Central manager for task discovery, indexing, and loading.
Baber's avatar
Baber committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

    TaskManager scans directories for YAML task configurations and maintains
    an index of all available tasks, groups, and tags. It provides methods
    for listing, filtering, and loading tasks with their configurations.

    The manager supports:
    - Automatic discovery from default lm_eval/tasks/ directory
    - Custom task directories via include_path
    - Task grouping and tagging
    - Configuration inheritance via YAML includes
    - Caching for performance

    Attributes:
        include_path: Additional directories to search for tasks
        metadata: Global metadata to inject into all task configs
        task_group_map: Mapping of tasks to their parent groups
226

Baber's avatar
Baber committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    Example:
        Basic usage::

            tm = TaskManager()
            print(f"Found {len(tm.all_tasks)} tasks")
            hellaswag_config = tm._get_config("hellaswag")

        With custom tasks::

            tm = TaskManager(
                include_path="/my/custom/tasks",
                verbosity="INFO"
            )
            custom_tasks = [t for t in tm.all_tasks if "custom" in t]
241

242
243
    """

244
245
    def __init__(
        self,
246
247
        verbosity: str | None = None,
        include_path: str | Path | list[str | Path] | None = None,
248
        include_defaults: bool = True,
249
        metadata: dict[str, dict[str, Any]] | None = None,
250
    ) -> None:
251
        """Initialize the TaskManager.
Baber's avatar
Baber committed
252
253
254
255
256
257
258

        Args:
            verbosity: Logging verbosity level (DEBUG, INFO, WARNING, ERROR)
            include_path: Additional path(s) to search for tasks. Can be a single
                         path or list of paths.
            include_defaults: Whether to include default tasks from lm_eval/tasks/
            metadata: Global metadata dictionary to inject into all task configs
259

Baber's avatar
Baber committed
260
        """
Lintang Sutawika's avatar
Lintang Sutawika committed
261
        if verbosity is not None:
Baber's avatar
nit  
Baber committed
262
            setup_logging(verbosity)
263
        self.include_path = include_path
Baber Abbasi's avatar
Baber Abbasi committed
264
        self.metadata = metadata
265
        self._task_index = self.initialize_tasks(
266
267
            include_path=include_path,
            include_defaults=include_defaults,
268
        )
269
        self._all_tasks = sorted(self._task_index.keys())
270

271
        self._all_groups = sorted(
272
            [x for x in self._all_tasks if self._task_index[x]["type"] == "group"],
273
274
        )
        self._all_subtasks = sorted(
275
276
277
278
            [
                x
                for x in self._all_tasks
                if self._task_index[x]["type"] in ["task", "python_task"]
279
            ],
280
281
        )
        self._all_tags = sorted(
282
            [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"],
283
284
        )

285
        self.task_group_map = collections.defaultdict(list)
286

287
288
    def initialize_tasks(
        self,
289
        include_path: str | Path | list[str | Path] | None = None,
290
        include_defaults: bool = True,
Baber Abbasi's avatar
Baber Abbasi committed
291
292
    ) -> dict[str, dict]:
        """Creates a dictionary of tasks indexes.
293

Baber's avatar
nit  
Baber committed
294
        :param include_path: Union[str, list] = None
295
296
297
298
            An additional path to be searched for tasks recursively.
            Can provide more than one such path as a list.
        :param include_defaults: bool = True
            If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
299
300

        Return:
Baber's avatar
nit  
Baber committed
301
            dictionary of task names as key and task metadata
302

303
        """
304
        all_paths = [Path(__file__).parent] if include_defaults else []
305
        if include_path is not None:
Baber's avatar
Baber committed
306
            if isinstance(include_path, (str, Path)):
307
                include_path = [include_path]
Baber's avatar
Baber committed
308
309
            # Convert all paths to Path objects
            all_paths.extend(Path(p) for p in include_path)
310

311
312
313
314
        task_index = {}
        for task_dir in all_paths:
            tasks = self._get_task_and_group(task_dir)
            task_index = {**tasks, **task_index}
lintangsutawika's avatar
format  
lintangsutawika committed
315

316
317
318
        return task_index

    @property
Baber's avatar
nit  
Baber committed
319
    def all_tasks(self) -> list[str]:
Baber's avatar
Baber committed
320
        """Get sorted list of all task names (tasks, groups, and tags)."""
321
322
        return self._all_tasks

323
    @property
Baber's avatar
nit  
Baber committed
324
    def all_groups(self) -> list[str]:
Baber's avatar
Baber committed
325
        """Get sorted list of all group names."""
326
327
328
        return self._all_groups

    @property
Baber's avatar
nit  
Baber committed
329
    def all_subtasks(self) -> list[str]:
Baber's avatar
Baber committed
330
        """Get sorted list of all individual task names (excludes groups and tags)."""
331
332
333
        return self._all_subtasks

    @property
Baber's avatar
nit  
Baber committed
334
    def all_tags(self) -> list[str]:
Baber's avatar
Baber committed
335
        """Get sorted list of all tag names."""
336
337
        return self._all_tags

338
    @property
339
    def task_index(self) -> dict[str, dict[str, str | int | list[str]]]:
Baber's avatar
Baber committed
340
        """Get the complete task index with metadata for all tasks."""
341
342
        return self._task_index

343
    def list_all_tasks(
Baber's avatar
Baber committed
344
345
346
347
        self,
        list_groups: bool = True,
        list_tags: bool = True,
        list_subtasks: bool = True,
348
    ) -> str:
349
        """Return a Markdown table (as a string) listing groups, tags and/or subtasks
350
351
352
        known to this TaskManager.  Safe for configs whose yaml_path is -1 and for
        task configs whose `include:` is a list.
        """
353
354
        from pytablewriter import MarkdownTableWriter

355
        # ------------------------------------------------------------------ helpers
Baber's avatar
Baber committed
356
        def sanitize_path(path: str) -> str:
357
358
            # print a relative path for anything inside lm_eval/tasks/
            # path_str = str(path)
359
360
            if "lm_eval/tasks/" in path:
                return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
361
362
363
364
            return path

        def first_output_type_from_includes(cfg: dict, base: Path) -> str:
            """Walk cfg['include'] (string or list) and return the first
365
366
            include that itself specifies an output_type.
            """
367
368
369
370
371
372
            inc_raw = cfg.get("include")
            if not inc_raw:
                return ""

            inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw]
            for inc in inc_list:
Baber's avatar
nit  
Baber committed
373
374
375
376
377
378
379
380
381
382
                if inc:
                    inc_path = Path(inc)
                    if not inc_path.is_absolute():  # treat as relative include
                        inc_path = base.parent / inc_path
                    try:
                        inc_cfg = load_yaml_config(inc_path, mode="simple")
                    except FileNotFoundError:
                        continue
                    if "output_type" in inc_cfg:
                        return inc_cfg["output_type"]
383
384
385
            return ""

        # -------------------------------------------------------------- GROUP table
386
387
        group_table = MarkdownTableWriter()
        group_table.headers = ["Group", "Config Location"]
388
389
390
391
392
393
394
395
396
        group_table.value_matrix = [
            [
                g,
                "---"
                if self.task_index[g]["yaml_path"] == -1
                else sanitize_path(self.task_index[g]["yaml_path"]),
            ]
            for g in self.all_groups
        ]
397

398
        # ---------------------------------------------------------------- TAG table
399
400
401
402
        tag_table = MarkdownTableWriter()
        tag_table.headers = ["Tag"]
        tag_table.value_matrix = [[t] for t in self.all_tags]

403
        # ------------------------------------------------------------ SUBTASK table
404
405
        subtask_table = MarkdownTableWriter()
        subtask_table.headers = ["Task", "Config Location", "Output Type"]
406
407
        st_values: list[list[str]] = []

408
        for t in self.all_subtasks:
409
410
411
412
413
414
            raw_path = self.task_index[t]["yaml_path"]

            if raw_path == -1:
                # python-only task or generated at runtime
                display_path = "---"
                output_type = ""
415
            else:
416
417
418
419
420
421
422
423
424
425
426
427
                path_obj = Path(raw_path)
                display_path = sanitize_path(str(path_obj))

                # load minimal YAML to discover output_type
                cfg = load_yaml_config(path_obj, mode="simple")
                if "output_type" in cfg:
                    output_type = cfg["output_type"]
                else:
                    output_type = first_output_type_from_includes(cfg, path_obj)

            st_values.append([t, display_path, output_type])

428
429
        subtask_table.value_matrix = st_values

430
431
        # ------------------------------------------------------------- final string
        parts: list[str] = ["\n"]
432
        if list_groups:
433
434
            parts.append(group_table.dumps())
            parts.append("\n")
435
        if list_tags:
436
437
            parts.append(tag_table.dumps())
            parts.append("\n")
438
        if list_subtasks:
439
440
441
442
            parts.append(subtask_table.dumps())
            parts.append("\n")

        return "".join(parts)
443

Baber Abbasi's avatar
Baber Abbasi committed
444
    def match_tasks(self, task_list: list[str]) -> list[str]:
445
        """Match task names using glob-style pattern matching."""
Baber's avatar
nit  
Baber committed
446
        return pattern_match(task_list, self.all_tasks)
447

Baber Abbasi's avatar
Baber Abbasi committed
448
    def _name_is_registered(self, name: str) -> bool:
Baber's avatar
Baber committed
449
        """Check if a name is registered in the task index."""
450
        return name in self.all_tasks
451

Baber Abbasi's avatar
Baber Abbasi committed
452
    def _name_is_task(self, name: str) -> bool:
Baber's avatar
Baber committed
453
        """Check if a name refers to an individual task (not group or tag)."""
454
455
456
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "task"
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
457

Baber Abbasi's avatar
Baber Abbasi committed
458
    def _name_is_tag(self, name: str) -> bool:
Baber's avatar
Baber committed
459
        """Check if a name refers to a tag."""
460
        return self._name_is_registered(name) and self.task_index[name]["type"] == "tag"
461

Baber Abbasi's avatar
Baber Abbasi committed
462
    def _name_is_group(self, name: str) -> bool:
Baber's avatar
Baber committed
463
        """Check if a name refers to a group."""
464
465
466
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "group"
        )
467

Baber Abbasi's avatar
Baber Abbasi committed
468
    def _name_is_python_task(self, name: str) -> bool:
Baber's avatar
Baber committed
469
        """Check if a name refers to a Python-defined task."""
470
471
472
473
        return (
            self._name_is_registered(name)
            and self.task_index[name]["type"] == "python_task"
        )
474

Baber's avatar
fix  
Baber committed
475
476
    @staticmethod
    def _config_is_task(config: dict) -> bool:
Baber's avatar
Baber committed
477
        """Check if a config dictionary defines a single task."""
478
        return "task" in config and isinstance(config["task"], str)
479

Baber's avatar
fix  
Baber committed
480
481
    @staticmethod
    def _config_is_group(config: dict) -> bool:
Baber's avatar
Baber committed
482
        """Check if a config dictionary defines a group of tasks."""
483
        return "task" in config and isinstance(config["task"], list)
484

Baber's avatar
fix  
Baber committed
485
486
    @staticmethod
    def _config_is_python_task(config: dict) -> bool:
Baber's avatar
Baber committed
487
        """Check if a config dictionary defines a Python class-based task."""
488
489
        return "class" in config

Baber's avatar
fix  
Baber committed
490
491
    @staticmethod
    def _config_is_task_list(config: dict) -> bool:
Baber's avatar
Baber committed
492
        """Check if a config dictionary defines a task list."""
493
        return "task_list" in config and isinstance(config["task_list"], list)
494

495
496
    def _get_yaml_path(self, name: str) -> str | int | list[str]:
        """Get the YAML file path for a registered task.
Baber's avatar
Baber committed
497
498
499
500
501
502
503
504
505

        Args:
            name: Task name

        Returns:
            Path to YAML file, or -1 for Python-only tasks

        Raises:
            ValueError: If task name is not registered
506

Baber's avatar
Baber committed
507
        """
508
509
        if name not in self.task_index:
            raise ValueError
510
511
        return self.task_index[name]["yaml_path"]

Baber's avatar
nit  
Baber committed
512
    def _get_config(self, name: str) -> dict:
513
        """Load the full configuration for a registered task.
Baber's avatar
Baber committed
514
515
516
517
518
519
520
521
522

        Args:
            name: Task name

        Returns:
            Complete task configuration dictionary

        Raises:
            ValueError: If task name is not registered
523

Baber's avatar
Baber committed
524
        """
525
526
        if name not in self.task_index:
            raise ValueError
527
528
529
        yaml_path = self._get_yaml_path(name)
        if yaml_path == -1:
            return {}
530
        return load_yaml_config(Path(yaml_path))
531

532
533
    def _get_tasklist(self, name: str) -> list[str] | int:
        """Get the task list for a group or tag.
Baber's avatar
Baber committed
534
535
536
537
538
539
540
541
542

        Args:
            name: Group or tag name

        Returns:
            List of task names in the group/tag

        Raises:
            ValueError: If name refers to an individual task
543

Baber's avatar
Baber committed
544
        """
545
546
        if self._name_is_task(name):
            raise ValueError
547
548
        return self.task_index[name]["task"]

549
    @staticmethod
550
551
552
553
    def _register_task(
        task_name: str,
        task_type: str,
        yaml_path: str,
Baber's avatar
nit  
Baber committed
554
        tasks_and_groups: dict[str, dict],
555
556
        config: dict | None = None,
        populate_tags_fn: Callable | None = None,
Baber's avatar
Baber committed
557
    ) -> None:
558
        """Helper method to register a task in the tasks_and_groups dict."""
559
560
561
562
563
564
        tasks_and_groups[task_name] = {
            "type": task_type,
            "yaml_path": yaml_path,
        }
        # Only populate tags for configs that support it (not groups)
        if config and task_type != "group" and populate_tags_fn:
565
            populate_tags_fn(config, task_name, tasks_and_groups)
566

567
    @staticmethod
568
    def _merge_task_configs(
569
570
571
        base_config: dict,
        task_specific_config: dict,
        task_name: str,
Baber's avatar
nit  
Baber committed
572
    ) -> dict:
573
        """Merge base config with task-specific overrides for task_list configs."""
574
575
576
577
578
579
        if task_specific_config:
            task_specific_config = task_specific_config.copy()
            task_specific_config.pop("task", None)
            return {**base_config, **task_specific_config, "task": task_name}
        return {**base_config, "task": task_name}

Baber's avatar
Baber committed
580
    def _process_tag_subtasks(
581
582
583
        self,
        tag_name: str,
        update_config: dict | None = None,
Baber's avatar
nit  
Baber committed
584
    ) -> dict:
585
        """Process subtasks for a tag and return loaded tasks."""
586
587
588
589
590
591
592
        subtask_list = self._get_tasklist(tag_name)
        fn = partial(
            self._load_individual_task_or_group,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))

593
594
    @staticmethod
    def _process_alias(config: dict, group: str | None = None) -> dict:
595
        """Process group alias configuration.
Baber's avatar
Baber committed
596
597
598
599
600
601
602
603
604
605

        If the group is not the same as the original group which the group alias
        was intended for, set the group_alias to None instead.

        Args:
            config: Task configuration dictionary
            group: Group name to validate against

        Returns:
            Modified configuration with processed aliases
606

Baber's avatar
Baber committed
607
        """
608
609
610
611
612
613
614
        if (
            ("group_alias" in config)
            and ("group" in config)
            and group is not None
            and config["group"] != group
        ):
            config["group_alias"] = None
615
616
        return config

Baber's avatar
Baber committed
617
    def _class_has_config_in_constructor(self, cls) -> bool:
618
        """Check if a class constructor accepts a 'config' parameter.
Baber's avatar
Baber committed
619
620
621
622
623
624

        Args:
            cls: Class to inspect

        Returns:
            True if constructor has 'config' parameter, False otherwise
625

Baber's avatar
Baber committed
626
        """
627
628
629
630
631
632
633
        constructor = getattr(cls, "__init__", None)
        return (
            "config" in inspect.signature(constructor).parameters
            if constructor
            else False
        )

634
635
636
637
638
    ###############################################################################
    # NEW: Refactored _load_individual_task_or_group and helper methods          #
    ###############################################################################

    def _create_task_object(
639
        self,
640
641
        cfg: dict,
        task_name: str,
642
        yaml_path: str | None,
643
    ) -> dict:
644
        """Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
645
        Returns {task_name: task_object}.
Baber's avatar
Baber committed
646
        """
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
        from lm_eval.api.task import ConfigurableTask, Task  # local import avoids cycle

        # ---- include handling ---------------------------------------------------
        if "include" in cfg:
            # keep original name so include merging cannot clobber it
            orig_name = cfg.get("task", task_name)
            cfg = {
                **load_yaml_config(  # recurse once, cached
                    yaml_path=Path(yaml_path) if yaml_path else None,
                    yaml_config={"include": cfg.pop("include")},
                    mode="full" if yaml_path else "simple",
                ),
                **cfg,
                "task": orig_name,
            }
662

663
664
665
666
667
668
669
670
671
672
673
674
        # ---- metadata merge -----------------------------------------------------
        if self.metadata is not None:
            cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
        else:
            cfg["metadata"] = cfg.get("metadata", {})

        # ---- python-task vs YAML-task -------------------------------------------
        if self._config_is_python_task(cfg):
            cls = cfg["class"]
            task_obj: Task
            if self._class_has_config_in_constructor(cls):
                task_obj = cls(config=cfg)
675
            else:
676
677
678
679
680
681
                task_obj = cls()
            # make sure name propagates when the class inherits ConfigurableTask
            if isinstance(task_obj, ConfigurableTask):  # type: ignore
                task_obj.config.task = task_name
        else:
            task_obj = ConfigurableTask(config=cfg)  # type: ignore
Baber's avatar
Baber committed
682

683
        return {task_name: task_obj}
Baber's avatar
Baber committed
684

685
686
687
    def _create_group_object(
        self,
        cfg: dict,
688
689
690
        parent_name: str | None = None,
    ) -> tuple[GroupConfig, list[str | dict]]:
        """Build GroupConfig and return (group_obj, subtask_names).
691
692
693
694
695
        Resolves tag expansion.
        """
        if self.metadata is not None:
            cfg["metadata"] = cfg.get("metadata", {}) | self.metadata

696
        grp = GroupConfig(**cfg)
697
        subtasks: list[str | dict] = []
Baber's avatar
fix  
Baber committed
698
699
700
701
702
703
        if grp.task:
            for t in grp.task:
                if isinstance(t, str) and self._name_is_tag(t):
                    subtasks.extend(self._get_tasklist(t))
                else:
                    subtasks.append(t)
704
        return grp, subtasks
Baber's avatar
Baber committed
705

706
707
    def _load_subtasks(
        self,
708
709
710
        subtasks: list[str | dict],
        parent_name: str | GroupConfig | None,
        update_config: dict | None,
711
712
713
714
715
716
717
718
    ) -> Mapping:
        """Return merged mapping of all subtasks, handling duplicates."""
        fn = functools.partial(
            self._load_individual_task_or_group,
            parent_name=parent_name,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtasks))))
Baber's avatar
Baber committed
719

720
721
    def _load_individual_task_or_group(
        self,
722
        payload: str | dict,
723
        *,
724
725
        parent_name: str | None = None,
        update_config: dict | None = None,
726
    ) -> Mapping:
727
        """Public helper that turns *payload* (str task/group/tag **or** dict config)
728
729
730
731
732
733
734
735
736
737
        into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
        """
        # ------------------------------------------------------------------ STRING
        if isinstance(payload, str):
            # If caller supplied extra overrides, treat as dict immediately
            if update_config:
                return self._load_individual_task_or_group(
                    {"task": payload, **update_config},
                    parent_name=parent_name,
                )
738

739
740
741
742
743
744
745
746
747
748
749
750
751
752
            # ------------ registered TASK (YAML or python) -----------------
            if self._name_is_task(payload) or self._name_is_python_task(payload):
                yaml_path = self._get_yaml_path(payload)
                cfg = self._get_config(payload)

                # task_list configs: extract the per-task override ------------
                if "task_list" in cfg:
                    override = next(
                        (
                            entry
                            for entry in cfg["task_list"]
                            if isinstance(entry, dict) and entry.get("task") == payload
                        ),
                        None,
Lintang Sutawika's avatar
Lintang Sutawika committed
753
                    )
754
755
756
757
758
759
760
761
762
763
764
                    base = {k: v for k, v in cfg.items() if k != "task_list"}
                    if override:
                        cfg = {**base, **override, "task": payload}
                return self._create_task_object(cfg, payload, yaml_path)

            # ------------ registered GROUP ----------------------------------
            if self._name_is_group(payload):
                group_cfg = self._get_config(payload)
                grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS}
                grp_obj, subtasks = self._create_group_object(grp_only, parent_name)
                return {
765
                    grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None),
766
                }
767

768
769
770
771
            # ------------ registered TAG ------------------------------------
            if self._name_is_tag(payload):
                return self._process_tag_subtasks(payload, update_config=None)

772
773
            msg = f"Unknown task / group / tag name: {payload!r}"
            raise ValueError(msg)
774
775
776
777
778
779
780
781
782
783
784

        # ------------------------------------------------------------------- DICT
        if isinstance(payload, dict):
            # ------------------ simple 'task: name' dict --------------------
            if self._config_is_task(payload):
                name = payload["task"]
                # override existing registered YAML if exists
                if self._name_is_registered(name):
                    base_cfg = self._get_config(name)
                    yaml_path = self._get_yaml_path(name)
                    merged = {**base_cfg, **payload}
785
                else:
786
                    merged = payload
787
                    yaml_path = None
788

789
790
791
792
793
794
795
                # duplicate-naming guard when inside a group
                if parent_name is not None:
                    count = len(
                        [
                            n
                            for n in self.task_group_map[parent_name]
                            if n.startswith(name)
796
                        ],
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
                    )
                    if count:
                        name = f"{name}-{count}"
                    self.task_group_map[parent_name].append(name)

                return self._create_task_object(merged, name, yaml_path)

            # ----------------- literal group dict (task: [...]) -------------
            if self._config_is_group(payload):
                grp_cfg = {k: v for k, v in payload.items() if k in GROUP_ONLY_KEYS}
                sub_override = {
                    k: v for k, v in payload.items() if k not in GROUP_ONLY_KEYS
                } or None
                grp_obj, subtasks = self._create_group_object(grp_cfg, parent_name)
                return {grp_obj: self._load_subtasks(subtasks, grp_obj, sub_override)}

            # ----------------- python-task dict ('class': …) ----------------
            if self._config_is_python_task(payload):
                name = payload["task"]
                return self._create_task_object(payload, name, yaml_path=None)

818
        msg = f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
819
        raise TypeError(
820
            msg,
821
        )
822

Baber's avatar
Baber committed
823
    def load_task_or_group(
824
825
        self,
        task_list: str | list[str] | None = None,
Baber's avatar
nit  
Baber committed
826
    ) -> dict:
827
        """Load multiple tasks or groups from a list of names.
Baber's avatar
Baber committed
828
829
830
831
832
833
834
835

        This is the main entry point for loading tasks. It handles lists
        of task names and delegates to _load_individual_task_or_group for
        each item, then merges the results.

        Args:
            task_list: Single task name or list of task names to load.
                      Can include individual tasks, groups, and tags.
836

Baber's avatar
Baber committed
837
838
839
        Returns:
            Dictionary mapping task/group names to loaded task objects.
            Results from all requested items are merged into a single dict.
840

Baber's avatar
Baber committed
841
842
843
844
845
846
847
848
849
850
        Example:
            Load multiple tasks::

                tasks = tm.load_task_or_group(["hellaswag", "arc_easy"])
                # Returns: {"hellaswag": Task1, "arc_easy": Task2}

            Load a group::

                tasks = tm.load_task_or_group("arc_group")
                # Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
851

852
853
854
        """
        if isinstance(task_list, str):
            task_list = [task_list]
855

856
        return dict(
Baber Abbasi's avatar
Baber Abbasi committed
857
            collections.ChainMap(
858
859
                *(self._load_individual_task_or_group(task) for task in task_list),
            ),
860
861
        )

Baber's avatar
nit  
Baber committed
862
    def load_config(self, config: dict) -> Mapping:
863
        """Load a task from an inline configuration dictionary.
Baber's avatar
Baber committed
864
865
866
867
868
869
870
871
872
873

        Args:
            config: Configuration dictionary defining the task

        Returns:
            Mapping of task name to loaded task object

        Example:
            >>> config = {"task": "hellaswag", "num_fewshot": 5}
            >>> task_dict = tm.load_config(config)
874

Baber's avatar
Baber committed
875
        """
876
877
        return self._load_individual_task_or_group(config)

878
879
    def _get_task_and_group(self, task_dir: str | Path) -> dict[str, dict]:
        """Scan a directory for task configurations and build an index.
Baber's avatar
Baber committed
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901

        Creates a dictionary of task metadata by recursively scanning for
        YAML files and parsing their configurations. This method handles:
        - Regular task configs with 'task' key
        - Python class-based tasks with 'class' key
        - Group configs with 'group' key
        - Task list configs with 'task_list' key
        - Tag extraction and registration

        Args:
            task_dir: Directory path to scan for YAML task configurations

        Returns:
            Dictionary mapping task names to metadata dictionaries.
            Each metadata dict contains:
            - 'type': One of 'task', 'python_task', 'group', 'tag'
            - 'yaml_path': Path to source YAML file (or -1 for generated entries)
            - 'task': For groups/tags, list of constituent task names

        Note:
            This method is called during TaskManager initialization to build
            the master task index. It uses 'simple' parsing mode for performance.
902

903
        """
904

Baber's avatar
Baber committed
905
        def _populate_tags_and_groups(
906
907
908
            config: dict,
            task: str,
            tasks_and_groups: dict[str, dict],
Baber's avatar
Baber committed
909
        ) -> None:
910
            """Extract and register tags from a task configuration.
Baber's avatar
Baber committed
911
912
913
914
915
916
917
918
919

            Tags allow grouping tasks by theme or category. This function
            processes the 'tag' field in task configs and maintains tag
            indices for quick lookup.

            Args:
                config: Task configuration dictionary
                task: Name of the task being processed
                tasks_and_groups: Master index to update with tag information
920

Baber's avatar
Baber committed
921
            """
922
            # TODO: remove group in next release
923
924
925
926
927
928
929
930
931
932
933
934
935
            if "tag" in config:
                attr_list = config["tag"]
                if isinstance(attr_list, str):
                    attr_list = [attr_list]

                for tag in attr_list:
                    if tag not in tasks_and_groups:
                        tasks_and_groups[tag] = {
                            "type": "tag",
                            "task": [task],
                            "yaml_path": -1,
                        }
                    elif tasks_and_groups[tag]["type"] != "tag":
Lintang Sutawika's avatar
Lintang Sutawika committed
936
                        eval_logger.info(
937
                            f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
938
                            "This may affect tasks you want to call.",
939
                        )
940
941
942
                        break
                    else:
                        tasks_and_groups[tag]["task"].append(task)
943

Lintang Sutawika's avatar
Lintang Sutawika committed
944
        # TODO: remove group in next release
945
946
947
948
        # ignore_dirs = [
        #     "__pycache__",
        #     ".ipynb_checkpoints",
        # ]
949
        tasks_and_groups = collections.defaultdict()
Baber's avatar
Baber committed
950
951
        task_dir_path = Path(task_dir)

952
953
954
        for yaml_path in iter_yaml_files(task_dir_path):
            try:
                config = load_yaml_config(
955
956
957
                    yaml_path,
                    resolve_functions=False,
                    resolve_includes=False,
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
                )
            except (FileNotFoundError, YAMLError, OSError) as err:
                eval_logger.debug(f"File {yaml_path} could not be loaded ({err})")
                continue
            if self._config_is_python_task(config):
                # This is a python class config
                task = config["task"]
                self._register_task(
                    "python_task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_group(config):
                # This is a group config
                tasks_and_groups[config["group"]] = {
                    "type": "group",
                    "task": -1,  # This signals that
                    # we don't need to know
                    # the task list for indexing
                    # as it can be loaded
                    # when called.
                    "yaml_path": str(yaml_path),
                }
lintangsutawika's avatar
lintangsutawika committed
983

984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
                # # Registered the level 1 tasks from a group config
                # for config in config["task"]:
                #     if isinstance(config, dict) and self._config_is_task(config):
                #         task = config["task"]
                #         tasks_and_groups[task] = {
                #             "type": "task",
                #             "yaml_path": yaml_path,
                #             }

            elif self._config_is_task(config):
                # This is a task config
                task = config["task"]
                self._register_task(
                    "task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_task_list(config):
                # This is a task_list config
                for task_entry in config["task_list"]:
                    if isinstance(task_entry, dict) and "task" in task_entry:
                        task_name = task_entry["task"]
1008
1009
                        self._register_task(
                            "task",
Baber's avatar
Baber committed
1010
                            str(yaml_path),
1011
1012
1013
                            tasks_and_groups,
                            config,
                            _populate_tags_and_groups,
1014
                        )
1015
1016
            else:
                eval_logger.debug(f"File {yaml_path} could not be loaded")
1017
1018

        return tasks_and_groups
lintangsutawika's avatar
lintangsutawika committed
1019

1020

Baber's avatar
nit  
Baber committed
1021
def get_task_name_from_config(task_config: dict[str, str]) -> str:
1022
    """Extract a task name from a configuration dictionary.
Baber's avatar
Baber committed
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040

    Determines the canonical name for a task based on its configuration,
    with fallback strategies for different config formats.

    Args:
        task_config: Task configuration dictionary

    Returns:
        String name for the task

    Example:
        >>> config = {"task": "hellaswag", "num_fewshot": 5}
        >>> get_task_name_from_config(config)
        'hellaswag'

        >>> config = {"dataset_path": "custom", "dataset_name": "mytask"}
        >>> get_task_name_from_config(config)
        'custom_mytask'
1041

Baber's avatar
Baber committed
1042
    """
1043
1044
1045
1046
    if "task" in task_config:
        return task_config["task"]
    if "dataset_name" in task_config:
        return "{dataset_path}_{dataset_name}".format(**task_config)
1047
    return "{dataset_path}".format(**task_config)
lintangsutawika's avatar
lintangsutawika committed
1048

1049

1050
1051
def get_task_name_from_object(task_object: ConfigurableTask | Task) -> str:
    """Extract the name from an instantiated task object.
Baber's avatar
Baber committed
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065

    Handles both ConfigurableTask and legacy Task objects with different
    attribute conventions for storing the task name.

    Args:
        task_object: An instantiated task object

    Returns:
        String name of the task

    Example:
        >>> task = ConfigurableTask(config={"task": "hellaswag"})
        >>> get_task_name_from_object(task)
        'hellaswag'
1066

Baber's avatar
Baber committed
1067
    """
1068
1069
    if hasattr(task_object, "config"):
        return task_object._config["task"]
lintangsutawika's avatar
lintangsutawika committed
1070
1071
1072
1073
1074
1075
1076
1077
1078

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )

1079

Baber's avatar
nit  
Baber committed
1080
def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
1081
    """Validate that no tasks appear in multiple groups simultaneously.
Baber's avatar
Baber committed
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098

    Helper function used to prevent conflicts when multiple groups claim
    the same constituent task. This could lead to ambiguous configuration
    like conflicting num_fewshot values.

    Args:
        task_dict: Dictionary mapping group names to lists of subtask names

    Raises:
        ValueError: If any tasks appear in multiple groups

    Example:
        >>> task_dict = {
        ...     "group1": ["task_a", "task_b"],
        ...     "group2": ["task_b", "task_c"]  # task_b appears twice!
        ... }
        >>> _check_duplicates(task_dict)  # Raises ValueError
1099

Lintang Sutawika's avatar
Lintang Sutawika committed
1100
1101
    """
    subtask_names = []
1102
    for value in task_dict.values():
Lintang Sutawika's avatar
Lintang Sutawika committed
1103
1104
1105
1106
1107
1108
1109
1110
1111
        subtask_names.extend(value)

    duplicate_tasks = {
        task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
    }

    # locate the potentially problematic groups that seem to 'compete' for constituent subtasks
    competing_groups = [
        group
1112
        for group in task_dict
Lintang Sutawika's avatar
Lintang Sutawika committed
1113
1114
1115
1116
        if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
    ]

    if len(duplicate_tasks) > 0:
1117
        msg = f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
Lintang Sutawika's avatar
Lintang Sutawika committed
1118
        raise ValueError(
1119
            msg,
Lintang Sutawika's avatar
Lintang Sutawika committed
1120
1121
1122
        )


1123
def get_task_dict(
1124
1125
1126
1127
    task_name_list: str | list[str | dict | Task],
    task_manager: TaskManager | None = None,
) -> dict[str, ConfigurableTask | Task]:
    """Create a dictionary of task objects from mixed input types.
Baber's avatar
Baber committed
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172

    This is the main public API for loading tasks. It accepts various input
    formats (names, configs, objects) and returns a unified dictionary of
    instantiated task objects ready for evaluation.

    The function handles:
    - String task names (looked up via TaskManager)
    - Configuration dictionaries (processed as inline configs)
    - Pre-instantiated Task objects (used as-is)
    - Validation to prevent conflicting group memberships

    Args:
        task_name_list: Mixed list of task specifications:
                       - str: Task name to look up
                       - dict: Inline task configuration
                       - Task: Pre-instantiated task object
        task_manager: TaskManager instance for name resolution.
                     If None, creates a default TaskManager.

    Returns:
        Dictionary mapping task names to instantiated task objects.
        All tasks are ready for evaluation.

    Raises:
        TypeError: If task_name_list contains unsupported types
        ValueError: If there are conflicting group memberships

    Example:
        Mixed input types::

            tasks = get_task_dict([
                "hellaswag",                              # lookup by name
                {"task": "arc_easy", "num_fewshot": 5},   # inline config
                pre_existing_task_object                  # direct object
            ])

        Simple case::

            tasks = get_task_dict("hellaswag")
            # Returns: {"hellaswag": ConfigurableTask(...)}

        With custom TaskManager::

            tm = TaskManager(include_path="/custom/tasks")
            tasks = get_task_dict(["custom_task"], task_manager=tm)
1173

1174
    """
1175
    from lm_eval.api.task import Task
1176

1177
    # Normalize input to list
1178
    if isinstance(task_name_list, str):
lintangsutawika's avatar
lintangsutawika committed
1179
        task_name_list = [task_name_list]
1180
    elif not isinstance(task_name_list, list):
1181
        msg = f"Expected a 'str' or 'list' but received {type(task_name_list)}."
1182
        raise TypeError(
1183
            msg,
1184
        )
lintangsutawika's avatar
lintangsutawika committed
1185

1186
1187
    # Validate list items
    if not all(isinstance(task, (str, dict, Task)) for task in task_name_list):
1188
        msg = "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
1189
        raise TypeError(
1190
            msg,
1191
        )
1192

1193
1194
1195
    # Ensure we have a task manager
    if task_manager is None:
        task_manager = TaskManager()
Lintang Sutawika's avatar
Lintang Sutawika committed
1196

1197
1198
1199
1200
1201
1202
1203
    # Process all items
    final_task_dict = {}
    for task_spec in task_name_list:
        if isinstance(task_spec, Task):
            # Pre-instantiated task object
            task_name = get_task_name_from_object(task_spec)
            if task_name in final_task_dict:
1204
1205
                msg = f"Duplicate task name: {task_name}"
                raise ValueError(msg)
1206
1207
1208
1209
1210
1211
1212
            final_task_dict[task_name] = task_spec
        else:
            # String or dict - use load_task_or_group
            result = task_manager.load_task_or_group(task_spec)
            # Check for duplicate names
            for name in result:
                if name in final_task_dict:
1213
1214
                    msg = f"Duplicate task name: {name}"
                    raise ValueError(msg)
1215
1216
1217
            final_task_dict.update(result)

    # Check for conflicting group memberships
Lintang Sutawika's avatar
Lintang Sutawika committed
1218
1219
1220
    _check_duplicates(get_subtask_list(final_task_dict))

    return final_task_dict