Unverified Commit a57ffba1 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Merge pull request #3133 from EleutherAI/tasklist

Add `tasklist`
parents 70314843 bcd6faaa
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, field
from inspect import getsource from inspect import getsource
from typing import Any, Callable, Optional, Union from typing import Callable, Optional, Union
@dataclass @dataclass
...@@ -22,10 +22,10 @@ class AggMetricConfig(dict): ...@@ -22,10 +22,10 @@ class AggMetricConfig(dict):
@dataclass @dataclass
class GroupConfig(dict): class GroupConfig:
group: Optional[str] = None group: Optional[str] = None
group_alias: Optional[str] = None group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None task: Union[str, list] = field(default_factory=list)
aggregate_metric_list: Optional[ aggregate_metric_list: Optional[
Union[list[AggMetricConfig], AggMetricConfig, dict] Union[list[AggMetricConfig], AggMetricConfig, dict]
] = None ] = None
...@@ -40,6 +40,24 @@ class GroupConfig(dict): ...@@ -40,6 +40,24 @@ class GroupConfig(dict):
def __setitem__(self, item, value): def __setitem__(self, item, value):
return setattr(self, item, value) return setattr(self, item, value)
def __contains__(self, item):
"""Support 'in' operator for dict-like behavior."""
return hasattr(self, item)
def get(self, key, default=None):
"""Dict-like get method."""
return getattr(self, key, default)
def __hash__(self):
"""Make GroupConfig hashable based on group name."""
return hash(self.group)
def __eq__(self, other):
"""Equality comparison based on group name."""
if not isinstance(other, GroupConfig):
return False
return self.group == other.group
def __post_init__(self): def __post_init__(self):
if self.aggregate_metric_list is not None: if self.aggregate_metric_list is not None:
if isinstance(self.aggregate_metric_list, dict): if isinstance(self.aggregate_metric_list, dict):
...@@ -88,33 +106,5 @@ class GroupConfig(dict): ...@@ -88,33 +106,5 @@ class GroupConfig(dict):
except (TypeError, OSError): except (TypeError, OSError):
return str(value) return str(value)
class ConfigurableGroup:
def __init__(
self,
config: Optional[dict] = None,
) -> None:
self._config = GroupConfig(**config)
@property
def group(self):
return self._config.group
@property
def group_alias(self):
return self._config.group_alias
@property
def version(self):
return self._config.version
@property
def config(self):
return self._config.to_dict()
@property
def group_name(self) -> Any:
return self._config.group
def __repr__(self): def __repr__(self):
return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})" return f"GroupConfig(group={self.group},group_alias={self.group_alias})"
...@@ -5,7 +5,7 @@ import ast ...@@ -5,7 +5,7 @@ import ast
import logging import logging
import random import random
import re import re
from collections.abc import Callable from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy from copy import deepcopy
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, overload from typing import TYPE_CHECKING, Any, Literal, overload
...@@ -376,7 +376,6 @@ class Task(abc.ABC): ...@@ -376,7 +376,6 @@ class Task(abc.ABC):
The number of times each instance in a dataset is inferred on. Defaults to 1, The number of times each instance in a dataset is inferred on. Defaults to 1,
can be increased for techniques like majority voting. can be increased for techniques like majority voting.
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc: dict, results: list) -> dict[str, Any]: def process_results(self, doc: dict, results: list) -> dict[str, Any]:
...@@ -1249,7 +1248,7 @@ class ConfigurableTask(Task): ...@@ -1249,7 +1248,7 @@ class ConfigurableTask(Task):
): # TODO: ensure that non-multimodal tasks aren't getting visual args ): # TODO: ensure that non-multimodal tasks aren't getting visual args
multimodal_arg = { multimodal_arg = {
**multimodal_arg, **multimodal_arg,
**{"visual": self.doc_to_image(doc)}, "visual": self.doc_to_image(doc),
} }
if ( if (
...@@ -1257,7 +1256,7 @@ class ConfigurableTask(Task): ...@@ -1257,7 +1256,7 @@ class ConfigurableTask(Task):
): # TODO: ensure that non-multimodal tasks aren't getting audio args ): # TODO: ensure that non-multimodal tasks aren't getting audio args
multimodal_arg = { multimodal_arg = {
**multimodal_arg, **multimodal_arg,
**{"audio": self.doc_to_audio(doc)}, "audio": self.doc_to_audio(doc),
} }
if bool(multimodal_arg): if bool(multimodal_arg):
...@@ -1543,6 +1542,8 @@ class MultipleChoiceTask(Task): ...@@ -1543,6 +1542,8 @@ class MultipleChoiceTask(Task):
} }
def aggregation(self) -> dict: def aggregation(self) -> dict:
from lm_eval.api.metrics import mean
return { return {
"acc": mean, "acc": mean,
"acc_norm": mean, "acc_norm": mean,
...@@ -1609,6 +1610,8 @@ class PerplexityTask(Task): ...@@ -1609,6 +1610,8 @@ class PerplexityTask(Task):
} }
def aggregation(self) -> dict: def aggregation(self) -> dict:
from lm_eval.api.metrics import bits_per_byte, weighted_perplexity
return { return {
"word_perplexity": weighted_perplexity, "word_perplexity": weighted_perplexity,
"byte_perplexity": weighted_perplexity, "byte_perplexity": weighted_perplexity,
......
...@@ -340,23 +340,25 @@ class EvaluatorConfig: ...@@ -340,23 +340,25 @@ class EvaluatorConfig:
metadata=self.metadata if self.metadata else {}, metadata=self.metadata if self.metadata else {},
) )
task_names = task_manager.match_tasks(self.tasks) task_names = self.tasks
# TODO: FIX TASKS VALIDATION!!!
# Check for any individual task files in the list # task_names = task_manager.match_tasks(self.tasks)
for task in [task for task in self.tasks if task not in task_names]:
task_path = Path(task) # # Check for any individual task files in the list
if task_path.is_file(): # for task in [task for task in self.tasks if task not in task_names]:
config = utils.load_yaml_config(str(task_path)) # task_path = Path(task)
task_names.append(config) # if task_path.is_file():
# config = utils.load_yaml_config(str(task_path))
# Check for missing tasks # task_names.append(config)
task_missing = [ #
task for task in self.tasks if task not in task_names and "*" not in task # # Check for missing tasks
] # task_missing = [
# task for task in self.tasks if task not in task_names and "*" not in task
if task_missing: # ]
missing = ", ".join(task_missing) #
raise ValueError(f"Tasks not found: {missing}") # if task_missing:
# missing = ", ".join(task_missing)
# raise ValueError(f"Tasks not found: {missing}")
# Update tasks with resolved names # Update tasks with resolved names
self.tasks = task_names self.tasks = task_names
......
...@@ -29,7 +29,8 @@ from lm_eval.evaluator_utils import ( ...@@ -29,7 +29,8 @@ from lm_eval.evaluator_utils import (
) )
from lm_eval.loggers import EvaluationTracker from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import TaskManager
from lm_eval.tasks.manager import get_task_dict
from lm_eval.utils import ( from lm_eval.utils import (
get_logger, get_logger,
handle_non_serializable, handle_non_serializable,
......
...@@ -5,7 +5,6 @@ import pathlib ...@@ -5,7 +5,6 @@ import pathlib
import sys import sys
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from lm_eval.api.group import ConfigurableGroup
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
aggregate_subtask_metrics, aggregate_subtask_metrics,
mean, mean,
...@@ -153,11 +152,14 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]: ...@@ -153,11 +152,14 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
def get_subtask_list(task_dict, task_root=None, depth=0): def get_subtask_list(task_dict, task_root=None, depth=0):
from lm_eval.api.group import GroupConfig
from lm_eval.api.task import Task
subtask_list = {} subtask_list = {}
for group_obj, task_obj in task_dict.items(): for group_obj, task_obj in task_dict.items():
if isinstance(group_obj, ConfigurableGroup): if isinstance(group_obj, GroupConfig):
# group_name = group_obj.group_name # group_name = group_obj.group
group_name = group_obj.group_name group_name = group_obj.group
else: else:
group_name = group_obj group_name = group_obj
if isinstance(task_obj, dict): if isinstance(task_obj, dict):
...@@ -175,9 +177,9 @@ def get_subtask_list(task_dict, task_root=None, depth=0): ...@@ -175,9 +177,9 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
subtask_list = {**subtask_list, **_subtask_list} subtask_list = {**subtask_list, **_subtask_list}
else: else:
if isinstance(task_obj, ConfigurableGroup): if isinstance(task_obj, GroupConfig):
# group_or_task_name = task_obj.group_name # group_or_task_name = task_obj.group
group_or_task_name = task_obj.group_name group_or_task_name = task_obj.group
elif isinstance(task_obj, Task): elif isinstance(task_obj, Task):
# group_or_task_name = task_obj.task_name # group_or_task_name = task_obj.task_name
group_or_task_name = task_obj.task_name group_or_task_name = task_obj.task_name
...@@ -224,6 +226,8 @@ def prepare_print_tasks( ...@@ -224,6 +226,8 @@ def prepare_print_tasks(
task_depth=0, task_depth=0,
group_depth=0, group_depth=0,
) -> Tuple[dict, dict]: ) -> Tuple[dict, dict]:
from lm_eval.api.task import Task
""" """
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its @param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
value is a list of task names. value is a list of task names.
...@@ -238,6 +242,7 @@ def prepare_print_tasks( ...@@ -238,6 +242,7 @@ def prepare_print_tasks(
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing. Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
""" """
from lm_eval.api.group import GroupConfig
def _sort_task_dict(task_dict): def _sort_task_dict(task_dict):
""" """
...@@ -248,8 +253,8 @@ def prepare_print_tasks( ...@@ -248,8 +253,8 @@ def prepare_print_tasks(
return dict( return dict(
sorted( sorted(
task_dict.items(), task_dict.items(),
key=lambda item: item[0].group_name key=lambda item: item[0].group
if isinstance(item[0], ConfigurableGroup) if isinstance(item[0], GroupConfig)
else item[0], else item[0],
) )
) )
...@@ -259,9 +264,9 @@ def prepare_print_tasks( ...@@ -259,9 +264,9 @@ def prepare_print_tasks(
task_dict = _sort_task_dict(task_dict) task_dict = _sort_task_dict(task_dict)
for task_or_group_name, task_or_group_obj in task_dict.items(): for task_or_group_name, task_or_group_obj in task_dict.items():
tab_string = " " * task_depth + "- " if task_depth > 0 else "" tab_string = " " * task_depth + "- " if task_depth > 0 else ""
if isinstance(task_or_group_name, ConfigurableGroup): if isinstance(task_or_group_name, GroupConfig):
# string_name = task_or_group_name.group_name # string_name = task_or_group_name.group
name = task_or_group_name.group_name name = task_or_group_name.group
from_configurable_group = True from_configurable_group = True
task_or_group_obj = _sort_task_dict(task_or_group_obj) task_or_group_obj = _sort_task_dict(task_or_group_obj)
elif isinstance(task_or_group_name, str): elif isinstance(task_or_group_name, str):
...@@ -395,6 +400,9 @@ def consolidate_group_results( ...@@ -395,6 +400,9 @@ def consolidate_group_results(
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple. The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
In the top-level invocation of this function, task_aggregation_list is ignored. In the top-level invocation of this function, task_aggregation_list is ignored.
""" """
from lm_eval.api.group import GroupConfig
from lm_eval.api.task import Task
if task_root is None: if task_root is None:
task_root = {} task_root = {}
...@@ -403,9 +411,9 @@ def consolidate_group_results( ...@@ -403,9 +411,9 @@ def consolidate_group_results(
for group_or_task, group_or_task_info in task_dict.items(): for group_or_task, group_or_task_info in task_dict.items():
# Convert to string # Convert to string
if isinstance(group_or_task, ConfigurableGroup): if isinstance(group_or_task, GroupConfig):
group_config = group_or_task.config group_config = group_or_task.to_dict()
group_or_task = group_or_task.group_name group_or_task = group_or_task.group
else: else:
group_config = None group_config = None
...@@ -434,7 +442,7 @@ def consolidate_group_results( ...@@ -434,7 +442,7 @@ def consolidate_group_results(
) )
if (group_config is None) or ( if (group_config is None) or (
group_config["aggregate_metric_list"] is None group_config.get("aggregate_metric_list") is None
): ):
results[group_or_task][" "] = " " results[group_or_task][" "] = " "
continue continue
...@@ -443,7 +451,7 @@ def consolidate_group_results( ...@@ -443,7 +451,7 @@ def consolidate_group_results(
agg_metric_list = group_config["aggregate_metric_list"] agg_metric_list = group_config["aggregate_metric_list"]
show_group_table = show_group_table | bool( show_group_table = show_group_table | bool(
group_config["aggregate_metric_list"] group_config.get("aggregate_metric_list")
) )
task_list = _task_aggregation_list[group_or_task] task_list = _task_aggregation_list[group_or_task]
......
...@@ -3,6 +3,8 @@ import logging ...@@ -3,6 +3,8 @@ import logging
import os import os
from typing import Dict from typing import Dict
import lm_eval.tasks
import lm_eval.utils
from lm_eval import utils from lm_eval import utils
...@@ -122,7 +124,7 @@ class PromptString: ...@@ -122,7 +124,7 @@ class PromptString:
if "doc_to_choice" in self.prompt_string: if "doc_to_choice" in self.prompt_string:
raise NotImplementedError("Not yet implemented to accept doc_to_choice") raise NotImplementedError("Not yet implemented to accept doc_to_choice")
text_string = utils.apply_template(doc_to_text, doc) text_string = lm_eval.utils.apply_template(doc_to_text, doc)
target_string = utils.apply_template(doc_to_target, doc) target_string = lm_eval.utils.apply_template(doc_to_target, doc)
return [text_string, target_string] return [text_string, target_string]
This diff is collapsed.
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
from typing import Any
import yaml
_Base = (
yaml.CSafeLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
)
_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"}
def _mk_function_ctor(base_dir: Path, resolve: bool):
def ctor(loader: yaml.Loader, node: yaml.Node):
spec = loader.construct_scalar(node) # type: ignore[arg-type]
if not resolve:
return str(base_dir.expanduser() / spec)
return _import_func_in_yml(spec, base_dir)
return ctor
def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
class Loader(_Base): ... # type: ignore[no-redef]
yaml.add_constructor(
"!function",
_mk_function_ctor(base_dir, resolve_funcs),
Loader=Loader,
)
return Loader
def _load_module_with_cache(module_path: Path) -> Any:
"""Load a module from a file path with caching and hot-reload support.
Args:
module_path: Path to the Python file to load
Returns:
The loaded module
"""
# Determine module name based on location
path_str = str(module_path)
# Check if this is a built-in task module
if "/lm_eval/tasks/" in path_str:
# Find the position of lm_eval/tasks/ in the path
tasks_idx = path_str.find("/lm_eval/tasks/")
if tasks_idx != -1:
# Extract path starting from lm_eval/tasks/
# e.g., /path/to/lm_eval/tasks/hellaswag/utils.py → hellaswag/utils.py
relative_path = path_str[tasks_idx + len("/lm_eval/tasks/") :]
# Remove .py and convert to module name
# e.g., hellaswag/utils.py → lm_eval.tasks.hellaswag.utils
module_parts = relative_path.replace(".py", "").replace("/", ".")
module_name = f"lm_eval.tasks.{module_parts}"
else:
# Fallback to full path if pattern not found
module_name = str(module_path.with_suffix(""))
else:
# External module - use full path without extension
module_name = str(module_path.with_suffix(""))
# Check if we need to reload the module
if module_name in sys.modules:
existing_module = sys.modules[module_name]
# Check if it was modified
current_mtime = module_path.stat().st_mtime_ns
if (
hasattr(existing_module, "__mtime__")
and existing_module.__mtime__ == current_mtime
):
# Module hasn't changed, reuse it
return existing_module
# Load or reload the module
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec)
# Store mtime for future checks
module.__mtime__ = module_path.stat().st_mtime_ns
spec.loader.exec_module(module) # type: ignore[arg-type]
sys.modules[module_name] = module
return module
def _import_func_in_yml(qual: str, base_dir: Path):
"""Import function from qual: utils.process_doc, checking local files first then standard imports.
Args:
qual: Qualified function name (e.g., 'utils.process_doc')
base_dir: Directory to search for local modules
"""
mod_path, _, fn_name = qual.rpartition(".")
# 1) relative "utils.py" next to YAML
rel = (base_dir / f"{mod_path.replace('.', '/')}.py").resolve()
if rel.exists():
module = _load_module_with_cache(rel)
return getattr(module, fn_name)
# 2) already-importable module
module = __import__(mod_path, fromlist=[fn_name])
return getattr(module, fn_name)
def _import_fun_from_str(path_str: str) -> Any:
"""Import a function from a string in the form '/absolute/path/to/module.function_name'."""
try:
# Split off the function name from the rightmost dot
module_path_str, function_name = path_str.rsplit(".", 1)
except ValueError as e:
raise ValueError(
f"Invalid path format: {path_str}. Expected format: /path/to/module.function_name"
) from e
# Convert to Path and handle .py extension
module_path = Path(module_path_str)
if not module_path.suffix:
module_path = module_path.with_suffix(".py")
elif module_path.suffix != ".py":
# If it has a non-.py suffix, the user might have included .py in the path
# e.g., "/path/to/module.py.function_name"
base_path = module_path.with_suffix("")
if base_path.with_suffix(".py").exists():
module_path = base_path.with_suffix(".py")
if not module_path.exists():
raise ImportError(f"Module file not found: {module_path}")
module = _load_module_with_cache(module_path)
if not hasattr(module, function_name):
raise AttributeError(
f"Function '{function_name}' not found in module {module_path}"
)
return getattr(module, function_name)
def load_yaml(
path: str | Path,
*,
resolve_func: bool = True,
recursive: bool = True,
_seen: set[Path] | None = None,
) -> dict[str, Any]:
"""Pure data-loading helper.
Returns a dict ready for higher-level interpretation.
•No task/group/tag semantics here.
"""
path = Path(path).expanduser().resolve()
if _seen is None:
_seen = set()
if path in _seen:
raise ValueError(f"Include cycle at {path}")
_seen.add(path)
loader_cls = _make_loader(path.parent, resolve_funcs=resolve_func)
with path.open("rb") as fh:
cfg = yaml.load(fh, Loader=loader_cls)
if not recursive or "include" not in cfg:
return cfg
else:
includes = cfg.pop("include")
merged = {}
for inc in includes if isinstance(includes, list) else [includes]:
inc_path = (path.parent / inc) if not Path(inc).is_absolute() else Path(inc)
merged.update(
load_yaml(
inc_path,
resolve_func=resolve_func,
recursive=True,
_seen=_seen,
),
)
merged.update(cfg) # local keys win
return merged
from __future__ import annotations
import inspect
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from lm_eval.api.group import GroupConfig
from lm_eval.api.task import ConfigurableTask, Task # noqa: F401 (typing)
from lm_eval.tasks._config_loader import load_yaml as load_cfg
from lm_eval.tasks.index import Entry, Kind
load_cfg_cached = load_cfg # type: ignore[no-redef]
class TaskFactory:
"""
Turns a *Entry* (plus optional overrides) into a
*Task* | *ConfigurableTask* | *GroupConfig* hierarchy.
"""
def __init__(self, *, meta: dict[str, Any] | None = None):
self._meta = meta or {}
# ---------------------------------------------------------------- public API
def build(
self,
entry: Entry,
*,
overrides: dict[str, Any] | None = None,
registry: Mapping[str, Entry],
):
"""
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
"""
if entry.kind is Kind.TAG:
return self._build_tag(entry, overrides, registry)
if entry.kind is Kind.GROUP:
return self._build_group(entry, overrides, registry)
return self._build_task(entry, overrides)
def _build_task(self, entry: Entry, overrides: dict[str, Any] | None):
cfg = self._load_full_config(entry, overrides)
if "class" in cfg: # PY_TASK route
cls = cfg["class"]
obj = cls(config=cfg) if _ctor_accepts_config(cls) else cls()
if isinstance(obj, ConfigurableTask):
obj.config.task = entry.name
return obj
# YAML task
return ConfigurableTask(config=cfg) # type: ignore[arg-type]
def _build_group(
self,
entry: Entry,
overrides: dict[str, Any] | None,
registry: Mapping[str, Entry],
):
raw_cfg = self._load_full_config(entry, None)
grp_cfg = {k: v for k, v in raw_cfg.items() if k in GroupConfig.__annotations__}
grp_cfg["metadata"] = grp_cfg.get("metadata", {}) | self._meta
group_obj = GroupConfig(**grp_cfg)
children: dict[str, Any] = {}
for item in group_obj.task:
if isinstance(item, str): # task: hellaswag
child = self.build(
registry[item],
overrides=overrides, # group-level overrides propagate
registry=registry,
)
elif isinstance(item, dict): # task: {task: hellaswag, num_fewshot: 5}
base_name = item["task"]
child = self.build(
registry[base_name],
overrides=item, # per-item override
registry=registry,
)
else:
raise TypeError(
f"Unsupported sub-entry {item!r} in group '{entry.name}'"
)
# `child` itself is a mapping (task-name -> obj) or {GroupConfig: ...}
children.update(child)
return {group_obj: children}
def _build_tag(
self,
entry: Entry,
overrides: dict[str, Any] | None,
registry: Mapping[str, Entry],
):
return {
name: self._build_task(registry[name], overrides) for name in entry.tags
}
def _load_full_config(
self, entry: Entry, overrides: dict[str, Any] | None
) -> dict[str, Any]:
if entry.yaml_path:
cfg = deepcopy(load_cfg_cached(entry.yaml_path, resolve_func=True))
else:
cfg = {"metadata": {"config": "unknown"}} # python task without YAML
if overrides:
cfg = {**cfg, **overrides}
cfg["metadata"] = (
m if isinstance(m := cfg.get("metadata", {}), dict) else {"_metadata": m}
) | self._meta
cfg.setdefault("task", entry.name)
return cfg
def _ctor_accepts_config(cls) -> bool:
init = getattr(cls, "__init__", None)
return init and "config" in inspect.signature(init).parameters
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import TYPE_CHECKING, Any
from lm_eval.tasks._config_loader import load_yaml as load_cfg
if TYPE_CHECKING:
from collections.abc import Iterable
from pathlib import Path
class Kind(Enum):
TASK = auto() # YAML task, or task_list entry
PY_TASK = auto() # Python-defined, via "class"
GROUP = auto()
TAG = auto()
TASK_LIST = auto()
@dataclass
class Entry:
name: str
kind: Kind
yaml_path: Path | None # None for generated / py-only entries
cfg: dict[str, str] | None = None
tags: set[str] = field(default_factory=set)
task_list_path: Path | None = None
log = logging.getLogger(__name__)
_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"}
class TaskIndex:
"""Walks one or more directories, parses YAML quickly (functions unresolved),
and produces a mapping {task_name: Entry}.
"""
def __init__(self, *, meta: dict[str, str] | None = None) -> None:
self._metadata = meta or {}
def build(
self,
paths: Iterable[Path],
*,
resolve_includes=False,
) -> dict[str, Entry]:
index: dict[str, Entry] = {}
log.debug("Building task index from %s", paths)
for root in paths:
for yaml_path in self._iter_yaml_files(root):
try:
cfg = load_cfg(
yaml_path,
resolve_func=False,
recursive=resolve_includes,
)
self.process_cfg(cfg, yaml_path, index)
except Exception as err:
log.debug("Skip %s (%s)", yaml_path, err)
continue
# self._process_cfg(cfg, yaml_path, index)
log.debug("Built task index with %d entries", len(index))
return index
@staticmethod
def _iter_yaml_files(root: Path):
yield from (
p
for p in root.glob("**/*.yaml")
if not any(part in _IGNORE_DIRS for part in p.parts)
)
@staticmethod
def process_cfg(
cfg: dict[str, Any],
path: Path,
index: dict[str, Entry],
) -> None:
kind = TaskIndex._kind_of(cfg)
if kind is Kind.GROUP:
grp_name = cfg["group"]
index[grp_name] = Entry(
name=grp_name,
kind=Kind.GROUP,
yaml_path=path,
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
return
if kind is Kind.PY_TASK:
name = cfg["task"]
index[name] = Entry(
name=name,
kind=Kind.PY_TASK,
yaml_path=None,
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
TaskIndex._register_tags(name, cfg.get("tag"), index)
return
if kind is Kind.TASK:
name = cfg["task"]
index[name] = Entry(
name=name,
kind=Kind.TASK,
yaml_path=path,
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
TaskIndex._register_tags(name, cfg.get("tag"), index)
return
if kind is Kind.TASK_LIST:
for entry in cfg["task_list"]:
task_name = entry["task"] if isinstance(entry, dict) else entry
index[task_name] = Entry(
name=task_name,
kind=Kind.TASK,
yaml_path=path,
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
TaskIndex._register_tags(task_name, entry.get("tag"), index)
return
@staticmethod
def _register_tags(
task: str,
tags: str | list[str] | None,
index: dict[str, Entry],
) -> None:
if not tags:
return
for tag in tags if isinstance(tags, list) else [tags]:
entry = index.setdefault(
tag,
Entry(name=tag, kind=Kind.TAG, yaml_path=None, tags=set()),
)
entry.tags.add(task)
@staticmethod
def _kind_of(cfg: dict) -> Kind:
if "class" in cfg:
return Kind.PY_TASK
if "group" in cfg:
return Kind.GROUP
if "task_list" in cfg:
return Kind.TASK_LIST
if "task" in cfg:
return Kind.GROUP if isinstance(cfg["task"], list) else Kind.TASK
msg = "Unknown config shape"
raise ValueError(msg) from None
@staticmethod
def _str_to_set(tags: str | list[str] | None = None) -> set[str]:
"""Convert a string or list of strings to a set of strings."""
return (
set(tags)
if isinstance(tags, list)
else {tags}
if isinstance(tags, str)
else set()
)
from __future__ import annotations
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any
from lm_eval.api.task import Task
from lm_eval.tasks.factory import TaskFactory
from lm_eval.tasks.index import Entry, Kind, TaskIndex
from lm_eval.utils import setup_logging
if TYPE_CHECKING:
from lm_eval.api.task import Task
class TaskManager:
def __init__(
self,
verbosity: str | None = None,
include_path: str | Path | list[str | Path] | None = None,
include_defaults: bool = True,
metadata: dict[str, dict[str, Any]] | None = None,
) -> None:
if verbosity:
setup_logging(verbosity)
index = TaskIndex()
self._factory = TaskFactory(meta=metadata)
all_paths: list[Path] = []
if include_defaults:
all_paths.append(Path(__file__).parent)
if include_path:
all_paths += [
Path(p)
for p in (
include_path
if isinstance(include_path, (list, tuple))
else [include_path]
)
]
self._index = index.build(all_paths)
buckets = defaultdict(list)
for k, e in self._index.items():
buckets[e.kind].append(k)
self._all_tasks = sorted(
chain.from_iterable(buckets[k] for k in {Kind.TASK, Kind.PY_TASK})
)
self._all_groups = sorted(buckets[Kind.GROUP])
self._all_tags = sorted(buckets[Kind.TAG])
def _entry(self, name: str) -> Entry:
if name not in self._index:
raise KeyError(f"Unknown task/group/tag: {name}")
return self._index[name]
def load_spec(self, spec: str | dict[str, Any]):
"""Spec can be:
• str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
"""
if isinstance(spec, str):
entry = self._entry(spec)
return self._factory.build(entry, overrides=None, registry=self._index)
if isinstance(spec, dict):
# inline dict => find base entry, then pass overrides
name = spec["task"]
entry = self._entry(name)
return self._factory.build(entry, overrides=spec, registry=self._index)
raise TypeError("spec must be str or dict")
def load_task_or_group(self, task_list: str | list[str]):
return (
[self.load_spec(s) for s in task_list]
if isinstance(task_list, list)
else [self.load_spec(task_list)]
)
def get_task_dict(
task_name_list: str | list[str | dict | Task],
task_manager: TaskManager | None = None,
):
if not task_manager:
task_manager = TaskManager()
else:
assert isinstance(task_manager, TaskManager)
return {
task_name: task_manager.load_spec(task_name)
if isinstance(task_name, str)
else task_name
for task_name in task_name_list
}
...@@ -15,11 +15,10 @@ from dataclasses import asdict, is_dataclass ...@@ -15,11 +15,10 @@ from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional
import numpy as np import numpy as np
import yaml from jinja2 import BaseLoader, Environment, StrictUndefined
from jinja2 import BaseLoader, Environment, StrictUndefined, Template
SPACING = " " * 47 SPACING = " " * 47
...@@ -117,8 +116,7 @@ def setup_logging(verbosity=logging.INFO, suppress_third_party=True): ...@@ -117,8 +116,7 @@ def setup_logging(verbosity=logging.INFO, suppress_third_party=True):
# Configure custom formatter # Configure custom formatter
class CustomFormatter(logging.Formatter): class CustomFormatter(logging.Formatter):
def format(self, record): def format(self, record):
if record.name.startswith("lm_eval."): record.name = record.name.removeprefix("lm_eval.")
record.name = record.name[len("lm_eval.") :]
return super().format(record) return super().format(record)
formatter = CustomFormatter( formatter = CustomFormatter(
...@@ -527,105 +525,6 @@ def positional_deprecated(fn): ...@@ -527,105 +525,6 @@ def positional_deprecated(fn):
return _wrapper return _wrapper
def ignore_constructor(loader, node):
return node
def import_function(loader: yaml.Loader, node, yaml_path: Path):
function_name = loader.construct_scalar(node)
*module_name, function_name = function_name.split(".")
if isinstance(module_name, list):
module_name = ".".join(module_name)
module_path = yaml_path.parent / f"{module_name}.py"
spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
if spec is None:
raise ImportError(f"Could not import module {module_name} from {module_path}.")
module = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
spec.loader.exec_module(module)
function = getattr(module, function_name)
return function
def load_yaml_config(
yaml_path: str | None = None, yaml_config=None, yaml_dir=None, mode="full"
):
if mode == "simple":
constructor_fn = ignore_constructor
elif mode == "full":
if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later
constructor_fn = partial(import_function, yaml_path=Path(yaml_path))
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", constructor_fn, Loader=loader)
if yaml_config is None:
with open(yaml_path, "rb") as file:
yaml_config = yaml.load(file, Loader=loader)
if yaml_dir is None:
yaml_dir = os.path.dirname(yaml_path)
assert yaml_dir is not None
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if isinstance(include_path, str):
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not os.path.isfile(path):
path = os.path.join(yaml_dir, path)
try:
included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
env = Environment(
loader=BaseLoader(), undefined=StrictUndefined, keep_trailing_newline=True
)
env.filters["regex_replace"] = regex_replace
@lru_cache(maxsize=128)
def _compile(raw: str) -> Template:
return env.from_string(raw)
def apply_template(template: str, doc: dict) -> str:
rtemplate = _compile(template)
return rtemplate.render(**doc)
def create_iterator( def create_iterator(
raw_iterator: collections.Iterator, raw_iterator: collections.Iterator,
*, *,
...@@ -705,3 +604,25 @@ def hash_dict_images(data_dict): ...@@ -705,3 +604,25 @@ def hash_dict_images(data_dict):
if importlib.util.find_spec("PIL") if importlib.util.find_spec("PIL")
else data_dict else data_dict
) )
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
@functools.lru_cache(maxsize=256)
def _compile_tpl(src: str):
return apply_template._env.from_string(src)
def apply_template(template: str, doc: dict) -> str:
if not hasattr(apply_template, "_env"):
apply_template._env = Environment(
loader=BaseLoader(),
undefined=StrictUndefined,
keep_trailing_newline=True,
)
apply_template._env.filters["regex_replace"] = regex_replace
return _compile_tpl(template).render(**doc)
...@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" ...@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
name = "lm_eval" name = "lm_eval"
version = "0.4.9.1" version = "0.4.9.1"
authors = [ authors = [
{name="EleutherAI", email="contact@eleuther.ai"} { name = "EleutherAI", email = "contact@eleuther.ai" }
] ]
description = "A framework for evaluating language models" description = "A framework for evaluating language models"
readme = "README.md" readme = "README.md"
...@@ -19,25 +19,22 @@ classifiers = [ ...@@ -19,25 +19,22 @@ classifiers = [
requires-python = ">=3.9" requires-python = ">=3.9"
license = { "text" = "MIT" } license = { "text" = "MIT" }
dependencies = [ dependencies = [
"accelerate>=0.26.0", "accelerate>=0.26.0",
"datasets>=2.16.0,<4.0", "datasets>=2.16.0,<4.0",
"evaluate>=0.4.0", "evaluate>=0.4.0",
"peft>=0.2.0", "peft>=0.2.0",
"pytablewriter", "pytablewriter",
"rouge-score>=0.0.4", "rouge-score>=0.0.4",
"sacrebleu>=1.5.0", "sacrebleu>=1.5.0",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"sqlitedict", "sqlitedict",
"torch>=1.8", "torch>=1.8",
"transformers>=4.1", "transformers>=4.1",
"dill", "dill",
"word2number", "word2number",
"more_itertools" "more_itertools"
] ]
[tool.setuptools.packages.find]
include = ["lm_eval*"]
# required to include yaml files in pip installation # required to include yaml files in pip installation
[tool.setuptools.package-data] [tool.setuptools.package-data]
lm_eval = ["**/*.yaml", "tasks/**/*"] lm_eval = ["**/*.yaml", "tasks/**/*"]
...@@ -63,7 +60,7 @@ ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"] ...@@ -63,7 +60,7 @@ ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
ipex = ["optimum"] ipex = ["optimum"]
japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"] japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"]
longbench = ["jieba", "fuzzywuzzy", "rouge"] longbench = ["jieba", "fuzzywuzzy", "rouge"]
libra=["pymorphy2"] libra = ["pymorphy2"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"] mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"] math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"] multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
...@@ -75,22 +72,17 @@ sae_lens = ["sae_lens"] ...@@ -75,22 +72,17 @@ sae_lens = ["sae_lens"]
sentencepiece = ["sentencepiece>=0.1.98"] sentencepiece = ["sentencepiece>=0.1.98"]
sparsify = ["sparsify"] sparsify = ["sparsify"]
discrim_eval = ["statsmodels==0.14.4"] discrim_eval = ["statsmodels==0.14.4"]
testing = ["pytest", "pytest-cov", "pytest-xdist"]
unitxt = ["unitxt==1.22.0"]
vllm = ["vllm>=0.4.2"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
zeno = ["pandas", "zeno-client"]
tasks = [ tasks = [
"lm_eval[acpbench]", "lm_eval[acpbench]",
"lm_eval[discrim_eval]", "lm_eval[discrim_eval]",
"lm_eval[ifeval]", "lm_eval[ifeval]",
"lm_eval[japanese_leaderboard]", "lm_eval[japanese_leaderboard]",
"lm_eval[longbench]", "lm_eval[longbench]",
"lm_eval[libra]", "lm_eval[libra]",
"lm_eval[mamba]", "lm_eval[mamba]",
"lm_eval[math]", "lm_eval[math]",
"lm_eval[multilingual]", "lm_eval[multilingual]",
"lm_eval[ruler]" "lm_eval[ruler]"
] ]
testing = ["pytest", "pytest-cov", "pytest-xdist"] testing = ["pytest", "pytest-cov", "pytest-xdist"]
unitxt = ["unitxt==1.22.0"] unitxt = ["unitxt==1.22.0"]
...@@ -98,14 +90,6 @@ vllm = ["vllm>=0.4.2"] ...@@ -98,14 +90,6 @@ vllm = ["vllm>=0.4.2"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"] wandb = ["wandb>=0.16.3", "pandas", "numpy"]
zeno = ["pandas", "zeno-client"] zeno = ["pandas", "zeno-client"]
[project.scripts]
lm-eval = "lm_eval.__main__:cli_evaluate"
lm_eval = "lm_eval.__main__:cli_evaluate"
[project.urls]
Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
[tool.pymarkdown] [tool.pymarkdown]
plugins.md013.enabled = false # line-length plugins.md013.enabled = false # line-length
plugins.md024.allow_different_nesting = true # no-duplicate-headers plugins.md024.allow_different_nesting = true # no-duplicate-headers
...@@ -128,9 +112,5 @@ combine-as-imports = true ...@@ -128,9 +112,5 @@ combine-as-imports = true
known-first-party = ["lm_eval"] known-first-party = ["lm_eval"]
lines-after-imports = 2 lines-after-imports = 2
# required to include yaml files in pip installation
[tool.setuptools.package-data]
lm_eval = ["**/*.yaml", "tasks/**/*"]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
include = ["lm_eval*"] include = ["lm_eval*"]
"""
Tests for the config loader pure functions.
Note: _import_function uses LRU caching, so file changes during runtime
won't be detected unless the cache is cleared.
Test coverage:
- _mk_function_ctor:
- test_mk_function_ctor_with_resolve_false: no-op lambda when resolve=False
- test_mk_function_ctor_with_resolve_true: actual function import when resolve=True
- _make_loader:
- test_make_loader_creates_loader_class: creates YAML loader with !function support
- test_make_loader_caching: loader classes cached by parameters
- _import_function:
- test_import_local_module: imports from local .py files
- test_import_nested_local_module: handles dot-separated nested paths
- test_import_standard_module: falls back to standard library imports
- test_import_caching: LRU cache behavior
- test_import_mtime_sensitivity: cache behavior with file changes
- load():
- test_load_simple_yaml: basic YAML parsing
- test_load_with_function_resolved: !function tags resolved to callables
- test_load_with_function_not_resolved: !function tags become strings
- test_load_with_includes: include files merged, main values win
- test_load_with_absolute_include: absolute path includes
- test_load_without_includes_resolution: includes preserved when disabled
- test_load_include_cycle_detection: circular includes raise error
- test_load_multiple_includes: include order precedence (later includes override earlier, main overrides all)
- test_load_recursive_includes: nested includes (main->inc1->inc2, main overrides inc1 overrides inc2)
- test_load_expanduser_path: ~ paths expanded
"""
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from lm_eval.tasks._config_loader import (
_Base,
_import_func_in_yml,
_make_loader,
_mk_function_ctor,
import_fun_from_str,
load_yaml,
)
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as td:
yield Path(td)
@pytest.fixture
def yaml_file(temp_dir):
def _create_yaml(content, filename="test.yaml"):
file_path = temp_dir / filename
file_path.write_text(content)
return file_path
return _create_yaml
@pytest.fixture
def python_module(temp_dir):
def _create_module(content, filename="utils.py"):
file_path = temp_dir / filename
file_path.write_text(content)
return file_path
return _create_module
class TestMkFunctionCtor:
"""Tests for the YAML !function constructor factory."""
def test_mk_function_ctor_with_resolve_false(self, temp_dir):
"""When resolve=False, should return a string."""
ctor = _mk_function_ctor(temp_dir, resolve=False)
loader = MagicMock()
node = MagicMock()
loader.construct_scalar.return_value = "module.function"
result = ctor(loader, node)
assert isinstance(result, str)
def test_mk_function_ctor_with_resolve_true(self, temp_dir, python_module):
"""When resolve=True, should import and return the actual function."""
# Create a local module
python_module("def test_func(x):\n return x * 2\n")
ctor = _mk_function_ctor(temp_dir, resolve=True)
loader = MagicMock()
node = MagicMock()
loader.construct_scalar.return_value = "utils.test_func"
result = ctor(loader, node)
assert callable(result)
assert result(5) == 10
class TestMakeLoader:
"""Tests for YAML loader class creation and caching."""
def test_make_loader_creates_loader_class(self, temp_dir):
loader_cls = _make_loader(temp_dir, resolve_funcs=True)
assert issubclass(loader_cls, _Base)
# !function constructor should be registered
constructors = loader_cls.yaml_constructors
assert "!function" in constructors
def test_make_loader_caching(self, temp_dir):
"""Loader classes should be cached by parameters."""
# Clear cache first
_make_loader.cache_clear()
loader1 = _make_loader(temp_dir, resolve_funcs=True)
loader2 = _make_loader(temp_dir, resolve_funcs=True)
loader3 = _make_loader(temp_dir, resolve_funcs=False)
assert loader1 is loader2 # Same params = same class
assert loader1 is not loader3 # Different params = different class
class TestImportFunction:
"""Tests for dynamic function importing with mtime-based module caching."""
def test_import_local_module(self, temp_dir, python_module):
# Create a local module
python_module("def local_func(x, y):\n return x + y\n")
func = _import_func_in_yml("utils.local_func", temp_dir)
assert callable(func)
assert func(2, 3) == 5
def test_import_nested_local_module(self, temp_dir):
"""Should handle dot-separated paths for nested modules."""
# Create nested directory structure
(temp_dir / "sub").mkdir()
(temp_dir / "sub" / "module.py").write_text(
"def nested_func():\n return 'nested'\n"
)
func = _import_func_in_yml("sub.module.nested_func", temp_dir)
assert callable(func)
assert func() == "nested"
def test_import_standard_module(self, temp_dir):
"""Falls back to standard import for non-local modules."""
# Import from standard library
func = _import_func_in_yml("os.path.join", temp_dir)
assert callable(func)
assert func("a", "b") in ("a/b", "a\\b") # Unix or Windows
def test_import_caching(self, temp_dir, python_module):
# Clear cache first
_import_func_in_yml.cache_clear()
python_module("def cached_func():\n return 42\n")
func1 = _import_func_in_yml("utils.cached_func", temp_dir)
func2 = _import_func_in_yml("utils.cached_func", temp_dir)
assert func1 is func2 # Cached
def test_import_mtime_sensitivity(self, temp_dir):
"""Verifies LRU cache behavior - file changes require cache clear."""
# Clear the LRU cache
_import_func_in_yml.cache_clear()
# Create a module
module_path = temp_dir / "test_mtime.py"
module_path.write_text("value = 1\n")
# Import it
import_key = "test_mtime.value"
value1 = _import_func_in_yml(import_key, temp_dir)
assert value1 == 1
value2 = _import_func_in_yml(import_key, temp_dir)
assert value2 == 1 # From cache
_import_func_in_yml.cache_clear()
value3 = _import_func_in_yml(import_key, temp_dir)
assert value3 == 1 # Re-imported
class TestImportFunFromStr:
"""Tests for import_fun_from_str function."""
def test_import_from_absolute_path(self, temp_dir):
"""Test importing function from absolute path."""
# Create a test module
module_path = temp_dir / "test_module.py"
module_path.write_text("def test_func(x):\n return x * 2\n")
# Import using absolute path
func = import_fun_from_str(f"{module_path.with_suffix('')}.test_func")
assert callable(func)
assert func(5) == 10
def test_import_with_py_extension(self, temp_dir):
"""Test importing when .py is included in the path."""
# Create a test module
module_path = temp_dir / "test_module.py"
module_path.write_text("def test_func(x):\n return x + 10\n")
# Import with .py in the path
func = import_fun_from_str(f"{module_path}.test_func")
assert callable(func)
assert func(5) == 15
def test_import_nested_function(self, temp_dir):
"""Test importing from nested module structure."""
# Create nested directory
(temp_dir / "subdir").mkdir()
module_path = temp_dir / "subdir" / "nested.py"
module_path.write_text("def nested_func():\n return 'nested'\n")
# Import from nested path
func = import_fun_from_str(f"{module_path.with_suffix('')}.nested_func")
assert callable(func)
assert func() == "nested"
def test_import_missing_module(self, temp_dir):
"""Test error when module doesn't exist."""
with pytest.raises(ImportError, match="Module file not found"):
import_fun_from_str(f"{temp_dir}/nonexistent.test_func")
def test_import_missing_function(self, temp_dir):
"""Test error when function doesn't exist in module."""
module_path = temp_dir / "test_module.py"
module_path.write_text("def other_func():\n pass\n")
with pytest.raises(AttributeError, match="Function 'missing_func' not found"):
import_fun_from_str(f"{module_path.with_suffix('')}.missing_func")
def test_import_invalid_format(self):
"""Test error with invalid path format."""
with pytest.raises(ValueError, match="Invalid path format"):
import_fun_from_str("/path/without/function")
def test_import_caching(self, temp_dir):
"""Test that modules are cached by mtime."""
# Clear any existing cache
import sys
keys_to_remove = [k for k in sys.modules if str(temp_dir) in k]
for k in keys_to_remove:
del sys.modules[k]
module_path = temp_dir / "cached_module.py"
module_path.write_text(
"call_count = 0\ndef func():\n global call_count\n call_count += 1\n return call_count\n"
)
# First import
func1 = import_fun_from_str(f"{module_path.with_suffix('')}.func")
_result1 = func1()
# Second import should use cached module
func2 = import_fun_from_str(f"{module_path.with_suffix('')}.func")
result2 = func2()
# Both should refer to the same module instance
assert func1 is func2
assert result2 == 2 # call_count incremented
class TestLoad:
"""Tests for the main YAML loading function with includes and function resolution."""
def test_load_simple_yaml(self, yaml_file):
content = """
task: test_task
description: A test task
metric: accuracy
"""
file_path = yaml_file(content)
result = load_yaml(file_path)
assert result["task"] == "test_task"
assert result["description"] == "A test task"
assert result["metric"] == "accuracy"
def test_load_with_function_resolved(self, yaml_file, python_module):
# Create a module with a function
python_module("def process_doc(doc):\n return doc.upper()\n")
content = """
task: test_task
doc_to_text: !function utils.process_doc
"""
file_path = yaml_file(content)
result = load_yaml(file_path, resolve_func=True)
assert callable(result["doc_to_text"])
assert result["doc_to_text"]("hello") == "HELLO"
def test_load_with_function_not_resolved(self, yaml_file):
content = """
task: test_task
doc_to_text: !function utils.process_doc
"""
file_path = yaml_file(content)
result = load_yaml(file_path, resolve_func=False)
assert isinstance(result["doc_to_text"], str)
# When resolve_functions=False, it returns the full path + function spec
assert result["doc_to_text"].endswith("utils.process_doc")
assert result["doc_to_text"] == str(file_path.parent / "utils.process_doc")
def test_load_with_includes(self, temp_dir, yaml_file):
"""Include files are merged with local values taking precedence."""
# Create included file with shared_value: 42
included_content = """
shared_metric: f1_score
shared_value: 42
"""
yaml_file(included_content, "included.yaml")
# Create main file that also defines shared_value: 100
main_content = """
include:
- included.yaml
task: main_task
shared_value: 100
"""
main_path = yaml_file(main_content, "main.yaml")
result = load_yaml(main_path, recursive=True)
assert result["task"] == "main_task"
assert result["shared_metric"] == "f1_score"
# Verify main file value (100) overrides included file value (42)
assert result["shared_value"] == 100 # Local wins
assert "include" not in result
def test_load_with_absolute_include(self, temp_dir, yaml_file):
# Create included file in different directory
other_dir = temp_dir / "other"
other_dir.mkdir()
included_path = other_dir / "included.yaml"
included_path.write_text("included_key: included_value\n")
# Create main file with absolute path
main_content = f"""
include:
- {included_path}
main_key: main_value
"""
main_path = yaml_file(main_content)
result = load_yaml(main_path, recursive=True)
assert result["main_key"] == "main_value"
assert result["included_key"] == "included_value"
def test_load_without_includes_resolution(self, yaml_file):
content = """
include:
- other.yaml
task: test_task
"""
file_path = yaml_file(content)
result = load_yaml(file_path, recursive=False)
assert result["include"] == ["other.yaml"]
assert result["task"] == "test_task"
def test_load_include_cycle_detection(self, temp_dir, yaml_file):
"""Circular includes should raise ValueError."""
# Create circular includes
yaml_file("include:\n - b.yaml\n", "a.yaml")
yaml_file("include:\n - c.yaml\n", "b.yaml")
yaml_file("include:\n - a.yaml\n", "c.yaml")
with pytest.raises(ValueError, match="Include cycle"):
load_yaml(temp_dir / "a.yaml")
def test_load_multiple_includes(self, temp_dir, yaml_file):
"""Multiple includes are processed in order, later values override earlier."""
# Create multiple included files
yaml_file("key1: value1\n", "inc1.yaml") # Sets key1 to "value1"
yaml_file(
"key2: value2\nmain_key: should_be_ignored\n", "inc2.yaml"
) # Tries to set main_key
yaml_file(
"key3: value3\nkey1: override\n", "inc3.yaml"
) # Overrides key1 to "override"
# Include order matters: inc3 comes after inc1, so its key1 value wins
main_content = """
include:
- inc1.yaml
- inc2.yaml
- inc3.yaml
main_key: main_value
"""
main_path = yaml_file(main_content)
result = load_yaml(main_path)
# Verify inc3's value overrides inc1's value for key1
assert result["key1"] == "override" # Last include wins
assert result["key2"] == "value2"
assert result["key3"] == "value3"
# Verify main file's value is NOT overridden by inc2.yaml
assert result["main_key"] == "main_value" # Main file wins over includes
def test_load_recursive_includes(self, temp_dir, yaml_file):
"""Includes can be recursive - inc1 can include inc2."""
# Create inc2.yaml (deepest level)
yaml_file(
"deep_key: deep_value\nshared_key: from_inc2\nshared_middle: inc2_middle\n",
"inc2.yaml",
)
# Create inc1.yaml that includes inc2.yaml
inc1_content = """include:
- inc2.yaml
middle_key: middle_value
shared_key: from_inc1
shared_middle: inc1_middle
"""
yaml_file(inc1_content, "inc1.yaml")
# Create main.yaml that includes inc1.yaml
main_content = """include:
- inc1.yaml
top_key: top_value
shared_key: from_main
"""
main_path = yaml_file(main_content, "main.yaml")
result = load_yaml(main_path)
# All keys should be present
assert result["deep_key"] == "deep_value" # From inc2
assert result["middle_key"] == "middle_value" # From inc1
assert result["top_key"] == "top_value" # From main
# Verify override order: main > inc1 > inc2
assert result["shared_key"] == "from_main" # Main wins
assert result["shared_middle"] == "inc1_middle" # inc1 wins over inc2
assert "include" not in result # Include directives removed
def test_load_expanduser_path(self, yaml_file):
"""Verifies that load() calls expanduser() on paths with ~."""
content = "test: value\n"
file_path = yaml_file(content)
# Mock expanduser to verify it's called and control the expansion
with patch.object(Path, "expanduser") as mock_expand:
mock_expand.return_value = file_path
result = load_yaml("~/test.yaml")
mock_expand.assert_called_once()
assert result["test"] == "value"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
import importlib # import importlib
import os # import os
import sys # import sys
from datetime import datetime # from datetime import datetime
from typing import List, Optional, Tuple # from typing import List, Optional, Tuple
#
import pytest # import pytest
import torch # import torch
#
from lm_eval.caching.cache import PATH # from lm_eval.caching.cache import PATH
#
#
MODULE_DIR = os.path.dirname(os.path.realpath(__file__)) # MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
#
# NOTE the script this loads uses simple evaluate # # NOTE the script this loads uses simple evaluate
# TODO potentially test both the helper script and the normal script # # TODO potentially test both the helper script and the normal script
sys.path.append(f"{MODULE_DIR}/../scripts") # sys.path.append(f"{MODULE_DIR}/../scripts")
model_loader = importlib.import_module("requests_caching") # model_loader = importlib.import_module("requests_caching")
run_model_for_task_caching = model_loader.run_model_for_task_caching # run_model_for_task_caching = model_loader.run_model_for_task_caching
#
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1" # os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
DEFAULT_TASKS = ["lambada_openai", "sciq"] # DEFAULT_TASKS = ["lambada_openai", "sciq"]
#
#
@pytest.fixture(autouse=True) # @pytest.fixture(autouse=True)
def setup_and_teardown(): # def setup_and_teardown():
# Setup # # Setup
torch.use_deterministic_algorithms(False) # torch.use_deterministic_algorithms(False)
clear_cache() # clear_cache()
# Yields control back to the test function # # Yields control back to the test function
yield # yield
# Cleanup here # # Cleanup here
#
#
def clear_cache(): # def clear_cache():
if os.path.exists(PATH): # if os.path.exists(PATH):
cache_files = os.listdir(PATH) # cache_files = os.listdir(PATH)
for file in cache_files: # for file in cache_files:
file_path = f"{PATH}/{file}" # file_path = f"{PATH}/{file}"
os.unlink(file_path) # os.unlink(file_path)
#
#
# leaving tasks here to allow for the option to select specific task files # # leaving tasks here to allow for the option to select specific task files
def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]: # def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]:
cache_files = os.listdir(PATH) # cache_files = os.listdir(PATH)
#
file_task_names = [] # file_task_names = []
#
for file in cache_files: # for file in cache_files:
file_without_prefix = file.split("-")[1] # file_without_prefix = file.split("-")[1]
file_without_prefix_and_suffix = file_without_prefix.split(".")[0] # file_without_prefix_and_suffix = file_without_prefix.split(".")[0]
file_task_names.extend([file_without_prefix_and_suffix]) # file_task_names.extend([file_without_prefix_and_suffix])
#
return cache_files, file_task_names # return cache_files, file_task_names
#
#
def assert_created(tasks: List[str], file_task_names: List[str]): # def assert_created(tasks: List[str], file_task_names: List[str]):
tasks.sort() # tasks.sort()
file_task_names.sort() # file_task_names.sort()
#
assert tasks == file_task_names # assert tasks == file_task_names
#
#
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_true(tasks: List[str]): # def requests_caching_true(tasks: List[str]):
run_model_for_task_caching(tasks=tasks, cache_requests="true") # run_model_for_task_caching(tasks=tasks, cache_requests="true")
#
cache_files, file_task_names = get_cache_files() # cache_files, file_task_names = get_cache_files()
print(file_task_names) # print(file_task_names)
assert_created(tasks=tasks, file_task_names=file_task_names) # assert_created(tasks=tasks, file_task_names=file_task_names)
#
#
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_refresh(tasks: List[str]): # def requests_caching_refresh(tasks: List[str]):
run_model_for_task_caching(tasks=tasks, cache_requests="true") # run_model_for_task_caching(tasks=tasks, cache_requests="true")
#
timestamp_before_test = datetime.now().timestamp() # timestamp_before_test = datetime.now().timestamp()
#
run_model_for_task_caching(tasks=tasks, cache_requests="refresh") # run_model_for_task_caching(tasks=tasks, cache_requests="refresh")
#
cache_files, file_task_names = get_cache_files() # cache_files, file_task_names = get_cache_files()
#
for file in cache_files: # for file in cache_files:
modification_time = os.path.getmtime(f"{PATH}/{file}") # modification_time = os.path.getmtime(f"{PATH}/{file}")
assert modification_time > timestamp_before_test # assert modification_time > timestamp_before_test
#
tasks.sort() # tasks.sort()
file_task_names.sort() # file_task_names.sort()
#
assert tasks == file_task_names # assert tasks == file_task_names
#
#
@pytest.mark.parametrize("tasks", [DEFAULT_TASKS]) # @pytest.mark.parametrize("tasks", [DEFAULT_TASKS])
def requests_caching_delete(tasks: List[str]): # def requests_caching_delete(tasks: List[str]):
# populate the data first, rerun this test within this test for additional confidence # # populate the data first, rerun this test within this test for additional confidence
# test_requests_caching_true(tasks=tasks) # # test_requests_caching_true(tasks=tasks)
#
run_model_for_task_caching(tasks=tasks, cache_requests="delete") # run_model_for_task_caching(tasks=tasks, cache_requests="delete")
#
cache_files, file_task_names = get_cache_files() # cache_files, file_task_names = get_cache_files()
#
assert len(cache_files) == 0 # assert len(cache_files) == 0
#
#
# useful for locally running tests through the debugger # # useful for locally running tests through the debugger
if __name__ == "__main__": # if __name__ == "__main__":
#
def run_tests(): # def run_tests():
tests = [ # tests = [
# test_requests_caching_true, # # test_requests_caching_true,
# test_requests_caching_refresh, # # test_requests_caching_refresh,
# test_requests_caching_delete, # # test_requests_caching_delete,
] # ]
# Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first # # Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first
default_tasks = DEFAULT_TASKS # default_tasks = DEFAULT_TASKS
for test_func in tests: # for test_func in tests:
clear_cache() # clear_cache()
test_func(tasks=default_tasks) # test_func(tasks=default_tasks)
#
print("Tests pass") # print("Tests pass")
#
run_tests() # run_tests()
"""Tests for the task index builder that discovers YAML task configurations.
Test coverage:
- TaskIndexBuilder._kind_of: identifies task/group/tag/task_list/py_task
- TaskIndexBuilder._iter_yaml_files: finds YAML files, ignores __pycache__
- TaskIndexBuilder._process_cfg: creates correct TaskEntry for each type
- TaskIndexBuilder._register_tags: creates TAG entries for task tags
- TaskIndexBuilder.build: discovers all task types in directory tree
"""
import tempfile
from pathlib import Path
import pytest
from lm_eval.tasks.index import Kind, TaskIndex
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as td:
yield Path(td)
@pytest.fixture
def yaml_file(temp_dir):
def _create_yaml(content, path="test.yaml"):
file_path = temp_dir / path
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(content)
return file_path
return _create_yaml
class TestKindOf:
"""Tests for identifying task configuration types."""
def test_kind_of_task(self):
"""Single task with string name."""
cfg = {"task": "my_task", "dataset_path": "data"}
assert TaskIndex._kind_of(cfg) == Kind.TASK
def test_kind_of_group(self):
"""Group has task as list."""
cfg = {"task": ["task1", "task2"], "group": "my_group"}
assert TaskIndex._kind_of(cfg) == Kind.GROUP
def test_kind_of_py_task(self):
"""Python task has class field."""
cfg = {"task": "my_task", "class": "tasks.MyTask"}
assert TaskIndex._kind_of(cfg) == Kind.PY_TASK
def test_kind_of_task_list(self):
"""Task list has task_list field."""
cfg = {"task_list": ["task1", "task2"]}
assert TaskIndex._kind_of(cfg) == Kind.TASK_LIST
def test_kind_of_unknown(self):
"""Unknown config raises ValueError."""
cfg = {"unknown": "field"}
with pytest.raises(ValueError, match="Unknown config shape"):
TaskIndex._kind_of(cfg)
class TestIterYamlFiles:
"""Tests for YAML file discovery."""
def test_iter_yaml_files_simple(self, temp_dir):
"""Finds .yaml files in directory tree."""
# Create some yaml files
(temp_dir / "task1.yaml").touch()
(temp_dir / "subdir").mkdir()
(temp_dir / "subdir" / "task2.yaml").touch()
(temp_dir / "other.txt").touch()
builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files(temp_dir))
assert len(yaml_files) == 2
names = {f.name for f in yaml_files}
assert names == {"task1.yaml", "task2.yaml"}
def test_iter_yaml_files_ignores_pycache(self, temp_dir):
"""Ignores files in __pycache__ directories."""
(temp_dir / "task.yaml").touch()
(temp_dir / "__pycache__").mkdir()
(temp_dir / "__pycache__" / "ignored.yaml").touch()
(temp_dir / ".ipynb_checkpoints").mkdir()
(temp_dir / ".ipynb_checkpoints" / "also_ignored.yaml").touch()
builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files(temp_dir))
assert len(yaml_files) == 1
assert yaml_files[0].name == "task.yaml"
class TestProcessCfg:
"""Tests for processing individual config files."""
def test_process_task(self, temp_dir):
"""Regular task creates TASK entry."""
cfg = {"task": "my_task", "tag": ["tag1", "tag2"]}
path = temp_dir / "task.yaml"
index = {}
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
assert "my_task" in index
entry = index["my_task"]
assert entry.name == "my_task"
assert entry.kind == Kind.TASK
assert entry.yaml_path == path
assert entry.tags == {"tag1", "tag2"}
def test_process_group(self, temp_dir):
"""Group creates GROUP entry."""
cfg = {"task": ["t1", "t2"], "group": "my_group", "tag": ["grp_tag"]}
path = temp_dir / "group.yaml"
index = {}
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
assert "my_group" in index
entry = index["my_group"]
assert entry.name == "my_group"
assert entry.kind == Kind.GROUP
assert entry.yaml_path == path
assert entry.tags == {"grp_tag"}
def test_process_py_task(self, temp_dir):
"""Python task creates PY_TASK entry."""
cfg = {"task": "py_task", "class": "MyTask", "tag": ["py_tag"]}
path = temp_dir / "py_task.yaml"
index = {}
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
assert "py_task" in index
entry = index["py_task"]
assert entry.name == "py_task"
assert entry.kind == Kind.PY_TASK
assert entry.yaml_path is None # Python tasks don't store yaml_path
assert entry.tags == {"py_tag"}
def test_process_task_list(self, temp_dir):
"""Task list creates entries for each task."""
cfg = {
"task_list": [
"simple_task",
{"task": "complex_task", "tag": ["tag1", "tag2"]},
],
}
path = temp_dir / "list.yaml"
index = {}
builder = TaskIndex()
# The implementation has a bug - it calls entry.get() on string entries
# This test documents the current behavior which will fail
with pytest.raises(AttributeError, match="'str' object has no attribute 'get'"):
builder.process_cfg(cfg, path, index)
def test_process_task_list_dict_entries(self, temp_dir):
"""Task list with only dict entries works."""
cfg = {
"task_list": [
{"task": "task1"},
{"task": "task2", "tag": ["tag1", "tag2"]},
],
}
path = temp_dir / "list.yaml"
index = {}
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
# Task without tags
assert "task1" in index
task1 = index["task1"]
assert task1.kind == Kind.TASK
assert task1.yaml_path == path
assert task1.tags == set()
# Task with tags
assert "task2" in index
task2 = index["task2"]
assert task2.kind == Kind.TASK
assert task2.yaml_path == path
assert task2.tags == {"tag1", "tag2"}
class TestRegisterTags:
"""Tests for tag registration."""
def test_register_single_tag(self):
"""Single tag creates TAG entry."""
index = {}
builder = TaskIndex()
builder._register_tags("task1", "my_tag", index)
assert "my_tag" in index
tag_entry = index["my_tag"]
assert tag_entry.kind == Kind.TAG
assert tag_entry.yaml_path is None
assert "task1" in tag_entry.tags # TAG entries use tags set for task names
def test_register_multiple_tags(self):
"""Multiple tags create multiple TAG entries."""
index = {}
builder = TaskIndex()
builder._register_tags("task1", ["tag1", "tag2"], index)
assert "tag1" in index
assert "tag2" in index
assert "task1" in index["tag1"].tags
assert "task1" in index["tag2"].tags
def test_register_tags_accumulates(self):
"""Multiple tasks can have same tag."""
index = {}
builder = TaskIndex()
builder._register_tags("task1", "shared_tag", index)
builder._register_tags("task2", "shared_tag", index)
assert "shared_tag" in index
tag_entry = index["shared_tag"]
assert tag_entry.tags == {"task1", "task2"}
class TestBuild:
"""Tests for the main build method."""
def test_build_empty_directory(self, temp_dir):
"""Empty directory returns empty index."""
builder = TaskIndex()
index = builder.build([temp_dir])
assert index == {}
def test_build_single_task(self, temp_dir, yaml_file):
"""Single task file is discovered."""
yaml_file("task: my_task\ndataset_path: data\n")
builder = TaskIndex()
index = builder.build([temp_dir])
assert len(index) == 1
assert "my_task" in index
assert index["my_task"].kind == Kind.TASK
def test_build_mixed_types(self, temp_dir, yaml_file):
"""Discovers various task types."""
# Regular task with list tag format
yaml_file("task: task1\ntag: [common]\n", "task1.yaml")
# Group
yaml_file("task: [t1, t2]\ngroup: group1\n", "group1.yaml")
# Task list with only dict entries (to avoid the bug)
yaml_file(
"task_list:\n - task: task2\n - task: task3\n tag: [common]\n",
"list.yaml",
)
# Python task
yaml_file("task: py_task\nclass: MyClass\n", "python.yaml")
builder = TaskIndex()
index = builder.build([temp_dir])
# Check all entries exist
assert "task1" in index
assert "group1" in index
assert "task2" in index
assert "task3" in index
assert "py_task" in index
assert "common" in index # Tag entry
# Check types
assert index["task1"].kind == Kind.TASK
assert index["group1"].kind == Kind.GROUP
assert index["task2"].kind == Kind.TASK
assert index["task3"].kind == Kind.TASK
assert index["py_task"].kind == Kind.PY_TASK
assert index["common"].kind == Kind.TAG
# Check tag has both tasks
assert index["common"].tags == {"task1", "task3"}
def test_build_nested_directories(self, temp_dir, yaml_file):
"""Discovers tasks in nested directories."""
yaml_file("task: root_task\n", "root.yaml")
yaml_file("task: sub_task\n", "subdir/sub.yaml")
yaml_file("task: deep_task\n", "subdir/deeper/deep.yaml")
builder = TaskIndex()
index = builder.build([temp_dir])
assert len(index) == 3
assert all(name in index for name in ["root_task", "sub_task", "deep_task"])
def test_build_skips_invalid_yaml(self, temp_dir, yaml_file):
"""Skips files that fail to parse."""
yaml_file("task: valid_task\n", "valid.yaml")
yaml_file("invalid: [\n", "invalid.yaml") # Invalid YAML
builder = TaskIndex()
index = builder.build([temp_dir])
assert len(index) == 1
assert "valid_task" in index
def test_build_multiple_paths(self, temp_dir):
"""Can search multiple root paths."""
# Create two separate directories
dir1 = temp_dir / "dir1"
dir2 = temp_dir / "dir2"
dir1.mkdir()
dir2.mkdir()
(dir1 / "task1.yaml").write_text("task: task1\n")
(dir2 / "task2.yaml").write_text("task: task2\n")
builder = TaskIndex()
index = builder.build([dir1, dir2])
assert len(index) == 2
assert "task1" in index
assert "task2" in index
...@@ -64,10 +64,10 @@ def test_python_task_inclusion( ...@@ -64,10 +64,10 @@ def test_python_task_inclusion(
verbosity="INFO", include_path=str(custom_task_files_dir) verbosity="INFO", include_path=str(custom_task_files_dir)
) )
# check if python tasks enters the global task_index # check if python tasks enters the global task_index
assert custom_task_name in task_manager.task_index assert custom_task_name in task_manager._index
# check if subtask is present # check if subtask is present
assert custom_task_name in task_manager.all_subtasks assert custom_task_name in task_manager._index
# check if tag is present # check if tag is present
assert custom_task_tag in task_manager.all_tags assert custom_task_tag in task_manager._index
# check if it can be loaded by tag (custom_task_tag) # check if it can be loaded by tag (custom_task_tag)
assert custom_task_name in task_manager.load_task_or_group(custom_task_tag) assert custom_task_name in task_manager.load_task_or_group(custom_task_tag)
#!/usr/bin/env python3
"""
Walkthrough tests using real dataset configurations.
These tests use YAML configs with existing datasets (hellaswag) to enable
complete code walkthrough of the task loading system, including:
- Basic task loading
- Task list functionality
- Group functionality
- Include inheritance
- Issue #2158 fix (include processing preserving task names)
"""
import os
import pytest
from lm_eval.tasks import TaskManager, get_task_dict
class TestWalkthroughConfigs:
"""Test walkthrough configurations for easier code demonstration"""
@pytest.fixture(autouse=True)
def setup_task_manager(self):
"""Set up TaskManager with test configs directory"""
test_configs_dir = os.path.join(os.path.dirname(__file__), "test_configs")
self.tm = TaskManager(include_path=test_configs_dir, include_defaults=False)
def test_simple_task_loading(self):
"""Test basic task loading - walkthrough starting point"""
# Simple task should be indexed
assert "simple_task" in self.tm.all_tasks
assert self.tm._name_is_task("simple_task")
# Load the task
task_dict = get_task_dict(["simple_task"], task_manager=self.tm)
assert "simple_task" in task_dict
# Verify task configuration
task_obj = task_dict["simple_task"]
assert hasattr(task_obj, "config")
assert task_obj.config.task == "simple_task"
def test_task_list_functionality(self):
"""Test task_list feature - multiple tasks sharing config"""
# All task_list tasks should be indexed as individual tasks
expected_tasks = ["task_list_fs0", "task_list_fs1", "task_list_fs3"]
for task_name in expected_tasks:
assert task_name in self.tm.all_tasks, f"Task {task_name} not indexed"
assert self.tm._name_is_task(task_name), (
f"Task {task_name} not recognized as task"
)
# Load all tasks from the task_list
task_dict = get_task_dict(expected_tasks, task_manager=self.tm)
# Each should be a separate task object
assert len(task_dict) == 3
for task_name in expected_tasks:
assert task_name in task_dict
task_obj = task_dict[task_name]
assert task_obj.config.task == task_name
# Verify different num_fewshot values were applied
assert task_dict["task_list_fs0"].config.num_fewshot == 0
assert task_dict["task_list_fs1"].config.num_fewshot == 1
assert task_dict["task_list_fs3"].config.num_fewshot == 3
def test_group_functionality(self):
"""Test group loading with task-specific overrides"""
# Group should be indexed
assert "test_group" in self.tm.all_groups
assert self.tm._name_is_group("test_group")
# Load the group
task_dict = get_task_dict(["test_group"], task_manager=self.tm)
# Should contain the group object and its subtasks
assert len(task_dict) == 1
group_obj = list(task_dict.keys())[0]
subtasks = task_dict[group_obj]
# Check expected subtasks
expected_subtasks = ["group_task_fs0", "group_task_fs2"]
for subtask_name in expected_subtasks:
assert subtask_name in subtasks
# Verify different configurations were applied
fs0_task = subtasks["group_task_fs0"]
fs2_task = subtasks["group_task_fs2"]
assert fs0_task.config.num_fewshot == 0
assert fs2_task.config.num_fewshot == 2
def test_include_inheritance(self):
"""Test include functionality and inheritance"""
# Test direct include tasks (these were created as separate files)
include_tasks = ["include_task_fs0", "include_task_fs1", "include_task_fs5"]
for task_name in include_tasks:
assert task_name in self.tm.all_tasks
# Load tasks that use include
task_dict = get_task_dict(
include_tasks[:1], task_manager=self.tm
) # Just test first one
# Should inherit from base config
task_obj = task_dict["include_task_fs0"]
# Should inherit dataset_path from include
assert task_obj.config.dataset_path == "json"
# Should inherit output_type from include
assert task_obj.config.output_type == "multiple_choice"
# Should preserve specific task name (not base_task_name)
assert task_obj.config.task == "include_task_fs0"
# Should have overridden num_fewshot
assert task_obj.config.num_fewshot == 0
def test_issue_2158_fix_demo(self):
"""
Test issue #2158 fix - multiple tasks with same include in group.
This demonstrates the specific scenario that was failing before the fix.
"""
# Group with multiple tasks using same include should work
assert "include_group" in self.tm.all_groups
# This should NOT raise a duplicate detection error
# Before the fix, this would fail with:
# "Please call groups which overlap their constituent tasks in separate evaluation runs"
task_dict = get_task_dict(["include_group"], task_manager=self.tm)
# Should successfully load the group
assert len(task_dict) == 1
group_obj = list(task_dict.keys())[0]
subtasks = task_dict[group_obj]
# Check all expected tasks are present with correct names
expected_tasks = ["include_task_fs0", "include_task_fs1", "include_task_fs5"]
for task_name in expected_tasks:
assert task_name in subtasks, f"Task {task_name} missing from group"
task_obj = subtasks[task_name]
# CRITICAL: Task name should be preserved, not overwritten by include
assert task_obj.config.task == task_name
# Should inherit base config from include
assert task_obj.config.dataset_path == "json"
assert task_obj.config.output_type == "multiple_choice"
# Verify different num_fewshot values
assert subtasks["include_task_fs0"].config.num_fewshot == 0
assert subtasks["include_task_fs1"].config.num_fewshot == 1
assert subtasks["include_task_fs5"].config.num_fewshot == 5
def test_config_types_detection(self):
"""Test that different config types are correctly detected"""
# Load various config types to test detection methods
configs = [
# Simple task config
{"task": "walkthrough_simple_task"},
# Group config
{"group": "test_group", "task": ["task1", "task2"]},
# Task list config (would need to be loaded from file)
]
# Test config detection methods
assert self.tm._config_is_task(configs[0])
assert not self.tm._config_is_group()
assert not self.tm._config_is_task_list(configs[0])
assert not self.tm._config_is_task(configs[1])
assert self.tm._config_is_group()
assert not self.tm._config_is_task_list(configs[1])
# Test task_list detection with actual config
task_list_config = {"task_list": [{"task": "task1"}, {"task": "task2"}]}
assert self.tm._config_is_task_list(task_list_config)
assert not self.tm._config_is_task(task_list_config)
assert not self.tm._config_is_group()
if __name__ == "__main__":
pytest.main([__file__, "-v"])
...@@ -12,6 +12,7 @@ from lm_eval.api.metrics import ( ...@@ -12,6 +12,7 @@ from lm_eval.api.metrics import (
) )
from lm_eval.models.utils import Collator from lm_eval.models.utils import Collator
from lm_eval.utils import ( from lm_eval.utils import (
apply_template,
get_rolling_token_windows, get_rolling_token_windows,
make_disjoint_window, make_disjoint_window,
) )
...@@ -396,3 +397,95 @@ def test_aggregate_stderrs(samples): ...@@ -396,3 +397,95 @@ def test_aggregate_stderrs(samples):
mean_stderr(list(itertools.chain.from_iterable(samples))), mean_stderr(list(itertools.chain.from_iterable(samples))),
atol=1.0e-3, atol=1.0e-3,
) )
def test_apply_template():
"""Test the apply_template function with various scenarios."""
# Test basic variable substitution
result = apply_template("Hello {{name}}!", {"name": "World"})
assert result == "Hello World!"
# Test multiple variables
result = apply_template(
"{{greeting}} {{name}}!", {"greeting": "Hi", "name": "Alice"}
)
assert result == "Hi Alice!"
# Test missing variable (should raise error due to StrictUndefined)
with pytest.raises(Exception): # Jinja2 will raise UndefinedError
apply_template("Hello {{missing}}!", {})
# Test empty template
result = apply_template("", {})
assert result == ""
# Test template with no variables
result = apply_template("Static text", {"unused": "variable"})
assert result == "Static text"
# Test numeric variables
result = apply_template("Count: {{count}}", {"count": 42})
assert result == "Count: 42"
# Test boolean variables
result = apply_template("Flag: {{flag}}", {"flag": True})
assert result == "Flag: True"
# Test list variables
result = apply_template("Items: {{items}}", {"items": [1, 2, 3]})
assert result == "Items: [1, 2, 3]"
# Test regex_replace filter
result = apply_template(
"{{text | regex_replace('[0-9]+', 'X')}}", {"text": "abc123def456"}
)
assert result == "abcXdefX"
# Test regex_replace with count parameter
result = apply_template(
"{{text | regex_replace('[0-9]+', 'X', 1)}}", {"text": "abc123def456"}
)
assert result == "abcXdef456"
# Test complex template with loops
result = apply_template(
"{% for item in items %}{{item}} {% endfor %}", {"items": ["a", "b", "c"]}
)
assert result == "a b c "
# Test conditional template
result = apply_template("{% if flag %}Yes{% else %}No{% endif %}", {"flag": True})
assert result == "Yes"
result = apply_template("{% if flag %}Yes{% else %}No{% endif %}", {"flag": False})
assert result == "No"
# Test whitespace handling (keep_trailing_newline=True)
result = apply_template("Line 1\nLine 2\n", {})
assert result == "Line 1\nLine 2\n"
def test_apply_template_lazy_initialization():
"""Test that the Jinja2 Environment is lazily initialized."""
# Clear any existing environment to test fresh initialization
if hasattr(apply_template, "_env"):
delattr(apply_template, "_env")
# Environment should not exist before first call
assert not hasattr(apply_template, "_env")
# First call should create the environment
apply_template("{{test}}", {"test": "value"})
assert hasattr(apply_template, "_env")
# Store reference to the environment
env = apply_template._env
# Second call should reuse the same environment
apply_template("{{test}}", {"test": "value"})
assert apply_template._env is env # Same object reference
# Environment should have the custom regex_replace filter
assert "regex_replace" in apply_template._env.filters
import os import os
from typing import List, Union from typing import List, Union
from lm_eval.utils import load_yaml_config from lm_eval.tasks._config_loader import load_yaml
# {{{CI}}} # {{{CI}}}
...@@ -12,7 +12,7 @@ from lm_eval.utils import load_yaml_config ...@@ -12,7 +12,7 @@ from lm_eval.utils import load_yaml_config
# reads a text file and returns a list of words # reads a text file and returns a list of words
# used to read the output of the changed txt from tj-actions/changed-files # used to read the output of the changed txt from tj-actions/changed-files
def load_changed_files(file_path: str) -> List[str]: def load_changed_files(file_path: str) -> List[str]:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
content = f.read() content = f.read()
words_list = list(content.split()) words_list = list(content.split())
return words_list return words_list
...@@ -26,7 +26,7 @@ def parser(full_path: List[str]) -> List[str]: ...@@ -26,7 +26,7 @@ def parser(full_path: List[str]) -> List[str]:
_output = set() _output = set()
for x in full_path: for x in full_path:
if x.endswith(".yaml") and os.path.exists(x): if x.endswith(".yaml") and os.path.exists(x):
config = load_yaml_config(x, mode="simple") config = load_yaml(x, recursive=True, resolve_func=True)
if isinstance(config["task"], str): if isinstance(config["task"], str):
_output.add(config["task"]) _output.add(config["task"])
elif isinstance(config["task"], list): elif isinstance(config["task"], list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment