Commit d77596da authored by Baber's avatar Baber
Browse files

refactor: improve type hints and simplify YAML loading functions

parent 15e930af
...@@ -116,7 +116,6 @@ def _import_function(qual: str, base_dir: Path): ...@@ -116,7 +116,6 @@ def _import_function(qual: str, base_dir: Path):
2. python importlib (package/module already importable) 2. python importlib (package/module already importable)
Uses file *mtime* so edits are reloaded without killing the process. Uses file *mtime* so edits are reloaded without killing the process.
""" """
import importlib
if "." not in qual: if "." not in qual:
msg = f"!function value '{qual}' must contain a '.'" msg = f"!function value '{qual}' must contain a '.'"
...@@ -137,9 +136,6 @@ def _import_function(qual: str, base_dir: Path): ...@@ -137,9 +136,6 @@ def _import_function(qual: str, base_dir: Path):
sys.modules[module_key] = mod sys.modules[module_key] = mod
return getattr(mod, fn_name) return getattr(mod, fn_name)
# Fallback to regular import mechanism
import importlib
module = importlib.import_module(mod_part) module = importlib.import_module(mod_part)
return getattr(module, fn_name) return getattr(module, fn_name)
...@@ -183,87 +179,6 @@ def load_yaml_config( ...@@ -183,87 +179,6 @@ def load_yaml_config(
return merged return merged
# def load_yaml_config(
# yaml_path: Union[Path, str, None] = None,
# yaml_config: Optional[dict] = None,
# yaml_dir: Optional[Path] = None,
# mode: str = "full",
# *,
# _seen: Optional[set[tuple[Path, str]]] = None,
# resolve_includes: bool = True,
# ) -> dict:
# """
# Parse a YAML config with optional include handling.
#
# Parameters
# ----------
# yaml_path
# Path to the main YAML file. Needed unless *yaml_config* is
# supplied directly (e.g. by tests).
# yaml_config
# Pre-parsed dict to use instead of reading *yaml_path*.
# yaml_dir
# Base directory for resolving relative include paths. Defaults
# to `yaml_path.parent`.
# mode
# "full" - honour !function tags
# "simple" - ignore !function (faster).
# _seen
# **Internal** recursion set: tuples of (absolute-path, mode).
# Prevents include cycles such as A → B → A.
# """
# if yaml_config is None and yaml_path is None:
# raise ValueError("load_yaml_config needs either yaml_path or yaml_config")
#
# # ------------------------------------------------------------------ cycle guard
# if _seen is None:
# _seen = set()
# if yaml_path is not None:
# yaml_path = Path(yaml_path).expanduser().resolve()
#
# # ---------- fast-path: use LRU cached function ----------
# if yaml_config is None and resolve_includes:
# return _get_cached_config(yaml_path, mode)
#
# key = (yaml_path.resolve(), mode)
# if key in _seen:
# raise ValueError(f"Include cycle detected at {yaml_path}")
# _seen.add(key)
#
# # ------------------------------------------------------------------ load / parse
# if yaml_config is None: # ordinary path-based load
# yaml_config = _parse_yaml_file(yaml_path, mode)
#
# if yaml_dir is None and yaml_path is not None:
# yaml_dir = yaml_path.parent
# assert yaml_dir is not None, "yaml_dir must be set by caller or deduced from path"
#
# # ------------------------------------------------------------------ handle include
# include = yaml_config.pop("include", None)
# if not include and not resolve_includes:
# return yaml_config
#
# include_paths = include if isinstance(include, list) else [include]
# final_cfg: dict = {}
#
# for inc in reversed(include_paths):
# if inc is None: # guard against explicit nulls
# continue
# inc_path = Path(inc)
# if not inc_path.is_absolute():
# inc_path = (yaml_dir / inc_path).resolve()
# included = load_yaml_config(
# yaml_path=inc_path,
# mode=mode,
# yaml_dir=inc_path.parent,
# _seen=_seen, # <-- pass set downward
# )
# final_cfg.update(included)
#
# final_cfg.update(yaml_config) # local keys win
# return final_cfg
def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]: def iter_yaml_files(root: Path, ignore=_IGNORE_DIRS) -> Generator[Path, Any, None]:
"""Recursively iterate over all YAML files in a directory tree. """Recursively iterate over all YAML files in a directory tree.
...@@ -631,8 +546,8 @@ class TaskManager: ...@@ -631,8 +546,8 @@ class TaskManager:
raise ValueError raise ValueError
return self.task_index[name]["task"] return self.task_index[name]["task"]
@staticmethod
def _register_task( def _register_task(
self,
task_name: str, task_name: str,
task_type: str, task_type: str,
yaml_path: str, yaml_path: str,
...@@ -649,8 +564,8 @@ class TaskManager: ...@@ -649,8 +564,8 @@ class TaskManager:
if config and task_type != "group" and populate_tags_fn: if config and task_type != "group" and populate_tags_fn:
populate_tags_fn(config, task_name, tasks_and_groups) populate_tags_fn(config, task_name, tasks_and_groups)
@staticmethod
def _merge_task_configs( def _merge_task_configs(
self,
base_config: dict, base_config: dict,
task_specific_config: dict, task_specific_config: dict,
task_name: str, task_name: str,
...@@ -675,7 +590,8 @@ class TaskManager: ...@@ -675,7 +590,8 @@ class TaskManager:
) )
return dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
def _process_alias(self, config: dict, group: str | None = None) -> dict: @staticmethod
def _process_alias(config: dict, group: str | None = None) -> dict:
"""Process group alias configuration. """Process group alias configuration.
If the group is not the same as the original group which the group alias If the group is not the same as the original group which the group alias
...@@ -1047,7 +963,6 @@ class TaskManager: ...@@ -1047,7 +963,6 @@ class TaskManager:
# This is a python class config # This is a python class config
task = config["task"] task = config["task"]
self._register_task( self._register_task(
task,
"python_task", "python_task",
str(yaml_path), str(yaml_path),
tasks_and_groups, tasks_and_groups,
...@@ -1079,7 +994,6 @@ class TaskManager: ...@@ -1079,7 +994,6 @@ class TaskManager:
# This is a task config # This is a task config
task = config["task"] task = config["task"]
self._register_task( self._register_task(
task,
"task", "task",
str(yaml_path), str(yaml_path),
tasks_and_groups, tasks_and_groups,
...@@ -1092,7 +1006,6 @@ class TaskManager: ...@@ -1092,7 +1006,6 @@ class TaskManager:
if isinstance(task_entry, dict) and "task" in task_entry: if isinstance(task_entry, dict) and "task" in task_entry:
task_name = task_entry["task"] task_name = task_entry["task"]
self._register_task( self._register_task(
task_name,
"task", "task",
str(yaml_path), str(yaml_path),
tasks_and_groups, tasks_and_groups,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment