__init__.py 46.1 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
Task Management Module for LM Evaluation Harness.

This module provides comprehensive task discovery, loading, and management functionality
for the LM Evaluation Harness. It handles YAML configuration parsing with include support,
dynamic function importing, and task indexing across multiple directories.

Key Components:
- TaskManager: Main class for task discovery and management
- YAML configuration loading with !function tag support
- Task, group, and tag indexing
- Include resolution with cycle detection
- Caching for performance optimization

Example:
    Basic usage::

        task_manager = TaskManager()
        all_tasks = task_manager.all_tasks
        task_config = task_manager._get_config("hellaswag")

    Custom task paths::

        task_manager = TaskManager(
            include_path="/path/to/custom/tasks",
            include_defaults=True
        )
"""

30
import collections
31
32
import functools
import importlib.util
33
import inspect
34
import logging
35
import sys
36
from functools import partial
37
from glob import iglob
Baber's avatar
Baber committed
38
from pathlib import Path
Baber's avatar
nit  
Baber committed
39
40
41
42
43
44
45
46
47
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Generator,
    Mapping,
    Optional,
    Union,
)
48
49
50

import yaml
from yaml import YAMLError
&'s avatar
& committed
51

52
from lm_eval.api.group import GroupConfig
Lintang Sutawika's avatar
Lintang Sutawika committed
53
from lm_eval.evaluator_utils import get_subtask_list
Baber's avatar
nit  
Baber committed
54
55
56
57
58
from lm_eval.utils import pattern_match, setup_logging


if TYPE_CHECKING:
    from lm_eval.api.task import ConfigurableTask, Task
Lintang Sutawika's avatar
Lintang Sutawika committed
59

Baber's avatar
nit  
Baber committed
60
eval_logger = logging.getLogger(__name__)
Lintang Sutawika's avatar
Lintang Sutawika committed
61

Baber's avatar
Baber committed
62
#: List of configuration keys that are specific to groups only
Lintang Sutawika's avatar
Lintang Sutawika committed
63
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
Baber's avatar
Baber committed
64
65

#: Base YAML loader class - uses C loader if available for performance
Baber's avatar
nit  
Baber committed
66
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
Baber's avatar
Baber committed
67
68

#: Directory names to ignore during task discovery
Baber's avatar
nit  
Baber committed
69
70
71
72
_IGNORE_DIRS = (
    "__pycache__",
    ".ipynb_checkpoints",
)
73

lintangsutawika's avatar
lintangsutawika committed
74

Baber's avatar
nit  
Baber committed
75
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
76
    """YAML constructor that ignores !function tags during simple parsing."""
Baber's avatar
nit  
Baber committed
77
    return None
78
79


Baber's avatar
nit  
Baber committed
80
@functools.lru_cache(maxsize=2048)  # ← reuse per (directory, simple) pair
81
82
83
84
85
86
87
88
89
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
    """
    Return a custom YAML Loader class bound to *yaml_dir*.

    yaml_dir
        Directory that holds the YAML file being parsed.
        We capture it so that !function look-ups can resolve relative
        Python files like  my_utils.some_fn  ➜  yaml_dir / "my_utils.py".
    simple
Baber's avatar
nit  
Baber committed
90
91
        If True we ignore !function completely (used by `mode="simple"`),
        used on TaskManager init to index.
92
93
94
95
96
97
98
    """

    class Loader(_Base):
        """Dynamically-generated loader that knows its base directory."""

    # Register (or stub) the !function constructor **for this Loader only**
    if simple:
Baber's avatar
nit  
Baber committed
99
        yaml.add_constructor("!function", ignore_constructor, Loader=Loader)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    else:
        yaml.add_constructor(
            "!function",
            # capture yaml_dir once so the lambda is fast and pickle-able
            lambda ld, node, _dir=yaml_dir: _import_function(
                ld.construct_scalar(node),
                base_path=_dir,
            ),
            Loader=Loader,
        )

    return Loader


Baber's avatar
nit  
Baber committed
114
@functools.lru_cache(maxsize=None)  # ← cache module objects
115
def _import_function(qualname: str, *, base_path: Path) -> Callable:
Baber's avatar
Baber committed
116
117
118
119
    """
    Dynamically import a function from a Python module relative to base_path.

    This function enables YAML files to reference Python functions using
120
    the !function tag. Supports dot notation for nested modules.
Baber's avatar
Baber committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

    Args:
        qualname: Qualified function name like "my_module.my_function"
        base_path: Base directory for resolving relative module paths

    Returns:
        The imported callable function

    Raises:
        ValueError: If qualname doesn't contain a module part

    Example:
        >>> func = _import_function("utils.custom_metric", base_path=Path("/tasks"))
        >>> result = func(predictions, references)
    """
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    mod_path, _, func_name = qualname.rpartition(".")
    if not mod_path:
        raise ValueError(f"{qualname!r} has no module part")
    file_path = base_path / f"{mod_path.replace('.', '/')}.py"
    module_name = f"_yaml_dynamic.{hash(file_path)}_{file_path.stem}"
    if module_name in sys.modules:
        mod = sys.modules[module_name]
    else:
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        mod = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(mod)
        sys.modules[module_name] = mod
    return getattr(mod, func_name)


Baber's avatar
Baber committed
151
@functools.lru_cache(maxsize=4096)
152
def _parse_yaml_file(path: Path, mode: str) -> dict:
Baber's avatar
Baber committed
153
154
155
156
157
158
159
160
161
162
    """
    Parse a single YAML file with the appropriate loader.

    Args:
        path: Path to the YAML file
        mode: Parsing mode ("full" or "simple")

    Returns:
        Parsed YAML configuration as dictionary
    """
163
164
165
166
167
    loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
    with path.open("rb") as fh:
        return yaml.load(fh, Loader=loader_cls)


Baber's avatar
nit  
Baber committed
168
@functools.lru_cache(maxsize=4096)
Baber's avatar
nit  
Baber committed
169
def _get_cached_config(yaml_path: Path, mode: str) -> dict:
170
    """Load and cache resolved YAML configs"""
Baber's avatar
nit  
Baber committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    # Parse the YAML file
    yaml_config = _parse_yaml_file(yaml_path, mode)
    yaml_dir = yaml_path.parent

    # Handle includes
    include = yaml_config.pop("include", None)
    if not include:
        return yaml_config

    include_paths = include if isinstance(include, list) else [include]
    final_cfg: dict = {}

    for inc in reversed(include_paths):
        if inc is None:
            continue
        inc_path = Path(inc)
        if not inc_path.is_absolute():
            inc_path = (yaml_dir / inc_path).resolve()
        # Recursive call will use the cache
        included = _get_cached_config(inc_path, mode)
        final_cfg.update(included)

    final_cfg.update(yaml_config)  # local keys win
    return final_cfg


197
198
def load_yaml_config(
    yaml_path: Union[Path, str, None] = None,
Baber's avatar
Baber committed
199
200
    yaml_config: Optional[dict] = None,
    yaml_dir: Optional[Path] = None,
201
202
    mode: str = "full",
    *,
Baber's avatar
Baber committed
203
    _seen: Optional[set[tuple[Path, str]]] = None,
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    resolve_includes: bool = True,
) -> dict:
    """
    Parse a YAML config with optional include handling.

    Parameters
    ----------
    yaml_path
        Path to the main YAML file.  Needed unless *yaml_config* is
        supplied directly (e.g. by tests).
    yaml_config
        Pre-parsed dict to use instead of reading *yaml_path*.
    yaml_dir
        Base directory for resolving relative include paths.  Defaults
        to `yaml_path.parent`.
    mode
        "full"  – honour  !function  tags
        "simple" – ignore !function  (faster).
    _seen
        **Internal** recursion set: tuples of (absolute-path, mode).
        Prevents include cycles such as  A → B → A.
    """
    if yaml_config is None and yaml_path is None:
        raise ValueError("load_yaml_config needs either yaml_path or yaml_config")

    # ------------------------------------------------------------------ cycle guard
    if _seen is None:
        _seen = set()
    if yaml_path is not None:
        yaml_path = Path(yaml_path).expanduser().resolve()

Baber's avatar
nit  
Baber committed
235
236
237
        # ---------- fast-path: use LRU cached function ----------
        if yaml_config is None and resolve_includes:
            return _get_cached_config(yaml_path, mode)
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

        key = (yaml_path.resolve(), mode)
        if key in _seen:
            raise ValueError(f"Include cycle detected at {yaml_path}")
        _seen.add(key)

    # ------------------------------------------------------------------ load / parse
    if yaml_config is None:  # ordinary path-based load
        yaml_config = _parse_yaml_file(yaml_path, mode)

    if yaml_dir is None and yaml_path is not None:
        yaml_dir = yaml_path.parent
    assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"

    # ------------------------------------------------------------------ handle include
    include = yaml_config.pop("include", None)
    if not include and not resolve_includes:
        return yaml_config

    include_paths = include if isinstance(include, list) else [include]
    final_cfg: dict = {}

    for inc in reversed(include_paths):
        if inc is None:  # guard against explicit nulls
            continue
        inc_path = Path(inc)
        if not inc_path.is_absolute():
            inc_path = (yaml_dir / inc_path).resolve()
        included = load_yaml_config(
            yaml_path=inc_path,
            mode=mode,
            yaml_dir=inc_path.parent,
            _seen=_seen,  # <-- pass set downward
        )
        final_cfg.update(included)

    final_cfg.update(yaml_config)  # local keys win
    return final_cfg


278
def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
Baber's avatar
Baber committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    """
    Recursively iterate over all YAML files in a directory tree.

    Excludes files in ignored directories like __pycache__ and .ipynb_checkpoints.

    Args:
        root: Root directory to search for YAML files

    Yields:
        Path objects for each discovered YAML file

    Example:
        >>> for yaml_file in iter_yaml_files(Path("tasks")):
        ...     print(f"Found task config: {yaml_file}")
    """
Baber's avatar
Baber committed
294
    for p in iglob(str(root / "**/*.yaml"), recursive=True):
Baber's avatar
nit  
Baber committed
295
        # ignore check
Baber's avatar
Baber committed
296
297
298
        path = Path(p)
        # Check if any parent directory is in the ignore list
        if any(part in ignore for part in path.parts):
299
            continue
Baber's avatar
Baber committed
300
        yield path
Lintang Sutawika's avatar
Lintang Sutawika committed
301

302

303
class TaskManager:
Baber's avatar
Baber committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    """
    Central manager for task discovery, indexing, and loading.

    TaskManager scans directories for YAML task configurations and maintains
    an index of all available tasks, groups, and tags. It provides methods
    for listing, filtering, and loading tasks with their configurations.

    The manager supports:
    - Automatic discovery from default lm_eval/tasks/ directory
    - Custom task directories via include_path
    - Task grouping and tagging
    - Configuration inheritance via YAML includes
    - Caching for performance

    Attributes:
        include_path: Additional directories to search for tasks
        metadata: Global metadata to inject into all task configs
        task_group_map: Mapping of tasks to their parent groups
322

Baber's avatar
Baber committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    Example:
        Basic usage::

            tm = TaskManager()
            print(f"Found {len(tm.all_tasks)} tasks")
            hellaswag_config = tm._get_config("hellaswag")

        With custom tasks::

            tm = TaskManager(
                include_path="/my/custom/tasks",
                verbosity="INFO"
            )
            custom_tasks = [t for t in tm.all_tasks if "custom" in t]
337
338
    """

339
340
    def __init__(
        self,
Lintang Sutawika's avatar
Lintang Sutawika committed
341
        verbosity: Optional[str] = None,
Baber's avatar
nit  
Baber committed
342
        include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
343
        include_defaults: bool = True,
344
        metadata: Optional[dict[str, dict[str, Any]]] = None,
345
    ) -> None:
Baber's avatar
Baber committed
346
347
348
349
350
351
352
353
354
355
        """
        Initialize the TaskManager.

        Args:
            verbosity: Logging verbosity level (DEBUG, INFO, WARNING, ERROR)
            include_path: Additional path(s) to search for tasks. Can be a single
                         path or list of paths.
            include_defaults: Whether to include default tasks from lm_eval/tasks/
            metadata: Global metadata dictionary to inject into all task configs
        """
Lintang Sutawika's avatar
Lintang Sutawika committed
356
        if verbosity is not None:
Baber's avatar
nit  
Baber committed
357
            setup_logging(verbosity)
358
        self.include_path = include_path
Baber Abbasi's avatar
Baber Abbasi committed
359
        self.metadata = metadata
360
361
362
        self._task_index = self.initialize_tasks(
            include_path=include_path, include_defaults=include_defaults
        )
363
        self._all_tasks = sorted(list(self._task_index.keys()))
364

365
366
367
368
        self._all_groups = sorted(
            [x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
        )
        self._all_subtasks = sorted(
369
370
371
372
373
            [
                x
                for x in self._all_tasks
                if self._task_index[x]["type"] in ["task", "python_task"]
            ]
374
375
376
377
378
        )
        self._all_tags = sorted(
            [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
        )

379
        self.task_group_map = collections.defaultdict(list)
380

381
382
    def initialize_tasks(
        self,
Baber's avatar
nit  
Baber committed
383
        include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
384
        include_defaults: bool = True,
Baber Abbasi's avatar
Baber Abbasi committed
385
386
    ) -> dict[str, dict]:
        """Creates a dictionary of tasks indexes.
387

Baber's avatar
nit  
Baber committed
388
        :param include_path: Union[str, list] = None
389
390
391
392
            An additional path to be searched for tasks recursively.
            Can provide more than one such path as a list.
        :param include_defaults: bool = True
            If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
Baber Abbasi's avatar
Baber Abbasi committed
393
        return
Baber's avatar
nit  
Baber committed
394
            dictionary of task names as key and task metadata
395
        """
396
        if include_defaults:
Baber's avatar
Baber committed
397
            all_paths = [Path(__file__).parent]
398
399
        else:
            all_paths = []
400
        if include_path is not None:
Baber's avatar
Baber committed
401
            if isinstance(include_path, (str, Path)):
402
                include_path = [include_path]
Baber's avatar
Baber committed
403
404
            # Convert all paths to Path objects
            all_paths.extend(Path(p) for p in include_path)
405

406
407
408
409
        task_index = {}
        for task_dir in all_paths:
            tasks = self._get_task_and_group(task_dir)
            task_index = {**tasks, **task_index}
lintangsutawika's avatar
format  
lintangsutawika committed
410

411
412
413
        return task_index

    @property
Baber's avatar
nit  
Baber committed
414
    def all_tasks(self) -> list[str]:
Baber's avatar
Baber committed
415
        """Get sorted list of all task names (tasks, groups, and tags)."""
416
417
        return self._all_tasks

418
    @property
Baber's avatar
nit  
Baber committed
419
    def all_groups(self) -> list[str]:
Baber's avatar
Baber committed
420
        """Get sorted list of all group names."""
421
422
423
        return self._all_groups

    @property
Baber's avatar
nit  
Baber committed
424
    def all_subtasks(self) -> list[str]:
Baber's avatar
Baber committed
425
        """Get sorted list of all individual task names (excludes groups and tags)."""
426
427
428
        return self._all_subtasks

    @property
Baber's avatar
nit  
Baber committed
429
    def all_tags(self) -> list[str]:
Baber's avatar
Baber committed
430
        """Get sorted list of all tag names."""
431
432
        return self._all_tags

433
    @property
Baber's avatar
nit  
Baber committed
434
    def task_index(self) -> dict[str, dict[str, Union[str, int, list[str]]]]:
Baber's avatar
Baber committed
435
        """Get the complete task index with metadata for all tasks."""
436
437
        return self._task_index

438
    def list_all_tasks(
Baber's avatar
Baber committed
439
440
441
442
        self,
        list_groups: bool = True,
        list_tags: bool = True,
        list_subtasks: bool = True,
443
    ) -> str:
444
445
446
447
448
        """
        Return a Markdown table (as a string) listing groups, tags and/or subtasks
        known to this TaskManager.  Safe for configs whose yaml_path is -1 and for
        task configs whose `include:` is a list.
        """
449
450
        from pytablewriter import MarkdownTableWriter

451
        # ------------------------------------------------------------------ helpers
Baber's avatar
Baber committed
452
        def sanitize_path(path: str) -> str:
453
454
            # print a relative path for anything inside lm_eval/tasks/
            # path_str = str(path)
455
456
            if "lm_eval/tasks/" in path:
                return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
457
458
459
460
461
462
463
464
465
466
467
            return path

        def first_output_type_from_includes(cfg: dict, base: Path) -> str:
            """Walk cfg['include'] (string or list) and return the first
            include that itself specifies an output_type."""
            inc_raw = cfg.get("include")
            if not inc_raw:
                return ""

            inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw]
            for inc in inc_list:
Baber's avatar
nit  
Baber committed
468
469
470
471
472
473
474
475
476
477
                if inc:
                    inc_path = Path(inc)
                    if not inc_path.is_absolute():  # treat as relative include
                        inc_path = base.parent / inc_path
                    try:
                        inc_cfg = load_yaml_config(inc_path, mode="simple")
                    except FileNotFoundError:
                        continue
                    if "output_type" in inc_cfg:
                        return inc_cfg["output_type"]
478
479
480
            return ""

        # -------------------------------------------------------------- GROUP table
481
482
        group_table = MarkdownTableWriter()
        group_table.headers = ["Group", "Config Location"]
483
484
485
486
487
488
489
490
491
        group_table.value_matrix = [
            [
                g,
                "---"
                if self.task_index[g]["yaml_path"] == -1
                else sanitize_path(self.task_index[g]["yaml_path"]),
            ]
            for g in self.all_groups
        ]
492

493
        # ---------------------------------------------------------------- TAG table
494
495
496
497
        tag_table = MarkdownTableWriter()
        tag_table.headers = ["Tag"]
        tag_table.value_matrix = [[t] for t in self.all_tags]

498
        # ------------------------------------------------------------ SUBTASK table
499
500
        subtask_table = MarkdownTableWriter()
        subtask_table.headers = ["Task", "Config Location", "Output Type"]
501
502
        st_values: list[list[str]] = []

503
        for t in self.all_subtasks:
504
505
506
507
508
509
            raw_path = self.task_index[t]["yaml_path"]

            if raw_path == -1:
                # python-only task or generated at runtime
                display_path = "---"
                output_type = ""
510
            else:
511
512
513
514
515
516
517
518
519
520
521
522
                path_obj = Path(raw_path)
                display_path = sanitize_path(str(path_obj))

                # load minimal YAML to discover output_type
                cfg = load_yaml_config(path_obj, mode="simple")
                if "output_type" in cfg:
                    output_type = cfg["output_type"]
                else:
                    output_type = first_output_type_from_includes(cfg, path_obj)

            st_values.append([t, display_path, output_type])

523
524
        subtask_table.value_matrix = st_values

525
526
        # ------------------------------------------------------------- final string
        parts: list[str] = ["\n"]
527
        if list_groups:
528
529
            parts.append(group_table.dumps())
            parts.append("\n")
530
        if list_tags:
531
532
            parts.append(tag_table.dumps())
            parts.append("\n")
533
        if list_subtasks:
534
535
536
537
            parts.append(subtask_table.dumps())
            parts.append("\n")

        return "".join(parts)
538

Baber Abbasi's avatar
Baber Abbasi committed
539
    def match_tasks(self, task_list: list[str]) -> list[str]:
540
        """Match task names using glob-style pattern matching."""
Baber's avatar
nit  
Baber committed
541
        return pattern_match(task_list, self.all_tasks)
542

Baber Abbasi's avatar
Baber Abbasi committed
543
    def _name_is_registered(self, name: str) -> bool:
Baber's avatar
Baber committed
544
        """Check if a name is registered in the task index."""
545
        return name in self.all_tasks
546

Baber Abbasi's avatar
Baber Abbasi committed
547
    def _name_is_task(self, name: str) -> bool:
Baber's avatar
Baber committed
548
        """Check if a name refers to an individual task (not group or tag)."""
549
550
551
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "task"
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
552

Baber Abbasi's avatar
Baber Abbasi committed
553
    def _name_is_tag(self, name: str) -> bool:
Baber's avatar
Baber committed
554
        """Check if a name refers to a tag."""
555
        return self._name_is_registered(name) and self.task_index[name]["type"] == "tag"
556

Baber Abbasi's avatar
Baber Abbasi committed
557
    def _name_is_group(self, name: str) -> bool:
Baber's avatar
Baber committed
558
        """Check if a name refers to a group."""
559
560
561
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "group"
        )
562

Baber Abbasi's avatar
Baber Abbasi committed
563
    def _name_is_python_task(self, name: str) -> bool:
Baber's avatar
Baber committed
564
        """Check if a name refers to a Python-defined task."""
565
566
567
568
        return (
            self._name_is_registered(name)
            and self.task_index[name]["type"] == "python_task"
        )
569

Baber's avatar
fix  
Baber committed
570
571
    @staticmethod
    def _config_is_task(config: dict) -> bool:
Baber's avatar
Baber committed
572
        """Check if a config dictionary defines a single task."""
573
        return "task" in config and isinstance(config["task"], str)
574

Baber's avatar
fix  
Baber committed
575
576
    @staticmethod
    def _config_is_group(config: dict) -> bool:
Baber's avatar
Baber committed
577
        """Check if a config dictionary defines a group of tasks."""
578
        return "task" in config and isinstance(config["task"], list)
579

Baber's avatar
fix  
Baber committed
580
581
    @staticmethod
    def _config_is_python_task(config: dict) -> bool:
Baber's avatar
Baber committed
582
        """Check if a config dictionary defines a Python class-based task."""
583
584
        return "class" in config

Baber's avatar
fix  
Baber committed
585
586
    @staticmethod
    def _config_is_task_list(config: dict) -> bool:
Baber's avatar
Baber committed
587
        """Check if a config dictionary defines a task list."""
588
        return "task_list" in config and isinstance(config["task_list"], list)
589

Baber's avatar
fix  
Baber committed
590
    def _get_yaml_path(self, name: str) -> Union[str, int, list[str]]:
Baber's avatar
Baber committed
591
592
593
594
595
596
597
598
599
600
601
602
        """
        Get the YAML file path for a registered task.

        Args:
            name: Task name

        Returns:
            Path to YAML file, or -1 for Python-only tasks

        Raises:
            ValueError: If task name is not registered
        """
603
604
        if name not in self.task_index:
            raise ValueError
605
606
        return self.task_index[name]["yaml_path"]

Baber's avatar
nit  
Baber committed
607
    def _get_config(self, name: str) -> dict:
Baber's avatar
Baber committed
608
609
610
611
612
613
614
615
616
617
618
619
        """
        Load the full configuration for a registered task.

        Args:
            name: Task name

        Returns:
            Complete task configuration dictionary

        Raises:
            ValueError: If task name is not registered
        """
620
621
        if name not in self.task_index:
            raise ValueError
622
623
624
625
        yaml_path = self._get_yaml_path(name)
        if yaml_path == -1:
            return {}
        else:
626
            return load_yaml_config(Path(yaml_path), mode="full")
627

Baber's avatar
nit  
Baber committed
628
    def _get_tasklist(self, name: str) -> Union[list[str], int]:
Baber's avatar
Baber committed
629
630
631
632
633
634
635
636
637
638
639
640
        """
        Get the task list for a group or tag.

        Args:
            name: Group or tag name

        Returns:
            List of task names in the group/tag

        Raises:
            ValueError: If name refers to an individual task
        """
641
642
        if self._name_is_task(name):
            raise ValueError
643
644
        return self.task_index[name]["task"]

645
646
647
648
649
    def _register_task(
        self,
        task_name: str,
        task_type: str,
        yaml_path: str,
Baber's avatar
nit  
Baber committed
650
651
        tasks_and_groups: dict[str, dict],
        config: Optional[dict] = None,
Baber's avatar
fix  
Baber committed
652
        populate_tags_fn: Optional[Callable] = None,
Baber's avatar
Baber committed
653
    ) -> None:
654
655
656
657
658
659
660
        """Helper method to register a task in the tasks_and_groups dict"""
        tasks_and_groups[task_name] = {
            "type": task_type,
            "yaml_path": yaml_path,
        }
        # Only populate tags for configs that support it (not groups)
        if config and task_type != "group" and populate_tags_fn:
661
            populate_tags_fn(config, task_name, tasks_and_groups)
662
663

    def _merge_task_configs(
Baber's avatar
nit  
Baber committed
664
665
        self, base_config: dict, task_specific_config: dict, task_name: str
    ) -> dict:
666
667
668
669
670
671
672
        """Merge base config with task-specific overrides for task_list configs"""
        if task_specific_config:
            task_specific_config = task_specific_config.copy()
            task_specific_config.pop("task", None)
            return {**base_config, **task_specific_config, "task": task_name}
        return {**base_config, "task": task_name}

Baber's avatar
Baber committed
673
    def _process_tag_subtasks(
Baber's avatar
nit  
Baber committed
674
675
        self, tag_name: str, update_config: Optional[dict] = None
    ) -> dict:
676
677
678
679
680
681
682
683
        """Process subtasks for a tag and return loaded tasks"""
        subtask_list = self._get_tasklist(tag_name)
        fn = partial(
            self._load_individual_task_or_group,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))

Baber's avatar
nit  
Baber committed
684
    def _process_alias(self, config: dict, group: Optional[str] = None) -> dict:
Baber's avatar
Baber committed
685
686
687
688
689
690
691
692
693
694
695
696
697
        """
        Process group alias configuration.

        If the group is not the same as the original group which the group alias
        was intended for, set the group_alias to None instead.

        Args:
            config: Task configuration dictionary
            group: Group name to validate against

        Returns:
            Modified configuration with processed aliases
        """
698
699
700
701
702
        if ("group_alias" in config) and ("group" in config) and group is not None:
            if config["group"] != group:
                config["group_alias"] = None
        return config

Baber's avatar
Baber committed
703
    def _class_has_config_in_constructor(self, cls) -> bool:
Baber's avatar
Baber committed
704
705
706
707
708
709
710
711
712
        """
        Check if a class constructor accepts a 'config' parameter.

        Args:
            cls: Class to inspect

        Returns:
            True if constructor has 'config' parameter, False otherwise
        """
713
714
715
716
717
718
719
        constructor = getattr(cls, "__init__", None)
        return (
            "config" in inspect.signature(constructor).parameters
            if constructor
            else False
        )

720
721
722
723
724
    ###############################################################################
    # NEW: Refactored _load_individual_task_or_group and helper methods          #
    ###############################################################################

    def _create_task_object(
725
        self,
726
727
        cfg: dict,
        task_name: str,
Baber's avatar
Baber committed
728
        yaml_path: Union[str, None],
729
    ) -> dict:
Baber's avatar
Baber committed
730
        """
731
732
        Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
        Returns {task_name: task_object}.
Baber's avatar
Baber committed
733
        """
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
        from lm_eval.api.task import ConfigurableTask, Task  # local import avoids cycle

        # ---- include handling ---------------------------------------------------
        if "include" in cfg:
            # keep original name so include merging cannot clobber it
            orig_name = cfg.get("task", task_name)
            cfg = {
                **load_yaml_config(  # recurse once, cached
                    yaml_path=Path(yaml_path) if yaml_path else None,
                    yaml_config={"include": cfg.pop("include")},
                    mode="full" if yaml_path else "simple",
                ),
                **cfg,
                "task": orig_name,
            }
749

750
751
752
753
754
755
756
757
758
759
760
761
        # ---- metadata merge -----------------------------------------------------
        if self.metadata is not None:
            cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
        else:
            cfg["metadata"] = cfg.get("metadata", {})

        # ---- python-task vs YAML-task -------------------------------------------
        if self._config_is_python_task(cfg):
            cls = cfg["class"]
            task_obj: Task
            if self._class_has_config_in_constructor(cls):
                task_obj = cls(config=cfg)
762
            else:
763
764
765
766
767
768
                task_obj = cls()
            # make sure name propagates when the class inherits ConfigurableTask
            if isinstance(task_obj, ConfigurableTask):  # type: ignore
                task_obj.config.task = task_name
        else:
            task_obj = ConfigurableTask(config=cfg)  # type: ignore
Baber's avatar
Baber committed
769

770
        return {task_name: task_obj}
Baber's avatar
Baber committed
771

772
773
774
    def _create_group_object(
        self,
        cfg: dict,
Baber's avatar
Baber committed
775
        parent_name: Union[str, None] = None,
776
    ) -> tuple[GroupConfig, list[Union[str, dict]]]:
777
        """
778
        Build GroupConfig and return (group_obj, subtask_names).
779
780
781
782
783
        Resolves tag expansion.
        """
        if self.metadata is not None:
            cfg["metadata"] = cfg.get("metadata", {}) | self.metadata

784
        grp = GroupConfig(**cfg)
785
        subtasks: list[Union[str, dict]] = []
Baber's avatar
fix  
Baber committed
786
787
788
789
790
791
        if grp.task:
            for t in grp.task:
                if isinstance(t, str) and self._name_is_tag(t):
                    subtasks.extend(self._get_tasklist(t))
                else:
                    subtasks.append(t)
792
        return grp, subtasks
Baber's avatar
Baber committed
793

794
795
796
    def _load_subtasks(
        self,
        subtasks: list[Union[str, dict]],
797
        parent_name: Union[str, GroupConfig, None],
Baber's avatar
Baber committed
798
        update_config: Union[dict, None],
799
800
801
802
803
804
805
806
    ) -> Mapping:
        """Return merged mapping of all subtasks, handling duplicates."""
        fn = functools.partial(
            self._load_individual_task_or_group,
            parent_name=parent_name,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtasks))))
Baber's avatar
Baber committed
807

808
809
    def _load_individual_task_or_group(
        self,
Baber's avatar
Baber committed
810
        payload: Union[str, dict],
811
        *,
Baber's avatar
Baber committed
812
813
        parent_name: Union[str, None] = None,
        update_config: Union[dict, None] = None,
814
815
816
817
818
    ) -> Mapping:
        """
        Public helper that turns *payload* (str task/group/tag **or** dict config)
        into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
        """
Baber's avatar
Baber committed
819

820
821
822
823
824
825
826
827
        # ------------------------------------------------------------------ STRING
        if isinstance(payload, str):
            # If caller supplied extra overrides, treat as dict immediately
            if update_config:
                return self._load_individual_task_or_group(
                    {"task": payload, **update_config},
                    parent_name=parent_name,
                )
828

829
830
831
832
833
834
835
836
837
838
839
840
841
842
            # ------------ registered TASK (YAML or python) -----------------
            if self._name_is_task(payload) or self._name_is_python_task(payload):
                yaml_path = self._get_yaml_path(payload)
                cfg = self._get_config(payload)

                # task_list configs: extract the per-task override ------------
                if "task_list" in cfg:
                    override = next(
                        (
                            entry
                            for entry in cfg["task_list"]
                            if isinstance(entry, dict) and entry.get("task") == payload
                        ),
                        None,
Lintang Sutawika's avatar
Lintang Sutawika committed
843
                    )
844
845
846
847
848
849
850
851
852
853
854
855
856
                    base = {k: v for k, v in cfg.items() if k != "task_list"}
                    if override:
                        cfg = {**base, **override, "task": payload}
                return self._create_task_object(cfg, payload, yaml_path)

            # ------------ registered GROUP ----------------------------------
            if self._name_is_group(payload):
                group_cfg = self._get_config(payload)
                grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS}
                grp_obj, subtasks = self._create_group_object(grp_only, parent_name)
                return {
                    grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None)
                }
857

858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
            # ------------ registered TAG ------------------------------------
            if self._name_is_tag(payload):
                return self._process_tag_subtasks(payload, update_config=None)

            raise ValueError(f"Unknown task / group / tag name: {payload!r}")

        # ------------------------------------------------------------------- DICT
        if isinstance(payload, dict):
            # ------------------ simple 'task: name' dict --------------------
            if self._config_is_task(payload):
                name = payload["task"]
                # override existing registered YAML if exists
                if self._name_is_registered(name):
                    base_cfg = self._get_config(name)
                    yaml_path = self._get_yaml_path(name)
                    merged = {**base_cfg, **payload}
874
                else:
875
                    merged = payload
876
                    yaml_path = None
877

878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
                # duplicate-naming guard when inside a group
                if parent_name is not None:
                    count = len(
                        [
                            n
                            for n in self.task_group_map[parent_name]
                            if n.startswith(name)
                        ]
                    )
                    if count:
                        name = f"{name}-{count}"
                    self.task_group_map[parent_name].append(name)

                return self._create_task_object(merged, name, yaml_path)

            # ----------------- literal group dict (task: [...]) -------------
            if self._config_is_group(payload):
                grp_cfg = {k: v for k, v in payload.items() if k in GROUP_ONLY_KEYS}
                sub_override = {
                    k: v for k, v in payload.items() if k not in GROUP_ONLY_KEYS
                } or None
                grp_obj, subtasks = self._create_group_object(grp_cfg, parent_name)
                return {grp_obj: self._load_subtasks(subtasks, grp_obj, sub_override)}

            # ----------------- python-task dict ('class': …) ----------------
            if self._config_is_python_task(payload):
                name = payload["task"]
                return self._create_task_object(payload, name, yaml_path=None)

        raise TypeError(
            f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
909
        )
910

Baber's avatar
Baber committed
911
    def load_task_or_group(
Baber's avatar
nit  
Baber committed
912
913
        self, task_list: Optional[Union[str, list[str]]] = None
    ) -> dict:
Baber's avatar
Baber committed
914
915
916
917
918
919
920
921
922
923
        """
        Load multiple tasks or groups from a list of names.

        This is the main entry point for loading tasks. It handles lists
        of task names and delegates to _load_individual_task_or_group for
        each item, then merges the results.

        Args:
            task_list: Single task name or list of task names to load.
                      Can include individual tasks, groups, and tags.
924

Baber's avatar
Baber committed
925
926
927
        Returns:
            Dictionary mapping task/group names to loaded task objects.
            Results from all requested items are merged into a single dict.
928

Baber's avatar
Baber committed
929
930
931
932
933
934
935
936
937
938
        Example:
            Load multiple tasks::

                tasks = tm.load_task_or_group(["hellaswag", "arc_easy"])
                # Returns: {"hellaswag": Task1, "arc_easy": Task2}

            Load a group::

                tasks = tm.load_task_or_group("arc_group")
                # Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
939
940
941
        """
        if isinstance(task_list, str):
            task_list = [task_list]
942

943
        all_loaded_tasks = dict(
Baber Abbasi's avatar
Baber Abbasi committed
944
945
946
947
948
949
            collections.ChainMap(
                *map(
                    lambda task: self._load_individual_task_or_group(task),
                    task_list,
                )
            )
950
951
952
        )
        return all_loaded_tasks

Baber's avatar
nit  
Baber committed
953
    def load_config(self, config: dict) -> Mapping:
Baber's avatar
Baber committed
954
955
956
957
958
959
960
961
962
963
964
965
966
        """
        Load a task from an inline configuration dictionary.

        Args:
            config: Configuration dictionary defining the task

        Returns:
            Mapping of task name to loaded task object

        Example:
            >>> config = {"task": "hellaswag", "num_fewshot": 5}
            >>> task_dict = tm.load_config(config)
        """
967
968
        return self._load_individual_task_or_group(config)

Baber's avatar
nit  
Baber committed
969
    def _get_task_and_group(self, task_dir: Union[str, Path]) -> dict[str, dict]:
Baber's avatar
Baber committed
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
        """
        Scan a directory for task configurations and build an index.

        Creates a dictionary of task metadata by recursively scanning for
        YAML files and parsing their configurations. This method handles:
        - Regular task configs with 'task' key
        - Python class-based tasks with 'class' key
        - Group configs with 'group' key
        - Task list configs with 'task_list' key
        - Tag extraction and registration

        Args:
            task_dir: Directory path to scan for YAML task configurations

        Returns:
            Dictionary mapping task names to metadata dictionaries.
            Each metadata dict contains:
            - 'type': One of 'task', 'python_task', 'group', 'tag'
            - 'yaml_path': Path to source YAML file (or -1 for generated entries)
            - 'task': For groups/tags, list of constituent task names

        Note:
            This method is called during TaskManager initialization to build
            the master task index. It uses 'simple' parsing mode for performance.
994
        """
995

Baber's avatar
Baber committed
996
        def _populate_tags_and_groups(
Baber's avatar
nit  
Baber committed
997
            config: dict, task: str, tasks_and_groups: dict[str, dict]
Baber's avatar
Baber committed
998
        ) -> None:
Baber's avatar
Baber committed
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
            """
            Extract and register tags from a task configuration.

            Tags allow grouping tasks by theme or category. This function
            processes the 'tag' field in task configs and maintains tag
            indices for quick lookup.

            Args:
                config: Task configuration dictionary
                task: Name of the task being processed
                tasks_and_groups: Master index to update with tag information
            """
1011
            # TODO: remove group in next release
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
            if "tag" in config:
                attr_list = config["tag"]
                if isinstance(attr_list, str):
                    attr_list = [attr_list]

                for tag in attr_list:
                    if tag not in tasks_and_groups:
                        tasks_and_groups[tag] = {
                            "type": "tag",
                            "task": [task],
                            "yaml_path": -1,
                        }
                    elif tasks_and_groups[tag]["type"] != "tag":
Lintang Sutawika's avatar
Lintang Sutawika committed
1025
                        eval_logger.info(
1026
1027
                            f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
                            "This may affect tasks you want to call."
1028
                        )
1029
1030
1031
                        break
                    else:
                        tasks_and_groups[tag]["task"].append(task)
1032

Lintang Sutawika's avatar
Lintang Sutawika committed
1033
        # TODO: remove group in next release
1034
1035
1036
1037
        # ignore_dirs = [
        #     "__pycache__",
        #     ".ipynb_checkpoints",
        # ]
1038
        tasks_and_groups = collections.defaultdict()
Baber's avatar
Baber committed
1039
1040
        task_dir_path = Path(task_dir)

1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
        for yaml_path in iter_yaml_files(task_dir_path):
            try:
                config = load_yaml_config(
                    yaml_path, mode="simple", resolve_includes=False
                )
            except (FileNotFoundError, YAMLError, OSError) as err:
                eval_logger.debug(f"File {yaml_path} could not be loaded ({err})")
                continue
            if self._config_is_python_task(config):
                # This is a python class config
                task = config["task"]
                self._register_task(
                    task,
                    "python_task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_group(config):
                # This is a group config
                tasks_and_groups[config["group"]] = {
                    "type": "group",
                    "task": -1,  # This signals that
                    # we don't need to know
                    # the task list for indexing
                    # as it can be loaded
                    # when called.
                    "yaml_path": str(yaml_path),
                }
lintangsutawika's avatar
lintangsutawika committed
1071

1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
                # # Registered the level 1 tasks from a group config
                # for config in config["task"]:
                #     if isinstance(config, dict) and self._config_is_task(config):
                #         task = config["task"]
                #         tasks_and_groups[task] = {
                #             "type": "task",
                #             "yaml_path": yaml_path,
                #             }

            elif self._config_is_task(config):
                # This is a task config
                task = config["task"]
                self._register_task(
                    task,
                    "task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_task_list(config):
                # This is a task_list config
                for task_entry in config["task_list"]:
                    if isinstance(task_entry, dict) and "task" in task_entry:
                        task_name = task_entry["task"]
1097
                        self._register_task(
1098
                            task_name,
1099
                            "task",
Baber's avatar
Baber committed
1100
                            str(yaml_path),
1101
1102
1103
                            tasks_and_groups,
                            config,
                            _populate_tags_and_groups,
1104
                        )
1105
1106
            else:
                eval_logger.debug(f"File {yaml_path} could not be loaded")
1107
1108

        return tasks_and_groups
lintangsutawika's avatar
lintangsutawika committed
1109

1110

Baber's avatar
nit  
Baber committed
1111
def get_task_name_from_config(task_config: dict[str, str]) -> str:
Baber's avatar
Baber committed
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
    """
    Extract a task name from a configuration dictionary.

    Determines the canonical name for a task based on its configuration,
    with fallback strategies for different config formats.

    Args:
        task_config: Task configuration dictionary

    Returns:
        String name for the task

    Example:
        >>> config = {"task": "hellaswag", "num_fewshot": 5}
        >>> get_task_name_from_config(config)
        'hellaswag'

        >>> config = {"dataset_path": "custom", "dataset_name": "mytask"}
        >>> get_task_name_from_config(config)
        'custom_mytask'
    """
1133
1134
1135
1136
1137
1138
    if "task" in task_config:
        return task_config["task"]
    if "dataset_name" in task_config:
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
lintangsutawika's avatar
lintangsutawika committed
1139

1140

Baber's avatar
nit  
Baber committed
1141
def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> str:
Baber's avatar
Baber committed
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
    """
    Extract the name from an instantiated task object.

    Handles both ConfigurableTask and legacy Task objects with different
    attribute conventions for storing the task name.

    Args:
        task_object: An instantiated task object

    Returns:
        String name of the task

    Example:
        >>> task = ConfigurableTask(config={"task": "hellaswag"})
        >>> get_task_name_from_object(task)
        'hellaswag'
    """
1159
1160
    if hasattr(task_object, "config"):
        return task_object._config["task"]
lintangsutawika's avatar
lintangsutawika committed
1161
1162
1163
1164
1165
1166
1167
1168
1169

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )

1170

Baber's avatar
nit  
Baber committed
1171
def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
Baber's avatar
Baber committed
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
    """
    Validate that no tasks appear in multiple groups simultaneously.

    Helper function used to prevent conflicts when multiple groups claim
    the same constituent task. This could lead to ambiguous configuration
    like conflicting num_fewshot values.

    Args:
        task_dict: Dictionary mapping group names to lists of subtask names

    Raises:
        ValueError: If any tasks appear in multiple groups

    Example:
        >>> task_dict = {
        ...     "group1": ["task_a", "task_b"],
        ...     "group2": ["task_b", "task_c"]  # task_b appears twice!
        ... }
        >>> _check_duplicates(task_dict)  # Raises ValueError
Lintang Sutawika's avatar
Lintang Sutawika committed
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
    """
    subtask_names = []
    for key, value in task_dict.items():
        subtask_names.extend(value)

    duplicate_tasks = {
        task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
    }

    # locate the potentially problematic groups that seem to 'compete' for constituent subtasks
    competing_groups = [
        group
        for group in task_dict.keys()
        if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
    ]

    if len(duplicate_tasks) > 0:
        raise ValueError(
            f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
        )


1213
def get_task_dict(
Baber's avatar
nit  
Baber committed
1214
    task_name_list: Union[str, list[Union[str, dict, "Task"]]],
1215
    task_manager: Optional[TaskManager] = None,
Baber's avatar
nit  
Baber committed
1216
) -> dict[str, Union["ConfigurableTask", "Task"]]:
Baber's avatar
Baber committed
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
    """
    Create a dictionary of task objects from mixed input types.

    This is the main public API for loading tasks. It accepts various input
    formats (names, configs, objects) and returns a unified dictionary of
    instantiated task objects ready for evaluation.

    The function handles:
    - String task names (looked up via TaskManager)
    - Configuration dictionaries (processed as inline configs)
    - Pre-instantiated Task objects (used as-is)
    - Validation to prevent conflicting group memberships

    Args:
        task_name_list: Mixed list of task specifications:
                       - str: Task name to look up
                       - dict: Inline task configuration
                       - Task: Pre-instantiated task object
        task_manager: TaskManager instance for name resolution.
                     If None, creates a default TaskManager.

    Returns:
        Dictionary mapping task names to instantiated task objects.
        All tasks are ready for evaluation.

    Raises:
        TypeError: If task_name_list contains unsupported types
        ValueError: If there are conflicting group memberships

    Example:
        Mixed input types::

            tasks = get_task_dict([
                "hellaswag",                              # lookup by name
                {"task": "arc_easy", "num_fewshot": 5},   # inline config
                pre_existing_task_object                  # direct object
            ])

        Simple case::

            tasks = get_task_dict("hellaswag")
            # Returns: {"hellaswag": ConfigurableTask(...)}

        With custom TaskManager::

            tm = TaskManager(include_path="/custom/tasks")
            tasks = get_task_dict(["custom_task"], task_manager=tm)
1264
    """
1265
    from lm_eval.api.task import Task
1266

1267
    # Normalize input to list
1268
    if isinstance(task_name_list, str):
lintangsutawika's avatar
lintangsutawika committed
1269
        task_name_list = [task_name_list]
1270
    elif not isinstance(task_name_list, list):
1271
1272
1273
        raise TypeError(
            f"Expected a 'str' or 'list' but received {type(task_name_list)}."
        )
lintangsutawika's avatar
lintangsutawika committed
1274

1275
1276
1277
1278
    # Validate list items
    if not all(isinstance(task, (str, dict, Task)) for task in task_name_list):
        raise TypeError(
            "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
1279
        )
1280

1281
1282
1283
    # Ensure we have a task manager
    if task_manager is None:
        task_manager = TaskManager()
Lintang Sutawika's avatar
Lintang Sutawika committed
1284

1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
    # Process all items
    final_task_dict = {}
    for task_spec in task_name_list:
        if isinstance(task_spec, Task):
            # Pre-instantiated task object
            task_name = get_task_name_from_object(task_spec)
            if task_name in final_task_dict:
                raise ValueError(f"Duplicate task name: {task_name}")
            final_task_dict[task_name] = task_spec
        else:
            # String or dict - use load_task_or_group
            result = task_manager.load_task_or_group(task_spec)
            # Check for duplicate names
            for name in result:
                if name in final_task_dict:
                    raise ValueError(f"Duplicate task name: {name}")
            final_task_dict.update(result)

    # Check for conflicting group memberships
Lintang Sutawika's avatar
Lintang Sutawika committed
1304
1305
1306
    _check_duplicates(get_subtask_list(final_task_dict))

    return final_task_dict