__init__.py 45.6 KB
Newer Older
1
2
3
4
# ruff: noqa E402
from __future__ import annotations


Baber's avatar
Baber committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""
Task Management Module for LM Evaluation Harness.

This module provides comprehensive task discovery, loading, and management functionality
for the LM Evaluation Harness. It handles YAML configuration parsing with include support,
dynamic function importing, and task indexing across multiple directories.

Key Components:
- TaskManager: Main class for task discovery and management
- YAML configuration loading with !function tag support
- Task, group, and tag indexing
- Include resolution with cycle detection
- Caching for performance optimization

Example:
    Basic usage::

        task_manager = TaskManager()
        all_tasks = task_manager.all_tasks
        task_config = task_manager._get_config("hellaswag")

    Custom task paths::

        task_manager = TaskManager(
            include_path="/path/to/custom/tasks",
            include_defaults=True
        )
"""
33
import collections
34
35
import functools
import importlib.util
36
import inspect
37
import logging
38
import sys
39
from functools import partial
Baber's avatar
Baber committed
40
from pathlib import Path
Baber's avatar
nit  
Baber committed
41
42
43
44
45
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
)
46
47
48

import yaml
from yaml import YAMLError
&'s avatar
& committed
49

50
from lm_eval.api.group import GroupConfig
Lintang Sutawika's avatar
Lintang Sutawika committed
51
from lm_eval.evaluator_utils import get_subtask_list
Baber's avatar
nit  
Baber committed
52
53
54
55
from lm_eval.utils import pattern_match, setup_logging


if TYPE_CHECKING:
56
57
    from collections.abc import Generator, Mapping

Baber's avatar
nit  
Baber committed
58
    from lm_eval.api.task import ConfigurableTask, Task
Lintang Sutawika's avatar
Lintang Sutawika committed
59

Baber's avatar
nit  
Baber committed
60
eval_logger = logging.getLogger(__name__)
Lintang Sutawika's avatar
Lintang Sutawika committed
61

Baber's avatar
Baber committed
62
#: List of configuration keys that are specific to groups only
Lintang Sutawika's avatar
Lintang Sutawika committed
63
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
Baber's avatar
Baber committed
64
65

#: Base YAML loader class - uses C loader if available for performance
Baber's avatar
nit  
Baber committed
66
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
Baber's avatar
Baber committed
67
68

#: Directory names to ignore during task discovery
Baber's avatar
nit  
Baber committed
69
70
71
72
_IGNORE_DIRS = (
    "__pycache__",
    ".ipynb_checkpoints",
)
73

lintangsutawika's avatar
lintangsutawika committed
74

75
76
def _mk_function_ctor(base_dir: Path, resolve: bool):
    """Return a constructor that resolves !function relative to *base_dir*."""
77

78
79
80
81
82
    def ctor(loader: yaml.Loader, node: yaml.Node):
        spec = loader.construct_scalar(node)
        if not resolve:  # “simple” mode → stub
            return lambda *a, **kw: None
        return _import_function(spec, base_dir)
83

84
85
86
87
88
89
90
    return ctor


@functools.lru_cache(maxsize=1024)
def make_yaml_loader(base_dir: Path, *, simple: bool) -> type[yaml.Loader]:
    """Factory that returns a *cached* PyYAML Loader subclass bound to *base_dir*.
    simple=True  →  !function returns a stub (used when only metadata is needed).
91
92
93
    """

    class Loader(_Base):
94
        pass  # dynamic subclass just to carry custom constructors
95

96
97
98
99
100
    yaml.add_constructor(
        "!function",
        _mk_function_ctor(base_dir, resolve=not simple),
        Loader=Loader,
    )
101
102
103
    return Loader


104
105
106
107
108
@functools.lru_cache(maxsize=4096)
def _read_yaml(path: Path, *, resolve_functions: bool) -> dict:
    loader_cls = make_yaml_loader(path.parent, simple=not resolve_functions)
    with path.open("rb") as fh:
        return yaml.load(fh, Loader=loader_cls)
Baber's avatar
Baber committed
109
110


111
112
113
114
115
116
117
@functools.cache
def _import_function(qual: str, base_dir: Path):
    """Import `qual` where qual looks like  "my_utils.some_fn".
    Search order:
      1. <base_dir>/my_utils.py            (relative file)
      2. python importlib (package/module already importable)
    Uses file *mtime* so edits are reloaded without killing the process.
Baber's avatar
Baber committed
118
    """
119
    import importlib
120

121
122
123
    if "." not in qual:
        msg = f"!function value '{qual}' must contain a '.'"
        raise ValueError(msg)
Baber's avatar
Baber committed
124

125
126
    mod_part, _, fn_name = qual.rpartition(".")
    relative_path = (base_dir / f"{mod_part.replace('.', '/')}.py").resolve()
127

128
129
130
131
132
133
134
135
136
137
138
    if relative_path.exists():
        mtime = relative_path.stat().st_mtime_ns  # for cache busting
        module_key = f"{relative_path}:{mtime}"
        if module_key in sys.modules:
            mod = sys.modules[module_key]
        else:
            spec = importlib.util.spec_from_file_location(module_key, relative_path)
            mod = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(mod)  # type: ignore[arg-type]
            sys.modules[module_key] = mod
        return getattr(mod, fn_name)
139

140
141
    # Fallback to regular import mechanism
    import importlib
Baber's avatar
nit  
Baber committed
142

143
144
    module = importlib.import_module(mod_part)
    return getattr(module, fn_name)
Baber's avatar
nit  
Baber committed
145
146


147
def load_yaml_config(
148
    yaml_path: Path | str,
149
    *,
150
    resolve_functions: bool = True,
151
    resolve_includes: bool = True,
152
    _seen: set[tuple[Path, bool]] | None = None,
153
) -> dict:
154
155
    """Read YAML once, optionally walk `include:` chains, with cycle detection."""
    path = Path(yaml_path).expanduser().resolve()
156
157
    if _seen is None:
        _seen = set()
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    key = (path, resolve_functions)
    if key in _seen:
        msg = f"Include cycle at {path}"
        raise ValueError(msg)
    _seen.add(key)

    cfg = _read_yaml(path, resolve_functions=resolve_functions)

    if not resolve_includes or "include" not in cfg:
        return cfg

    base_dir = path.parent
    merged: dict = {}
    for inc in cfg.pop("include"):
        inc_path = (
            (base_dir / inc).resolve() if not Path(inc).is_absolute() else Path(inc)
174
        )
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        merged.update(
            load_yaml_config(
                inc_path,
                resolve_functions=resolve_functions,
                _seen=_seen,
            ),
        )
    merged.update(cfg)  # local keys win
    return merged


# def load_yaml_config(
#     yaml_path: Union[Path, str, None] = None,
#     yaml_config: Optional[dict] = None,
#     yaml_dir: Optional[Path] = None,
#     mode: str = "full",
#     *,
#     _seen: Optional[set[tuple[Path, str]]] = None,
#     resolve_includes: bool = True,
# ) -> dict:
#     """
#     Parse a YAML config with optional include handling.
#
#     Parameters
#     ----------
#     yaml_path
#         Path to the main YAML file.  Needed unless *yaml_config* is
#         supplied directly (e.g. by tests).
#     yaml_config
#         Pre-parsed dict to use instead of reading *yaml_path*.
#     yaml_dir
#         Base directory for resolving relative include paths.  Defaults
#         to `yaml_path.parent`.
#     mode
#         "full"  - honour  !function  tags
#         "simple" - ignore !function  (faster).
#     _seen
#         **Internal** recursion set: tuples of (absolute-path, mode).
#         Prevents include cycles such as  A → B → A.
#     """
#     if yaml_config is None and yaml_path is None:
#         raise ValueError("load_yaml_config needs either yaml_path or yaml_config")
#
#     # ------------------------------------------------------------------ cycle guard
#     if _seen is None:
#         _seen = set()
#     if yaml_path is not None:
#         yaml_path = Path(yaml_path).expanduser().resolve()
#
#         # ---------- fast-path: use LRU cached function ----------
#         if yaml_config is None and resolve_includes:
#             return _get_cached_config(yaml_path, mode)
#
#         key = (yaml_path.resolve(), mode)
#         if key in _seen:
#             raise ValueError(f"Include cycle detected at {yaml_path}")
#         _seen.add(key)
#
#     # ------------------------------------------------------------------ load / parse
#     if yaml_config is None:  # ordinary path-based load
#         yaml_config = _parse_yaml_file(yaml_path, mode)
#
#     if yaml_dir is None and yaml_path is not None:
#         yaml_dir = yaml_path.parent
#     assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"
#
#     # ------------------------------------------------------------------ handle include
#     include = yaml_config.pop("include", None)
#     if not include and not resolve_includes:
#         return yaml_config
#
#     include_paths = include if isinstance(include, list) else [include]
#     final_cfg: dict = {}
#
#     for inc in reversed(include_paths):
#         if inc is None:  # guard against explicit nulls
#             continue
#         inc_path = Path(inc)
#         if not inc_path.is_absolute():
#             inc_path = (yaml_dir / inc_path).resolve()
#         included = load_yaml_config(
#             yaml_path=inc_path,
#             mode=mode,
#             yaml_dir=inc_path.parent,
#             _seen=_seen,  # <-- pass set downward
#         )
#         final_cfg.update(included)
#
#     final_cfg.update(yaml_config)  # local keys win
#     return final_cfg
265
266


267
def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
268
    """Recursively iterate over all YAML files in a directory tree.
Baber's avatar
Baber committed
269
270
271
272
273
274
275
276
277
278
279
280

    Excludes files in ignored directories like __pycache__ and .ipynb_checkpoints.

    Args:
        root: Root directory to search for YAML files

    Yields:
        Path objects for each discovered YAML file

    Example:
        >>> for yaml_file in iter_yaml_files(Path("tasks")):
        ...     print(f"Found task config: {yaml_file}")
281

Baber's avatar
Baber committed
282
    """
283
284
    # for p in iglob(str(root / "**/*.yaml"), recursive=True):
    for p in root.glob("**/*.yaml"):
Baber's avatar
nit  
Baber committed
285
        # ignore check
Baber's avatar
Baber committed
286
287
288
        path = Path(p)
        # Check if any parent directory is in the ignore list
        if any(part in ignore for part in path.parts):
289
            continue
Baber's avatar
Baber committed
290
        yield path
Lintang Sutawika's avatar
Lintang Sutawika committed
291

292

293
class TaskManager:
294
    """Central manager for task discovery, indexing, and loading.
Baber's avatar
Baber committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

    TaskManager scans directories for YAML task configurations and maintains
    an index of all available tasks, groups, and tags. It provides methods
    for listing, filtering, and loading tasks with their configurations.

    The manager supports:
    - Automatic discovery from default lm_eval/tasks/ directory
    - Custom task directories via include_path
    - Task grouping and tagging
    - Configuration inheritance via YAML includes
    - Caching for performance

    Attributes:
        include_path: Additional directories to search for tasks
        metadata: Global metadata to inject into all task configs
        task_group_map: Mapping of tasks to their parent groups
311

Baber's avatar
Baber committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    Example:
        Basic usage::

            tm = TaskManager()
            print(f"Found {len(tm.all_tasks)} tasks")
            hellaswag_config = tm._get_config("hellaswag")

        With custom tasks::

            tm = TaskManager(
                include_path="/my/custom/tasks",
                verbosity="INFO"
            )
            custom_tasks = [t for t in tm.all_tasks if "custom" in t]
326

327
328
    """

329
330
    def __init__(
        self,
331
332
        verbosity: str | None = None,
        include_path: str | Path | list[str | Path] | None = None,
333
        include_defaults: bool = True,
334
        metadata: dict[str, dict[str, Any]] | None = None,
335
    ) -> None:
336
        """Initialize the TaskManager.
Baber's avatar
Baber committed
337
338
339
340
341
342
343

        Args:
            verbosity: Logging verbosity level (DEBUG, INFO, WARNING, ERROR)
            include_path: Additional path(s) to search for tasks. Can be a single
                         path or list of paths.
            include_defaults: Whether to include default tasks from lm_eval/tasks/
            metadata: Global metadata dictionary to inject into all task configs
344

Baber's avatar
Baber committed
345
        """
Lintang Sutawika's avatar
Lintang Sutawika committed
346
        if verbosity is not None:
Baber's avatar
nit  
Baber committed
347
            setup_logging(verbosity)
348
        self.include_path = include_path
Baber Abbasi's avatar
Baber Abbasi committed
349
        self.metadata = metadata
350
        self._task_index = self.initialize_tasks(
351
352
            include_path=include_path,
            include_defaults=include_defaults,
353
        )
354
        self._all_tasks = sorted(self._task_index.keys())
355

356
        self._all_groups = sorted(
357
            [x for x in self._all_tasks if self._task_index[x]["type"] == "group"],
358
359
        )
        self._all_subtasks = sorted(
360
361
362
363
            [
                x
                for x in self._all_tasks
                if self._task_index[x]["type"] in ["task", "python_task"]
364
            ],
365
366
        )
        self._all_tags = sorted(
367
            [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"],
368
369
        )

370
        self.task_group_map = collections.defaultdict(list)
371

372
373
    def initialize_tasks(
        self,
374
        include_path: str | Path | list[str | Path] | None = None,
375
        include_defaults: bool = True,
Baber Abbasi's avatar
Baber Abbasi committed
376
377
    ) -> dict[str, dict]:
        """Creates a dictionary of tasks indexes.
378

Baber's avatar
nit  
Baber committed
379
        :param include_path: Union[str, list] = None
380
381
382
383
            An additional path to be searched for tasks recursively.
            Can provide more than one such path as a list.
        :param include_defaults: bool = True
            If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
384
385

        Return:
Baber's avatar
nit  
Baber committed
386
            dictionary of task names as key and task metadata
387

388
        """
389
        all_paths = [Path(__file__).parent] if include_defaults else []
390
        if include_path is not None:
Baber's avatar
Baber committed
391
            if isinstance(include_path, (str, Path)):
392
                include_path = [include_path]
Baber's avatar
Baber committed
393
394
            # Convert all paths to Path objects
            all_paths.extend(Path(p) for p in include_path)
395

396
397
398
399
        task_index = {}
        for task_dir in all_paths:
            tasks = self._get_task_and_group(task_dir)
            task_index = {**tasks, **task_index}
lintangsutawika's avatar
format  
lintangsutawika committed
400

401
402
403
        return task_index

    @property
Baber's avatar
nit  
Baber committed
404
    def all_tasks(self) -> list[str]:
Baber's avatar
Baber committed
405
        """Get sorted list of all task names (tasks, groups, and tags)."""
406
407
        return self._all_tasks

408
    @property
Baber's avatar
nit  
Baber committed
409
    def all_groups(self) -> list[str]:
Baber's avatar
Baber committed
410
        """Get sorted list of all group names."""
411
412
413
        return self._all_groups

    @property
Baber's avatar
nit  
Baber committed
414
    def all_subtasks(self) -> list[str]:
Baber's avatar
Baber committed
415
        """Get sorted list of all individual task names (excludes groups and tags)."""
416
417
418
        return self._all_subtasks

    @property
Baber's avatar
nit  
Baber committed
419
    def all_tags(self) -> list[str]:
Baber's avatar
Baber committed
420
        """Get sorted list of all tag names."""
421
422
        return self._all_tags

423
    @property
424
    def task_index(self) -> dict[str, dict[str, str | int | list[str]]]:
Baber's avatar
Baber committed
425
        """Get the complete task index with metadata for all tasks."""
426
427
        return self._task_index

428
    def list_all_tasks(
Baber's avatar
Baber committed
429
430
431
432
        self,
        list_groups: bool = True,
        list_tags: bool = True,
        list_subtasks: bool = True,
433
    ) -> str:
434
        """Return a Markdown table (as a string) listing groups, tags and/or subtasks
435
436
437
        known to this TaskManager.  Safe for configs whose yaml_path is -1 and for
        task configs whose `include:` is a list.
        """
438
439
        from pytablewriter import MarkdownTableWriter

440
        # ------------------------------------------------------------------ helpers
Baber's avatar
Baber committed
441
        def sanitize_path(path: str) -> str:
442
443
            # print a relative path for anything inside lm_eval/tasks/
            # path_str = str(path)
444
445
            if "lm_eval/tasks/" in path:
                return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
446
447
448
449
            return path

        def first_output_type_from_includes(cfg: dict, base: Path) -> str:
            """Walk cfg['include'] (string or list) and return the first
450
451
            include that itself specifies an output_type.
            """
452
453
454
455
456
457
            inc_raw = cfg.get("include")
            if not inc_raw:
                return ""

            inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw]
            for inc in inc_list:
Baber's avatar
nit  
Baber committed
458
459
460
461
462
463
464
465
466
467
                if inc:
                    inc_path = Path(inc)
                    if not inc_path.is_absolute():  # treat as relative include
                        inc_path = base.parent / inc_path
                    try:
                        inc_cfg = load_yaml_config(inc_path, mode="simple")
                    except FileNotFoundError:
                        continue
                    if "output_type" in inc_cfg:
                        return inc_cfg["output_type"]
468
469
470
            return ""

        # -------------------------------------------------------------- GROUP table
471
472
        group_table = MarkdownTableWriter()
        group_table.headers = ["Group", "Config Location"]
473
474
475
476
477
478
479
480
481
        group_table.value_matrix = [
            [
                g,
                "---"
                if self.task_index[g]["yaml_path"] == -1
                else sanitize_path(self.task_index[g]["yaml_path"]),
            ]
            for g in self.all_groups
        ]
482

483
        # ---------------------------------------------------------------- TAG table
484
485
486
487
        tag_table = MarkdownTableWriter()
        tag_table.headers = ["Tag"]
        tag_table.value_matrix = [[t] for t in self.all_tags]

488
        # ------------------------------------------------------------ SUBTASK table
489
490
        subtask_table = MarkdownTableWriter()
        subtask_table.headers = ["Task", "Config Location", "Output Type"]
491
492
        st_values: list[list[str]] = []

493
        for t in self.all_subtasks:
494
495
496
497
498
499
            raw_path = self.task_index[t]["yaml_path"]

            if raw_path == -1:
                # python-only task or generated at runtime
                display_path = "---"
                output_type = ""
500
            else:
501
502
503
504
505
506
507
508
509
510
511
512
                path_obj = Path(raw_path)
                display_path = sanitize_path(str(path_obj))

                # load minimal YAML to discover output_type
                cfg = load_yaml_config(path_obj, mode="simple")
                if "output_type" in cfg:
                    output_type = cfg["output_type"]
                else:
                    output_type = first_output_type_from_includes(cfg, path_obj)

            st_values.append([t, display_path, output_type])

513
514
        subtask_table.value_matrix = st_values

515
516
        # ------------------------------------------------------------- final string
        parts: list[str] = ["\n"]
517
        if list_groups:
518
519
            parts.append(group_table.dumps())
            parts.append("\n")
520
        if list_tags:
521
522
            parts.append(tag_table.dumps())
            parts.append("\n")
523
        if list_subtasks:
524
525
526
527
            parts.append(subtask_table.dumps())
            parts.append("\n")

        return "".join(parts)
528

Baber Abbasi's avatar
Baber Abbasi committed
529
    def match_tasks(self, task_list: list[str]) -> list[str]:
530
        """Match task names using glob-style pattern matching."""
Baber's avatar
nit  
Baber committed
531
        return pattern_match(task_list, self.all_tasks)
532

Baber Abbasi's avatar
Baber Abbasi committed
533
    def _name_is_registered(self, name: str) -> bool:
Baber's avatar
Baber committed
534
        """Check if a name is registered in the task index."""
535
        return name in self.all_tasks
536

Baber Abbasi's avatar
Baber Abbasi committed
537
    def _name_is_task(self, name: str) -> bool:
Baber's avatar
Baber committed
538
        """Check if a name refers to an individual task (not group or tag)."""
539
540
541
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "task"
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
542

Baber Abbasi's avatar
Baber Abbasi committed
543
    def _name_is_tag(self, name: str) -> bool:
Baber's avatar
Baber committed
544
        """Check if a name refers to a tag."""
545
        return self._name_is_registered(name) and self.task_index[name]["type"] == "tag"
546

Baber Abbasi's avatar
Baber Abbasi committed
547
    def _name_is_group(self, name: str) -> bool:
Baber's avatar
Baber committed
548
        """Check if a name refers to a group."""
549
550
551
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "group"
        )
552

Baber Abbasi's avatar
Baber Abbasi committed
553
    def _name_is_python_task(self, name: str) -> bool:
Baber's avatar
Baber committed
554
        """Check if a name refers to a Python-defined task."""
555
556
557
558
        return (
            self._name_is_registered(name)
            and self.task_index[name]["type"] == "python_task"
        )
559

Baber's avatar
fix  
Baber committed
560
561
    @staticmethod
    def _config_is_task(config: dict) -> bool:
Baber's avatar
Baber committed
562
        """Check if a config dictionary defines a single task."""
563
        return "task" in config and isinstance(config["task"], str)
564

Baber's avatar
fix  
Baber committed
565
566
    @staticmethod
    def _config_is_group(config: dict) -> bool:
Baber's avatar
Baber committed
567
        """Check if a config dictionary defines a group of tasks."""
568
        return "task" in config and isinstance(config["task"], list)
569

Baber's avatar
fix  
Baber committed
570
571
    @staticmethod
    def _config_is_python_task(config: dict) -> bool:
Baber's avatar
Baber committed
572
        """Check if a config dictionary defines a Python class-based task."""
573
574
        return "class" in config

Baber's avatar
fix  
Baber committed
575
576
    @staticmethod
    def _config_is_task_list(config: dict) -> bool:
Baber's avatar
Baber committed
577
        """Check if a config dictionary defines a task list."""
578
        return "task_list" in config and isinstance(config["task_list"], list)
579

580
581
    def _get_yaml_path(self, name: str) -> str | int | list[str]:
        """Get the YAML file path for a registered task.
Baber's avatar
Baber committed
582
583
584
585
586
587
588
589
590

        Args:
            name: Task name

        Returns:
            Path to YAML file, or -1 for Python-only tasks

        Raises:
            ValueError: If task name is not registered
591

Baber's avatar
Baber committed
592
        """
593
594
        if name not in self.task_index:
            raise ValueError
595
596
        return self.task_index[name]["yaml_path"]

Baber's avatar
nit  
Baber committed
597
    def _get_config(self, name: str) -> dict:
598
        """Load the full configuration for a registered task.
Baber's avatar
Baber committed
599
600
601
602
603
604
605
606
607

        Args:
            name: Task name

        Returns:
            Complete task configuration dictionary

        Raises:
            ValueError: If task name is not registered
608

Baber's avatar
Baber committed
609
        """
610
611
        if name not in self.task_index:
            raise ValueError
612
613
614
        yaml_path = self._get_yaml_path(name)
        if yaml_path == -1:
            return {}
615
        return load_yaml_config(Path(yaml_path))
616

617
618
    def _get_tasklist(self, name: str) -> list[str] | int:
        """Get the task list for a group or tag.
Baber's avatar
Baber committed
619
620
621
622
623
624
625
626
627

        Args:
            name: Group or tag name

        Returns:
            List of task names in the group/tag

        Raises:
            ValueError: If name refers to an individual task
628

Baber's avatar
Baber committed
629
        """
630
631
        if self._name_is_task(name):
            raise ValueError
632
633
        return self.task_index[name]["task"]

634
635
636
637
638
    def _register_task(
        self,
        task_name: str,
        task_type: str,
        yaml_path: str,
Baber's avatar
nit  
Baber committed
639
        tasks_and_groups: dict[str, dict],
640
641
        config: dict | None = None,
        populate_tags_fn: Callable | None = None,
Baber's avatar
Baber committed
642
    ) -> None:
643
        """Helper method to register a task in the tasks_and_groups dict."""
644
645
646
647
648
649
        tasks_and_groups[task_name] = {
            "type": task_type,
            "yaml_path": yaml_path,
        }
        # Only populate tags for configs that support it (not groups)
        if config and task_type != "group" and populate_tags_fn:
650
            populate_tags_fn(config, task_name, tasks_and_groups)
651
652

    def _merge_task_configs(
653
654
655
656
        self,
        base_config: dict,
        task_specific_config: dict,
        task_name: str,
Baber's avatar
nit  
Baber committed
657
    ) -> dict:
658
        """Merge base config with task-specific overrides for task_list configs."""
659
660
661
662
663
664
        if task_specific_config:
            task_specific_config = task_specific_config.copy()
            task_specific_config.pop("task", None)
            return {**base_config, **task_specific_config, "task": task_name}
        return {**base_config, "task": task_name}

Baber's avatar
Baber committed
665
    def _process_tag_subtasks(
666
667
668
        self,
        tag_name: str,
        update_config: dict | None = None,
Baber's avatar
nit  
Baber committed
669
    ) -> dict:
670
        """Process subtasks for a tag and return loaded tasks."""
671
672
673
674
675
676
677
        subtask_list = self._get_tasklist(tag_name)
        fn = partial(
            self._load_individual_task_or_group,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))

678
679
    def _process_alias(self, config: dict, group: str | None = None) -> dict:
        """Process group alias configuration.
Baber's avatar
Baber committed
680
681
682
683
684
685
686
687
688
689

        If the group is not the same as the original group which the group alias
        was intended for, set the group_alias to None instead.

        Args:
            config: Task configuration dictionary
            group: Group name to validate against

        Returns:
            Modified configuration with processed aliases
690

Baber's avatar
Baber committed
691
        """
692
693
694
695
696
697
698
        if (
            ("group_alias" in config)
            and ("group" in config)
            and group is not None
            and config["group"] != group
        ):
            config["group_alias"] = None
699
700
        return config

Baber's avatar
Baber committed
701
    def _class_has_config_in_constructor(self, cls) -> bool:
702
        """Check if a class constructor accepts a 'config' parameter.
Baber's avatar
Baber committed
703
704
705
706
707
708

        Args:
            cls: Class to inspect

        Returns:
            True if constructor has 'config' parameter, False otherwise
709

Baber's avatar
Baber committed
710
        """
711
712
713
714
715
716
717
        constructor = getattr(cls, "__init__", None)
        return (
            "config" in inspect.signature(constructor).parameters
            if constructor
            else False
        )

718
719
720
721
722
    ###############################################################################
    # NEW: Refactored _load_individual_task_or_group and helper methods          #
    ###############################################################################

    def _create_task_object(
723
        self,
724
725
        cfg: dict,
        task_name: str,
726
        yaml_path: str | None,
727
    ) -> dict:
728
        """Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
729
        Returns {task_name: task_object}.
Baber's avatar
Baber committed
730
        """
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
        from lm_eval.api.task import ConfigurableTask, Task  # local import avoids cycle

        # ---- include handling ---------------------------------------------------
        if "include" in cfg:
            # keep original name so include merging cannot clobber it
            orig_name = cfg.get("task", task_name)
            cfg = {
                **load_yaml_config(  # recurse once, cached
                    yaml_path=Path(yaml_path) if yaml_path else None,
                    yaml_config={"include": cfg.pop("include")},
                    mode="full" if yaml_path else "simple",
                ),
                **cfg,
                "task": orig_name,
            }
746

747
748
749
750
751
752
753
754
755
756
757
758
        # ---- metadata merge -----------------------------------------------------
        if self.metadata is not None:
            cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
        else:
            cfg["metadata"] = cfg.get("metadata", {})

        # ---- python-task vs YAML-task -------------------------------------------
        if self._config_is_python_task(cfg):
            cls = cfg["class"]
            task_obj: Task
            if self._class_has_config_in_constructor(cls):
                task_obj = cls(config=cfg)
759
            else:
760
761
762
763
764
765
                task_obj = cls()
            # make sure name propagates when the class inherits ConfigurableTask
            if isinstance(task_obj, ConfigurableTask):  # type: ignore
                task_obj.config.task = task_name
        else:
            task_obj = ConfigurableTask(config=cfg)  # type: ignore
Baber's avatar
Baber committed
766

767
        return {task_name: task_obj}
Baber's avatar
Baber committed
768

769
770
771
    def _create_group_object(
        self,
        cfg: dict,
772
773
774
        parent_name: str | None = None,
    ) -> tuple[GroupConfig, list[str | dict]]:
        """Build GroupConfig and return (group_obj, subtask_names).
775
776
777
778
779
        Resolves tag expansion.
        """
        if self.metadata is not None:
            cfg["metadata"] = cfg.get("metadata", {}) | self.metadata

780
        grp = GroupConfig(**cfg)
781
        subtasks: list[str | dict] = []
Baber's avatar
fix  
Baber committed
782
783
784
785
786
787
        if grp.task:
            for t in grp.task:
                if isinstance(t, str) and self._name_is_tag(t):
                    subtasks.extend(self._get_tasklist(t))
                else:
                    subtasks.append(t)
788
        return grp, subtasks
Baber's avatar
Baber committed
789

790
791
    def _load_subtasks(
        self,
792
793
794
        subtasks: list[str | dict],
        parent_name: str | GroupConfig | None,
        update_config: dict | None,
795
796
797
798
799
800
801
802
    ) -> Mapping:
        """Return merged mapping of all subtasks, handling duplicates."""
        fn = functools.partial(
            self._load_individual_task_or_group,
            parent_name=parent_name,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtasks))))
Baber's avatar
Baber committed
803

804
805
    def _load_individual_task_or_group(
        self,
806
        payload: str | dict,
807
        *,
808
809
        parent_name: str | None = None,
        update_config: dict | None = None,
810
    ) -> Mapping:
811
        """Public helper that turns *payload* (str task/group/tag **or** dict config)
812
813
814
815
816
817
818
819
820
821
        into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
        """
        # ------------------------------------------------------------------ STRING
        if isinstance(payload, str):
            # If caller supplied extra overrides, treat as dict immediately
            if update_config:
                return self._load_individual_task_or_group(
                    {"task": payload, **update_config},
                    parent_name=parent_name,
                )
822

823
824
825
826
827
828
829
830
831
832
833
834
835
836
            # ------------ registered TASK (YAML or python) -----------------
            if self._name_is_task(payload) or self._name_is_python_task(payload):
                yaml_path = self._get_yaml_path(payload)
                cfg = self._get_config(payload)

                # task_list configs: extract the per-task override ------------
                if "task_list" in cfg:
                    override = next(
                        (
                            entry
                            for entry in cfg["task_list"]
                            if isinstance(entry, dict) and entry.get("task") == payload
                        ),
                        None,
Lintang Sutawika's avatar
Lintang Sutawika committed
837
                    )
838
839
840
841
842
843
844
845
846
847
848
                    base = {k: v for k, v in cfg.items() if k != "task_list"}
                    if override:
                        cfg = {**base, **override, "task": payload}
                return self._create_task_object(cfg, payload, yaml_path)

            # ------------ registered GROUP ----------------------------------
            if self._name_is_group(payload):
                group_cfg = self._get_config(payload)
                grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS}
                grp_obj, subtasks = self._create_group_object(grp_only, parent_name)
                return {
849
                    grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None),
850
                }
851

852
853
854
855
            # ------------ registered TAG ------------------------------------
            if self._name_is_tag(payload):
                return self._process_tag_subtasks(payload, update_config=None)

856
857
            msg = f"Unknown task / group / tag name: {payload!r}"
            raise ValueError(msg)
858
859
860
861
862
863
864
865
866
867
868

        # ------------------------------------------------------------------- DICT
        if isinstance(payload, dict):
            # ------------------ simple 'task: name' dict --------------------
            if self._config_is_task(payload):
                name = payload["task"]
                # override existing registered YAML if exists
                if self._name_is_registered(name):
                    base_cfg = self._get_config(name)
                    yaml_path = self._get_yaml_path(name)
                    merged = {**base_cfg, **payload}
869
                else:
870
                    merged = payload
871
                    yaml_path = None
872

873
874
875
876
877
878
879
                # duplicate-naming guard when inside a group
                if parent_name is not None:
                    count = len(
                        [
                            n
                            for n in self.task_group_map[parent_name]
                            if n.startswith(name)
880
                        ],
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
                    )
                    if count:
                        name = f"{name}-{count}"
                    self.task_group_map[parent_name].append(name)

                return self._create_task_object(merged, name, yaml_path)

            # ----------------- literal group dict (task: [...]) -------------
            if self._config_is_group(payload):
                grp_cfg = {k: v for k, v in payload.items() if k in GROUP_ONLY_KEYS}
                sub_override = {
                    k: v for k, v in payload.items() if k not in GROUP_ONLY_KEYS
                } or None
                grp_obj, subtasks = self._create_group_object(grp_cfg, parent_name)
                return {grp_obj: self._load_subtasks(subtasks, grp_obj, sub_override)}

            # ----------------- python-task dict ('class': …) ----------------
            if self._config_is_python_task(payload):
                name = payload["task"]
                return self._create_task_object(payload, name, yaml_path=None)

902
        msg = f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
903
        raise TypeError(
904
            msg,
905
        )
906

Baber's avatar
Baber committed
907
    def load_task_or_group(
908
909
        self,
        task_list: str | list[str] | None = None,
Baber's avatar
nit  
Baber committed
910
    ) -> dict:
911
        """Load multiple tasks or groups from a list of names.
Baber's avatar
Baber committed
912
913
914
915
916
917
918
919

        This is the main entry point for loading tasks. It handles lists
        of task names and delegates to _load_individual_task_or_group for
        each item, then merges the results.

        Args:
            task_list: Single task name or list of task names to load.
                      Can include individual tasks, groups, and tags.
920

Baber's avatar
Baber committed
921
922
923
        Returns:
            Dictionary mapping task/group names to loaded task objects.
            Results from all requested items are merged into a single dict.
924

Baber's avatar
Baber committed
925
926
927
928
929
930
931
932
933
934
        Example:
            Load multiple tasks::

                tasks = tm.load_task_or_group(["hellaswag", "arc_easy"])
                # Returns: {"hellaswag": Task1, "arc_easy": Task2}

            Load a group::

                tasks = tm.load_task_or_group("arc_group")
                # Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
935

936
937
938
        """
        if isinstance(task_list, str):
            task_list = [task_list]
939

940
        return dict(
Baber Abbasi's avatar
Baber Abbasi committed
941
            collections.ChainMap(
942
943
                *(self._load_individual_task_or_group(task) for task in task_list),
            ),
944
945
        )

Baber's avatar
nit  
Baber committed
946
    def load_config(self, config: dict) -> Mapping:
947
        """Load a task from an inline configuration dictionary.
Baber's avatar
Baber committed
948
949
950
951
952
953
954
955
956
957

        Args:
            config: Configuration dictionary defining the task

        Returns:
            Mapping of task name to loaded task object

        Example:
            >>> config = {"task": "hellaswag", "num_fewshot": 5}
            >>> task_dict = tm.load_config(config)
958

Baber's avatar
Baber committed
959
        """
960
961
        return self._load_individual_task_or_group(config)

962
963
    def _get_task_and_group(self, task_dir: str | Path) -> dict[str, dict]:
        """Scan a directory for task configurations and build an index.
Baber's avatar
Baber committed
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985

        Creates a dictionary of task metadata by recursively scanning for
        YAML files and parsing their configurations. This method handles:
        - Regular task configs with 'task' key
        - Python class-based tasks with 'class' key
        - Group configs with 'group' key
        - Task list configs with 'task_list' key
        - Tag extraction and registration

        Args:
            task_dir: Directory path to scan for YAML task configurations

        Returns:
            Dictionary mapping task names to metadata dictionaries.
            Each metadata dict contains:
            - 'type': One of 'task', 'python_task', 'group', 'tag'
            - 'yaml_path': Path to source YAML file (or -1 for generated entries)
            - 'task': For groups/tags, list of constituent task names

        Note:
            This method is called during TaskManager initialization to build
            the master task index. It uses 'simple' parsing mode for performance.
986

987
        """
988

Baber's avatar
Baber committed
989
        def _populate_tags_and_groups(
990
991
992
            config: dict,
            task: str,
            tasks_and_groups: dict[str, dict],
Baber's avatar
Baber committed
993
        ) -> None:
994
            """Extract and register tags from a task configuration.
Baber's avatar
Baber committed
995
996
997
998
999
1000
1001
1002
1003

            Tags allow grouping tasks by theme or category. This function
            processes the 'tag' field in task configs and maintains tag
            indices for quick lookup.

            Args:
                config: Task configuration dictionary
                task: Name of the task being processed
                tasks_and_groups: Master index to update with tag information
1004

Baber's avatar
Baber committed
1005
            """
1006
            # TODO: remove group in next release
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
            if "tag" in config:
                attr_list = config["tag"]
                if isinstance(attr_list, str):
                    attr_list = [attr_list]

                for tag in attr_list:
                    if tag not in tasks_and_groups:
                        tasks_and_groups[tag] = {
                            "type": "tag",
                            "task": [task],
                            "yaml_path": -1,
                        }
                    elif tasks_and_groups[tag]["type"] != "tag":
Lintang Sutawika's avatar
Lintang Sutawika committed
1020
                        eval_logger.info(
1021
                            f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
1022
                            "This may affect tasks you want to call.",
1023
                        )
1024
1025
1026
                        break
                    else:
                        tasks_and_groups[tag]["task"].append(task)
1027

Lintang Sutawika's avatar
Lintang Sutawika committed
1028
        # TODO: remove group in next release
1029
1030
1031
1032
        # ignore_dirs = [
        #     "__pycache__",
        #     ".ipynb_checkpoints",
        # ]
1033
        tasks_and_groups = collections.defaultdict()
Baber's avatar
Baber committed
1034
1035
        task_dir_path = Path(task_dir)

1036
1037
1038
        for yaml_path in iter_yaml_files(task_dir_path):
            try:
                config = load_yaml_config(
1039
1040
1041
                    yaml_path,
                    resolve_functions=False,
                    resolve_includes=False,
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
                )
            except (FileNotFoundError, YAMLError, OSError) as err:
                eval_logger.debug(f"File {yaml_path} could not be loaded ({err})")
                continue
            if self._config_is_python_task(config):
                # This is a python class config
                task = config["task"]
                self._register_task(
                    task,
                    "python_task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_group(config):
                # This is a group config
                tasks_and_groups[config["group"]] = {
                    "type": "group",
                    "task": -1,  # This signals that
                    # we don't need to know
                    # the task list for indexing
                    # as it can be loaded
                    # when called.
                    "yaml_path": str(yaml_path),
                }
lintangsutawika's avatar
lintangsutawika committed
1068

1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
                # # Registered the level 1 tasks from a group config
                # for config in config["task"]:
                #     if isinstance(config, dict) and self._config_is_task(config):
                #         task = config["task"]
                #         tasks_and_groups[task] = {
                #             "type": "task",
                #             "yaml_path": yaml_path,
                #             }

            elif self._config_is_task(config):
                # This is a task config
                task = config["task"]
                self._register_task(
                    task,
                    "task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_task_list(config):
                # This is a task_list config
                for task_entry in config["task_list"]:
                    if isinstance(task_entry, dict) and "task" in task_entry:
                        task_name = task_entry["task"]
1094
                        self._register_task(
1095
                            task_name,
1096
                            "task",
Baber's avatar
Baber committed
1097
                            str(yaml_path),
1098
1099
1100
                            tasks_and_groups,
                            config,
                            _populate_tags_and_groups,
1101
                        )
1102
1103
            else:
                eval_logger.debug(f"File {yaml_path} could not be loaded")
1104
1105

        return tasks_and_groups
lintangsutawika's avatar
lintangsutawika committed
1106

1107

Baber's avatar
nit  
Baber committed
1108
def get_task_name_from_config(task_config: dict[str, str]) -> str:
1109
    """Extract a task name from a configuration dictionary.
Baber's avatar
Baber committed
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127

    Determines the canonical name for a task based on its configuration,
    with fallback strategies for different config formats.

    Args:
        task_config: Task configuration dictionary

    Returns:
        String name for the task

    Example:
        >>> config = {"task": "hellaswag", "num_fewshot": 5}
        >>> get_task_name_from_config(config)
        'hellaswag'

        >>> config = {"dataset_path": "custom", "dataset_name": "mytask"}
        >>> get_task_name_from_config(config)
        'custom_mytask'
1128

Baber's avatar
Baber committed
1129
    """
1130
1131
1132
1133
    if "task" in task_config:
        return task_config["task"]
    if "dataset_name" in task_config:
        return "{dataset_path}_{dataset_name}".format(**task_config)
1134
    return "{dataset_path}".format(**task_config)
lintangsutawika's avatar
lintangsutawika committed
1135

1136

1137
1138
def get_task_name_from_object(task_object: ConfigurableTask | Task) -> str:
    """Extract the name from an instantiated task object.
Baber's avatar
Baber committed
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152

    Handles both ConfigurableTask and legacy Task objects with different
    attribute conventions for storing the task name.

    Args:
        task_object: An instantiated task object

    Returns:
        String name of the task

    Example:
        >>> task = ConfigurableTask(config={"task": "hellaswag"})
        >>> get_task_name_from_object(task)
        'hellaswag'
1153

Baber's avatar
Baber committed
1154
    """
1155
1156
    if hasattr(task_object, "config"):
        return task_object._config["task"]
lintangsutawika's avatar
lintangsutawika committed
1157
1158
1159
1160
1161
1162
1163
1164
1165

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )

1166

Baber's avatar
nit  
Baber committed
1167
def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
1168
    """Validate that no tasks appear in multiple groups simultaneously.
Baber's avatar
Baber committed
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185

    Helper function used to prevent conflicts when multiple groups claim
    the same constituent task. This could lead to ambiguous configuration
    like conflicting num_fewshot values.

    Args:
        task_dict: Dictionary mapping group names to lists of subtask names

    Raises:
        ValueError: If any tasks appear in multiple groups

    Example:
        >>> task_dict = {
        ...     "group1": ["task_a", "task_b"],
        ...     "group2": ["task_b", "task_c"]  # task_b appears twice!
        ... }
        >>> _check_duplicates(task_dict)  # Raises ValueError
1186

Lintang Sutawika's avatar
Lintang Sutawika committed
1187
1188
    """
    subtask_names = []
1189
    for value in task_dict.values():
Lintang Sutawika's avatar
Lintang Sutawika committed
1190
1191
1192
1193
1194
1195
1196
1197
1198
        subtask_names.extend(value)

    duplicate_tasks = {
        task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
    }

    # locate the potentially problematic groups that seem to 'compete' for constituent subtasks
    competing_groups = [
        group
1199
        for group in task_dict
Lintang Sutawika's avatar
Lintang Sutawika committed
1200
1201
1202
1203
        if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
    ]

    if len(duplicate_tasks) > 0:
1204
        msg = f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
Lintang Sutawika's avatar
Lintang Sutawika committed
1205
        raise ValueError(
1206
            msg,
Lintang Sutawika's avatar
Lintang Sutawika committed
1207
1208
1209
        )


1210
def get_task_dict(
1211
1212
1213
1214
    task_name_list: str | list[str | dict | Task],
    task_manager: TaskManager | None = None,
) -> dict[str, ConfigurableTask | Task]:
    """Create a dictionary of task objects from mixed input types.
Baber's avatar
Baber committed
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259

    This is the main public API for loading tasks. It accepts various input
    formats (names, configs, objects) and returns a unified dictionary of
    instantiated task objects ready for evaluation.

    The function handles:
    - String task names (looked up via TaskManager)
    - Configuration dictionaries (processed as inline configs)
    - Pre-instantiated Task objects (used as-is)
    - Validation to prevent conflicting group memberships

    Args:
        task_name_list: Mixed list of task specifications:
                       - str: Task name to look up
                       - dict: Inline task configuration
                       - Task: Pre-instantiated task object
        task_manager: TaskManager instance for name resolution.
                     If None, creates a default TaskManager.

    Returns:
        Dictionary mapping task names to instantiated task objects.
        All tasks are ready for evaluation.

    Raises:
        TypeError: If task_name_list contains unsupported types
        ValueError: If there are conflicting group memberships

    Example:
        Mixed input types::

            tasks = get_task_dict([
                "hellaswag",                              # lookup by name
                {"task": "arc_easy", "num_fewshot": 5},   # inline config
                pre_existing_task_object                  # direct object
            ])

        Simple case::

            tasks = get_task_dict("hellaswag")
            # Returns: {"hellaswag": ConfigurableTask(...)}

        With custom TaskManager::

            tm = TaskManager(include_path="/custom/tasks")
            tasks = get_task_dict(["custom_task"], task_manager=tm)
1260

1261
    """
1262
    from lm_eval.api.task import Task
1263

1264
    # Normalize input to list
1265
    if isinstance(task_name_list, str):
lintangsutawika's avatar
lintangsutawika committed
1266
        task_name_list = [task_name_list]
1267
    elif not isinstance(task_name_list, list):
1268
        msg = f"Expected a 'str' or 'list' but received {type(task_name_list)}."
1269
        raise TypeError(
1270
            msg,
1271
        )
lintangsutawika's avatar
lintangsutawika committed
1272

1273
1274
    # Validate list items
    if not all(isinstance(task, (str, dict, Task)) for task in task_name_list):
1275
        msg = "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
1276
        raise TypeError(
1277
            msg,
1278
        )
1279

1280
1281
1282
    # Ensure we have a task manager
    if task_manager is None:
        task_manager = TaskManager()
Lintang Sutawika's avatar
Lintang Sutawika committed
1283

1284
1285
1286
1287
1288
1289
1290
    # Process all items
    final_task_dict = {}
    for task_spec in task_name_list:
        if isinstance(task_spec, Task):
            # Pre-instantiated task object
            task_name = get_task_name_from_object(task_spec)
            if task_name in final_task_dict:
1291
1292
                msg = f"Duplicate task name: {task_name}"
                raise ValueError(msg)
1293
1294
1295
1296
1297
1298
1299
            final_task_dict[task_name] = task_spec
        else:
            # String or dict - use load_task_or_group
            result = task_manager.load_task_or_group(task_spec)
            # Check for duplicate names
            for name in result:
                if name in final_task_dict:
1300
1301
                    msg = f"Duplicate task name: {name}"
                    raise ValueError(msg)
1302
1303
1304
            final_task_dict.update(result)

    # Check for conflicting group memberships
Lintang Sutawika's avatar
Lintang Sutawika committed
1305
1306
1307
    _check_duplicates(get_subtask_list(final_task_dict))

    return final_task_dict