Commit a9c16905 authored by Baber's avatar Baber
Browse files

refactor: migrate utils functions to lm_eval.tasks and update references

parent e11fa05d
...@@ -7,6 +7,8 @@ from functools import partial ...@@ -7,6 +7,8 @@ from functools import partial
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import lm_eval.tasks
def try_parse_json(value: str) -> Union[str, dict, None]: def try_parse_json(value: str) -> Union[str, dict, None]:
if value is None: if value is None:
...@@ -401,14 +403,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -401,14 +403,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
task_names = [] task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml") yaml_path = os.path.join(args.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path): for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file) config = lm_eval.tasks.load_yaml_config(yaml_file)
task_names.append(config) task_names.append(config)
else: else:
task_list = args.tasks.split(",") task_list = args.tasks.split(",")
task_names = task_manager.match_tasks(task_list) task_names = task_manager.match_tasks(task_list)
for task in [task for task in task_list if task not in task_names]: for task in [task for task in task_list if task not in task_names]:
if os.path.isfile(task): if os.path.isfile(task):
config = utils.load_yaml_config(task) config = lm_eval.tasks.load_yaml_config(task)
task_names.append(config) task_names.append(config)
task_missing = [ task_missing = [
task for task in task_list if task not in task_names and "*" not in task task for task in task_list if task not in task_names and "*" not in task
......
...@@ -24,6 +24,7 @@ import datasets ...@@ -24,6 +24,7 @@ import datasets
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
import lm_eval.tasks
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance, OutputType from lm_eval.api.instance import Instance, OutputType
...@@ -1124,7 +1125,7 @@ class ConfigurableTask(Task): ...@@ -1124,7 +1125,7 @@ class ConfigurableTask(Task):
# get task description # get task description
if description := self.config.description: if description := self.config.description:
description = utils.apply_template(self.config.description, doc) description = lm_eval.tasks.apply_template(self.config.description, doc)
# create system prompt based on the provided system instruction and description # create system prompt based on the provided system instruction and description
if system_instruction is not None and description: if system_instruction is not None and description:
...@@ -1259,7 +1260,7 @@ class ConfigurableTask(Task): ...@@ -1259,7 +1260,7 @@ class ConfigurableTask(Task):
return doc_to_decontamination_query(doc) return doc_to_decontamination_query(doc)
else: else:
return ast.literal_eval( return ast.literal_eval(
utils.apply_template( lm_eval.tasks.apply_template(
self.config.doc_to_decontamination_query, doc self.config.doc_to_decontamination_query, doc
) )
) )
...@@ -1292,7 +1293,7 @@ class ConfigurableTask(Task): ...@@ -1292,7 +1293,7 @@ class ConfigurableTask(Task):
# else: # else:
return doc[doc_to_text] return doc[doc_to_text]
else: else:
text_string = utils.apply_template(doc_to_text, doc) text_string = lm_eval.tasks.apply_template(doc_to_text, doc)
if text_string.isdigit() and self._config.doc_to_choice is not None: if text_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(text_string) return ast.literal_eval(text_string)
else: else:
...@@ -1328,7 +1329,7 @@ class ConfigurableTask(Task): ...@@ -1328,7 +1329,7 @@ class ConfigurableTask(Task):
# else: # else:
return doc[doc_to_target] return doc[doc_to_target]
else: else:
target_string = utils.apply_template(doc_to_target, doc) target_string = lm_eval.tasks.apply_template(doc_to_target, doc)
if target_string.isdigit() and self._config.doc_to_choice is not None: if target_string.isdigit() and self._config.doc_to_choice is not None:
return ast.literal_eval(target_string) return ast.literal_eval(target_string)
elif ( elif (
...@@ -1371,7 +1372,9 @@ class ConfigurableTask(Task): ...@@ -1371,7 +1372,9 @@ class ConfigurableTask(Task):
if doc_to_choice in self.features: if doc_to_choice in self.features:
return doc[doc_to_choice] return doc[doc_to_choice]
else: else:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) return ast.literal_eval(
lm_eval.tasks.apply_template(doc_to_choice, doc)
)
elif isinstance(doc_to_choice, list): elif isinstance(doc_to_choice, list):
return doc_to_choice return doc_to_choice
elif isinstance(doc_to_choice, dict): elif isinstance(doc_to_choice, dict):
...@@ -1400,7 +1403,7 @@ class ConfigurableTask(Task): ...@@ -1400,7 +1403,7 @@ class ConfigurableTask(Task):
if doc_to_image in self.features: if doc_to_image in self.features:
return doc[doc_to_image] return doc[doc_to_image]
else: else:
return ast.literal_eval(utils.apply_template(doc_to_image, doc)) return ast.literal_eval(lm_eval.tasks.apply_template(doc_to_image, doc))
elif callable(doc_to_image): elif callable(doc_to_image):
return doc_to_image(doc) return doc_to_image(doc)
else: else:
...@@ -1423,7 +1426,7 @@ class ConfigurableTask(Task): ...@@ -1423,7 +1426,7 @@ class ConfigurableTask(Task):
if doc_to_audio in self.features: if doc_to_audio in self.features:
return doc[doc_to_audio] return doc[doc_to_audio]
else: else:
return ast.literal_eval(utils.apply_template(doc_to_audio, doc)) return ast.literal_eval(lm_eval.tasks.apply_template(doc_to_audio, doc))
elif callable(doc_to_audio): elif callable(doc_to_audio):
return doc_to_audio(doc) return doc_to_audio(doc)
else: else:
...@@ -1434,7 +1437,7 @@ class ConfigurableTask(Task): ...@@ -1434,7 +1437,7 @@ class ConfigurableTask(Task):
if gen_prefix in self.features: if gen_prefix in self.features:
return doc[gen_prefix] return doc[gen_prefix]
else: else:
return utils.apply_template(gen_prefix, doc) return lm_eval.tasks.apply_template(gen_prefix, doc)
return None return None
def construct_requests( def construct_requests(
......
...@@ -3,6 +3,7 @@ import logging ...@@ -3,6 +3,7 @@ import logging
import os import os
from typing import Dict from typing import Dict
import lm_eval.tasks
from lm_eval import utils from lm_eval import utils
...@@ -122,7 +123,7 @@ class PromptString: ...@@ -122,7 +123,7 @@ class PromptString:
if "doc_to_choice" in self.prompt_string: if "doc_to_choice" in self.prompt_string:
raise NotImplementedError("Not yet implemented to accept doc_to_choice") raise NotImplementedError("Not yet implemented to accept doc_to_choice")
text_string = utils.apply_template(doc_to_text, doc) text_string = lm_eval.tasks.apply_template(doc_to_text, doc)
target_string = utils.apply_template(doc_to_target, doc) target_string = lm_eval.tasks.apply_template(doc_to_target, doc)
return [text_string, target_string] return [text_string, target_string]
import collections import collections
import functools
import importlib.util
import inspect import inspect
import logging import logging
import os import re
import sys
from functools import partial from functools import partial
from glob import iglob
from pathlib import Path from pathlib import Path
from typing import Dict, List, Mapping, Optional, Union from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Union
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
from yaml import YAMLError
from lm_eval import utils from lm_eval import utils
from lm_eval.api.group import ConfigurableGroup, GroupConfig from lm_eval.api.group import ConfigurableGroup, GroupConfig
...@@ -13,8 +21,191 @@ from lm_eval.evaluator_utils import get_subtask_list ...@@ -13,8 +21,191 @@ from lm_eval.evaluator_utils import get_subtask_list
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys()) GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
_CONFIG_CACHE: dict[tuple[Path, str], dict] = {}
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
_Base = yaml.CLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
@functools.lru_cache(maxsize=128) # ← reuse per (directory, simple) pair
def _make_loader(yaml_dir: Path, simple: bool = False) -> type[yaml.Loader]:
"""
Return a custom YAML Loader class bound to *yaml_dir*.
yaml_dir
Directory that holds the YAML file being parsed.
We capture it so that !function look-ups can resolve relative
Python files like my_utils.some_fn ➜ yaml_dir / "my_utils.py".
simple
If True we ignore !function completely (used by `mode="simple"`).
"""
class Loader(_Base):
"""Dynamically-generated loader that knows its base directory."""
# no extra state needed; the constructor stays the same
# Register (or stub) the !function constructor **for this Loader only**
if simple:
yaml.add_constructor("!function", lambda *_: None, Loader=Loader)
else:
yaml.add_constructor(
"!function",
# capture yaml_dir once so the lambda is fast and pickle-able
lambda ld, node, _dir=yaml_dir: _import_function(
ld.construct_scalar(node),
base_path=_dir,
),
Loader=Loader,
)
return Loader
@functools.lru_cache(maxsize=1000) # ← cache module objects
def _import_function(qualname: str, *, base_path: Path) -> Callable:
mod_path, _, func_name = qualname.rpartition(".")
if not mod_path:
raise ValueError(f"{qualname!r} has no module part")
file_path = base_path / f"{mod_path.replace('.', '/')}.py"
module_name = f"_yaml_dynamic.{hash(file_path)}_{file_path.stem}"
if module_name in sys.modules:
mod = sys.modules[module_name]
else:
spec = importlib.util.spec_from_file_location(module_name, file_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
sys.modules[module_name] = mod
return getattr(mod, func_name)
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> None:
return None
@functools.lru_cache(maxsize=1000) #
def _parse_yaml_file(path: Path, mode: str) -> dict:
loader_cls = _make_loader(path.parent, simple=(mode == "simple"))
with path.open("rb") as fh:
return yaml.load(fh, Loader=loader_cls)
def load_yaml_config(
yaml_path: Union[Path, str, None] = None,
yaml_config: dict | None = None,
yaml_dir: Path | None = None,
mode: str = "full",
*,
_seen: set[tuple[Path, str]] | None = None,
resolve_includes: bool = True,
) -> dict:
"""
Parse a YAML config with optional include handling.
Parameters
----------
yaml_path
Path to the main YAML file. Needed unless *yaml_config* is
supplied directly (e.g. by tests).
yaml_config
Pre-parsed dict to use instead of reading *yaml_path*.
yaml_dir
Base directory for resolving relative include paths. Defaults
to `yaml_path.parent`.
mode
"full" – honour !function tags
"simple" – ignore !function (faster).
_seen
**Internal** recursion set: tuples of (absolute-path, mode).
Prevents include cycles such as A → B → A.
"""
if yaml_config is None and yaml_path is None:
raise ValueError("load_yaml_config needs either yaml_path or yaml_config")
# ------------------------------------------------------------------ cycle guard
if _seen is None:
_seen = set()
if yaml_path is not None:
yaml_path = Path(yaml_path).expanduser().resolve()
# ---------- fast-path: return memoised, already-resolved cfg ----------
cache_key = (yaml_path, mode)
if yaml_config is None and resolve_includes and cache_key in _CONFIG_CACHE:
return _CONFIG_CACHE[cache_key]
key = (yaml_path.resolve(), mode)
if key in _seen:
raise ValueError(f"Include cycle detected at {yaml_path}")
_seen.add(key)
# ------------------------------------------------------------------ load / parse
if yaml_config is None: # ordinary path-based load
yaml_config = _parse_yaml_file(yaml_path, mode)
if yaml_dir is None and yaml_path is not None:
yaml_dir = yaml_path.parent
assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"
# ------------------------------------------------------------------ handle include
include = yaml_config.pop("include", None)
if not include and not resolve_includes:
return yaml_config
include_paths = include if isinstance(include, list) else [include]
final_cfg: dict = {}
for inc in reversed(include_paths):
if inc is None: # guard against explicit nulls
continue
inc_path = Path(inc)
if not inc_path.is_absolute():
inc_path = (yaml_dir / inc_path).resolve()
included = load_yaml_config(
yaml_path=inc_path,
mode=mode,
yaml_dir=inc_path.parent,
_seen=_seen, # <-- pass set downward
)
final_cfg.update(included)
final_cfg.update(yaml_config) # local keys win
# -------- memoise after *all* includes are merged ----------
if yaml_config is None and resolve_includes:
_CONFIG_CACHE[cache_key] = final_cfg
return final_cfg
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
@functools.lru_cache(maxsize=256)
def _compile_tpl(src: str):
return apply_template._env.from_string(src)
def apply_template(template: str, doc: dict) -> str:
if not hasattr(apply_template, "_env"):
apply_template._env = Environment(
loader=BaseLoader(),
undefined=StrictUndefined,
keep_trailing_newline=True,
)
apply_template._env.filters["regex_replace"] = regex_replace
return _compile_tpl(template).render(**doc)
def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
# '**/*.yaml' is handled internally by os.scandir.
for path in iglob("**/*.yaml", root_dir=root, recursive=True):
# quick ignore check
if "/__pycache__/" in path or "/.ipynb_checkpoints/" in path:
continue
yield root / path
class TaskManager: class TaskManager:
...@@ -113,68 +304,99 @@ class TaskManager: ...@@ -113,68 +304,99 @@ class TaskManager:
list_tags: bool = True, list_tags: bool = True,
list_subtasks: bool = True, list_subtasks: bool = True,
) -> str: ) -> str:
"""
Return a Markdown table (as a string) listing groups, tags and/or subtasks
known to this TaskManager. Safe for configs whose yaml_path is -1 and for
task configs whose `include:` is a list.
"""
from pytablewriter import MarkdownTableWriter from pytablewriter import MarkdownTableWriter
# ------------------------------------------------------------------ helpers
def sanitize_path(path: str) -> str: def sanitize_path(path: str) -> str:
# don't print full path if we are within the lm_eval/tasks dir ! # print a relative path for anything inside lm_eval/tasks/
# if we aren't though, provide the full path. # path_str = str(path)
if "lm_eval/tasks/" in path: if "lm_eval/tasks/" in path:
return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1] return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1]
else: return path
return path
def first_output_type_from_includes(cfg: dict, base: Path) -> str:
"""Walk cfg['include'] (string or list) and return the first
include that itself specifies an output_type."""
inc_raw = cfg.get("include")
if not inc_raw:
return ""
inc_list = inc_raw if isinstance(inc_raw, list) else [inc_raw]
for inc in inc_list:
inc_path = Path(inc)
if not inc_path.is_absolute(): # treat as relative include
inc_path = base.parent / inc_path
try:
inc_cfg = load_yaml_config(inc_path, mode="simple")
except FileNotFoundError:
continue
if "output_type" in inc_cfg:
return inc_cfg["output_type"]
return ""
# -------------------------------------------------------------- GROUP table
group_table = MarkdownTableWriter() group_table = MarkdownTableWriter()
group_table.headers = ["Group", "Config Location"] group_table.headers = ["Group", "Config Location"]
gt_values = [] group_table.value_matrix = [
for g in self.all_groups: [
path = self.task_index[g]["yaml_path"] g,
if path == -1: "---"
path = "---" if self.task_index[g]["yaml_path"] == -1
else: else sanitize_path(self.task_index[g]["yaml_path"]),
path = sanitize_path(path) ]
gt_values.append([g, path]) for g in self.all_groups
group_table.value_matrix = gt_values ]
# ---------------------------------------------------------------- TAG table
tag_table = MarkdownTableWriter() tag_table = MarkdownTableWriter()
tag_table.headers = ["Tag"] tag_table.headers = ["Tag"]
tag_table.value_matrix = [[t] for t in self.all_tags] tag_table.value_matrix = [[t] for t in self.all_tags]
# ------------------------------------------------------------ SUBTASK table
subtask_table = MarkdownTableWriter() subtask_table = MarkdownTableWriter()
subtask_table.headers = ["Task", "Config Location", "Output Type"] subtask_table.headers = ["Task", "Config Location", "Output Type"]
st_values = [] st_values: list[list[str]] = []
for t in self.all_subtasks: for t in self.all_subtasks:
path = self.task_index[t]["yaml_path"] raw_path = self.task_index[t]["yaml_path"]
output_type = "" if raw_path == -1:
# python-only task or generated at runtime
# read the yaml file to determine the output type display_path = "---"
if path != -1: output_type = ""
config = utils.load_yaml_config(path, mode="simple")
if "output_type" in config:
output_type = config["output_type"]
elif (
"include" in config
): # if no output type, check if there is an include with an output type
include_path = path.split("/")[:-1] + config["include"]
include_config = utils.load_yaml_config(include_path, mode="simple")
if "output_type" in include_config:
output_type = include_config["output_type"]
if path == -1:
path = "---"
else: else:
path = sanitize_path(path) path_obj = Path(raw_path)
st_values.append([t, path, output_type]) display_path = sanitize_path(str(path_obj))
# load minimal YAML to discover output_type
cfg = load_yaml_config(path_obj, mode="simple")
if "output_type" in cfg:
output_type = cfg["output_type"]
else:
output_type = first_output_type_from_includes(cfg, path_obj)
st_values.append([t, display_path, output_type])
subtask_table.value_matrix = st_values subtask_table.value_matrix = st_values
result = "\n" # ------------------------------------------------------------- final string
parts: list[str] = ["\n"]
if list_groups: if list_groups:
result += group_table.dumps() + "\n\n" parts.append(group_table.dumps())
parts.append("\n")
if list_tags: if list_tags:
result += tag_table.dumps() + "\n\n" parts.append(tag_table.dumps())
parts.append("\n")
if list_subtasks: if list_subtasks:
result += subtask_table.dumps() + "\n\n" parts.append(subtask_table.dumps())
return result parts.append("\n")
return "".join(parts)
def match_tasks(self, task_list: list[str]) -> list[str]: def match_tasks(self, task_list: list[str]) -> list[str]:
return utils.pattern_match(task_list, self.all_tasks) return utils.pattern_match(task_list, self.all_tasks)
...@@ -225,7 +447,7 @@ class TaskManager: ...@@ -225,7 +447,7 @@ class TaskManager:
if yaml_path == -1: if yaml_path == -1:
return {} return {}
else: else:
return utils.load_yaml_config(yaml_path, mode="full") return load_yaml_config(Path(yaml_path), mode="full")
def _get_tasklist(self, name: str) -> Union[List[str], int]: def _get_tasklist(self, name: str) -> Union[List[str], int]:
if self._name_is_task(name): if self._name_is_task(name):
...@@ -302,8 +524,8 @@ class TaskManager: ...@@ -302,8 +524,8 @@ class TaskManager:
original_task_name = config.get("task", task) original_task_name = config.get("task", task)
config = { config = {
**utils.load_yaml_config( **load_yaml_config(
yaml_path=yaml_path, yaml_path=Path(yaml_path),
yaml_config={"include": config.pop("include")}, yaml_config={"include": config.pop("include")},
mode="full" if yaml_path else "simple", mode="full" if yaml_path else "simple",
), ),
...@@ -555,78 +777,79 @@ class TaskManager: ...@@ -555,78 +777,79 @@ class TaskManager:
tasks_and_groups[tag]["task"].append(task) tasks_and_groups[tag]["task"].append(task)
# TODO: remove group in next release # TODO: remove group in next release
ignore_dirs = [ # ignore_dirs = [
"__pycache__", # "__pycache__",
".ipynb_checkpoints", # ".ipynb_checkpoints",
] # ]
tasks_and_groups = collections.defaultdict() tasks_and_groups = collections.defaultdict()
task_dir_path = Path(task_dir) task_dir_path = Path(task_dir)
for root, dirs, file_list in os.walk(task_dir_path): for yaml_path in iter_yaml_files(task_dir_path):
dirs[:] = [d for d in dirs if d not in ignore_dirs] try:
root_path = Path(root) config = load_yaml_config(
for f in file_list: yaml_path, mode="simple", resolve_includes=False
if f.endswith(".yaml"): )
yaml_path = root_path / f except (FileNotFoundError, YAMLError, OSError) as err:
config = utils.load_yaml_config(yaml_path, mode="simple") eval_logger.debug(f"File {yaml_path} could not be loaded ({err})")
if self._config_is_python_task(config): continue
# This is a python class config if self._config_is_python_task(config):
task = config["task"] # This is a python class config
self._register_task( task = config["task"]
task, self._register_task(
"python_task", task,
str(yaml_path), "python_task",
tasks_and_groups, str(yaml_path),
config, tasks_and_groups,
_populate_tags_and_groups, config,
) _populate_tags_and_groups,
elif self._config_is_group(config): )
# This is a group config elif self._config_is_group(config):
tasks_and_groups[config["group"]] = { # This is a group config
"type": "group", tasks_and_groups[config["group"]] = {
"task": -1, # This signals that "type": "group",
# we don't need to know "task": -1, # This signals that
# the task list for indexing # we don't need to know
# as it can be loaded # the task list for indexing
# when called. # as it can be loaded
"yaml_path": str(yaml_path), # when called.
} "yaml_path": str(yaml_path),
}
# # Registered the level 1 tasks from a group config # # Registered the level 1 tasks from a group config
# for config in config["task"]: # for config in config["task"]:
# if isinstance(config, dict) and self._config_is_task(config): # if isinstance(config, dict) and self._config_is_task(config):
# task = config["task"] # task = config["task"]
# tasks_and_groups[task] = { # tasks_and_groups[task] = {
# "type": "task", # "type": "task",
# "yaml_path": yaml_path, # "yaml_path": yaml_path,
# } # }
elif self._config_is_task(config): elif self._config_is_task(config):
# This is a task config # This is a task config
task = config["task"] task = config["task"]
self._register_task(
task,
"task",
str(yaml_path),
tasks_and_groups,
config,
_populate_tags_and_groups,
)
elif self._config_is_task_list(config):
# This is a task_list config
for task_entry in config["task_list"]:
if isinstance(task_entry, dict) and "task" in task_entry:
task_name = task_entry["task"]
self._register_task( self._register_task(
task, task_name,
"task", "task",
str(yaml_path), str(yaml_path),
tasks_and_groups, tasks_and_groups,
config, config,
_populate_tags_and_groups, _populate_tags_and_groups,
) )
elif self._config_is_task_list(config): else:
# This is a task_list config eval_logger.debug(f"File {yaml_path} could not be loaded")
for task_entry in config["task_list"]:
if isinstance(task_entry, dict) and "task" in task_entry:
task_name = task_entry["task"]
self._register_task(
task_name,
"task",
str(yaml_path),
tasks_and_groups,
config,
_populate_tags_and_groups,
)
else:
eval_logger.debug(f"File {f} in {root} could not be loaded")
return tasks_and_groups return tasks_and_groups
......
...@@ -10,12 +10,9 @@ import os ...@@ -10,12 +10,9 @@ import os
import re import re
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from itertools import islice from itertools import islice
from pathlib import Path from typing import Any, Callable, Generator, List, Optional, Tuple
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
SPACING = " " * 47 SPACING = " " * 47
...@@ -441,114 +438,6 @@ def positional_deprecated(fn): ...@@ -441,114 +438,6 @@ def positional_deprecated(fn):
return _wrapper return _wrapper
def ignore_constructor(loader: yaml.Loader, node: yaml.Node) -> yaml.Node:
return node
def import_function(loader: yaml.Loader, node: yaml.Node, yaml_path: Path) -> Callable:
function_name = loader.construct_scalar(node)
*module_name, function_name = function_name.split(".")
if isinstance(module_name, list):
module_name = ".".join(module_name)
module_path = yaml_path.parent / f"{module_name}.py"
spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
if spec is None:
raise ImportError(f"Could not import module {module_name} from {module_path}.")
module = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
spec.loader.exec_module(module)
function = getattr(module, function_name)
return function
def load_yaml_config(
yaml_path: Optional[Union[str, Path]] = None,
yaml_config: Optional[Dict] = None,
yaml_dir: Optional[Union[str, Path]] = None,
mode: str = "full",
) -> Dict:
# Convert yaml_path to Path object if it's a string
if yaml_path is not None:
yaml_path = Path(yaml_path)
# Convert yaml_dir to Path object if it's a string
if yaml_dir is not None:
yaml_dir = Path(yaml_dir)
if mode == "simple":
constructor_fn = ignore_constructor
elif mode == "full":
if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=yaml_path)
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", constructor_fn, Loader=loader)
if yaml_config is None:
with open(yaml_path, "rb") as file:
yaml_config = yaml.load(file, Loader=loader)
if yaml_dir is None and yaml_path is not None:
yaml_dir = yaml_path.parent
assert yaml_dir is not None
if "include" in yaml_config:
include_path = yaml_config["include"]
del yaml_config["include"]
if isinstance(include_path, str):
include_path = [include_path]
# Load from the last one first
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Convert to Path object
path = Path(path)
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
if not path.is_file():
path = yaml_dir / path
try:
included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
final_yaml_config.update(included_yaml_config)
except Exception as ex:
# If failed to load, ignore
raise ex
final_yaml_config.update(yaml_config)
return final_yaml_config
return yaml_config
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
def apply_template(template: str, doc: dict) -> str:
# Lazy initialization - only create Environment when actually needed
if not hasattr(apply_template, "_env"):
apply_template._env = Environment(
loader=BaseLoader(), undefined=StrictUndefined, keep_trailing_newline=True
)
apply_template._env.filters["regex_replace"] = regex_replace
rtemplate = apply_template._env.from_string(template)
return rtemplate.render(**doc)
def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None): def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
""" """
Method for creating a (potentially) sliced and limited Method for creating a (potentially) sliced and limited
......
...@@ -11,8 +11,8 @@ from lm_eval.api.metrics import ( ...@@ -11,8 +11,8 @@ from lm_eval.api.metrics import (
stderr_for_metric, stderr_for_metric,
) )
from lm_eval.models.utils import Collator from lm_eval.models.utils import Collator
from lm_eval.tasks import apply_template
from lm_eval.utils import ( from lm_eval.utils import (
apply_template,
get_rolling_token_windows, get_rolling_token_windows,
make_disjoint_window, make_disjoint_window,
) )
......
import os import os
from typing import List, Union from typing import List, Union
from lm_eval.utils import load_yaml_config from lm_eval.tasks import load_yaml_config
# {{{CI}}} # {{{CI}}}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment