__init__.py 50.7 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
Task Management Module for LM Evaluation Harness.

This module provides comprehensive task discovery, loading, and management functionality
for the LM Evaluation Harness. It handles YAML configuration parsing with include support,
dynamic function importing, and task indexing across multiple directories.

Key Components:
- TaskManager: Main class for task discovery and management
- YAML configuration loading with !function tag support
- Task, group, and tag indexing
- Include resolution with cycle detection
- Caching for performance optimization

Example:
    Basic usage::

        task_manager = TaskManager()
        all_tasks = task_manager.all_tasks
        task_config = task_manager._get_config("hellaswag")

    Custom task paths::

        task_manager = TaskManager(
            include_path="/path/to/custom/tasks",
            include_defaults=True
        )
"""

30
import collections
31
32
import functools
import importlib.util
33
import inspect
34
import logging
35
import sys
36
from functools import partial
37
from glob import iglob
Baber's avatar
Baber committed
38
from pathlib import Path
Baber's avatar
nit  
Baber committed
39
40
41
42
43
44
45
46
47
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Generator,
    Mapping,
    Optional,
    Union,
)
48
49
50

import yaml
from yaml import YAMLError
&'s avatar
& committed
51

Lintang Sutawika's avatar
Lintang Sutawika committed
52
53
from lm_eval.api.group import ConfigurableGroup, GroupConfig
from lm_eval.evaluator_utils import get_subtask_list
Baber's avatar
nit  
Baber committed
54
55
56
57
58
from lm_eval.utils import pattern_match, setup_logging


if TYPE_CHECKING:
    from lm_eval.api.task import ConfigurableTask, Task
Lintang Sutawika's avatar
Lintang Sutawika committed
59

Baber's avatar
nit  
Baber committed
60
eval_logger = logging.getLogger(__name__)
Lintang Sutawika's avatar
Lintang Sutawika committed
61

Baber's avatar
Baber committed
62
#: List of configuration keys that are specific to groups only
Lintang Sutawika's avatar
Lintang Sutawika committed
63
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
Baber's avatar
Baber committed
64
65

#: Base YAML loader class - uses C loader if available for performance
Baber's avatar
nit  
Baber committed
66
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
Baber's avatar
Baber committed
67
68

#: Directory names to ignore during task discovery
Baber's avatar
nit  
Baber committed
69
70
71
72
_IGNORE_DIRS = (
    "__pycache__",
    ".ipynb_checkpoints",
)
73

lintangsutawika's avatar
lintangsutawika committed
74

Baber's avatar
nit  
Baber committed
75
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
Baber's avatar
Baber committed
76
77
78
79
80
81
82
83
84
85
86
87
88
    """
    YAML constructor that ignores !function tags during simple parsing.

    This is used when mode="simple" to skip function resolution for
    faster indexing operations.

    Args:
        loader: YAML loader instance
        node: YAML node being processed

    Returns:
        None
    """
Baber's avatar
nit  
Baber committed
89
    return None
90
91


Baber's avatar
nit  
Baber committed
92
@functools.lru_cache(maxsize=2048)  # ← reuse per (directory, simple) pair
93
94
95
96
97
98
99
100
101
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
    """
    Return a custom YAML Loader class bound to *yaml_dir*.

    yaml_dir
        Directory that holds the YAML file being parsed.
        We capture it so that !function look-ups can resolve relative
        Python files like  my_utils.some_fn  ➜  yaml_dir / "my_utils.py".
    simple
Baber's avatar
nit  
Baber committed
102
103
        If True we ignore !function completely (used by `mode="simple"`),
        used on TaskManager init to index.
104
105
106
107
108
109
110
    """

    class Loader(_Base):
        """Dynamically-generated loader that knows its base directory."""

    # Register (or stub) the !function constructor **for this Loader only**
    if simple:
Baber's avatar
nit  
Baber committed
111
        yaml.add_constructor("!function", ignore_constructor, Loader=Loader)
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    else:
        yaml.add_constructor(
            "!function",
            # capture yaml_dir once so the lambda is fast and pickle-able
            lambda ld, node, _dir=yaml_dir: _import_function(
                ld.construct_scalar(node),
                base_path=_dir,
            ),
            Loader=Loader,
        )

    return Loader


Baber's avatar
nit  
Baber committed
126
@functools.lru_cache(maxsize=None)  # ← cache module objects
127
def _import_function(qualname: str, *, base_path: Path) -> Callable:
Baber's avatar
Baber committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    """
    Dynamically import a function from a Python module relative to base_path.

    This function enables YAML files to reference Python functions using
    the !function tag. It supports dot notation for nested modules and
    caches imported modules for performance.

    Args:
        qualname: Qualified function name like "my_module.my_function"
        base_path: Base directory for resolving relative module paths

    Returns:
        The imported callable function

    Raises:
        ValueError: If qualname doesn't contain a module part

    Example:
        >>> func = _import_function("utils.custom_metric", base_path=Path("/tasks"))
        >>> result = func(predictions, references)
    """
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    mod_path, _, func_name = qualname.rpartition(".")
    if not mod_path:
        raise ValueError(f"{qualname!r} has no module part")
    file_path = base_path / f"{mod_path.replace('.', '/')}.py"
    module_name = f"_yaml_dynamic.{hash(file_path)}_{file_path.stem}"
    if module_name in sys.modules:
        mod = sys.modules[module_name]
    else:
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        mod = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(mod)
        sys.modules[module_name] = mod
    return getattr(mod, func_name)


Baber's avatar
Baber committed
164
@functools.lru_cache(maxsize=4096)
165
def _parse_yaml_file(path: Path, mode: str) -> dict:
Baber's avatar
Baber committed
166
167
168
169
170
171
172
173
174
175
    """
    Parse a single YAML file with the appropriate loader.

    Args:
        path: Path to the YAML file
        mode: Parsing mode ("full" or "simple")

    Returns:
        Parsed YAML configuration as dictionary
    """
176
177
178
179
180
    loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
    with path.open("rb") as fh:
        return yaml.load(fh, Loader=loader_cls)


Baber's avatar
nit  
Baber committed
181
@functools.lru_cache(maxsize=4096)
Baber's avatar
nit  
Baber committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def _get_cached_config(yaml_path: Path, mode: str) -> dict:
    """Load and cache resolved YAML configs with LRU eviction."""
    # Parse the YAML file
    yaml_config = _parse_yaml_file(yaml_path, mode)
    yaml_dir = yaml_path.parent

    # Handle includes
    include = yaml_config.pop("include", None)
    if not include:
        return yaml_config

    include_paths = include if isinstance(include, list) else [include]
    final_cfg: dict = {}

    for inc in reversed(include_paths):
        if inc is None:
            continue
        inc_path = Path(inc)
        if not inc_path.is_absolute():
            inc_path = (yaml_dir / inc_path).resolve()
        # Recursive call will use the cache
        included = _get_cached_config(inc_path, mode)
        final_cfg.update(included)

    final_cfg.update(yaml_config)  # local keys win
    return final_cfg


210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def load_yaml_config(
    yaml_path: Union[Path, str, None] = None,
    yaml_config: dict | None = None,
    yaml_dir: Path | None = None,
    mode: str = "full",
    *,
    _seen: set[tuple[Path, str]] | None = None,
    resolve_includes: bool = True,
) -> dict:
    """
    Parse a YAML config with optional include handling.

    Parameters
    ----------
    yaml_path
        Path to the main YAML file.  Needed unless *yaml_config* is
        supplied directly (e.g. by tests).
    yaml_config
        Pre-parsed dict to use instead of reading *yaml_path*.
    yaml_dir
        Base directory for resolving relative include paths.  Defaults
        to `yaml_path.parent`.
    mode
        "full"  – honour  !function  tags
        "simple" – ignore !function  (faster).
    _seen
        **Internal** recursion set: tuples of (absolute-path, mode).
        Prevents include cycles such as  A → B → A.
    """
    if yaml_config is None and yaml_path is None:
        raise ValueError("load_yaml_config needs either yaml_path or yaml_config")

    # ------------------------------------------------------------------ cycle guard
    if _seen is None:
        _seen = set()
    if yaml_path is not None:
        yaml_path = Path(yaml_path).expanduser().resolve()

Baber's avatar
nit  
Baber committed
248
249
250
        # ---------- fast-path: use LRU cached function ----------
        if yaml_config is None and resolve_includes:
            return _get_cached_config(yaml_path, mode)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291

        key = (yaml_path.resolve(), mode)
        if key in _seen:
            raise ValueError(f"Include cycle detected at {yaml_path}")
        _seen.add(key)

    # ------------------------------------------------------------------ load / parse
    if yaml_config is None:  # ordinary path-based load
        yaml_config = _parse_yaml_file(yaml_path, mode)

    if yaml_dir is None and yaml_path is not None:
        yaml_dir = yaml_path.parent
    assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"

    # ------------------------------------------------------------------ handle include
    include = yaml_config.pop("include", None)
    if not include and not resolve_includes:
        return yaml_config

    include_paths = include if isinstance(include, list) else [include]
    final_cfg: dict = {}

    for inc in reversed(include_paths):
        if inc is None:  # guard against explicit nulls
            continue
        inc_path = Path(inc)
        if not inc_path.is_absolute():
            inc_path = (yaml_dir / inc_path).resolve()
        included = load_yaml_config(
            yaml_path=inc_path,
            mode=mode,
            yaml_dir=inc_path.parent,
            _seen=_seen,  # <-- pass set downward
        )
        final_cfg.update(included)

    final_cfg.update(yaml_config)  # local keys win
    return final_cfg


def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
Baber's avatar
Baber committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    """
    Recursively iterate over all YAML files in a directory tree.

    Excludes files in ignored directories like __pycache__ and .ipynb_checkpoints.

    Args:
        root: Root directory to search for YAML files

    Yields:
        Path objects for each discovered YAML file

    Example:
        >>> for yaml_file in iter_yaml_files(Path("tasks")):
        ...     print(f"Found task config: {yaml_file}")
    """
Baber's avatar
nit  
Baber committed
307
308
    for p in iglob("**/*.yaml", root_dir=root, recursive=True):
        # ignore check
Baber's avatar
nit  
Baber committed
309
        if Path(p).parts[0] in _IGNORE_DIRS:
310
            continue
Baber's avatar
nit  
Baber committed
311
        yield root / p
Lintang Sutawika's avatar
Lintang Sutawika committed
312

313

314
class TaskManager:
Baber's avatar
Baber committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    """
    Central manager for task discovery, indexing, and loading.

    TaskManager scans directories for YAML task configurations and maintains
    an index of all available tasks, groups, and tags. It provides methods
    for listing, filtering, and loading tasks with their configurations.

    The manager supports:
    - Automatic discovery from default lm_eval/tasks/ directory
    - Custom task directories via include_path
    - Task grouping and tagging
    - Configuration inheritance via YAML includes
    - Caching for performance

    Attributes:
        include_path: Additional directories to search for tasks
        metadata: Global metadata to inject into all task configs
        task_group_map: Mapping of tasks to their parent groups
333

Baber's avatar
Baber committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    Example:
        Basic usage::

            tm = TaskManager()
            print(f"Found {len(tm.all_tasks)} tasks")
            hellaswag_config = tm._get_config("hellaswag")

        With custom tasks::

            tm = TaskManager(
                include_path="/my/custom/tasks",
                verbosity="INFO"
            )
            custom_tasks = [t for t in tm.all_tasks if "custom" in t]
348
349
    """

350
351
    def __init__(
        self,
Lintang Sutawika's avatar
Lintang Sutawika committed
352
        verbosity: Optional[str] = None,
Baber's avatar
nit  
Baber committed
353
        include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
354
        include_defaults: bool = True,
Baber Abbasi's avatar
Baber Abbasi committed
355
        metadata: Optional[dict] = None,
356
    ) -> None:
Baber's avatar
Baber committed
357
358
359
360
361
362
363
364
365
366
        """
        Initialize the TaskManager.

        Args:
            verbosity: Logging verbosity level (DEBUG, INFO, WARNING, ERROR)
            include_path: Additional path(s) to search for tasks. Can be a single
                         path or list of paths.
            include_defaults: Whether to include default tasks from lm_eval/tasks/
            metadata: Global metadata dictionary to inject into all task configs
        """
Lintang Sutawika's avatar
Lintang Sutawika committed
367
        if verbosity is not None:
Baber's avatar
nit  
Baber committed
368
            setup_logging(verbosity)
369
        self.include_path = include_path
Baber Abbasi's avatar
Baber Abbasi committed
370
        self.metadata = metadata
371
372
373
        self._task_index = self.initialize_tasks(
            include_path=include_path, include_defaults=include_defaults
        )
374
        self._all_tasks = sorted(list(self._task_index.keys()))
375

376
377
378
379
        self._all_groups = sorted(
            [x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
        )
        self._all_subtasks = sorted(
380
381
382
383
384
            [
                x
                for x in self._all_tasks
                if self._task_index[x]["type"] in ["task", "python_task"]
            ]
385
386
387
388
389
        )
        self._all_tags = sorted(
            [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
        )

390
        self.task_group_map = collections.defaultdict(list)
391

392
393
    def initialize_tasks(
        self,
Baber's avatar
nit  
Baber committed
394
        include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
395
        include_defaults: bool = True,
Baber Abbasi's avatar
Baber Abbasi committed
396
397
    ) -> dict[str, dict]:
        """Creates a dictionary of tasks indexes.
398

Baber's avatar
nit  
Baber committed
399
        :param include_path: Union[str, list] = None
400
401
402
403
            An additional path to be searched for tasks recursively.
            Can provide more than one such path as a list.
        :param include_defaults: bool = True
            If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
Baber Abbasi's avatar
Baber Abbasi committed
404
        return
Baber's avatar
nit  
Baber committed
405
            dictionary of task names as key and task metadata
406
        """
407
        if include_defaults:
Baber's avatar
Baber committed
408
            all_paths = [Path(__file__).parent]
409
410
        else:
            all_paths = []
411
        if include_path is not None:
Baber's avatar
Baber committed
412
            if isinstance(include_path, (str, Path)):
413
                include_path = [include_path]
Baber's avatar
Baber committed
414
415
            # Convert all paths to Path objects
            all_paths.extend(Path(p) for p in include_path)
416

417
418
419
420
        task_index = {}
        for task_dir in all_paths:
            tasks = self._get_task_and_group(task_dir)
            task_index = {**tasks, **task_index}
lintangsutawika's avatar
format  
lintangsutawika committed
421

422
423
424
        return task_index

    @property
Baber's avatar
nit  
Baber committed
425
    def all_tasks(self) -> list[str]:
Baber's avatar
Baber committed
426
        """Get sorted list of all task names (tasks, groups, and tags)."""
427
428
        return self._all_tasks

429
    @property
Baber's avatar
nit  
Baber committed
430
    def all_groups(self) -> list[str]:
Baber's avatar
Baber committed
431
        """Get sorted list of all group names."""
432
433
434
        return self._all_groups

    @property
Baber's avatar
nit  
Baber committed
435
    def all_subtasks(self) -> list[str]:
Baber's avatar
Baber committed
436
        """Get sorted list of all individual task names (excludes groups and tags)."""
437
438
439
        return self._all_subtasks

    @property
Baber's avatar
nit  
Baber committed
440
    def all_tags(self) -> list[str]:
Baber's avatar
Baber committed
441
        """Get sorted list of all tag names."""
442
443
        return self._all_tags

444
    @property
Baber's avatar
nit  
Baber committed
445
    def task_index(self) -> dict[str, dict[str, Union[str, int, list[str]]]]:
Baber's avatar
Baber committed
446
        """Get the complete task index with metadata for all tasks."""
447
448
        return self._task_index

449
    def list_all_tasks(
Baber's avatar
Baber committed
450
451
452
453
        self,
        list_groups: bool = True,
        list_tags: bool = True,
        list_subtasks: bool = True,
454
    ) -> str:
455
456
457
458
459
        """
        Return a Markdown table (as a string) listing groups, tags and/or subtasks
        known to this TaskManager.  Safe for configs whose yaml_path is -1 and for
        task configs whose `include:` is a list.
        """
460
461
        from pytablewriter import MarkdownTableWriter

462
        # ------------------------------------------------------------------ helpers
Baber's avatar
Baber committed
463
        def sanitize_path(path: str) -> str:
464
465
            # print a relative path for anything inside lm_eval/tasks/
            # path_str = str(path)
466
467
            if "lm_eval/tasks/" in path:
                return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
468
469
470
471
472
473
474
475
476
477
478
            return path

        def first_output_type_from_includes(cfg: dict, base: Path) -> str:
            """Walk cfg['include'] (string or list) and return the first
            include that itself specifies an output_type."""
            inc_raw = cfg.get("include")
            if not inc_raw:
                return ""

            inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw]
            for inc in inc_list:
Baber's avatar
nit  
Baber committed
479
480
481
482
483
484
485
486
487
488
                if inc:
                    inc_path = Path(inc)
                    if not inc_path.is_absolute():  # treat as relative include
                        inc_path = base.parent / inc_path
                    try:
                        inc_cfg = load_yaml_config(inc_path, mode="simple")
                    except FileNotFoundError:
                        continue
                    if "output_type" in inc_cfg:
                        return inc_cfg["output_type"]
489
490
491
            return ""

        # -------------------------------------------------------------- GROUP table
492
493
        group_table = MarkdownTableWriter()
        group_table.headers = ["Group", "Config Location"]
494
495
496
497
498
499
500
501
502
        group_table.value_matrix = [
            [
                g,
                "---"
                if self.task_index[g]["yaml_path"] == -1
                else sanitize_path(self.task_index[g]["yaml_path"]),
            ]
            for g in self.all_groups
        ]
503

504
        # ---------------------------------------------------------------- TAG table
505
506
507
508
        tag_table = MarkdownTableWriter()
        tag_table.headers = ["Tag"]
        tag_table.value_matrix = [[t] for t in self.all_tags]

509
        # ------------------------------------------------------------ SUBTASK table
510
511
        subtask_table = MarkdownTableWriter()
        subtask_table.headers = ["Task", "Config Location", "Output Type"]
512
513
        st_values: list[list[str]] = []

514
        for t in self.all_subtasks:
515
516
517
518
519
520
            raw_path = self.task_index[t]["yaml_path"]

            if raw_path == -1:
                # python-only task or generated at runtime
                display_path = "---"
                output_type = ""
521
            else:
522
523
524
525
526
527
528
529
530
531
532
533
                path_obj = Path(raw_path)
                display_path = sanitize_path(str(path_obj))

                # load minimal YAML to discover output_type
                cfg = load_yaml_config(path_obj, mode="simple")
                if "output_type" in cfg:
                    output_type = cfg["output_type"]
                else:
                    output_type = first_output_type_from_includes(cfg, path_obj)

            st_values.append([t, display_path, output_type])

534
535
        subtask_table.value_matrix = st_values

536
537
        # ------------------------------------------------------------- final string
        parts: list[str] = ["\n"]
538
        if list_groups:
539
540
            parts.append(group_table.dumps())
            parts.append("\n")
541
        if list_tags:
542
543
            parts.append(tag_table.dumps())
            parts.append("\n")
544
        if list_subtasks:
545
546
547
548
            parts.append(subtask_table.dumps())
            parts.append("\n")

        return "".join(parts)
549

Baber Abbasi's avatar
Baber Abbasi committed
550
    def match_tasks(self, task_list: list[str]) -> list[str]:
Baber's avatar
Baber committed
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        """
        Match task names using pattern matching.

        Supports glob-style patterns and returns all matching task names.

        Args:
            task_list: List of task name patterns to match

        Returns:
            List of matching task names

        Example:
            >>> tm.match_tasks(["hella*", "arc_*"])
            ['hellaswag', 'arc_easy', 'arc_challenge']
        """
Baber's avatar
nit  
Baber committed
566
        return pattern_match(task_list, self.all_tasks)
567

Baber Abbasi's avatar
Baber Abbasi committed
568
    def _name_is_registered(self, name: str) -> bool:
Baber's avatar
Baber committed
569
        """Check if a name is registered in the task index."""
570
        return name in self.all_tasks
571

Baber Abbasi's avatar
Baber Abbasi committed
572
    def _name_is_task(self, name: str) -> bool:
Baber's avatar
Baber committed
573
        """Check if a name refers to an individual task (not group or tag)."""
574
575
576
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "task"
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
577

Baber Abbasi's avatar
Baber Abbasi committed
578
    def _name_is_tag(self, name: str) -> bool:
Baber's avatar
Baber committed
579
        """Check if a name refers to a tag."""
580
        return self._name_is_registered(name) and self.task_index[name]["type"] == "tag"
581

Baber Abbasi's avatar
Baber Abbasi committed
582
    def _name_is_group(self, name: str) -> bool:
Baber's avatar
Baber committed
583
        """Check if a name refers to a group."""
584
585
586
        return (
            self._name_is_registered(name) and self.task_index[name]["type"] == "group"
        )
587

Baber Abbasi's avatar
Baber Abbasi committed
588
    def _name_is_python_task(self, name: str) -> bool:
Baber's avatar
Baber committed
589
        """Check if a name refers to a Python-defined task."""
590
591
592
593
        return (
            self._name_is_registered(name)
            and self.task_index[name]["type"] == "python_task"
        )
594

Baber Abbasi's avatar
Baber Abbasi committed
595
    def _config_is_task(self, config: dict) -> bool:
Baber's avatar
Baber committed
596
        """Check if a config dictionary defines a single task."""
597
        return "task" in config and isinstance(config["task"], str)
598

Baber Abbasi's avatar
Baber Abbasi committed
599
    def _config_is_group(self, config: dict) -> bool:
Baber's avatar
Baber committed
600
        """Check if a config dictionary defines a group of tasks."""
601
        return "task" in config and isinstance(config["task"], list)
602

Baber Abbasi's avatar
Baber Abbasi committed
603
    def _config_is_python_task(self, config: dict) -> bool:
Baber's avatar
Baber committed
604
        """Check if a config dictionary defines a Python class-based task."""
605
606
607
        return "class" in config

    def _config_is_task_list(self, config: dict) -> bool:
Baber's avatar
Baber committed
608
        """Check if a config dictionary defines a task list."""
609
        return "task_list" in config and isinstance(config["task_list"], list)
610

Baber's avatar
Baber committed
611
    def _get_yaml_path(self, name: str) -> Union[str, int]:
Baber's avatar
Baber committed
612
613
614
615
616
617
618
619
620
621
622
623
        """
        Get the YAML file path for a registered task.

        Args:
            name: Task name

        Returns:
            Path to YAML file, or -1 for Python-only tasks

        Raises:
            ValueError: If task name is not registered
        """
624
625
        if name not in self.task_index:
            raise ValueError
626
627
        return self.task_index[name]["yaml_path"]

Baber's avatar
nit  
Baber committed
628
    def _get_config(self, name: str) -> dict:
Baber's avatar
Baber committed
629
630
631
632
633
634
635
636
637
638
639
640
        """
        Load the full configuration for a registered task.

        Args:
            name: Task name

        Returns:
            Complete task configuration dictionary

        Raises:
            ValueError: If task name is not registered
        """
641
642
        if name not in self.task_index:
            raise ValueError
643
644
645
646
        yaml_path = self._get_yaml_path(name)
        if yaml_path == -1:
            return {}
        else:
647
            return load_yaml_config(Path(yaml_path), mode="full")
648

Baber's avatar
nit  
Baber committed
649
    def _get_tasklist(self, name: str) -> Union[list[str], int]:
Baber's avatar
Baber committed
650
651
652
653
654
655
656
657
658
659
660
661
        """
        Get the task list for a group or tag.

        Args:
            name: Group or tag name

        Returns:
            List of task names in the group/tag

        Raises:
            ValueError: If name refers to an individual task
        """
662
663
        if self._name_is_task(name):
            raise ValueError
664
665
        return self.task_index[name]["task"]

666
667
668
669
670
    def _register_task(
        self,
        task_name: str,
        task_type: str,
        yaml_path: str,
Baber's avatar
nit  
Baber committed
671
672
        tasks_and_groups: dict[str, dict],
        config: Optional[dict] = None,
Baber's avatar
Baber committed
673
674
        populate_tags_fn: Optional[callable] = None,
    ) -> None:
675
676
677
678
679
680
681
        """Helper method to register a task in the tasks_and_groups dict"""
        tasks_and_groups[task_name] = {
            "type": task_type,
            "yaml_path": yaml_path,
        }
        # Only populate tags for configs that support it (not groups)
        if config and task_type != "group" and populate_tags_fn:
682
            populate_tags_fn(config, task_name, tasks_and_groups)
683
684

    def _merge_task_configs(
Baber's avatar
nit  
Baber committed
685
686
        self, base_config: dict, task_specific_config: dict, task_name: str
    ) -> dict:
687
688
689
690
691
692
693
        """Merge base config with task-specific overrides for task_list configs"""
        if task_specific_config:
            task_specific_config = task_specific_config.copy()
            task_specific_config.pop("task", None)
            return {**base_config, **task_specific_config, "task": task_name}
        return {**base_config, "task": task_name}

Baber's avatar
Baber committed
694
    def _process_tag_subtasks(
Baber's avatar
nit  
Baber committed
695
696
        self, tag_name: str, update_config: Optional[dict] = None
    ) -> dict:
697
698
699
700
701
702
703
704
        """Process subtasks for a tag and return loaded tasks"""
        subtask_list = self._get_tasklist(tag_name)
        fn = partial(
            self._load_individual_task_or_group,
            update_config=update_config,
        )
        return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))

Baber's avatar
nit  
Baber committed
705
    def _process_alias(self, config: dict, group: Optional[str] = None) -> dict:
Baber's avatar
Baber committed
706
707
708
709
710
711
712
713
714
715
716
717
718
        """
        Process group alias configuration.

        If the group is not the same as the original group which the group alias
        was intended for, set the group_alias to None instead.

        Args:
            config: Task configuration dictionary
            group: Group name to validate against

        Returns:
            Modified configuration with processed aliases
        """
719
720
721
722
723
        if ("group_alias" in config) and ("group" in config) and group is not None:
            if config["group"] != group:
                config["group_alias"] = None
        return config

Baber's avatar
Baber committed
724
    def _class_has_config_in_constructor(self, cls) -> bool:
Baber's avatar
Baber committed
725
726
727
728
729
730
731
732
733
        """
        Check if a class constructor accepts a 'config' parameter.

        Args:
            cls: Class to inspect

        Returns:
            True if constructor has 'config' parameter, False otherwise
        """
734
735
736
737
738
739
740
        constructor = getattr(cls, "__init__", None)
        return (
            "config" in inspect.signature(constructor).parameters
            if constructor
            else False
        )

741
    def _load_individual_task_or_group(
742
        self,
Baber's avatar
nit  
Baber committed
743
        name_or_config: Optional[Union[str, dict]] = None,
744
        parent_name: Optional[str] = None,
Baber's avatar
nit  
Baber committed
745
        update_config: Optional[dict] = None,
746
    ) -> Mapping:
Baber's avatar
Baber committed
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
        """
        Load a single task or group with all its configurations and dependencies.

        This is the core method for instantiating task objects from either task names
        or configuration dictionaries. It handles complex scenarios including:
        - Individual tasks and Python class-based tasks
        - Groups and their constituent subtasks
        - Tags and their associated tasks
        - Configuration merging and inheritance
        - Duplicate detection and name resolution
        - Include processing and YAML inheritance

        Args:
            name_or_config: Either a task name (str) or configuration dict.
                           If str, looks up the task in the index.
                           If dict, processes as inline configuration.
            parent_name: Name of parent group (for duplicate detection)
            update_config: Additional configuration to merge into task configs

        Returns:
            Mapping of task/group names to instantiated task objects.
            For individual tasks: {task_name: ConfigurableTask}
            For groups: {group_name: {subtask1: Task1, subtask2: Task2, ...}}

        Example:
            Load individual task::

                task_dict = tm._load_individual_task_or_group("hellaswag")
                # Returns: {"hellaswag": ConfigurableTask(...)}

            Load with config override::

                task_dict = tm._load_individual_task_or_group(
                    {"task": "hellaswag", "num_fewshot": 5}
                )

            Load a group::

                group_dict = tm._load_individual_task_or_group("arc_group")
                # Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
        """
Baber's avatar
nit  
Baber committed
788
789
        from lm_eval.api.task import ConfigurableTask, Task

Baber's avatar
Baber committed
790
        def _load_task(
Baber's avatar
nit  
Baber committed
791
792
            config: dict, task: str, yaml_path: Optional[str] = None
        ) -> dict[str, Union["ConfigurableTask", "Task"]]:
Baber's avatar
Baber committed
793
794
795
796
797
798
799
800
801
802
803
804
805
            """
            Create a single task object from configuration.

            Handles include processing, Python class instantiation, and metadata injection.

            Args:
                config: Task configuration dictionary
                task: Task name
                yaml_path: Path to source YAML file (for include resolution)

            Returns:
                Dictionary mapping task name to instantiated task object
            """
806
            if "include" in config:
807
808
809
                # Store the task name to preserve it after include processing
                original_task_name = config.get("task", task)

810
                config = {
811
812
                    **load_yaml_config(
                        yaml_path=Path(yaml_path),
813
                        yaml_config={"include": config.pop("include")},
814
                        mode="full" if yaml_path else "simple",
815
816
                    ),
                    **config,
817
                    "task": original_task_name,
818
                }
819
820
821
822

                # Ensure the task name from the group config is preserved
                # This prevents tasks with the same include from being treated as duplicates

823
            if self._config_is_python_task(config):
824
825
826
827
828
829
                if self._class_has_config_in_constructor(config["class"]):
                    task_object = config["class"](config=config)
                else:
                    task_object = config["class"]()
                if isinstance(task_object, ConfigurableTask):
                    # very scuffed: set task name here. TODO: fixme?
830
                    task_object.config.task = task
831
            else:
Baber Abbasi's avatar
Baber Abbasi committed
832
833
834
835
                if self.metadata is not None:
                    config["metadata"] = config.get("metadata", {}) | self.metadata
                else:
                    config["metadata"] = config.get("metadata", {})
836
                task_object = ConfigurableTask(config=config)
Lintang Sutawika's avatar
Lintang Sutawika committed
837

838
839
            return {task: task_object}

Baber Abbasi's avatar
Baber Abbasi committed
840
        def _get_group_and_subtask_from_config(
Baber's avatar
nit  
Baber committed
841
842
            config: dict,
        ) -> tuple[ConfigurableGroup, list[str]]:
Baber's avatar
Baber committed
843
844
845
846
847
848
849
850
851
852
853
            """
            Extract group object and subtask list from group configuration.

            Expands any tags in the task list to their constituent tasks.

            Args:
                config: Group configuration dictionary

            Returns:
                Tuple of (ConfigurableGroup, list of subtask names)
            """
Baber Abbasi's avatar
Baber Abbasi committed
854
855
            if self.metadata is not None:
                config["metadata"] = config.get("metadata", {}) | self.metadata
Lintang Sutawika's avatar
Lintang Sutawika committed
856
857
858
859
860
861
862
863
864
            group_name = ConfigurableGroup(config=config)
            subtask_list = []
            for task in group_name.config["task"]:
                if isinstance(task, str) and self._name_is_tag(task):
                    subtask_list.extend(self._get_tasklist(task))
                else:
                    subtask_list.append(task)
            return group_name, subtask_list

Baber Abbasi's avatar
Baber Abbasi committed
865
        def _process_group_config(
Baber's avatar
nit  
Baber committed
866
867
            config: dict, update_config: Optional[dict] = None
        ) -> tuple[dict, Optional[dict]]:
Baber's avatar
Baber committed
868
869
870
871
872
873
874
875
876
877
878
879
880
            """
            Separate group-specific config from task-level config overrides.

            Group-only keys (like 'group', 'aggregate') stay with the group,
            while other keys become config overrides for constituent tasks.

            Args:
                config: Full configuration dictionary
                update_config: Additional config to merge

            Returns:
                Tuple of (group_config, task_update_config)
            """
Lintang Sutawika's avatar
Lintang Sutawika committed
881
882
883
884
885
886
887
888
889
890
891
            if update_config is not None:
                config = {**config, **update_config}
            _update_config = {
                k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS
            }
            if not bool(_update_config):
                _update_config = None

            group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
            return group_config, _update_config

892
893
894
895
        if isinstance(name_or_config, str):
            if update_config is not None:
                # Process name_or_config as a dict instead
                name_or_config = {"task": name_or_config, **update_config}
Lintang Sutawika's avatar
Lintang Sutawika committed
896
897
898
            elif self._name_is_task(name_or_config) or self._name_is_python_task(
                name_or_config
            ):
899
900
                # Get the yaml_path for this task
                yaml_path = self._get_yaml_path(name_or_config)
901
                task_config = self._get_config(name_or_config)
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930

                # Handle task_list configs
                if "task_list" in task_config:
                    # Find the specific task entry
                    task_specific_config = None
                    for task_entry in task_config["task_list"]:
                        if (
                            isinstance(task_entry, dict)
                            and task_entry.get("task") == name_or_config
                        ):
                            task_specific_config = task_entry
                            break

                    if task_specific_config:
                        # Create base config without task_list
                        base_config = {
                            k: v for k, v in task_config.items() if k != "task_list"
                        }
                        # Merge using helper method
                        task_config = self._merge_task_configs(
                            base_config, task_specific_config, name_or_config
                        )
                    else:
                        # Task not found in task_list, shouldn't happen if indexing worked correctly
                        eval_logger.warning(
                            f"Task {name_or_config} not found in task_list"
                        )
                        task_config = {"task": name_or_config}

931
                return _load_task(task_config, task=name_or_config, yaml_path=yaml_path)
932
            else:
933
934
935
                subtask_list = self._get_tasklist(name_or_config)
                if subtask_list == -1:
                    group_config = self._get_config(name_or_config)
Lintang Sutawika's avatar
Lintang Sutawika committed
936
937
938
939
940
941
                    group_config, update_config = _process_group_config(group_config)
                    group_name, subtask_list = _get_group_and_subtask_from_config(
                        group_config
                    )
                else:
                    if self._name_is_tag(name_or_config):
942
943
944
                        return self._process_tag_subtasks(
                            name_or_config,
                            name_or_config
Lintang Sutawika's avatar
Lintang Sutawika committed
945
946
947
948
949
950
951
                            if isinstance(name_or_config, dict)
                            else None,
                        )
                    else:
                        group_name = ConfigurableGroup(
                            config={"group": name_or_config, "task": subtask_list}
                        )
952

953
954
        if isinstance(name_or_config, dict):
            if self._config_is_task(name_or_config):
Lintang Sutawika's avatar
Lintang Sutawika committed
955
956
957
                name = name_or_config.pop("task")
                if update_config is not None:
                    name_or_config = {**name_or_config, **update_config}
958
959
                # If the name is registered as a group
                if self._name_is_group(name):
Lintang Sutawika's avatar
Lintang Sutawika committed
960
961
962
963
964
965
966
967
968
                    group_config = self._get_config(name)

                    group_config, update_config = _process_group_config(
                        group_config, name_or_config
                    )
                    group_name, subtask_list = _get_group_and_subtask_from_config(
                        group_config
                    )
                elif self._name_is_tag(name):
969
                    return self._process_tag_subtasks(name, name_or_config)
970
                else:
971
                    yaml_path = None
972
                    if self._name_is_registered(name):
973
                        yaml_path = self._get_yaml_path(name)
974
975
976
977
                        base_task_config = self._get_config(name)

                        # Check if this is a duplicate.
                        if parent_name is not None:
978
979
980
981
982
983
984
985
                            num_duplicate = len(
                                list(
                                    filter(
                                        lambda x: x.startswith(name),
                                        self.task_group_map[parent_name],
                                    )
                                )
                            )
986
987
988
989
                            if num_duplicate > 0:
                                name = f"{name}-{num_duplicate}"
                            self.task_group_map[parent_name].append(name)

990
991
992
993
                        task_config = {
                            **base_task_config,
                            **name_or_config,
                        }
994
995
                    else:
                        task_config = name_or_config
996
                    return _load_task(task_config, task=name, yaml_path=yaml_path)
997
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
998
999
1000
1001
                group_config, update_config = _process_group_config(name_or_config)
                group_name, subtask_list = _get_group_and_subtask_from_config(
                    group_config
                )
1002

1003
1004
1005
1006
1007
        fn = partial(
            self._load_individual_task_or_group,
            parent_name=group_name,
            update_config=update_config,
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
1008
1009
        return {
            group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
1010
        }
1011

Baber's avatar
Baber committed
1012
    def load_task_or_group(
Baber's avatar
nit  
Baber committed
1013
1014
        self, task_list: Optional[Union[str, list[str]]] = None
    ) -> dict:
Baber's avatar
Baber committed
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
        """
        Load multiple tasks or groups from a list of names.

        This is the main entry point for loading tasks. It handles lists
        of task names and delegates to _load_individual_task_or_group for
        each item, then merges the results.

        Args:
            task_list: Single task name or list of task names to load.
                      Can include individual tasks, groups, and tags.
1025

Baber's avatar
Baber committed
1026
1027
1028
        Returns:
            Dictionary mapping task/group names to loaded task objects.
            Results from all requested items are merged into a single dict.
1029

Baber's avatar
Baber committed
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
        Example:
            Load multiple tasks::

                tasks = tm.load_task_or_group(["hellaswag", "arc_easy"])
                # Returns: {"hellaswag": Task1, "arc_easy": Task2}

            Load a group::

                tasks = tm.load_task_or_group("arc_group")
                # Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
1040
1041
1042
        """
        if isinstance(task_list, str):
            task_list = [task_list]
1043

1044
        all_loaded_tasks = dict(
Baber Abbasi's avatar
Baber Abbasi committed
1045
1046
1047
1048
1049
1050
            collections.ChainMap(
                *map(
                    lambda task: self._load_individual_task_or_group(task),
                    task_list,
                )
            )
1051
1052
1053
        )
        return all_loaded_tasks

Baber's avatar
nit  
Baber committed
1054
    def load_config(self, config: dict) -> Mapping:
Baber's avatar
Baber committed
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
        """
        Load a task from an inline configuration dictionary.

        Args:
            config: Configuration dictionary defining the task

        Returns:
            Mapping of task name to loaded task object

        Example:
            >>> config = {"task": "hellaswag", "num_fewshot": 5}
            >>> task_dict = tm.load_config(config)
        """
1068
1069
        return self._load_individual_task_or_group(config)

Baber's avatar
nit  
Baber committed
1070
    def _get_task_and_group(self, task_dir: Union[str, Path]) -> dict[str, dict]:
Baber's avatar
Baber committed
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
        """
        Scan a directory for task configurations and build an index.

        Creates a dictionary of task metadata by recursively scanning for
        YAML files and parsing their configurations. This method handles:
        - Regular task configs with 'task' key
        - Python class-based tasks with 'class' key
        - Group configs with 'group' key
        - Task list configs with 'task_list' key
        - Tag extraction and registration

        Args:
            task_dir: Directory path to scan for YAML task configurations

        Returns:
            Dictionary mapping task names to metadata dictionaries.
            Each metadata dict contains:
            - 'type': One of 'task', 'python_task', 'group', 'tag'
            - 'yaml_path': Path to source YAML file (or -1 for generated entries)
            - 'task': For groups/tags, list of constituent task names

        Note:
            This method is called during TaskManager initialization to build
            the master task index. It uses 'simple' parsing mode for performance.
1095
        """
1096

Baber's avatar
Baber committed
1097
        def _populate_tags_and_groups(
Baber's avatar
nit  
Baber committed
1098
            config: dict, task: str, tasks_and_groups: dict[str, dict]
Baber's avatar
Baber committed
1099
        ) -> None:
Baber's avatar
Baber committed
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
            """
            Extract and register tags from a task configuration.

            Tags allow grouping tasks by theme or category. This function
            processes the 'tag' field in task configs and maintains tag
            indices for quick lookup.

            Args:
                config: Task configuration dictionary
                task: Name of the task being processed
                tasks_and_groups: Master index to update with tag information
            """
1112
            # TODO: remove group in next release
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
            if "tag" in config:
                attr_list = config["tag"]
                if isinstance(attr_list, str):
                    attr_list = [attr_list]

                for tag in attr_list:
                    if tag not in tasks_and_groups:
                        tasks_and_groups[tag] = {
                            "type": "tag",
                            "task": [task],
                            "yaml_path": -1,
                        }
                    elif tasks_and_groups[tag]["type"] != "tag":
Lintang Sutawika's avatar
Lintang Sutawika committed
1126
                        eval_logger.info(
1127
1128
                            f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
                            "This may affect tasks you want to call."
1129
                        )
1130
1131
1132
                        break
                    else:
                        tasks_and_groups[tag]["task"].append(task)
1133

Lintang Sutawika's avatar
Lintang Sutawika committed
1134
        # TODO: remove group in next release
1135
1136
1137
1138
        # ignore_dirs = [
        #     "__pycache__",
        #     ".ipynb_checkpoints",
        # ]
1139
        tasks_and_groups = collections.defaultdict()
Baber's avatar
Baber committed
1140
1141
        task_dir_path = Path(task_dir)

1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
        for yaml_path in iter_yaml_files(task_dir_path):
            try:
                config = load_yaml_config(
                    yaml_path, mode="simple", resolve_includes=False
                )
            except (FileNotFoundError, YAMLError, OSError) as err:
                eval_logger.debug(f"File {yaml_path} could not be loaded ({err})")
                continue
            if self._config_is_python_task(config):
                # This is a python class config
                task = config["task"]
                self._register_task(
                    task,
                    "python_task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_group(config):
                # This is a group config
                tasks_and_groups[config["group"]] = {
                    "type": "group",
                    "task": -1,  # This signals that
                    # we don't need to know
                    # the task list for indexing
                    # as it can be loaded
                    # when called.
                    "yaml_path": str(yaml_path),
                }
lintangsutawika's avatar
lintangsutawika committed
1172

1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
                # # Registered the level 1 tasks from a group config
                # for config in config["task"]:
                #     if isinstance(config, dict) and self._config_is_task(config):
                #         task = config["task"]
                #         tasks_and_groups[task] = {
                #             "type": "task",
                #             "yaml_path": yaml_path,
                #             }

            elif self._config_is_task(config):
                # This is a task config
                task = config["task"]
                self._register_task(
                    task,
                    "task",
                    str(yaml_path),
                    tasks_and_groups,
                    config,
                    _populate_tags_and_groups,
                )
            elif self._config_is_task_list(config):
                # This is a task_list config
                for task_entry in config["task_list"]:
                    if isinstance(task_entry, dict) and "task" in task_entry:
                        task_name = task_entry["task"]
1198
                        self._register_task(
1199
                            task_name,
1200
                            "task",
Baber's avatar
Baber committed
1201
                            str(yaml_path),
1202
1203
1204
                            tasks_and_groups,
                            config,
                            _populate_tags_and_groups,
1205
                        )
1206
1207
            else:
                eval_logger.debug(f"File {yaml_path} could not be loaded")
1208
1209

        return tasks_and_groups
lintangsutawika's avatar
lintangsutawika committed
1210

1211

Baber's avatar
nit  
Baber committed
1212
def get_task_name_from_config(task_config: dict[str, str]) -> str:
Baber's avatar
Baber committed
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
    """
    Extract a task name from a configuration dictionary.

    Determines the canonical name for a task based on its configuration,
    with fallback strategies for different config formats.

    Args:
        task_config: Task configuration dictionary

    Returns:
        String name for the task

    Example:
        >>> config = {"task": "hellaswag", "num_fewshot": 5}
        >>> get_task_name_from_config(config)
        'hellaswag'

        >>> config = {"dataset_path": "custom", "dataset_name": "mytask"}
        >>> get_task_name_from_config(config)
        'custom_mytask'
    """
1234
1235
1236
1237
1238
1239
    if "task" in task_config:
        return task_config["task"]
    if "dataset_name" in task_config:
        return "{dataset_path}_{dataset_name}".format(**task_config)
    else:
        return "{dataset_path}".format(**task_config)
lintangsutawika's avatar
lintangsutawika committed
1240

1241

Baber's avatar
nit  
Baber committed
1242
def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> str:
Baber's avatar
Baber committed
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
    """
    Extract the name from an instantiated task object.

    Handles both ConfigurableTask and legacy Task objects with different
    attribute conventions for storing the task name.

    Args:
        task_object: An instantiated task object

    Returns:
        String name of the task

    Example:
        >>> task = ConfigurableTask(config={"task": "hellaswag"})
        >>> get_task_name_from_object(task)
        'hellaswag'
    """
1260
1261
    if hasattr(task_object, "config"):
        return task_object._config["task"]
lintangsutawika's avatar
lintangsutawika committed
1262
1263
1264
1265
1266
1267
1268
1269
1270

    # TODO: scrap this
    # this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
    return (
        task_object.EVAL_HARNESS_NAME
        if hasattr(task_object, "EVAL_HARNESS_NAME")
        else type(task_object).__name__
    )

1271

Baber's avatar
nit  
Baber committed
1272
def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
Baber's avatar
Baber committed
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
    """
    Validate that no tasks appear in multiple groups simultaneously.

    Helper function used to prevent conflicts when multiple groups claim
    the same constituent task. This could lead to ambiguous configuration
    like conflicting num_fewshot values.

    Args:
        task_dict: Dictionary mapping group names to lists of subtask names

    Raises:
        ValueError: If any tasks appear in multiple groups

    Example:
        >>> task_dict = {
        ...     "group1": ["task_a", "task_b"],
        ...     "group2": ["task_b", "task_c"]  # task_b appears twice!
        ... }
        >>> _check_duplicates(task_dict)  # Raises ValueError
Lintang Sutawika's avatar
Lintang Sutawika committed
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
    """
    subtask_names = []
    for key, value in task_dict.items():
        subtask_names.extend(value)

    duplicate_tasks = {
        task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
    }

    # locate the potentially problematic groups that seem to 'compete' for constituent subtasks
    competing_groups = [
        group
        for group in task_dict.keys()
        if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
    ]

    if len(duplicate_tasks) > 0:
        raise ValueError(
            f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
        )


1314
def get_task_dict(
Baber's avatar
nit  
Baber committed
1315
    task_name_list: Union[str, list[Union[str, dict, "Task"]]],
1316
    task_manager: Optional[TaskManager] = None,
Baber's avatar
nit  
Baber committed
1317
) -> dict[str, Union["ConfigurableTask", "Task"]]:
Baber's avatar
Baber committed
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
    """
    Create a dictionary of task objects from mixed input types.

    This is the main public API for loading tasks. It accepts various input
    formats (names, configs, objects) and returns a unified dictionary of
    instantiated task objects ready for evaluation.

    The function handles:
    - String task names (looked up via TaskManager)
    - Configuration dictionaries (processed as inline configs)
    - Pre-instantiated Task objects (used as-is)
    - Validation to prevent conflicting group memberships

    Args:
        task_name_list: Mixed list of task specifications:
                       - str: Task name to look up
                       - dict: Inline task configuration
                       - Task: Pre-instantiated task object
        task_manager: TaskManager instance for name resolution.
                     If None, creates a default TaskManager.

    Returns:
        Dictionary mapping task names to instantiated task objects.
        All tasks are ready for evaluation.

    Raises:
        TypeError: If task_name_list contains unsupported types
        ValueError: If there are conflicting group memberships

    Example:
        Mixed input types::

            tasks = get_task_dict([
                "hellaswag",                              # lookup by name
                {"task": "arc_easy", "num_fewshot": 5},   # inline config
                pre_existing_task_object                  # direct object
            ])

        Simple case::

            tasks = get_task_dict("hellaswag")
            # Returns: {"hellaswag": ConfigurableTask(...)}

        With custom TaskManager::

            tm = TaskManager(include_path="/custom/tasks")
            tasks = get_task_dict(["custom_task"], task_manager=tm)
1365
    """
Baber's avatar
nit  
Baber committed
1366
    from lm_eval.api.task import ConfigurableTask, Task
Lintang Sutawika's avatar
Lintang Sutawika committed
1367

1368
    task_name_from_string_dict = {}
1369
1370
1371
    task_name_from_config_dict = {}
    task_name_from_object_dict = {}

1372
    if isinstance(task_name_list, str):
lintangsutawika's avatar
lintangsutawika committed
1373
        task_name_list = [task_name_list]
1374
1375
1376
1377
1378
1379
1380
1381
1382
    elif isinstance(task_name_list, list):
        if not all([isinstance(task, (str, dict, Task)) for task in task_name_list]):
            raise TypeError(
                "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
            )
    else:
        raise TypeError(
            f"Expected a 'str' or 'list' but received {type(task_name_list)}."
        )
lintangsutawika's avatar
lintangsutawika committed
1383

1384
    string_task_name_list = [task for task in task_name_list if isinstance(task, str)]
1385
1386
1387
    others_task_name_list = [
        task for task in task_name_list if not isinstance(task, str)
    ]
1388
1389
1390
    if len(string_task_name_list) > 0:
        if task_manager is None:
            task_manager = TaskManager()
lintangsutawika's avatar
lintangsutawika committed
1391

1392
1393
1394
        task_name_from_string_dict = task_manager.load_task_or_group(
            string_task_name_list
        )
1395

1396
1397
    for task_element in others_task_name_list:
        if isinstance(task_element, dict):
1398
1399
            task_name_from_config_dict = {
                **task_name_from_config_dict,
1400
                **task_manager.load_config(config=task_element),
1401
1402
1403
1404
1405
            }

        elif isinstance(task_element, Task):
            task_name_from_object_dict = {
                **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
1406
                get_task_name_from_object(task_element): task_element,
1407
            }
lintangsutawika's avatar
lintangsutawika committed
1408

1409
    if not set(task_name_from_string_dict.keys()).isdisjoint(
lintangsutawika's avatar
lintangsutawika committed
1410
        set(task_name_from_object_dict.keys())
1411
1412
    ):
        raise ValueError
1413

Lintang Sutawika's avatar
Lintang Sutawika committed
1414
    final_task_dict = {
1415
        **task_name_from_string_dict,
lintangsutawika's avatar
lintangsutawika committed
1416
        **task_name_from_config_dict,
1417
        **task_name_from_object_dict,
lintangsutawika's avatar
lintangsutawika committed
1418
    }
Lintang Sutawika's avatar
Lintang Sutawika committed
1419
1420
1421
1422
1423
1424
1425
1426

    # behavior can get odd if one tries to invoke several groups that "compete" for the same task.
    # (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
    # and we'd be unsure which to use and report.)
    # we explicitly check and error in this case.
    _check_duplicates(get_subtask_list(final_task_dict))

    return final_task_dict