loader.py 10.5 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from pathlib import Path
from typing import Any, List, Literal, Optional, Tuple, TypeVar, Union

from megatron.energon.epathlib import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors import BaseCoreDatasetFactory, SavableDataset
from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder
from megatron.energon.metadataset import load_dataset
from megatron.energon.task_encoder.base import DefaultTaskEncoder, TaskEncoder, WorkerConfig

T = TypeVar("T", covariant=True)


def _split_kwargs(kwargs: dict) -> dict:
    loader_kwargs = {}
    try:
        loader_kwargs["split_part"] = kwargs.pop("split_part")
    except KeyError:
        pass
    try:
        loader_kwargs["dataset_config"] = kwargs.pop("dataset_config")
    except KeyError:
        pass
    try:
        loader_kwargs["split_config"] = kwargs.pop("split_config")
    except KeyError:
        pass
    return loader_kwargs


def _split_deprecated_decoder_kwargs(kwargs: dict, task_encoder: TaskEncoder) -> None:
    """
    auto_decode: bool = True,
    image_decode: ImageDecoder = "torchrgb",
    ignore_decoder_errors: bool = False,
    av_decode: AVDecoder = "AVDecoder",
    video_decode_audio: bool = False,
    """
    auto_decode = True

    decoder_kwargs = {}
    if "auto_decode" in kwargs:
        auto_decode = kwargs.pop("auto_decode")
    if "image_decode" in kwargs:
        decoder_kwargs["image_decode"] = kwargs.pop("image_decode")
    if "av_decode" in kwargs:
        decoder_kwargs["av_decode"] = kwargs.pop("av_decode")
    if "video_decode_audio" in kwargs:
        decoder_kwargs["video_decode_audio"] = kwargs.pop("video_decode_audio")

    if not auto_decode:
        task_encoder.decoder = None
    elif len(decoder_kwargs) > 0:
        warn_deprecated(
            "The following decoder kwargs are deprecated and will be removed in a future version: "
            + ", ".join(decoder_kwargs.keys())
            + ". Instead, set the decoder directly in your task encoder."
        )

        if (
            hasattr(task_encoder, "decoder")
            and task_encoder.decoder is not None
            and task_encoder.decoder is not DefaultTaskEncoder.decoder
        ):
            # The task encoder already has a decoder set.
            # The user might be reusing the task encoder in multiple calls to get_train_dataset
            # and get_val_dataset.
            # We need to check if the decoder is the same as the one we are setting here.
            # If it is, we can return.
            if task_encoder.decoder.config() == SampleDecoder(**decoder_kwargs).config():
                # It's the same decoder, nothing to do.
                return
            else:
                raise ValueError(
                    "Task encoder already has a decoder, and you are setting a different decoder, which is not allowed."
                )

        assert (
            not hasattr(task_encoder, "decoder")
            or task_encoder.decoder is DefaultTaskEncoder.decoder
        ), "Task encoder already has a decoder, and setting using deprecated kwargs is not allowed."

        task_encoder.decoder = SampleDecoder(**decoder_kwargs)


def get_train_dataset(
    path: Union[str, EPath, Path],
    *,
    split_part: Union[Literal["train"], str] = "train",
    worker_config: WorkerConfig,
    batch_size: Optional[int],
    batch_drop_last: bool = False,
    packing_buffer_size: Optional[int] = None,
    shuffle_buffer_size: Optional[int],
    max_samples_per_sequence: Optional[int],
    virtual_epoch_length: int = 0,
    shuffle_over_epochs_multiplier: Optional[int] = 1,
    task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(),
    repeat: bool = True,
    **kwargs,
) -> SavableDataset[T]:
    """
    Get training data loader with sensible defaults. See `get_dataset` for more details.

    The following recipe will be used:
      - :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
      - `task_encoder.encode_sample`
      - (:class:`megatron.energon.MixDataset` if mixing)
      - :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
      - `task_encoder.encode_batch`
      - :class:`megatron.energon.EpochizeDataset` (if `virtual_epoch_length` is set)

    Args:
        path: Path to the dataset.
        split_part: Default split part to use.
        worker_config: Worker configuration to use.
        batch_size: Size of a batch. If None, do not batch
        batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
        shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding).
        max_samples_per_sequence: If set, limit the number of samples per sample-sequence to this.
        virtual_epoch_length: If set, the dataset will be epochized to this length (=iterating
            will be suspended and the for-loop returns, next for-loop continues iterating).
            Otherwise, the dataset will loop indefinitely.
        shuffle_over_epochs_multiplier: Shuffle the shards over this many epochs.
        task_encoder: Task encoder to use.
        repeat: By default, the inner datasets will loop. If set to False, stop iteration after
            one epoch. Must only be set to False in conjunction with blend_epochized in the
            metadataset if one is used.
        cache_pool: If set, the cache pool to use for the dataset.
        **kwargs: Additional arguments to the dataset constructor.

    Returns:
        The dataloader.
    """

    loader = load_dataset(path, **_split_kwargs(kwargs))
    _split_deprecated_decoder_kwargs(kwargs, task_encoder)

    datasets = loader.get_datasets(
        training=True,
        split_part=split_part,
        worker_config=worker_config,
        max_samples_per_sequence=max_samples_per_sequence,
        shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
        decoder=task_encoder.decoder,
        **kwargs,
    )
    return task_encoder.build_train_datasets(
        datasets=datasets.datasets,
        worker_config=worker_config,
        batch_size=batch_size,
        batch_drop_last=batch_drop_last,
        packing_buffer_size=packing_buffer_size,
        virtual_epoch_length=virtual_epoch_length,
        shuffle_buffer_size=shuffle_buffer_size,
        blend_mode=datasets.blend_mode,
        repeat=repeat,
    )


def get_val_dataset(
    path: Union[str, EPath, Path],
    *,
    split_part: Union[Literal["val", "test"], str] = "val",
    worker_config: WorkerConfig,
    batch_size: int,
    batch_drop_last: bool = False,
    packing_buffer_size: Optional[int] = None,
    limit: Optional[int] = None,
    task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(),
    **kwargs,
) -> SavableDataset[T]:
    """
    Get the validation/test dataset with sensible defaults. See `get_dataset` for more details.

    The following recipe will be used:
      - :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
      - `task_encoder.encode_sample`
      - (:class:`megatron.energon.MixDataset` if mixing)
      - :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
      - :class:`megatron.energon.LimitDataset` (if `limit` is set)
      - `task_encoder.encode_batch`

    Args:
        path: Path to the dataset.
        split_part: Default split part to use.
        worker_config: Worker configuration to use.
        batch_size: Size of a batch
        batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
        limit: If set, limit the number of batches loaded from the dataset to this.
        task_encoder: Task encoder to use.
        **kwargs: Additional arguments to the dataset constructor.

    Returns:
        The loaded dataset.
    """
    _split_deprecated_decoder_kwargs(kwargs, task_encoder)
    loader = load_dataset(path, **_split_kwargs(kwargs))
    datasets = loader.get_datasets(
        training=False,
        split_part=split_part,
        worker_config=worker_config,
        decoder=task_encoder.decoder,
        **kwargs,
    )
    return task_encoder.build_val_datasets(
        datasets=datasets.datasets,
        worker_config=worker_config,
        batch_size=batch_size,
        batch_drop_last=batch_drop_last,
        packing_buffer_size=packing_buffer_size,
        limit=limit,
    )


def get_val_datasets(
    path: Union[str, EPath, Path],
    *,
    split_part: Union[Literal["val", "test"], str] = "val",
    worker_config: WorkerConfig,
    batch_size: int,
    batch_drop_last: bool = False,
    packing_buffer_size: Optional[int] = None,
    limit: Optional[int] = None,
    task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(),
    **kwargs,
) -> List[Tuple[SavableDataset[T], BaseCoreDatasetFactory]]:
    """
    Get the validation/test dataset with sensible defaults. See `get_dataset` for more details.

    The following recipe will be used:
      - :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
      - `task_encoder.encode_sample`
      - (:class:`megatron.energon.MixDataset` if mixing)
      - :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
      - :class:`megatron.energon.LimitDataset` (if `limit` is set)
      - `task_encoder.encode_batch`

    Args:
        path: Path to the dataset.
        split_part: Default split part to use.
        worker_config: Worker configuration to use.
        batch_size: Size of a batch
        batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
        limit: If set, limit the number of batches loaded from the dataset to this.
        task_encoder: Task encoder to use.
        **kwargs: Additional arguments to the dataset constructor.

    Returns:
        The loaded val datasets, with the source datasets.
    """
    _split_deprecated_decoder_kwargs(kwargs, task_encoder)
    loader = load_dataset(path, **_split_kwargs(kwargs))
    datasets = loader.get_datasets(
        training=False,
        split_part=split_part,
        worker_config=worker_config,
        decoder=task_encoder.decoder,
        **kwargs,
    )
    return [
        (
            task_encoder.build_val_datasets(
                datasets=[dataset],
                worker_config=worker_config,
                batch_size=batch_size,
                batch_drop_last=batch_drop_last,
                packing_buffer_size=packing_buffer_size,
                limit=limit,
            ),
            dataset.dataset,
        )
        for dataset in datasets.datasets
    ]