Commit 6e1866f5 authored by Baber's avatar Baber
Browse files

add task factory

parent 4254c7bd
from __future__ import annotations
import functools
import importlib.util
import sys
from pathlib import Path
......@@ -15,8 +14,6 @@ _Base = (
_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"}
# --------------------------------------------------------------------------- helpers
@functools.lru_cache(128)
def _mk_function_ctor(base_dir: Path, resolve: bool):
def ctor(loader: yaml.Loader, node: yaml.Node):
spec = loader.construct_scalar(node) # type: ignore[arg-type]
......@@ -27,7 +24,6 @@ def _mk_function_ctor(base_dir: Path, resolve: bool):
return ctor
@functools.lru_cache(maxsize=512)
def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
class Loader(_Base): ... # type: ignore[no-redef]
......@@ -39,7 +35,61 @@ def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
return Loader
@functools.lru_cache(maxsize=128)
def _load_module_with_cache(module_path: Path) -> Any:
"""Load a module from a file path with caching and hot-reload support.
Args:
module_path: Path to the Python file to load
Returns:
The loaded module
"""
# Determine module name based on location
path_str = str(module_path)
# Check if this is a built-in task module
if "/lm_eval/tasks/" in path_str:
# Find the position of lm_eval/tasks/ in the path
tasks_idx = path_str.find("/lm_eval/tasks/")
if tasks_idx != -1:
# Extract path starting from lm_eval/tasks/
# e.g., /path/to/lm_eval/tasks/hellaswag/utils.py → hellaswag/utils.py
relative_path = path_str[tasks_idx + len("/lm_eval/tasks/") :]
# Remove .py and convert to module name
# e.g., hellaswag/utils.py → lm_eval.tasks.hellaswag.utils
module_parts = relative_path.replace(".py", "").replace("/", ".")
module_name = f"lm_eval.tasks.{module_parts}"
else:
# Fallback to full path if pattern not found
module_name = str(module_path.with_suffix(""))
else:
# External module - use full path without extension
module_name = str(module_path.with_suffix(""))
# Check if we need to reload the module
if module_name in sys.modules:
existing_module = sys.modules[module_name]
# Check if it was modified
current_mtime = module_path.stat().st_mtime_ns
if (
hasattr(existing_module, "__mtime__")
and existing_module.__mtime__ == current_mtime
):
# Module hasn't changed, reuse it
return existing_module
# Load or reload the module
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec)
# Store mtime for future checks
module.__mtime__ = module_path.stat().st_mtime_ns
spec.loader.exec_module(module) # type: ignore[arg-type]
sys.modules[module_name] = module
return module
def _import_func_in_yml(qual: str, base_dir: Path):
"""Import function from qual: utils.process_doc, checking local files first then standard imports.
......@@ -48,26 +98,17 @@ def _import_func_in_yml(qual: str, base_dir: Path):
base_dir: Directory to search for local modules
"""
mod_path, _, fn_name = qual.rpartition(".")
# 1) relative utils.py next to YAML
# 1) relative "utils.py" next to YAML
rel = (base_dir / f"{mod_path.replace('.', '/')}.py").resolve()
if rel.exists():
mtime = rel.stat().st_mtime_ns
key = f"{rel}:{mtime}" # one module per mtime
if key not in sys.modules:
spec = importlib.util.spec_from_file_location(key, rel)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from {rel}") from None
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore[arg-type]
sys.modules[key] = mod
return getattr(sys.modules[key], fn_name)
module = _load_module_with_cache(rel)
return getattr(module, fn_name)
# 2) already-importable module
module = __import__(mod_path, fromlist=[fn_name])
return getattr(module, fn_name)
@functools.lru_cache(maxsize=128)
def _import_fun_from_str(path_str: str) -> Any:
"""Import a function from a string in the form '/absolute/path/to/module.function_name'."""
try:
......@@ -92,19 +133,7 @@ def _import_fun_from_str(path_str: str) -> Any:
if not module_path.exists():
raise ImportError(f"Module file not found: {module_path}")
# Use similar approach to _import_func_in_yml for consistency
mtime = module_path.stat().st_mtime_ns
cache_key = f"{module_path}:{mtime}"
if cache_key not in sys.modules:
spec = importlib.util.spec_from_file_location(cache_key, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules[cache_key] = module
module = sys.modules[cache_key]
module = _load_module_with_cache(module_path)
if not hasattr(module, function_name):
raise AttributeError(
......
......@@ -3,7 +3,6 @@ from __future__ import annotations
import inspect
from collections.abc import Mapping
from copy import deepcopy
from functools import lru_cache
from typing import Any
from lm_eval.api.group import GroupConfig
......@@ -12,7 +11,7 @@ from lm_eval.tasks._config_loader import load_yaml as load_cfg
from lm_eval.tasks.index import Entry, Kind
load_cfg_cached = lru_cache(maxsize=512)(load_cfg) # type: ignore[no-redef]
load_cfg_cached = load_cfg # type: ignore[no-redef]
class TaskFactory:
......@@ -108,7 +107,6 @@ class TaskFactory:
) -> dict[str, Any]:
if entry.yaml_path:
cfg = deepcopy(load_cfg_cached(entry.yaml_path, resolve_functions=True))
print(f"Loaded task config from {load_cfg_cached.cache_info()}")
else:
cfg = {"metadata": {"config": "unknown"}} # python task without YAML
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment