Commit 15e930af authored by Baber's avatar Baber
Browse files

refactor: improve type hints and simplify YAML loading functions

parent 4a0a8bd8
# ruff: noqa E402
from __future__ import annotations
""" """
Task Management Module for LM Evaluation Harness. Task Management Module for LM Evaluation Harness.
...@@ -26,7 +30,6 @@ Example: ...@@ -26,7 +30,6 @@ Example:
include_defaults=True include_defaults=True
) )
""" """
import collections import collections
import functools import functools
import importlib.util import importlib.util
...@@ -34,16 +37,11 @@ import inspect ...@@ -34,16 +37,11 @@ import inspect
import logging import logging
import sys import sys
from functools import partial from functools import partial
from glob import iglob
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Generator,
Mapping,
Optional,
Union,
) )
import yaml import yaml
...@@ -55,6 +53,8 @@ from lm_eval.utils import pattern_match, setup_logging ...@@ -55,6 +53,8 @@ from lm_eval.utils import pattern_match, setup_logging
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator, Mapping
from lm_eval.api.task import ConfigurableTask, Task from lm_eval.api.task import ConfigurableTask, Task
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -72,212 +72,200 @@ _IGNORE_DIRS = ( ...@@ -72,212 +72,200 @@ _IGNORE_DIRS = (
) )
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None: def _mk_function_ctor(base_dir: Path, resolve: bool):
"""YAML constructor that ignores !function tags during simple parsing.""" """Return a constructor that resolves !function relative to *base_dir*."""
return None
def ctor(loader: yaml.Loader, node: yaml.Node):
spec = loader.construct_scalar(node)
if not resolve: # “simple” mode → stub
return lambda *a, **kw: None
return _import_function(spec, base_dir)
@functools.lru_cache(maxsize=2048) # ← reuse per (directory, simple) pair return ctor
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
"""
Return a custom YAML Loader class bound to *yaml_dir*. @functools.lru_cache(maxsize=1024)
def make_yaml_loader(base_dir: Path, *, simple: bool) -> type[yaml.Loader]:
yaml_dir """Factory that returns a *cached* PyYAML Loader subclass bound to *base_dir*.
Directory that holds the YAML file being parsed. simple=True → !function returns a stub (used when only metadata is needed).
We capture it so that !function look-ups can resolve relative
Python files like my_utils.some_fn ➜ yaml_dir / "my_utils.py".
simple
If True we ignore !function completely (used by `mode="simple"`),
used on TaskManager init to index.
""" """
class Loader(_Base): class Loader(_Base):
"""Dynamically-generated loader that knows its base directory.""" pass # dynamic subclass just to carry custom constructors
# Register (or stub) the !function constructor **for this Loader only**
if simple:
yaml.add_constructor("!function", ignore_constructor, Loader=Loader)
else:
yaml.add_constructor(
"!function",
# capture yaml_dir once so the lambda is fast and pickle-able
lambda ld, node, _dir=yaml_dir: _import_function(
ld.construct_scalar(node),
base_path=_dir,
),
Loader=Loader,
)
yaml.add_constructor(
"!function",
_mk_function_ctor(base_dir, resolve=not simple),
Loader=Loader,
)
return Loader return Loader
@functools.lru_cache(maxsize=None) # ← cache module objects @functools.lru_cache(maxsize=4096)
def _import_function(qualname: str, *, base_path: Path) -> Callable: def _read_yaml(path: Path, *, resolve_functions: bool) -> dict:
""" loader_cls = make_yaml_loader(path.parent, simple=not resolve_functions)
Dynamically import a function from a Python module relative to base_path. with path.open("rb") as fh:
return yaml.load(fh, Loader=loader_cls)
This function enables YAML files to reference Python functions using
the !function tag. Supports dot notation for nested modules.
Args:
qualname: Qualified function name like "my_module.my_function"
base_path: Base directory for resolving relative module paths
Returns:
The imported callable function
Raises:
ValueError: If qualname doesn't contain a module part
Example: @functools.cache
>>> func = _import_function("utils.custom_metric", base_path=Path("/tasks")) def _import_function(qual: str, base_dir: Path):
>>> result = func(predictions, references) """Import `qual` where qual looks like "my_utils.some_fn".
Search order:
1. <base_dir>/my_utils.py (relative file)
2. python importlib (package/module already importable)
Uses file *mtime* so edits are reloaded without killing the process.
""" """
mod_path, _, func_name = qualname.rpartition(".") import importlib
if not mod_path:
raise ValueError(f"{qualname!r} has no module part")
file_path = base_path / f"{mod_path.replace('.', '/')}.py"
module_name = f"_yaml_dynamic.{hash(file_path)}_{file_path.stem}"
if module_name in sys.modules:
mod = sys.modules[module_name]
else:
spec = importlib.util.spec_from_file_location(module_name, file_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
sys.modules[module_name] = mod
return getattr(mod, func_name)
@functools.lru_cache(maxsize=4096) if "." not in qual:
def _parse_yaml_file(path: Path, mode: str) -> dict: msg = f"!function value '{qual}' must contain a '.'"
""" raise ValueError(msg)
Parse a single YAML file with the appropriate loader.
Args: mod_part, _, fn_name = qual.rpartition(".")
path: Path to the YAML file relative_path = (base_dir / f"{mod_part.replace('.', '/')}.py").resolve()
mode: Parsing mode ("full" or "simple")
Returns:
Parsed YAML configuration as dictionary
"""
loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
with path.open("rb") as fh:
return yaml.load(fh, Loader=loader_cls)
if relative_path.exists():
mtime = relative_path.stat().st_mtime_ns # for cache busting
module_key = f"{relative_path}:{mtime}"
if module_key in sys.modules:
mod = sys.modules[module_key]
else:
spec = importlib.util.spec_from_file_location(module_key, relative_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[arg-type]
sys.modules[module_key] = mod
return getattr(mod, fn_name)
@functools.lru_cache(maxsize=4096) # Fallback to regular import mechanism
def _get_cached_config(yaml_path: Path, mode: str) -> dict: import importlib
"""Load and cache resolved YAML configs"""
# Parse the YAML file
yaml_config = _parse_yaml_file(yaml_path, mode)
yaml_dir = yaml_path.parent
# Handle includes
include = yaml_config.pop("include", None)
if not include:
return yaml_config
include_paths = include if isinstance(include, list) else [include]
final_cfg: dict = {}
for inc in reversed(include_paths):
if inc is None:
continue
inc_path = Path(inc)
if not inc_path.is_absolute():
inc_path = (yaml_dir / inc_path).resolve()
# Recursive call will use the cache
included = _get_cached_config(inc_path, mode)
final_cfg.update(included)
final_cfg.update(yaml_config) # local keys win module = importlib.import_module(mod_part)
return final_cfg return getattr(module, fn_name)
def load_yaml_config( def load_yaml_config(
yaml_path: Union[Path, str, None] = None, yaml_path: Path | str,
yaml_config: Optional[dict] = None,
yaml_dir: Optional[Path] = None,
mode: str = "full",
*, *,
_seen: Optional[set[tuple[Path, str]]] = None, resolve_functions: bool = True,
resolve_includes: bool = True, resolve_includes: bool = True,
_seen: set[tuple[Path, bool]] | None = None,
) -> dict: ) -> dict:
""" """Read YAML once, optionally walk `include:` chains, with cycle detection."""
Parse a YAML config with optional include handling. path = Path(yaml_path).expanduser().resolve()
Parameters
----------
yaml_path
Path to the main YAML file. Needed unless *yaml_config* is
supplied directly (e.g. by tests).
yaml_config
Pre-parsed dict to use instead of reading *yaml_path*.
yaml_dir
Base directory for resolving relative include paths. Defaults
to `yaml_path.parent`.
mode
"full" – honour !function tags
"simple" – ignore !function (faster).
_seen
**Internal** recursion set: tuples of (absolute-path, mode).
Prevents include cycles such as A → B → A.
"""
if yaml_config is None and yaml_path is None:
raise ValueError("load_yaml_config needs either yaml_path or yaml_config")
# ------------------------------------------------------------------ cycle guard
if _seen is None: if _seen is None:
_seen = set() _seen = set()
if yaml_path is not None: key = (path, resolve_functions)
yaml_path = Path(yaml_path).expanduser().resolve() if key in _seen:
msg = f"Include cycle at {path}"
# ---------- fast-path: use LRU cached function ---------- raise ValueError(msg)
if yaml_config is None and resolve_includes: _seen.add(key)
return _get_cached_config(yaml_path, mode)
cfg = _read_yaml(path, resolve_functions=resolve_functions)
key = (yaml_path.resolve(), mode)
if key in _seen: if not resolve_includes or "include" not in cfg:
raise ValueError(f"Include cycle detected at {yaml_path}") return cfg
_seen.add(key)
base_dir = path.parent
# ------------------------------------------------------------------ load / parse merged: dict = {}
if yaml_config is None: # ordinary path-based load for inc in cfg.pop("include"):
yaml_config = _parse_yaml_file(yaml_path, mode) inc_path = (
(base_dir / inc).resolve() if not Path(inc).is_absolute() else Path(inc)
if yaml_dir is None and yaml_path is not None:
yaml_dir = yaml_path.parent
assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"
# ------------------------------------------------------------------ handle include
include = yaml_config.pop("include", None)
if not include and not resolve_includes:
return yaml_config
include_paths = include if isinstance(include, list) else [include]
final_cfg: dict = {}
for inc in reversed(include_paths):
if inc is None: # guard against explicit nulls
continue
inc_path = Path(inc)
if not inc_path.is_absolute():
inc_path = (yaml_dir / inc_path).resolve()
included = load_yaml_config(
yaml_path=inc_path,
mode=mode,
yaml_dir=inc_path.parent,
_seen=_seen, # <-- pass set downward
) )
final_cfg.update(included) merged.update(
load_yaml_config(
final_cfg.update(yaml_config) # local keys win inc_path,
return final_cfg resolve_functions=resolve_functions,
_seen=_seen,
),
)
merged.update(cfg) # local keys win
return merged
# def load_yaml_config(
# yaml_path: Union[Path, str, None] = None,
# yaml_config: Optional[dict] = None,
# yaml_dir: Optional[Path] = None,
# mode: str = "full",
# *,
# _seen: Optional[set[tuple[Path, str]]] = None,
# resolve_includes: bool = True,
# ) -> dict:
# """
# Parse a YAML config with optional include handling.
#
# Parameters
# ----------
# yaml_path
# Path to the main YAML file. Needed unless *yaml_config* is
# supplied directly (e.g. by tests).
# yaml_config
# Pre-parsed dict to use instead of reading *yaml_path*.
# yaml_dir
# Base directory for resolving relative include paths. Defaults
# to `yaml_path.parent`.
# mode
# "full" - honour !function tags
# "simple" - ignore !function (faster).
# _seen
# **Internal** recursion set: tuples of (absolute-path, mode).
# Prevents include cycles such as A → B → A.
# """
# if yaml_config is None and yaml_path is None:
# raise ValueError("load_yaml_config needs either yaml_path or yaml_config")
#
# # ------------------------------------------------------------------ cycle guard
# if _seen is None:
# _seen = set()
# if yaml_path is not None:
# yaml_path = Path(yaml_path).expanduser().resolve()
#
# # ---------- fast-path: use LRU cached function ----------
# if yaml_config is None and resolve_includes:
# return _get_cached_config(yaml_path, mode)
#
# key = (yaml_path.resolve(), mode)
# if key in _seen:
# raise ValueError(f"Include cycle detected at {yaml_path}")
# _seen.add(key)
#
# # ------------------------------------------------------------------ load / parse
# if yaml_config is None: # ordinary path-based load
# yaml_config = _parse_yaml_file(yaml_path, mode)
#
# if yaml_dir is None and yaml_path is not None:
# yaml_dir = yaml_path.parent
# assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"
#
# # ------------------------------------------------------------------ handle include
# include = yaml_config.pop("include", None)
# if not include and not resolve_includes:
# return yaml_config
#
# include_paths = include if isinstance(include, list) else [include]
# final_cfg: dict = {}
#
# for inc in reversed(include_paths):
# if inc is None: # guard against explicit nulls
# continue
# inc_path = Path(inc)
# if not inc_path.is_absolute():
# inc_path = (yaml_dir / inc_path).resolve()
# included = load_yaml_config(
# yaml_path=inc_path,
# mode=mode,
# yaml_dir=inc_path.parent,
# _seen=_seen, # <-- pass set downward
# )
# final_cfg.update(included)
#
# final_cfg.update(yaml_config) # local keys win
# return final_cfg
def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]: def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
""" """Recursively iterate over all YAML files in a directory tree.
Recursively iterate over all YAML files in a directory tree.
Excludes files in ignored directories like __pycache__ and .ipynb_checkpoints. Excludes files in ignored directories like __pycache__ and .ipynb_checkpoints.
...@@ -290,8 +278,10 @@ def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, Non ...@@ -290,8 +278,10 @@ def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, Non
Example: Example:
>>> for yaml_file in iter_yaml_files(Path("tasks")): >>> for yaml_file in iter_yaml_files(Path("tasks")):
... print(f"Found task config: {yaml_file}") ... print(f"Found task config: {yaml_file}")
""" """
for p in iglob(str(root / "**/*.yaml"), recursive=True): # for p in iglob(str(root / "**/*.yaml"), recursive=True):
for p in root.glob("**/*.yaml"):
# ignore check # ignore check
path = Path(p) path = Path(p)
# Check if any parent directory is in the ignore list # Check if any parent directory is in the ignore list
...@@ -301,8 +291,7 @@ def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, Non ...@@ -301,8 +291,7 @@ def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, Non
class TaskManager: class TaskManager:
""" """Central manager for task discovery, indexing, and loading.
Central manager for task discovery, indexing, and loading.
TaskManager scans directories for YAML task configurations and maintains TaskManager scans directories for YAML task configurations and maintains
an index of all available tasks, groups, and tags. It provides methods an index of all available tasks, groups, and tags. It provides methods
...@@ -334,17 +323,17 @@ class TaskManager: ...@@ -334,17 +323,17 @@ class TaskManager:
verbosity="INFO" verbosity="INFO"
) )
custom_tasks = [t for t in tm.all_tasks if "custom" in t] custom_tasks = [t for t in tm.all_tasks if "custom" in t]
""" """
def __init__( def __init__(
self, self,
verbosity: Optional[str] = None, verbosity: str | None = None,
include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None, include_path: str | Path | list[str | Path] | None = None,
include_defaults: bool = True, include_defaults: bool = True,
metadata: Optional[dict[str, dict[str, Any]]] = None, metadata: dict[str, dict[str, Any]] | None = None,
) -> None: ) -> None:
""" """Initialize the TaskManager.
Initialize the TaskManager.
Args: Args:
verbosity: Logging verbosity level (DEBUG, INFO, WARNING, ERROR) verbosity: Logging verbosity level (DEBUG, INFO, WARNING, ERROR)
...@@ -352,35 +341,37 @@ class TaskManager: ...@@ -352,35 +341,37 @@ class TaskManager:
path or list of paths. path or list of paths.
include_defaults: Whether to include default tasks from lm_eval/tasks/ include_defaults: Whether to include default tasks from lm_eval/tasks/
metadata: Global metadata dictionary to inject into all task configs metadata: Global metadata dictionary to inject into all task configs
""" """
if verbosity is not None: if verbosity is not None:
setup_logging(verbosity) setup_logging(verbosity)
self.include_path = include_path self.include_path = include_path
self.metadata = metadata self.metadata = metadata
self._task_index = self.initialize_tasks( self._task_index = self.initialize_tasks(
include_path=include_path, include_defaults=include_defaults include_path=include_path,
include_defaults=include_defaults,
) )
self._all_tasks = sorted(list(self._task_index.keys())) self._all_tasks = sorted(self._task_index.keys())
self._all_groups = sorted( self._all_groups = sorted(
[x for x in self._all_tasks if self._task_index[x]["type"] == "group"] [x for x in self._all_tasks if self._task_index[x]["type"] == "group"],
) )
self._all_subtasks = sorted( self._all_subtasks = sorted(
[ [
x x
for x in self._all_tasks for x in self._all_tasks
if self._task_index[x]["type"] in ["task", "python_task"] if self._task_index[x]["type"] in ["task", "python_task"]
] ],
) )
self._all_tags = sorted( self._all_tags = sorted(
[x for x in self._all_tasks if self._task_index[x]["type"] == "tag"] [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"],
) )
self.task_group_map = collections.defaultdict(list) self.task_group_map = collections.defaultdict(list)
def initialize_tasks( def initialize_tasks(
self, self,
include_path: Optional[Union[str, Path, list[Union[str, Path]]]] = None, include_path: str | Path | list[str | Path] | None = None,
include_defaults: bool = True, include_defaults: bool = True,
) -> dict[str, dict]: ) -> dict[str, dict]:
"""Creates a dictionary of tasks indexes. """Creates a dictionary of tasks indexes.
...@@ -390,13 +381,12 @@ class TaskManager: ...@@ -390,13 +381,12 @@ class TaskManager:
Can provide more than one such path as a list. Can provide more than one such path as a list.
:param include_defaults: bool = True :param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed. If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
return
Return:
dictionary of task names as key and task metadata dictionary of task names as key and task metadata
""" """
if include_defaults: all_paths = [Path(__file__).parent] if include_defaults else []
all_paths = [Path(__file__).parent]
else:
all_paths = []
if include_path is not None: if include_path is not None:
if isinstance(include_path, (str, Path)): if isinstance(include_path, (str, Path)):
include_path = [include_path] include_path = [include_path]
...@@ -431,7 +421,7 @@ class TaskManager: ...@@ -431,7 +421,7 @@ class TaskManager:
return self._all_tags return self._all_tags
@property @property
def task_index(self) -> dict[str, dict[str, Union[str, int, list[str]]]]: def task_index(self) -> dict[str, dict[str, str | int | list[str]]]:
"""Get the complete task index with metadata for all tasks.""" """Get the complete task index with metadata for all tasks."""
return self._task_index return self._task_index
...@@ -441,8 +431,7 @@ class TaskManager: ...@@ -441,8 +431,7 @@ class TaskManager:
list_tags: bool = True, list_tags: bool = True,
list_subtasks: bool = True, list_subtasks: bool = True,
) -> str: ) -> str:
""" """Return a Markdown table (as a string) listing groups, tags and/or subtasks
Return a Markdown table (as a string) listing groups, tags and/or subtasks
known to this TaskManager. Safe for configs whose yaml_path is -1 and for known to this TaskManager. Safe for configs whose yaml_path is -1 and for
task configs whose `include:` is a list. task configs whose `include:` is a list.
""" """
...@@ -458,7 +447,8 @@ class TaskManager: ...@@ -458,7 +447,8 @@ class TaskManager:
def first_output_type_from_includes(cfg: dict, base: Path) -> str: def first_output_type_from_includes(cfg: dict, base: Path) -> str:
"""Walk cfg['include'] (string or list) and return the first """Walk cfg['include'] (string or list) and return the first
include that itself specifies an output_type.""" include that itself specifies an output_type.
"""
inc_raw = cfg.get("include") inc_raw = cfg.get("include")
if not inc_raw: if not inc_raw:
return "" return ""
...@@ -587,9 +577,8 @@ class TaskManager: ...@@ -587,9 +577,8 @@ class TaskManager:
"""Check if a config dictionary defines a task list.""" """Check if a config dictionary defines a task list."""
return "task_list" in config and isinstance(config["task_list"], list) return "task_list" in config and isinstance(config["task_list"], list)
def _get_yaml_path(self, name: str) -> Union[str, int, list[str]]: def _get_yaml_path(self, name: str) -> str | int | list[str]:
""" """Get the YAML file path for a registered task.
Get the YAML file path for a registered task.
Args: Args:
name: Task name name: Task name
...@@ -599,14 +588,14 @@ class TaskManager: ...@@ -599,14 +588,14 @@ class TaskManager:
Raises: Raises:
ValueError: If task name is not registered ValueError: If task name is not registered
""" """
if name not in self.task_index: if name not in self.task_index:
raise ValueError raise ValueError
return self.task_index[name]["yaml_path"] return self.task_index[name]["yaml_path"]
def _get_config(self, name: str) -> dict: def _get_config(self, name: str) -> dict:
""" """Load the full configuration for a registered task.
Load the full configuration for a registered task.
Args: Args:
name: Task name name: Task name
...@@ -616,18 +605,17 @@ class TaskManager: ...@@ -616,18 +605,17 @@ class TaskManager:
Raises: Raises:
ValueError: If task name is not registered ValueError: If task name is not registered
""" """
if name not in self.task_index: if name not in self.task_index:
raise ValueError raise ValueError
yaml_path = self._get_yaml_path(name) yaml_path = self._get_yaml_path(name)
if yaml_path == -1: if yaml_path == -1:
return {} return {}
else: return load_yaml_config(Path(yaml_path))
return load_yaml_config(Path(yaml_path), mode="full")
def _get_tasklist(self, name: str) -> Union[list[str], int]: def _get_tasklist(self, name: str) -> list[str] | int:
""" """Get the task list for a group or tag.
Get the task list for a group or tag.
Args: Args:
name: Group or tag name name: Group or tag name
...@@ -637,6 +625,7 @@ class TaskManager: ...@@ -637,6 +625,7 @@ class TaskManager:
Raises: Raises:
ValueError: If name refers to an individual task ValueError: If name refers to an individual task
""" """
if self._name_is_task(name): if self._name_is_task(name):
raise ValueError raise ValueError
...@@ -648,10 +637,10 @@ class TaskManager: ...@@ -648,10 +637,10 @@ class TaskManager:
task_type: str, task_type: str,
yaml_path: str, yaml_path: str,
tasks_and_groups: dict[str, dict], tasks_and_groups: dict[str, dict],
config: Optional[dict] = None, config: dict | None = None,
populate_tags_fn: Optional[Callable] = None, populate_tags_fn: Callable | None = None,
) -> None: ) -> None:
"""Helper method to register a task in the tasks_and_groups dict""" """Helper method to register a task in the tasks_and_groups dict."""
tasks_and_groups[task_name] = { tasks_and_groups[task_name] = {
"type": task_type, "type": task_type,
"yaml_path": yaml_path, "yaml_path": yaml_path,
...@@ -661,9 +650,12 @@ class TaskManager: ...@@ -661,9 +650,12 @@ class TaskManager:
populate_tags_fn(config, task_name, tasks_and_groups) populate_tags_fn(config, task_name, tasks_and_groups)
def _merge_task_configs( def _merge_task_configs(
self, base_config: dict, task_specific_config: dict, task_name: str self,
base_config: dict,
task_specific_config: dict,
task_name: str,
) -> dict: ) -> dict:
"""Merge base config with task-specific overrides for task_list configs""" """Merge base config with task-specific overrides for task_list configs."""
if task_specific_config: if task_specific_config:
task_specific_config = task_specific_config.copy() task_specific_config = task_specific_config.copy()
task_specific_config.pop("task", None) task_specific_config.pop("task", None)
...@@ -671,9 +663,11 @@ class TaskManager: ...@@ -671,9 +663,11 @@ class TaskManager:
return {**base_config, "task": task_name} return {**base_config, "task": task_name}
def _process_tag_subtasks( def _process_tag_subtasks(
self, tag_name: str, update_config: Optional[dict] = None self,
tag_name: str,
update_config: dict | None = None,
) -> dict: ) -> dict:
"""Process subtasks for a tag and return loaded tasks""" """Process subtasks for a tag and return loaded tasks."""
subtask_list = self._get_tasklist(tag_name) subtask_list = self._get_tasklist(tag_name)
fn = partial( fn = partial(
self._load_individual_task_or_group, self._load_individual_task_or_group,
...@@ -681,9 +675,8 @@ class TaskManager: ...@@ -681,9 +675,8 @@ class TaskManager:
) )
return dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
def _process_alias(self, config: dict, group: Optional[str] = None) -> dict: def _process_alias(self, config: dict, group: str | None = None) -> dict:
""" """Process group alias configuration.
Process group alias configuration.
If the group is not the same as the original group which the group alias If the group is not the same as the original group which the group alias
was intended for, set the group_alias to None instead. was intended for, set the group_alias to None instead.
...@@ -694,21 +687,26 @@ class TaskManager: ...@@ -694,21 +687,26 @@ class TaskManager:
Returns: Returns:
Modified configuration with processed aliases Modified configuration with processed aliases
""" """
if ("group_alias" in config) and ("group" in config) and group is not None: if (
if config["group"] != group: ("group_alias" in config)
config["group_alias"] = None and ("group" in config)
and group is not None
and config["group"] != group
):
config["group_alias"] = None
return config return config
def _class_has_config_in_constructor(self, cls) -> bool: def _class_has_config_in_constructor(self, cls) -> bool:
""" """Check if a class constructor accepts a 'config' parameter.
Check if a class constructor accepts a 'config' parameter.
Args: Args:
cls: Class to inspect cls: Class to inspect
Returns: Returns:
True if constructor has 'config' parameter, False otherwise True if constructor has 'config' parameter, False otherwise
""" """
constructor = getattr(cls, "__init__", None) constructor = getattr(cls, "__init__", None)
return ( return (
...@@ -725,10 +723,9 @@ class TaskManager: ...@@ -725,10 +723,9 @@ class TaskManager:
self, self,
cfg: dict, cfg: dict,
task_name: str, task_name: str,
yaml_path: Union[str, None], yaml_path: str | None,
) -> dict: ) -> dict:
""" """Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
Returns {task_name: task_object}. Returns {task_name: task_object}.
""" """
from lm_eval.api.task import ConfigurableTask, Task # local import avoids cycle from lm_eval.api.task import ConfigurableTask, Task # local import avoids cycle
...@@ -772,17 +769,16 @@ class TaskManager: ...@@ -772,17 +769,16 @@ class TaskManager:
def _create_group_object( def _create_group_object(
self, self,
cfg: dict, cfg: dict,
parent_name: Union[str, None] = None, parent_name: str | None = None,
) -> tuple[GroupConfig, list[Union[str, dict]]]: ) -> tuple[GroupConfig, list[str | dict]]:
""" """Build GroupConfig and return (group_obj, subtask_names).
Build GroupConfig and return (group_obj, subtask_names).
Resolves tag expansion. Resolves tag expansion.
""" """
if self.metadata is not None: if self.metadata is not None:
cfg["metadata"] = cfg.get("metadata", {}) | self.metadata cfg["metadata"] = cfg.get("metadata", {}) | self.metadata
grp = GroupConfig(**cfg) grp = GroupConfig(**cfg)
subtasks: list[Union[str, dict]] = [] subtasks: list[str | dict] = []
if grp.task: if grp.task:
for t in grp.task: for t in grp.task:
if isinstance(t, str) and self._name_is_tag(t): if isinstance(t, str) and self._name_is_tag(t):
...@@ -793,9 +789,9 @@ class TaskManager: ...@@ -793,9 +789,9 @@ class TaskManager:
def _load_subtasks( def _load_subtasks(
self, self,
subtasks: list[Union[str, dict]], subtasks: list[str | dict],
parent_name: Union[str, GroupConfig, None], parent_name: str | GroupConfig | None,
update_config: Union[dict, None], update_config: dict | None,
) -> Mapping: ) -> Mapping:
"""Return merged mapping of all subtasks, handling duplicates.""" """Return merged mapping of all subtasks, handling duplicates."""
fn = functools.partial( fn = functools.partial(
...@@ -807,16 +803,14 @@ class TaskManager: ...@@ -807,16 +803,14 @@ class TaskManager:
def _load_individual_task_or_group( def _load_individual_task_or_group(
self, self,
payload: Union[str, dict], payload: str | dict,
*, *,
parent_name: Union[str, None] = None, parent_name: str | None = None,
update_config: Union[dict, None] = None, update_config: dict | None = None,
) -> Mapping: ) -> Mapping:
""" """Public helper that turns *payload* (str task/group/tag **or** dict config)
Public helper that turns *payload* (str task/group/tag **or** dict config)
into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}. into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
""" """
# ------------------------------------------------------------------ STRING # ------------------------------------------------------------------ STRING
if isinstance(payload, str): if isinstance(payload, str):
# If caller supplied extra overrides, treat as dict immediately # If caller supplied extra overrides, treat as dict immediately
...@@ -852,14 +846,15 @@ class TaskManager: ...@@ -852,14 +846,15 @@ class TaskManager:
grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS} grp_only = {k: v for k, v in group_cfg.items() if k in GROUP_ONLY_KEYS}
grp_obj, subtasks = self._create_group_object(grp_only, parent_name) grp_obj, subtasks = self._create_group_object(grp_only, parent_name)
return { return {
grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None) grp_obj: self._load_subtasks(subtasks, grp_obj, update_config=None),
} }
# ------------ registered TAG ------------------------------------ # ------------ registered TAG ------------------------------------
if self._name_is_tag(payload): if self._name_is_tag(payload):
return self._process_tag_subtasks(payload, update_config=None) return self._process_tag_subtasks(payload, update_config=None)
raise ValueError(f"Unknown task / group / tag name: {payload!r}") msg = f"Unknown task / group / tag name: {payload!r}"
raise ValueError(msg)
# ------------------------------------------------------------------- DICT # ------------------------------------------------------------------- DICT
if isinstance(payload, dict): if isinstance(payload, dict):
...@@ -882,7 +877,7 @@ class TaskManager: ...@@ -882,7 +877,7 @@ class TaskManager:
n n
for n in self.task_group_map[parent_name] for n in self.task_group_map[parent_name]
if n.startswith(name) if n.startswith(name)
] ],
) )
if count: if count:
name = f"{name}-{count}" name = f"{name}-{count}"
...@@ -904,15 +899,16 @@ class TaskManager: ...@@ -904,15 +899,16 @@ class TaskManager:
name = payload["task"] name = payload["task"]
return self._create_task_object(payload, name, yaml_path=None) return self._create_task_object(payload, name, yaml_path=None)
msg = f"_load_individual_task_or_group expected str | dict, got {type(payload)}"
raise TypeError( raise TypeError(
f"_load_individual_task_or_group expected str | dict, got {type(payload)}" msg,
) )
def load_task_or_group( def load_task_or_group(
self, task_list: Optional[Union[str, list[str]]] = None self,
task_list: str | list[str] | None = None,
) -> dict: ) -> dict:
""" """Load multiple tasks or groups from a list of names.
Load multiple tasks or groups from a list of names.
This is the main entry point for loading tasks. It handles lists This is the main entry point for loading tasks. It handles lists
of task names and delegates to _load_individual_task_or_group for of task names and delegates to _load_individual_task_or_group for
...@@ -936,23 +932,19 @@ class TaskManager: ...@@ -936,23 +932,19 @@ class TaskManager:
tasks = tm.load_task_or_group("arc_group") tasks = tm.load_task_or_group("arc_group")
# Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}} # Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
""" """
if isinstance(task_list, str): if isinstance(task_list, str):
task_list = [task_list] task_list = [task_list]
all_loaded_tasks = dict( return dict(
collections.ChainMap( collections.ChainMap(
*map( *(self._load_individual_task_or_group(task) for task in task_list),
lambda task: self._load_individual_task_or_group(task), ),
task_list,
)
)
) )
return all_loaded_tasks
def load_config(self, config: dict) -> Mapping: def load_config(self, config: dict) -> Mapping:
""" """Load a task from an inline configuration dictionary.
Load a task from an inline configuration dictionary.
Args: Args:
config: Configuration dictionary defining the task config: Configuration dictionary defining the task
...@@ -963,12 +955,12 @@ class TaskManager: ...@@ -963,12 +955,12 @@ class TaskManager:
Example: Example:
>>> config = {"task": "hellaswag", "num_fewshot": 5} >>> config = {"task": "hellaswag", "num_fewshot": 5}
>>> task_dict = tm.load_config(config) >>> task_dict = tm.load_config(config)
""" """
return self._load_individual_task_or_group(config) return self._load_individual_task_or_group(config)
def _get_task_and_group(self, task_dir: Union[str, Path]) -> dict[str, dict]: def _get_task_and_group(self, task_dir: str | Path) -> dict[str, dict]:
""" """Scan a directory for task configurations and build an index.
Scan a directory for task configurations and build an index.
Creates a dictionary of task metadata by recursively scanning for Creates a dictionary of task metadata by recursively scanning for
YAML files and parsing their configurations. This method handles: YAML files and parsing their configurations. This method handles:
...@@ -991,13 +983,15 @@ class TaskManager: ...@@ -991,13 +983,15 @@ class TaskManager:
Note: Note:
This method is called during TaskManager initialization to build This method is called during TaskManager initialization to build
the master task index. It uses 'simple' parsing mode for performance. the master task index. It uses 'simple' parsing mode for performance.
""" """
def _populate_tags_and_groups( def _populate_tags_and_groups(
config: dict, task: str, tasks_and_groups: dict[str, dict] config: dict,
task: str,
tasks_and_groups: dict[str, dict],
) -> None: ) -> None:
""" """Extract and register tags from a task configuration.
Extract and register tags from a task configuration.
Tags allow grouping tasks by theme or category. This function Tags allow grouping tasks by theme or category. This function
processes the 'tag' field in task configs and maintains tag processes the 'tag' field in task configs and maintains tag
...@@ -1007,6 +1001,7 @@ class TaskManager: ...@@ -1007,6 +1001,7 @@ class TaskManager:
config: Task configuration dictionary config: Task configuration dictionary
task: Name of the task being processed task: Name of the task being processed
tasks_and_groups: Master index to update with tag information tasks_and_groups: Master index to update with tag information
""" """
# TODO: remove group in next release # TODO: remove group in next release
if "tag" in config: if "tag" in config:
...@@ -1024,7 +1019,7 @@ class TaskManager: ...@@ -1024,7 +1019,7 @@ class TaskManager:
elif tasks_and_groups[tag]["type"] != "tag": elif tasks_and_groups[tag]["type"] != "tag":
eval_logger.info( eval_logger.info(
f"The tag '{tag}' is already registered as a group, this tag will not be registered. " f"The tag '{tag}' is already registered as a group, this tag will not be registered. "
"This may affect tasks you want to call." "This may affect tasks you want to call.",
) )
break break
else: else:
...@@ -1041,7 +1036,9 @@ class TaskManager: ...@@ -1041,7 +1036,9 @@ class TaskManager:
for yaml_path in iter_yaml_files(task_dir_path): for yaml_path in iter_yaml_files(task_dir_path):
try: try:
config = load_yaml_config( config = load_yaml_config(
yaml_path, mode="simple", resolve_includes=False yaml_path,
resolve_functions=False,
resolve_includes=False,
) )
except (FileNotFoundError, YAMLError, OSError) as err: except (FileNotFoundError, YAMLError, OSError) as err:
eval_logger.debug(f"File {yaml_path} could not be loaded ({err})") eval_logger.debug(f"File {yaml_path} could not be loaded ({err})")
...@@ -1109,8 +1106,7 @@ class TaskManager: ...@@ -1109,8 +1106,7 @@ class TaskManager:
def get_task_name_from_config(task_config: dict[str, str]) -> str: def get_task_name_from_config(task_config: dict[str, str]) -> str:
""" """Extract a task name from a configuration dictionary.
Extract a task name from a configuration dictionary.
Determines the canonical name for a task based on its configuration, Determines the canonical name for a task based on its configuration,
with fallback strategies for different config formats. with fallback strategies for different config formats.
...@@ -1129,18 +1125,17 @@ def get_task_name_from_config(task_config: dict[str, str]) -> str: ...@@ -1129,18 +1125,17 @@ def get_task_name_from_config(task_config: dict[str, str]) -> str:
>>> config = {"dataset_path": "custom", "dataset_name": "mytask"} >>> config = {"dataset_path": "custom", "dataset_name": "mytask"}
>>> get_task_name_from_config(config) >>> get_task_name_from_config(config)
'custom_mytask' 'custom_mytask'
""" """
if "task" in task_config: if "task" in task_config:
return task_config["task"] return task_config["task"]
if "dataset_name" in task_config: if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config) return "{dataset_path}_{dataset_name}".format(**task_config)
else: return "{dataset_path}".format(**task_config)
return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> str: def get_task_name_from_object(task_object: ConfigurableTask | Task) -> str:
""" """Extract the name from an instantiated task object.
Extract the name from an instantiated task object.
Handles both ConfigurableTask and legacy Task objects with different Handles both ConfigurableTask and legacy Task objects with different
attribute conventions for storing the task name. attribute conventions for storing the task name.
...@@ -1155,6 +1150,7 @@ def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> ...@@ -1155,6 +1150,7 @@ def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) ->
>>> task = ConfigurableTask(config={"task": "hellaswag"}) >>> task = ConfigurableTask(config={"task": "hellaswag"})
>>> get_task_name_from_object(task) >>> get_task_name_from_object(task)
'hellaswag' 'hellaswag'
""" """
if hasattr(task_object, "config"): if hasattr(task_object, "config"):
return task_object._config["task"] return task_object._config["task"]
...@@ -1169,8 +1165,7 @@ def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) -> ...@@ -1169,8 +1165,7 @@ def get_task_name_from_object(task_object: Union["ConfigurableTask", "Task"]) ->
def _check_duplicates(task_dict: dict[str, list[str]]) -> None: def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
""" """Validate that no tasks appear in multiple groups simultaneously.
Validate that no tasks appear in multiple groups simultaneously.
Helper function used to prevent conflicts when multiple groups claim Helper function used to prevent conflicts when multiple groups claim
the same constituent task. This could lead to ambiguous configuration the same constituent task. This could lead to ambiguous configuration
...@@ -1188,9 +1183,10 @@ def _check_duplicates(task_dict: dict[str, list[str]]) -> None: ...@@ -1188,9 +1183,10 @@ def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
... "group2": ["task_b", "task_c"] # task_b appears twice! ... "group2": ["task_b", "task_c"] # task_b appears twice!
... } ... }
>>> _check_duplicates(task_dict) # Raises ValueError >>> _check_duplicates(task_dict) # Raises ValueError
""" """
subtask_names = [] subtask_names = []
for key, value in task_dict.items(): for value in task_dict.values():
subtask_names.extend(value) subtask_names.extend(value)
duplicate_tasks = { duplicate_tasks = {
...@@ -1200,22 +1196,22 @@ def _check_duplicates(task_dict: dict[str, list[str]]) -> None: ...@@ -1200,22 +1196,22 @@ def _check_duplicates(task_dict: dict[str, list[str]]) -> None:
# locate the potentially problematic groups that seem to 'compete' for constituent subtasks # locate the potentially problematic groups that seem to 'compete' for constituent subtasks
competing_groups = [ competing_groups = [
group group
for group in task_dict.keys() for group in task_dict
if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0 if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
] ]
if len(duplicate_tasks) > 0: if len(duplicate_tasks) > 0:
msg = f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
raise ValueError( raise ValueError(
f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs." msg,
) )
def get_task_dict( def get_task_dict(
task_name_list: Union[str, list[Union[str, dict, "Task"]]], task_name_list: str | list[str | dict | Task],
task_manager: Optional[TaskManager] = None, task_manager: TaskManager | None = None,
) -> dict[str, Union["ConfigurableTask", "Task"]]: ) -> dict[str, ConfigurableTask | Task]:
""" """Create a dictionary of task objects from mixed input types.
Create a dictionary of task objects from mixed input types.
This is the main public API for loading tasks. It accepts various input This is the main public API for loading tasks. It accepts various input
formats (names, configs, objects) and returns a unified dictionary of formats (names, configs, objects) and returns a unified dictionary of
...@@ -1261,6 +1257,7 @@ def get_task_dict( ...@@ -1261,6 +1257,7 @@ def get_task_dict(
tm = TaskManager(include_path="/custom/tasks") tm = TaskManager(include_path="/custom/tasks")
tasks = get_task_dict(["custom_task"], task_manager=tm) tasks = get_task_dict(["custom_task"], task_manager=tm)
""" """
from lm_eval.api.task import Task from lm_eval.api.task import Task
...@@ -1268,14 +1265,16 @@ def get_task_dict( ...@@ -1268,14 +1265,16 @@ def get_task_dict(
if isinstance(task_name_list, str): if isinstance(task_name_list, str):
task_name_list = [task_name_list] task_name_list = [task_name_list]
elif not isinstance(task_name_list, list): elif not isinstance(task_name_list, list):
msg = f"Expected a 'str' or 'list' but received {type(task_name_list)}."
raise TypeError( raise TypeError(
f"Expected a 'str' or 'list' but received {type(task_name_list)}." msg,
) )
# Validate list items # Validate list items
if not all(isinstance(task, (str, dict, Task)) for task in task_name_list): if not all(isinstance(task, (str, dict, Task)) for task in task_name_list):
msg = "Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
raise TypeError( raise TypeError(
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match." msg,
) )
# Ensure we have a task manager # Ensure we have a task manager
...@@ -1289,7 +1288,8 @@ def get_task_dict( ...@@ -1289,7 +1288,8 @@ def get_task_dict(
# Pre-instantiated task object # Pre-instantiated task object
task_name = get_task_name_from_object(task_spec) task_name = get_task_name_from_object(task_spec)
if task_name in final_task_dict: if task_name in final_task_dict:
raise ValueError(f"Duplicate task name: {task_name}") msg = f"Duplicate task name: {task_name}"
raise ValueError(msg)
final_task_dict[task_name] = task_spec final_task_dict[task_name] = task_spec
else: else:
# String or dict - use load_task_or_group # String or dict - use load_task_or_group
...@@ -1297,7 +1297,8 @@ def get_task_dict( ...@@ -1297,7 +1297,8 @@ def get_task_dict(
# Check for duplicate names # Check for duplicate names
for name in result: for name in result:
if name in final_task_dict: if name in final_task_dict:
raise ValueError(f"Duplicate task name: {name}") msg = f"Duplicate task name: {name}"
raise ValueError(msg)
final_task_dict.update(result) final_task_dict.update(result)
# Check for conflicting group memberships # Check for conflicting group memberships
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment