eyaml.py 2.3 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import json
from typing import Any, BinaryIO, Union

import ryml

from megatron.energon.epathlib.epath import EPath


def load_yaml_json(path: EPath) -> Any:
    """Load a YAML or JSON file from a path. Loading json is much faster than yaml.

    Args:
        path: The path to the file to load.

    Returns:
        The parsed YAML or JSON data as a native Python object.
    """
    if path.name.endswith(".yaml"):
        with path.open("rb") as f:
            return load_yaml(f)
    elif path.name.endswith(".json"):
        with path.open("rb") as f:
            return json.load(f)
    else:
        raise ValueError(f"Invalid file extension: {path.name}")


def load_yaml(stream: Union[BinaryIO, bytes]) -> Any:
    """Load a YAML file from a stream or bytes object using rapidyaml/ryml.
    This is much faster than the standard yaml library, but we don't
    support all YAML features.

    Args:
        stream: A stream or bytes object containing the YAML data.

    Returns:
        The parsed YAML data as a native Python object.
    """

    if isinstance(stream, bytes):
        bytes_data = stream
    else:
        bytes_data = stream.read()
    tree = ryml.parse_in_arena(bytes_data)
    native_obj = _ryml_tree_to_native(tree)

    return native_obj


def _cast_scalar(tree: ryml.Tree, nid: int) -> object:
    """Convert YAML scalar to bool | int | float | None | str."""
    memview = tree.val(nid)
    s = memview.tobytes().decode()

    if tree.is_val_quoted(nid):
        return s

    s_lower = s.lower()

    if s_lower in {"null", "~"}:
        return None
    if s_lower in {"true", "yes"}:
        return True
    if s_lower in {"false", "no"}:
        return False
    try:
        if "." in s or "e" in s_lower:
            return float(s)
        return int(s)
    except ValueError:
        return s


def _to_native(t: ryml.Tree, nid: int):
    if t.is_map(nid):
        # iterate children of a mapping node
        return {t.key(ch).tobytes().decode(): _to_native(t, ch) for ch in ryml.children(t, nid)}
    if t.is_seq(nid):
        return [_to_native(t, ch) for ch in ryml.children(t, nid)]
    # scalar leaf
    return _cast_scalar(t, nid)


def _ryml_tree_to_native(tree: ryml.Tree):
    return _to_native(tree, tree.root_id())