Commit 1de882c2 authored by Baber's avatar Baber
Browse files

fix test

parent b58e5556
...@@ -146,8 +146,8 @@ def _import_fun_from_str(path_str: str) -> Any: ...@@ -146,8 +146,8 @@ def _import_fun_from_str(path_str: str) -> Any:
def load_yaml( def load_yaml(
path: str | Path, path: str | Path,
*, *,
resolve_functions: bool = True, resolve_func: bool = True,
resolve_includes: bool = True, recursive: bool = True,
_seen: set[Path] | None = None, _seen: set[Path] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Pure data-loading helper. """Pure data-loading helper.
...@@ -161,11 +161,11 @@ def load_yaml( ...@@ -161,11 +161,11 @@ def load_yaml(
raise ValueError(f"Include cycle at {path}") raise ValueError(f"Include cycle at {path}")
_seen.add(path) _seen.add(path)
loader_cls = _make_loader(path.parent, resolve_funcs=resolve_functions) loader_cls = _make_loader(path.parent, resolve_funcs=resolve_func)
with path.open("rb") as fh: with path.open("rb") as fh:
cfg = yaml.load(fh, Loader=loader_cls) cfg = yaml.load(fh, Loader=loader_cls)
if not resolve_includes or "include" not in cfg: if not recursive or "include" not in cfg:
return cfg return cfg
else: else:
includes = cfg.pop("include") includes = cfg.pop("include")
...@@ -176,8 +176,8 @@ def load_yaml( ...@@ -176,8 +176,8 @@ def load_yaml(
merged.update( merged.update(
load_yaml( load_yaml(
inc_path, inc_path,
resolve_functions=resolve_functions, resolve_func=resolve_func,
resolve_includes=True, recursive=True,
_seen=_seen, _seen=_seen,
), ),
) )
......
...@@ -106,7 +106,7 @@ class TaskFactory: ...@@ -106,7 +106,7 @@ class TaskFactory:
self, entry: Entry, overrides: dict[str, Any] | None self, entry: Entry, overrides: dict[str, Any] | None
) -> dict[str, Any]: ) -> dict[str, Any]:
if entry.yaml_path: if entry.yaml_path:
cfg = deepcopy(load_cfg_cached(entry.yaml_path, resolve_functions=True)) cfg = deepcopy(load_cfg_cached(entry.yaml_path, resolve_func=True))
else: else:
cfg = {"metadata": {"config": "unknown"}} # python task without YAML cfg = {"metadata": {"config": "unknown"}} # python task without YAML
......
...@@ -56,8 +56,8 @@ class TaskIndex: ...@@ -56,8 +56,8 @@ class TaskIndex:
try: try:
cfg = load_cfg( cfg = load_cfg(
yaml_path, yaml_path,
resolve_functions=False, resolve_func=False,
resolve_includes=resolve_includes, recursive=resolve_includes,
) )
self.process_cfg(cfg, yaml_path, index) self.process_cfg(cfg, yaml_path, index)
except Exception as err: except Exception as err:
......
...@@ -308,7 +308,7 @@ doc_to_text: !function utils.process_doc ...@@ -308,7 +308,7 @@ doc_to_text: !function utils.process_doc
""" """
file_path = yaml_file(content) file_path = yaml_file(content)
result = load_yaml(file_path, resolve_functions=True) result = load_yaml(file_path, resolve_func=True)
assert callable(result["doc_to_text"]) assert callable(result["doc_to_text"])
assert result["doc_to_text"]("hello") == "HELLO" assert result["doc_to_text"]("hello") == "HELLO"
...@@ -320,7 +320,7 @@ doc_to_text: !function utils.process_doc ...@@ -320,7 +320,7 @@ doc_to_text: !function utils.process_doc
""" """
file_path = yaml_file(content) file_path = yaml_file(content)
result = load_yaml(file_path, resolve_functions=False) result = load_yaml(file_path, resolve_func=False)
assert isinstance(result["doc_to_text"], str) assert isinstance(result["doc_to_text"], str)
# When resolve_functions=False, it returns the full path + function spec # When resolve_functions=False, it returns the full path + function spec
...@@ -345,7 +345,7 @@ shared_value: 100 ...@@ -345,7 +345,7 @@ shared_value: 100
""" """
main_path = yaml_file(main_content, "main.yaml") main_path = yaml_file(main_content, "main.yaml")
result = load_yaml(main_path, resolve_includes=True) result = load_yaml(main_path, recursive=True)
assert result["task"] == "main_task" assert result["task"] == "main_task"
assert result["shared_metric"] == "f1_score" assert result["shared_metric"] == "f1_score"
...@@ -368,7 +368,7 @@ main_key: main_value ...@@ -368,7 +368,7 @@ main_key: main_value
""" """
main_path = yaml_file(main_content) main_path = yaml_file(main_content)
result = load_yaml(main_path, resolve_includes=True) result = load_yaml(main_path, recursive=True)
assert result["main_key"] == "main_value" assert result["main_key"] == "main_value"
assert result["included_key"] == "included_value" assert result["included_key"] == "included_value"
...@@ -381,7 +381,7 @@ task: test_task ...@@ -381,7 +381,7 @@ task: test_task
""" """
file_path = yaml_file(content) file_path = yaml_file(content)
result = load_yaml(file_path, resolve_includes=False) result = load_yaml(file_path, recursive=False)
assert result["include"] == ["other.yaml"] assert result["include"] == ["other.yaml"]
assert result["task"] == "test_task" assert result["task"] == "test_task"
......
import os import os
from typing import List, Union from typing import List, Union
from lm_eval.tasks import load_yaml_config from lm_eval.tasks._config_loader import load_yaml
# {{{CI}}} # {{{CI}}}
...@@ -12,7 +12,7 @@ from lm_eval.tasks import load_yaml_config ...@@ -12,7 +12,7 @@ from lm_eval.tasks import load_yaml_config
# reads a text file and returns a list of words # reads a text file and returns a list of words
# used to read the output of the changed txt from tj-actions/changed-files # used to read the output of the changed txt from tj-actions/changed-files
def load_changed_files(file_path: str) -> List[str]: def load_changed_files(file_path: str) -> List[str]:
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
content = f.read() content = f.read()
words_list = list(content.split()) words_list = list(content.split())
return words_list return words_list
...@@ -26,7 +26,7 @@ def parser(full_path: List[str]) -> List[str]: ...@@ -26,7 +26,7 @@ def parser(full_path: List[str]) -> List[str]:
_output = set() _output = set()
for x in full_path: for x in full_path:
if x.endswith(".yaml") and os.path.exists(x): if x.endswith(".yaml") and os.path.exists(x):
config = load_yaml_config(x, mode="simple") config = load_yaml(x, recursive=True, resolve_func=True)
if isinstance(config["task"], str): if isinstance(config["task"], str):
_output.add(config["task"]) _output.add(config["task"])
elif isinstance(config["task"], list): elif isinstance(config["task"], list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment