Commit 495ea3a0 authored by Baber's avatar Baber
Browse files

nit

parent 3c969207
......@@ -20,7 +20,6 @@ from typing import (
)
import yaml
from memory_profiler import profile
from yaml import YAMLError
from lm_eval.api.group import ConfigurableGroup, GroupConfig
......@@ -33,14 +32,13 @@ if TYPE_CHECKING:
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
_CONFIG_CACHE: dict[tuple[Path, str], dict] = {}
eval_logger = logging.getLogger(__name__)
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
@functools.lru_cache(maxsize=128) # ← reuse per (directory, simple) pair
@functools.lru_cache(maxsize=None) # ← reuse per (directory, simple) pair
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
"""
Return a custom YAML Loader class bound to *yaml_dir*.
......@@ -60,7 +58,7 @@ def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
# Register (or stub) the !function constructor **for this Loader only**
if simple:
yaml.add_constructor("!function", lambda *_: None, Loader=Loader)
yaml.add_constructor("!function", ignore_constructor, Loader=Loader)
else:
yaml.add_constructor(
"!function",
......@@ -75,7 +73,7 @@ def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
return Loader
@functools.lru_cache(maxsize=1000) # ← cache module objects
@functools.lru_cache(maxsize=None) # ← cache module objects
def _import_function(qualname: str, *, base_path: Path) -> Callable:
mod_path, _, func_name = qualname.rpartition(".")
if not mod_path:
......@@ -96,13 +94,42 @@ def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
return None
@functools.lru_cache(maxsize=1000) #
@functools.lru_cache(maxsize=None) #
def _parse_yaml_file(path: Path, mode: str) -> dict:
loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
with path.open("rb") as fh:
return yaml.load(fh, Loader=loader_cls)
@functools.lru_cache(maxsize=None)
def _get_cached_config(yaml_path: Path, mode: str) -> dict:
"""Load and cache resolved YAML configs with LRU eviction."""
# Parse the YAML file
yaml_config = _parse_yaml_file(yaml_path, mode)
yaml_dir = yaml_path.parent
# Handle includes
include = yaml_config.pop("include", None)
if not include:
return yaml_config
include_paths = include if isinstance(include, list) else [include]
final_cfg: dict = {}
for inc in reversed(include_paths):
if inc is None:
continue
inc_path = Path(inc)
if not inc_path.is_absolute():
inc_path = (yaml_dir / inc_path).resolve()
# Recursive call will use the cache
included = _get_cached_config(inc_path, mode)
final_cfg.update(included)
final_cfg.update(yaml_config) # local keys win
return final_cfg
def load_yaml_config(
yaml_path: Union[Path, str, None] = None,
yaml_config: dict | None = None,
......@@ -141,10 +168,9 @@ def load_yaml_config(
if yaml_path is not None:
yaml_path = Path(yaml_path).expanduser().resolve()
# ---------- fast-path: return memoised, already-resolved cfg ----------
cache_key = (yaml_path, mode)
if yaml_config is None and resolve_includes and cache_key in _CONFIG_CACHE:
return _CONFIG_CACHE[cache_key]
# ---------- fast-path: use LRU cached function ----------
if yaml_config is None and resolve_includes:
return _get_cached_config(yaml_path, mode)
key = (yaml_path.resolve(), mode)
if key in _seen:
......@@ -182,20 +208,16 @@ def load_yaml_config(
final_cfg.update(included)
final_cfg.update(yaml_config) # local keys win
# -------- memoise after *all* includes are merged ----------
if yaml_config is None and resolve_includes:
_CONFIG_CACHE[cache_key] = final_cfg
return final_cfg
def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
# '**/*.yaml' is handled internally by os.scandir.
for path in iglob("**/*.yaml", root_dir=root, recursive=True):
# quick ignore check
if "/__pycache__/" in path or "/.ipynb_checkpoints/" in path:
for p in iglob("**/*.yaml", root_dir=root, recursive=True):
# ignore check
if p.startswith(("__pycache__", ".ipynb_checkpoints")):
continue
yield root / path
yield root / p
class TaskManager:
......@@ -204,7 +226,6 @@ class TaskManager:
"""
@profile
def __init__(
self,
verbosity: Optional[str] = None,
......@@ -237,7 +258,6 @@ class TaskManager:
self.task_group_map = collections.defaultdict(list)
@profile
def initialize_tasks(
self,
include_path: Optional[Union[str, Path, List[Union[str, Path]]]] = None,
......@@ -722,7 +742,6 @@ class TaskManager:
def load_config(self, config: Dict) -> Mapping:
return self._load_individual_task_or_group(config)
@profile
def _get_task_and_group(self, task_dir: Union[str, Path]) -> Dict[str, Dict]:
"""Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`.
......@@ -898,7 +917,6 @@ def _check_duplicates(task_dict: Dict[str, List[str]]) -> None:
)
@profile
def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, "Task"]]],
task_manager: Optional[TaskManager] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment