__init__.py 45.9 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
Task Management Module for LM Evaluation Harness.

This module provides comprehensive task discovery, loading, and management functionality
for the LM Evaluation Harness. It handles YAML configuration parsing with include support,
dynamic function importing, and task indexing across multiple directories.

Key Components:
- TaskManager: Main class for task discovery and management
- YAML configuration loading with !function tag support
- Task, group, and tag indexing
- Include resolution with cycle detection
- Caching for performance optimization

Example:
    Basic usage::

        task_manager = TaskManager()
        all_tasks = task_manager.all_tasks
        task_config = task_manager._get_config("hellaswag")

    Custom task paths::

        task_manager = TaskManager(
            include_path="/path/to/custom/tasks",
            include_defaults=True
        )
"""

30
import collections
31
32
import functools
import importlib.util
33
import inspect
34
import logging
35
import sys
36
from functools import partial
37
from glob import iglob
Baber's avatar
Baber committed
38
from pathlib import Path
Baber's avatar
nit  
Baber committed
39
40
41
42
43
44
45
46
47
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Generator,
    Mapping,
    Optional,
    Union,
)
48
49
50

import yaml
from yaml import YAMLError
&'s avatar
& committed
51

52
from lm_eval.api.group import GroupConfig
Lintang Sutawika's avatar
Lintang Sutawika committed
53
from lm_eval.evaluator_utils import get_subtask_list
Baber's avatar
nit  
Baber committed
54
55
56
57
58
from lm_eval.utils import pattern_match, setup_logging


if TYPE_CHECKING:
    from lm_eval.api.task import ConfigurableTask, Task
Lintang Sutawika's avatar
Lintang Sutawika committed
59

Baber's avatar
nit  
Baber committed
60
eval_logger = logging.getLogger(__name__)
Lintang Sutawika's avatar
Lintang Sutawika committed
61

Baber's avatar
Baber committed
62
#: List of configuration keys that are specific to groups only
Lintang Sutawika's avatar
Lintang Sutawika committed
63
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
Baber's avatar
Baber committed
64
65

#: Base YAML loader class - uses C loader if available for performance
Baber's avatar
nit  
Baber committed
66
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
Baber's avatar
Baber committed
67
68

#: Directory names to ignore during task discovery
Baber's avatar
nit  
Baber committed
69
70
71
72
_IGNORE_DIRS = (
    "__pycache__",
    ".ipynb_checkpoints",
)
73

lintangsutawika's avatar
lintangsutawika committed
74

Baber's avatar
nit  
Baber committed
75
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
76
    """YAML constructor that ignores !function tags during simple parsing."""
Baber's avatar
nit  
Baber committed
77
    return None
78
79


Baber's avatar
nit  
Baber committed
80
@functools.lru_cache(maxsize=2048)  # ← reuse per (directory, simple) pair
81
82
83
84
85
86
87
88
89
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
    """
    Return a custom YAML Loader class bound to *yaml_dir*.

    yaml_dir
        Directory that holds the YAML file being parsed.
        We capture it so that !function look-ups can resolve relative
        Python files like  my_utils.some_fn  ➜  yaml_dir / "my_utils.py".
    simple
Baber's avatar
nit  
Baber committed
90
91
        If True we ignore !function completely (used by `mode="simple"`),
        used on TaskManager init to index.
92
93
94
95
96
97
98
    """

    class Loader(_Base):
        """Dynamically-generated loader that knows its base directory."""

    # Register (or stub) the !function constructor **for this Loader only**
    if simple:
Baber's avatar
nit  
Baber committed
99
        yaml.add_constructor("!function", ignore_constructor, Loader=Loader)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    else:
        yaml.add_constructor(
            "!function",
            # capture yaml_dir once so the lambda is fast and pickle-able
            lambda ld, node, _dir=yaml_dir: _import_function(
                ld.construct_scalar(node),
                base_path=_dir,
            ),
            Loader=Loader,
        )

    return Loader


Baber's avatar
nit  
Baber committed
114
@functools.lru_cache(maxsize=None)  # ← cache module objects
115
def _import_function(qualname: str, *, base_path: Path) -> Callable:
Baber's avatar
Baber committed
116
117
118
119
    """
    Dynamically import a function from a Python module relative to base_path.

    This function enables YAML files to reference Python functions using
120
    the !function tag. Supports dot notation for nested modules.
Baber's avatar
Baber committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

    Args:
        qualname: Qualified function name like "my_module.my_function"
        base_path: Base directory for resolving relative module paths

    Returns:
        The imported callable function

    Raises:
        ValueError: If qualname doesn't contain a module part

    Example:
        >>> func = _import_function("utils.custom_metric", base_path=Path("/tasks"))
        >>> result = func(predictions, references)
    """
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    mod_path, _, func_name = qualname.rpartition(".")
    if not mod_path:
        raise ValueError(f"{qualname!r} has no module part")
    file_path = base_path / f"{mod_path.replace('.', '/')}.py"
    module_name = f"_yaml_dynamic.{hash(file_path)}_{file_path.stem}"
    if module_name in sys.modules:
        mod = sys.modules[module_name]
    else:
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        mod = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(mod)
        sys.modules[module_name] = mod
    return getattr(mod, func_name)


Baber's avatar
Baber committed
151
@functools.lru_cache(maxsize=4096)
152
def _parse_yaml_file(path: Path, mode: str) -> dict:
Baber's avatar
Baber committed
153
154
155
156
157
158
159
160
161
162
    """
    Parse a single YAML file with the appropriate loader.

    Args:
        path: Path to the YAML file
        mode: Parsing mode ("full" or "simple")

    Returns:
        Parsed YAML configuration as dictionary
    """
163
164
165
166
167
    loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
    with path.open("rb") as fh:
        return yaml.load(fh, Loader=loader_cls)


Baber's avatar
nit  
Baber committed
168
@functools.lru_cache(maxsize=4096)
Baber's avatar
nit  
Baber committed
169
def _get_cached_config(yaml_path: Path, mode: str) -> dict:
170
    """Load and cache resolved YAML configs"""
Baber's avatar
nit  
Baber committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    # Parse the YAML file
    yaml_config = _parse_yaml_file(yaml_path, mode)
    yaml_dir = yaml_path.parent

    # Handle includes
    include = yaml_config.pop("include", None)
    if not include:
        return yaml_config

    include_paths = include if isinstance(include, list) else [include]
    final_cfg: dict = {}

    for inc in reversed(include_paths):
        if inc is None:
            continue
        inc_path = Path(inc)
        if not inc_path.is_absolute():
            inc_path = (yaml_dir / inc_path).resolve()
        # Recursive call will use the cache
        included = _get_cached_config(inc_path, mode)
        final_cfg.update(included)

    final_cfg.update(yaml_config)  # local keys win
    return final_cfg


197
198
def load_yaml_config(
    yaml_path: Union[Path, str, None] = None,
Baber's avatar
Baber committed
199
200
    yaml_config: Optional[dict] = None,
    yaml_dir: Optional[Path] = None,
201
202
    mode: str = "full",
    *,
Baber's avatar
Baber committed
203
    _seen: Optional[set[tuple[Path, str]]] = None,
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    resolve_includes: bool = True,
) -> dict:
    """
    Parse a YAML config with optional include handling.

    Parameters
    ----------
    yaml_path
        Path to the main YAML file.  Needed unless *yaml_config* is
        supplied directly (e.g. by tests).
    yaml_config
        Pre-parsed dict to use instead of reading *yaml_path*.
    yaml_dir
        Base directory for resolving relative include paths.  Defaults
        to `yaml_path.parent`.
    mode
        "full"  – honour  !function  tags
        "simple" – ignore !function  (faster).
    _seen
        **Internal** recursion set: tuples of (absolute-path, mode).
        Prevents include cycles such as  A → B → A.
    """
    if yaml_config is None and yaml_path is None:
        raise ValueError("load_yaml_config needs either yaml_path or yaml_config")

    # ------------------------------------------------------------------ cycle guard
    if _seen is None:
        _seen = set()
    if yaml_path is not None:
        yaml_path = Path(yaml_path).expanduser().resolve()

Baber's avatar
nit  
Baber committed
235
236
237
        # ---------- fast-path: use LRU cached function ----------
        if yaml_config is None and resolve_includes:
            return _get_cached_config(yaml_path, mode)
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

        key = (yaml_path.resolve(), mode)
        if key in _seen:
            raise ValueError(f"Include cycle detected at {yaml_path}")
        _seen.add(key)

    # ------------------------------------------------------------------ load / parse
    if yaml_config is None:  # ordinary path-based load
        yaml_config = _parse_yaml_file(yaml_path, mode)

    if yaml_dir is None and yaml_path is not None:
        yaml_dir = yaml_path.parent
    assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"

    # ------------------------------------------------------------------ handle include
    include = yaml_config.pop("include", None)
    if not include and not resolve_includes:
        return yaml_config

    include_paths = include if isinstance(include, list) else [include]
    final_cfg: dict = {}

    for inc in reversed(include_paths):
        if inc is None:  # guard against explicit nulls
            continue
        inc_path = Path(inc)
        if not inc_path.is_absolute():
            inc_path = (yaml_dir / inc_path).resolve()
        included = load_yaml_config(
            yaml_path=inc_path,
            mode=mode,
            yaml_dir=inc_path.parent,
            _seen=_seen,  # <-- pass set downward
        )
        final_cfg.update(included)

    final_cfg.update(yaml_config)  # local keys win
    return final_cfg


278
def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
Baber's avatar
Baber committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    """
    Recursively iterate over all YAML files in a directory tree.

    Excludes files in ignored directories like __pycache__ and .ipynb_checkpoints.

    Args:
        root: Root directory to search for YAML files

    Yields:
        Path objects for each discovered YAML file

    Example:
        >>> for yaml_file in iter_yaml_files(Path("tasks")):
        ...     print(f"Found task config: {yaml_file}")
    """
Baber's avatar
nit  
Baber committed
294
295
    for p in iglob("**/*.yaml", root_dir=root, recursive=True):
        # ignore check
296
        if Path(p).parts[0] in ignore:
297
            continue
Baber's avatar
nit  
Baber committed
298
        yield root / p
Lintang Sutawika's avatar
Lintang Sutawika committed
299

300

301
class TaskManager:
Baber's avatar
Baber committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
    """
    Central manager for task discovery, indexing, and loading.

    TaskManager scans directories for YAML task configurations and maintains
    an index of all available tasks, groups, and tags. It provides methods
    for listing, filtering, and loading tasks with their configurations.

    The manager supports:
    - Automatic discovery from default lm_eval/tasks/ directory
    - Custom task directories via include_path
    - Task grouping and tagging
    - Configuration inheritance via YAML includes
    - Caching for performance

    Attributes:
        include_path: Additional directories to search for tasks
        metadata: Global metadata to inject into all task configs
        task_group_map: Mapping of tasks to their parent groups
320

Baber's avatar
Baber committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    Example:
        Basic usage::

            tm = TaskManager()
            print(f"Found {len(tm.all_tasks)} tasks")
            hellaswag_config = tm._get_config("hellaswag")

        With custom tasks::

            tm = TaskManager(
                include_path="/my/custom/tasks",
                verbosity="INFO"
            )
            custom_tasks = [t for t in tm.all_tasks if "custom" in t]
335
336
    """

337
338
    def __init__(
        self,
Lintang Sutawika's avatar
Lintang Sutawika committed
339
        verbosity: Optional[str] = None,
Baber's avatar
nit  
Baber committed
340
        include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
341
        include_defaults: bool = True,
342
        metadata: Optional[dict[str, dict[str, Any]]] = None,
343
    ) -> None:
Baber's avatar
Baber committed
344
345
346
347
348
349
350
351
352
353
        """
        Initialize the TaskManager.

        Args:
            verbosity: Logging verbosity level (DEBUG, INFO, WARNING, ERROR)
            include_path: Additional path(s) to search for tasks. Can be a single
                         path or list of paths.
            include_defaults: Whether to include default tasks from lm_eval/tasks/
            metadata: Global metadata dictionary to inject into all task configs
        """
Lintang Sutawika's avatar
Lintang Sutawika committed
354
        if verbosity is not None:
Baber's avatar
nit  
Baber committed
355
            setup_logging(verbosity)
356
        self.include_path = include_path
Baber Abbasi's avatar
Baber Abbasi committed
357
        self.metadata = metadata
358
359
360
        self._task_index = self.initialize_tasks(
            include_path=include_path, include_defaults=include_defaults
        )
361
        self._all_tasks = sorted(list(self._task_index.keys()))
362

363
364
365
366
        self._all_groups = sorted(
            [x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
        )
        self._all_subtasks = sorted(
367
368
369
370
371
            [
                x
                for x in self._all_tasks
                if self._task_index[x]["type"] in ["task", "python_task"]
            ]
372
373
374
375
376
        )
        self._all_tags = sorted(
            [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
        )

377
        self.task_group_map = collections.defaultdict(list)
378

379
380
    def initialize_tasks(
        self,
Baber's avatar
nit  
Baber committed
381
        include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
382
        include_defaults: bool = True,
Baber Abbasi's avatar
Baber Abbasi committed
383
384
    ) -> dict[str, dict]:
        """Creates a dictionary of tasks indexes.
385

Baber's avatar
nit  
Baber committed
386
        :param include_path: Union[str, list] = None
387
388
389
390
            An additional path to be searched for tasks recursively.
            Can provide more than one such path as a list.
        :param include_defaults: bool = True
            If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
Baber Abbasi's avatar
Baber Abbasi committed
391
        return
Baber's avatar
nit  
Baber committed
392
            dictionary of task names as key and task metadata
393
        """
394
        if include_defaults:
Baber's avatar
Baber committed
395
            all_paths = [Path(__file__).parent]
396
397
        else:
            all_paths = []
398
        if include_path is not None:
Baber's avatar
Baber committed
399
            if isinstance(include_path, (str, Path)):
400
                include_path = [include_path]
Baber's avatar
Baber committed
401
402
            # Convert all paths to Path objects
            all_paths.extend(Path(p) for p in include_path)
403

404
405
406
407
        task_index = {}
        for task_dir in all_paths:
            tasks = self._get_task_and_group(task_dir)
            task_index = {**tasks, **task_index}
lintangsutawika's avatar
format  
lintangsutawika committed
408

409
410
411
        return task_index

    @property
Baber's avatar
nit  
Baber committed
412
    def all_tasks(self) -> list[str]:
Baber's avatar
Baber committed
413
        """Get sorted list of all task names (tasks, groups, and tags)."""
414
415
        return self._all_tasks

416
    @property
Baber's avatar
nit  
Baber committed
417
    def all_groups(self) -> list[str]:
Baber's avatar
Baber committed
418
        """Get sorted list of all group names."""
419
420
421
        return self._all_groups

    @property
Baber's avatar
nit  
Baber committed
422
    def all_subtasks(self) -> list[str]:
Baber's avatar
Baber committed
423
        """Get sorted list of all individual task names (excludes groups and tags)."""
424
425
426
        return self._all_subtasks

    @property
Baber's avatar
nit  
Baber committed
427
    def all_tags(self) -> list[str]:
Baber's avatar
Baber committed
428
        """Get sorted list of all tag names."""
429
430
        return self._all_tags

431
    @property
Baber's avatar
nit  
Baber committed
432
    def task_index(self) -> dict[str, dict[str, Union[str, int, list[str]]]]:
Baber's avatar
Baber committed
433
        """Get the complete task index with metadata for all tasks."""
434
435
        return self._task_index

436
    def list_all_tasks(
Baber's avatar
Baber committed
437
438
439
440
        self,
        list_groups: bool = True,
        list_tags: bool = True,
        list_subtasks: bool = True,
441
    ) -> str:
442
443
444
445
446
        """
        Return a Markdown table (as a string) listing groups, tags and/or subtasks
        known to this TaskManager.  Safe for configs whose yaml_path is -1 and for
        task configs whose `include:` is a list.
        """
447
448
        from pytablewriter import MarkdownTableWriter

449
        # ------------------------------------------------------------------ helpers
Baber's avatar
Baber committed
450
        def sanitize_path(path: str) -> str:
451
452
            # print a relative path for anything inside lm_eval/tasks/
            # path_str = str(path)
453
454
            if "lm_eval/tasks/" in path:
                return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
455
456
457
458
459
460
461
462
463
464
465
            return path

        def first_output_type_from_includes(cfg: dict, base: Path) -> str:
            """Walk cfg['include'] (string or list) and return the first
            include that itself specifies an output_type."""
            inc_raw = cfg.get("include")
            if not inc_raw:
                return ""

            inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw]
            for inc in inc_list:
Baber's avatar
nit  
Baber committed
466
467
468
469
470
471
472
473
474
475
                if inc:
                    inc_path = Path(inc)
                    if not inc_path.is_absolute():  # treat as relative include
                        inc_path = base.parent / inc_path
                    try:
                        inc_cfg = load_yaml_config(inc_path, mode="simple")
                    except FileNotFoundError:
                        continue
                    if "output_type" in inc_cfg:
                        return inc_cfg["output_type"]
476
477
478
            return ""

        # -------------------------------------------------------------- GROUP table
479
480
        group_table = MarkdownTableWriter()
        group_table.headers = ["Group", "Config Location"]
481
482
483
484
485
486
487
488
489
        group_table.value_matrix = [
            [
                g,
                "---"
                if self.task_index[g]["yaml_path"] == -1
                else sanitize_path(self.task_index[g]["yaml_path"]),
            ]
            for g in self.all_groups
        ]
490

491
        # ---------------------------------------------------------------- TAG table
492
493
494
495
        tag_table = MarkdownTableWriter()
        tag_table.headers = ["Tag"]
        tag_table.value_matrix = [[t] for t in self.all_tags]

496
        # ------------------------------------------------------------ SUBTASK table
497
498
        subtask_table = MarkdownTableWriter()
        subtask_table.headers = ["Task", "Config Location", "Output Type"]
499
500
        st_values: list[list[str]] = []

501
        for t in self.all_subtasks:
502
503
504
505
506
507
            raw_path = self.task_index[t]["yaml_path"]

            if raw_path == -1:
                # python-only task or generated at runtime
                display_path = "---"
                output_type = ""
508
            else:
509
510
511
512
513
514
515
516
517
518
519
520
                path_obj = Path(raw_path)
                display_path = sanitize_path(str(path_obj))

                # load minimal YAML to discover output_type
                cfg = load_yaml_config(path_obj, mode="simple")
                if "output_type" in cfg:
                    output_type = cfg["output_type"]
                else:
                    output_type = first_output_type_from_includes(cfg, path_obj)

            st_values.append([t, display_path, output_type])

521
522
        subtask_table.value_matrix = st_values

523
524
        # ------------------------------------------------------------- final string
        parts: list[str] = ["\n"]
525
        if list_groups:
526
527
            parts.append(group_table.dumps())
            parts.append("\n")
528
        if list_tags:
529
530
            parts.append(tag_table.dumps())
            parts.append("\n")
531
        if list_subtasks:
532
533
534
535
            parts.append(subtask_table.dumps())
            parts.append("\n")

        return "".join(parts)
536

Baber Abbasi's avatar
Baber Abbasi committed
537
    def match_tasks(self, task_list: list[str]) -> list[str]:
538
        """Match task names using glob-style pattern matching."""
Baber's avatar
nit  
Baber committed
539
        return pattern_match(task_list, self.all_tasks)
540

Baber Abbasi's avatar
Baber Abbasi committed
541
    def _name_is_registered(self, name: str) -> bool:
Baber's avatar
Baber committed
542
        """Check if a name is registered in the task index."""
543
        return name in self.all_tasks
544

Baber Abbasi's avatar
Baber Abbasi committed
545
    def _name_is_task(self, name: str) -> bool:
Baber's avatar
Baber committed
546
        """Check if a name refers to an individual task (not group or tag)."""
547
548
549
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "task"
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
550

Baber Abbasi's avatar
Baber Abbasi committed
551
    def _name_is_tag(self, name: str) -> bool:
Baber's avatar
Baber committed
552
        """Check if a name refers to a tag."""
553
        return self._name_is_registered(name) and self.task_index[name]["type"] == "tag"
554

Baber Abbasi's avatar
Baber Abbasi committed
555
    def _name_is_group(self, name: str) -> bool:
Baber's avatar
Baber committed
556
        """Check if a name refers to a group."""
557
558
559
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "group"
        )
560

Baber Abbasi's avatar
Baber Abbasi committed
561
    def _name_is_python_task(self, name: str) -> bool:
Baber's avatar
Baber committed
562
        """Check if a name refers to a Python-defined task."""
563
564
565
566
        return (
            self._name_is_registered(name)
            and self.task_index[name]["type"] == "python_task"
        )
567

Baber Abbasi's avatar
Baber Abbasi committed
568
    def _config_is_task(self, config: dict) -> bool:
Baber's avatar
Baber committed
569
        """Check if a config dictionary defines a single task."""
570
        return "task" in config and isinstance(config["task"], str)
571

Baber Abbasi's avatar
Baber Abbasi committed
572
    def _config_is_group(self, config: dict) -> bool:
Baber's avatar
Baber committed
573
        """Check if a config dictionary defines a group of tasks."""
574
        return "task" in config and isinstance(config["task"], list)
575

Baber Abbasi's avatar
Baber Abbasi committed
576
    def _config_is_python_task(self, config: dict) -> bool:
Baber's avatar
Baber committed
577
        """Check if a config dictionary defines a Python class-based task."""
578
579
580
        return "class" in config

    def _config_is_task_list(self, config: dict) -> bool:
Baber's avatar
Baber committed
581
        """Check if a config dictionary defines a task list."""
582
        return "task_list" in config and isinstance(config["task_list"], list)
583

Baber's avatar
Baber committed
584
    def _get_yaml_path(self, name: str) -> Union[str, int]:
Baber's avatar
Baber committed
585
586
587
588
589
590
591
592
593
594
595
596
        """
        Get the YAML file path for a registered task.

        Args:
            name: Task name

        Returns:
            Path to YAML file, or -1 for Python-only tasks

        Raises:
            ValueError: If task name is not registered
        """
597
598
        if name not in self.task_index:
            raise ValueError
599
600
        return self.task_index[name]["yaml_path"]

Baber's avatar
nit  
Baber committed
601
    def _get_config(self, name: str) -> dict:
Baber's avatar
Baber committed
602
603
604
605
606
607
608
609
610
611
612
613
        """
        Load the full configuration for a registered task.

        Args:
            name: Task name

        Returns:
            Complete task configuration dictionary

        Raises:
            ValueError: If task name is not registered
        """
614
615
        if name not in self.task_index:
            raise ValueError
616
617
618
619
        yaml_path = self._get_yaml_path(name)
        if yaml_path == -1:
            return {}
        else:
620
            return load_yaml_config(Path(yaml_path), mode="full")
621

Baber's avatar
nit  
Baber committed
622
    def _get_tasklist(self, name: str) -> Union[list[str], int]:
Baber's avatar
Baber committed
623
624
625
626
627
628
629
630
631
632
633
634
        """
        Get the task list for a group or tag.

        Args:
            name: Group or tag name

        Returns:
            List of task names in the group/tag

        Raises:
            ValueError: If name refers to an individual task
        """
635
636
        if self._name_is_task(name):
            raise ValueError
637
638
        return self.task_index[name]["task"]

639
640
641
642
643
    def _register_task(
        self,
        task_name: str,
        task_type: str,
        yaml_path: str,
Baber's avatar
nit  
Baber committed
644
645
        tasks_and_groups: dict[str, dict],
        config: Optional[dict] = None,
Baber's avatar
Baber committed
646
647
        populate_tags_fn: Optional[callable] = None,
    ) -> None:
648
649
650
651
652
653
654
        """Helper method to register a task in the tasks_and_groups dict"""
        tasks_and_groups[task_name] = {
            "type": task_type,
            "yaml_path": yaml_path,
        }
        # Only populate tags for configs that support it (not groups)
        if config and task_type != "group" and populate_tags_fn:
655
            populate_tags_fn(config, task_name, tasks_and_groups)
656
657

    def _merge_task_configs(
Baber's avatar
nit  
Baber committed
658
659
        self, base_config: dict, task_specific_config: dict, task_name: str
    ) -> dict:
660
661
662
663
664
665
666
        """Merge base config with task-specific overrides for task_list configs"""
        if task_specific_config:
            task_specific_config = task_specific_config.copy()
            task_specific_config.pop("task", None)
            return {**base_config, **task_specific_config, "task": task_name}
        return {**base_config, "task": task_name}

Baber's avatar
Baber committed
667
    def _process_tag_subtasks(
Baber's avatar
nit  
Baber committed
668
669
        self, tag_name: str, update_config: Optional[dict] = None
    ) -> dict:
670
671
672
673
674
675
676
677
        """Process subtasks for a tag and return loaded tasks"""
        subtask_list = self._get_tasklist(tag_name)
        fn = partial(
            self._load_individual_task_or_group,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))

Baber's avatar
nit  
Baber committed
678
    def _process_alias(self, config: dict, group: Optional[str] = None) -> dict:
Baber's avatar
Baber committed
679
680
681
682
683
684
685
686
687
688
689
690
691
        """
        Process group alias configuration.

        If the group is not the same as the original group which the group alias
        was intended for, set the group_alias to None instead.

        Args:
            config: Task configuration dictionary
            group: Group name to validate against

        Returns:
            Modified configuration with processed aliases
        """
692
693
694
695
696
        if ("group_alias" in config) and ("group" in config) and group is not None:
            if config["group"] != group:
                config["group_alias"] = None
        return config

Baber's avatar
Baber committed
697
    def _class_has_config_in_constructor(self, cls) -> bool:
Baber's avatar
Baber committed
698
699
700
701
702
703
704
705
706
        """
        Check if a class constructor accepts a 'config' parameter.

        Args:
            cls: Class to inspect

        Returns:
            True if constructor has 'config' parameter, False otherwise
        """
707
708
709
710
711
712
713
        constructor = getattr(cls, "__init__", None)
        return (
            "config" in inspect.signature(constructor).parameters
            if constructor
            else False
        )

714
715
716
717
718
    ###############################################################################
    # NEW: Refactored _load_individual_task_or_group and helper methods          #
    ###############################################################################

    def _create_task_object(
719
        self,
720
721
        cfg: dict,
        task_name: str,
Baber's avatar
Baber committed
722
        yaml_path: Union[str, None],
723
    ) -> dict:
Baber's avatar
Baber committed
724
        """
725
726
        Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
        Returns {task_name: task_object}.
Baber's avatar
Baber committed
727
        """
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
        from lm_eval.api.task import ConfigurableTask, Task  # local import avoids cycle

        # ---- include handling ---------------------------------------------------
        if "include" in cfg:
            # keep original name so include merging cannot clobber it
            orig_name = cfg.get("task", task_name)
            cfg = {
                **load_yaml_config(  # recurse once, cached
                    yaml_path=Path(yaml_path) if yaml_path else None,
                    yaml_config={"include": cfg.pop("include")},
                    mode="full" if yaml_path else "simple",
                ),
                **cfg,
                "task": orig_name,
            }
743

744
745
746
747
748
749
750
751
752
753
754
755
        # ---- metadata merge -----------------------------------------------------
        if self.metadata is not None:
            cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
        else:
            cfg["metadata"] = cfg.get("metadata", {})

        # ---- python-task vs YAML-task -------------------------------------------
        if self._config_is_python_task(cfg):
            cls = cfg["class"]
            task_obj: Task
            if self._class_has_config_in_constructor(cls):
                task_obj = cls(config=cfg)
756
            else:
757
758
759
760
761
762
                task_obj = cls()
            # make sure name propagates when the class inherits ConfigurableTask
            if isinstance(task_obj, ConfigurableTask):  # type: ignore
                task_obj.config.task = task_name
        else:
            task_obj = ConfigurableTask(config=cfg)  # type: ignore
Baber's avatar
Baber committed
763

764
        return {task_name: task_obj}
Baber's avatar
Baber committed
765

766
767
768
    def _create_group_object(
        self,
        cfg: dict,
Baber's avatar
Baber committed
769
        parent_name: Union[str, None] = None,
770
    ) -> tuple[GroupConfig, list[Union[str, dict]]]:
771
        """
772
        Build GroupConfig and return (group_obj, subtask_names).
773
774
775
776
777
        Resolves tag expansion.
        """
        if self.metadata is not None:
            cfg["metadata"] = cfg.get("metadata", {}) | self.metadata

778
        grp = GroupConfig(**cfg)
779
        subtasks: list[Union[str, dict]] = []
780
        for t in grp.task:
781
782
783
784
785
            if isinstance(t, str) and self._name_is_tag(t):
                subtasks.extend(self._get_tasklist(t))
            else:
                subtasks.append(t)
        return grp, subtasks
Baber's avatar
Baber committed
786

787
788
789
    def _load_subtasks(
        self,
        subtasks: list[Union[str, dict]],
790
        parent_name: Union[str, GroupConfig, None],
Baber's avatar
Baber committed
791
        update_config: Union[dict, None],
792
793
794
795
796
797
798
799
    ) -> Mapping:
        """Return merged mapping of all subtasks, handling duplicates."""
        fn = functools.partial(
            self._load_individual_task_or_group,
            parent_name=parent_name,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtasks))))
Baber's avatar
Baber committed
800

801
802
    def _load_individual_task_or_group(
        self,
Baber's avatar
Baber committed
803
        payload: Union[str, dict],
804
        *,
Baber's avatar
Baber committed
805
806
        parent_name: Union[str, None] = None,
        update_config: Union[dict, None] = None,
807
808
809
810
811
    ) -> Mapping:
        """
        Public helper that turns *payload* (str task/group/tag **or** dict config)
        into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
        """
Baber's avatar
Baber committed
812

813
814
815
816
817
818
819
820
        # ------------------------------------------------------------------ STRING
        if isinstance(payload, str):
            # If caller supplied extra overrides, treat as dict immediately
            if update_config:
                return self._load_individual_task_or_group(
                    {"task": payload, **update_config},
                    parent_name=parent_name,
                )
821

822
823
824
825
826
827
828
829
830
831
832
833
834
835
            # ------------ registered TASK (YAML or python) -----------------
            if self._name_is_task(payload) or self._name_is_python_task(payload):
                yaml_path = self._get_yaml_path(payload)
                cfg = self._get_config(payload)

                # task_list configs: extract the per-task override ------------
                if "task_list" in cfg:
                    override = next(
                        (
                            entry
                            for entry in cfg["task_list"]
                            if isinstance(entry, dict) and entry.get("task") == payload
                        ),
                        None,
Lintang Sutawika's avatar
Lintang Sutawika committed
836
                    )
837
838
839
840
841
842
843
844
845
846
847
848
849
                    base = {k: v for k, v in cfg.items() if k != "task_list"}
                    if override:
                        cfg = {**base, **override, "task": payload}
                return self._create_task_object(cfg, payload, yaml_path)

            # ------------ registered GROUP ----------------------------------
            if self._name_is_group(payload):
                group_cfg = self._get_config(payload)
                grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS}
                grp_obj, subtasks = self._create_group_object(grp_only, parent_name)
                return {
                    grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None)
                }
850

851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
            # ------------ registered TAG ------------------------------------
            if self._name_is_tag(payload):
                return self._process_tag_subtasks(payload, update_config=None)

            raise ValueError(f"Unknown task / group / tag name: {payload!r}")

        # ------------------------------------------------------------------- DICT
        if isinstance(payload, dict):
            # ------------------ simple 'task: name' dict --------------------
            if self._config_is_task(payload):
                name = payload["task"]
                # override existing registered YAML if exists
                if self._name_is_registered(name):
                    base_cfg = self._get_config(name)
                    yaml_path = self._get_yaml_path(name)
                    merged = {**base_cfg, **payload}
867
                else:
868
                    merged = payload
869
                    yaml_path = None
870

871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
                # duplicate-naming guard when inside a group
                if parent_name is not None:
                    count = len(
                        [
                            n
                            for n in self.task_group_map[parent_name]
                            if n.startswith(name)
                        ]
                    )
                    if count:
                        name = f"{name}-{count}"
                    self.task_group_map[parent_name].append(name)

                return self._create_task_object(merged, name, yaml_path)

            # ----------------- literal group dict (task: [...]) -------------
            if self._config_is_group(payload):
                grp_cfg = {k: v for k, v in payload.items() if k in GROUP_ONLY_KEYS}
                sub_override = {
                    k: v for k, v in payload.items() if k not in GROUP_ONLY_KEYS
                } or None
                grp_obj, subtasks = self._create_group_object(grp_cfg, parent_name)
                return {grp_obj: self._load_subtasks(subtasks, grp_obj, sub_override)}

            # ----------------- python-task dict ('class': …) ----------------
            if self._config_is_python_task(payload):
                name = payload["task"]
                return self._create_task_object(payload, name, yaml_path=None)

        raise TypeError(
            f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
902
        )
903

Baber's avatar
Baber committed
904
    def load_task_or_group(
Baber's avatar
nit  
Baber committed
905
906
        self, task_list: Optional[Union[str, list[str]]] = None
    ) -> dict:
Baber's avatar
Baber committed
907
908
909
910
911
912
913
914
915
916
        """
        Load multiple tasks or groups from a list of names.

        This is the main entry point for loading tasks. It handles lists
        of task names and delegates to _load_individual_task_or_group for
        each item, then merges the results.

        Args:
            task_list: Single task name or list of task names to load.
                      Can include individual tasks, groups, and tags.
917

Baber's avatar
Baber committed
918
919
920
        Returns:
            Dictionary mapping task/group names to loaded task objects.
            Results from all requested items are merged into a single dict.
921

Baber's avatar
Baber committed
922
923
924
925
926
927
928
929
930
931
        Example:
            Load multiple tasks::

                tasks = tm.load_task_or_group(["hellaswag", "arc_easy"])
                # Returns: {"hellaswag": Task1, "arc_easy": Task2}

            Load a group::

                tasks = tm.load_task_or_group("arc_group")
                # Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
932
933
934
        """
        if isinstance(task_list, str):
            task_list = [task_list]
935

936
        all_loaded_tasks = dict(
Baber Abbasi's avatar
Baber Abbasi committed
937
938
939
940
941
942
            collections.ChainMap(
                *map(
                    lambda task: self._load_individual_task_or_group(task),
                    task_list,
                )
            )
943
944
945
        )
        return all_loaded_tasks

Baber's avatar
nit  
Baber committed
946
    def load_config(self, config: dict) -> Mapping:
Baber's avatar
Baber committed
947
948
949
950
951
952
953
954
955
956
957
958
959
        """
        Load a task from an inline configuration dictionary.

        Args:
            config: Configuration dictionary defining the task

        Returns:
            Mapping of task name to loaded task object

        Example:
            >>> config = {"task": "hellaswag", "num_fewshot": 5}
            >>> task_dict = tm.load_config(config)
        """
960
961
        return self._load_individual_task_or_group(config)

Baber's avatar
nit  
Baber committed
962
    def _get_task_and_group(self, task_dir: Union[str, Path]) -> dict[str, dict]:
Baber's avatar
Baber committed
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
        """
        Scan a directory for task configurations and build an index.

        Creates a dictionary of task metadata by recursively scanning for
        YAML files and parsing their configurations. This method handles:
        - Regular task configs with 'task' key
        - Python class-based tasks with 'class' key
        - Group configs with 'group' key
        - Task list configs with 'task_list' key
        - Tag extraction and registration

        Args:
            task_dir: Directory path to scan for YAML task configurations

        Returns:
            Dictionary mapping task names to metadata dictionaries.
            Each metadata dict contains:
            - 'type': One of 'task', 'python_task', 'group', 'tag'
            - 'yaml_path': Path to source YAML file (or -1 for generated entries)
            - 'task': For groups/tags, list of constituent task names

        Note:
            This method is called during TaskManager initialization to build
            the master task index. It uses 'simple' parsing mode for performance.
987
        """
988

Baber's avatar
Baber committed
989
        def _populate_tags_and_groups(
Baber's avatar
nit  
Baber committed
990
            config: dict, task: str, tasks_and_groups: dict[str, dict]
Baber's avatar
Baber committed
991
        ) -> None:
Baber's avatar
Baber committed
992
993
994
995
996
997
998
999
1000
1001
1002
1003
            """
            Extract and register tags from a task configuration.

            Tags allow grouping tasks by theme or category. This function
            processes the 'tag' field in task configs and maintains tag
            indices for quick lookup.

            Args:
                config: Task configuration dictionary
                task: Name of the task being processed
                tasks_and_groups: Master index to update with tag information
            """
1004
            # TODO: remove group in next release
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
            if "tag" in config:
                attr_list = config["tag"]
                if isinstance(attr_list, str):
                    attr_list = [attr_list]

                for tag in attr_list:
                    if tag not in tasks_and_groups:
                        tasks_and_groups[tag] = {
                            "type": "tag",
                            "task": [task],
                            "yaml_path": -1,
                        }
                    elif tasks_and_groups[tag]["type"] != "tag":
Lintang Sutawika's avatar
Lintang Sutawika committed
1018
                        eval_logger.info(
1019
1020
                            f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
                            "This may affect tasks you want to call."
1021
                        )
1022
1023
1024
                        break
                    else:
                        tasks_and_groups[tag]["task"].append(task)
1025

Lintang Sutawika's avatar
Lintang Sutawika committed
1026
        # TODO: remove group in next release
1027
1028
1029
1030
        # ignore_dirs = [
        #     "__pycache__",
        #     ".ipynb_checkpoints",
        # ]
1031
        tasks_and_groups = collections.defaultdict()
Baber's avatar
Baber committed
1032
1033
        task_dir_path = Path(task_dir)

1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
        for yaml_path in iter_yaml_files(task_dir_path):
            try:
                config = load_yaml_config(
                    yaml_path, mode="simple", resolve_includes=False
                )
            except (FileNotFoundError, YAMLError, OSError) as err:
                eval_logger.debug(f"File {yaml_path} could not be loaded ({err})")
                continue
            if self._config_is_python_task(config):
                # This is a python class config
                task = config["task"]
                self._register_task(
                    task,
                    "python_task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_group(config):
                # This is a group config
                tasks_and_groups[config["group"]] = {
                    "type": "group",
                    "task": -1,  # This signals that
                    # we don't need to know
                    # the task list for indexing
                    # as it can be loaded
                    # when called.
                    "yaml_path": str(yaml_path),
                }
lintangsutawika's avatar
lintangsutawika committed
1064

1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
                # # Registered the level 1 tasks from a group config
                # for config in config["task"]:
                #     if isinstance(config, dict) and self._config_is_task(config):
                #         task = config["task"]
                #         tasks_and_groups[task] = {
                #             "type": "task",
                #             "yaml_path": yaml_path,
                #             }

            elif self._config_is_task(config):
                # This is a task config
                task = config["task"]
                self._register_task(
                    task,
                    "task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_task_list(config):
                # This is a task_list config
                for task_entry in config["task_list"]:
                    if isinstance(task_entry, dict) and "task" in task_entry:
                        task_name = task_entry["task"]
1090
                        self._register_task(
1091
                            task_name,
1092
                            "task",
Baber's avatar
Baber committed
1093
                            str(yaml_path),
1094
1095
1096
                            tasks_and_groups,
                            config,
                            _populate_tags_and_groups,
1097
                        )
1098
1099
            else:
                eval_logger.debug(f"File {yaml_path} could not be loaded")
1100
1101

        return tasks_and_groups
lintangsutawika's avatar
lintangsutawika committed
1102

1103

Baber's avatar
nit  
Baber committed
1104
def get_task_name_from_config(task_config: dict[str, str]) -> str:
Baber's avatar
Baber committed
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
    """
    Extract a task name from a configuration dictionary.

    Determines the canonical name for a task based on its configuration,
    with fallback strategies for different config formats.

    Args:
        task_config: Task configuration dictionary

    Returns:
        String name for the task

    Example:
        >>> config = {"task": "hellaswag", "num_fewshot": 5}
        >>> get_task_name_from_config(config)
        'hellaswag'

        >>> config = {"dataset_path": "custom", "dataset_name": "mytask"}
        >>> get_task_name_from_config(config)
        'custom_mytask'
    """
1126
1127
1128
1129
1130
1131
    if "task" in task_config:
        return task_config["task"]
    if "dataset_name" in task_config:
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
lintangsutawika's avatar
lintangsutawika committed
1132

1133

Baber's avatar
nit  
Baber committed
1134
def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> str:
Baber's avatar
Baber committed
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
    """
    Extract the name from an instantiated task object.

    Handles both ConfigurableTask and legacy Task objects with different
    attribute conventions for storing the task name.

    Args:
        task_object: An instantiated task object

    Returns:
        String name of the task

    Example:
        >>> task = ConfigurableTask(config={"task": "hellaswag"})
        >>> get_task_name_from_object(task)
        'hellaswag'
    """
1152
1153
    if hasattr(task_object, "config"):
        return task_object._config["task"]
lintangsutawika's avatar
lintangsutawika committed
1154
1155
1156
1157
1158
1159
1160
1161
1162

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )

1163

Baber's avatar
nit  
Baber committed
1164
def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
Baber's avatar
Baber committed
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
    """
    Validate that no tasks appear in multiple groups simultaneously.

    Helper function used to prevent conflicts when multiple groups claim
    the same constituent task. This could lead to ambiguous configuration
    like conflicting num_fewshot values.

    Args:
        task_dict: Dictionary mapping group names to lists of subtask names

    Raises:
        ValueError: If any tasks appear in multiple groups

    Example:
        >>> task_dict = {
        ...     "group1": ["task_a", "task_b"],
        ...     "group2": ["task_b", "task_c"]  # task_b appears twice!
        ... }
        >>> _check_duplicates(task_dict)  # Raises ValueError
Lintang Sutawika's avatar
Lintang Sutawika committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
    """
    subtask_names = []
    for key, value in task_dict.items():
        subtask_names.extend(value)

    duplicate_tasks = {
        task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
    }

    # locate the potentially problematic groups that seem to 'compete' for constituent subtasks
    competing_groups = [
        group
        for group in task_dict.keys()
        if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
    ]

    if len(duplicate_tasks) > 0:
        raise ValueError(
            f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
        )


1206
def get_task_dict(
Baber's avatar
nit  
Baber committed
1207
    task_name_list: Union[str, list[Union[str, dict, "Task"]]],
1208
    task_manager: Optional[TaskManager] = None,
Baber's avatar
nit  
Baber committed
1209
) -> dict[str, Union["ConfigurableTask", "Task"]]:
Baber's avatar
Baber committed
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
    """
    Create a dictionary of task objects from mixed input types.

    This is the main public API for loading tasks. It accepts various input
    formats (names, configs, objects) and returns a unified dictionary of
    instantiated task objects ready for evaluation.

    The function handles:
    - String task names (looked up via TaskManager)
    - Configuration dictionaries (processed as inline configs)
    - Pre-instantiated Task objects (used as-is)
    - Validation to prevent conflicting group memberships

    Args:
        task_name_list: Mixed list of task specifications:
                       - str: Task name to look up
                       - dict: Inline task configuration
                       - Task: Pre-instantiated task object
        task_manager: TaskManager instance for name resolution.
                     If None, creates a default TaskManager.

    Returns:
        Dictionary mapping task names to instantiated task objects.
        All tasks are ready for evaluation.

    Raises:
        TypeError: If task_name_list contains unsupported types
        ValueError: If there are conflicting group memberships

    Example:
        Mixed input types::

            tasks = get_task_dict([
                "hellaswag",                              # lookup by name
                {"task": "arc_easy", "num_fewshot": 5},   # inline config
                pre_existing_task_object                  # direct object
            ])

        Simple case::

            tasks = get_task_dict("hellaswag")
            # Returns: {"hellaswag": ConfigurableTask(...)}

        With custom TaskManager::

            tm = TaskManager(include_path="/custom/tasks")
            tasks = get_task_dict(["custom_task"], task_manager=tm)
1257
    """
1258
    from lm_eval.api.task import Task
1259

1260
    # Normalize input to list
1261
    if isinstance(task_name_list, str):
lintangsutawika's avatar
lintangsutawika committed
1262
        task_name_list = [task_name_list]
1263
    elif not isinstance(task_name_list, list):
1264
1265
1266
        raise TypeError(
            f"Expected a 'str' or 'list' but received {type(task_name_list)}."
        )
lintangsutawika's avatar
lintangsutawika committed
1267

1268
1269
1270
1271
    # Validate list items
    if not all(isinstance(task, (str, dict, Task)) for task in task_name_list):
        raise TypeError(
            "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
1272
        )
1273

1274
1275
1276
    # Ensure we have a task manager
    if task_manager is None:
        task_manager = TaskManager()
Lintang Sutawika's avatar
Lintang Sutawika committed
1277

1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
    # Process all items
    final_task_dict = {}
    for task_spec in task_name_list:
        if isinstance(task_spec, Task):
            # Pre-instantiated task object
            task_name = get_task_name_from_object(task_spec)
            if task_name in final_task_dict:
                raise ValueError(f"Duplicate task name: {task_name}")
            final_task_dict[task_name] = task_spec
        else:
            # String or dict - use load_task_or_group
            result = task_manager.load_task_or_group(task_spec)
            # Check for duplicate names
            for name in result:
                if name in final_task_dict:
                    raise ValueError(f"Duplicate task name: {name}")
            final_task_dict.update(result)

    # Check for conflicting group memberships
Lintang Sutawika's avatar
Lintang Sutawika committed
1297
1298
1299
    _check_duplicates(get_subtask_list(final_task_dict))

    return final_task_dict