Commit 15e930af authored by Baber's avatar Baber
Browse files

refactor: improve type hints and simplify YAML loading functions

parent 4a0a8bd8
# ruff: noqa E402
from __future__ import annotations
"""
Task Management Module for LM Evaluation Harness.
......@@ -26,7 +30,6 @@ Example:
include_defaults=True
)
"""
import collections
import functools
import importlib.util
......@@ -34,16 +37,11 @@ import inspect
import logging
import sys
from functools import partial
from glob import iglob
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
Mapping,
Optional,
Union,
)
import yaml
......@@ -55,6 +53,8 @@ from lm_eval.utils import pattern_match, setup_logging
if TYPE_CHECKING:
from collections.abc import Generator, Mapping
from lm_eval.api.task import ConfigurableTask, Task
eval_logger = logging.getLogger(__name__)
......@@ -72,212 +72,200 @@ _IGNORE_DIRS = (
)
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
"""YAML constructor that ignores !function tags during simple parsing."""
return None
def _mk_function_ctor(base_dir: Path, resolve: bool):
"""Return a constructor that resolves !function relative to *base_dir*."""
def ctor(loader: yaml.Loader, node: yaml.Node):
spec = loader.construct_scalar(node)
if not resolve: # “simple” mode → stub
return lambda *a, **kw: None
return _import_function(spec, base_dir)
@functools.lru_cache(maxsize=2048) # ← reuse per (directory, simple) pair
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
"""
Return a custom YAML Loader class bound to *yaml_dir*.
yaml_dir
Directory that holds the YAML file being parsed.
We capture it so that !function look-ups can resolve relative
Python files like my_utils.some_fn ➜ yaml_dir / "my_utils.py".
simple
If True we ignore !function completely (used by `mode="simple"`),
used on TaskManager init to index.
return ctor
@functools.lru_cache(maxsize=1024)
def make_yaml_loader(base_dir: Path, *, simple: bool) -> type[yaml.Loader]:
"""Factory that returns a *cached* PyYAML Loader subclass bound to *base_dir*.
simple=True → !function returns a stub (used when only metadata is needed).
"""
class Loader(_Base):
"""Dynamically-generated loader that knows its base directory."""
# Register (or stub) the !function constructor **for this Loader only**
if simple:
yaml.add_constructor("!function", ignore_constructor, Loader=Loader)
else:
yaml.add_constructor(
"!function",
# capture yaml_dir once so the lambda is fast and pickle-able
lambda ld, node, _dir=yaml_dir: _import_function(
ld.construct_scalar(node),
base_path=_dir,
),
Loader=Loader,
)
pass # dynamic subclass just to carry custom constructors
yaml.add_constructor(
"!function",
_mk_function_ctor(base_dir, resolve=not simple),
Loader=Loader,
)
return Loader
@functools.lru_cache(maxsize=None) # ← cache module objects
def _import_function(qualname: str, *, base_path: Path) -> Callable:
"""
Dynamically import a function from a Python module relative to base_path.
This function enables YAML files to reference Python functions using
the !function tag. Supports dot notation for nested modules.
Args:
qualname: Qualified function name like "my_module.my_function"
base_path: Base directory for resolving relative module paths
Returns:
The imported callable function
@functools.lru_cache(maxsize=4096)
def _read_yaml(path: Path, *, resolve_functions: bool) -> dict:
loader_cls = make_yaml_loader(path.parent, simple=not resolve_functions)
with path.open("rb") as fh:
return yaml.load(fh, Loader=loader_cls)
Raises:
ValueError: If qualname doesn't contain a module part
Example:
>>> func = _import_function("utils.custom_metric", base_path=Path("/tasks"))
>>> result = func(predictions, references)
@functools.cache
def _import_function(qual: str, base_dir: Path):
"""Import `qual` where qual looks like "my_utils.some_fn".
Search order:
1. <base_dir>/my_utils.py (relative file)
2. python importlib (package/module already importable)
Uses file *mtime* so edits are reloaded without killing the process.
"""
mod_path, _, func_name = qualname.rpartition(".")
if not mod_path:
raise ValueError(f"{qualname!r} has no module part")
file_path = base_path / f"{mod_path.replace('.', '/')}.py"
module_name = f"_yaml_dynamic.{hash(file_path)}_{file_path.stem}"
if module_name in sys.modules:
mod = sys.modules[module_name]
else:
spec = importlib.util.spec_from_file_location(module_name, file_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
sys.modules[module_name] = mod
return getattr(mod, func_name)
import importlib
@functools.lru_cache(maxsize=4096)
def _parse_yaml_file(path: Path, mode: str) -> dict:
"""
Parse a single YAML file with the appropriate loader.
if "." not in qual:
msg = f"!function value '{qual}' must contain a '.'"
raise ValueError(msg)
Args:
path: Path to the YAML file
mode: Parsing mode ("full" or "simple")
Returns:
Parsed YAML configuration as dictionary
"""
loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
with path.open("rb") as fh:
return yaml.load(fh, Loader=loader_cls)
mod_part, _, fn_name = qual.rpartition(".")
relative_path = (base_dir / f"{mod_part.replace('.', '/')}.py").resolve()
if relative_path.exists():
mtime = relative_path.stat().st_mtime_ns # for cache busting
module_key = f"{relative_path}:{mtime}"
if module_key in sys.modules:
mod = sys.modules[module_key]
else:
spec = importlib.util.spec_from_file_location(module_key, relative_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[arg-type]
sys.modules[module_key] = mod
return getattr(mod, fn_name)
@functools.lru_cache(maxsize=4096)
def _get_cached_config(yaml_path: Path, mode: str) -> dict:
"""Load and cache resolved YAML configs"""
# Parse the YAML file
yaml_config = _parse_yaml_file(yaml_path, mode)
yaml_dir = yaml_path.parent
# Handle includes
include = yaml_config.pop("include", None)
if not include:
return yaml_config
include_paths = include if isinstance(include, list) else [include]
final_cfg: dict = {}
for inc in reversed(include_paths):
if inc is None:
continue
inc_path = Path(inc)
if not inc_path.is_absolute():
inc_path = (yaml_dir / inc_path).resolve()
# Recursive call will use the cache
included = _get_cached_config(inc_path, mode)
final_cfg.update(included)
# Fallback to regular import mechanism
import importlib
final_cfg.update(yaml_config) # local keys win
return final_cfg
module = importlib.import_module(mod_part)
return getattr(module, fn_name)
def load_yaml_config(
yaml_path: Union[Path, str, None] = None,
yaml_config: Optional[dict] = None,
yaml_dir: Optional[Path] = None,
mode: str = "full",
yaml_path: Path | str,
*,
_seen: Optional[set[tuple[Path, str]]] = None,
resolve_functions: bool = True,
resolve_includes: bool = True,
_seen: set[tuple[Path, bool]] | None = None,
) -> dict:
"""
Parse a YAML config with optional include handling.
Parameters
----------
yaml_path
Path to the main YAML file. Needed unless *yaml_config* is
supplied directly (e.g. by tests).
yaml_config
Pre-parsed dict to use instead of reading *yaml_path*.
yaml_dir
Base directory for resolving relative include paths. Defaults
to `yaml_path.parent`.
mode
"full" – honour !function tags
"simple" – ignore !function (faster).
_seen
**Internal** recursion set: tuples of (absolute-path, mode).
Prevents include cycles such as A → B → A.
"""
if yaml_config is None and yaml_path is None:
raise ValueError("load_yaml_config needs either yaml_path or yaml_config")
# ------------------------------------------------------------------ cycle guard
"""Read YAML once, optionally walk `include:` chains, with cycle detection."""
path = Path(yaml_path).expanduser().resolve()
if _seen is None:
_seen = set()
if yaml_path is not None:
yaml_path = Path(yaml_path).expanduser().resolve()
# ---------- fast-path: use LRU cached function ----------
if yaml_config is None and resolve_includes:
return _get_cached_config(yaml_path, mode)
key = (yaml_path.resolve(), mode)
if key in _seen:
raise ValueError(f"Include cycle detected at {yaml_path}")
_seen.add(key)
# ------------------------------------------------------------------ load / parse
if yaml_config is None: # ordinary path-based load
yaml_config = _parse_yaml_file(yaml_path, mode)
if yaml_dir is None and yaml_path is not None:
yaml_dir = yaml_path.parent
assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"
# ------------------------------------------------------------------ handle include
include = yaml_config.pop("include", None)
if not include and not resolve_includes:
return yaml_config
include_paths = include if isinstance(include, list) else [include]
final_cfg: dict = {}
for inc in reversed(include_paths):
if inc is None: # guard against explicit nulls
continue
inc_path = Path(inc)
if not inc_path.is_absolute():
inc_path = (yaml_dir / inc_path).resolve()
included = load_yaml_config(
yaml_path=inc_path,
mode=mode,
yaml_dir=inc_path.parent,
_seen=_seen, # <-- pass set downward
key = (path, resolve_functions)
if key in _seen:
msg = f"Include cycle at {path}"
raise ValueError(msg)
_seen.add(key)
cfg = _read_yaml(path, resolve_functions=resolve_functions)
if not resolve_includes or "include" not in cfg:
return cfg
base_dir = path.parent
merged: dict = {}
for inc in cfg.pop("include"):
inc_path = (
(base_dir / inc).resolve() if not Path(inc).is_absolute() else Path(inc)
)
final_cfg.update(included)
final_cfg.update(yaml_config) # local keys win
return final_cfg
merged.update(
load_yaml_config(
inc_path,
resolve_functions=resolve_functions,
_seen=_seen,
),
)
merged.update(cfg) # local keys win
return merged
# def load_yaml_config(
# yaml_path: Union[Path, str, None] = None,
# yaml_config: Optional[dict] = None,
# yaml_dir: Optional[Path] = None,
# mode: str = "full",
# *,
# _seen: Optional[set[tuple[Path, str]]] = None,
# resolve_includes: bool = True,
# ) -> dict:
# """
# Parse a YAML config with optional include handling.
#
# Parameters
# ----------
# yaml_path
# Path to the main YAML file. Needed unless *yaml_config* is
# supplied directly (e.g. by tests).
# yaml_config
# Pre-parsed dict to use instead of reading *yaml_path*.
# yaml_dir
# Base directory for resolving relative include paths. Defaults
# to `yaml_path.parent`.
# mode
# "full" - honour !function tags
# "simple" - ignore !function (faster).
# _seen
# **Internal** recursion set: tuples of (absolute-path, mode).
# Prevents include cycles such as A → B → A.
# """
# if yaml_config is None and yaml_path is None:
# raise ValueError("load_yaml_config needs either yaml_path or yaml_config")
#
# # ------------------------------------------------------------------ cycle guard
# if _seen is None:
# _seen = set()
# if yaml_path is not None:
# yaml_path = Path(yaml_path).expanduser().resolve()
#
# # ---------- fast-path: use LRU cached function ----------
# if yaml_config is None and resolve_includes:
# return _get_cached_config(yaml_path, mode)
#
# key = (yaml_path.resolve(), mode)
# if key in _seen:
# raise ValueError(f"Include cycle detected at {yaml_path}")
# _seen.add(key)
#
# # ------------------------------------------------------------------ load / parse
# if yaml_config is None: # ordinary path-based load
# yaml_config = _parse_yaml_file(yaml_path, mode)
#
# if yaml_dir is None and yaml_path is not None:
# yaml_dir = yaml_path.parent
# assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"
#
# # ------------------------------------------------------------------ handle include
# include = yaml_config.pop("include", None)
# if not include and not resolve_includes:
# return yaml_config
#
# include_paths = include if isinstance(include, list) else [include]
# final_cfg: dict = {}
#
# for inc in reversed(include_paths):
# if inc is None: # guard against explicit nulls
# continue
# inc_path = Path(inc)
# if not inc_path.is_absolute():
# inc_path = (yaml_dir / inc_path).resolve()
# included = load_yaml_config(
# yaml_path=inc_path,
# mode=mode,
# yaml_dir=inc_path.parent,
# _seen=_seen, # <-- pass set downward
# )
# final_cfg.update(included)
#
# final_cfg.update(yaml_config) # local keys win
# return final_cfg
def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
"""
Recursively iterate over all YAML files in a directory tree.
"""Recursively iterate over all YAML files in a directory tree.
Excludes files in ignored directories like __pycache__ and .ipynb_checkpoints.
......@@ -290,8 +278,10 @@ def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, Non
Example:
>>> for yaml_file in iter_yaml_files(Path("tasks")):
... print(f"Found task config: {yaml_file}")
"""
for p in iglob(str(root / "**/*.yaml"), recursive=True):
# for p in iglob(str(root / "**/*.yaml"), recursive=True):
for p in root.glob("**/*.yaml"):
# ignore check
path = Path(p)
# Check if any parent directory is in the ignore list
......@@ -301,8 +291,7 @@ def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, Non
class TaskManager:
"""
Central manager for task discovery, indexing, and loading.
"""Central manager for task discovery, indexing, and loading.
TaskManager scans directories for YAML task configurations and maintains
an index of all available tasks, groups, and tags. It provides methods
......@@ -334,17 +323,17 @@ class TaskManager:
verbosity="INFO"
)
custom_tasks = [t for t in tm.all_tasks if "custom" in t]
"""
def __init__(
self,
verbosity: Optional[str] = None,
include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
verbosity: str | None = None,
include_path: str | Path | list[str | Path] | None = None,
include_defaults: bool = True,
metadata: Optional[dict[str, dict[str, Any]]] = None,
metadata: dict[str, dict[str, Any]] | None = None,
) -> None:
"""
Initialize the TaskManager.
"""Initialize the TaskManager.
Args:
verbosity: Logging verbosity level (DEBUG, INFO, WARNING, ERROR)
......@@ -352,35 +341,37 @@ class TaskManager:
path or list of paths.
include_defaults: Whether to include default tasks from lm_eval/tasks/
metadata: Global metadata dictionary to inject into all task configs
"""
if verbosity is not None:
setup_logging(verbosity)
self.include_path = include_path
self.metadata = metadata
self._task_index = self.initialize_tasks(
include_path=include_path, include_defaults=include_defaults
include_path=include_path,
include_defaults=include_defaults,
)
self._all_tasks = sorted(list(self._task_index.keys()))
self._all_tasks = sorted(self._task_index.keys())
self._all_groups = sorted(
[x for x in self._all_tasks if self._task_index[x]["type"] == "group"]
[x for x in self._all_tasks if self._task_index[x]["type"] == "group"],
)
self._all_subtasks = sorted(
[
x
for x in self._all_tasks
if self._task_index[x]["type"] in ["task", "python_task"]
]
],
)
self._all_tags = sorted(
[x for x in self._all_tasks if self._task_index[x]["type"] == "tag"]
[x for x in self._all_tasks if self._task_index[x]["type"] == "tag"],
)
self.task_group_map = collections.defaultdict(list)
def initialize_tasks(
self,
include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None,
include_path: str | Path | list[str | Path] | None = None,
include_defaults: bool = True,
) -> dict[str, dict]:
"""Creates a dictionary of tasks indexes.
......@@ -390,13 +381,12 @@ class TaskManager:
Can provide more than one such path as a list.
:param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
return
Return:
dictionary of task names as key and task metadata
"""
if include_defaults:
all_paths = [Path(__file__).parent]
else:
all_paths = []
all_paths = [Path(__file__).parent] if include_defaults else []
if include_path is not None:
if isinstance(include_path, (str, Path)):
include_path = [include_path]
......@@ -431,7 +421,7 @@ class TaskManager:
return self._all_tags
@property
def task_index(self) -> dict[str, dict[str, Union[str, int, list[str]]]]:
def task_index(self) -> dict[str, dict[str, str | int | list[str]]]:
"""Get the complete task index with metadata for all tasks."""
return self._task_index
......@@ -441,8 +431,7 @@ class TaskManager:
list_tags: bool = True,
list_subtasks: bool = True,
) -> str:
"""
Return a Markdown table (as a string) listing groups, tags and/or subtasks
"""Return a Markdown table (as a string) listing groups, tags and/or subtasks
known to this TaskManager. Safe for configs whose yaml_path is -1 and for
task configs whose `include:` is a list.
"""
......@@ -458,7 +447,8 @@ class TaskManager:
def first_output_type_from_includes(cfg: dict, base: Path) -> str:
"""Walk cfg['include'] (string or list) and return the first
include that itself specifies an output_type."""
include that itself specifies an output_type.
"""
inc_raw = cfg.get("include")
if not inc_raw:
return ""
......@@ -587,9 +577,8 @@ class TaskManager:
"""Check if a config dictionary defines a task list."""
return "task_list" in config and isinstance(config["task_list"], list)
def _get_yaml_path(self, name: str) -> Union[str, int, list[str]]:
"""
Get the YAML file path for a registered task.
def _get_yaml_path(self, name: str) -> str | int | list[str]:
"""Get the YAML file path for a registered task.
Args:
name: Task name
......@@ -599,14 +588,14 @@ class TaskManager:
Raises:
ValueError: If task name is not registered
"""
if name not in self.task_index:
raise ValueError
return self.task_index[name]["yaml_path"]
def _get_config(self, name: str) -> dict:
"""
Load the full configuration for a registered task.
"""Load the full configuration for a registered task.
Args:
name: Task name
......@@ -616,18 +605,17 @@ class TaskManager:
Raises:
ValueError: If task name is not registered
"""
if name not in self.task_index:
raise ValueError
yaml_path = self._get_yaml_path(name)
if yaml_path == -1:
return {}
else:
return load_yaml_config(Path(yaml_path), mode="full")
return load_yaml_config(Path(yaml_path))
def _get_tasklist(self, name: str) -> Union[list[str], int]:
"""
Get the task list for a group or tag.
def _get_tasklist(self, name: str) -> list[str] | int:
"""Get the task list for a group or tag.
Args:
name: Group or tag name
......@@ -637,6 +625,7 @@ class TaskManager:
Raises:
ValueError: If name refers to an individual task
"""
if self._name_is_task(name):
raise ValueError
......@@ -648,10 +637,10 @@ class TaskManager:
task_type: str,
yaml_path: str,
tasks_and_groups: dict[str, dict],
config: Optional[dict] = None,
populate_tags_fn: Optional[Callable] = None,
config: dict | None = None,
populate_tags_fn: Callable | None = None,
) -> None:
"""Helper method to register a task in the tasks_and_groups dict"""
"""Helper method to register a task in the tasks_and_groups dict."""
tasks_and_groups[task_name] = {
"type": task_type,
"yaml_path": yaml_path,
......@@ -661,9 +650,12 @@ class TaskManager:
populate_tags_fn(config, task_name, tasks_and_groups)
def _merge_task_configs(
self, base_config: dict, task_specific_config: dict, task_name: str
self,
base_config: dict,
task_specific_config: dict,
task_name: str,
) -> dict:
"""Merge base config with task-specific overrides for task_list configs"""
"""Merge base config with task-specific overrides for task_list configs."""
if task_specific_config:
task_specific_config = task_specific_config.copy()
task_specific_config.pop("task", None)
......@@ -671,9 +663,11 @@ class TaskManager:
return {**base_config, "task": task_name}
def _process_tag_subtasks(
self, tag_name: str, update_config: Optional[dict] = None
self,
tag_name: str,
update_config: dict | None = None,
) -> dict:
"""Process subtasks for a tag and return loaded tasks"""
"""Process subtasks for a tag and return loaded tasks."""
subtask_list = self._get_tasklist(tag_name)
fn = partial(
self._load_individual_task_or_group,
......@@ -681,9 +675,8 @@ class TaskManager:
)
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
def _process_alias(self, config: dict, group: Optional[str] = None) -> dict:
"""
Process group alias configuration.
def _process_alias(self, config: dict, group: str | None = None) -> dict:
"""Process group alias configuration.
If the group is not the same as the original group which the group alias
was intended for, set the group_alias to None instead.
......@@ -694,21 +687,26 @@ class TaskManager:
Returns:
Modified configuration with processed aliases
"""
if ("group_alias" in config) and ("group" in config) and group is not None:
if config["group"] != group:
config["group_alias"] = None
if (
("group_alias" in config)
and ("group" in config)
and group is not None
and config["group"] != group
):
config["group_alias"] = None
return config
def _class_has_config_in_constructor(self, cls) -> bool:
"""
Check if a class constructor accepts a 'config' parameter.
"""Check if a class constructor accepts a 'config' parameter.
Args:
cls: Class to inspect
Returns:
True if constructor has 'config' parameter, False otherwise
"""
constructor = getattr(cls, "__init__", None)
return (
......@@ -725,10 +723,9 @@ class TaskManager:
self,
cfg: dict,
task_name: str,
yaml_path: Union[str, None],
yaml_path: str | None,
) -> dict:
"""
Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
"""Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
Returns {task_name: task_object}.
"""
from lm_eval.api.task import ConfigurableTask, Task # local import avoids cycle
......@@ -772,17 +769,16 @@ class TaskManager:
def _create_group_object(
self,
cfg: dict,
parent_name: Union[str, None] = None,
) -> tuple[GroupConfig, list[Union[str, dict]]]:
"""
Build GroupConfig and return (group_obj, subtask_names).
parent_name: str | None = None,
) -> tuple[GroupConfig, list[str | dict]]:
"""Build GroupConfig and return (group_obj, subtask_names).
Resolves tag expansion.
"""
if self.metadata is not None:
cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
grp = GroupConfig(**cfg)
subtasks: list[Union[str, dict]] = []
subtasks: list[str | dict] = []
if grp.task:
for t in grp.task:
if isinstance(t, str) and self._name_is_tag(t):
......@@ -793,9 +789,9 @@ class TaskManager:
def _load_subtasks(
self,
subtasks: list[Union[str, dict]],
parent_name: Union[str, GroupConfig, None],
update_config: Union[dict, None],
subtasks: list[str | dict],
parent_name: str | GroupConfig | None,
update_config: dict | None,
) -> Mapping:
"""Return merged mapping of all subtasks, handling duplicates."""
fn = functools.partial(
......@@ -807,16 +803,14 @@ class TaskManager:
def _load_individual_task_or_group(
self,
payload: Union[str, dict],
payload: str | dict,
*,
parent_name: Union[str, None] = None,
update_config: Union[dict, None] = None,
parent_name: str | None = None,
update_config: dict | None = None,
) -> Mapping:
"""
Public helper that turns *payload* (str task/group/tag **or** dict config)
"""Public helper that turns *payload* (str task/group/tag **or** dict config)
into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
"""
# ------------------------------------------------------------------ STRING
if isinstance(payload, str):
# If caller supplied extra overrides, treat as dict immediately
......@@ -852,14 +846,15 @@ class TaskManager:
grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS}
grp_obj, subtasks = self._create_group_object(grp_only, parent_name)
return {
grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None)
grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None),
}
# ------------ registered TAG ------------------------------------
if self._name_is_tag(payload):
return self._process_tag_subtasks(payload, update_config=None)
raise ValueError(f"Unknown task / group / tag name: {payload!r}")
msg = f"Unknown task / group / tag name: {payload!r}"
raise ValueError(msg)
# ------------------------------------------------------------------- DICT
if isinstance(payload, dict):
......@@ -882,7 +877,7 @@ class TaskManager:
n
for n in self.task_group_map[parent_name]
if n.startswith(name)
]
],
)
if count:
name = f"{name}-{count}"
......@@ -904,15 +899,16 @@ class TaskManager:
name = payload["task"]
return self._create_task_object(payload, name, yaml_path=None)
msg = f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
raise TypeError(
f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
msg,
)
def load_task_or_group(
self, task_list: Optional[Union[str, list[str]]] = None
self,
task_list: str | list[str] | None = None,
) -> dict:
"""
Load multiple tasks or groups from a list of names.
"""Load multiple tasks or groups from a list of names.
This is the main entry point for loading tasks. It handles lists
of task names and delegates to _load_individual_task_or_group for
......@@ -936,23 +932,19 @@ class TaskManager:
tasks = tm.load_task_or_group("arc_group")
# Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
"""
if isinstance(task_list, str):
task_list = [task_list]
all_loaded_tasks = dict(
return dict(
collections.ChainMap(
*map(
lambda task: self._load_individual_task_or_group(task),
task_list,
)
)
*(self._load_individual_task_or_group(task) for task in task_list),
),
)
return all_loaded_tasks
def load_config(self, config: dict) -> Mapping:
"""
Load a task from an inline configuration dictionary.
"""Load a task from an inline configuration dictionary.
Args:
config: Configuration dictionary defining the task
......@@ -963,12 +955,12 @@ class TaskManager:
Example:
>>> config = {"task": "hellaswag", "num_fewshot": 5}
>>> task_dict = tm.load_config(config)
"""
return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: Union[str, Path]) -> dict[str, dict]:
"""
Scan a directory for task configurations and build an index.
def _get_task_and_group(self, task_dir: str | Path) -> dict[str, dict]:
"""Scan a directory for task configurations and build an index.
Creates a dictionary of task metadata by recursively scanning for
YAML files and parsing their configurations. This method handles:
......@@ -991,13 +983,15 @@ class TaskManager:
Note:
This method is called during TaskManager initialization to build
the master task index. It uses 'simple' parsing mode for performance.
"""
def _populate_tags_and_groups(
config: dict, task: str, tasks_and_groups: dict[str, dict]
config: dict,
task: str,
tasks_and_groups: dict[str, dict],
) -> None:
"""
Extract and register tags from a task configuration.
"""Extract and register tags from a task configuration.
Tags allow grouping tasks by theme or category. This function
processes the 'tag' field in task configs and maintains tag
......@@ -1007,6 +1001,7 @@ class TaskManager:
config: Task configuration dictionary
task: Name of the task being processed
tasks_and_groups: Master index to update with tag information
"""
# TODO: remove group in next release
if "tag" in config:
......@@ -1024,7 +1019,7 @@ class TaskManager:
elif tasks_and_groups[tag]["type"] != "tag":
eval_logger.info(
f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
"This may affect tasks you want to call."
"This may affect tasks you want to call.",
)
break
else:
......@@ -1041,7 +1036,9 @@ class TaskManager:
for yaml_path in iter_yaml_files(task_dir_path):
try:
config = load_yaml_config(
yaml_path, mode="simple", resolve_includes=False
yaml_path,
resolve_functions=False,
resolve_includes=False,
)
except (FileNotFoundError, YAMLError, OSError) as err:
eval_logger.debug(f"File {yaml_path} could not be loaded ({err})")
......@@ -1109,8 +1106,7 @@ class TaskManager:
def get_task_name_from_config(task_config: dict[str, str]) -> str:
"""
Extract a task name from a configuration dictionary.
"""Extract a task name from a configuration dictionary.
Determines the canonical name for a task based on its configuration,
with fallback strategies for different config formats.
......@@ -1129,18 +1125,17 @@ def get_task_name_from_config(task_config: dict[str, str]) -> str:
>>> config = {"dataset_path": "custom", "dataset_name": "mytask"}
>>> get_task_name_from_config(config)
'custom_mytask'
"""
if "task" in task_config:
return task_config["task"]
if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config)
else:
return "{dataset_path}".format(**task_config)
return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> str:
"""
Extract the name from an instantiated task object.
def get_task_name_from_object(task_object: ConfigurableTask | Task) -> str:
"""Extract the name from an instantiated task object.
Handles both ConfigurableTask and legacy Task objects with different
attribute conventions for storing the task name.
......@@ -1155,6 +1150,7 @@ def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) ->
>>> task = ConfigurableTask(config={"task": "hellaswag"})
>>> get_task_name_from_object(task)
'hellaswag'
"""
if hasattr(task_object, "config"):
return task_object._config["task"]
......@@ -1169,8 +1165,7 @@ def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) ->
def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
"""
Validate that no tasks appear in multiple groups simultaneously.
"""Validate that no tasks appear in multiple groups simultaneously.
Helper function used to prevent conflicts when multiple groups claim
the same constituent task. This could lead to ambiguous configuration
......@@ -1188,9 +1183,10 @@ def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
... "group2": ["task_b", "task_c"] # task_b appears twice!
... }
>>> _check_duplicates(task_dict) # Raises ValueError
"""
subtask_names = []
for key, value in task_dict.items():
for value in task_dict.values():
subtask_names.extend(value)
duplicate_tasks = {
......@@ -1200,22 +1196,22 @@ def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
# locate the potentially problematic groups that seem to 'compete' for constituent subtasks
competing_groups = [
group
for group in task_dict.keys()
for group in task_dict
if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
]
if len(duplicate_tasks) > 0:
msg = f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
raise ValueError(
f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
msg,
)
def get_task_dict(
task_name_list: Union[str, list[Union[str, dict, "Task"]]],
task_manager: Optional[TaskManager] = None,
) -> dict[str, Union["ConfigurableTask", "Task"]]:
"""
Create a dictionary of task objects from mixed input types.
task_name_list: str | list[str | dict | Task],
task_manager: TaskManager | None = None,
) -> dict[str, ConfigurableTask | Task]:
"""Create a dictionary of task objects from mixed input types.
This is the main public API for loading tasks. It accepts various input
formats (names, configs, objects) and returns a unified dictionary of
......@@ -1261,6 +1257,7 @@ def get_task_dict(
tm = TaskManager(include_path="/custom/tasks")
tasks = get_task_dict(["custom_task"], task_manager=tm)
"""
from lm_eval.api.task import Task
......@@ -1268,14 +1265,16 @@ def get_task_dict(
if isinstance(task_name_list, str):
task_name_list = [task_name_list]
elif not isinstance(task_name_list, list):
msg = f"Expected a 'str' or 'list' but received {type(task_name_list)}."
raise TypeError(
f"Expected a 'str' or 'list' but received {type(task_name_list)}."
msg,
)
# Validate list items
if not all(isinstance(task, (str, dict, Task)) for task in task_name_list):
msg = "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
raise TypeError(
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
msg,
)
# Ensure we have a task manager
......@@ -1289,7 +1288,8 @@ def get_task_dict(
# Pre-instantiated task object
task_name = get_task_name_from_object(task_spec)
if task_name in final_task_dict:
raise ValueError(f"Duplicate task name: {task_name}")
msg = f"Duplicate task name: {task_name}"
raise ValueError(msg)
final_task_dict[task_name] = task_spec
else:
# String or dict - use load_task_or_group
......@@ -1297,7 +1297,8 @@ def get_task_dict(
# Check for duplicate names
for name in result:
if name in final_task_dict:
raise ValueError(f"Duplicate task name: {name}")
msg = f"Duplicate task name: {name}"
raise ValueError(msg)
final_task_dict.update(result)
# Check for conflicting group memberships
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment