Commit a9c16905 authored by Baber's avatar Baber
Browse files

refactor: migrate utils functions to lm_eval.tasks and update references

parent e11fa05d
......@@ -7,6 +7,8 @@ from functools import partial
from pathlib import Path
from typing import Union
import lm_eval.tasks
def try_parse_json(value: str) -> Union[str, dict, None]:
if value is None:
......@@ -401,14 +403,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file)
config = lm_eval.tasks.load_yaml_config(yaml_file)
task_names.append(config)
else:
task_list = args.tasks.split(",")
task_names = task_manager.match_tasks(task_list)
for task in [task for task in task_list if task not in task_names]:
if os.path.isfile(task):
config = utils.load_yaml_config(task)
config = lm_eval.tasks.load_yaml_config(task)
task_names.append(config)
task_missing = [
task for task in task_list if task not in task_names and "*" not in task
......
......@@ -24,6 +24,7 @@ import datasets
import numpy as np
from tqdm import tqdm
import lm_eval.tasks
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType
......@@ -1124,7 +1125,7 @@ class ConfigurableTask(Task):
# get task description
if description := self.config.description:
description = utils.apply_template(self.config.description, doc)
description = lm_eval.tasks.apply_template(self.config.description, doc)
# create system prompt based on the provided system instruction and description
if system_instruction is not None and description:
......@@ -1259,7 +1260,7 @@ class ConfigurableTask(Task):
return doc_to_decontamination_query(doc)
else:
return ast.literal_eval(
utils.apply_template(
lm_eval.tasks.apply_template(
self.config.doc_to_decontamination_query, doc
)
)
......@@ -1292,7 +1293,7 @@ class ConfigurableTask(Task):
# else:
return doc[doc_to_text]
else:
text_string = utils.apply_template(doc_to_text, doc)
text_string = lm_eval.tasks.apply_template(doc_to_text, doc)
if text_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(text_string)
else:
......@@ -1328,7 +1329,7 @@ class ConfigurableTask(Task):
# else:
return doc[doc_to_target]
else:
target_string = utils.apply_template(doc_to_target, doc)
target_string = lm_eval.tasks.apply_template(doc_to_target, doc)
if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string)
elif (
......@@ -1371,7 +1372,9 @@ class ConfigurableTask(Task):
if doc_to_choice in self.features:
return doc[doc_to_choice]
else:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
return ast.literal_eval(
lm_eval.tasks.apply_template(doc_to_choice, doc)
)
elif isinstance(doc_to_choice, list):
return doc_to_choice
elif isinstance(doc_to_choice, dict):
......@@ -1400,7 +1403,7 @@ class ConfigurableTask(Task):
if doc_to_image in self.features:
return doc[doc_to_image]
else:
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
return ast.literal_eval(lm_eval.tasks.apply_template(doc_to_image, doc))
elif callable(doc_to_image):
return doc_to_image(doc)
else:
......@@ -1423,7 +1426,7 @@ class ConfigurableTask(Task):
if doc_to_audio in self.features:
return doc[doc_to_audio]
else:
return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
return ast.literal_eval(lm_eval.tasks.apply_template(doc_to_audio, doc))
elif callable(doc_to_audio):
return doc_to_audio(doc)
else:
......@@ -1434,7 +1437,7 @@ class ConfigurableTask(Task):
if gen_prefix in self.features:
return doc[gen_prefix]
else:
return utils.apply_template(gen_prefix, doc)
return lm_eval.tasks.apply_template(gen_prefix, doc)
return None
def construct_requests(
......
......@@ -3,6 +3,7 @@ import logging
import os
from typing import Dict
import lm_eval.tasks
from lm_eval import utils
......@@ -122,7 +123,7 @@ class PromptString:
if "doc_to_choice" in self.prompt_string:
raise NotImplementedError("Not yet implemented to accept doc_to_choice")
text_string = utils.apply_template(doc_to_text, doc)
target_string = utils.apply_template(doc_to_target, doc)
text_string = lm_eval.tasks.apply_template(doc_to_text, doc)
target_string = lm_eval.tasks.apply_template(doc_to_target, doc)
return [text_string, target_string]
import collections
import functools
import importlib.util
import inspect
import logging
import os
import re
import sys
from functools import partial
from glob import iglob
from pathlib import Path
from typing import Dict, List, Mapping, Optional, Union
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Union
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
from yaml import YAMLError
from lm_eval import utils
from lm_eval.api.group import ConfigurableGroup, GroupConfig
......@@ -13,8 +21,191 @@ from lm_eval.evaluator_utils import get_subtask_list
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
_CONFIG_CACHE: dict[tuple[Path, str], dict] = {}
eval_logger = logging.getLogger(__name__)
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
@functools.lru_cache(maxsize=128) # ← reuse per (directory, simple) pair
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
"""
Return a custom YAML Loader class bound to *yaml_dir*.
yaml_dir
Directory that holds the YAML file being parsed.
We capture it so that !function look-ups can resolve relative
Python files like my_utils.some_fn ➜ yaml_dir / "my_utils.py".
simple
If True we ignore !function completely (used by `mode="simple"`).
"""
class Loader(_Base):
"""Dynamically-generated loader that knows its base directory."""
# no extra state needed; the constructor stays the same
# Register (or stub) the !function constructor **for this Loader only**
if simple:
yaml.add_constructor("!function", lambda *_: None, Loader=Loader)
else:
yaml.add_constructor(
"!function",
# capture yaml_dir once so the lambda is fast and pickle-able
lambda ld, node, _dir=yaml_dir: _import_function(
ld.construct_scalar(node),
base_path=_dir,
),
Loader=Loader,
)
return Loader
@functools.lru_cache(maxsize=1000) # ← cache module objects
def _import_function(qualname: str, *, base_path: Path) -> Callable:
mod_path, _, func_name = qualname.rpartition(".")
if not mod_path:
raise ValueError(f"{qualname!r} has no module part")
file_path = base_path / f"{mod_path.replace('.', '/')}.py"
module_name = f"_yaml_dynamic.{hash(file_path)}_{file_path.stem}"
if module_name in sys.modules:
mod = sys.modules[module_name]
else:
spec = importlib.util.spec_from_file_location(module_name, file_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
sys.modules[module_name] = mod
return getattr(mod, func_name)
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
return None
@functools.lru_cache(maxsize=1000) #
def _parse_yaml_file(path: Path, mode: str) -> dict:
loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
with path.open("rb") as fh:
return yaml.load(fh, Loader=loader_cls)
def load_yaml_config(
yaml_path: Union[Path, str, None] = None,
yaml_config: dict | None = None,
yaml_dir: Path | None = None,
mode: str = "full",
*,
_seen: set[tuple[Path, str]] | None = None,
resolve_includes: bool = True,
) -> dict:
"""
Parse a YAML config with optional include handling.
Parameters
----------
yaml_path
Path to the main YAML file. Needed unless *yaml_config* is
supplied directly (e.g. by tests).
yaml_config
Pre-parsed dict to use instead of reading *yaml_path*.
yaml_dir
Base directory for resolving relative include paths. Defaults
to `yaml_path.parent`.
mode
"full" – honour !function tags
"simple" – ignore !function (faster).
_seen
**Internal** recursion set: tuples of (absolute-path, mode).
Prevents include cycles such as A → B → A.
"""
if yaml_config is None and yaml_path is None:
raise ValueError("load_yaml_config needs either yaml_path or yaml_config")
# ------------------------------------------------------------------ cycle guard
if _seen is None:
_seen = set()
if yaml_path is not None:
yaml_path = Path(yaml_path).expanduser().resolve()
# ---------- fast-path: return memoised, already-resolved cfg ----------
cache_key = (yaml_path, mode)
if yaml_config is None and resolve_includes and cache_key in _CONFIG_CACHE:
return _CONFIG_CACHE[cache_key]
key = (yaml_path.resolve(), mode)
if key in _seen:
raise ValueError(f"Include cycle detected at {yaml_path}")
_seen.add(key)
# ------------------------------------------------------------------ load / parse
if yaml_config is None: # ordinary path-based load
yaml_config = _parse_yaml_file(yaml_path, mode)
if yaml_dir is None and yaml_path is not None:
yaml_dir = yaml_path.parent
assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"
# ------------------------------------------------------------------ handle include
include = yaml_config.pop("include", None)
if not include and not resolve_includes:
return yaml_config
include_paths = include if isinstance(include, list) else [include]
final_cfg: dict = {}
for inc in reversed(include_paths):
if inc is None: # guard against explicit nulls
continue
inc_path = Path(inc)
if not inc_path.is_absolute():
inc_path = (yaml_dir / inc_path).resolve()
included = load_yaml_config(
yaml_path=inc_path,
mode=mode,
yaml_dir=inc_path.parent,
_seen=_seen, # <-- pass set downward
)
final_cfg.update(included)
final_cfg.update(yaml_config) # local keys win
# -------- memoise after *all* includes are merged ----------
if yaml_config is None and resolve_includes:
_CONFIG_CACHE[cache_key] = final_cfg
return final_cfg
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
@functools.lru_cache(maxsize=256)
def _compile_tpl(src: str):
return apply_template._env.from_string(src)
def apply_template(template: str, doc: dict) -> str:
if not hasattr(apply_template, "_env"):
apply_template._env = Environment(
loader=BaseLoader(),
undefined=StrictUndefined,
keep_trailing_newline=True,
)
apply_template._env.filters["regex_replace"] = regex_replace
return _compile_tpl(template).render(**doc)
def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
# '**/*.yaml' is handled internally by os.scandir.
for path in iglob("**/*.yaml", root_dir=root, recursive=True):
# quick ignore check
if "/__pycache__/" in path or "/.ipynb_checkpoints/" in path:
continue
yield root / path
class TaskManager:
......@@ -113,68 +304,99 @@ class TaskManager:
list_tags: bool = True,
list_subtasks: bool = True,
) -> str:
"""
Return a Markdown table (as a string) listing groups, tags and/or subtasks
known to this TaskManager. Safe for configs whose yaml_path is -1 and for
task configs whose `include:` is a list.
"""
from pytablewriter import MarkdownTableWriter
# ------------------------------------------------------------------ helpers
def sanitize_path(path: str) -> str:
# don't print full path if we are within the lm_eval/tasks dir !
# if we aren't though, provide the full path.
# print a relative path for anything inside lm_eval/tasks/
# path_str = str(path)
if "lm_eval/tasks/" in path:
return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
else:
return path
return path
def first_output_type_from_includes(cfg: dict, base: Path) -> str:
"""Walk cfg['include'] (string or list) and return the first
include that itself specifies an output_type."""
inc_raw = cfg.get("include")
if not inc_raw:
return ""
inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw]
for inc in inc_list:
inc_path = Path(inc)
if not inc_path.is_absolute(): # treat as relative include
inc_path = base.parent / inc_path
try:
inc_cfg = load_yaml_config(inc_path, mode="simple")
except FileNotFoundError:
continue
if "output_type" in inc_cfg:
return inc_cfg["output_type"]
return ""
# -------------------------------------------------------------- GROUP table
group_table = MarkdownTableWriter()
group_table.headers = ["Group", "Config Location"]
gt_values = []
for g in self.all_groups:
path = self.task_index[g]["yaml_path"]
if path == -1:
path = "---"
else:
path = sanitize_path(path)
gt_values.append([g, path])
group_table.value_matrix = gt_values
group_table.value_matrix = [
[
g,
"---"
if self.task_index[g]["yaml_path"] == -1
else sanitize_path(self.task_index[g]["yaml_path"]),
]
for g in self.all_groups
]
# ---------------------------------------------------------------- TAG table
tag_table = MarkdownTableWriter()
tag_table.headers = ["Tag"]
tag_table.value_matrix = [[t] for t in self.all_tags]
# ------------------------------------------------------------ SUBTASK table
subtask_table = MarkdownTableWriter()
subtask_table.headers = ["Task", "Config Location", "Output Type"]
st_values = []
st_values: list[list[str]] = []
for t in self.all_subtasks:
path = self.task_index[t]["yaml_path"]
output_type = ""
# read the yaml file to determine the output type
if path != -1:
config = utils.load_yaml_config(path, mode="simple")
if "output_type" in config:
output_type = config["output_type"]
elif (
"include" in config
): # if no output type, check if there is an include with an output type
include_path = path.split("/")[:-1] + config["include"]
include_config = utils.load_yaml_config(include_path, mode="simple")
if "output_type" in include_config:
output_type = include_config["output_type"]
if path == -1:
path = "---"
raw_path = self.task_index[t]["yaml_path"]
if raw_path == -1:
# python-only task or generated at runtime
display_path = "---"
output_type = ""
else:
path = sanitize_path(path)
st_values.append([t, path, output_type])
path_obj = Path(raw_path)
display_path = sanitize_path(str(path_obj))
# load minimal YAML to discover output_type
cfg = load_yaml_config(path_obj, mode="simple")
if "output_type" in cfg:
output_type = cfg["output_type"]
else:
output_type = first_output_type_from_includes(cfg, path_obj)
st_values.append([t, display_path, output_type])
subtask_table.value_matrix = st_values
result = "\n"
# ------------------------------------------------------------- final string
parts: list[str] = ["\n"]
if list_groups:
result += group_table.dumps() + "\n\n"
parts.append(group_table.dumps())
parts.append("\n")
if list_tags:
result += tag_table.dumps() + "\n\n"
parts.append(tag_table.dumps())
parts.append("\n")
if list_subtasks:
result += subtask_table.dumps() + "\n\n"
return result
parts.append(subtask_table.dumps())
parts.append("\n")
return "".join(parts)
def match_tasks(self, task_list: list[str]) -> list[str]:
return utils.pattern_match(task_list, self.all_tasks)
......@@ -225,7 +447,7 @@ class TaskManager:
if yaml_path == -1:
return {}
else:
return utils.load_yaml_config(yaml_path, mode="full")
return load_yaml_config(Path(yaml_path), mode="full")
def _get_tasklist(self, name: str) -> Union[List[str], int]:
if self._name_is_task(name):
......@@ -302,8 +524,8 @@ class TaskManager:
original_task_name = config.get("task", task)
config = {
**utils.load_yaml_config(
yaml_path=yaml_path,
**load_yaml_config(
yaml_path=Path(yaml_path),
yaml_config={"include": config.pop("include")},
mode="full" if yaml_path else "simple",
),
......@@ -555,78 +777,79 @@ class TaskManager:
tasks_and_groups[tag]["task"].append(task)
# TODO: remove group in next release
ignore_dirs = [
"__pycache__",
".ipynb_checkpoints",
]
# ignore_dirs = [
# "__pycache__",
# ".ipynb_checkpoints",
# ]
tasks_and_groups = collections.defaultdict()
task_dir_path = Path(task_dir)
for root, dirs, file_list in os.walk(task_dir_path):
dirs[:] = [d for d in dirs if d not in ignore_dirs]
root_path = Path(root)
for f in file_list:
if f.endswith(".yaml"):
yaml_path = root_path / f
config = utils.load_yaml_config(yaml_path, mode="simple")
if self._config_is_python_task(config):
# This is a python class config
task = config["task"]
self._register_task(
task,
"python_task",
str(yaml_path),
tasks_and_groups,
config,
_populate_tags_and_groups,
)
elif self._config_is_group(config):
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": str(yaml_path),
}
for yaml_path in iter_yaml_files(task_dir_path):
try:
config = load_yaml_config(
yaml_path, mode="simple", resolve_includes=False
)
except (FileNotFoundError, YAMLError, OSError) as err:
eval_logger.debug(f"File {yaml_path} could not be loaded ({err})")
continue
if self._config_is_python_task(config):
# This is a python class config
task = config["task"]
self._register_task(
task,
"python_task",
str(yaml_path),
tasks_and_groups,
config,
_populate_tags_and_groups,
)
elif self._config_is_group(config):
# This is a group config
tasks_and_groups[config["group"]] = {
"type": "group",
"task": -1, # This signals that
# we don't need to know
# the task list for indexing
# as it can be loaded
# when called.
"yaml_path": str(yaml_path),
}
# # Registered the level 1 tasks from a group config
# for config in config["task"]:
# if isinstance(config, dict) and self._config_is_task(config):
# task = config["task"]
# tasks_and_groups[task] = {
# "type": "task",
# "yaml_path": yaml_path,
# }
elif self._config_is_task(config):
# This is a task config
task = config["task"]
# # Registered the level 1 tasks from a group config
# for config in config["task"]:
# if isinstance(config, dict) and self._config_is_task(config):
# task = config["task"]
# tasks_and_groups[task] = {
# "type": "task",
# "yaml_path": yaml_path,
# }
elif self._config_is_task(config):
# This is a task config
task = config["task"]
self._register_task(
task,
"task",
str(yaml_path),
tasks_and_groups,
config,
_populate_tags_and_groups,
)
elif self._config_is_task_list(config):
# This is a task_list config
for task_entry in config["task_list"]:
if isinstance(task_entry, dict) and "task" in task_entry:
task_name = task_entry["task"]
self._register_task(
task,
task_name,
"task",
str(yaml_path),
tasks_and_groups,
config,
_populate_tags_and_groups,
)
elif self._config_is_task_list(config):
# This is a task_list config
for task_entry in config["task_list"]:
if isinstance(task_entry, dict) and "task" in task_entry:
task_name = task_entry["task"]
self._register_task(
task_name,
"task",
str(yaml_path),
tasks_and_groups,
config,
_populate_tags_and_groups,
)
else:
eval_logger.debug(f"File {f} in {root} could not be loaded")
else:
eval_logger.debug(f"File {yaml_path} could not be loaded")
return tasks_and_groups
......
......@@ -10,12 +10,9 @@ import os
import re
from dataclasses import asdict, is_dataclass
from itertools import islice
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Generator, List, Optional, Tuple
import numpy as np
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
SPACING = " " * 47
......@@ -441,114 +438,6 @@ def positional_deprecated(fn):
return _wrapper
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> yaml.Node:
return node
def import_function(loader: yaml.Loader, node: yaml.Node, yaml_path: Path) -> Callable:
function_name = loader.construct_scalar(node)
*module_name, function_name = function_name.split(".")
if isinstance(module_name, list):
module_name = ".".join(module_name)
module_path = yaml_path.parent / f"{module_name}.py"
spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
if spec is None:
raise ImportError(f"Could not import module {module_name} from {module_path}.")
module = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
spec.loader.exec_module(module)
function = getattr(module, function_name)
return function
def load_yaml_config(
yaml_path: Optional[Union[str, Path]] = None,
yaml_config: Optional[Dict] = None,
yaml_dir: Optional[Union[str, Path]] = None,
mode: str = "full",
) -> Dict:
# Convert yaml_path to Path object if it's a string
if yaml_path is not None:
yaml_path = Path(yaml_path)
# Convert yaml_dir to Path object if it's a string
if yaml_dir is not None:
yaml_dir = Path(yaml_dir)
if mode == "simple":
constructor_fn = ignore_constructor
elif mode == "full":
if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=yaml_path)
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", constructor_fn, Loader=loader)
if yaml_config is None:
with open(yaml_path, "rb") as file:
yaml_config = yaml.load(file, Loader=loader)
if yaml_dir is None and yaml_path is not None:
yaml_dir = yaml_path.parent
assert yaml_dir is not None
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if isinstance(include_path, str):
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Convert to Path object
path = Path(path)
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not path.is_file():
path = yaml_dir / path
try:
included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
def apply_template(template: str, doc: dict) -> str:
# Lazy initialization - only create Environment when actually needed
if not hasattr(apply_template, "_env"):
apply_template._env = Environment(
loader=BaseLoader(), undefined=StrictUndefined, keep_trailing_newline=True
)
apply_template._env.filters["regex_replace"] = regex_replace
rtemplate = apply_template._env.from_string(template)
return rtemplate.render(**doc)
def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
"""
Method for creating a (potentially) sliced and limited
......
......@@ -11,8 +11,8 @@ from lm_eval.api.metrics import (
stderr_for_metric,
)
from lm_eval.models.utils import Collator
from lm_eval.tasks import apply_template
from lm_eval.utils import (
apply_template,
get_rolling_token_windows,
make_disjoint_window,
)
......
import os
from typing import List, Union
from lm_eval.utils import load_yaml_config
from lm_eval.tasks import load_yaml_config
# {{{CI}}}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment