Commit 4254c7bd authored by Baber's avatar Baber
Browse files

add task factory

parent eec9de3e
......@@ -29,7 +29,7 @@ repos:
- id: mixed-line-ending
args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
rev: v0.12.5
hooks:
# Run the linter.
- id: ruff
......
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Callable, List, Optional, Union
from typing import Callable, Optional, Union
from datasets.features.pdf import field
@dataclass
......@@ -25,9 +27,9 @@ class AggMetricConfig(dict):
class GroupConfig:
group: Optional[str] = None
group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None
task: Union[str, list] = field(default_factory=list)
aggregate_metric_list: Optional[
Union[List[AggMetricConfig], AggMetricConfig, dict]
Union[list[AggMetricConfig], AggMetricConfig, dict]
] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
......
This diff is collapsed.
......@@ -3,29 +3,31 @@ from __future__ import annotations
import functools
import importlib.util
import sys
from collections.abc import Callable
from pathlib import Path
from typing import Any
import yaml
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
_Base = (
yaml.CSafeLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
)
_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"}
# --------------------------------------------------------------------------- helpers
@functools.lru_cache(128)
def _mk_function_ctor(base_dir: Path, resolve: bool):
def ctor(loader: yaml.Loader, node: yaml.Node):
spec = loader.construct_scalar(node) # type: ignore[arg-type]
if not resolve:
return lambda *_, **__: None
return _import_function(spec, base_dir)
return str(base_dir.expanduser() / spec)
return _import_func_in_yml(spec, base_dir)
return ctor
@functools.lru_cache(maxsize=1024)
@functools.lru_cache(maxsize=512)
def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
class Loader(_Base): ... # type: ignore[no-redef]
......@@ -37,8 +39,14 @@ def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
return Loader
@functools.lru_cache(maxsize=4096)
def _import_function(qual: str, base_dir: Path):
@functools.lru_cache(maxsize=128)
def _import_func_in_yml(qual: str, base_dir: Path):
"""Import function from qual: utils.process_doc, checking local files first then standard imports.
Args:
qual: Qualified function name (e.g., 'utils.process_doc')
base_dir: Directory to search for local modules
"""
mod_path, _, fn_name = qual.rpartition(".")
# 1) relative “utils.py” next to YAML
rel = (base_dir / f"{mod_path.replace('.', '/')}.py").resolve()
......@@ -47,26 +55,74 @@ def _import_function(qual: str, base_dir: Path):
key = f"{rel}:{mtime}" # one module per mtime
if key not in sys.modules:
spec = importlib.util.spec_from_file_location(key, rel)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from {rel}") from None
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[arg-type]
sys.modules[key] = mod
return getattr(sys.modules[key], fn_name)
# 2) alreadyimportable module
# 2) already-importable module
module = __import__(mod_path, fromlist=[fn_name])
return getattr(module, fn_name)
# --------------------------------------------------------------------- public API
@functools.lru_cache(maxsize=128)
def _import_fun_from_str(path_str: str) -> Any:
"""Import a function from a string in the form '/absolute/path/to/module.function_name'."""
try:
# Split off the function name from the rightmost dot
module_path_str, function_name = path_str.rsplit(".", 1)
except ValueError as e:
raise ValueError(
f"Invalid path format: {path_str}. Expected format: /path/to/module.function_name"
) from e
# Convert to Path and handle .py extension
module_path = Path(module_path_str)
if not module_path.suffix:
module_path = module_path.with_suffix(".py")
elif module_path.suffix != ".py":
# If it has a non-.py suffix, the user might have included .py in the path
# e.g., "/path/to/module.py.function_name"
base_path = module_path.with_suffix("")
if base_path.with_suffix(".py").exists():
module_path = base_path.with_suffix(".py")
if not module_path.exists():
raise ImportError(f"Module file not found: {module_path}")
# Use similar approach to _import_func_in_yml for consistency
mtime = module_path.stat().st_mtime_ns
cache_key = f"{module_path}:{mtime}"
if cache_key not in sys.modules:
spec = importlib.util.spec_from_file_location(cache_key, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules[cache_key] = module
module = sys.modules[cache_key]
if not hasattr(module, function_name):
raise AttributeError(
f"Function '{function_name}' not found in module {module_path}"
)
return getattr(module, function_name)
def load_yaml(
path: str | Path,
*,
resolve_functions: bool = True,
resolve_includes: bool = True,
_seen: set[Path] | None = None,
) -> dict[str, str | Callable[..., Any]]:
"""Pure dataloading helper.
Returns a dict ready for higherlevel interpretation.
) -> dict[str, Any]:
"""Pure data-loading helper.
Returns a dict ready for higher-level interpretation.
•No task/group/tag semantics here.
"""
path = Path(path).expanduser().resolve()
......@@ -82,9 +138,11 @@ def load_yaml(
if not resolve_includes or "include" not in cfg:
return cfg
else:
includes = cfg.pop("include")
merged = {}
for inc in cfg.pop("include"):
for inc in includes if isinstance(includes, list) else [includes]:
inc_path = (path.parent / inc) if not Path(inc).is_absolute() else Path(inc)
merged.update(
load_yaml(
......
from __future__ import annotations
import inspect
from collections.abc import Mapping
from copy import deepcopy
from functools import lru_cache
from typing import Any
from lm_eval.api.group import GroupConfig
from lm_eval.api.task import ConfigurableTask, Task # noqa: F401 (typing)
from lm_eval.tasks._config_loader import load_yaml as load_cfg
from lm_eval.tasks.index import Entry, Kind
load_cfg_cached = lru_cache(maxsize=512)(load_cfg) # type: ignore[no-redef]
class TaskFactory:
"""
Turns a *Entry* (plus optional overrides) into a
*Task* | *ConfigurableTask* | *GroupConfig* hierarchy.
"""
def __init__(self, *, meta: dict[str, Any] | None = None):
self._meta = meta or {}
# ---------------------------------------------------------------- public API
def build(
self,
entry: Entry,
*,
overrides: dict[str, Any] | None = None,
registry: Mapping[str, Entry],
):
"""
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
"""
if entry.kind is Kind.TAG:
return self._build_tag(entry, overrides, registry)
if entry.kind is Kind.GROUP:
return self._build_group(entry, overrides, registry)
return self._build_task(entry, overrides)
def _build_task(self, entry: Entry, overrides: dict[str, Any] | None):
cfg = self._load_full_config(entry, overrides)
if "class" in cfg: # PY_TASK route
cls = cfg["class"]
obj = cls(config=cfg) if _ctor_accepts_config(cls) else cls()
if isinstance(obj, ConfigurableTask):
obj.config.task = entry.name
return obj
# YAML task
return ConfigurableTask(config=cfg) # type: ignore[arg-type]
def _build_group(
self,
entry: Entry,
overrides: dict[str, Any] | None,
registry: Mapping[str, Entry],
):
raw_cfg = self._load_full_config(entry, None)
grp_cfg = {k: v for k, v in raw_cfg.items() if k in GroupConfig.__annotations__}
grp_cfg["metadata"] = grp_cfg.get("metadata", {}) | self._meta
group_obj = GroupConfig(**grp_cfg)
children: dict[str, Any] = {}
for item in group_obj.task:
if isinstance(item, str): # task: hellaswag
child = self.build(
registry[item],
overrides=overrides, # group-level overrides propagate
registry=registry,
)
elif isinstance(item, dict): # task: {task: hellaswag, num_fewshot: 5}
base_name = item["task"]
child = self.build(
registry[base_name],
overrides=item, # per-item override
registry=registry,
)
else:
raise TypeError(
f"Unsupported sub-entry {item!r} in group '{entry.name}'"
)
# `child` itself is a mapping (task-name -> obj) or {GroupConfig: ...}
children.update(child)
return {group_obj: children}
def _build_tag(
self,
entry: Entry,
overrides: dict[str, Any] | None,
registry: Mapping[str, Entry],
):
return {
name: self._build_task(registry[name], overrides) for name in entry.tags
}
def _load_full_config(
self, entry: Entry, overrides: dict[str, Any] | None
) -> dict[str, Any]:
if entry.yaml_path:
cfg = deepcopy(load_cfg_cached(entry.yaml_path, resolve_functions=True))
print(f"Loaded task config from {load_cfg_cached.cache_info()}")
else:
cfg = {"metadata": {"config": "unknown"}} # python task without YAML
if overrides:
cfg = {**cfg, **overrides}
cfg["metadata"] = (
m if isinstance(m := cfg.get("metadata", {}), dict) else {"_metadata": m}
) | self._meta
cfg.setdefault("task", entry.name)
return cfg
def _ctor_accepts_config(cls) -> bool:
init = getattr(cls, "__init__", None)
return init and "config" in inspect.signature(init).parameters
# lm_eval/task_index.py (continued)
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from lm_eval.tasks._config_loader import load_yaml as load_cfg
......@@ -14,137 +13,159 @@ if TYPE_CHECKING:
from pathlib import Path
class TaskKind(Enum):
class Kind(Enum):
TASK = auto() # YAML task, or task_list entry
PY_TASK = auto() # Pythondefined, via "class"
PY_TASK = auto() # Python-defined, via "class"
GROUP = auto()
TAG = auto()
TASK_LIST = auto()
@dataclass
class TaskEntry:
class Entry:
name: str
kind: TaskKind
yaml_path: Path | None # None for generated / py‑only entries
kind: Kind
yaml_path: Path | None # None for generated / py-only entries
cfg: dict[str, str] | None = None
tags: set[str] = field(default_factory=set)
task_list_path: Path | None = None # only for GROUP / TAG when lazy‑loaded
task_list_path: Path | None = None
log = logging.getLogger(__name__)
_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"}
class TaskIndexBuilder:
class TaskIndex:
"""Walks one or more directories, parses YAML quickly (functions unresolved),
and produces a mapping {task_name: TaskEntry}.
and produces a mapping {task_name: Entry}.
"""
def __init__(self, *, metadata: dict | None = None) -> None:
self._metadata = metadata or {}
def __init__(self, *, meta: dict[str, str] | None = None) -> None:
self._metadata = meta or {}
# ------------- public API --------------------------------------------------
def build(
self,
paths: Iterable[Path],
# include_defaults: bool = True,
) -> dict[str, TaskEntry]:
index: dict[str, TaskEntry] = {}
*,
resolve_includes=False,
) -> dict[str, Entry]:
index: dict[str, Entry] = {}
log.debug("Building task index from %s", paths)
for root in paths:
for yaml_path in self._iter_yaml_files(root):
try:
cfg = load_cfg(
yaml_path,
resolve_functions=False,
resolve_includes=False,
resolve_includes=resolve_includes,
)
self.process_cfg(cfg, yaml_path, index)
except Exception as err:
log.debug("Skip %s (%s)", yaml_path, err)
continue
self._process_cfg(cfg, yaml_path, index)
# self._process_cfg(cfg, yaml_path, index)
log.debug("Built task index with %d entries", len(index))
return index
# ------------- helpers -----------------------------------------------------
def _iter_yaml_files(self, root: Path):
@staticmethod
def _iter_yaml_files(root: Path):
yield from (
p
for p in root.glob("**/*.yaml")
if not any(part in _IGNORE_DIRS for part in p.parts)
)
# ---------------------------------------------------------------------------
def _process_cfg(
self,
cfg: dict,
@staticmethod
def process_cfg(
cfg: dict[str, Any],
path: Path,
index: dict[str, TaskEntry],
index: dict[str, Entry],
) -> None:
kind = self._kind_of(cfg)
if kind is TaskKind.GROUP:
kind = TaskIndex._kind_of(cfg)
if kind is Kind.GROUP:
grp_name = cfg["group"]
index[grp_name] = TaskEntry(
index[grp_name] = Entry(
name=grp_name,
kind=TaskKind.GROUP,
kind=Kind.GROUP,
yaml_path=path,
tags=set(cfg.get("tag", [])),
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
return
if kind is TaskKind.PY_TASK:
if kind is Kind.PY_TASK:
name = cfg["task"]
index[name] = TaskEntry(
index[name] = Entry(
name=name,
kind=TaskKind.PY_TASK,
kind=Kind.PY_TASK,
yaml_path=None,
tags=set(cfg.get("tag", [])),
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
self._register_tags(name, cfg.get("tag", []), index)
TaskIndex._register_tags(name, cfg.get("tag"), index)
return
if kind is TaskKind.TASK:
if kind is Kind.TASK:
name = cfg["task"]
index[name] = TaskEntry(
index[name] = Entry(
name=name,
kind=TaskKind.TASK,
kind=Kind.TASK,
yaml_path=path,
tags=set(cfg.get("tag", [])),
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
self._register_tags(name, cfg.get("tag", []), index)
TaskIndex._register_tags(name, cfg.get("tag"), index)
return
if kind is TaskKind.TASK_LIST:
if kind is Kind.TASK_LIST:
for entry in cfg["task_list"]:
task_name = entry["task"] if isinstance(entry, dict) else entry
index[task_name] = TaskEntry(
index[task_name] = Entry(
name=task_name,
kind=TaskKind.TASK,
kind=Kind.TASK,
yaml_path=path,
tags=set(entry.get("tag", []))
if isinstance(entry, dict)
else set(),
tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
)
self._register_tags(task_name, entry.get("tag", []), index)
TaskIndex._register_tags(task_name, entry.get("tag"), index)
return
# ---------------------------------------------------------------------------
def _register_tags(self, task: str, tags, index) -> None:
@staticmethod
def _register_tags(
task: str,
tags: str | list[str] | None,
index: dict[str, Entry],
) -> None:
if not tags:
return
for tag in tags if isinstance(tags, list) else [tags]:
if not tag:
continue
entry = index.setdefault(
tag,
TaskEntry(name=tag, kind=TaskKind.TAG, yaml_path=None, tags=set()),
Entry(name=tag, kind=Kind.TAG, yaml_path=None, tags=set()),
)
entry.tags.add(task) # mutate ok; dataclass not frozen for TAG
entry.tags.add(task)
@staticmethod
def _kind_of(cfg: dict) -> TaskKind:
def _kind_of(cfg: dict) -> Kind:
if "class" in cfg:
return TaskKind.PY_TASK
return Kind.PY_TASK
if "group" in cfg:
return Kind.GROUP
if "task_list" in cfg:
return TaskKind.TASK_LIST
return Kind.TASK_LIST
if "task" in cfg:
return TaskKind.GROUP if isinstance(cfg["task"], list) else TaskKind.TASK
return Kind.GROUP if isinstance(cfg["task"], list) else Kind.TASK
msg = "Unknown config shape"
raise ValueError(msg)
raise ValueError(msg) from None
@staticmethod
def _str_to_set(tags: str | list[str] | None = None) -> set[str]:
"""Convert a string or list of strings to a set of strings."""
return (
set(tags)
if isinstance(tags, list)
else {tags}
if isinstance(tags, str)
else set()
)
from __future__ import annotations
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Any
from lm_eval.tasks.factory import TaskFactory
from lm_eval.tasks.index import Entry, Kind, TaskIndex
from lm_eval.utils import setup_logging
class TaskManager:
def __init__(
self,
verbosity: str | None = None,
include_path: str | Path | list[str | Path] | None = None,
include_defaults: bool = True,
metadata: dict[str, dict[str, Any]] | None = None,
) -> None:
if verbosity:
setup_logging(verbosity)
index = TaskIndex()
self._factory = TaskFactory(meta=metadata)
all_paths: list[Path] = []
if include_defaults:
all_paths.append(Path(__file__).parent)
if include_path:
all_paths += [
Path(p)
for p in (
include_path
if isinstance(include_path, (list, tuple))
else [include_path]
)
]
self._index = index.build(all_paths)
buckets = defaultdict(list)
for k, e in self._index.items():
buckets[e.kind].append(k)
self._all_tasks = sorted(
chain.from_iterable(buckets[k] for k in {Kind.TASK, Kind.PY_TASK})
)
self._all_groups = sorted(buckets[Kind.GROUP])
self._all_tags = sorted(buckets[Kind.TAG])
def _entry(self, name: str) -> Entry:
if name not in self._index:
raise KeyError(f"Unknown task/group/tag: {name}")
return self._index[name]
def load_spec(self, spec: str | dict[str, Any]):
"""Spec can be:
• str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
"""
if isinstance(spec, str):
entry = self._entry(spec)
return self._factory.build(entry, overrides=None, registry=self._index)
if isinstance(spec, dict):
# inline dict => find base entry, then pass overrides
name = spec["task"]
entry = self._entry(name)
return self._factory.build(entry, overrides=spec, registry=self._index)
raise TypeError("spec must be str or dict")
def load_task_or_group(self, task_list: str | list[str]):
return (
[self.load_spec(s) for s in task_list]
if isinstance(task_list, list)
else [self.load_spec(task_list)]
)
......@@ -103,7 +103,8 @@ plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled = false # no-bare-urls
[tool.ruff.lint]
extend-select = ["I"]
select = ["ASYNC","B", "C4", "E", "F", "I", "LOG","PIE", "PTH","SIM", "UP", "PERF", "ISC001", "ISC002", "ICN001", "C901","FURB", "RUF"]
ignore = ["E501", "E111", "E114", "E117", "E501", "PERF203", "B011"]
[tool.ruff.lint.isort]
lines-after-imports = 2
......@@ -111,7 +112,6 @@ known-first-party = ["lm_eval"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403"]
"utils.py" = ["F401"]
[dependency-groups]
dev = [
......
......@@ -20,7 +20,7 @@ Test coverage:
- load():
- test_load_simple_yaml: basic YAML parsing
- test_load_with_function_resolved: !function tags resolved to callables
- test_load_with_function_not_resolved: !function tags become no-op lambdas
- test_load_with_function_not_resolved: !function tags become strings
- test_load_with_includes: include files merged, main values win
- test_load_with_absolute_include: absolute path includes
- test_load_without_includes_resolution: includes preserved when disabled
......@@ -38,9 +38,10 @@ import pytest
from lm_eval.tasks._config_loader import (
_Base,
_import_function,
_import_func_in_yml,
_make_loader,
_mk_function_ctor,
import_fun_from_str,
load_yaml,
)
......@@ -75,7 +76,7 @@ class TestMkFunctionCtor:
"""Tests for the YAML !function constructor factory."""
def test_mk_function_ctor_with_resolve_false(self, temp_dir):
"""When resolve=False, should return a no-op lambda."""
"""When resolve=False, should return a string."""
ctor = _mk_function_ctor(temp_dir, resolve=False)
loader = MagicMock()
......@@ -84,8 +85,7 @@ class TestMkFunctionCtor:
result = ctor(loader, node)
assert callable(result)
assert result("arg1", kwarg="value") is None
assert isinstance(result, str)
def test_mk_function_ctor_with_resolve_true(self, temp_dir, python_module):
"""When resolve=True, should import and return the actual function."""
......@@ -136,7 +136,7 @@ class TestImportFunction:
# Create a local module
python_module("def local_func(x, y):\n return x + y\n")
func = _import_function("utils.local_func", temp_dir)
func = _import_func_in_yml("utils.local_func", temp_dir)
assert callable(func)
assert func(2, 3) == 5
......@@ -149,7 +149,7 @@ class TestImportFunction:
"def nested_func():\n return 'nested'\n"
)
func = _import_function("sub.module.nested_func", temp_dir)
func = _import_func_in_yml("sub.module.nested_func", temp_dir)
assert callable(func)
assert func() == "nested"
......@@ -157,19 +157,19 @@ class TestImportFunction:
def test_import_standard_module(self, temp_dir):
"""Falls back to standard import for non-local modules."""
# Import from standard library
func = _import_function("os.path.join", temp_dir)
func = _import_func_in_yml("os.path.join", temp_dir)
assert callable(func)
assert func("a", "b") in ("a/b", "a\\b") # Unix or Windows
def test_import_caching(self, temp_dir, python_module):
# Clear cache first
_import_function.cache_clear()
_import_func_in_yml.cache_clear()
python_module("def cached_func():\n return 42\n")
func1 = _import_function("utils.cached_func", temp_dir)
func2 = _import_function("utils.cached_func", temp_dir)
func1 = _import_func_in_yml("utils.cached_func", temp_dir)
func2 = _import_func_in_yml("utils.cached_func", temp_dir)
assert func1 is func2 # Cached
......@@ -177,7 +177,7 @@ class TestImportFunction:
"""Verifies LRU cache behavior - file changes require cache clear."""
# Clear the LRU cache
_import_function.cache_clear()
_import_func_in_yml.cache_clear()
# Create a module
module_path = temp_dir / "test_mtime.py"
......@@ -185,17 +185,102 @@ class TestImportFunction:
# Import it
import_key = "test_mtime.value"
value1 = _import_function(import_key, temp_dir)
value1 = _import_func_in_yml(import_key, temp_dir)
assert value1 == 1
value2 = _import_function(import_key, temp_dir)
value2 = _import_func_in_yml(import_key, temp_dir)
assert value2 == 1 # From cache
_import_function.cache_clear()
value3 = _import_function(import_key, temp_dir)
_import_func_in_yml.cache_clear()
value3 = _import_func_in_yml(import_key, temp_dir)
assert value3 == 1 # Re-imported
class TestImportFunFromStr:
"""Tests for import_fun_from_str function."""
def test_import_from_absolute_path(self, temp_dir):
"""Test importing function from absolute path."""
# Create a test module
module_path = temp_dir / "test_module.py"
module_path.write_text("def test_func(x):\n return x * 2\n")
# Import using absolute path
func = import_fun_from_str(f"{module_path.with_suffix('')}.test_func")
assert callable(func)
assert func(5) == 10
def test_import_with_py_extension(self, temp_dir):
"""Test importing when .py is included in the path."""
# Create a test module
module_path = temp_dir / "test_module.py"
module_path.write_text("def test_func(x):\n return x + 10\n")
# Import with .py in the path
func = import_fun_from_str(f"{module_path}.test_func")
assert callable(func)
assert func(5) == 15
def test_import_nested_function(self, temp_dir):
"""Test importing from nested module structure."""
# Create nested directory
(temp_dir / "subdir").mkdir()
module_path = temp_dir / "subdir" / "nested.py"
module_path.write_text("def nested_func():\n return 'nested'\n")
# Import from nested path
func = import_fun_from_str(f"{module_path.with_suffix('')}.nested_func")
assert callable(func)
assert func() == "nested"
def test_import_missing_module(self, temp_dir):
"""Test error when module doesn't exist."""
with pytest.raises(ImportError, match="Module file not found"):
import_fun_from_str(f"{temp_dir}/nonexistent.test_func")
def test_import_missing_function(self, temp_dir):
"""Test error when function doesn't exist in module."""
module_path = temp_dir / "test_module.py"
module_path.write_text("def other_func():\n pass\n")
with pytest.raises(AttributeError, match="Function 'missing_func' not found"):
import_fun_from_str(f"{module_path.with_suffix('')}.missing_func")
def test_import_invalid_format(self):
"""Test error with invalid path format."""
with pytest.raises(ValueError, match="Invalid path format"):
import_fun_from_str("/path/without/function")
def test_import_caching(self, temp_dir):
"""Test that modules are cached by mtime."""
# Clear any existing cache
import sys
keys_to_remove = [k for k in sys.modules if str(temp_dir) in k]
for k in keys_to_remove:
del sys.modules[k]
module_path = temp_dir / "cached_module.py"
module_path.write_text(
"call_count = 0\ndef func():\n global call_count\n call_count += 1\n return call_count\n"
)
# First import
func1 = import_fun_from_str(f"{module_path.with_suffix('')}.func")
_result1 = func1()
# Second import should use cached module
func2 = import_fun_from_str(f"{module_path.with_suffix('')}.func")
result2 = func2()
# Both should refer to the same module instance
assert func1 is func2
assert result2 == 2 # call_count incremented
class TestLoad:
"""Tests for the main YAML loading function with includes and function resolution."""
......@@ -237,8 +322,10 @@ doc_to_text: !function utils.process_doc
result = load_yaml(file_path, resolve_functions=False)
assert callable(result["doc_to_text"])
assert result["doc_to_text"]("hello") is None # No-op lambda
assert isinstance(result["doc_to_text"], str)
# When resolve_functions=False, it returns the full path + function spec
assert result["doc_to_text"].endswith("utils.process_doc")
assert result["doc_to_text"] == str(file_path.parent / "utils.process_doc")
def test_load_with_includes(self, temp_dir, yaml_file):
"""Include files are merged with local values taking precedence."""
......@@ -388,3 +475,7 @@ shared_key: from_main
mock_expand.assert_called_once()
assert result["test"] == "value"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
"""
Tests for the task index builder that discovers YAML task configurations.
"""Tests for the task index builder that discovers YAML task configurations.
Test coverage:
- TaskIndexBuilder._kind_of: identifies task/group/tag/task_list/py_task
......@@ -14,7 +13,7 @@ from pathlib import Path
import pytest
from lm_eval.tasks._task_index import TaskIndexBuilder, TaskKind
from lm_eval.tasks._task_index import TaskIndex, TaskKind
@pytest.fixture
......@@ -40,28 +39,28 @@ class TestTaskKindOf:
def test_kind_of_task(self):
"""Single task with string name."""
cfg = {"task": "my_task", "dataset_path": "data"}
assert TaskIndexBuilder._kind_of(cfg) == TaskKind.TASK
assert TaskIndex._kind_of(cfg) == TaskKind.TASK
def test_kind_of_group(self):
"""Group has task as list."""
cfg = {"task": ["task1", "task2"], "group": "my_group"}
assert TaskIndexBuilder._kind_of(cfg) == TaskKind.GROUP
assert TaskIndex._kind_of(cfg) == TaskKind.GROUP
def test_kind_of_py_task(self):
"""Python task has class field."""
cfg = {"task": "my_task", "class": "tasks.MyTask"}
assert TaskIndexBuilder._kind_of(cfg) == TaskKind.PY_TASK
assert TaskIndex._kind_of(cfg) == TaskKind.PY_TASK
def test_kind_of_task_list(self):
"""Task list has task_list field."""
cfg = {"task_list": ["task1", "task2"]}
assert TaskIndexBuilder._kind_of(cfg) == TaskKind.TASK_LIST
assert TaskIndex._kind_of(cfg) == TaskKind.TASK_LIST
def test_kind_of_unknown(self):
"""Unknown config raises ValueError."""
cfg = {"unknown": "field"}
with pytest.raises(ValueError, match="Unknown config shape"):
TaskIndexBuilder._kind_of(cfg)
TaskIndex._kind_of(cfg)
class TestIterYamlFiles:
......@@ -75,8 +74,8 @@ class TestIterYamlFiles:
(temp_dir / "subdir" / "task2.yaml").touch()
(temp_dir / "other.txt").touch()
builder = TaskIndexBuilder()
yaml_files = list(builder._iter_yaml_files(temp_dir))
builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files())
assert len(yaml_files) == 2
names = {f.name for f in yaml_files}
......@@ -90,8 +89,8 @@ class TestIterYamlFiles:
(temp_dir / ".ipynb_checkpoints").mkdir()
(temp_dir / ".ipynb_checkpoints" / "also_ignored.yaml").touch()
builder = TaskIndexBuilder()
yaml_files = list(builder._iter_yaml_files(temp_dir))
builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files())
assert len(yaml_files) == 1
assert yaml_files[0].name == "task.yaml"
......@@ -106,8 +105,8 @@ class TestProcessCfg:
path = temp_dir / "task.yaml"
index = {}
builder = TaskIndexBuilder()
builder._process_cfg(cfg, path, index)
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
assert "my_task" in index
entry = index["my_task"]
......@@ -122,8 +121,8 @@ class TestProcessCfg:
path = temp_dir / "group.yaml"
index = {}
builder = TaskIndexBuilder()
builder._process_cfg(cfg, path, index)
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
assert "my_group" in index
entry = index["my_group"]
......@@ -138,8 +137,8 @@ class TestProcessCfg:
path = temp_dir / "py_task.yaml"
index = {}
builder = TaskIndexBuilder()
builder._process_cfg(cfg, path, index)
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
assert "py_task" in index
entry = index["py_task"]
......@@ -154,27 +153,30 @@ class TestProcessCfg:
"task_list": [
"simple_task",
{"task": "complex_task", "tag": ["tag1", "tag2"]},
]
],
}
path = temp_dir / "list.yaml"
index = {}
builder = TaskIndexBuilder()
builder = TaskIndex()
# The implementation has a bug - it calls entry.get() on string entries
# This test documents the current behavior which will fail
with pytest.raises(AttributeError, match="'str' object has no attribute 'get'"):
builder._process_cfg(cfg, path, index)
builder.process_cfg(cfg, path, index)
def test_process_task_list_dict_entries(self, temp_dir):
"""Task list with only dict entries works."""
cfg = {
"task_list": [{"task": "task1"}, {"task": "task2", "tag": ["tag1", "tag2"]}]
"task_list": [
{"task": "task1"},
{"task": "task2", "tag": ["tag1", "tag2"]},
],
}
path = temp_dir / "list.yaml"
index = {}
builder = TaskIndexBuilder()
builder._process_cfg(cfg, path, index)
builder = TaskIndex()
builder.process_cfg(cfg, path, index)
# Task without tags
assert "task1" in index
......@@ -197,7 +199,7 @@ class TestRegisterTags:
def test_register_single_tag(self):
"""Single tag creates TAG entry."""
index = {}
builder = TaskIndexBuilder()
builder = TaskIndex()
builder._register_tags("task1", "my_tag", index)
......@@ -210,7 +212,7 @@ class TestRegisterTags:
def test_register_multiple_tags(self):
"""Multiple tags create multiple TAG entries."""
index = {}
builder = TaskIndexBuilder()
builder = TaskIndex()
builder._register_tags("task1", ["tag1", "tag2"], index)
......@@ -222,7 +224,7 @@ class TestRegisterTags:
def test_register_tags_accumulates(self):
"""Multiple tasks can have same tag."""
index = {}
builder = TaskIndexBuilder()
builder = TaskIndex()
builder._register_tags("task1", "shared_tag", index)
builder._register_tags("task2", "shared_tag", index)
......@@ -237,7 +239,7 @@ class TestBuild:
def test_build_empty_directory(self, temp_dir):
"""Empty directory returns empty index."""
builder = TaskIndexBuilder()
builder = TaskIndex()
index = builder.build([temp_dir])
assert index == {}
......@@ -245,7 +247,7 @@ class TestBuild:
"""Single task file is discovered."""
yaml_file("task: my_task\ndataset_path: data\n")
builder = TaskIndexBuilder()
builder = TaskIndex()
index = builder.build([temp_dir])
assert len(index) == 1
......@@ -269,7 +271,7 @@ class TestBuild:
# Python task
yaml_file("task: py_task\nclass: MyClass\n", "python.yaml")
builder = TaskIndexBuilder()
builder = TaskIndex()
index = builder.build([temp_dir])
# Check all entries exist
......@@ -297,7 +299,7 @@ class TestBuild:
yaml_file("task: sub_task\n", "subdir/sub.yaml")
yaml_file("task: deep_task\n", "subdir/deeper/deep.yaml")
builder = TaskIndexBuilder()
builder = TaskIndex()
index = builder.build([temp_dir])
assert len(index) == 3
......@@ -308,7 +310,7 @@ class TestBuild:
yaml_file("task: valid_task\n", "valid.yaml")
yaml_file("invalid: [\n", "invalid.yaml") # Invalid YAML
builder = TaskIndexBuilder()
builder = TaskIndex()
index = builder.build([temp_dir])
assert len(index) == 1
......@@ -325,7 +327,7 @@ class TestBuild:
(dir1 / "task1.yaml").write_text("task: task1\n")
(dir2 / "task2.yaml").write_text("task: task2\n")
builder = TaskIndexBuilder()
builder = TaskIndex()
index = builder.build([dir1, dir2])
assert len(index) == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment