dataset_config.py 5.09 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path
from typing import Any, Dict, Optional, Type, TypeVar, Union

from megatron.energon.epathlib import EPath
from megatron.energon.eyaml import load_yaml
from megatron.energon.flavors import (
    BaseCoreDatasetFactory,
    CrudeSample,
    DefaultCrudeJsonlDatasetFactory,
    StandardWebdatasetFactory,
)
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type
from megatron.energon.typed_converter import JsonParser
from megatron.energon.worker import WorkerConfig

T = TypeVar("T")


def load_config(
    path: Union[EPath, Dict[str, Any]],
    *,
    default_type: Type[T],
    default_kwargs: Optional[Dict[str, Any]] = None,
    parser: JsonParser = JsonParser(strict=True),
) -> T:
    """
    Loads a config from a file or directly from a dictionary.

    Args:
        path: Path to the config to load or a dictionary containing the config.
        default_type: If set, this is the type to use if no type is specified in the config.
        default_kwargs: Default kwargs to use, will be overridden by the config.

    Returns:
        The instantiated type.
    """
    if isinstance(path, dict):
        data = path
    else:
        # Read the config from a file
        with path.open("rb") as f:
            data: dict = load_yaml(f)

    if default_kwargs is not None:
        new_data = default_kwargs.copy()
        new_data.update(data)
        data = new_data

    return parser.raw_to_instance(data, default_type)


T_sample = TypeVar("T_sample", covariant=True)


def get_dataset_from_config(
    path: Union[EPath, Path, str],
    *,
    dataset_config: str | None = None,
    split_config: str | None = None,
    split_part: str | None = None,
    training: bool = True,
    subflavors: Optional[Dict[str, Any]] = None,
    worker_config: WorkerConfig,
    sample_type: Optional[Type[T_sample]] = None,
    **kwargs,
) -> BaseCoreDatasetFactory[T_sample]:
    """
    Gets a dataset from a config path or path to a jsonl file.

    Args:
        path: Path to the folder where the `.nv-meta` folder is contained, or path to a jsonl file.
        dataset_config: Filename of the dataset config file (`path / '.nv-meta' / config`), or None for jsonl datasets.
        split_config: Filename of the split config file (`path / '.nv-meta' / split_config`), or None for jsonl datasets.
        split_part: Name of the split to load, or None for jsonl datasets.
        training: If true, apply training randomization and loop the dataset.
        subflavors: Merge-Override the __subflavors__ property of each sample.
        worker_config: If set, use this worker config instead of the default one.
        sample_type: Type of the samples to load, only used to ensure typing.
        **kwargs: Additional arguments to be passed to the dataset constructor.

    Returns:
        The instantiated dataset
    """
    path = EPath(path)
    dataset: BaseCoreDatasetFactory[T_sample]
    ds_type = get_dataset_type(path)
    if ds_type == EnergonDatasetType.JSONL:
        assert sample_type is CrudeSample or sample_type is None, (
            f"Sample type must be CrudeSample for jsonl datasets, but got {sample_type}"
        )
        assert dataset_config is None, (
            f"Dataset config must be None for jsonl datasets, but got {dataset_config}"
        )
        assert split_config is None, (
            f"Split config must be None for jsonl datasets, but got {split_config}"
        )
        # Note: We ignore split_part for jsonl datasets and always return the full dataset.

        dataset = DefaultCrudeJsonlDatasetFactory(
            path,
            training=training,
            subflavors=subflavors,
            worker_config=worker_config,
            **kwargs,
        )
    elif ds_type == EnergonDatasetType.WEBDATASET:
        if dataset_config is None:
            dataset_config = "dataset.yaml"
        if split_config is None:
            split_config = "split.yaml"
        if split_part is None:
            split_part = "train"

        dataset = load_config(
            path / MAIN_FOLDER_NAME / dataset_config,
            default_kwargs=dict(
                path=path,
                split_config=split_config,
                split_part=split_part,
                training=training,
                worker_config=worker_config,
                **kwargs,
            ),
            default_type=StandardWebdatasetFactory,
        )
    else:
        raise ValueError(
            f"Path {path} does not contain a {MAIN_FOLDER_NAME}/.info.yaml or .info.json file nor is it a jsonl file. "
            f"Did you forget to prepare the dataset? Please check the documentation for an introduction to dataset "
            f"preparation."
        )
    if subflavors is not None:
        dataset.subflavors.update(subflavors)
    if sample_type is not None:
        assert issubclass(dataset.__sample_type__, sample_type), (
            f"Sample of type {dataset.__sample_type__} is not a subclass of {sample_type}."
        )
    return dataset