Commit 4254c7bd authored by Baber's avatar Baber
Browse files

add task factory

parent eec9de3e
...@@ -29,7 +29,7 @@ repos: ...@@ -29,7 +29,7 @@ repos:
- id: mixed-line-ending - id: mixed-line-ending
args: [--fix=lf] args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2 rev: v0.12.5
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff - id: ruff
......
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Callable, List, Optional, Union from typing import Callable, Optional, Union
from datasets.features.pdf import field
@dataclass @dataclass
...@@ -25,9 +27,9 @@ class AggMetricConfig(dict): ...@@ -25,9 +27,9 @@ class AggMetricConfig(dict):
class GroupConfig: class GroupConfig:
group: Optional[str] = None group: Optional[str] = None
group_alias: Optional[str] = None group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None task: Union[str, list] = field(default_factory=list)
aggregate_metric_list: Optional[ aggregate_metric_list: Optional[
Union[List[AggMetricConfig], AggMetricConfig, dict] Union[list[AggMetricConfig], AggMetricConfig, dict]
] = None ] = None
metadata: Optional[dict] = ( metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks None # by default, not used in the code. allows for users to pass arbitrary info to tasks
......
This diff is collapsed.
...@@ -3,29 +3,31 @@ from __future__ import annotations ...@@ -3,29 +3,31 @@ from __future__ import annotations
import functools import functools
import importlib.util import importlib.util
import sys import sys
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import yaml import yaml
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader _Base = (
yaml.CSafeLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
)
_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"} _IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"}
# --------------------------------------------------------------------------- helpers # --------------------------------------------------------------------------- helpers
@functools.lru_cache(128)
def _mk_function_ctor(base_dir: Path, resolve: bool): def _mk_function_ctor(base_dir: Path, resolve: bool):
def ctor(loader: yaml.Loader, node: yaml.Node): def ctor(loader: yaml.Loader, node: yaml.Node):
spec = loader.construct_scalar(node) # type: ignore[arg-type] spec = loader.construct_scalar(node) # type: ignore[arg-type]
if not resolve: if not resolve:
return lambda *_, **__: None return str(base_dir.expanduser() / spec)
return _import_function(spec, base_dir) return _import_func_in_yml(spec, base_dir)
return ctor return ctor
@functools.lru_cache(maxsize=1024) @functools.lru_cache(maxsize=512)
def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]: def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
class Loader(_Base): ... # type: ignore[no-redef] class Loader(_Base): ... # type: ignore[no-redef]
...@@ -37,8 +39,14 @@ def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]: ...@@ -37,8 +39,14 @@ def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
return Loader return Loader
@functools.lru_cache(maxsize=4096) @functools.lru_cache(maxsize=128)
def _import_function(qual: str, base_dir: Path): def _import_func_in_yml(qual: str, base_dir: Path):
"""Import function from qual: utils.process_doc, checking local files first then standard imports.
Args:
qual: Qualified function name (e.g., 'utils.process_doc')
base_dir: Directory to search for local modules
"""
mod_path, _, fn_name = qual.rpartition(".") mod_path, _, fn_name = qual.rpartition(".")
# 1) relative “utils.py” next to YAML # 1) relative “utils.py” next to YAML
rel = (base_dir / f"{mod_path.replace('.', '/')}.py").resolve() rel = (base_dir / f"{mod_path.replace('.', '/')}.py").resolve()
...@@ -47,26 +55,74 @@ def _import_function(qual: str, base_dir: Path): ...@@ -47,26 +55,74 @@ def _import_function(qual: str, base_dir: Path):
key = f"{rel}:{mtime}" # one module per mtime key = f"{rel}:{mtime}" # one module per mtime
if key not in sys.modules: if key not in sys.modules:
spec = importlib.util.spec_from_file_location(key, rel) spec = importlib.util.spec_from_file_location(key, rel)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from {rel}") from None
mod = importlib.util.module_from_spec(spec) mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[arg-type] spec.loader.exec_module(mod) # type: ignore[arg-type]
sys.modules[key] = mod sys.modules[key] = mod
return getattr(sys.modules[key], fn_name) return getattr(sys.modules[key], fn_name)
# 2) alreadyimportable module # 2) already-importable module
module = __import__(mod_path, fromlist=[fn_name]) module = __import__(mod_path, fromlist=[fn_name])
return getattr(module, fn_name) return getattr(module, fn_name)
# --------------------------------------------------------------------- public API @functools.lru_cache(maxsize=128)
def _import_fun_from_str(path_str: str) -> Any:
"""Import a function from a string in the form '/absolute/path/to/module.function_name'."""
try:
# Split off the function name from the rightmost dot
module_path_str, function_name = path_str.rsplit(".", 1)
except ValueError as e:
raise ValueError(
f"Invalid path format: {path_str}. Expected format: /path/to/module.function_name"
) from e
# Convert to Path and handle .py extension
module_path = Path(module_path_str)
if not module_path.suffix:
module_path = module_path.with_suffix(".py")
elif module_path.suffix != ".py":
# If it has a non-.py suffix, the user might have included .py in the path
# e.g., "/path/to/module.py.function_name"
base_path = module_path.with_suffix("")
if base_path.with_suffix(".py").exists():
module_path = base_path.with_suffix(".py")
if not module_path.exists():
raise ImportError(f"Module file not found: {module_path}")
# Use similar approach to _import_func_in_yml for consistency
mtime = module_path.stat().st_mtime_ns
cache_key = f"{module_path}:{mtime}"
if cache_key not in sys.modules:
spec = importlib.util.spec_from_file_location(cache_key, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules[cache_key] = module
module = sys.modules[cache_key]
if not hasattr(module, function_name):
raise AttributeError(
f"Function '{function_name}' not found in module {module_path}"
)
return getattr(module, function_name)
def load_yaml( def load_yaml(
path: str | Path, path: str | Path,
*, *,
resolve_functions: bool = True, resolve_functions: bool = True,
resolve_includes: bool = True, resolve_includes: bool = True,
_seen: set[Path] | None = None, _seen: set[Path] | None = None,
) -> dict[str, str | Callable[..., Any]]: ) -> dict[str, Any]:
"""Pure dataloading helper. """Pure data-loading helper.
Returns a dict ready for higherlevel interpretation. Returns a dict ready for higher-level interpretation.
•No task/group/tag semantics here. •No task/group/tag semantics here.
""" """
path = Path(path).expanduser().resolve() path = Path(path).expanduser().resolve()
...@@ -82,9 +138,11 @@ def load_yaml( ...@@ -82,9 +138,11 @@ def load_yaml(
if not resolve_includes or "include" not in cfg: if not resolve_includes or "include" not in cfg:
return cfg return cfg
else:
includes = cfg.pop("include")
merged = {} merged = {}
for inc in cfg.pop("include"): for inc in includes if isinstance(includes, list) else [includes]:
inc_path = (path.parent / inc) if not Path(inc).is_absolute() else Path(inc) inc_path = (path.parent / inc) if not Path(inc).is_absolute() else Path(inc)
merged.update( merged.update(
load_yaml( load_yaml(
......
from __future__ import annotations
import inspect
from collections.abc import Mapping
from copy import deepcopy
from functools import lru_cache
from typing import Any
from lm_eval.api.group import GroupConfig
from lm_eval.api.task import ConfigurableTask, Task # noqa: F401 (typing)
from lm_eval.tasks._config_loader import load_yaml as load_cfg
from lm_eval.tasks.index import Entry, Kind
load_cfg_cached = lru_cache(maxsize=512)(load_cfg) # type: ignore[no-redef]
class TaskFactory:
"""
Turns a *Entry* (plus optional overrides) into a
*Task* | *ConfigurableTask* | *GroupConfig* hierarchy.
"""
def __init__(self, *, meta: dict[str, Any] | None = None):
self._meta = meta or {}
# ---------------------------------------------------------------- public API
def build(
self,
entry: Entry,
*,
overrides: dict[str, Any] | None = None,
registry: Mapping[str, Entry],
):
"""
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
"""
if entry.kind is Kind.TAG:
return self._build_tag(entry, overrides, registry)
if entry.kind is Kind.GROUP:
return self._build_group(entry, overrides, registry)
return self._build_task(entry, overrides)
def _build_task(self, entry: Entry, overrides: dict[str, Any] | None):
cfg = self._load_full_config(entry, overrides)
if "class" in cfg: # PY_TASK route
cls = cfg["class"]
obj = cls(config=cfg) if _ctor_accepts_config(cls) else cls()
if isinstance(obj, ConfigurableTask):
obj.config.task = entry.name
return obj
# YAML task
return ConfigurableTask(config=cfg) # type: ignore[arg-type]
def _build_group(
self,
entry: Entry,
overrides: dict[str, Any] | None,
registry: Mapping[str, Entry],
):
raw_cfg = self._load_full_config(entry, None)
grp_cfg = {k: v for k, v in raw_cfg.items() if k in GroupConfig.__annotations__}
grp_cfg["metadata"] = grp_cfg.get("metadata", {}) | self._meta
group_obj = GroupConfig(**grp_cfg)
children: dict[str, Any] = {}
for item in group_obj.task:
if isinstance(item, str): # task: hellaswag
child = self.build(
registry[item],
overrides=overrides, # group-level overrides propagate
registry=registry,
)
elif isinstance(item, dict): # task: {task: hellaswag, num_fewshot: 5}
base_name = item["task"]
child = self.build(
registry[base_name],
overrides=item, # per-item override
registry=registry,
)
else:
raise TypeError(
f"Unsupported sub-entry {item!r} in group '{entry.name}'"
)
# `child` itself is a mapping (task-name -> obj) or {GroupConfig: ...}
children.update(child)
return {group_obj: children}
def _build_tag(
self,
entry: Entry,
overrides: dict[str, Any] | None,
registry: Mapping[str, Entry],
):
return {
name: self._build_task(registry[name], overrides) for name in entry.tags
}
def _load_full_config(
self, entry: Entry, overrides: dict[str, Any] | None
) -> dict[str, Any]:
if entry.yaml_path:
cfg = deepcopy(load_cfg_cached(entry.yaml_path, resolve_functions=True))
print(f"Loaded task config from {load_cfg_cached.cache_info()}")
else:
cfg = {"metadata": {"config": "unknown"}} # python task without YAML
if overrides:
cfg = {**cfg, **overrides}
cfg["metadata"] = (
m if isinstance(m := cfg.get("metadata", {}), dict) else {"_metadata": m}
) | self._meta
cfg.setdefault("task", entry.name)
return cfg
def _ctor_accepts_config(cls) -> bool:
init = getattr(cls, "__init__", None)
return init and "config" in inspect.signature(init).parameters
# lm_eval/task_index.py (continued)
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
from lm_eval.tasks._config_loader import load_yaml as load_cfg from lm_eval.tasks._config_loader import load_yaml as load_cfg
...@@ -14,137 +13,159 @@ if TYPE_CHECKING: ...@@ -14,137 +13,159 @@ if TYPE_CHECKING:
from pathlib import Path from pathlib import Path
class TaskKind(Enum): class Kind(Enum):
TASK = auto() # YAML task, or task_list entry TASK = auto() # YAML task, or task_list entry
PY_TASK = auto() # Pythondefined, via "class" PY_TASK = auto() # Python-defined, via "class"
GROUP = auto() GROUP = auto()
TAG = auto() TAG = auto()
TASK_LIST = auto() TASK_LIST = auto()
@dataclass @dataclass
class TaskEntry: class Entry:
name: str name: str
kind: TaskKind kind: Kind
yaml_path: Path | None # None for generated / py‑only entries yaml_path: Path | None # None for generated / py-only entries
cfg: dict[str, str] | None = None
tags: set[str] = field(default_factory=set) tags: set[str] = field(default_factory=set)
task_list_path: Path | None = None # only for GROUP / TAG when lazy‑loaded task_list_path: Path | None = None
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"} _IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"}
class TaskIndexBuilder: class TaskIndex:
"""Walks one or more directories, parses YAML quickly (functions unresolved), """Walks one or more directories, parses YAML quickly (functions unresolved),
and produces a mapping {task_name: TaskEntry}. and produces a mapping {task_name: Entry}.
""" """
def __init__(self, *, metadata: dict | None = None) -> None: def __init__(self, *, meta: dict[str, str] | None = None) -> None:
self._metadata = metadata or {} self._metadata = meta or {}
# ------------- public API --------------------------------------------------
def build( def build(
self, self,
paths: Iterable[Path], paths: Iterable[Path],
# include_defaults: bool = True, *,
) -> dict[str, TaskEntry]: resolve_includes=False,
index: dict[str, TaskEntry] = {} ) -> dict[str, Entry]:
index: dict[str, Entry] = {}
log.debug("Building task index from %s", paths)
for root in paths: for root in paths:
for yaml_path in self._iter_yaml_files(root): for yaml_path in self._iter_yaml_files(root):
try: try:
cfg = load_cfg( cfg = load_cfg(
yaml_path, yaml_path,
resolve_functions=False, resolve_functions=False,
resolve_includes=False, resolve_includes=resolve_includes,
) )
self.process_cfg(cfg, yaml_path, index)
except Exception as err: except Exception as err:
log.debug("Skip %s (%s)", yaml_path, err) log.debug("Skip %s (%s)", yaml_path, err)
continue continue
self._process_cfg(cfg, yaml_path, index) # self._process_cfg(cfg, yaml_path, index)
log.debug("Built task index with %d entries", len(index))
return index return index
# ------------- helpers ----------------------------------------------------- @staticmethod
def _iter_yaml_files(self, root: Path): def _iter_yaml_files(root: Path):
yield from ( yield from (
p p
for p in root.glob("**/*.yaml") for p in root.glob("**/*.yaml")
if not any(part in _IGNORE_DIRS for part in p.parts) if not any(part in _IGNORE_DIRS for part in p.parts)
) )
# --------------------------------------------------------------------------- @staticmethod
def _process_cfg( def process_cfg(
self, cfg: dict[str, Any],
cfg: dict,
path: Path, path: Path,
index: dict[str, TaskEntry], index: dict[str, Entry],
) -> None: ) -> None:
kind = self._kind_of(cfg) kind = TaskIndex._kind_of(cfg)
if kind is TaskKind.GROUP: if kind is Kind.GROUP:
grp_name = cfg["group"] grp_name = cfg["group"]
index[grp_name] = TaskEntry( index[grp_name] = Entry(
name=grp_name, name=grp_name,
kind=TaskKind.GROUP, kind=Kind.GROUP,
yaml_path=path, yaml_path=path,
tags=set(cfg.get("tag", [])), tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
) )
return return
if kind is TaskKind.PY_TASK: if kind is Kind.PY_TASK:
name = cfg["task"] name = cfg["task"]
index[name] = TaskEntry( index[name] = Entry(
name=name, name=name,
kind=TaskKind.PY_TASK, kind=Kind.PY_TASK,
yaml_path=None, yaml_path=None,
tags=set(cfg.get("tag", [])), tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
) )
self._register_tags(name, cfg.get("tag", []), index) TaskIndex._register_tags(name, cfg.get("tag"), index)
return return
if kind is TaskKind.TASK: if kind is Kind.TASK:
name = cfg["task"] name = cfg["task"]
index[name] = TaskEntry( index[name] = Entry(
name=name, name=name,
kind=TaskKind.TASK, kind=Kind.TASK,
yaml_path=path, yaml_path=path,
tags=set(cfg.get("tag", [])), tags=TaskIndex._str_to_set(cfg.get("tag")),
cfg=cfg,
) )
self._register_tags(name, cfg.get("tag", []), index) TaskIndex._register_tags(name, cfg.get("tag"), index)
return return
if kind is TaskKind.TASK_LIST: if kind is Kind.TASK_LIST:
for entry in cfg["task_list"]: for entry in cfg["task_list"]:
task_name = entry["task"] if isinstance(entry, dict) else entry task_name = entry["task"] if isinstance(entry, dict) else entry
index[task_name] = TaskEntry( index[task_name] = Entry(
name=task_name, name=task_name,
kind=TaskKind.TASK, kind=Kind.TASK,
yaml_path=path, yaml_path=path,
tags=set(entry.get("tag", [])) tags=TaskIndex._str_to_set(cfg.get("tag")),
if isinstance(entry, dict) cfg=cfg,
else set(),
) )
self._register_tags(task_name, entry.get("tag", []), index) TaskIndex._register_tags(task_name, entry.get("tag"), index)
return return
# --------------------------------------------------------------------------- @staticmethod
def _register_tags(self, task: str, tags, index) -> None: def _register_tags(
task: str,
tags: str | list[str] | None,
index: dict[str, Entry],
) -> None:
if not tags:
return
for tag in tags if isinstance(tags, list) else [tags]: for tag in tags if isinstance(tags, list) else [tags]:
if not tag:
continue
entry = index.setdefault( entry = index.setdefault(
tag, tag,
TaskEntry(name=tag, kind=TaskKind.TAG, yaml_path=None, tags=set()), Entry(name=tag, kind=Kind.TAG, yaml_path=None, tags=set()),
) )
entry.tags.add(task) # mutate ok; dataclass not frozen for TAG entry.tags.add(task)
@staticmethod @staticmethod
def _kind_of(cfg: dict) -> TaskKind: def _kind_of(cfg: dict) -> Kind:
if "class" in cfg: if "class" in cfg:
return TaskKind.PY_TASK return Kind.PY_TASK
if "group" in cfg:
return Kind.GROUP
if "task_list" in cfg: if "task_list" in cfg:
return TaskKind.TASK_LIST return Kind.TASK_LIST
if "task" in cfg: if "task" in cfg:
return TaskKind.GROUP if isinstance(cfg["task"], list) else TaskKind.TASK return Kind.GROUP if isinstance(cfg["task"], list) else Kind.TASK
msg = "Unknown config shape" msg = "Unknown config shape"
raise ValueError(msg) raise ValueError(msg) from None
@staticmethod
def _str_to_set(tags: str | list[str] | None = None) -> set[str]:
"""Convert a string or list of strings to a set of strings."""
return (
set(tags)
if isinstance(tags, list)
else {tags}
if isinstance(tags, str)
else set()
)
from __future__ import annotations
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Any
from lm_eval.tasks.factory import TaskFactory
from lm_eval.tasks.index import Entry, Kind, TaskIndex
from lm_eval.utils import setup_logging
class TaskManager:
def __init__(
self,
verbosity: str | None = None,
include_path: str | Path | list[str | Path] | None = None,
include_defaults: bool = True,
metadata: dict[str, dict[str, Any]] | None = None,
) -> None:
if verbosity:
setup_logging(verbosity)
index = TaskIndex()
self._factory = TaskFactory(meta=metadata)
all_paths: list[Path] = []
if include_defaults:
all_paths.append(Path(__file__).parent)
if include_path:
all_paths += [
Path(p)
for p in (
include_path
if isinstance(include_path, (list, tuple))
else [include_path]
)
]
self._index = index.build(all_paths)
buckets = defaultdict(list)
for k, e in self._index.items():
buckets[e.kind].append(k)
self._all_tasks = sorted(
chain.from_iterable(buckets[k] for k in {Kind.TASK, Kind.PY_TASK})
)
self._all_groups = sorted(buckets[Kind.GROUP])
self._all_tags = sorted(buckets[Kind.TAG])
def _entry(self, name: str) -> Entry:
if name not in self._index:
raise KeyError(f"Unknown task/group/tag: {name}")
return self._index[name]
def load_spec(self, spec: str | dict[str, Any]):
"""Spec can be:
• str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
"""
if isinstance(spec, str):
entry = self._entry(spec)
return self._factory.build(entry, overrides=None, registry=self._index)
if isinstance(spec, dict):
# inline dict => find base entry, then pass overrides
name = spec["task"]
entry = self._entry(name)
return self._factory.build(entry, overrides=spec, registry=self._index)
raise TypeError("spec must be str or dict")
def load_task_or_group(self, task_list: str | list[str]):
return (
[self.load_spec(s) for s in task_list]
if isinstance(task_list, list)
else [self.load_spec(task_list)]
)
...@@ -103,7 +103,8 @@ plugins.md029.allow_extended_start_values = true # ol-prefix ...@@ -103,7 +103,8 @@ plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled = false # no-bare-urls plugins.md034.enabled = false # no-bare-urls
[tool.ruff.lint] [tool.ruff.lint]
extend-select = ["I"] select = ["ASYNC","B", "C4", "E", "F", "I", "LOG","PIE", "PTH","SIM", "UP", "PERF", "ISC001", "ISC002", "ICN001", "C901","FURB", "RUF"]
ignore = ["E501", "E111", "E114", "E117", "E501", "PERF203", "B011"]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
lines-after-imports = 2 lines-after-imports = 2
...@@ -111,7 +112,6 @@ known-first-party = ["lm_eval"] ...@@ -111,7 +112,6 @@ known-first-party = ["lm_eval"]
[tool.ruff.lint.extend-per-file-ignores] [tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403"] "__init__.py" = ["F401","F402","F403"]
"utils.py" = ["F401"]
[dependency-groups] [dependency-groups]
dev = [ dev = [
......
...@@ -20,7 +20,7 @@ Test coverage: ...@@ -20,7 +20,7 @@ Test coverage:
- load(): - load():
- test_load_simple_yaml: basic YAML parsing - test_load_simple_yaml: basic YAML parsing
- test_load_with_function_resolved: !function tags resolved to callables - test_load_with_function_resolved: !function tags resolved to callables
- test_load_with_function_not_resolved: !function tags become no-op lambdas - test_load_with_function_not_resolved: !function tags become strings
- test_load_with_includes: include files merged, main values win - test_load_with_includes: include files merged, main values win
- test_load_with_absolute_include: absolute path includes - test_load_with_absolute_include: absolute path includes
- test_load_without_includes_resolution: includes preserved when disabled - test_load_without_includes_resolution: includes preserved when disabled
...@@ -38,9 +38,10 @@ import pytest ...@@ -38,9 +38,10 @@ import pytest
from lm_eval.tasks._config_loader import ( from lm_eval.tasks._config_loader import (
_Base, _Base,
_import_function, _import_func_in_yml,
_make_loader, _make_loader,
_mk_function_ctor, _mk_function_ctor,
import_fun_from_str,
load_yaml, load_yaml,
) )
...@@ -75,7 +76,7 @@ class TestMkFunctionCtor: ...@@ -75,7 +76,7 @@ class TestMkFunctionCtor:
"""Tests for the YAML !function constructor factory.""" """Tests for the YAML !function constructor factory."""
def test_mk_function_ctor_with_resolve_false(self, temp_dir): def test_mk_function_ctor_with_resolve_false(self, temp_dir):
"""When resolve=False, should return a no-op lambda.""" """When resolve=False, should return a string."""
ctor = _mk_function_ctor(temp_dir, resolve=False) ctor = _mk_function_ctor(temp_dir, resolve=False)
loader = MagicMock() loader = MagicMock()
...@@ -84,8 +85,7 @@ class TestMkFunctionCtor: ...@@ -84,8 +85,7 @@ class TestMkFunctionCtor:
result = ctor(loader, node) result = ctor(loader, node)
assert callable(result) assert isinstance(result, str)
assert result("arg1", kwarg="value") is None
def test_mk_function_ctor_with_resolve_true(self, temp_dir, python_module): def test_mk_function_ctor_with_resolve_true(self, temp_dir, python_module):
"""When resolve=True, should import and return the actual function.""" """When resolve=True, should import and return the actual function."""
...@@ -136,7 +136,7 @@ class TestImportFunction: ...@@ -136,7 +136,7 @@ class TestImportFunction:
# Create a local module # Create a local module
python_module("def local_func(x, y):\n return x + y\n") python_module("def local_func(x, y):\n return x + y\n")
func = _import_function("utils.local_func", temp_dir) func = _import_func_in_yml("utils.local_func", temp_dir)
assert callable(func) assert callable(func)
assert func(2, 3) == 5 assert func(2, 3) == 5
...@@ -149,7 +149,7 @@ class TestImportFunction: ...@@ -149,7 +149,7 @@ class TestImportFunction:
"def nested_func():\n return 'nested'\n" "def nested_func():\n return 'nested'\n"
) )
func = _import_function("sub.module.nested_func", temp_dir) func = _import_func_in_yml("sub.module.nested_func", temp_dir)
assert callable(func) assert callable(func)
assert func() == "nested" assert func() == "nested"
...@@ -157,19 +157,19 @@ class TestImportFunction: ...@@ -157,19 +157,19 @@ class TestImportFunction:
def test_import_standard_module(self, temp_dir): def test_import_standard_module(self, temp_dir):
"""Falls back to standard import for non-local modules.""" """Falls back to standard import for non-local modules."""
# Import from standard library # Import from standard library
func = _import_function("os.path.join", temp_dir) func = _import_func_in_yml("os.path.join", temp_dir)
assert callable(func) assert callable(func)
assert func("a", "b") in ("a/b", "a\\b") # Unix or Windows assert func("a", "b") in ("a/b", "a\\b") # Unix or Windows
def test_import_caching(self, temp_dir, python_module): def test_import_caching(self, temp_dir, python_module):
# Clear cache first # Clear cache first
_import_function.cache_clear() _import_func_in_yml.cache_clear()
python_module("def cached_func():\n return 42\n") python_module("def cached_func():\n return 42\n")
func1 = _import_function("utils.cached_func", temp_dir) func1 = _import_func_in_yml("utils.cached_func", temp_dir)
func2 = _import_function("utils.cached_func", temp_dir) func2 = _import_func_in_yml("utils.cached_func", temp_dir)
assert func1 is func2 # Cached assert func1 is func2 # Cached
...@@ -177,7 +177,7 @@ class TestImportFunction: ...@@ -177,7 +177,7 @@ class TestImportFunction:
"""Verifies LRU cache behavior - file changes require cache clear.""" """Verifies LRU cache behavior - file changes require cache clear."""
# Clear the LRU cache # Clear the LRU cache
_import_function.cache_clear() _import_func_in_yml.cache_clear()
# Create a module # Create a module
module_path = temp_dir / "test_mtime.py" module_path = temp_dir / "test_mtime.py"
...@@ -185,17 +185,102 @@ class TestImportFunction: ...@@ -185,17 +185,102 @@ class TestImportFunction:
# Import it # Import it
import_key = "test_mtime.value" import_key = "test_mtime.value"
value1 = _import_function(import_key, temp_dir) value1 = _import_func_in_yml(import_key, temp_dir)
assert value1 == 1 assert value1 == 1
value2 = _import_function(import_key, temp_dir) value2 = _import_func_in_yml(import_key, temp_dir)
assert value2 == 1 # From cache assert value2 == 1 # From cache
_import_function.cache_clear() _import_func_in_yml.cache_clear()
value3 = _import_function(import_key, temp_dir) value3 = _import_func_in_yml(import_key, temp_dir)
assert value3 == 1 # Re-imported assert value3 == 1 # Re-imported
class TestImportFunFromStr:
"""Tests for import_fun_from_str function."""
def test_import_from_absolute_path(self, temp_dir):
"""Test importing function from absolute path."""
# Create a test module
module_path = temp_dir / "test_module.py"
module_path.write_text("def test_func(x):\n return x * 2\n")
# Import using absolute path
func = import_fun_from_str(f"{module_path.with_suffix('')}.test_func")
assert callable(func)
assert func(5) == 10
def test_import_with_py_extension(self, temp_dir):
"""Test importing when .py is included in the path."""
# Create a test module
module_path = temp_dir / "test_module.py"
module_path.write_text("def test_func(x):\n return x + 10\n")
# Import with .py in the path
func = import_fun_from_str(f"{module_path}.test_func")
assert callable(func)
assert func(5) == 15
def test_import_nested_function(self, temp_dir):
"""Test importing from nested module structure."""
# Create nested directory
(temp_dir / "subdir").mkdir()
module_path = temp_dir / "subdir" / "nested.py"
module_path.write_text("def nested_func():\n return 'nested'\n")
# Import from nested path
func = import_fun_from_str(f"{module_path.with_suffix('')}.nested_func")
assert callable(func)
assert func() == "nested"
def test_import_missing_module(self, temp_dir):
"""Test error when module doesn't exist."""
with pytest.raises(ImportError, match="Module file not found"):
import_fun_from_str(f"{temp_dir}/nonexistent.test_func")
def test_import_missing_function(self, temp_dir):
"""Test error when function doesn't exist in module."""
module_path = temp_dir / "test_module.py"
module_path.write_text("def other_func():\n pass\n")
with pytest.raises(AttributeError, match="Function 'missing_func' not found"):
import_fun_from_str(f"{module_path.with_suffix('')}.missing_func")
def test_import_invalid_format(self):
"""Test error with invalid path format."""
with pytest.raises(ValueError, match="Invalid path format"):
import_fun_from_str("/path/without/function")
def test_import_caching(self, temp_dir):
"""Test that modules are cached by mtime."""
# Clear any existing cache
import sys
keys_to_remove = [k for k in sys.modules if str(temp_dir) in k]
for k in keys_to_remove:
del sys.modules[k]
module_path = temp_dir / "cached_module.py"
module_path.write_text(
"call_count = 0\ndef func():\n global call_count\n call_count += 1\n return call_count\n"
)
# First import
func1 = import_fun_from_str(f"{module_path.with_suffix('')}.func")
_result1 = func1()
# Second import should use cached module
func2 = import_fun_from_str(f"{module_path.with_suffix('')}.func")
result2 = func2()
# Both should refer to the same module instance
assert func1 is func2
assert result2 == 2 # call_count incremented
class TestLoad: class TestLoad:
"""Tests for the main YAML loading function with includes and function resolution.""" """Tests for the main YAML loading function with includes and function resolution."""
...@@ -237,8 +322,10 @@ doc_to_text: !function utils.process_doc ...@@ -237,8 +322,10 @@ doc_to_text: !function utils.process_doc
result = load_yaml(file_path, resolve_functions=False) result = load_yaml(file_path, resolve_functions=False)
assert callable(result["doc_to_text"]) assert isinstance(result["doc_to_text"], str)
assert result["doc_to_text"]("hello") is None # No-op lambda # When resolve_functions=False, it returns the full path + function spec
assert result["doc_to_text"].endswith("utils.process_doc")
assert result["doc_to_text"] == str(file_path.parent / "utils.process_doc")
def test_load_with_includes(self, temp_dir, yaml_file): def test_load_with_includes(self, temp_dir, yaml_file):
"""Include files are merged with local values taking precedence.""" """Include files are merged with local values taking precedence."""
...@@ -388,3 +475,7 @@ shared_key: from_main ...@@ -388,3 +475,7 @@ shared_key: from_main
mock_expand.assert_called_once() mock_expand.assert_called_once()
assert result["test"] == "value" assert result["test"] == "value"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
""" """Tests for the task index builder that discovers YAML task configurations.
Tests for the task index builder that discovers YAML task configurations.
Test coverage: Test coverage:
- TaskIndexBuilder._kind_of: identifies task/group/tag/task_list/py_task - TaskIndexBuilder._kind_of: identifies task/group/tag/task_list/py_task
...@@ -14,7 +13,7 @@ from pathlib import Path ...@@ -14,7 +13,7 @@ from pathlib import Path
import pytest import pytest
from lm_eval.tasks._task_index import TaskIndexBuilder, TaskKind from lm_eval.tasks._task_index import TaskIndex, TaskKind
@pytest.fixture @pytest.fixture
...@@ -40,28 +39,28 @@ class TestTaskKindOf: ...@@ -40,28 +39,28 @@ class TestTaskKindOf:
def test_kind_of_task(self): def test_kind_of_task(self):
"""Single task with string name.""" """Single task with string name."""
cfg = {"task": "my_task", "dataset_path": "data"} cfg = {"task": "my_task", "dataset_path": "data"}
assert TaskIndexBuilder._kind_of(cfg) == TaskKind.TASK assert TaskIndex._kind_of(cfg) == TaskKind.TASK
def test_kind_of_group(self): def test_kind_of_group(self):
"""Group has task as list.""" """Group has task as list."""
cfg = {"task": ["task1", "task2"], "group": "my_group"} cfg = {"task": ["task1", "task2"], "group": "my_group"}
assert TaskIndexBuilder._kind_of(cfg) == TaskKind.GROUP assert TaskIndex._kind_of(cfg) == TaskKind.GROUP
def test_kind_of_py_task(self): def test_kind_of_py_task(self):
"""Python task has class field.""" """Python task has class field."""
cfg = {"task": "my_task", "class": "tasks.MyTask"} cfg = {"task": "my_task", "class": "tasks.MyTask"}
assert TaskIndexBuilder._kind_of(cfg) == TaskKind.PY_TASK assert TaskIndex._kind_of(cfg) == TaskKind.PY_TASK
def test_kind_of_task_list(self): def test_kind_of_task_list(self):
"""Task list has task_list field.""" """Task list has task_list field."""
cfg = {"task_list": ["task1", "task2"]} cfg = {"task_list": ["task1", "task2"]}
assert TaskIndexBuilder._kind_of(cfg) == TaskKind.TASK_LIST assert TaskIndex._kind_of(cfg) == TaskKind.TASK_LIST
def test_kind_of_unknown(self): def test_kind_of_unknown(self):
"""Unknown config raises ValueError.""" """Unknown config raises ValueError."""
cfg = {"unknown": "field"} cfg = {"unknown": "field"}
with pytest.raises(ValueError, match="Unknown config shape"): with pytest.raises(ValueError, match="Unknown config shape"):
TaskIndexBuilder._kind_of(cfg) TaskIndex._kind_of(cfg)
class TestIterYamlFiles: class TestIterYamlFiles:
...@@ -75,8 +74,8 @@ class TestIterYamlFiles: ...@@ -75,8 +74,8 @@ class TestIterYamlFiles:
(temp_dir / "subdir" / "task2.yaml").touch() (temp_dir / "subdir" / "task2.yaml").touch()
(temp_dir / "other.txt").touch() (temp_dir / "other.txt").touch()
builder = TaskIndexBuilder() builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files(temp_dir)) yaml_files = list(builder._iter_yaml_files())
assert len(yaml_files) == 2 assert len(yaml_files) == 2
names = {f.name for f in yaml_files} names = {f.name for f in yaml_files}
...@@ -90,8 +89,8 @@ class TestIterYamlFiles: ...@@ -90,8 +89,8 @@ class TestIterYamlFiles:
(temp_dir / ".ipynb_checkpoints").mkdir() (temp_dir / ".ipynb_checkpoints").mkdir()
(temp_dir / ".ipynb_checkpoints" / "also_ignored.yaml").touch() (temp_dir / ".ipynb_checkpoints" / "also_ignored.yaml").touch()
builder = TaskIndexBuilder() builder = TaskIndex()
yaml_files = list(builder._iter_yaml_files(temp_dir)) yaml_files = list(builder._iter_yaml_files())
assert len(yaml_files) == 1 assert len(yaml_files) == 1
assert yaml_files[0].name == "task.yaml" assert yaml_files[0].name == "task.yaml"
...@@ -106,8 +105,8 @@ class TestProcessCfg: ...@@ -106,8 +105,8 @@ class TestProcessCfg:
path = temp_dir / "task.yaml" path = temp_dir / "task.yaml"
index = {} index = {}
builder = TaskIndexBuilder() builder = TaskIndex()
builder._process_cfg(cfg, path, index) builder.process_cfg(cfg, path, index)
assert "my_task" in index assert "my_task" in index
entry = index["my_task"] entry = index["my_task"]
...@@ -122,8 +121,8 @@ class TestProcessCfg: ...@@ -122,8 +121,8 @@ class TestProcessCfg:
path = temp_dir / "group.yaml" path = temp_dir / "group.yaml"
index = {} index = {}
builder = TaskIndexBuilder() builder = TaskIndex()
builder._process_cfg(cfg, path, index) builder.process_cfg(cfg, path, index)
assert "my_group" in index assert "my_group" in index
entry = index["my_group"] entry = index["my_group"]
...@@ -138,8 +137,8 @@ class TestProcessCfg: ...@@ -138,8 +137,8 @@ class TestProcessCfg:
path = temp_dir / "py_task.yaml" path = temp_dir / "py_task.yaml"
index = {} index = {}
builder = TaskIndexBuilder() builder = TaskIndex()
builder._process_cfg(cfg, path, index) builder.process_cfg(cfg, path, index)
assert "py_task" in index assert "py_task" in index
entry = index["py_task"] entry = index["py_task"]
...@@ -154,27 +153,30 @@ class TestProcessCfg: ...@@ -154,27 +153,30 @@ class TestProcessCfg:
"task_list": [ "task_list": [
"simple_task", "simple_task",
{"task": "complex_task", "tag": ["tag1", "tag2"]}, {"task": "complex_task", "tag": ["tag1", "tag2"]},
] ],
} }
path = temp_dir / "list.yaml" path = temp_dir / "list.yaml"
index = {} index = {}
builder = TaskIndexBuilder() builder = TaskIndex()
# The implementation has a bug - it calls entry.get() on string entries # The implementation has a bug - it calls entry.get() on string entries
# This test documents the current behavior which will fail # This test documents the current behavior which will fail
with pytest.raises(AttributeError, match="'str' object has no attribute 'get'"): with pytest.raises(AttributeError, match="'str' object has no attribute 'get'"):
builder._process_cfg(cfg, path, index) builder.process_cfg(cfg, path, index)
def test_process_task_list_dict_entries(self, temp_dir): def test_process_task_list_dict_entries(self, temp_dir):
"""Task list with only dict entries works.""" """Task list with only dict entries works."""
cfg = { cfg = {
"task_list": [{"task": "task1"}, {"task": "task2", "tag": ["tag1", "tag2"]}] "task_list": [
{"task": "task1"},
{"task": "task2", "tag": ["tag1", "tag2"]},
],
} }
path = temp_dir / "list.yaml" path = temp_dir / "list.yaml"
index = {} index = {}
builder = TaskIndexBuilder() builder = TaskIndex()
builder._process_cfg(cfg, path, index) builder.process_cfg(cfg, path, index)
# Task without tags # Task without tags
assert "task1" in index assert "task1" in index
...@@ -197,7 +199,7 @@ class TestRegisterTags: ...@@ -197,7 +199,7 @@ class TestRegisterTags:
def test_register_single_tag(self): def test_register_single_tag(self):
"""Single tag creates TAG entry.""" """Single tag creates TAG entry."""
index = {} index = {}
builder = TaskIndexBuilder() builder = TaskIndex()
builder._register_tags("task1", "my_tag", index) builder._register_tags("task1", "my_tag", index)
...@@ -210,7 +212,7 @@ class TestRegisterTags: ...@@ -210,7 +212,7 @@ class TestRegisterTags:
def test_register_multiple_tags(self): def test_register_multiple_tags(self):
"""Multiple tags create multiple TAG entries.""" """Multiple tags create multiple TAG entries."""
index = {} index = {}
builder = TaskIndexBuilder() builder = TaskIndex()
builder._register_tags("task1", ["tag1", "tag2"], index) builder._register_tags("task1", ["tag1", "tag2"], index)
...@@ -222,7 +224,7 @@ class TestRegisterTags: ...@@ -222,7 +224,7 @@ class TestRegisterTags:
def test_register_tags_accumulates(self): def test_register_tags_accumulates(self):
"""Multiple tasks can have same tag.""" """Multiple tasks can have same tag."""
index = {} index = {}
builder = TaskIndexBuilder() builder = TaskIndex()
builder._register_tags("task1", "shared_tag", index) builder._register_tags("task1", "shared_tag", index)
builder._register_tags("task2", "shared_tag", index) builder._register_tags("task2", "shared_tag", index)
...@@ -237,7 +239,7 @@ class TestBuild: ...@@ -237,7 +239,7 @@ class TestBuild:
def test_build_empty_directory(self, temp_dir): def test_build_empty_directory(self, temp_dir):
"""Empty directory returns empty index.""" """Empty directory returns empty index."""
builder = TaskIndexBuilder() builder = TaskIndex()
index = builder.build([temp_dir]) index = builder.build([temp_dir])
assert index == {} assert index == {}
...@@ -245,7 +247,7 @@ class TestBuild: ...@@ -245,7 +247,7 @@ class TestBuild:
"""Single task file is discovered.""" """Single task file is discovered."""
yaml_file("task: my_task\ndataset_path: data\n") yaml_file("task: my_task\ndataset_path: data\n")
builder = TaskIndexBuilder() builder = TaskIndex()
index = builder.build([temp_dir]) index = builder.build([temp_dir])
assert len(index) == 1 assert len(index) == 1
...@@ -269,7 +271,7 @@ class TestBuild: ...@@ -269,7 +271,7 @@ class TestBuild:
# Python task # Python task
yaml_file("task: py_task\nclass: MyClass\n", "python.yaml") yaml_file("task: py_task\nclass: MyClass\n", "python.yaml")
builder = TaskIndexBuilder() builder = TaskIndex()
index = builder.build([temp_dir]) index = builder.build([temp_dir])
# Check all entries exist # Check all entries exist
...@@ -297,7 +299,7 @@ class TestBuild: ...@@ -297,7 +299,7 @@ class TestBuild:
yaml_file("task: sub_task\n", "subdir/sub.yaml") yaml_file("task: sub_task\n", "subdir/sub.yaml")
yaml_file("task: deep_task\n", "subdir/deeper/deep.yaml") yaml_file("task: deep_task\n", "subdir/deeper/deep.yaml")
builder = TaskIndexBuilder() builder = TaskIndex()
index = builder.build([temp_dir]) index = builder.build([temp_dir])
assert len(index) == 3 assert len(index) == 3
...@@ -308,7 +310,7 @@ class TestBuild: ...@@ -308,7 +310,7 @@ class TestBuild:
yaml_file("task: valid_task\n", "valid.yaml") yaml_file("task: valid_task\n", "valid.yaml")
yaml_file("invalid: [\n", "invalid.yaml") # Invalid YAML yaml_file("invalid: [\n", "invalid.yaml") # Invalid YAML
builder = TaskIndexBuilder() builder = TaskIndex()
index = builder.build([temp_dir]) index = builder.build([temp_dir])
assert len(index) == 1 assert len(index) == 1
...@@ -325,7 +327,7 @@ class TestBuild: ...@@ -325,7 +327,7 @@ class TestBuild:
(dir1 / "task1.yaml").write_text("task: task1\n") (dir1 / "task1.yaml").write_text("task: task1\n")
(dir2 / "task2.yaml").write_text("task: task2\n") (dir2 / "task2.yaml").write_text("task: task2\n")
builder = TaskIndexBuilder() builder = TaskIndex()
index = builder.build([dir1, dir2]) index = builder.build([dir1, dir2])
assert len(index) == 2 assert len(index) == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment