__init__.py 46 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
Task Management Module for LM Evaluation Harness.

This module provides comprehensive task discovery, loading, and management functionality
for the LM Evaluation Harness. It handles YAML configuration parsing with include support,
dynamic function importing, and task indexing across multiple directories.

Key Components:
- TaskManager: Main class for task discovery and management
- YAML configuration loading with !function tag support
- Task, group, and tag indexing
- Include resolution with cycle detection
- Caching for performance optimization

Example:
    Basic usage::

        task_manager = TaskManager()
        all_tasks = task_manager.all_tasks
        task_config = task_manager._get_config("hellaswag")

    Custom task paths::

        task_manager = TaskManager(
            include_path="/path/to/custom/tasks",
            include_defaults=True
        )
"""

30
import collections
31
32
import functools
import importlib.util
33
import inspect
34
import logging
35
import sys
36
from functools import partial
37
from glob import iglob
Baber's avatar
Baber committed
38
from pathlib import Path
Baber's avatar
nit  
Baber committed
39
40
41
42
43
44
45
46
47
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Generator,
    Mapping,
    Optional,
    Union,
)
48
49
50

import yaml
from yaml import YAMLError
&'s avatar
& committed
51

52
from lm_eval.api.group import GroupConfig
Lintang Sutawika's avatar
Lintang Sutawika committed
53
from lm_eval.evaluator_utils import get_subtask_list
Baber's avatar
nit  
Baber committed
54
55
56
57
58
from lm_eval.utils import pattern_match, setup_logging


if TYPE_CHECKING:
    from lm_eval.api.task import ConfigurableTask, Task
Lintang Sutawika's avatar
Lintang Sutawika committed
59

Baber's avatar
nit  
Baber committed
60
eval_logger = logging.getLogger(__name__)
Lintang Sutawika's avatar
Lintang Sutawika committed
61

Baber's avatar
Baber committed
62
#: List of configuration keys that are specific to groups only
Lintang Sutawika's avatar
Lintang Sutawika committed
63
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
Baber's avatar
Baber committed
64
65

#: Base YAML loader class - uses C loader if available for performance
Baber's avatar
nit  
Baber committed
66
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
Baber's avatar
Baber committed
67
68

#: Directory names to ignore during task discovery
Baber's avatar
nit  
Baber committed
69
70
71
72
_IGNORE_DIRS = (
    "__pycache__",
    ".ipynb_checkpoints",
)
73

lintangsutawika's avatar
lintangsutawika committed
74

Baber's avatar
nit  
Baber committed
75
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
76
    """YAML constructor that ignores !function tags during simple parsing."""
Baber's avatar
nit  
Baber committed
77
    return None
78
79


Baber's avatar
nit  
Baber committed
80
@functools.lru_cache(maxsize=2048)  # ← reuse per (directory, simple) pair
81
82
83
84
85
86
87
88
89
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
    """
    Return a custom YAML Loader class bound to *yaml_dir*.

    yaml_dir
        Directory that holds the YAML file being parsed.
        We capture it so that !function look-ups can resolve relative
        Python files like  my_utils.some_fn  ➜  yaml_dir / "my_utils.py".
    simple
Baber's avatar
nit  
Baber committed
90
91
        If True we ignore !function completely (used by `mode="simple"`),
        used on TaskManager init to index.
92
93
94
95
96
97
98
    """

    class Loader(_Base):
        """Dynamically-generated loader that knows its base directory."""

    # Register (or stub) the !function constructor **for this Loader only**
    if simple:
Baber's avatar
nit  
Baber committed
99
        yaml.add_constructor("!function", ignore_constructor, Loader=Loader)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    else:
        yaml.add_constructor(
            "!function",
            # capture yaml_dir once so the lambda is fast and pickle-able
            lambda ld, node, _dir=yaml_dir: _import_function(
                ld.construct_scalar(node),
                base_path=_dir,
            ),
            Loader=Loader,
        )

    return Loader


Baber's avatar
nit  
Baber committed
114
@functools.lru_cache(maxsize=None)  # ← cache module objects
115
def _import_function(qualname: str, *, base_path: Path) -> Callable:
Baber's avatar
Baber committed
116
117
118
119
    """
    Dynamically import a function from a Python module relative to base_path.

    This function enables YAML files to reference Python functions using
120
    the !function tag. Supports dot notation for nested modules.
Baber's avatar
Baber committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

    Args:
        qualname: Qualified function name like "my_module.my_function"
        base_path: Base directory for resolving relative module paths

    Returns:
        The imported callable function

    Raises:
        ValueError: If qualname doesn't contain a module part

    Example:
        >>> func = _import_function("utils.custom_metric", base_path=Path("/tasks"))
        >>> result = func(predictions, references)
    """
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    mod_path, _, func_name = qualname.rpartition(".")
    if not mod_path:
        raise ValueError(f"{qualname!r} has no module part")
    file_path = base_path / f"{mod_path.replace('.', '/')}.py"
    module_name = f"_yaml_dynamic.{hash(file_path)}_{file_path.stem}"
    if module_name in sys.modules:
        mod = sys.modules[module_name]
    else:
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        mod = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(mod)
        sys.modules[module_name] = mod
    return getattr(mod, func_name)


Baber's avatar
Baber committed
151
@functools.lru_cache(maxsize=4096)
152
def _parse_yaml_file(path: Path, mode: str) -> dict:
Baber's avatar
Baber committed
153
154
155
156
157
158
159
160
161
162
    """
    Parse a single YAML file with the appropriate loader.

    Args:
        path: Path to the YAML file
        mode: Parsing mode ("full" or "simple")

    Returns:
        Parsed YAML configuration as dictionary
    """
163
164
165
166
167
    loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
    with path.open("rb") as fh:
        return yaml.load(fh, Loader=loader_cls)


Baber's avatar
nit  
Baber committed
168
@functools.lru_cache(maxsize=4096)
Baber's avatar
nit  
Baber committed
169
def _get_cached_config(yaml_path: Path, mode: str) -> dict:
170
    """Load and cache resolved YAML configs"""
Baber's avatar
nit  
Baber committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    # Parse the YAML file
    yaml_config = _parse_yaml_file(yaml_path, mode)
    yaml_dir = yaml_path.parent

    # Handle includes
    include = yaml_config.pop("include", None)
    if not include:
        return yaml_config

    include_paths = include if isinstance(include, list) else [include]
    final_cfg: dict = {}

    for inc in reversed(include_paths):
        if inc is None:
            continue
        inc_path = Path(inc)
        if not inc_path.is_absolute():
            inc_path = (yaml_dir / inc_path).resolve()
        # Recursive call will use the cache
        included = _get_cached_config(inc_path, mode)
        final_cfg.update(included)

    final_cfg.update(yaml_config)  # local keys win
    return final_cfg


197
198
def load_yaml_config(
    yaml_path: Union[Path, str, None] = None,
Baber's avatar
Baber committed
199
200
    yaml_config: Optional[dict] = None,
    yaml_dir: Optional[Path] = None,
201
202
    mode: str = "full",
    *,
Baber's avatar
Baber committed
203
    _seen: Optional[set[tuple[Path, str]]] = None,
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    resolve_includes: bool = True,
) -> dict:
    """
    Parse a YAML config with optional include handling.

    Parameters
    ----------
    yaml_path
        Path to the main YAML file.  Needed unless *yaml_config* is
        supplied directly (e.g. by tests).
    yaml_config
        Pre-parsed dict to use instead of reading *yaml_path*.
    yaml_dir
        Base directory for resolving relative include paths.  Defaults
        to `yaml_path.parent`.
    mode
        "full"  – honour  !function  tags
        "simple" – ignore !function  (faster).
    _seen
        **Internal** recursion set: tuples of (absolute-path, mode).
        Prevents include cycles such as  A → B → A.
    """
    if yaml_config is None and yaml_path is None:
        raise ValueError("load_yaml_config needs either yaml_path or yaml_config")

    # ------------------------------------------------------------------ cycle guard
    if _seen is None:
        _seen = set()
    if yaml_path is not None:
        yaml_path = Path(yaml_path).expanduser().resolve()

Baber's avatar
nit  
Baber committed
235
236
237
        # ---------- fast-path: use LRU cached function ----------
        if yaml_config is None and resolve_includes:
            return _get_cached_config(yaml_path, mode)
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

        key = (yaml_path.resolve(), mode)
        if key in _seen:
            raise ValueError(f"Include cycle detected at {yaml_path}")
        _seen.add(key)

    # ------------------------------------------------------------------ load / parse
    if yaml_config is None:  # ordinary path-based load
        yaml_config = _parse_yaml_file(yaml_path, mode)

    if yaml_dir is None and yaml_path is not None:
        yaml_dir = yaml_path.parent
    assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"

    # ------------------------------------------------------------------ handle include
    include = yaml_config.pop("include", None)
    if not include and not resolve_includes:
        return yaml_config

    include_paths = include if isinstance(include, list) else [include]
    final_cfg: dict = {}

    for inc in reversed(include_paths):
        if inc is None:  # guard against explicit nulls
            continue
        inc_path = Path(inc)
        if not inc_path.is_absolute():
            inc_path = (yaml_dir / inc_path).resolve()
        included = load_yaml_config(
            yaml_path=inc_path,
            mode=mode,
            yaml_dir=inc_path.parent,
            _seen=_seen,  # <-- pass set downward
        )
        final_cfg.update(included)

    final_cfg.update(yaml_config)  # local keys win
    return final_cfg


278
def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
Baber's avatar
Baber committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    """
    Recursively iterate over all YAML files in a directory tree.

    Excludes files in ignored directories like __pycache__ and .ipynb_checkpoints.

    Args:
        root: Root directory to search for YAML files

    Yields:
        Path objects for each discovered YAML file

    Example:
        >>> for yaml_file in iter_yaml_files(Path("tasks")):
        ...     print(f"Found task config: {yaml_file}")
    """
Baber's avatar
Baber committed
294
    for p in iglob(str(root / "**/*.yaml"), recursive=True):
Baber's avatar
nit  
Baber committed
295
        # ignore check
Baber's avatar
Baber committed
296
297
298
        path = Path(p)
        # Check if any parent directory is in the ignore list
        if any(part in ignore for part in path.parts):
299
            continue
Baber's avatar
Baber committed
300
        yield path
Lintang Sutawika's avatar
Lintang Sutawika committed
301

302

303
class TaskManager:
Baber's avatar
Baber committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    """
    Central manager for task discovery, indexing, and loading.

    TaskManager scans directories for YAML task configurations and maintains
    an index of all available tasks, groups, and tags. It provides methods
    for listing, filtering, and loading tasks with their configurations.

    The manager supports:
    - Automatic discovery from default lm_eval/tasks/ directory
    - Custom task directories via include_path
    - Task grouping and tagging
    - Configuration inheritance via YAML includes
    - Caching for performance

    Attributes:
        include_path: Additional directories to search for tasks
        metadata: Global metadata to inject into all task configs
        task_group_map: Mapping of tasks to their parent groups
322

Baber's avatar
Baber committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    Example:
        Basic usage::

            tm = TaskManager()
            print(f"Found {len(tm.all_tasks)} tasks")
            hellaswag_config = tm._get_config("hellaswag")

        With custom tasks::

            tm = TaskManager(
                include_path="/my/custom/tasks",
                verbosity="INFO"
            )
            custom_tasks = [t for t in tm.all_tasks if "custom" in t]
337
338
    """

339
340
    def __init__(
        self,
Lintang Sutawika's avatar
Lintang Sutawika committed
341
        verbosity: Optional[str] = None,
Baber's avatar
nit  
Baber committed
342
        include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
343
        include_defaults: bool = True,
344
        metadata: Optional[dict[str, dict[str, Any]]] = None,
345
    ) -> None:
Baber's avatar
Baber committed
346
347
348
349
350
351
352
353
354
355
        """
        Initialize the TaskManager.

        Args:
            verbosity: Logging verbosity level (DEBUG, INFO, WARNING, ERROR)
            include_path: Additional path(s) to search for tasks. Can be a single
                         path or list of paths.
            include_defaults: Whether to include default tasks from lm_eval/tasks/
            metadata: Global metadata dictionary to inject into all task configs
        """
Lintang Sutawika's avatar
Lintang Sutawika committed
356
        if verbosity is not None:
Baber's avatar
nit  
Baber committed
357
            setup_logging(verbosity)
358
        self.include_path = include_path
Baber Abbasi's avatar
Baber Abbasi committed
359
        self.metadata = metadata
360
361
362
        self._task_index = self.initialize_tasks(
            include_path=include_path, include_defaults=include_defaults
        )
363
        self._all_tasks = sorted(list(self._task_index.keys()))
364

365
366
367
368
        self._all_groups = sorted(
            [x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
        )
        self._all_subtasks = sorted(
369
370
371
372
373
            [
                x
                for x in self._all_tasks
                if self._task_index[x]["type"] in ["task", "python_task"]
            ]
374
375
376
377
378
        )
        self._all_tags = sorted(
            [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
        )

379
        self.task_group_map = collections.defaultdict(list)
380

381
382
    def initialize_tasks(
        self,
Baber's avatar
nit  
Baber committed
383
        include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
384
        include_defaults: bool = True,
Baber Abbasi's avatar
Baber Abbasi committed
385
386
    ) -> dict[str, dict]:
        """Creates a dictionary of tasks indexes.
387

Baber's avatar
nit  
Baber committed
388
        :param include_path: Union[str, list] = None
389
390
391
392
            An additional path to be searched for tasks recursively.
            Can provide more than one such path as a list.
        :param include_defaults: bool = True
            If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
Baber Abbasi's avatar
Baber Abbasi committed
393
        return
Baber's avatar
nit  
Baber committed
394
            dictionary of task names as key and task metadata
395
        """
396
        if include_defaults:
Baber's avatar
Baber committed
397
            all_paths = [Path(__file__).parent]
398
399
        else:
            all_paths = []
400
        if include_path is not None:
Baber's avatar
Baber committed
401
            if isinstance(include_path, (str, Path)):
402
                include_path = [include_path]
Baber's avatar
Baber committed
403
404
            # Convert all paths to Path objects
            all_paths.extend(Path(p) for p in include_path)
405

406
407
408
409
        task_index = {}
        for task_dir in all_paths:
            tasks = self._get_task_and_group(task_dir)
            task_index = {**tasks, **task_index}
lintangsutawika's avatar
format  
lintangsutawika committed
410

411
412
413
        return task_index

    @property
Baber's avatar
nit  
Baber committed
414
    def all_tasks(self) -> list[str]:
Baber's avatar
Baber committed
415
        """Get sorted list of all task names (tasks, groups, and tags)."""
416
417
        return self._all_tasks

418
    @property
Baber's avatar
nit  
Baber committed
419
    def all_groups(self) -> list[str]:
Baber's avatar
Baber committed
420
        """Get sorted list of all group names."""
421
422
423
        return self._all_groups

    @property
Baber's avatar
nit  
Baber committed
424
    def all_subtasks(self) -> list[str]:
Baber's avatar
Baber committed
425
        """Get sorted list of all individual task names (excludes groups and tags)."""
426
427
428
        return self._all_subtasks

    @property
Baber's avatar
nit  
Baber committed
429
    def all_tags(self) -> list[str]:
Baber's avatar
Baber committed
430
        """Get sorted list of all tag names."""
431
432
        return self._all_tags

433
    @property
Baber's avatar
nit  
Baber committed
434
    def task_index(self) -> dict[str, dict[str, Union[str, int, list[str]]]]:
Baber's avatar
Baber committed
435
        """Get the complete task index with metadata for all tasks."""
436
437
        return self._task_index

438
    def list_all_tasks(
Baber's avatar
Baber committed
439
440
441
442
        self,
        list_groups: bool = True,
        list_tags: bool = True,
        list_subtasks: bool = True,
443
    ) -> str:
444
445
446
447
448
        """
        Return a Markdown table (as a string) listing groups, tags and/or subtasks
        known to this TaskManager.  Safe for configs whose yaml_path is -1 and for
        task configs whose `include:` is a list.
        """
449
450
        from pytablewriter import MarkdownTableWriter

451
        # ------------------------------------------------------------------ helpers
Baber's avatar
Baber committed
452
        def sanitize_path(path: str) -> str:
453
454
            # print a relative path for anything inside lm_eval/tasks/
            # path_str = str(path)
455
456
            if "lm_eval/tasks/" in path:
                return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
457
458
459
460
461
462
463
464
465
466
467
            return path

        def first_output_type_from_includes(cfg: dict, base: Path) -> str:
            """Walk cfg['include'] (string or list) and return the first
            include that itself specifies an output_type."""
            inc_raw = cfg.get("include")
            if not inc_raw:
                return ""

            inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw]
            for inc in inc_list:
Baber's avatar
nit  
Baber committed
468
469
470
471
472
473
474
475
476
477
                if inc:
                    inc_path = Path(inc)
                    if not inc_path.is_absolute():  # treat as relative include
                        inc_path = base.parent / inc_path
                    try:
                        inc_cfg = load_yaml_config(inc_path, mode="simple")
                    except FileNotFoundError:
                        continue
                    if "output_type" in inc_cfg:
                        return inc_cfg["output_type"]
478
479
480
            return ""

        # -------------------------------------------------------------- GROUP table
481
482
        group_table = MarkdownTableWriter()
        group_table.headers = ["Group", "Config Location"]
483
484
485
486
487
488
489
490
491
        group_table.value_matrix = [
            [
                g,
                "---"
                if self.task_index[g]["yaml_path"] == -1
                else sanitize_path(self.task_index[g]["yaml_path"]),
            ]
            for g in self.all_groups
        ]
492

493
        # ---------------------------------------------------------------- TAG table
494
495
496
497
        tag_table = MarkdownTableWriter()
        tag_table.headers = ["Tag"]
        tag_table.value_matrix = [[t] for t in self.all_tags]

498
        # ------------------------------------------------------------ SUBTASK table
499
500
        subtask_table = MarkdownTableWriter()
        subtask_table.headers = ["Task", "Config Location", "Output Type"]
501
502
        st_values: list[list[str]] = []

503
        for t in self.all_subtasks:
504
505
506
507
508
509
            raw_path = self.task_index[t]["yaml_path"]

            if raw_path == -1:
                # python-only task or generated at runtime
                display_path = "---"
                output_type = ""
510
            else:
511
512
513
514
515
516
517
518
519
520
521
522
                path_obj = Path(raw_path)
                display_path = sanitize_path(str(path_obj))

                # load minimal YAML to discover output_type
                cfg = load_yaml_config(path_obj, mode="simple")
                if "output_type" in cfg:
                    output_type = cfg["output_type"]
                else:
                    output_type = first_output_type_from_includes(cfg, path_obj)

            st_values.append([t, display_path, output_type])

523
524
        subtask_table.value_matrix = st_values

525
526
        # ------------------------------------------------------------- final string
        parts: list[str] = ["\n"]
527
        if list_groups:
528
529
            parts.append(group_table.dumps())
            parts.append("\n")
530
        if list_tags:
531
532
            parts.append(tag_table.dumps())
            parts.append("\n")
533
        if list_subtasks:
534
535
536
537
            parts.append(subtask_table.dumps())
            parts.append("\n")

        return "".join(parts)
538

Baber Abbasi's avatar
Baber Abbasi committed
539
    def match_tasks(self, task_list: list[str]) -> list[str]:
540
        """Match task names using glob-style pattern matching."""
Baber's avatar
nit  
Baber committed
541
        return pattern_match(task_list, self.all_tasks)
542

Baber Abbasi's avatar
Baber Abbasi committed
543
    def _name_is_registered(self, name: str) -> bool:
Baber's avatar
Baber committed
544
        """Check if a name is registered in the task index."""
545
        return name in self.all_tasks
546

Baber Abbasi's avatar
Baber Abbasi committed
547
    def _name_is_task(self, name: str) -> bool:
Baber's avatar
Baber committed
548
        """Check if a name refers to an individual task (not group or tag)."""
549
550
551
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "task"
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
552

Baber Abbasi's avatar
Baber Abbasi committed
553
    def _name_is_tag(self, name: str) -> bool:
Baber's avatar
Baber committed
554
        """Check if a name refers to a tag."""
555
        return self._name_is_registered(name) and self.task_index[name]["type"] == "tag"
556

Baber Abbasi's avatar
Baber Abbasi committed
557
    def _name_is_group(self, name: str) -> bool:
Baber's avatar
Baber committed
558
        """Check if a name refers to a group."""
559
560
561
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "group"
        )
562

Baber Abbasi's avatar
Baber Abbasi committed
563
    def _name_is_python_task(self, name: str) -> bool:
Baber's avatar
Baber committed
564
        """Check if a name refers to a Python-defined task."""
565
566
567
568
        return (
            self._name_is_registered(name)
            and self.task_index[name]["type"] == "python_task"
        )
569

Baber Abbasi's avatar
Baber Abbasi committed
570
    def _config_is_task(self, config: dict) -> bool:
Baber's avatar
Baber committed
571
        """Check if a config dictionary defines a single task."""
572
        return "task" in config and isinstance(config["task"], str)
573

Baber Abbasi's avatar
Baber Abbasi committed
574
    def _config_is_group(self, config: dict) -> bool:
Baber's avatar
Baber committed
575
        """Check if a config dictionary defines a group of tasks."""
576
        return "task" in config and isinstance(config["task"], list)
577

Baber Abbasi's avatar
Baber Abbasi committed
578
    def _config_is_python_task(self, config: dict) -> bool:
Baber's avatar
Baber committed
579
        """Check if a config dictionary defines a Python class-based task."""
580
581
582
        return "class" in config

    def _config_is_task_list(self, config: dict) -> bool:
Baber's avatar
Baber committed
583
        """Check if a config dictionary defines a task list."""
584
        return "task_list" in config and isinstance(config["task_list"], list)
585

Baber's avatar
Baber committed
586
    def _get_yaml_path(self, name: str) -> Union[str, int]:
Baber's avatar
Baber committed
587
588
589
590
591
592
593
594
595
596
597
598
        """
        Get the YAML file path for a registered task.

        Args:
            name: Task name

        Returns:
            Path to YAML file, or -1 for Python-only tasks

        Raises:
            ValueError: If task name is not registered
        """
599
600
        if name not in self.task_index:
            raise ValueError
601
602
        return self.task_index[name]["yaml_path"]

Baber's avatar
nit  
Baber committed
603
    def _get_config(self, name: str) -> dict:
Baber's avatar
Baber committed
604
605
606
607
608
609
610
611
612
613
614
615
        """
        Load the full configuration for a registered task.

        Args:
            name: Task name

        Returns:
            Complete task configuration dictionary

        Raises:
            ValueError: If task name is not registered
        """
616
617
        if name not in self.task_index:
            raise ValueError
618
619
620
621
        yaml_path = self._get_yaml_path(name)
        if yaml_path == -1:
            return {}
        else:
622
            return load_yaml_config(Path(yaml_path), mode="full")
623

Baber's avatar
nit  
Baber committed
624
    def _get_tasklist(self, name: str) -> Union[list[str], int]:
Baber's avatar
Baber committed
625
626
627
628
629
630
631
632
633
634
635
636
        """
        Get the task list for a group or tag.

        Args:
            name: Group or tag name

        Returns:
            List of task names in the group/tag

        Raises:
            ValueError: If name refers to an individual task
        """
637
638
        if self._name_is_task(name):
            raise ValueError
639
640
        return self.task_index[name]["task"]

641
642
643
644
645
    def _register_task(
        self,
        task_name: str,
        task_type: str,
        yaml_path: str,
Baber's avatar
nit  
Baber committed
646
647
        tasks_and_groups: dict[str, dict],
        config: Optional[dict] = None,
Baber's avatar
Baber committed
648
649
        populate_tags_fn: Optional[callable] = None,
    ) -> None:
650
651
652
653
654
655
656
        """Helper method to register a task in the tasks_and_groups dict"""
        tasks_and_groups[task_name] = {
            "type": task_type,
            "yaml_path": yaml_path,
        }
        # Only populate tags for configs that support it (not groups)
        if config and task_type != "group" and populate_tags_fn:
657
            populate_tags_fn(config, task_name, tasks_and_groups)
658
659

    def _merge_task_configs(
Baber's avatar
nit  
Baber committed
660
661
        self, base_config: dict, task_specific_config: dict, task_name: str
    ) -> dict:
662
663
664
665
666
667
668
        """Merge base config with task-specific overrides for task_list configs"""
        if task_specific_config:
            task_specific_config = task_specific_config.copy()
            task_specific_config.pop("task", None)
            return {**base_config, **task_specific_config, "task": task_name}
        return {**base_config, "task": task_name}

Baber's avatar
Baber committed
669
    def _process_tag_subtasks(
Baber's avatar
nit  
Baber committed
670
671
        self, tag_name: str, update_config: Optional[dict] = None
    ) -> dict:
672
673
674
675
676
677
678
679
        """Process subtasks for a tag and return loaded tasks"""
        subtask_list = self._get_tasklist(tag_name)
        fn = partial(
            self._load_individual_task_or_group,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))

Baber's avatar
nit  
Baber committed
680
    def _process_alias(self, config: dict, group: Optional[str] = None) -> dict:
Baber's avatar
Baber committed
681
682
683
684
685
686
687
688
689
690
691
692
693
        """
        Process group alias configuration.

        If the group is not the same as the original group which the group alias
        was intended for, set the group_alias to None instead.

        Args:
            config: Task configuration dictionary
            group: Group name to validate against

        Returns:
            Modified configuration with processed aliases
        """
694
695
696
697
698
        if ("group_alias" in config) and ("group" in config) and group is not None:
            if config["group"] != group:
                config["group_alias"] = None
        return config

Baber's avatar
Baber committed
699
    def _class_has_config_in_constructor(self, cls) -> bool:
Baber's avatar
Baber committed
700
701
702
703
704
705
706
707
708
        """
        Check if a class constructor accepts a 'config' parameter.

        Args:
            cls: Class to inspect

        Returns:
            True if constructor has 'config' parameter, False otherwise
        """
709
710
711
712
713
714
715
        constructor = getattr(cls, "__init__", None)
        return (
            "config" in inspect.signature(constructor).parameters
            if constructor
            else False
        )

716
717
718
719
720
    ###############################################################################
    # NEW: Refactored _load_individual_task_or_group and helper methods          #
    ###############################################################################

    def _create_task_object(
721
        self,
722
723
        cfg: dict,
        task_name: str,
Baber's avatar
Baber committed
724
        yaml_path: Union[str, None],
725
    ) -> dict:
Baber's avatar
Baber committed
726
        """
727
728
        Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
        Returns {task_name: task_object}.
Baber's avatar
Baber committed
729
        """
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
        from lm_eval.api.task import ConfigurableTask, Task  # local import avoids cycle

        # ---- include handling ---------------------------------------------------
        if "include" in cfg:
            # keep original name so include merging cannot clobber it
            orig_name = cfg.get("task", task_name)
            cfg = {
                **load_yaml_config(  # recurse once, cached
                    yaml_path=Path(yaml_path) if yaml_path else None,
                    yaml_config={"include": cfg.pop("include")},
                    mode="full" if yaml_path else "simple",
                ),
                **cfg,
                "task": orig_name,
            }
745

746
747
748
749
750
751
752
753
754
755
756
757
        # ---- metadata merge -----------------------------------------------------
        if self.metadata is not None:
            cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
        else:
            cfg["metadata"] = cfg.get("metadata", {})

        # ---- python-task vs YAML-task -------------------------------------------
        if self._config_is_python_task(cfg):
            cls = cfg["class"]
            task_obj: Task
            if self._class_has_config_in_constructor(cls):
                task_obj = cls(config=cfg)
758
            else:
759
760
761
762
763
764
                task_obj = cls()
            # make sure name propagates when the class inherits ConfigurableTask
            if isinstance(task_obj, ConfigurableTask):  # type: ignore
                task_obj.config.task = task_name
        else:
            task_obj = ConfigurableTask(config=cfg)  # type: ignore
Baber's avatar
Baber committed
765

766
        return {task_name: task_obj}
Baber's avatar
Baber committed
767

768
769
770
    def _create_group_object(
        self,
        cfg: dict,
Baber's avatar
Baber committed
771
        parent_name: Union[str, None] = None,
772
    ) -> tuple[GroupConfig, list[Union[str, dict]]]:
773
        """
774
        Build GroupConfig and return (group_obj, subtask_names).
775
776
777
778
779
        Resolves tag expansion.
        """
        if self.metadata is not None:
            cfg["metadata"] = cfg.get("metadata", {}) | self.metadata

780
        grp = GroupConfig(**cfg)
781
        subtasks: list[Union[str, dict]] = []
782
        for t in grp.task:
783
784
785
786
787
            if isinstance(t, str) and self._name_is_tag(t):
                subtasks.extend(self._get_tasklist(t))
            else:
                subtasks.append(t)
        return grp, subtasks
Baber's avatar
Baber committed
788

789
790
791
    def _load_subtasks(
        self,
        subtasks: list[Union[str, dict]],
792
        parent_name: Union[str, GroupConfig, None],
Baber's avatar
Baber committed
793
        update_config: Union[dict, None],
794
795
796
797
798
799
800
801
    ) -> Mapping:
        """Return merged mapping of all subtasks, handling duplicates."""
        fn = functools.partial(
            self._load_individual_task_or_group,
            parent_name=parent_name,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtasks))))
Baber's avatar
Baber committed
802

803
804
    def _load_individual_task_or_group(
        self,
Baber's avatar
Baber committed
805
        payload: Union[str, dict],
806
        *,
Baber's avatar
Baber committed
807
808
        parent_name: Union[str, None] = None,
        update_config: Union[dict, None] = None,
809
810
811
812
813
    ) -> Mapping:
        """
        Public helper that turns *payload* (str task/group/tag **or** dict config)
        into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
        """
Baber's avatar
Baber committed
814

815
816
817
818
819
820
821
822
        # ------------------------------------------------------------------ STRING
        if isinstance(payload, str):
            # If caller supplied extra overrides, treat as dict immediately
            if update_config:
                return self._load_individual_task_or_group(
                    {"task": payload, **update_config},
                    parent_name=parent_name,
                )
823

824
825
826
827
828
829
830
831
832
833
834
835
836
837
            # ------------ registered TASK (YAML or python) -----------------
            if self._name_is_task(payload) or self._name_is_python_task(payload):
                yaml_path = self._get_yaml_path(payload)
                cfg = self._get_config(payload)

                # task_list configs: extract the per-task override ------------
                if "task_list" in cfg:
                    override = next(
                        (
                            entry
                            for entry in cfg["task_list"]
                            if isinstance(entry, dict) and entry.get("task") == payload
                        ),
                        None,
Lintang Sutawika's avatar
Lintang Sutawika committed
838
                    )
839
840
841
842
843
844
845
846
847
848
849
850
851
                    base = {k: v for k, v in cfg.items() if k != "task_list"}
                    if override:
                        cfg = {**base, **override, "task": payload}
                return self._create_task_object(cfg, payload, yaml_path)

            # ------------ registered GROUP ----------------------------------
            if self._name_is_group(payload):
                group_cfg = self._get_config(payload)
                grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS}
                grp_obj, subtasks = self._create_group_object(grp_only, parent_name)
                return {
                    grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None)
                }
852

853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
            # ------------ registered TAG ------------------------------------
            if self._name_is_tag(payload):
                return self._process_tag_subtasks(payload, update_config=None)

            raise ValueError(f"Unknown task / group / tag name: {payload!r}")

        # ------------------------------------------------------------------- DICT
        if isinstance(payload, dict):
            # ------------------ simple 'task: name' dict --------------------
            if self._config_is_task(payload):
                name = payload["task"]
                # override existing registered YAML if exists
                if self._name_is_registered(name):
                    base_cfg = self._get_config(name)
                    yaml_path = self._get_yaml_path(name)
                    merged = {**base_cfg, **payload}
869
                else:
870
                    merged = payload
871
                    yaml_path = None
872

873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
                # duplicate-naming guard when inside a group
                if parent_name is not None:
                    count = len(
                        [
                            n
                            for n in self.task_group_map[parent_name]
                            if n.startswith(name)
                        ]
                    )
                    if count:
                        name = f"{name}-{count}"
                    self.task_group_map[parent_name].append(name)

                return self._create_task_object(merged, name, yaml_path)

            # ----------------- literal group dict (task: [...]) -------------
            if self._config_is_group(payload):
                grp_cfg = {k: v for k, v in payload.items() if k in GROUP_ONLY_KEYS}
                sub_override = {
                    k: v for k, v in payload.items() if k not in GROUP_ONLY_KEYS
                } or None
                grp_obj, subtasks = self._create_group_object(grp_cfg, parent_name)
                return {grp_obj: self._load_subtasks(subtasks, grp_obj, sub_override)}

            # ----------------- python-task dict ('class': …) ----------------
            if self._config_is_python_task(payload):
                name = payload["task"]
                return self._create_task_object(payload, name, yaml_path=None)

        raise TypeError(
            f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
904
        )
905

Baber's avatar
Baber committed
906
    def load_task_or_group(
Baber's avatar
nit  
Baber committed
907
908
        self, task_list: Optional[Union[str, list[str]]] = None
    ) -> dict:
Baber's avatar
Baber committed
909
910
911
912
913
914
915
916
917
918
        """
        Load multiple tasks or groups from a list of names.

        This is the main entry point for loading tasks. It handles lists
        of task names and delegates to _load_individual_task_or_group for
        each item, then merges the results.

        Args:
            task_list: Single task name or list of task names to load.
                      Can include individual tasks, groups, and tags.
919

Baber's avatar
Baber committed
920
921
922
        Returns:
            Dictionary mapping task/group names to loaded task objects.
            Results from all requested items are merged into a single dict.
923

Baber's avatar
Baber committed
924
925
926
927
928
929
930
931
932
933
        Example:
            Load multiple tasks::

                tasks = tm.load_task_or_group(["hellaswag", "arc_easy"])
                # Returns: {"hellaswag": Task1, "arc_easy": Task2}

            Load a group::

                tasks = tm.load_task_or_group("arc_group")
                # Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
934
935
936
        """
        if isinstance(task_list, str):
            task_list = [task_list]
937

938
        all_loaded_tasks = dict(
Baber Abbasi's avatar
Baber Abbasi committed
939
940
941
942
943
944
            collections.ChainMap(
                *map(
                    lambda task: self._load_individual_task_or_group(task),
                    task_list,
                )
            )
945
946
947
        )
        return all_loaded_tasks

Baber's avatar
nit  
Baber committed
948
    def load_config(self, config: dict) -> Mapping:
Baber's avatar
Baber committed
949
950
951
952
953
954
955
956
957
958
959
960
961
        """
        Load a task from an inline configuration dictionary.

        Args:
            config: Configuration dictionary defining the task

        Returns:
            Mapping of task name to loaded task object

        Example:
            >>> config = {"task": "hellaswag", "num_fewshot": 5}
            >>> task_dict = tm.load_config(config)
        """
962
963
        return self._load_individual_task_or_group(config)

Baber's avatar
nit  
Baber committed
964
    def _get_task_and_group(self, task_dir: Union[str, Path]) -> dict[str, dict]:
Baber's avatar
Baber committed
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
        """
        Scan a directory for task configurations and build an index.

        Creates a dictionary of task metadata by recursively scanning for
        YAML files and parsing their configurations. This method handles:
        - Regular task configs with 'task' key
        - Python class-based tasks with 'class' key
        - Group configs with 'group' key
        - Task list configs with 'task_list' key
        - Tag extraction and registration

        Args:
            task_dir: Directory path to scan for YAML task configurations

        Returns:
            Dictionary mapping task names to metadata dictionaries.
            Each metadata dict contains:
            - 'type': One of 'task', 'python_task', 'group', 'tag'
            - 'yaml_path': Path to source YAML file (or -1 for generated entries)
            - 'task': For groups/tags, list of constituent task names

        Note:
            This method is called during TaskManager initialization to build
            the master task index. It uses 'simple' parsing mode for performance.
989
        """
990

Baber's avatar
Baber committed
991
        def _populate_tags_and_groups(
Baber's avatar
nit  
Baber committed
992
            config: dict, task: str, tasks_and_groups: dict[str, dict]
Baber's avatar
Baber committed
993
        ) -> None:
Baber's avatar
Baber committed
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
            """
            Extract and register tags from a task configuration.

            Tags allow grouping tasks by theme or category. This function
            processes the 'tag' field in task configs and maintains tag
            indices for quick lookup.

            Args:
                config: Task configuration dictionary
                task: Name of the task being processed
                tasks_and_groups: Master index to update with tag information
            """
1006
            # TODO: remove group in next release
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
            if "tag" in config:
                attr_list = config["tag"]
                if isinstance(attr_list, str):
                    attr_list = [attr_list]

                for tag in attr_list:
                    if tag not in tasks_and_groups:
                        tasks_and_groups[tag] = {
                            "type": "tag",
                            "task": [task],
                            "yaml_path": -1,
                        }
                    elif tasks_and_groups[tag]["type"] != "tag":
Lintang Sutawika's avatar
Lintang Sutawika committed
1020
                        eval_logger.info(
1021
1022
                            f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
                            "This may affect tasks you want to call."
1023
                        )
1024
1025
1026
                        break
                    else:
                        tasks_and_groups[tag]["task"].append(task)
1027

Lintang Sutawika's avatar
Lintang Sutawika committed
1028
        # TODO: remove group in next release
1029
1030
1031
1032
        # ignore_dirs = [
        #     "__pycache__",
        #     ".ipynb_checkpoints",
        # ]
1033
        tasks_and_groups = collections.defaultdict()
Baber's avatar
Baber committed
1034
1035
        task_dir_path = Path(task_dir)

1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
        for yaml_path in iter_yaml_files(task_dir_path):
            try:
                config = load_yaml_config(
                    yaml_path, mode="simple", resolve_includes=False
                )
            except (FileNotFoundError, YAMLError, OSError) as err:
                eval_logger.debug(f"File {yaml_path} could not be loaded ({err})")
                continue
            if self._config_is_python_task(config):
                # This is a python class config
                task = config["task"]
                self._register_task(
                    task,
                    "python_task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_group(config):
                # This is a group config
                tasks_and_groups[config["group"]] = {
                    "type": "group",
                    "task": -1,  # This signals that
                    # we don't need to know
                    # the task list for indexing
                    # as it can be loaded
                    # when called.
                    "yaml_path": str(yaml_path),
                }
lintangsutawika's avatar
lintangsutawika committed
1066

1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
                # # Registered the level 1 tasks from a group config
                # for config in config["task"]:
                #     if isinstance(config, dict) and self._config_is_task(config):
                #         task = config["task"]
                #         tasks_and_groups[task] = {
                #             "type": "task",
                #             "yaml_path": yaml_path,
                #             }

            elif self._config_is_task(config):
                # This is a task config
                task = config["task"]
                self._register_task(
                    task,
                    "task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_task_list(config):
                # This is a task_list config
                for task_entry in config["task_list"]:
                    if isinstance(task_entry, dict) and "task" in task_entry:
                        task_name = task_entry["task"]
1092
                        self._register_task(
1093
                            task_name,
1094
                            "task",
Baber's avatar
Baber committed
1095
                            str(yaml_path),
1096
1097
1098
                            tasks_and_groups,
                            config,
                            _populate_tags_and_groups,
1099
                        )
1100
1101
            else:
                eval_logger.debug(f"File {yaml_path} could not be loaded")
1102
1103

        return tasks_and_groups
lintangsutawika's avatar
lintangsutawika committed
1104

1105

Baber's avatar
nit  
Baber committed
1106
def get_task_name_from_config(task_config: dict[str, str]) -> str:
Baber's avatar
Baber committed
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
    """
    Extract a task name from a configuration dictionary.

    Determines the canonical name for a task based on its configuration,
    with fallback strategies for different config formats.

    Args:
        task_config: Task configuration dictionary

    Returns:
        String name for the task

    Example:
        >>> config = {"task": "hellaswag", "num_fewshot": 5}
        >>> get_task_name_from_config(config)
        'hellaswag'

        >>> config = {"dataset_path": "custom", "dataset_name": "mytask"}
        >>> get_task_name_from_config(config)
        'custom_mytask'
    """
1128
1129
1130
1131
1132
1133
    if "task" in task_config:
        return task_config["task"]
    if "dataset_name" in task_config:
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
lintangsutawika's avatar
lintangsutawika committed
1134

1135

Baber's avatar
nit  
Baber committed
1136
def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> str:
Baber's avatar
Baber committed
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
    """
    Extract the name from an instantiated task object.

    Handles both ConfigurableTask and legacy Task objects with different
    attribute conventions for storing the task name.

    Args:
        task_object: An instantiated task object

    Returns:
        String name of the task

    Example:
        >>> task = ConfigurableTask(config={"task": "hellaswag"})
        >>> get_task_name_from_object(task)
        'hellaswag'
    """
1154
1155
    if hasattr(task_object, "config"):
        return task_object._config["task"]
lintangsutawika's avatar
lintangsutawika committed
1156
1157
1158
1159
1160
1161
1162
1163
1164

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )

1165

Baber's avatar
nit  
Baber committed
1166
def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
Baber's avatar
Baber committed
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
    """
    Validate that no tasks appear in multiple groups simultaneously.

    Helper function used to prevent conflicts when multiple groups claim
    the same constituent task. This could lead to ambiguous configuration
    like conflicting num_fewshot values.

    Args:
        task_dict: Dictionary mapping group names to lists of subtask names

    Raises:
        ValueError: If any tasks appear in multiple groups

    Example:
        >>> task_dict = {
        ...     "group1": ["task_a", "task_b"],
        ...     "group2": ["task_b", "task_c"]  # task_b appears twice!
        ... }
        >>> _check_duplicates(task_dict)  # Raises ValueError
Lintang Sutawika's avatar
Lintang Sutawika committed
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
    """
    subtask_names = []
    for key, value in task_dict.items():
        subtask_names.extend(value)

    duplicate_tasks = {
        task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
    }

    # locate the potentially problematic groups that seem to 'compete' for constituent subtasks
    competing_groups = [
        group
        for group in task_dict.keys()
        if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
    ]

    if len(duplicate_tasks) > 0:
        raise ValueError(
            f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
        )


1208
def get_task_dict(
Baber's avatar
nit  
Baber committed
1209
    task_name_list: Union[str, list[Union[str, dict, "Task"]]],
1210
    task_manager: Optional[TaskManager] = None,
Baber's avatar
nit  
Baber committed
1211
) -> dict[str, Union["ConfigurableTask", "Task"]]:
Baber's avatar
Baber committed
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
    """
    Create a dictionary of task objects from mixed input types.

    This is the main public API for loading tasks. It accepts various input
    formats (names, configs, objects) and returns a unified dictionary of
    instantiated task objects ready for evaluation.

    The function handles:
    - String task names (looked up via TaskManager)
    - Configuration dictionaries (processed as inline configs)
    - Pre-instantiated Task objects (used as-is)
    - Validation to prevent conflicting group memberships

    Args:
        task_name_list: Mixed list of task specifications:
                       - str: Task name to look up
                       - dict: Inline task configuration
                       - Task: Pre-instantiated task object
        task_manager: TaskManager instance for name resolution.
                     If None, creates a default TaskManager.

    Returns:
        Dictionary mapping task names to instantiated task objects.
        All tasks are ready for evaluation.

    Raises:
        TypeError: If task_name_list contains unsupported types
        ValueError: If there are conflicting group memberships

    Example:
        Mixed input types::

            tasks = get_task_dict([
                "hellaswag",                              # lookup by name
                {"task": "arc_easy", "num_fewshot": 5},   # inline config
                pre_existing_task_object                  # direct object
            ])

        Simple case::

            tasks = get_task_dict("hellaswag")
            # Returns: {"hellaswag": ConfigurableTask(...)}

        With custom TaskManager::

            tm = TaskManager(include_path="/custom/tasks")
            tasks = get_task_dict(["custom_task"], task_manager=tm)
1259
    """
1260
    from lm_eval.api.task import Task
1261

1262
    # Normalize input to list
1263
    if isinstance(task_name_list, str):
lintangsutawika's avatar
lintangsutawika committed
1264
        task_name_list = [task_name_list]
1265
    elif not isinstance(task_name_list, list):
1266
1267
1268
        raise TypeError(
            f"Expected a 'str' or 'list' but received {type(task_name_list)}."
        )
lintangsutawika's avatar
lintangsutawika committed
1269

1270
1271
1272
1273
    # Validate list items
    if not all(isinstance(task, (str, dict, Task)) for task in task_name_list):
        raise TypeError(
            "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
1274
        )
1275

1276
1277
1278
    # Ensure we have a task manager
    if task_manager is None:
        task_manager = TaskManager()
Lintang Sutawika's avatar
Lintang Sutawika committed
1279

1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
    # Process all items
    final_task_dict = {}
    for task_spec in task_name_list:
        if isinstance(task_spec, Task):
            # Pre-instantiated task object
            task_name = get_task_name_from_object(task_spec)
            if task_name in final_task_dict:
                raise ValueError(f"Duplicate task name: {task_name}")
            final_task_dict[task_name] = task_spec
        else:
            # String or dict - use load_task_or_group
            result = task_manager.load_task_or_group(task_spec)
            # Check for duplicate names
            for name in result:
                if name in final_task_dict:
                    raise ValueError(f"Duplicate task name: {name}")
            final_task_dict.update(result)

    # Check for conflicting group memberships
Lintang Sutawika's avatar
Lintang Sutawika committed
1299
1300
1301
    _check_duplicates(get_subtask_list(final_task_dict))

    return final_task_dict