Unverified Commit 8cfa0d74 authored by Giulio Lovisotto's avatar Giulio Lovisotto Committed by GitHub
Browse files

Use yaml.CLoader to load yaml files when available. (#2777)

parent 07bd7e23
......@@ -10,6 +10,7 @@ import os
import re
from dataclasses import asdict, is_dataclass
from itertools import islice
from pathlib import Path
from typing import Any, Callable, Generator, List, Tuple
import numpy as np
......@@ -428,17 +429,22 @@ def ignore_constructor(loader, node):
return node
def import_function(loader, node):
def import_function(loader: yaml.Loader, node, yaml_path: Path):
function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name)
*module_name, function_name = function_name.split(".")
if isinstance(module_name, list):
module_name = ".".join(module_name)
module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name)))
module_path = yaml_path.parent / f"{module_name}.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())
if spec is None:
raise ImportError(f"Could not import module {module_name} from {module_path}.")
module = importlib.util.module_from_spec(spec)
if spec.loader is None:
raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
spec.loader.exec_module(module)
function = getattr(module, function_name)
......@@ -449,13 +455,17 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
if mode == "simple":
constructor_fn = ignore_constructor
elif mode == "full":
constructor_fn = import_function
if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader
yaml.add_constructor("!function", constructor_fn)
yaml.add_constructor("!function", constructor_fn, Loader=loader)
if yaml_config is None:
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
yaml_config = yaml.load(file, Loader=loader)
if yaml_dir is None:
yaml_dir = os.path.dirname(yaml_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment