Unverified Commit 8cfa0d74 authored by Giulio Lovisotto's avatar Giulio Lovisotto Committed by GitHub
Browse files

Use yaml.CLoader to load yaml files when available. (#2777)

parent 07bd7e23
...@@ -10,6 +10,7 @@ import os ...@@ -10,6 +10,7 @@ import os
import re import re
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from itertools import islice from itertools import islice
from pathlib import Path
from typing import Any, Callable, Generator, List, Tuple from typing import Any, Callable, Generator, List, Tuple
import numpy as np import numpy as np
...@@ -428,17 +429,22 @@ def ignore_constructor(loader, node): ...@@ -428,17 +429,22 @@ def ignore_constructor(loader, node):
return node return node
def import_function(loader, node): def import_function(loader: yaml.Loader, node, yaml_path: Path):
function_name = loader.construct_scalar(node) function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name)
*module_name, function_name = function_name.split(".") *module_name, function_name = function_name.split(".")
if isinstance(module_name, list): if isinstance(module_name, list):
module_name = ".".join(module_name) module_name = ".".join(module_name)
module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name))) module_path = yaml_path.parent / f"{module_name}.py"
spec = importlib.util.spec_from_file_location(module_name, module_path) spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
if spec is None:
raise ImportError(f"Could not import module {module_name} from {module_path}.")
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
spec.loader.exec_module(module) spec.loader.exec_module(module)
function = getattr(module, function_name) function = getattr(module, function_name)
...@@ -449,13 +455,17 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full ...@@ -449,13 +455,17 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
if mode == "simple": if mode == "simple":
constructor_fn = ignore_constructor constructor_fn = ignore_constructor
elif mode == "full": elif mode == "full":
constructor_fn = import_function if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader # Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", constructor_fn) yaml.add_constructor("!function", constructor_fn, Loader=loader)
if yaml_config is None: if yaml_config is None:
with open(yaml_path, "rb") as file: with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file) yaml_config = yaml.load(file, Loader=loader)
if yaml_dir is None: if yaml_dir is None:
yaml_dir = os.path.dirname(yaml_path) yaml_dir = os.path.dirname(yaml_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment