_config_loader.py 5.18 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
from __future__ import annotations

import functools
import importlib.util
import sys
from pathlib import Path
from typing import Any

import yaml


Baber's avatar
Baber committed
12
13
14
_Base = (
    yaml.CSafeLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader
)
Baber's avatar
Baber committed
15
16
17
18
_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"}


# --------------------------------------------------------------------------- helpers
Baber's avatar
Baber committed
19
@functools.lru_cache(128)
Baber's avatar
Baber committed
20
21
22
23
def _mk_function_ctor(base_dir: Path, resolve: bool):
    def ctor(loader: yaml.Loader, node: yaml.Node):
        spec = loader.construct_scalar(node)  # type: ignore[arg-type]
        if not resolve:
Baber's avatar
Baber committed
24
25
            return str(base_dir.expanduser() / spec)
        return _import_func_in_yml(spec, base_dir)
Baber's avatar
Baber committed
26
27
28
29

    return ctor


Baber's avatar
Baber committed
30
@functools.lru_cache(maxsize=512)
Baber's avatar
Baber committed
31
32
33
34
def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
    class Loader(_Base): ...  # type: ignore[no-redef]

    yaml.add_constructor(
Baber's avatar
Baber committed
35
36
37
        "!function",
        _mk_function_ctor(base_dir, resolve_funcs),
        Loader=Loader,
Baber's avatar
Baber committed
38
39
40
41
    )
    return Loader


Baber's avatar
Baber committed
42
43
44
45
46
47
48
49
@functools.lru_cache(maxsize=128)
def _import_func_in_yml(qual: str, base_dir: Path):
    """Import function from qual: utils.process_doc, checking local files first then standard imports.

    Args:
        qual: Qualified function name (e.g., 'utils.process_doc')
        base_dir: Directory to search for local modules
    """
Baber's avatar
Baber committed
50
51
52
53
54
55
56
57
    mod_path, _, fn_name = qual.rpartition(".")
    # 1) relative “utils.py” next to YAML
    rel = (base_dir / f"{mod_path.replace('.', '/')}.py").resolve()
    if rel.exists():
        mtime = rel.stat().st_mtime_ns
        key = f"{rel}:{mtime}"  # one module per mtime
        if key not in sys.modules:
            spec = importlib.util.spec_from_file_location(key, rel)
Baber's avatar
Baber committed
58
59
            if spec is None or spec.loader is None:
                raise ImportError(f"Cannot load module from {rel}") from None
Baber's avatar
Baber committed
60
61
62
63
64
            mod = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(mod)  # type: ignore[arg-type]
            sys.modules[key] = mod
        return getattr(sys.modules[key], fn_name)

Baber's avatar
Baber committed
65
    # 2) already-importable module
Baber's avatar
Baber committed
66
67
68
69
    module = __import__(mod_path, fromlist=[fn_name])
    return getattr(module, fn_name)


Baber's avatar
Baber committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@functools.lru_cache(maxsize=128)
def _import_fun_from_str(path_str: str) -> Any:
    """Import a function from a string in the form '/absolute/path/to/module.function_name'."""
    try:
        # Split off the function name from the rightmost dot
        module_path_str, function_name = path_str.rsplit(".", 1)
    except ValueError as e:
        raise ValueError(
            f"Invalid path format: {path_str}. Expected format: /path/to/module.function_name"
        ) from e

    # Convert to Path and handle .py extension
    module_path = Path(module_path_str)
    if not module_path.suffix:
        module_path = module_path.with_suffix(".py")
    elif module_path.suffix != ".py":
        # If it has a non-.py suffix, the user might have included .py in the path
        # e.g., "/path/to/module.py.function_name"
        base_path = module_path.with_suffix("")
        if base_path.with_suffix(".py").exists():
            module_path = base_path.with_suffix(".py")

    if not module_path.exists():
        raise ImportError(f"Module file not found: {module_path}")

    # Use similar approach to _import_func_in_yml for consistency
    mtime = module_path.stat().st_mtime_ns
    cache_key = f"{module_path}:{mtime}"

    if cache_key not in sys.modules:
        spec = importlib.util.spec_from_file_location(cache_key, module_path)
        if spec is None or spec.loader is None:
            raise ImportError(f"Cannot load module from {module_path}") from None
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        sys.modules[cache_key] = module

    module = sys.modules[cache_key]

    if not hasattr(module, function_name):
        raise AttributeError(
            f"Function '{function_name}' not found in module {module_path}"
        )

    return getattr(module, function_name)


Baber's avatar
Baber committed
117
118
119
120
121
122
def load_yaml(
    path: str | Path,
    *,
    resolve_functions: bool = True,
    resolve_includes: bool = True,
    _seen: set[Path] | None = None,
Baber's avatar
Baber committed
123
124
125
) -> dict[str, Any]:
    """Pure data-loading helper.
    Returns a dict ready for higher-level interpretation.
Baber's avatar
Baber committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    •No task/group/tag semantics here.
    """
    path = Path(path).expanduser().resolve()
    if _seen is None:
        _seen = set()
    if path in _seen:
        raise ValueError(f"Include cycle at {path}")
    _seen.add(path)

    loader_cls = _make_loader(path.parent, resolve_funcs=resolve_functions)
    with path.open("rb") as fh:
        cfg = yaml.load(fh, Loader=loader_cls)

    if not resolve_includes or "include" not in cfg:
        return cfg
Baber's avatar
Baber committed
141
142
    else:
        includes = cfg.pop("include")
Baber's avatar
Baber committed
143
144

    merged = {}
Baber's avatar
Baber committed
145
    for inc in includes if isinstance(includes, list) else [includes]:
Baber's avatar
Baber committed
146
147
148
149
150
151
152
        inc_path = (path.parent / inc) if not Path(inc).is_absolute() else Path(inc)
        merged.update(
            load_yaml(
                inc_path,
                resolve_functions=resolve_functions,
                resolve_includes=True,
                _seen=_seen,
Baber's avatar
Baber committed
153
            ),
Baber's avatar
Baber committed
154
155
156
        )
    merged.update(cfg)  # local keys win
    return merged