Commit f356f546 authored by maming's avatar maming
Browse files

Initial commit

parents
Pipeline #3339 canceled with stages
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from abc import ABC, abstractmethod
from bisect import bisect_right
from typing import (
Callable,
Dict,
Generator,
Generic,
List,
Optional,
Tuple,
TypeVar,
Union,
overload,
)
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset.config import skip_meta_re, split_name_re
from megatron.energon.flavors.webdataset.indexing import JoinIndexReader, SqliteIndexReader
from megatron.energon.flavors.webdataset.itar import (
CachedItarOffsetReader,
ITarFile,
ITarRawSamplePartPointer,
ITarSamplePointer,
)
from megatron.energon.flavors.webdataset.metadata import get_info_shard_files
from megatron.energon.flavors.webdataset.structs import FilteredSample, ShardInfo
from megatron.energon.source_info import SourceInfo
T_index = TypeVar("T_index", covariant=False)
class ITarReader(ABC, Generic[T_index]):
"""
An abstract base class for reading a sequence of tar files containing samples.
The common usage patterns and random-access interfaces are provided here.
Args:
base_path: The base path for the relative tar_filepaths.
tar_filenames: The list of tar file names.
tar_filepaths: The corresponding list of full paths to the tar files.
part_filter: An optional filter function to select parts of the samples.
itar_cache_size: The number of tar readers to keep open at the same time.
sample_filter: An optional filter function to select samples by their key.
"""
base_path: EPath
tar_filenames: List[str]
tar_filepaths: List[EPath]
part_filter: Optional[Callable[[str], bool]]
itar_files_cache: Dict[int, ITarFile]
sample_filter: Optional[Callable[[str], bool]]
def __init__(
self,
base_path: EPath,
tar_filenames: List[str],
tar_filepaths: List[EPath],
part_filter: Optional[Callable[[str], bool]] = None,
itar_cache_size: int = 5,
sample_filter: Optional[Callable[[str], bool]] = None,
):
assert len(tar_filenames) == len(tar_filepaths), (
f"tar_filenames length ({len(tar_filenames)}) does not match "
f"tar_filepaths length ({len(tar_filepaths)})"
)
self.base_path = base_path
self.tar_filenames = tar_filenames
self.tar_filepaths = tar_filepaths
self.part_filter = part_filter
self.itar_files_cache = {}
self.itar_cache_size = itar_cache_size
self.sample_filter = sample_filter
@abstractmethod
def __len__(self) -> int:
"""Returns the total number of samples in the reader."""
raise NotImplementedError
@abstractmethod
def __str__(self) -> str:
"""
Must return a descriptive string of the concrete reader.
"""
raise NotImplementedError
def close(self):
for tar_file in self.itar_files_cache.values():
tar_file.fileobj.close()
tar_file.close()
self.itar_files_cache.clear()
@abstractmethod
def _get_itar_sample_pointer(self, idx: T_index) -> ITarSamplePointer:
"""Get the ITarSample object for the given index."""
raise NotImplementedError
def _get_itarfile_cached(self, tar_file_id: int) -> ITarFile:
"""
Get the ITarFile object for the given tar file id.
If the file is not already open, open it. If we exceed
the global cache limit, close the least recently used file.
"""
if tar_file_id not in self.itar_files_cache:
file_object = self.tar_filepaths[tar_file_id].open(mode="rb")
tar_file = ITarFile.open(fileobj=file_object, mode="r:")
self.itar_files_cache[tar_file_id] = tar_file
# If we hit the limit of open files, close the least recently used file
while len(self.itar_files_cache) > self.itar_cache_size:
# Get the oldest file
lru_key = next(iter(self.itar_files_cache))
self.itar_files_cache[lru_key].fileobj.close()
self.itar_files_cache[lru_key].close()
del self.itar_files_cache[lru_key]
return self.itar_files_cache[tar_file_id]
def _get_part_by_raw_sample_pointer(
self,
raw_sample_pointer: ITarRawSamplePartPointer,
entry_name: str,
) -> tuple[bytes, SourceInfo]:
"""
Get a sample part and the source info from the dataset.
Args:
raw_sample_pointer: The raw data sample pointer to get the sample from.
Returns:
The raw data bytes.
"""
# Open the tar file (cached)
tar_file = self._get_itarfile_cached(raw_sample_pointer.tar_file_id)
shard_name = self.tar_filenames[raw_sample_pointer.tar_file_id]
# Get the raw data from the tar file
rest = tar_file.fileobj.tell()
tar_file.fileobj.seek(raw_sample_pointer.raw_byte_offset)
raw_data = tar_file.fileobj.read(raw_sample_pointer.raw_byte_size)
tar_file.fileobj.seek(rest)
return raw_data, SourceInfo(
dataset_path=self.base_path,
index=entry_name,
shard_name=shard_name,
file_names=(entry_name,),
)
def _get_item_by_sample_pointer(
self,
sample_pointer: ITarSamplePointer,
restore_index: str | int,
entry_match_fn: Optional[Callable[[str], bool]] = None,
) -> FilteredSample | None:
"""
Get a sample from the dataset or slice it.
Args:
sample_pointer: The sample pointer to get the sample from.
sample_index: The global index of the sample in the dataset.
entry_match_fn: An optional function to filter the entries in the sample.
Returns:
The sample or None if the sample is not found.
"""
# Open the tar file (cached)
tar_file = self._get_itarfile_cached(sample_pointer.tar_file_id)
shard_name = self.tar_filenames[sample_pointer.tar_file_id]
sample_base_name = None
sample_name = None
group_parts: Dict[str, bytes] = {}
file_names: list[str] = []
# Position the tar file at the correct offset
tar_file.offset = sample_pointer.byte_offset
while tar_file.offset < sample_pointer.byte_offset + sample_pointer.byte_size:
tarinfo = tar_file.next()
if tarinfo is None:
raise ValueError(
f"Unexpected end of tar file: {self.tar_filenames[sample_pointer.tar_file_id]}"
)
fname = tarinfo.name
if not tarinfo.isfile() or fname is None:
continue
if skip_meta_re.match(fname):
continue
# Extract the base_name and extension
m = split_name_re.match(fname)
if not m:
continue
cur_base_name, cur_ext = m.groups()
if sample_base_name is None:
sample_base_name = cur_base_name
sample_name = f"{shard_name}/{cur_base_name}"
if self.sample_filter is not None and not self.sample_filter(sample_name):
return None
else:
if sample_base_name != cur_base_name:
raise ValueError(
f"Inconsistent sample base name: {sample_base_name} vs {cur_base_name}"
)
if entry_match_fn is not None:
# If entry_match_fn is provided, use it to determine if we should take this entry
take_entry = entry_match_fn(fname)
else:
# If no entry_match_fn is provided, use the part_filter to determine if we should take this entry
take_entry = self.part_filter is None or self.part_filter(cur_ext)
if take_entry:
member_bytes = tar_file.extractfile(tarinfo).read()
group_parts[cur_ext] = member_bytes
file_names.append(fname)
if sample_base_name is None:
raise ValueError(f"No valid files found in sample {sample_pointer}")
return FilteredSample(
__key__=f"{shard_name}/{sample_base_name}",
__shard__=self.tar_filenames[sample_pointer.tar_file_id],
__restore_key__=("Webdataset", restore_index),
__sources__=(
SourceInfo(
dataset_path=self.base_path,
index=restore_index,
shard_name=shard_name,
file_names=tuple(file_names),
),
),
**group_parts,
)
def __getitem__(self, idx: T_index) -> FilteredSample | None:
"""
Get a sample from the dataset or slice it.
"""
assert isinstance(idx, int), f"Invalid argument type for __getitem__: {type(idx)}"
sample_pointer = self._get_itar_sample_pointer(idx)
return self._get_item_by_sample_pointer(sample_pointer, idx)
class JoinIndexFileITarReader(ITarReader[int]):
"""
A concrete ITarReader that reads samples from a join index file (via JoinIndexReader).
"""
index_file: EPath
column: int
index_reader_cache: Dict[int, JoinIndexReader]
index_reader_cache_size: int
def __init__(
self,
index_file: EPath,
column: int,
tar_filenames: List[str],
base_path: EPath,
part_filter: Optional[Callable[[str], bool]] = None,
itar_cache_size: int = 5,
sample_filter: Optional[Callable[[str], bool]] = None,
):
self.index_file = index_file
self.column = column
# Create the full path to each tar file
tar_filepaths = [base_path / fn for fn in tar_filenames]
self.index_reader_cache = {}
self.index_reader_cache_size = itar_cache_size
super().__init__(
base_path=base_path,
tar_filenames=tar_filenames,
tar_filepaths=tar_filepaths,
part_filter=part_filter,
itar_cache_size=itar_cache_size,
sample_filter=sample_filter,
)
def _get_join_index_reader_cached(self, sample_idx: int) -> JoinIndexReader:
"""
Get the JoinIndexReader object for the given sample index, or create it if it doesn't exist.
"""
if sample_idx not in self.index_reader_cache:
index_reader = JoinIndexReader(self.index_file, column=self.column)
self.index_reader_cache[sample_idx] = index_reader
# If we hit the limit of open files, close the least recently used file
while len(self.index_reader_cache) > self.index_reader_cache_size:
# Get the oldest file
lru_key = next(iter(self.index_reader_cache))
self.index_reader_cache[lru_key].close()
del self.index_reader_cache[lru_key]
return self.index_reader_cache[sample_idx]
def _get_itar_sample_pointer(self, sample_idx: int) -> ITarSamplePointer:
"""
Get the ITarSample object for the given index.
"""
index_reader = self._get_join_index_reader_cached(sample_idx)
row = index_reader[sample_idx]
# Update cache entry
new_offset = index_reader.tell_row()
del self.index_reader_cache[sample_idx]
self.index_reader_cache[new_offset] = index_reader
assert len(row) == 1
shard_idx, byte_offset, byte_size = row[0]
return ITarSamplePointer(
tar_file_id=shard_idx,
byte_offset=byte_offset,
byte_size=byte_size,
)
def __len__(self) -> int:
try:
# Get any reader, they will all work
index_reader = next(iter(self.index_reader_cache.values()))
except StopIteration:
# If there's no reader yet, we need to create one to get the length
index_reader = self._get_join_index_reader_cached(0)
return len(index_reader)
def __str__(self) -> str:
return (
f"JoinIndexFileITarReader("
f"len={len(self)}, base_path={self.base_path}, "
f"len(shards)={len(self.tar_filenames)}, "
f"shards=[{self.tar_filenames[0] if self.tar_filenames else 'N/A'}, ...])"
)
class ShardInfosITarReader(ITarReader[int]):
"""
A concrete ITarReader that constructs its internal sample list from a list of ShardInfos.
"""
shard_infos: List[ShardInfo]
shard_tar_file_idxs: List[int]
shard_count_cumsum: List[int]
cached_offset_reader: CachedItarOffsetReader
def __init__(
self,
base_path: EPath,
shard_infos: List[ShardInfo],
part_filter: Optional[Callable[[str], bool]] = None,
itar_cache_size: int = 5,
sample_filter: Optional[Callable[[str], bool]] = None,
):
# Build the tar_filenames and tar_filepaths from shard_infos,
# constructing the samples tensor as we go.
cur_tar_files: Dict[str, Tuple[int, EPath]] = {}
self.shard_infos = shard_infos
# Compute the cumsum of the shard counts, so that we can look up
# the shard index for a given sample index.
# Get all tar files from the shard_infos
self.shard_count_cumsum = [0]
self.shard_tar_file_idxs = []
sample_idx = 0
for shardinfo in shard_infos:
filepath = shardinfo.path
filename = shardinfo.name
if filename not in cur_tar_files:
cur_tar_files[filename] = (len(cur_tar_files), filepath)
sample_idx += shardinfo.count
self.shard_count_cumsum.append(sample_idx)
self.shard_tar_file_idxs.append(cur_tar_files[filename][0])
tar_filenames = list(cur_tar_files.keys())
tar_filepaths = [p[1] for p in cur_tar_files.values()]
# Instantiate cached reader for the .tar.idx files
self.cached_offset_reader = CachedItarOffsetReader(cache_size=itar_cache_size)
super().__init__(
base_path=base_path,
tar_filenames=tar_filenames,
tar_filepaths=tar_filepaths,
part_filter=part_filter,
itar_cache_size=itar_cache_size,
sample_filter=sample_filter,
)
def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer:
"""
Get the ITarSample object for the given index.
"""
# Find the shard index using binary search
shard_idx = bisect_right(self.shard_count_cumsum, idx) - 1
if shard_idx < 0 or shard_idx >= len(self.shard_infos):
raise IndexError(f"Index out of bounds: {idx}")
# Get the shard info for the given index
shard = self.shard_infos[shard_idx]
sample_idx_in_shard_file = idx - self.shard_count_cumsum[shard_idx]
# Now we know the tar file and the sample offset in the file.
# We need to figure out the byte offset and size of the sample,
# by looking it up in the .tar.idx file.
byte_offset, byte_size = self.cached_offset_reader.get_itar_byte_offset(
shard.path, sample_idx_in_shard_file
)
return ITarSamplePointer(
tar_file_id=self.shard_tar_file_idxs[shard_idx],
byte_offset=byte_offset,
byte_size=byte_size,
)
def __len__(self) -> int:
return self.shard_count_cumsum[-1]
def __str__(self) -> str:
return (
f"ShardInfosITarReader("
f"len={len(self)}, base_path={self.base_path}, "
f"len(shards)={len(self.tar_filenames)}, "
f"shards=[{self.tar_filenames[0] if self.tar_filenames else 'N/A'}, ...])"
)
class SqliteITarEntryReader(ITarReader[str]):
"""
A concrete ITarReader that constructs its internal sample list from a SQLite database.
"""
sqlite_reader: SqliteIndexReader
db_has_sample_parts: int
def __init__(
self,
base_path: EPath,
part_filter: Optional[Callable[[str], bool]] = None,
itar_cache_size: int = 5,
sample_filter: Optional[Callable[[str], bool]] = None,
key_is_full_entryname: bool = False,
):
from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.indexing import SqliteIndexReader
# shard_name_to_info_idx = {name: i for i, name in enumerate(wds_meta.info_shard_files)}
tar_filenames = get_info_shard_files(base_path)
tar_filepaths = [base_path / fn for fn in tar_filenames]
# Initialize the SQLite reader
sqlite_path = base_path / MAIN_FOLDER_NAME / "index.sqlite"
self.sqlite_reader = SqliteIndexReader(sqlite_path)
self.db_has_sample_parts = self.sqlite_reader.db_has_sample_parts()
self.key_is_full_entryname = key_is_full_entryname
super().__init__(
base_path=base_path,
tar_filenames=tar_filenames,
tar_filepaths=tar_filepaths,
part_filter=part_filter,
itar_cache_size=itar_cache_size,
sample_filter=sample_filter,
)
def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer:
"""
Get the ITarSample object for the given index.
"""
return self.sqlite_reader.get_sample_pointer_by_key(sample_key)
def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]:
"""List all samples in the jsonl file.
Returns:
A generator of tuples of (sample_key, size, tar_file_id)
"""
return self.sqlite_reader.list_all_samples()
def list_all_sample_parts(self) -> Generator[Tuple[str, int, int], None, None]:
"""List all sample parts in the jsonl file.
Returns:
A generator of tuples of (sample_key + "." + part_name, size, tar_file_id)
"""
return self.sqlite_reader.list_all_sample_parts()
def list_sample_parts(
self, sample_key: str, slow_mode: bool = False
) -> Generator[Tuple[str, int, int], None, None]:
"""Given a sample key, list all its parts. (E.g. given 001, list 001.jpg, 001.json, etc.)
If allow_fallback is True, and the database is an older version, which
does not contain the sample_parts table, we will try to find the sample_parts
in the tar files.
Args:
sample_key: The sample key to list the parts of.
allow_fallback: If True, and the database is an older version, which
does not contain the sample_parts table, we will try to find the sample_parts
in the tar files.
Returns:
A generator of tuples of (part_name, size, tar_file_id)
"""
if not slow_mode:
yield from self.sqlite_reader.list_sample_parts(sample_key)
else:
sample_pointer = self._get_itar_sample_pointer(sample_key)
sample = self._get_item_by_sample_pointer(sample_pointer, 0, entry_match_fn=None)
assert isinstance(sample, dict), f"Sample not found: {sample_pointer}"
for ext in sample.keys():
if not ext.startswith("__"):
yield ext, len(sample[ext]), sample_pointer.tar_file_id
def get_total_size(self) -> int:
return self.sqlite_reader.get_total_size()
@overload
def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]: ...
@overload
def __getitem__(self, key: slice) -> "ITarReader": ...
def __getitem__(
self, key: Union[slice, str]
) -> Union[FilteredSample, tuple[bytes, SourceInfo], ITarReader]:
"""
Either get a sample from the dataset by the sample key including all its entries,
or get the bytes of a specific entry by the full filename of the entry inside the tar.
"""
if isinstance(key, slice):
# Return a new reader with a sliced samples tensor
raise NotImplementedError("Slicing is not yet implemented")
assert isinstance(key, str), "Invalid argument type for __getitem__"
if self.key_is_full_entryname:
m = split_name_re.match(key)
if not m:
raise ValueError(f"Invalid file name: {key}")
sample_key, sample_ext = m.groups()
entry_match_fn = lambda fname: key == fname
if self.db_has_sample_parts:
# Directly fetch the sample part (byte offset and size) from the database
raw_sample_pointer = self.sqlite_reader.get_sample_part(sample_key, sample_ext)
raw_data, source_info = self._get_part_by_raw_sample_pointer(
raw_sample_pointer, key
)
return raw_data, source_info
else:
sample_key = key
sample_ext = None
entry_match_fn = None
sample_pointer = self._get_itar_sample_pointer(sample_key)
sample = self._get_item_by_sample_pointer(
sample_pointer, key, entry_match_fn=entry_match_fn
)
assert sample is not None, f"Sample not found: {sample_key}"
if self.key_is_full_entryname:
assert isinstance(sample_ext, str)
assert len(sample["__sources__"]) == 1
# Return the bytes directly
return sample[sample_ext], sample["__sources__"][0]
else:
return sample # Return the FilteredSample
def __len__(self) -> int:
"""Return the total number of samples in the database."""
return self.sqlite_reader.get_sample_count()
def __str__(self) -> str:
"""Return a descriptive string of this reader."""
return (
f"SqliteITarEntryReader("
f"len={len(self)}, base_path={self.base_path}, "
f"len(shards)={len(self.tar_filenames)}, "
f"shards=[{self.tar_filenames[0] if self.tar_filenames else 'N/A'}, ...])"
)
def close(self):
"""Close the SQLite reader and any open ITarFiles."""
# Close the SQLite reader
if hasattr(self, "sqlite_reader") and self.sqlite_reader is not None:
self.sqlite_reader.close()
# Close any open ITarFiles (using parent class implementation)
for tar_file_id in list(self.itar_files_cache.keys()):
tar_file = self.itar_files_cache[tar_file_id]
if (
tar_file is not None
and hasattr(tar_file, "fileobj")
and tar_file.fileobj is not None
):
tar_file.fileobj.close()
if tar_file is not None and hasattr(tar_file, "close"):
tar_file.close()
del self.itar_files_cache[tar_file_id]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from abc import ABC
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
from megatron.energon.cache import FileStore
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import (
BaseCoreDatasetFactory,
Sample,
SavableDataset,
legacy_handler,
set_sample_restore_key,
)
from megatron.energon.flavors.webdataset.base_webdataset import BaseWebdatasetFactory
from megatron.energon.flavors.webdataset.error_handler import ErrorHandler
from megatron.energon.flavors.webdataset.indexing import JoinIndexReader
from megatron.energon.flavors.webdataset.itar_reader import JoinIndexFileITarReader
from megatron.energon.flavors.webdataset.sample_loader import (
RawSampleData,
WebdatasetSampleLoaderDataset,
)
from megatron.energon.flavors.webdataset.sharder import Sharder
from megatron.energon.flavors.webdataset.structs import DatasetSubset, ShardInfo, reraise_exception
from megatron.energon.source_info import SourceInfo
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.map_dataset import MapDataset
T_sample = TypeVar("T_sample", covariant=True)
class JoinedWebdatasetFactory(
BaseCoreDatasetFactory[T_sample], Sharder, ErrorHandler[T_sample], Generic[T_sample], ABC
):
"""
Base class for all webdataset loaders. Applies proper sharding across workers. Can join multiple datasets.
"""
training: bool
worker_config: WorkerConfig
shuffle_over_epochs: Optional[int] = 1
parallel_shard_iters: Optional[int]
max_samples_per_sequence: Optional[int]
subset: Optional[DatasetSubset]
join_index: EPath
handler: Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None]
shards: List[Sequence[ShardInfo]]
part_datasets: SavableDataset[T_sample]
inner_datasets: List[BaseWebdatasetFactory]
inner_dataset_keys: Optional[List[str]]
_sample_joiner: Callable[..., T_sample]
def __init__(
self,
inner_datasets: Union[Sequence[BaseWebdatasetFactory], Mapping[str, BaseWebdatasetFactory]],
*,
training: bool,
worker_config: WorkerConfig,
shuffle_over_epochs: Optional[int] = 1,
parallel_shard_iters: Optional[int] = None,
max_samples_per_sequence: Optional[int] = None,
subset: Optional[DatasetSubset] = None,
join_index: EPath,
joiner: Union[Type[T_sample], Callable[..., T_sample]],
handler: Callable[
[Exception, Optional[str], Optional[list[SourceInfo]]], None
] = reraise_exception,
):
"""
Constructs the loader for a joined webdataset. The samples from the inner datasets are joined into a single
sample using the joiner function.
Args:
inner_dataset: The inner datasets. Must be loaded internally with `_is_composed=True`.
Either a list (\\*args for joiner) or a dict (\\*\\*kwargs for joiner) of datasets,
where the samples will be passed to the joiner function as \\*args or \\*\\*kwargs.
training: If true, apply shuffling and loop the dataset.
worker_config: Configuration for the workers.
shuffle_over_epochs: Only effective if training=True.
How many epochs to shuffle over if training.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather shard slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices
are drawn with replacement).
parallel_shard_iters: Number of parallel opened shards per worker, shuffling between.
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequentially iterated).
subset: If specified, the inner dataset(s) will be subsetted.
join_index: Path to the join index file. Only required for join_method="left".
joiner: Type of the joined samples or a method for joining the samples.
handler: Exception handler. Args: (exception, key).
"""
self.__sample_type__ = joiner
assert all(not hasattr(d, "dataset") for d in inner_datasets), (
"Inner dataset was not instantiated with _is_composed=True"
)
if isinstance(joiner, type) and issubclass(joiner, Sample):
joiner = joiner.from_joined
else:
assert callable(joiner), f"Joiner {joiner} must be a callable or a Sample subclass"
if isinstance(inner_datasets, Mapping):
inner_keys = list(inner_datasets.keys())
self.inner_dataset_keys = inner_keys
# Wrap the joiner to pass the samples as kwargs
self._sample_joiner = lambda *samples: joiner(**dict(zip(inner_keys, samples)))
inner_datasets = list(inner_datasets.values())
else:
assert isinstance(inner_datasets, Sequence)
self._sample_joiner = joiner
self.inner_dataset_keys = None
self.join_index = join_index
self.inner_datasets = inner_datasets
self.shards = list(zip(*(dataset.shards for dataset in self.inner_datasets)))
self.training = training
self.worker_config = worker_config
self.shuffle_over_epochs = shuffle_over_epochs
self.parallel_shard_iters = parallel_shard_iters
self.max_samples_per_sequence = max_samples_per_sequence
self.subset = subset
self.handler = legacy_handler(handler)
def __len__(self) -> int:
return sum(shard.count for shard in self.inner_datasets[0].shards)
def build(self, worker_rotation_offset: int = 0) -> SavableDataset[T_sample]:
if self.parallel_shard_iters is None:
if self.training:
# 16 seems to be a good choice since we don't want too many file handles open
parallel_shard_iters = 16
else:
parallel_shard_iters = 1
else:
parallel_shard_iters = self.parallel_shard_iters
# Get join index, get size, distribute samples
# Get samples for each worker on current rank
assert self.join_index.is_file(), (
f"Join index {self.join_index} does not exist, did you prepare the metadataset? "
"If you already prepared the metadataset, the join index might be outdated due to "
"modifications to the inner datasets. In this case, you need to re-prepare the metadataset."
)
with JoinIndexReader(self.join_index) as jir:
total_samples = len(jir)
workers_sample_slice_offsets = self.slice_workers(
total_samples,
worker_config=self.worker_config,
max_samples_per_sequence=self.max_samples_per_sequence,
rotation_offset=worker_rotation_offset,
subset=self.subset,
)
for worker_idx, sample_slice_offsets in enumerate(workers_sample_slice_offsets):
start_idx = sample_slice_offsets[0]
end_idx = sample_slice_offsets[-1]
if len(sample_slice_offsets) > 6:
offset_str = f"{', '.join(str(o) for o in sample_slice_offsets[:3])} ...<{len(sample_slice_offsets) - 6}> {', '.join(str(o) for o in sample_slice_offsets[-3:])}"
else:
offset_str = ", ".join(str(o) for o in sample_slice_offsets)
print(
f"rank={self.worker_config.rank}, worker={worker_idx}: sample_range=[{start_idx}, {end_idx}) in {len(sample_slice_offsets) - 1} slices, "
f"sum(count)={end_idx - start_idx}: [{offset_str}]"
)
itar_readers = [
JoinIndexFileITarReader(
index_file=self.join_index,
column=col_idx,
tar_filenames=indexed_dataset.split_part_files,
base_path=indexed_dataset.path,
part_filter=indexed_dataset.part_filter,
itar_cache_size=parallel_shard_iters,
)
for col_idx, indexed_dataset in enumerate(self.inner_datasets)
]
dataset = WebdatasetSampleLoaderDataset(
join_readers=itar_readers,
workers_sample_slice_offsets=workers_sample_slice_offsets,
worker_config=self.worker_config,
shuffle_over_epochs=self.shuffle_over_epochs if self.training else None,
parallel_slice_iters=parallel_shard_iters,
)
return self._process_samples(dataset)
def as_file_store(self) -> FileStore:
raise NotImplementedError("Not supported on joined datasets")
@property
def paths(self) -> List[EPath]:
return [dataset.path for dataset in self.inner_datasets]
def _process_samples(self, dataset: SavableDataset[RawSampleData]) -> SavableDataset[T_sample]:
"""Internally loads the sample."""
return MapDataset(
dataset,
self.load_sample,
error_handler=self.error_handler,
stateless_map_fn=True,
map_fn_config=self.config,
worker_config=self.worker_config,
)
def load_sample(self, samples: RawSampleData) -> T_sample:
assert len(samples.data) > 0 and samples.data[0] is not None, "Always need primary sample"
# First call the loaders of all inner datasets
loaded_samples = tuple(
None if sample is None else dataset.load_sample(sample)
for dataset, sample in zip(self.inner_datasets, samples.data)
)
# Then combine the loaded smaples into the final type
return set_sample_restore_key(
self._sample_joiner(*loaded_samples),
*samples.__restore_key__,
src=self,
fail_otherwise=True,
)
def config(self) -> Dict[str, Any]:
return dict(
type=type(self).__qualname__,
joined_datasets=[dataset.config() for dataset in self.inner_datasets],
training=self.training,
shuffle_over_epochs=self.shuffle_over_epochs,
parallel_shard_iters=self.parallel_shard_iters,
max_samples_per_sequence=self.max_samples_per_sequence,
subset=self.subset.config() if self.subset is not None else None,
)
def __str__(self):
return f"{type(self).__name__}(paths={self.paths})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import json
from enum import Enum
from typing import List, Set
import braceexpand
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.eyaml import load_yaml, load_yaml_json
from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.empty_dataset_error import EmptyDatasetError
from megatron.energon.flavors.webdataset.structs import (
ShardInfo,
WebdatasetInfo,
WebdatasetSplits,
)
from megatron.energon.typed_converter import JsonParser
class EnergonDatasetType(Enum):
METADATASET = "metadataset"
WEBDATASET = "webdataset"
JSONL = "jsonl"
INVALID = "invalid"
@edataclass
class WebdatasetMeta:
"""Class for getting metadata from a webdataset."""
sample_excludes: Set[str]
shards: List[ShardInfo]
split_part_files: List[str]
info_shard_files: List[str]
@staticmethod
def from_config(
path: EPath,
*,
split_part: str,
split_config: str | None = None,
) -> "WebdatasetMeta":
"""
Loads the metadata for a webdataset, i.e. the shards and sample excludes.
Args:
split_part: Which part to load (e.g. 'train', 'val', 'test').
split_config: Config file to use for shard split definitions.
"""
if split_config is None:
split_config = "split.yaml"
parser = JsonParser(strict=True)
info_object = get_dataset_info(path)
info = parser.raw_to_typed(
info_object,
WebdatasetInfo,
)
try:
splits = parser.raw_to_typed(
load_yaml_json(path / MAIN_FOLDER_NAME / split_config),
WebdatasetSplits,
)
except FileNotFoundError:
if split_config == "split.yaml":
# Try split.json instead
splits = parser.raw_to_typed(
load_yaml_json(path / MAIN_FOLDER_NAME / "split.json"),
WebdatasetSplits,
)
else:
raise
assert split_part in splits.split_parts, f"Invalid split part: {split_part!r}"
split_excludes = {
excluded
for excluded in splits.exclude
for excluded in braceexpand.braceexpand(excluded)
}
all_split_part_files = [
name
for name in splits.split_parts[split_part]
for name in braceexpand.braceexpand(name)
]
split_part_files = [name for name in all_split_part_files if name not in split_excludes]
if len(split_part_files) == 0:
raise EmptyDatasetError(f"No shards found in split part {split_part!r}")
return WebdatasetMeta(
sample_excludes={excluded for excluded in split_excludes if "/" in excluded},
shards=[
ShardInfo(
name=name,
path=path / name,
count=info.shard_counts[name],
)
for name in split_part_files
],
split_part_files=all_split_part_files,
info_shard_files=list(info.shard_counts.keys()),
)
def get_info_shard_files(path: EPath) -> List[str]:
"""Use this if you don't need the full metadata for split parts, but just the shard files."""
parser = JsonParser(strict=True)
info = parser.raw_to_typed(
get_dataset_info(path),
WebdatasetInfo,
)
return list(info.shard_counts.keys())
def get_dataset_info(path: EPath) -> dict:
"""Given the path to an energon webdataset that contains a .nv-meta folder,
return the dataset info as a dict.
"""
info_config = path / MAIN_FOLDER_NAME / ".info.json"
# YAML for backwards compatibility
yaml_info_config = path / MAIN_FOLDER_NAME / ".info.yaml"
if info_config.is_file():
with info_config.open("r") as rf:
return json.load(rf)
elif yaml_info_config.is_file():
return load_yaml(yaml_info_config.read_bytes())
else:
raise ValueError(f"No info config file found at {info_config} or {yaml_info_config}")
def check_dataset_info_present(path: EPath) -> bool:
"""Given the path to an energon webdataset that contains a .nv-meta folder,
return True if the dataset info is present, False otherwise.
"""
return (path / MAIN_FOLDER_NAME / ".info.json").is_file() or (
path / MAIN_FOLDER_NAME / ".info.yaml"
).is_file()
def get_dataset_type(path: EPath) -> EnergonDatasetType:
"""Get the type of the dataset at the given path.
Args:
path: The path to the dataset as specified by the user.
Returns:
The type of the dataset.
"""
if path.is_file():
if path.name.endswith(".jsonl"):
return EnergonDatasetType.JSONL
elif path.name.endswith(".yaml"):
return EnergonDatasetType.METADATASET
else:
return EnergonDatasetType.INVALID
elif check_dataset_info_present(path):
return EnergonDatasetType.WEBDATASET
else:
return EnergonDatasetType.INVALID
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import functools
import json
import logging
import random
import re
import tarfile
import uuid
from dataclasses import asdict
from pathlib import Path
from typing import (
Any,
Callable,
Container,
Dict,
Generator,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
import braceexpand
import yaml
from megatron.energon import __version__
from megatron.energon.bracecollapse import collapse
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.eyaml import load_yaml
from megatron.energon.flavors.webdataset.aggregator_pool import AggregatorPool, BaseAggregator
from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME, skip_meta_re, split_name_re
from megatron.energon.flavors.webdataset.indexing import SqliteIndexWriter
from megatron.energon.flavors.webdataset.itar import TarIndexWriter
from megatron.energon.flavors.webdataset.structs import ShardInfo, WebdatasetInfo, WebdatasetSplits
from megatron.energon.typed_converter import to_json_object
logger = logging.getLogger(__name__)
T = TypeVar("T", covariant=True)
@edataclass
class IndexAggregatable:
"""
A base class for all objects that can be returned/yielded by `_preprocess_tar` and
received by `SqliteIndexWriterAggregator.on_item`.
"""
...
@edataclass
class IndexSample(IndexAggregatable):
tar_file_id: int
sample_key: str
sample_index: int
byte_offset: int
byte_size: int
@edataclass
class IndexSamplePart(IndexAggregatable):
tar_file_id: int
sample_index: int
part_name: str
content_byte_offset: int
content_byte_size: int
@edataclass
class IndexShardInfo(IndexAggregatable):
shard_info: ShardInfo
parts: Set[str]
class SqliteIndexWriterAggregator(
BaseAggregator[
Tuple[ShardInfo, Set[str]], Tuple[List[ShardInfo], Set[str], bool, List[Tuple[str, int]]]
]
):
sqlite_path: EPath
total_tasks: int
progress_fn: Optional[Callable]
writer: Optional[SqliteIndexWriter]
had_update: bool
shards: List[ShardInfo]
found_parts: Set[str]
prog_iter: Iterator
def __init__(
self,
sqlite_path: EPath,
total_tasks: int,
progress_fn: Optional[Callable[[Iterator[Any], int], Iterator[T]]] = None,
):
self.sqlite_path = sqlite_path
self.total_tasks = total_tasks
self.writer = None
self.had_update = False
self.shards = []
self.found_parts = set()
if progress_fn is not None:
self.prog_iter = progress_fn(iter(range(self.total_tasks)), self.total_tasks)
else:
self.prog_iter = iter(range(self.total_tasks))
def on_start(self, aggregator_pool: AggregatorPool) -> None:
self.writer = SqliteIndexWriter(self.sqlite_path)
def on_item(
self,
item: IndexAggregatable,
aggregator_pool: AggregatorPool,
) -> None:
assert self.writer is not None, "Writer is not initialized."
if isinstance(item, IndexSample):
self.writer.append_sample(**asdict(item))
self.had_update = True
elif isinstance(item, IndexSamplePart):
self.writer.append_part(**asdict(item))
elif isinstance(item, IndexShardInfo):
# This is a (shard_info, parts) tuple
next(self.prog_iter)
shard_info, cur_parts = item.shard_info, item.parts
assert shard_info.count != 0, f"Shard {shard_info.name} has no samples."
self.shards.append(shard_info)
if len(self.found_parts) < 50:
self.found_parts.update(cur_parts)
def on_finish(self, aggregator_pool: AggregatorPool) -> None:
assert self.writer is not None, "Writer is not initialized."
self.writer.close()
def get_final_result_data(
self,
) -> Tuple[List[ShardInfo], Set[str], bool, List[Tuple[str, int]]]:
assert self.writer is not None, "Writer is not initialized."
return self.shards, self.found_parts, self.had_update, self.writer.duplicates
class WebdatasetPreparator:
@staticmethod
def _preprocess_tar(
path: str,
shard_to_idx: Dict[str, int],
parent_path: EPath,
max_parts: int,
) -> Generator[IndexAggregatable, None, None]:
"""Process a single tar file, i.e. read the tarinfos, generate the tar index and return
stats.
This method is passed to the `user_produce_data` argument of AggregatorPool.
Args:
path: Path to the tar file.
shard_to_idx: Mapping from shard path to its index
parent_path: Root path of the dataset.
max_parts: Maximum number of different parts to return
Returns:
A generator of items that will be processed by SqliteIndexWriterAggregator.
See method `on_item` of SqliteIndexWriterAggregator.
The items are either:
- A sample dictionary with information about the offset, key etc.
- Or a tuple of shard info and a set of found parts for statistics.
"""
shard_info = ShardInfo(name=path, path=parent_path / path, count=0)
try:
# Note: Write to .tmp file first, then remove .tmp extension, to make sure only complete
# files are used.
tar: tarfile.TarFile
with shard_info.path.open("rb") as f:
with (
tarfile.open(fileobj=f, mode="r:*") as tar,
TarIndexWriter(shard_info.path) as iw,
):
count = 0
# The parts set is used to collect various file endings that are
# available in the dataset. This is used for the interactive prepare wizard.
parts = set()
last_base_name = None
member: tarfile.TarInfo
next_index_sample = None
for member in tar:
if not member.isreg():
continue
if member.name is None:
continue
if skip_meta_re.match(member.name):
continue
name_match = split_name_re.match(member.name)
if name_match is None:
continue
base_name = name_match.group(1)
if len(parts) < max_parts:
parts.add(name_match.group(2))
if last_base_name != base_name:
iw.append(member.offset)
if next_index_sample is not None:
next_index_sample["byte_size"] = (
member.offset - next_index_sample["byte_offset"]
)
yield IndexSample(**next_index_sample)
next_index_sample = dict(
tar_file_id=shard_to_idx[path],
sample_key=base_name,
sample_index=count,
byte_offset=member.offset,
)
last_base_name = base_name
count += 1
# Yield this part of the sample to the aggregator
yield IndexSamplePart(
tar_file_id=shard_to_idx[path],
sample_index=count - 1,
part_name=name_match.group(2),
content_byte_offset=member.offset_data,
content_byte_size=member.size,
)
shard_info.count = count
iw.append(tar.offset)
if next_index_sample is not None:
next_index_sample["byte_size"] = (
tar.offset - next_index_sample["byte_offset"]
)
yield IndexSample(**next_index_sample)
yield IndexShardInfo(shard_info=shard_info, parts=parts)
return
except BaseException:
logger.exception(f"Shard failed to load: {path!r}. Skipping it.")
yield IndexShardInfo(shard_info=shard_info, parts=set())
return
@staticmethod
def iter_dataset_content(
path: Union[str, EPath],
extract_keys: Container[str] = (),
) -> Generator[Dict[str, Any], None, None]:
"""
Yield example dataset content for a few samples.
Args:
path: Path to the tar file.
"""
path = EPath(path)
with path.open("rb") as f:
tar: tarfile.TarFile
with tarfile.open(fileobj=f, mode="r:*") as tar:
last_base_name = None
sample = {}
member: tarfile.TarInfo
for member in tar:
if not member.isreg():
continue
if member.name is None:
continue
if skip_meta_re.match(member.name):
continue
name_match = split_name_re.match(member.name)
if name_match is None:
continue
base_name = name_match.group(1)
if last_base_name != base_name:
if sample:
yield sample
sample = {}
last_base_name = base_name
if name_match:
if name_match.group(2) in extract_keys:
sample[name_match.group(2)] = tar.extractfile(member).read()
else:
sample[name_match.group(2)] = None
if sample:
yield sample
@classmethod
def prepare_dataset(
cls,
parent_path: Union[Path, EPath],
paths: List[str],
*,
split_parts_ratio: Optional[List[Tuple[str, float]]] = None,
split_parts_patterns: Optional[List[Tuple[str, str]]] = None,
split_config: str = "split.yaml",
shuffle_seed: Optional[int] = 42,
progress_fn: Callable[[Iterator[Any], int], Iterator[T]] = (lambda x, y: x),
workers: int = 32,
tar_index_only: bool = False,
) -> Tuple[Set[str], List[Tuple[str, int]]]:
"""
Preprocess the shards and write the split config. Preprocessing is done in parallel.
Counts the number of samples in each shard.
Args:
parent_path: Common parent path for the shards
paths: Paths to the shards
split_parts_ratio: Names of splits and their ratio (will be normalized)
split_parts_patterns: Names of splits and their path patterns
split_config: Filename for the split config (`parent_path / '.nv-meta' / split_config`), may be yaml or json
shuffle_seed: Seed for shuffling shards before splitting into split_parts. None to
disable.
progress_fn: Callback for progress bar
workers: Number of parallel workers for reading each shard
tar_index_only: Only create tar-index, then exit
Returns:
The set of all parts found in the shards. But at most 50.
"""
parent_path = EPath(parent_path)
paths = [path for path in paths for path in braceexpand.braceexpand(path)]
# Construct a mapping from relative shard path to its index
shard_to_idx = {path: idx for idx, path in enumerate(paths)}
(parent_path / MAIN_FOLDER_NAME).mkdir(exist_ok=True)
aggregator = SqliteIndexWriterAggregator(
parent_path / MAIN_FOLDER_NAME / "index.sqlite",
total_tasks=len(paths),
progress_fn=progress_fn,
)
process_tar = functools.partial(
cls._preprocess_tar,
shard_to_idx=shard_to_idx,
parent_path=parent_path,
max_parts=50,
)
pool = AggregatorPool(
num_workers=workers,
user_produce_data=process_tar,
aggregator=aggregator,
)
for path in paths:
pool.submit_task(path)
shards, found_parts, had_update, duplicates = pool.process()
if had_update:
logger.info("Regenerating dataset UUID...")
with (parent_path / MAIN_FOLDER_NAME / "index.uuid").open("w") as f:
f.write(str(uuid.uuid4()))
json_info_config = parent_path / MAIN_FOLDER_NAME / ".info.json"
yaml_info_config = parent_path / MAIN_FOLDER_NAME / ".info.yaml"
if tar_index_only:
if yaml_info_config.is_file() and not json_info_config.is_file():
# Convert legacy .info.yaml to .info.json
with json_info_config.open("w") as f:
json.dump(load_yaml(yaml_info_config.read_bytes()), f, indent=2)
return found_parts, duplicates
assert len(shards) == len(shard_to_idx), (
f"Lengths of shards and shard_to_idx do not match: {len(shards)} != {len(shard_to_idx)}"
)
# Sort the shards according to the order in the input list
shards.sort(key=lambda shard: shard_to_idx[shard.name])
# Save info
assert [shard.name for shard in shards] == list(shard_to_idx.keys()), (
"Shards are not in the same order as in the input list."
)
info = WebdatasetInfo(
energon_version=__version__,
shard_counts={shard.name: shard.count for shard in shards},
)
print(f"Saving info to {json_info_config}")
with json_info_config.open("w") as wf:
json.dump(to_json_object(info), wf, indent=2)
if yaml_info_config.is_file():
# If a .info.yaml existed previously, let's also update it
# to keep them in sync
with yaml_info_config.open("w") as wf:
yaml.dump(to_json_object(info), wf)
if split_parts_ratio is not None:
# Normalize ratio
total_ratio = sum(split_ratio for _, split_ratio in split_parts_ratio)
split_parts_ratio = [
(split_part, split_ratio / total_ratio)
for split_part, split_ratio in split_parts_ratio
]
# Sample from shards based on the split ratio from split parts
split_shards = {}
if shuffle_seed is not None:
random.Random(shuffle_seed).shuffle(shards)
split_total = 0
split_offset = 0
for split_part, split_ratio in split_parts_ratio:
split_total += split_ratio
split_end = int(len(shards) * split_total)
split_shards[split_part] = [shard.name for shard in shards[split_offset:split_end]]
split_offset = split_end
else:
assert split_parts_patterns is not None, (
"Require either split_parts_ratio or split_parts_patterns"
)
# Sample from shards based on the split patterns from split parts
split_shards = {}
for split_part, split_pattern in split_parts_patterns:
patterns = [
re.compile(pattern) for pattern in braceexpand.braceexpand(split_pattern)
]
split_shards[split_part] = [
shard.name
for shard in shards
if any(pattern.match(shard.name) for pattern in patterns)
]
# Optimize the split parts by trying to bracecollapse the shard names
print("Collapsing split parts... ", flush=True, end="")
for split_part in split_shards:
split_shards[split_part] = collapse(split_shards[split_part], keep_order=True)
print("Done", flush=True)
# Save split config
splits_config = WebdatasetSplits(split_parts=split_shards)
with (parent_path / MAIN_FOLDER_NAME / split_config).open("w") as wf:
if split_config.endswith(".yaml"):
yaml.dump(to_json_object(splits_config), wf, sort_keys=False)
elif split_config.endswith(".json"):
json.dump(to_json_object(splits_config), wf, indent=2)
else:
raise ValueError(f"Invalid split config extension: {split_config}")
return found_parts, duplicates
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Callable, Literal
import webdataset
from megatron.energon.av import AVWebdatasetDecoder
from megatron.energon.cache import FileStoreDecoder
from megatron.energon.flavors.webdataset.config import split_name_re
try:
import filetype
except ImportError:
filetype = None
ImageDecoderType = Literal[
"l8",
"rgb8",
"rgba8",
"l",
"rgb",
"rgba",
"torchl8",
"torchrgb8",
"torchrgba8",
"torchl",
"torchrgb",
"torch",
"torchrgba",
"pill",
"pil",
"pilrgb",
"pilrgba",
]
AVDecoderType = Literal["torch", "AVDecoder", "pyav"]
class GuessingHandlerWrapper:
"""A wrapper that guesses the extension of the file using the `filetype` package."""
def __init__(self, handler: Callable[[str, bytes], Any]):
"""
Wraps a handler to guess the extension of the file using the `filetype` package.
Args:
handler: The handler to wrap.
"""
self.handler = handler
if filetype is None:
raise ImportError("filetype is not installed. Install it with `pip install filetype`.")
def __call__(self, key: str, data: bytes) -> Any:
"""The handler that guesses the extension of the file using the `filetype` package, then calls the delegate handler."""
kind = filetype.guess(data)
if kind is not None:
key = kind.extension
return self.handler(key, data)
@staticmethod
def wrap(
active: bool, handlers: list[Callable[[str, bytes], Any]]
) -> list[Callable[[str, bytes], Any]]:
"""
Wraps a list of handlers to guess the extension of the file using the `filetype` package.
Args:
active: Whether to wrap the handlers.
handlers: The handlers to wrap.
Returns:
The list of wrapped handlers.
"""
if not active:
return handlers
return [GuessingHandlerWrapper(handler) for handler in handlers]
class SampleDecoder(FileStoreDecoder):
"""The default decoder for webdataset samples."""
def __init__(
self,
*,
image_decode: ImageDecoderType = "torchrgb",
av_decode: AVDecoderType = "AVDecoder",
video_decode_audio: bool = False,
guess_content: bool = False,
):
"""
Args:
image_decode: This defines the decoding results.
av_decode: If "AVDecoder", returns an AVDecoder instance for flexible decoding. If "torch",
returns decoded VideoData.
video_decode_audio: Whether to decode audio from video files.
guess_content: Whether to guess the contents of the file using the `filetype` package.
"""
self._config = dict(
image_decode=image_decode,
av_decode=av_decode,
video_decode_audio=video_decode_audio,
guess_content=guess_content,
)
self._decoder = webdataset.autodecode.Decoder(
GuessingHandlerWrapper.wrap(
guess_content,
[
webdataset.autodecode.imagehandler(image_decode),
AVWebdatasetDecoder(
video_decode_audio=video_decode_audio,
av_decode=av_decode,
),
],
),
)
def decode(self, fname: str, raw: bytes) -> Any:
m = split_name_re.match(fname)
if m:
cur_base_name, ext = m.groups()
else:
# If the file name does not match the expected format, use the file name as is.
# The handlers may still handle the file correctly (e.g. if guessing by content).
cur_base_name = fname
ext = "."
return self(
{
"__key__": cur_base_name,
ext: raw,
}
)[ext]
def __call__(self, sample: dict) -> dict:
return self._decoder(sample)
def config(self) -> dict:
return self._config
DEFAULT_DECODER = SampleDecoder()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union
import torch
from megatron.energon.edataclass import edataclass
from megatron.energon.flavors.base_dataset import FlexState, SavableDataset
from megatron.energon.flavors.webdataset.itar_reader import ITarReader
from megatron.energon.flavors.webdataset.structs import FilteredSample
from megatron.energon.rng import WorkerRng
from megatron.energon.worker import WorkerConfig
@edataclass
class RawSampleData:
"""Represents the iteration state of a single slice slice to the index."""
#: Index of the sample. This is also the restore key
__restore_key__: Tuple[str, int]
#: The sample data
data: Tuple[Optional[FilteredSample], ...]
@edataclass
class SliceState:
"""Represents the iteration state of a single slice slice to the index."""
#: The slice index of this slice state
index: int
#: The actual state: The global sample offset (`slice[index] <= offset < slice[index + 1]``)
current: int
class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]):
"""Internal class for loading samples from webdataset slices"""
#: The readers for each joined dataset
join_readers: Sequence[ITarReader]
#: The offsets of the slice slices to iterate over for the current worker
slice_offsets: Optional[Sequence[int]]
# If = 1, every sample is seen exactly once per epoch. If > 1, samples
# (or rather slice slices) are shuffled within this number of epochs (i.e. randomly
# selected without replacement). If None, the slices are effectively shuffle over
# infinite epochs (i.e. slice slices are drawn with replacement).
shuffle_over_epochs: Optional[int]
# Number of parallel iterators to be opened simultaneously (and random sample between them)
parallel_slice_iters: int
# Worker's random generator
_worker_rng: WorkerRng
#: The RNG state to be used for regenerating the pending slices
_pending_slices_rng_state: Optional[FlexState]
#: The number of slices that have already been opened / processed and thus been removed from the
# pending slices.
_pending_slices_offset: Optional[int]
#: Pending slices are the slices which have not yet been opened, but should be processed
# in the current "epoch". If None, regenerate from the seed and offset.
_pending_slice_indexes: Optional[List[int]]
#: The active slices are the currently opened slices. May contain `None`, if there are fewer
# slices available (i.e. pending_slices empty) than parallel slice iterators requested.
_active_slice_state: List[Optional[SliceState]]
#: The total number of samples retrieved, it's just a monotonically increasing counter
_sample_count: int
#: Number of epochs this dataset has been iterated over
_epoch_count: int
#: The number of samples retrieved in current epoch
_epoch_sample_count: int
_savable_fields = (
"_worker_rng",
"_pending_slices_offset",
"_pending_slice_indexes",
"_active_slice_state",
"_sample_count",
"_epoch_count",
"_epoch_sample_count",
)
def __init__(
self,
join_readers: Sequence[ITarReader],
workers_sample_slice_offsets: Sequence[Sequence[int]],
*,
worker_config: WorkerConfig,
shuffle_over_epochs: Optional[int] = None,
parallel_slice_iters: int = 1,
):
"""
The webdataset loader. Iterates over the slice infos and yields the samples.
Args:
join_readers: A sequence of the joined readers (or just a single reader) to iterate over.
worker_slice_offsets: The offsets of the slice slices to iterate over, for each worker.
worker_config: The worker configuration.
shuffle_over_epochs: If None, disable shuffling.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather slice slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the slices are effectively shuffle over infinite epochs (i.e. slice slices
are drawn with replacement).
parallel_slice_iters: If > 1, samples are randomly drawn from parallel slice iterators.
This will not impact performance, but increase randomness. If = 1, the slices are
iterated in order.
"""
super().__init__(worker_config=worker_config)
self.join_readers = join_readers
self.shuffle_over_epochs = shuffle_over_epochs
self.parallel_slice_iters = parallel_slice_iters
# Store the slices for all workers
# The slices for the current worker, will have to be extracted from this list later
self.workers_slice_offsets = workers_sample_slice_offsets
self.slice_offsets = None
self.reset_state_own()
assert shuffle_over_epochs is None or shuffle_over_epochs == -1 or shuffle_over_epochs >= 1
assert self.parallel_slice_iters >= 1
def reset_state_own(self) -> None:
self._worker_rng = WorkerRng(self.worker_config)
self._pending_slice_indexes = None
self._pending_slices_offset = None
self._pending_slices_rng_state = None
self._active_slice_state = [None] * self.parallel_slice_iters
self._sample_count = 0
self._epoch_count = 0
self._epoch_sample_count = 0
def ensure_slice_offsets(self) -> None:
self.worker_config.assert_worker()
if self.slice_offsets is None:
self.slice_offsets = self.workers_slice_offsets[self.worker_config.rank_worker_id()]
def _get_sample(self, index: int) -> RawSampleData:
return RawSampleData(
__restore_key__=("Webdataset", index),
data=tuple(reader[index] for reader in self.join_readers),
)
def _slices_once(self) -> List[int]:
"""Yields the indexes to slice offsets once. Possibly shuffles the list."""
assert self.slice_offsets is not None
num_slices = len(self.slice_offsets) - 1
slices_offset = self._pending_slices_offset
if self.shuffle_over_epochs is None:
# No shuffling
res_list = list(range(num_slices))
if slices_offset is None:
slices_offset = 0
else:
# Restore state or start new (and save)
if slices_offset is None:
# Start new state. First, save the state to restore the same order.
self._pending_slices_rng_state = self._worker_rng.save_state()
rng = self._worker_rng
slices_offset = 0
else:
# Restore the state. Create a dedicated rng for this, as the main rng is in the
# state for iterating from the next iterator.
assert self._pending_slices_rng_state is not None
rng = WorkerRng(self.worker_config)
rng.restore_state(self._pending_slices_rng_state)
if self.shuffle_over_epochs == -1:
# Shuffle with replacement (i.e. infinite epochs), effectively return as many slices
# as are required for parallel slice iterators.
# Next slices are drawn in the _slices_iter.
res_list = [rng.randbelow(num_slices) for _ in range(self.parallel_slice_iters)]
elif self.shuffle_over_epochs >= 1:
# Shuffle without replacement (potentially over multiple epochs)
res_list = rng.shuffle(list(range(num_slices)) * self.shuffle_over_epochs)
else:
raise ValueError(f"Invalid shuffle_over_epochs: {self.shuffle_over_epochs}")
# Reverse, such that pop returns the first element (in O(1) time)
res_list.reverse()
# Skip restored slice list already processed slices
assert slices_offset is not None
self._pending_slices_offset = slices_offset
if slices_offset > 0:
# Those have already been popped in the current state
del res_list[-slices_offset:]
# Set the pending slices
self._pending_slice_indexes = res_list
return res_list
def _slices_iter(self) -> Generator[RawSampleData, None, None]:
"""Iterates the samples in a list of slices, possibly using multiple parallel iterators over
the slices."""
assert self.slice_offsets is not None
active_slice_probs = torch.zeros(self.parallel_slice_iters, dtype=torch.float32)
active_slices = self._active_slice_state
pending_slice_indexes = self._pending_slice_indexes
def slice_at(idx: int) -> SliceState:
assert self.slice_offsets is not None
return SliceState(
index=idx,
current=self.slice_offsets[idx],
)
# Weight the slices by their size to get a more even distribution of samples
if any(s is not None for s in active_slices) or self._pending_slices_offset is not None:
# Having an active state, or pending slices. This means we are resuming an epoch.
if pending_slice_indexes is None:
# Need to restore the pending slices
pending_slice_indexes = self._slices_once()
assert pending_slice_indexes is not None
# Restore the state
assert len(active_slices) == self.parallel_slice_iters
for idx, slice_state in enumerate(active_slices):
if slice_state is not None:
active_slice_probs[idx] = (
self.slice_offsets[slice_state.index + 1]
- self.slice_offsets[slice_state.index]
)
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "WebdatasetSampleLoaderDataset._slices_iter.resume_epoch",
"r": self.worker_config.rank,
"w": self.worker_config.rank_worker_id(),
"pending_slice_indexes": pending_slice_indexes,
"active_slices": [
(
None
if state is None
else {
"index": state.index,
"current": state.current,
}
)
for state in active_slices
],
"count": self._sample_count,
"epoch": self._epoch_count,
"epoch_count": self._epoch_sample_count,
"probs": active_slice_probs.tolist(),
}
)
else:
# Start a new epoch
assert pending_slice_indexes is None
pending_slice_indexes = self._slices_once()
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "WebdatasetSampleLoaderDataset._slices_iter.next_epoch",
"r": self.worker_config.rank,
"w": self.worker_config.rank_worker_id(),
"pending_slice_indexes": pending_slice_indexes,
"count": self._sample_count,
"epoch": self._epoch_count,
"epoch_count": self._epoch_sample_count,
"probs": active_slice_probs.tolist(),
"shuffle_over_epochs": self.shuffle_over_epochs,
}
)
assert self._pending_slices_offset is not None
# List of slice iterators, always of length `parallel_slice_iters`. May contain `None`.
active_slices.clear()
# Fill up the slice iterators
while len(pending_slice_indexes) > 0 and len(active_slices) < self.parallel_slice_iters:
slice_index = pending_slice_indexes.pop()
self._pending_slices_offset += 1
slice_state = slice_at(slice_index)
active_slice_probs[len(active_slices)] = (
self.slice_offsets[slice_state.index + 1]
- self.slice_offsets[slice_state.index]
)
active_slices.append(slice_state)
# Fill up the slice iterators with None
for _ in range(len(active_slices), self.parallel_slice_iters):
active_slices.append(None)
# print(
# f"Next slice iters generated for {self.worker_config.rank}:{self.worker_config.rank_worker_id()}: probs={active_slice_probs}"
# )
# for slice_state in active_slices:
# if slice_state is None:
# print(" - None")
# else:
# print(
# f" - [{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] at {slice_state.current}"
# )
# Iterate over the slice iterators while there is an iterator left
while torch.count_nonzero(active_slice_probs).item() > 0:
if self.shuffle_over_epochs is None:
# No shuffling, deterministic order, always the same
assert self.parallel_slice_iters == 1
slice_idx = 0
else:
# Take a random slice iterator
slice_idx = self._worker_rng.choice_idx(active_slice_probs)
slice_state = active_slices[slice_idx]
assert slice_state is not None
sample = self._get_sample(slice_state.current)
# print(f"Read sample at {slice_state.current} -> {'None' if sample is None or sample.data[0] is None else sample.data[0]['__key__']}")
slice_state.current += 1
self._sample_count += 1
self._epoch_sample_count += 1
if slice_state.current >= self.slice_offsets[slice_state.index + 1]:
# Iterator exhausted -> take next / remove from list
if len(pending_slice_indexes) > 0 or self.shuffle_over_epochs == -1:
if len(pending_slice_indexes) > 0:
# Take the next slice (without replacement)
next_idx = pending_slice_indexes.pop()
assert self._pending_slices_offset is not None
self._pending_slices_offset += 1
else:
# Randomly select a new slice directly (with replacement)
num_slices = len(self.slice_offsets) - 1
next_idx = self._worker_rng.randbelow(num_slices)
next_slice_state = slice_at(next_idx)
active_slice_probs[slice_idx] = (
self.slice_offsets[next_slice_state.index + 1]
- self.slice_offsets[next_slice_state.index]
)
active_slices[slice_idx] = next_slice_state
# print(
# f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} "
# f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, "
# f"taking next slice {next_slice_state} [{slice_offsets[next_slice_state.index]}, {slice_offsets[next_slice_state.index + 1]}], "
# f"{len(pending_slice_indexes)} slices left, probs={active_slice_probs.tolist()}"
# )
else:
active_slice_probs[slice_idx] = 0
active_slices[slice_idx] = None
# print(
# f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} "
# f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, "
# f"no next slice, probs={active_slice_probs.tolist()}"
# )
if self.worker_config.should_log(level=2):
self.worker_config.worker_log(
{
"t": "WebdatasetSampleLoaderDataset._slices_iter.exhausted",
"r": self.worker_config.rank,
"w": self.worker_config.rank_worker_id(),
"remaining": len(pending_slice_indexes),
"count": self._sample_count,
"epoch": self._epoch_count,
"epoch_count": self._epoch_sample_count,
"probs": active_slice_probs.tolist(),
}
)
if sample.data[0] is not None:
# Otherwise the sample was skipped.
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "WebdatasetSampleLoaderDataset._slices_iter.yield",
"r": self.worker_config.rank,
"w": self.worker_config.rank_worker_id(),
"index": sample.__restore_key__[1],
"key": sample.data[0]["__key__"],
"shard": sample.data[0]["__shard__"],
"count": self._sample_count,
"epoch": self._epoch_count,
"epoch_count": self._epoch_sample_count,
}
)
# Now, yield the sample
yield sample
del sample
if self.worker_config.should_log(level=2):
self.worker_config.worker_log(
{
"t": "WebdatasetSampleLoaderDataset._slices_iter.all_exhausted",
"r": self.worker_config.rank,
"w": self.worker_config.rank_worker_id(),
"count": self._sample_count,
"epoch": self._epoch_count,
"epoch_count": self._epoch_sample_count,
}
)
# Epoch has finished, reset states.
self._epoch_count += 1
self._epoch_sample_count = 0
self._pending_slice_indexes = None
self._pending_slices_offset = None
# print(
# f"slice iters exhausted for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} after {cnt} samples"
# )
def len_worker(self, worker_idx: int | None = None) -> int:
if worker_idx is None:
self.worker_config.assert_worker()
worker_idx = self.worker_config.rank_worker_id()
worker_slice_offsets = self.workers_slice_offsets[worker_idx]
return worker_slice_offsets[-1] - worker_slice_offsets[0]
def worker_has_samples(self) -> bool:
self.worker_config.assert_worker()
self.ensure_slice_offsets()
assert self.slice_offsets is not None
return len(self.slice_offsets) > 1
def __iter__(self) -> Iterator[RawSampleData]:
self.worker_config.assert_worker()
self.ensure_slice_offsets()
assert self.slice_offsets is not None
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "WebdatasetSampleLoaderDataset.__iter__",
"r": self.worker_config.rank,
"w": self.worker_config.rank_worker_id(),
"slice_offsets": self.slice_offsets,
"parallel_slice_iters": self.parallel_slice_iters,
"shuffle_over_epochs": self.shuffle_over_epochs,
}
)
if len(self.slice_offsets) <= 1:
return
yield from self._slices_iter()
def can_restore_sample(self) -> bool:
return True
def assert_can_restore(self) -> None:
pass
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> RawSampleData:
# Key is: ("Webdataset", index)
# The key is joined in the dataset's typed joining (i.e. load_sample of JoinedWebdatasetFactory).
id, index = restore_key
assert id == "Webdataset"
assert isinstance(index, int)
return self._get_sample(index)
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"joins": len(self.join_readers),
"len": len(self),
"slice_offsets": [
list(worker_slice_offsets) for worker_slice_offsets in self.workers_slice_offsets
],
"worker_config": self.worker_config.config(),
"shuffle_over_epochs": self.shuffle_over_epochs,
"parallel_slice_iters": self.parallel_slice_iters,
}
def __str__(self):
return f"WebdatasetSampleLoaderDataset(join_readers={self.join_readers}, shuffle_over_epochs={self.shuffle_over_epochs}, parallel_slice_iters={self.parallel_slice_iters})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from itertools import zip_longest
from typing import Generator, Optional, Sequence, Tuple, Union
import numpy as np
from megatron.energon.flavors.webdataset.structs import DatasetSubset, ShardInfo
from megatron.energon.worker import WorkerConfig
class Sharder:
@staticmethod
def _split_shard(
start_offset: int,
end_offset: int,
max_samples_per_sequence: Optional[int],
) -> Tuple[int, ...]:
"""Splits a shard into multiple slices of max_samples_per_sequence (more or less).
Returns the starting index of each slice (excluding the end_offset)."""
if (
max_samples_per_sequence is not None
and end_offset - start_offset > max_samples_per_sequence * 1.5
):
# Split the shard into slices of max_samples_per_sequence (more or less)
slice_count = max(round((end_offset - start_offset) / max_samples_per_sequence), 1)
samples_per_sequence = (end_offset - start_offset) / slice_count
# Note this must include the end offset as well, so slice_count + 1 steps
return tuple(
start_offset + int(slice * samples_per_sequence) for slice in range(slice_count)
)
else:
return (start_offset,)
@classmethod
def _split_shards(
cls,
shard_cumsums: np.ndarray,
offsets: Sequence[int],
*,
max_samples_per_sequence: Optional[int],
) -> Generator[Sequence[int], None, None]:
"""
Splits the shards into multiple lists based on the offsets. The first offset is the start
of the first shard emitted, the last offset is the beginning of the last shard emitted.
(i.e. number of slice sequences emitted is `len(offsets) - 1`).
Args:
shard_cumsums: The source shard offsets
offsets: The offsets to samples to get shards for (must be strictly increasing)
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequential).
Returns:
A list of starting offsets for each slice (including the end offset)
"""
# Find shard idx for start
start_index = np.searchsorted(shard_cumsums, offsets[0], side="right") - 1
for start_offset, end_offset in zip(offsets, offsets[1:]):
# Find shard idx for end
end_index = start_index
while end_index + 1 < len(shard_cumsums) and end_offset > shard_cumsums[end_index + 1]:
end_index += 1
if start_index == end_index:
yield (
*cls._split_shard(
start_offset=start_offset,
end_offset=end_offset,
max_samples_per_sequence=max_samples_per_sequence,
),
end_offset,
)
else:
# Middle is the original shards, start and end get an offset/length
yield (
*(
cls._split_shard(
start_offset=start_offset,
end_offset=shard_cumsums[start_index + 1],
max_samples_per_sequence=max_samples_per_sequence,
)
if shard_cumsums[start_index + 1] > start_offset
else ()
),
*(
offset
for inner_shard_start, inner_shard_end in zip(
shard_cumsums[start_index + 1 : end_index],
shard_cumsums[start_index + 2 : end_index + 1],
)
for offset in cls._split_shard(
start_offset=inner_shard_start,
end_offset=inner_shard_end,
max_samples_per_sequence=max_samples_per_sequence,
)
),
*cls._split_shard(
start_offset=shard_cumsums[end_index],
end_offset=end_offset,
max_samples_per_sequence=max_samples_per_sequence,
),
end_offset,
)
start_index = end_index
@classmethod
def _split_slices(
cls,
offsets: Sequence[int],
*,
max_samples_per_sequence: Optional[int],
) -> Generator[Sequence[int], None, None]:
"""
Splits the offsets into approximately `max_samples_per_sequence` sized slices. Each sequence
of slices includes the end of that sequence.
Args:
offsets: The offsets to samples to get shards for (must be strictly increasing)
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequential).
Returns:
A list of offsets for each slice sequence.
"""
for start, end in zip(offsets[:-1], offsets[1:]):
yield (
*cls._split_shard(
start_offset=start,
end_offset=end,
max_samples_per_sequence=max_samples_per_sequence,
),
end,
)
@classmethod
def _generalized_bit_reversal(
cls, length_or_indices: Union[int, Sequence[int]]
) -> Sequence[int]:
"""This function creates a permutation of given length.
The sequence is created by a recursive divide and interleave algorithm
to ensure a balanced distribution across ranks.
It corresponds to a generalized bit reversal permutation, which - for lengths
of power of two - is the reversed binary representation of the original indices.
For example for 16 indices, the sequence is:
[0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]
Visual illustration:
Step|0|1|2|3|4|5|6|7|8|9|A|B|C|D|E|F|
|-------------------------------|
0|X| | | | | | | | | | | | | | | |
1|X| | | | | | | |X| | | | | | | |
2|X| | | |X| | | |X| | | | | | | |
3|X| | | |X| | | |X| | | |X| | | |
4|X| |X| |X| | | |X| | | |X| | | |
5|X| |X| |X| | | |X| |X| |X| | | |
6|X| |X| |X| |X| |X| |X| |X| | | |
7|X| |X| |X| |X| |X| |X| |X| |X| |
8|X|X|X| |X| |X| |X| |X| |X| |X| |
9|X|X|X| |X| |X| |X|X|X| |X| |X| |
10|X|X|X| |X|X|X| |X|X|X| |X| |X| |
11|X|X|X| |X|X|X| |X|X|X| |X|X|X| |
12|X|X|X|X|X|X|X| |X|X|X| |X|X|X| |
13|X|X|X|X|X|X|X| |X|X|X|X|X|X|X| |
14|X|X|X|X|X|X|X|X|X|X|X|X|X|X|X| |
15|X|X|X|X|X|X|X|X|X|X|X|X|X|X|X|X|
"""
if isinstance(length_or_indices, int):
indices = list(range(length_or_indices))
else:
indices = length_or_indices
if len(indices) <= 2:
return indices
mid = len(indices) // 2
left = indices[:mid]
right = indices[mid:]
left_result = cls._generalized_bit_reversal(left)
right_result = cls._generalized_bit_reversal(right)
# Interleave the results
zipped = zip_longest(left_result, right_result)
result = [item for sublist in zipped for item in sublist if item is not None]
return result
@classmethod
def split_samples_to_workers(
cls,
start_samples: int,
end_samples: int,
worker_config: WorkerConfig,
*,
rotation_offset: int = 0,
) -> Sequence[int]:
# We split the total number of samples into the number of global workers across all ranks.
# Note that the global number of workers intentionally stays the same if you
# divide the number of ranks by N, and multiply the number of workers per rank by N.
# This allows to reproduce the same global batches with a different number of ranks.
total_samples = end_samples - start_samples
num_workers = max(1, worker_config.num_workers)
global_workers = num_workers * worker_config.world_size
min_samples_per_worker = int(total_samples / global_workers)
num_workers_with_more_samples = total_samples % global_workers
# We are going to compute the samples assigned to each worker on the current rank.
# This is done in multiple steps.
# Some of these steps could be collapsed into one, but we keep them separate for clarity:
# 1. Compute the number of samples per global worker (rotated by rotation_offset,
# typically given by previous datasets).
# 2. Permute the nuber of samples per global worker by a generalized bit reversal sequence
# 3. Given the sample counts, compute the start and end indices for each global worker
# 4. Extract the local worker sample assignments for the current rank.
# 5. Split the shards based on the start and end indices.
# 1. Let's compute it globally for all workers first
num_samples_per_global_worker = []
for global_worker_idx in range(global_workers):
if (
global_worker_idx - rotation_offset + global_workers
) % global_workers < num_workers_with_more_samples:
# This worker gets one more sample
num_samples_per_global_worker.append(min_samples_per_worker + 1)
else:
# This worker gets the minimum number of samples
num_samples_per_global_worker.append(min_samples_per_worker)
# 2. Permute the number of samples per global worker
worker_bitrev_seq = cls._generalized_bit_reversal(global_workers)
# The worker_bitrev_seq is the order in which any remainder samples shall
# be assigned to workers.
# That means, the x-axis (array index) is the remainder sample index
# and the y-axis (value) is the global worker index.
# So we map the y (value) to the old global worker index from the linear sequence.
new_num_samples_per_global_worker = [-1] * global_workers
for old_worker_idx, new_worker_idx in enumerate(worker_bitrev_seq):
new_num_samples_per_global_worker[new_worker_idx] = num_samples_per_global_worker[
old_worker_idx
]
num_samples_per_global_worker = new_num_samples_per_global_worker
# 3. Compute the global worker sample start and end indices
global_worker_sample_split_offsets = [start_samples]
cur_offset = start_samples
for global_worker_idx in range(global_workers):
cur_offset += num_samples_per_global_worker[global_worker_idx]
global_worker_sample_split_offsets.append(cur_offset)
# 4. Now we extract the local rank's worker ranges
local_worker_sample_split_offsets = global_worker_sample_split_offsets[
worker_config.rank * num_workers : (worker_config.rank + 1) * num_workers + 1
]
assert len(local_worker_sample_split_offsets) == num_workers + 1, (
"If this fails, there's a bug in the code above."
)
return local_worker_sample_split_offsets
@staticmethod
def _clean_offsets(offsets: Sequence[int]) -> Sequence[int]:
"""Removes empty offset slices, i.e. duplicates from offsets."""
return (
*(int(start) for start, end in zip(offsets, offsets[1:]) if start < end),
int(offsets[-1]),
)
@staticmethod
def _compute_subset(
total_samples: int,
subset: Optional[DatasetSubset] = None,
) -> tuple[int, int]:
start_samples = 0
end_samples = total_samples
if subset is None:
return start_samples, end_samples
if subset.absolute_range is not None:
start_samples, end_samples = subset.absolute_range
if end_samples is None:
end_samples = total_samples
assert end_samples <= total_samples, (
f"Subset samples {subset.absolute_range} {end_samples=} > {total_samples=}"
)
assert start_samples <= end_samples, (
f"Subset samples {subset.absolute_range} {start_samples=} > {end_samples=}"
)
assert start_samples >= 0, (
f"Subset samples {subset.absolute_range} {start_samples=} < 0"
)
if subset.range is not None:
previous_total = end_samples - start_samples
end_samples = start_samples + int(previous_total * subset.range[1])
start_samples += int(previous_total * subset.range[0])
assert end_samples <= total_samples, (
f"Subset ratio {subset.range} {end_samples=} is larger than total samples {total_samples}"
)
assert start_samples <= end_samples, (
f"Subset ratio {subset.range} {start_samples=} > {end_samples=}"
)
assert start_samples >= 0, f"Subset ratio {subset.range} {start_samples=} < 0"
return start_samples, end_samples
@classmethod
def shard_workers(
cls,
shards: Sequence[ShardInfo],
worker_config: WorkerConfig,
*,
max_samples_per_sequence: Optional[int],
subset: Optional[DatasetSubset] = None,
rotation_offset: int = 0,
) -> Sequence[Sequence[int]]:
"""
Creates shard slices for each worker of the current rank.
For that, the number of global samples is split across the number of global workers across all
ranks. Then each worker gets a slice of the global samples.
Args:
shards: The shards to split
worker_config: The config for the current rank and workers
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequential).
subset: If specified, the dataset will be subsetted to the given ratio.
rotation_offset: The offset to use for the worker rotation.
Returns:
The shards for the current rank and all workers
"""
end_samples = sum(shard.count for shard in shards)
if subset is not None:
start_samples, end_samples = subset.compute_subset(end_samples)
else:
start_samples = 0
local_worker_sample_split_offsets = cls.split_samples_to_workers(
start_samples,
end_samples,
worker_config,
rotation_offset=rotation_offset,
)
shard_cumsums = np.cumsum([0] + [shard.count for shard in shards])
return tuple(
# Filter out any empty shards for this worker
cls._clean_offsets(offsets)
for offsets in cls._split_shards(
shard_cumsums,
local_worker_sample_split_offsets,
max_samples_per_sequence=max_samples_per_sequence,
)
)
@classmethod
def slice_workers(
cls,
total_samples: int,
worker_config: WorkerConfig,
*,
max_samples_per_sequence: Optional[int],
subset: Optional[DatasetSubset] = None,
rotation_offset: int = 0,
) -> Sequence[Sequence[int]]:
"""
Creates shard slices for each worker of the current rank.
For that, the number of global samples is split across the number of global workers across all
ranks. Then each worker gets a slice of the global samples.
Args:
total_samples: The total number of samples
worker_config: The config for the current rank and workers
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequential).
subset: If specified, the dataset will be subsetted to the given ratio.
rotation_offset: The offset to use for the worker rotation.
Returns:
The shards for the current rank and all workers
"""
start_samples, end_samples = cls._compute_subset(total_samples, subset)
local_worker_sample_split_offsets = cls.split_samples_to_workers(
start_samples,
end_samples,
worker_config,
rotation_offset=rotation_offset,
)
# Split the shards
return tuple(
# Filter out any empty shards for this worker
cls._clean_offsets(offsets)
for offsets in cls._split_slices(
local_worker_sample_split_offsets,
max_samples_per_sequence=max_samples_per_sequence,
)
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Generic, Type, TypeVar
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset.default_decoder_webdataset import (
DefaultDecoderWebdatasetFactory,
)
T_sample = TypeVar("T_sample", covariant=True)
class StandardWebdatasetFactory(DefaultDecoderWebdatasetFactory[T_sample], Generic[T_sample]):
"""
This dataset sample loader factory uses the sample type e.g. given from a dataset.yaml, and applies the default
loading logic, which includes decoding images, videos and containers.
"""
def __init__(
self,
path: EPath,
*,
sample_type: Type[T_sample],
**kwargs,
):
"""
Factory for the standard webdataset sample loader.
Args:
path: Path to the dataset (passed to parent)
sample_type: Type of the sample to be loaded
auto_decode: If true, use the default webdataset sample decoder.
image_decode: This defines the decoding results.
ignore_decoder_errors: If true, ignore errors when decoding.
subflavors: Subflavors dictionary to set for all loaded samples.
field_map: Mapping from the webdataset fields to the sample fields.
sample_loader: Function to load the sample from the webdataset fields. May be a string
in order to load a function from a module, or a callable directly.
part_filter: Filter for the parts to load. May be a string in order to load a function
from a module, or a callable directly.
split_part: Which part to load (e.g. 'train', 'val', 'test').
training: If true, apply shuffling and loop the dataset.
worker_config: Configuration for the workers.
shuffle_over_epochs: Only effective if training=True.
How many epochs to shuffle over if training.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather shard slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices
are drawn with replacement).
parallel_shard_iters: Number of parallel opened shards per worker, shuffling between.
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequentially iterated).
subset: If specified, the dataset will be subsetted.
split_config: Config file to use for shard split definitions.
handler: Exception handler. Args: (exception, key).
"""
self.__sample_type__ = sample_type
assert issubclass(sample_type, Sample), (
f"sample_type must be subclass of megatron.energon.Sample, got {sample_type.__qualname__}"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from dataclasses import field
from typing import Dict, List, Optional, Tuple, TypedDict
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.source_info import SourceInfo
@edataclass
class WebdatasetInfo:
"""Info about a webdataset. Format for `.nv-meta/.info.yaml` or `.nv-meta/.info.json`."""
#: The version of the energon library that was used to prepare the dataset
energon_version: Optional[str] = None
#: Maps shard name to number of samples in that shard
shard_counts: Dict[str, int]
@edataclass
class WebdatasetSplits:
"""Info about the splits of a webdataset. Format for `.nv-meta/split.yaml` or `.nv-meta/split.json`
(or custom user yaml/json)."""
#: Maps split part to list of shard names
split_parts: Dict[str, List[str]]
#: Set of "<shard name>" or "<shard name>/<sample index>" to exclude
exclude: List[str] = field(default_factory=list)
@edataclass
class ShardInfo:
"""Info about a single shard as passed through internally. Not exposed to the user."""
#: Name of the shard file (relative path from the nvinfo dir)
name: str
#: The path to the shard file
path: EPath
#: The number of samples in this shard
count: int
class FilteredSample(TypedDict):
"""This is just a definition for the internal loaders. Not exposed to the user."""
#: The key of the sample within the tar file.
#: If the tar file contains files 12.jpg and 12.txt,
#: those two files make one sample with the key "12"
__key__: str
#: The base name of the shard file e.g. "shard_000"
__shard__: str
#: Globally unique key to restore a sample from disk.
#: For example `("Webdataset", 123)` would restore the sample at index 123.
__restore_key__: Tuple[str, int]
#: The source information for the sample.
__sources__: tuple[SourceInfo, ...]
@edataclass
class DatasetSubset:
"""A subset of a dataset.
A range is a tuple of two values, where the first value is the start of the subset and the second value is the end of the subset.
The sharder uses the (absolute/relative) ranges to compute the subsets:
* `absolute_range` (unit is samples) is applied first on the (e.g. train/val/test) subset
* then `range` (where `(0, 1)` would correspond to the whole dataset) is applied as relative ratio on the subset that is left.
This is the struct used internally for computing the range. The config is loaded via the metadataset_v2.
"""
range: tuple[float, float] | None = None
absolute_range: tuple[int, int | None] | None = None
def compute_subset(
self,
total_samples: int,
) -> tuple[int, int]:
"""
Computes the absolute subset of samples from the total number of samples.
The absolute range is applied first, then the relative range is applied on the subset that is left.
"""
start_samples = 0
end_samples = total_samples
if self.absolute_range is not None:
start_samples, end_samples = self.absolute_range
if end_samples is None:
end_samples = total_samples
assert end_samples <= total_samples, (
f"Subset samples {self.absolute_range} {end_samples=} > {total_samples=}"
)
assert start_samples <= end_samples, (
f"Subset samples {self.absolute_range} {start_samples=} > {end_samples=}"
)
assert start_samples >= 0, f"Subset samples {self.absolute_range} {start_samples=} < 0"
if self.range is not None:
previous_total = end_samples - start_samples
end_samples = start_samples + int(previous_total * self.range[1])
start_samples += int(previous_total * self.range[0])
assert end_samples <= total_samples, (
f"Subset ratio {self.range} {end_samples=} is larger than total samples {total_samples}"
)
assert start_samples <= end_samples, (
f"Subset ratio {self.range} {start_samples=} > {end_samples=}"
)
assert start_samples >= 0, f"Subset ratio {self.range} {start_samples=} < 0"
return start_samples, end_samples
def config(self) -> dict:
return {
"range": self.range,
"absolute_range": self.absolute_range,
}
def reraise_exception(
exc: Exception, key: Optional[str], sources: Optional[list[SourceInfo]] = None
) -> None:
if sources:
raise Exception(
f"For sample {key!r} from {', '.join(f'{source.dataset_path}[{source.index}] {source.shard_name}{source.file_names!r}' for source in sources)}"
) from exc
elif key:
raise Exception(f"For sample {key!r}") from exc
else:
raise
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import os
import random
import sqlite3
import threading
import time
from typing import Any, ClassVar
class ThreadLocalStorage:
"""
A class that allows to store data in a thread-local storage.
Example Usage:
```python
class MyThreadLocalStorage(ThreadLocalStorage):
__thread_local__ = ("my_data",)
# This is shared across threads
other_data: int
# This is local per thread
my_data: int
def __thread_init__(self):
# This is called when the data on a thread is initialized, which has
# not been accessed yet on that thread to set the value of that data.
self.my_data = 0
```
"""
__thread_local__: ClassVar[tuple[str, ...]]
_storage: object
def __init__(self):
self._storage = threading.local()
def __getattribute__(self, name: str) -> Any:
if name in ("__thread_local__", "_storage"):
return object.__getattribute__(self, name)
if name in self.__thread_local__:
if not self._thread_initialized:
self._storage.__initialized__ = True
self.__thread_init__()
return getattr(self._storage, name)
return object.__getattribute__(self, name)
def __delattr__(self, name: str) -> None:
if name in self.__thread_local__:
delattr(self._storage, name)
return
object.__delattr__(self, name)
def __setattr__(self, name: str, value: Any) -> None:
if name in self.__thread_local__:
if not self._thread_initialized:
self._storage.__initialized__ = True
self.__thread_init__()
setattr(self._storage, name, value)
return
object.__setattr__(self, name, value)
@property
def _thread_initialized(self) -> bool:
"""Check if the thread has been initialized."""
return getattr(self._storage, "__initialized__", False)
def thread_close(self):
"""Close the thread-local storage."""
if self._thread_initialized:
delattr(self._storage, "__initialized__")
def __thread_init__(self):
"""Called when the data on a thread is accessed for the first time, to
set the initial value of that data."""
# Copy the data from the default values
for name in self.__thread_local__:
try:
default_value = object.__getattribute__(self, name)
except AttributeError:
pass
else:
setattr(self._storage, name, default_value)
class ThreadLocalSqlite(ThreadLocalStorage):
"""A class that allows to store data in a thread-local storage."""
database: str
is_uri: bool
__thread_local__ = ("connection", "cursor")
connection: sqlite3.Connection
cursor: sqlite3.Cursor
def __init__(self, database: str, is_uri: bool = False):
super().__init__()
self.database = database
self.is_uri = is_uri
def __thread_init__(self):
"""Initialize the connection and cursor."""
self.connection = sqlite3.connect(self.database, uri=self.is_uri)
self.cursor = self.connection.cursor()
self.connection.execute("PRAGMA busy_timeout = 5000;")
def select_one(self, query: str, params: tuple[Any, ...] = ()):
"""Select one row from the database."""
self.cursor.execute(query, params)
return self.cursor.fetchone()
def select_all(self, query: str, params: tuple[Any, ...] = ()):
"""Select all rows from the database."""
self.cursor.execute(query, params)
return self.cursor.fetchall()
def thread_close(self):
"""Close the connection and cursor."""
if self._thread_initialized:
self.cursor.close()
self.connection.close()
super().thread_close()
def main():
"""Test the ThreadLocalSqlite class."""
import concurrent.futures
sqlite = ThreadLocalSqlite("tmp.sqlite")
sqlite.cursor.execute("CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY, name TEXT)")
sqlite.cursor.execute("INSERT INTO test (name) VALUES (?)", ("test",))
sqlite.connection.commit()
def _test_thread_local(sqlite_thread_local: ThreadLocalSqlite):
time.sleep(random.random())
print(sqlite_thread_local.select_all("SELECT * FROM test"))
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for _ in range(20):
futures.append(executor.submit(_test_thread_local, sqlite))
for future in concurrent.futures.as_completed(futures):
future.result()
os.remove("tmp.sqlite")
if __name__ == "__main__":
main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import functools
import os
import weakref
from dataclasses import dataclass
from typing import Callable
def _cleanup(hooks, key, wr):
hooks.pop(key)
class WeakCallbacks:
"""
A class that manages weak references to callback functions.
"""
# A dictionary of weak (or strong) references to functions.
_hooks: dict[int, Callable[[], Callable[..., None] | None]]
def __init__(self):
"""
Initialize the registry.
"""
self._hooks: dict[int, Callable[[], Callable[..., None] | None]] = {}
def add_hook(self, callable: Callable[..., None], make_persistent: bool = False) -> None:
"""
Add a callback to the registry.
Args:
callable: The function to run before the fork of a worker process.
make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
"""
if make_persistent:
# Not a weakref, but always return the callable.
self._hooks[id(callable)] = lambda: callable
elif getattr(callable, "__self__", None):
# Add a method reference to the hooks
key = id(callable.__self__)
self._hooks[key] = weakref.WeakMethod(
callable, functools.partial(_cleanup, self._hooks, key)
)
else:
# Add a function reference to the hooks
key = id(callable)
self._hooks[key] = weakref.ref(callable, functools.partial(_cleanup, self._hooks, key))
def run(self, *args, **kwargs) -> None:
"""
Run all the callbacks in the registry, passing the given arguments.
"""
for hook in self._hooks.values():
ref = hook()
if ref is not None:
ref(*args, **kwargs)
_after_in_child_fork_hooks = WeakCallbacks()
_after_in_parent_fork_hooks = WeakCallbacks()
_before_fork_hooks = WeakCallbacks()
def before_fork_hook(callable: Callable[[], None], make_persistent: bool = False):
"""
Run function before the fork of a worker process.
The function must be persistent (i.e. not a lambda) or an instance method.
Args:
callable: The function to run before the fork of a worker process.
make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
"""
_before_fork_hooks.add_hook(callable, make_persistent)
def after_in_parent_fork_hook(callable: Callable[[], None], make_persistent: bool = False):
"""
Run function after the fork of a worker process.
The function must be persistent (i.e. not a lambda) or an instance method.
Args:
callable: The function to run after the fork of a worker process.
make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
"""
_after_in_parent_fork_hooks.add_hook(callable, make_persistent)
def after_in_child_fork_hook(callable: Callable[[], None], make_persistent: bool = False):
"""
Run function after the fork of a worker process.
The function must be persistent (i.e. not a lambda) or an instance method.
Args:
callable: The function to run after the fork of a worker process.
make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
"""
_after_in_child_fork_hooks.add_hook(callable, make_persistent)
class ForkMixin:
"""
A mixin that runs a method after the fork of a worker process.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__post_init__()
def __post_init__(self):
if getattr(self.__before_fork__, "__func__", None) is not ForkMixin.__before_fork__:
before_fork_hook(self.__before_fork__)
if (
getattr(self.__after_in_child_fork__, "__func__", None)
is not ForkMixin.__after_in_child_fork__
):
after_in_child_fork_hook(self.__after_in_child_fork__)
if (
getattr(self.__after_in_parent_fork__, "__func__", None)
is not ForkMixin.__after_in_parent_fork__
):
after_in_parent_fork_hook(self.__after_in_parent_fork__)
def __after_in_child_fork__(self):
"""
A method that runs after the fork in the child process.
"""
pass
def __after_in_parent_fork__(self):
"""
A method that runs after the fork in the parent process.
"""
pass
def __before_fork__(self):
"""
A method that runs before the fork of a worker process.
"""
pass
@dataclass
class DataclassForkMixin:
"""
A mixin that runs a method after the fork of a worker process.
"""
def __post_init__(self):
if (
getattr(self.__before_fork__, "__func__", None)
is not DataclassForkMixin.__before_fork__
):
before_fork_hook(self.__before_fork__)
if (
getattr(self.__after_in_child_fork__, "__func__", None)
is not DataclassForkMixin.__after_in_child_fork__
):
after_in_child_fork_hook(self.__after_in_child_fork__)
if (
getattr(self.__after_in_parent_fork__, "__func__", None)
is not DataclassForkMixin.__after_in_parent_fork__
):
after_in_parent_fork_hook(self.__after_in_parent_fork__)
def __after_in_child_fork__(self):
"""
A method that runs after the fork in the child process.
"""
pass
def __after_in_parent_fork__(self):
"""
A method that runs after the fork in the parent process.
"""
pass
def __before_fork__(self):
"""
A method that runs before the fork of a worker process.
"""
pass
os.register_at_fork(
before=_before_fork_hooks.run,
after_in_child=_after_in_child_fork_hooks.run,
after_in_parent=_after_in_parent_fork_hooks.run,
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Optional, TypeVar
from megatron.energon.cache import CachePool
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors import SavableDataset
from megatron.energon.savable_loader import BasicDataLoader, SavableDataLoader
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER
T = TypeVar("T")
def get_savable_loader(
dataset: SavableDataset[T],
*,
worker_config: Optional[WorkerConfig] = None,
checkpoint_every_sec: float = 60,
checkpoint_every_min_n_samples: Optional[int] = None,
n_checkpoints: Optional[int] = None,
gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER,
prefetch_factor: int = 2,
cache_pool: Optional[CachePool] = None,
watchdog_timeout_seconds: Optional[float] = 60,
watchdog_initial_timeout_seconds: Optional[float] = None,
fail_on_timeout: bool = False,
) -> SavableDataLoader[T]:
"""
Get a dataloader for the given dataset.
Args:
dataset: The dataset to create a loader for.
worker_config: Deprecated. Please pass this to the dataset instead.
checkpoint_every_sec: This is the time in seconds after which an internal checkpoint is
saved. It may take the same duration to restore a checkpoint, but introduces additional
overhead during reading data from the dataset, so this should be chosen accordingly.
Only applies if using workers.
checkpoint_every_min_n_samples: Overwrites the minimum number of samples between
checkpoints. Defaults to `number of workers * 2`. Only applies if using workers.
n_checkpoints: The number of internal checkpoints to keep. Only applies if using workers.
If None, computes a suitable value.
cache_pool: If set, the cache pool to use for the dataset.
watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled.
watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.
fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace.
Returns:
The instantiated :class:`megatron.energon.SavableDataLoader`, yielding batches from the dataset,
allowing to save the state of the dataset.
"""
if worker_config is not None:
if worker_config != dataset.worker_config:
raise AssertionError(
"The worker_config passed to get_savable_loader() does not match the one of the dataset. "
"Also note, it is deprecated to pass one to get_savable_loader() and it will have no effect."
)
else:
warn_deprecated(
"Passing a worker_config to get_savable_loader() is deprecated and will have no effect."
)
return SavableDataLoader(
dataset,
checkpoint_every_sec=checkpoint_every_sec,
checkpoint_every_min_n_samples=checkpoint_every_min_n_samples,
n_checkpoints=n_checkpoints,
gc_collect_every_n_steps=gc_collect_every_n_steps,
prefetch_factor=prefetch_factor,
cache_pool=cache_pool,
watchdog_timeout_seconds=watchdog_timeout_seconds,
watchdog_initial_timeout_seconds=watchdog_initial_timeout_seconds,
fail_on_timeout=fail_on_timeout,
)
def get_loader(
dataset: SavableDataset[T],
*,
worker_config: Optional[WorkerConfig] = None,
prefetch_factor: int = 2,
cache_pool: Optional[CachePool] = None,
watchdog_timeout_seconds: Optional[float] = 60,
watchdog_initial_timeout_seconds: Optional[float] = None,
fail_on_timeout: bool = False,
) -> BasicDataLoader[T]:
"""
Get a dataloader for the given dataset.
Args:
dataset: The dataset to create a loader for.
worker_config: Deprecated. Please pass this to the dataset instead.
cache_pool: If set, the cache pool to use for the dataset.
watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled.
watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.
fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace.
Returns:
The instantiated :class:`torch.data.DataLoader`, yielding batches from the dataset.
"""
if worker_config is not None:
if worker_config != dataset.worker_config:
raise AssertionError(
"The worker_config passed to get_loader() does not match the one of the dataset. "
"Also note, it is deprecated to pass one to get_loader() and it will have no effect."
)
else:
warn_deprecated(
"Passing a worker_config to get_loader() is deprecated and will have no effect."
)
return BasicDataLoader(
dataset,
prefetch_factor=prefetch_factor,
cache_pool=cache_pool,
watchdog_timeout_seconds=watchdog_timeout_seconds,
watchdog_initial_timeout_seconds=watchdog_initial_timeout_seconds,
fail_on_timeout=fail_on_timeout,
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import hashlib
from pathlib import Path
from filelock import FileLock
from megatron.energon.epathlib import EPath
LOCAL_COPY_TMP_DIR = Path("/tmp/energon_local_copy")
def ensure_local_copy(path: EPath) -> EPath:
"""If the path is not local, copy it to a temporary directory and return the
path to the temporary directory.
Args:
path: The path to the file to copy.
Returns:
The path to the local copy of the file or the original path if it is already local.
"""
assert path.is_file(), f"Path {path} is not a file"
if path.is_local():
return path
LOCAL_COPY_TMP_DIR.mkdir(parents=True, exist_ok=True)
digest = hashlib.sha256(str(path).encode()).hexdigest()
final_path = LOCAL_COPY_TMP_DIR / f"{digest}.bin"
lock_path = final_path.with_suffix(".lock")
tmp_path = final_path.with_suffix(".part")
# Block until lock is free
with FileLock(lock_path, timeout=60 * 5):
if final_path.exists(): # someone else already produced it
return EPath(final_path)
# We are the downloader
try:
path.copy(EPath(tmp_path))
tmp_path.rename(final_path)
finally:
tmp_path.unlink(missing_ok=True)
return EPath(final_path)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.metadataset.dataset_loader import DatasetLoader
from megatron.energon.metadataset.loader import load_dataset, prepare_metadataset
from megatron.energon.metadataset.loader_interface import DatasetLoaderInterface
from megatron.energon.metadataset.metadataset import Metadataset
from megatron.energon.metadataset.metadataset_v2 import MetadatasetV2
__all__ = [
"DatasetLoader",
"load_dataset",
"DatasetLoaderInterface",
"Metadataset",
"MetadatasetV2",
"prepare_metadataset",
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Dict, Literal, Optional, Union
from megatron.energon.dataset_config import get_dataset_from_config
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.flavors import BaseCoreDatasetFactory
from megatron.energon.flavors.webdataset.structs import DatasetSubset
from megatron.energon.metadataset.loader_interface import (
DatasetBlendMode,
DatasetLoaderInterface,
LoadedDataset,
LoadedDatasetList,
)
from megatron.energon.worker import WorkerConfig
@edataclass
class DatasetLoader(DatasetLoaderInterface):
"""Loads a dataset from a path."""
path: Union[str, EPath]
split_part: Optional[str] = None
subflavors: Optional[Dict[str, Any]] = None
shuffle_over_epochs_multiplier: Optional[int] = 1
dataset_config: Optional[str] = None
split_config: Optional[str] = None
def post_initialize(self, mds_path: Optional[EPath] = None):
pass
def get_dataset(
self,
*,
training: bool,
split_part: Optional[str] = None,
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs: Optional[int] = 1,
split_config: Optional[str] = None,
dataset_config: Optional[str] = None,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> BaseCoreDatasetFactory:
"""
Args:
training: If true, apply training randomization.
split_part: Default split part to use.
worker_config: Worker configuration.
shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding).
subflavors: Subflavors to use, might be overridden by inner datasets.
shuffle_over_epochs: Shuffle the dataset over this many epochs.
subset: If specified, the inner dataset(s) will be subsetted.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The loaded dataset
"""
if self.split_part is not None:
split_part = self.split_part
if split_part is None:
raise ValueError("Missing split part")
if self.subflavors is not None:
subflavors = {**self.subflavors, **(subflavors or {})}
if split_config is None:
split_config = self.split_config
if dataset_config is None:
dataset_config = self.dataset_config
return get_dataset_from_config(
self.path,
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
dataset_config=dataset_config,
split_config=split_config,
shuffle_over_epochs=shuffle_over_epochs,
subset=subset,
**kwargs,
)
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
return LoadedDatasetList(
blend_mode=DatasetBlendMode.NONE,
datasets=[
LoadedDataset(
dataset=self.get_dataset(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
),
weight=None,
)
],
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import hashlib
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.flavors import (
BaseCoreDatasetFactory,
BaseWebdatasetFactory,
JoinedWebdatasetFactory,
Sample,
)
from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.indexing import JoinIndexWriter
from megatron.energon.flavors.webdataset.metadata import WebdatasetMeta
from megatron.energon.flavors.webdataset.structs import DatasetSubset
from megatron.energon.metadataset.dataset_loader import DatasetLoader
from megatron.energon.metadataset.loader_interface import (
DatasetBlendMode,
DatasetLoaderInterface,
LoadedDataset,
LoadedDatasetList,
)
from megatron.energon.worker import WorkerConfig
def filter_samples_with_excludes(
conn,
db_alias: str,
meta: "JoinedDatasetMetaInfo",
):
"""
Filter the samples in the database with the given excludes.
"""
filtered_name = f"{db_alias}_filtered"
conn.execute(f"DROP VIEW IF EXISTS {filtered_name}")
if not meta.excludes:
# Nothing to exclude, just use the original table
conn.execute(f"CREATE TEMP VIEW {filtered_name} AS SELECT * FROM {db_alias}.samples")
return
# Split the excludes into shard-level excludes and sample-level excludes
excluded_shard_ids = []
excluded_sample_keys = []
for exclude in meta.excludes:
if exclude in meta.shard_name_to_info_idx:
excluded_shard_ids.append(meta.shard_name_to_info_idx[exclude])
else:
# Find the shard name for the sample key
# Trivial split by .tar/
if ".tar/" in exclude:
tarname, sample_key = exclude.split(".tar/", 1)
shard_idx = meta.shard_name_to_info_idx[tarname + ".tar"]
excluded_sample_keys.append((shard_idx, sample_key))
elif exclude.endswith(".tar"):
# This is a shard and was probably already excluded outside this function
pass
else:
raise ValueError(
f"Invalid exclusion: Cannot split exclude {exclude} into shard and sample key"
)
# Create a temporary table for the shard excludes
# The key will be integers according to the tar_file_id column of the samples table
conn.execute(f"DROP TABLE IF EXISTS temp_shard_excludes_{db_alias}")
conn.execute(
f"""
CREATE TEMP TABLE temp_shard_excludes_{db_alias} (
exclude_key INTEGER PRIMARY KEY
)
"""
)
for shard_id in excluded_shard_ids:
conn.execute(
f"INSERT INTO temp_shard_excludes_{db_alias}(exclude_key) values (?)", (shard_id,)
)
# Create a temporary table for the sample excludes
conn.execute(f"DROP TABLE IF EXISTS temp_sample_excludes_{db_alias}")
conn.execute(
f"""
CREATE TEMP TABLE temp_sample_excludes_{db_alias} (
shard_idx INTEGER,
exclude_key TEXT,
PRIMARY KEY (shard_idx, exclude_key)
)
"""
)
conn.executemany(
f"INSERT INTO temp_sample_excludes_{db_alias}(shard_idx, exclude_key) values (?, ?)",
[(shard_idx, sample_key) for shard_idx, sample_key in excluded_sample_keys],
)
# Create view for filtered samples
conn.execute(
f"""
CREATE TEMP VIEW {filtered_name} AS
SELECT *
FROM {db_alias}.samples s
WHERE s.tar_file_id NOT IN (
SELECT exclude_key
FROM temp_shard_excludes_{db_alias}
)
AND NOT EXISTS (
SELECT 1
FROM temp_sample_excludes_{db_alias} e
WHERE e.shard_idx = s.tar_file_id
AND e.exclude_key = s.sample_key
)
"""
)
def join_multiple_indices(
meta_infos: List["JoinedDatasetMetaInfo"],
output_join_index_path: EPath,
):
"""
Joins the 'samples' table of one primary_db with multiple secondary_dbs
by 'sample_key'. For each secondary DB, we select three columns:
- tar_file_id
- byte_offset
- byte_size
The result is streamed out row-by-row and written to join index.
Note that the order of samples is determined by the shard_map of the primary DB.
Args:
meta_infos: List of meta infos for all datasets.
output_join_index_path: Path to the output join index.
"""
primary = meta_infos[0]
secondaries = meta_infos[1:]
assert primary.nonmatch == "error", (
"Primary join dataset must have nonmatch set 'error' (default)"
)
import sqlite3
# 1. Connect to the primary DB in 'main'
conn = sqlite3.connect(f"file:{primary.db_path!s}?mode=ro", uri=True)
# For safety, enable a read-only or big timeouts
conn.execute("PRAGMA busy_timeout = 5000;")
conn.execute("PRAGMA journal_mode = WAL;")
# 2. Attach each secondary DB under a unique alias, e.g. db1, db2, ...
secondary_aliases = []
for i, sec_mi in enumerate(secondaries, start=1):
alias = f"db{i}"
secondary_aliases.append(alias)
conn.execute(f"ATTACH DATABASE ? AS {alias}", (f"file:{sec_mi.db_path}?mode=ro",))
# Filter the primary and each secondary DB for excluded samples by creating
# a new VIEW for each
for alias, mi in zip(["main"] + secondary_aliases, meta_infos):
filter_samples_with_excludes(conn, alias, mi)
# Check each primary and secondary DB for duplicate sample_key values
for alias, mi in zip(["main"] + secondary_aliases, meta_infos):
duplicates = conn.execute(
f"""
SELECT sample_key, COUNT(*) AS c
FROM {alias}_filtered
GROUP BY sample_key
HAVING c > 1
LIMIT 5
"""
).fetchall()
if duplicates:
raise ValueError(
f"Can't join. Found duplicate sample keys in {mi.db_path}: {duplicates}"
)
# Create a temporary table to order the shards as in the current split config
conn.execute("DROP TABLE IF EXISTS primary_order")
conn.execute(
"""
CREATE TEMP TABLE primary_order (
tar_file_id INTEGER PRIMARY KEY,
split_index INTEGER
)
"""
)
conn.executemany(
"INSERT INTO primary_order(tar_file_id, split_index) values (?, ?)",
((n, i) for i, n in enumerate(primary.split_part_oder)),
)
# Map from tar_file_id to shard idx in the split part
tar_files_id_mapping = {}
for alias, mi in zip(["main"] + secondary_aliases, meta_infos):
tar_files_id_mapping[alias] = {
tar_file_id: shard_idx for shard_idx, tar_file_id in enumerate(mi.split_part_oder)
}
# These are the columns we want to select in the main SQL query
select_cols = [
"main_filtered.tar_file_id AS main_tar_file_id",
"main_filtered.byte_offset AS main_byte_offset",
"main_filtered.byte_size AS main_byte_size",
]
for i, alias in enumerate(secondary_aliases, start=1):
select_cols.append(f"{alias}_filtered.tar_file_id AS tar_file_id_{i}")
select_cols.append(f"{alias}_filtered.byte_offset AS byte_offset_{i}")
select_cols.append(f"{alias}_filtered.byte_size AS byte_size_{i}")
# Build the LEFT JOIN or INNER JOIN clauses
join_clauses = ""
for alias, mi in zip(secondary_aliases, secondaries):
if mi.nonmatch == "skip":
join_type = "INNER JOIN"
else:
join_type = "LEFT JOIN"
join_clauses += f" {join_type} {alias}_filtered ON main_filtered.sample_key = {alias}_filtered.sample_key"
# Construct the full SQL query
# We select three columns for the primary and each secondary DB
# Those are (tar_file_id, byte_offset, and byte_size)
# We join the secondary DBs to the primary DB using a LEFT JOIN, i.e.
# we keep all rows from the primary DB and add columns from the secondary DBs if available
# Finally, we also join the temporary shard order table to order the shards as in the split config.
# This join is done using an INNER JOIN, i.e. we only keep rows that have a matching shard index in the primary dataset,
# so we'll not include shards that come from other split parts
sql = f"""
SELECT
{", ".join(select_cols)}
FROM main_filtered
{join_clauses}
INNER JOIN primary_order o
ON main_tar_file_id = o.tar_file_id
ORDER BY o.split_index
"""
# 3. Execute the query; this returns a cursor we can iterate over row by row
cursor = conn.execute(sql)
all_db_aliases = ["main"] + secondary_aliases
# 4. Write the results to a binary file join index file row by row
with JoinIndexWriter(output_join_index_path) as join_index_writer:
# Example: We'll just show how to iterate the rows and pseudo-write them
num_rows = 0
num_missing = [0] * len(meta_infos)
for row in cursor:
# 'row' is a tuple of columns in the order of select_cols
join_tuples = []
for i, (alias, meta_info) in enumerate(zip(all_db_aliases, meta_infos)):
tar_file_id = row[3 * i]
if tar_file_id is None:
# This column is missing in this secondary dataset
# How we handle this case depends on the nonmatch setting
if meta_info.nonmatch == "none":
# The user accepts missing samples, we'll just add a dummy entry
join_tuples.append((-1, -1, -1))
num_missing[i] += 1
elif meta_info.nonmatch == "skip":
# The user wants to skip rows with missing samples.
# Skipping rows is already handled by the INNER JOIN above, so
# this case should not happen.
raise AssertionError(
f"Join has encountered a missing sample: Sample key {row[0]} missing from "
f"{meta_info.db_path}, although nonmatch_skip is set"
)
else:
# The user wants to raise an error on missing samples
raise ValueError(
f"Join has encountered a missing sample: Sample key {row[0]} missing from "
f"{meta_info.db_path}, although neither nonmatch_none nor nonmatch_skip are set"
)
else:
shard_idx = tar_files_id_mapping[alias][tar_file_id]
byte_offset = row[3 * i + 1]
byte_size = row[3 * i + 2]
join_tuples.append((shard_idx, byte_offset, byte_size))
else:
# Each row contains (shard_idx, byte_offset, byte_size) for each secondary key.
join_index_writer.append(*join_tuples)
num_rows += 1
any_skip = any(mi.nonmatch == "skip" for mi in meta_infos)
num_samples = conn.execute(
"SELECT COUNT(*) FROM main_filtered INNER JOIN primary_order o ON main_filtered.tar_file_id = o.tar_file_id"
).fetchone()[0]
if not any_skip:
# If no dataset has skipping active, we can check that the number of rows matches the number of samples in the primary DB
assert num_rows == num_samples, (
f"Number of rows in join index ({num_rows}) does not match number of samples in primary DB ({num_samples})"
)
print(f"Joined all {num_rows} samples")
else:
print(
f"Joined {num_rows}/{num_samples} samples, skipped {num_samples - num_rows} samples due to join"
)
if any(num_missing):
print(f"Non-matching samples filled with None for each dataset: {num_missing}")
conn.close()
@edataclass
class JoinedDatasetInfo:
"""Internal for passing the joined datasets."""
dataset: DatasetLoader
nonmatch: Literal["skip", "none", "error"]
@edataclass
class JoinedDatasetMetaInfo:
"""Internal for passing the joined datasets."""
db_path: EPath
uuid: str
excludes: List[str]
shard_name_to_info_idx: Dict[str, int]
split_part_oder: List[int]
nonmatch: Literal["skip", "none", "error"]
@edataclass
class JoinDatasetLoader(DatasetLoaderInterface):
"""Loads a joined dataset from a path."""
datasets: Union[List[JoinedDatasetInfo], Dict[str, JoinedDatasetInfo]]
joiner: Union[Type[Sample], Callable[..., Sample]]
cache_path: Optional[EPath] = None
split_part: Optional[str] = None
split_config: Optional[str] = None
subflavors: Optional[Dict[str, Any]] = None
shuffle_over_epochs_multiplier: Optional[int] = 1
def _get_joined_meta(self, split_part: str) -> Tuple[EPath, List[JoinedDatasetMetaInfo]]:
"""
Collect the metadata for the joined dataset.
Returns:
The hashfile path, and a list of the meta infos.
"""
# Get list of joinable datasets
datasets = self.datasets
if isinstance(datasets, dict):
datasets = list(datasets.values())
meta_infos: List[JoinedDatasetMetaInfo] = []
for dataset in datasets:
print(f" - {dataset}")
uuid_path = EPath(dataset.dataset.path) / MAIN_FOLDER_NAME / "index.uuid"
try:
uuid = uuid_path.read_text()
except FileNotFoundError:
raise FileNotFoundError(
f"Missing uuid file in {uuid_path}. You need to prepare the dataset "
"(with a recent version of energon). If you have already prepared the "
"dataset, it should be sufficient to run prepare with --tar-index-only."
)
db_path = EPath(dataset.dataset.path) / MAIN_FOLDER_NAME / "index.sqlite"
# Precedence for split_part is:
# 1. Join dataset split part (overrides individual dataset split parts)
# 2. Individual dataset split part
# 3. If none of the above is set, use the split part of the surrounding meta dataset
cur_split_part = dataset.dataset.split_part or self.split_part or split_part
assert cur_split_part is not None, "Missing split part"
wds_meta = WebdatasetMeta.from_config(
path=EPath(dataset.dataset.path),
split_part=cur_split_part,
split_config=dataset.dataset.split_config,
)
shard_name_to_info_idx = {name: i for i, name in enumerate(wds_meta.info_shard_files)}
# Given wds_meta.split_part_files, translate their order to info idx IDs
split_part_oder = [shard_name_to_info_idx[name] for name in wds_meta.split_part_files]
meta_infos.append(
JoinedDatasetMetaInfo(
db_path=db_path,
uuid=uuid,
excludes=list(wds_meta.sample_excludes),
shard_name_to_info_idx=shard_name_to_info_idx,
split_part_oder=split_part_oder,
nonmatch=dataset.nonmatch,
)
)
# Combine the hashes into a single hash by xor
hash = hashlib.sha256()
for meta_info in meta_infos:
hash.update(b"\0uuid=")
hash.update(meta_info.uuid.encode())
hash.update(b"\0excludes=")
for exclude in meta_info.excludes:
hash.update(exclude.encode())
hash.update(b"\0")
hash.update(f"\0nonmatch={meta_info.nonmatch}\0".encode())
assert self.cache_path is not None
return self.cache_path / f"join_index_{hash.hexdigest()}.bin", meta_infos
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is not None
self.cache_path = mds_path.parent / f"{mds_path.name}.cache"
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
assert self.cache_path is not None
assert split_part is not None
join_index_path, meta_infos = self._get_joined_meta(split_part)
if join_index_path.is_file():
print(f"Joined dataset already prepared at {join_index_path} and up-to-date")
return (join_index_path,)
print(f"Preparing joined dataset in {join_index_path}")
join_index_path.parent.mkdir(parents=True, exist_ok=True)
join_multiple_indices(
meta_infos=meta_infos,
output_join_index_path=join_index_path,
)
return (join_index_path,)
def get_dataset(
self,
*,
training: bool,
split_part: Optional[str] = None,
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs: Optional[int] = 1,
split_config: Optional[str] = None,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> BaseCoreDatasetFactory:
"""
Args:
training: If true, apply training randomization.
split_part: Default split part to use.
worker_config: Worker configuration.
shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding).
subflavors: Subflavors to use, might be overridden by inner datasets.
shuffle_over_epochs: Shuffle the dataset over this many epochs.
subset: If specified, the inner dataset(s) will be subsetted.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The loaded dataset
"""
if self.split_config is not None:
split_config = self.split_config
if self.split_part is not None:
split_part = self.split_part
if split_part is None:
raise ValueError("Missing split part")
if self.subflavors is not None:
subflavors = {**self.subflavors, **(subflavors or {})}
join_index_path, _ = self._get_joined_meta(split_part)
if isinstance(self.datasets, list):
inner_datasets = [
dataset.dataset.get_dataset(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs=shuffle_over_epochs,
split_config=split_config,
**kwargs,
)
for dataset in self.datasets
]
assert all(isinstance(d, BaseWebdatasetFactory) for d in inner_datasets), (
"Can only merge webdatasets efficiently"
)
elif isinstance(self.datasets, dict):
inner_datasets = {
key: dataset.dataset.get_dataset(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs=shuffle_over_epochs,
split_config=split_config,
**kwargs,
)
for key, dataset in self.datasets.items()
}
assert all(isinstance(d, BaseWebdatasetFactory) for d in inner_datasets.values()), (
"Can only merge webdatasets efficiently"
)
else:
raise ValueError("Invalid join type")
# Remove decoder from kwargs, it is already handled by the inner datasets
kwargs.pop("decoder", None)
return JoinedWebdatasetFactory(
inner_datasets=inner_datasets,
training=training,
worker_config=worker_config,
shuffle_over_epochs=shuffle_over_epochs,
join_index=join_index_path,
joiner=self.joiner,
subset=subset,
**kwargs,
)
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
return LoadedDatasetList(
blend_mode=DatasetBlendMode.NONE,
datasets=[
LoadedDataset(
dataset=self.get_dataset(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
),
weight=None,
)
],
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from pathlib import Path
from typing import Any, Union
from megatron.energon.dataset_config import load_config
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type
from megatron.energon.metadataset.dataset_loader import DatasetLoader
from megatron.energon.metadataset.loader_interface import DatasetLoaderInterface
from megatron.energon.metadataset.metadataset import Metadataset
from megatron.energon.typed_converter import JsonParser
def load_dataset(
path: Union[str, EPath, Path],
**kwargs,
) -> DatasetLoaderInterface:
"""Loads a (meta)dataset."""
if isinstance(path, dict):
mds = load_config(
path,
default_type=Metadataset,
default_kwargs=dict(path=EPath("/dict"), **kwargs),
)
return mds
path = EPath(path)
ds_type = get_dataset_type(path)
if ds_type == EnergonDatasetType.METADATASET:
mds = load_config(
path,
default_type=Metadataset,
default_kwargs=dict(path=path, **kwargs),
)
mds.post_initialize()
return mds
elif ds_type in (EnergonDatasetType.WEBDATASET, EnergonDatasetType.JSONL):
ds = DatasetLoader(path=path, **kwargs)
ds.post_initialize()
return ds
else:
raise ValueError(f"Invalid dataset at {path}")
class MockJsonParser(JsonParser):
"""Json Parser, which translates unknown objects to a mock class."""
def _resolve_object(
self,
module_name: str,
object_name: str,
cls: type,
is_type: bool,
is_callable: bool,
is_instantiating_class: bool,
is_calling_function: bool,
) -> Any:
try:
return super()._resolve_object(
module_name,
object_name,
cls,
is_type,
is_callable,
is_instantiating_class,
is_calling_function,
)
except ModuleNotFoundError:
class MockObject(cls):
def __init__(self, *_, **__):
pass
if is_type or is_instantiating_class:
return MockObject
elif is_callable or is_calling_function:
return MockObject
def prepare_metadataset(path: EPath):
from megatron.energon.dataset_config import load_config
from megatron.energon.metadataset.metadataset import Metadataset
meta_ds = load_config(
path,
default_type=Metadataset,
default_kwargs=dict(path=path),
parser=MockJsonParser(strict=True),
)
meta_ds.post_initialize()
meta_ds.prepare()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
from megatron.energon.cache import FileStore
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import BaseCoreDatasetFactory
from megatron.energon.flavors.webdataset.structs import DatasetSubset
from megatron.energon.worker import WorkerConfig
class DatasetBlendMode(Enum):
"""Determines how the the datasets are to be blended. Either by using the associated number as
the weight for sampling from that dataset, or alternatively by using the number as the number
of repetitions for samples in that dataset in one epoch (effectively, that corresponds to the
weight for samples)."""
NONE = "none"
DATASET_WEIGHT = "dataset_weight"
SAMPLE_REPETITIONS = "sample_repetitions"
@edataclass
class LoadedDataset:
dataset: BaseCoreDatasetFactory
weight: Union[float, int, None] = None
repetitions: Union[float, int, None] = None
aux: Optional[Dict[str, FileStore]] = None
@edataclass
class LoadedDatasetList:
datasets: List[LoadedDataset]
blend_mode: DatasetBlendMode = DatasetBlendMode.NONE
class DatasetLoaderInterface(ABC):
"""General interface for a dataset loader."""
@abstractmethod
def post_initialize(self, mds_path: Optional[EPath] = None):
"""Called to finally initialize the dataset."""
...
@abstractmethod
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
"""
Calls :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
for all innermost datasets and resolves their relative weights to absolute weights.
Args:
training: If true, apply training randomization.
split_part: Default split part to use.
worker_config: Worker configuration to use.
subflavors: Set the default subflavors for all datasets.
shuffle_over_epochs_multiplier: Multiply the inner datasets
`shuffle_over_epochs(_multiplier)` by this factor. E.g. if the inner dataset
has `shuffle_over_epochs_multiplier=2` and this function has
`shuffle_over_epochs_multiplier=3`, the inner dataset will be shuffled
over 6 epochs. Shuffling over `n` epochs guarantees that each sample is seen
exactly `n` times in `n` epochs of the inner dataset. Use -1 for shuffling over
an infinite number of epochs (effectively, this will draw shard slices with
replacement).
subset: If specified, the inner dataset(s) will be subsetted.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The dataset blending mode and the instantiated core datasets with their weights/repetitions.
"""
...
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
"""
Prepares the loader by creating caches and other necessary structures on disk.
Args:
split_part: Name of the split to load.
Returns:
List of paths to the cache paths. This is used for cleanup.
"""
return ()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Dict, List, Literal, Optional, Union
from megatron.energon.dataset_config import load_config
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.webdataset.metadata import check_dataset_info_present
from megatron.energon.flavors.webdataset.structs import DatasetSubset
from megatron.energon.metadataset.dataset_loader import DatasetLoader
from megatron.energon.metadataset.loader_interface import (
DatasetBlendMode,
DatasetLoaderInterface,
LoadedDatasetList,
)
from megatron.energon.worker import WorkerConfig
@edataclass
class DatasetReference:
path: Union[str, EPath]
split_part: Optional[str] = None
# Note: subflavor is only for legacy compatibility.
subflavor: Optional[str] = None
subflavors: Optional[Dict[str, Any]] = None
shuffle_over_epochs_multiplier: Optional[int] = 1
dataset_config: Optional[str] = None
split_config: Optional[str] = None
weight: float = 1.0
_dataset: Optional[DatasetLoaderInterface] = None
def __post_init__(self):
if self.subflavor is not None:
warn_deprecated(
"subflavor is deprecated, use subflavors instead. This will be removed in a future release."
)
if self.subflavors is None:
self.subflavors = {"__subflavor__": self.subflavor}
elif "__subflavor__" not in self.subflavors:
self.subflavors = {"__subflavor__": self.subflavor, **(self.subflavors or {})}
self.subflavor = None
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is not None
if not isinstance(self.path, EPath):
self.path = mds_path.parent / self.path
if self.path.is_file():
assert self.dataset_config is None, "Must not set dataset_config"
assert self.split_config is None, "Must not set split_config"
self._dataset = load_config(
self.path,
default_type=Metadataset,
default_kwargs=dict(path=self.path),
)
self._dataset.post_initialize()
elif check_dataset_info_present(self.path):
self._dataset = DatasetLoader(
path=self.path,
split_config=self.split_config,
dataset_config=self.dataset_config,
)
self._dataset.post_initialize()
else:
raise FileNotFoundError(self.path)
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
if self.subflavors is not None:
subflavors = {**self.subflavors, **(subflavors or {})}
assert self._dataset is not None
if shuffle_over_epochs_multiplier is None or self.shuffle_over_epochs_multiplier is None:
# If no shuffling is requested, this has override priority.
new_shuffle_over_epochs_multiplier = None
elif shuffle_over_epochs_multiplier == -1 or self.shuffle_over_epochs_multiplier == -1:
# Next priority is sampling without replacement.
new_shuffle_over_epochs_multiplier = -1
else:
# Otherwise, multiply the shuffle over epochs multiplier.
new_shuffle_over_epochs_multiplier = (
shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier
)
return self._dataset.get_datasets(
training=training,
split_part=self.split_part or split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=new_shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
@edataclass
class MetadatasetBlender:
"""Internal blending of the dataset."""
datasets: List[DatasetReference]
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is not None
for dataset in self.datasets:
dataset.post_initialize(mds_path)
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
sum_weight = sum(dataset.weight for dataset in self.datasets)
datasets = []
for dataset in self.datasets:
inner_result = dataset.get_datasets(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
if inner_result.blend_mode not in (
DatasetBlendMode.NONE,
DatasetBlendMode.DATASET_WEIGHT,
):
raise ValueError(
"Can only blend datasets which are of the same blend mode. Cannot mix blend with blend_epochized."
)
for loaded_dataset in inner_result.datasets:
if inner_result.blend_mode == DatasetBlendMode.DATASET_WEIGHT:
assert isinstance(loaded_dataset.weight, float)
else:
assert loaded_dataset.weight is None
loaded_dataset.weight = 1.0
loaded_dataset.weight = loaded_dataset.weight * dataset.weight / sum_weight
datasets.append(loaded_dataset)
return LoadedDatasetList(
blend_mode=DatasetBlendMode.DATASET_WEIGHT,
datasets=datasets,
)
class Metadataset(DatasetLoaderInterface):
"""Main entry for metadataset."""
_path: EPath
_splits: Dict[str, MetadatasetBlender]
def __init__(
self,
path: Union[EPath, str],
splits: Dict[str, MetadatasetBlender],
):
"""Create the metadataset"""
self._path = EPath(path)
self._splits = splits
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is None
for split in self._splits.values():
split.post_initialize(self._path)
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
return self._splits[split_part].get_datasets(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Type, Union
from megatron.energon.cache import FileStore
from megatron.energon.cache.file_store import (
SystemFileStore,
WebdatasetFileStore,
)
from megatron.energon.dataset_config import load_config
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.flavors import Sample
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type
from megatron.energon.flavors.webdataset.structs import DatasetSubset
from megatron.energon.metadataset.dataset_loader import DatasetLoader
from megatron.energon.metadataset.join_dataset_loader import JoinDatasetLoader, JoinedDatasetInfo
from megatron.energon.metadataset.loader_interface import (
DatasetBlendMode,
DatasetLoaderInterface,
LoadedDatasetList,
)
from megatron.energon.metadataset.metadataset import Metadataset
from megatron.energon.worker import WorkerConfig
# Regex for any URL-like string (any protocol)
url_regex = re.compile(r"^(?P<protocol>[a-z][a-z0-9+.-]*)://(?P<path>.*)", re.IGNORECASE)
@edataclass
class AuxDatasetReference:
path: Union[str, EPath]
def post_initialize(self, mds_path: Optional[EPath] = None) -> None:
assert mds_path is not None
if not isinstance(self.path, EPath):
self.path = mds_path.parent / self.path
assert not self.path.is_file(), (
"Auxiliary datasets must not be metadataset, but direct dataset references"
)
assert (self.path / MAIN_FOLDER_NAME / "index.sqlite").is_file(), (
"Auxiliary datasets must be prepared Energon datasets. This one does not exist or is not prepared: "
+ str(self.path)
)
def get_file_store(self) -> FileStore:
assert isinstance(self.path, EPath), "Missing call to post_initialize"
return WebdatasetFileStore(self.path)
@edataclass
class AuxFilesystemReference:
fs_path: Union[str, EPath]
def post_initialize(self, mds_path: Optional[EPath] = None) -> None:
assert mds_path is not None
if not isinstance(self.fs_path, EPath):
self.fs_path = mds_path.parent / self.fs_path
def get_file_store(self) -> FileStore:
assert isinstance(self.fs_path, EPath), "Missing call to post_initialize"
return SystemFileStore(self.fs_path)
@edataclass
class Subset:
"""
A subset range to be applied to a dataset. The range is always consecutive.
The range is a tuple of two values, where the first value is the start of the subset and the second value is the end of the subset (end not included).
The range can either be an absolute range with sample indices, or a ratio of the dataset size.
Relative range example: [25%, 75%]. This would limit the subset to the middle 50% of the dataset.
Absolute range example: [100, 200]. This would limit the subset to the 100 samples with indices 100-199.
For absolute ranges, the end can be set to "end" to indicate the end of the dataset, for example [100, end].
Since subsets can be specified at multiple levels of a hierarchy, for example in a blend,
their effects can be merged to a single subset.
Note however, that absolute ranges are only allowed for leaf datasets, while relative ranges
can be applied at any level.
"""
range: tuple[str | int, str | int]
def as_dataset_subset(self) -> DatasetSubset:
"""Convert the subset with string values to a DatasetSubset object with `range` and `absolute_range`."""
start, end = self.range
def _conv(value: str | int) -> float | int | None:
if isinstance(value, int):
return value
else:
assert isinstance(value, str), "Range must be a string if it's not an integer"
if value.strip() == "end":
return None
assert value.endswith("%"), "Range must be a percentage"
percentage = float(value.removesuffix("%"))
assert 0 <= percentage <= 100, "Percentage must be between 0 and 100"
return percentage / 100.0
start = _conv(start)
end = _conv(end)
if isinstance(start, int):
assert isinstance(end, int) or end is None, (
"End must be an integer if start is an integer"
)
return DatasetSubset(absolute_range=(start, end), range=(0, 1))
else:
assert isinstance(start, float), "Range start must be a float if it's not an integer"
assert isinstance(end, float) or end is None, "End must be a float if start is a float"
assert 0 <= start <= 1, "Start must be between 0 and 1"
assert 0 <= end <= 1, "End must be between 0 and 1"
assert start <= end, "Start must be less than end"
return DatasetSubset(range=(start, end), absolute_range=None)
def merge(self, parent_subset: DatasetSubset | None) -> DatasetSubset:
"""Merge this subset with a parent subset.
If the parent subset is None, return the subset.
If the parent subset is an absolute range, fail, because that's not allowed.
If the parent subset is a ratio, merge it with the subset.
Merging a child absolute range with a parent relative range:
In this case, both are kept in the DatasetSubset object and applies in "absolute first" order later.
Merging a child relative range with a parent relative range:
In this case, the relative parent range is applied to the child's relative range.
The absolute range is not affected.
For details on how this is applied, see `DatasetSubset.compute_subset`.
"""
assert parent_subset is None or parent_subset.absolute_range is None, (
f"Cannot merge absolute subset ranges. Absolute ranges are only allowed for a leaf dataset. {self.absolute_range=} {self.range=}"
)
my_subset = self.as_dataset_subset()
if parent_subset is None or parent_subset.range is None:
return my_subset
# Assuming inner ratio: [0.25, 0.75] and outer ratio: [0, 0.5]
# Then the total ratio is supposed to be: [0.25 + 0*0.5, 0.25 + 0.5 * 0.5] = [0.25, 0.5]
total = my_subset.range[1] - my_subset.range[0]
return DatasetSubset(
range=(
my_subset.range[0] + parent_subset.range[0] * total,
my_subset.range[0] + parent_subset.range[1] * total,
),
absolute_range=my_subset.absolute_range,
)
@edataclass
class SubsetRatioMixin:
subset: Optional[Subset] = None
def _get_subset(self, parent_subset: Optional[DatasetSubset]) -> Optional[DatasetSubset]:
if parent_subset is not None:
assert parent_subset.absolute_range is None, (
f"Can only use absolute subset ranges for a leaf dataset (Range {parent_subset.absolute_range=})"
)
if self.subset is not None:
return self.subset.merge(parent_subset)
else:
return parent_subset
elif self.subset is not None:
return self.subset.merge(None)
return None
@edataclass
class DatasetReference(SubsetRatioMixin, DatasetLoaderInterface):
path: Union[str, EPath]
split_part: Optional[str] = None
subflavors: Optional[Dict[str, Any]] = None
shuffle_over_epochs_multiplier: Optional[int] = 1
dataset_config: Optional[str] = None
split_config: Optional[str] = None
#: Auxiliary datasets. May only be specified for crude datasets for cooking. Cooking will get
# these references to load data from. If specified as string, it will be interpreted as a
# dataset path.
aux: Optional[Dict[str, str]] = None
_dataset: Optional[DatasetLoaderInterface] = None
def post_initialize(self, mds_path: Optional[EPath] = None) -> None:
assert mds_path is not None
if not isinstance(self.path, EPath):
self.path = mds_path.parent / self.path
ds_type = get_dataset_type(self.path)
if ds_type == EnergonDatasetType.METADATASET:
assert self.aux is None, "Cannot specify auxiliary datasets for crude datasets"
assert self.dataset_config is None, "Must not set dataset_config"
assert self.split_config is None, "Must not set split_config"
# Note: For backwards compatibility, the type must be Metadataset (V1).
self._dataset = load_config(
self.path,
default_type=Metadataset,
default_kwargs=dict(path=self.path),
)
self._dataset.post_initialize()
elif ds_type in (EnergonDatasetType.WEBDATASET, EnergonDatasetType.JSONL):
self._dataset = DatasetLoader(
path=self.path,
split_config=self.split_config,
dataset_config=self.dataset_config,
)
self._dataset.post_initialize()
if self.aux is not None:
new_aux = {}
for k, v in self.aux.items():
if m := url_regex.match(v):
if m.group("protocol") == "filesystem":
new_aux[k] = AuxFilesystemReference(fs_path=m.group("path"))
else:
raise ValueError(f"Unsupported protocol: {m.group('protocol')}")
else:
new_aux[k] = AuxDatasetReference(path=v)
new_aux[k].post_initialize(mds_path)
self.aux = new_aux
else:
raise FileNotFoundError(self.path)
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
assert self._dataset is not None
return self._dataset.prepare(split_part=split_part)
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
if self.subflavors is not None:
subflavors = {**self.subflavors, **(subflavors or {})}
assert self._dataset is not None
if shuffle_over_epochs_multiplier is None or self.shuffle_over_epochs_multiplier is None:
# If no shuffling is requested, this has override priority.
new_shuffle_over_epochs_multiplier = None
elif shuffle_over_epochs_multiplier == -1 or self.shuffle_over_epochs_multiplier == -1:
# Next priority is sampling without replacement.
new_shuffle_over_epochs_multiplier = -1
else:
# Otherwise, multiply the shuffle over epochs multiplier.
new_shuffle_over_epochs_multiplier = (
shuffle_over_epochs_multiplier * self.shuffle_over_epochs_multiplier
)
subset = self._get_subset(subset)
result = self._dataset.get_datasets(
training=training,
split_part=self.split_part or split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=new_shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
if self.aux is not None:
aux = {k: v.get_file_store() for k, v in self.aux.items()}
for loaded_dataset in result.datasets:
if loaded_dataset.aux is None:
loaded_dataset.aux = aux
else:
loaded_dataset.aux.update(aux)
return result
@edataclass
class JoinDatasetReference(DatasetReference):
nonmatch: Literal["skip", "none", "error"] = "error"
def post_initialize(self, mds_path: Optional[EPath] = None) -> DatasetLoader:
assert mds_path is not None
# Override and disable another metadataset reference, only allow direct dataset references.
# Do not store the loader, the parent MetadatasetJoin will do that.
if not isinstance(self.path, EPath):
self.path = mds_path.parent / self.path
ds_type = get_dataset_type(self.path)
if ds_type == EnergonDatasetType.WEBDATASET:
return DatasetLoader(
path=self.path,
split_part=self.split_part,
subflavors=self.subflavors,
shuffle_over_epochs_multiplier=self.shuffle_over_epochs_multiplier,
dataset_config=self.dataset_config,
split_config=self.split_config,
)
else:
raise ValueError(f"Not a joinabledataset at {self.path}")
def prepare(self, split_part: Optional[str] = None):
assert False, (
"JoinDatasetReference should not be used directly, but only by MetadatasetJoin"
)
def get_datasets(
self,
**kwargs,
) -> LoadedDatasetList:
assert False, (
"JoinDatasetReference should not be used directly, but only by MetadatasetJoin"
)
@edataclass
class MetadatasetJoin(SubsetRatioMixin, DatasetLoaderInterface):
join: Union[List[JoinDatasetReference], Dict[str, JoinDatasetReference]]
joiner: Union[Type[Sample], Callable[..., Sample]]
split_part: Optional[str] = None
subflavors: Optional[Dict[str, Any]] = None
shuffle_over_epochs_multiplier: Optional[int] = 1
dataset_config: Optional[str] = None
split_config: Optional[str] = None
_dataset: Optional[JoinDatasetLoader] = None
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is not None
assert self.join is not None
assert self.joiner is not None, "Must set joiner for joining datasets"
assert self.dataset_config is None, "Cannot set dataset_config for joining datasets"
assert self.split_config is None, "Cannot set split_config for joining datasets"
if isinstance(self.join, list):
inner_loaders = [
JoinedDatasetInfo(
dataset=join.post_initialize(mds_path),
nonmatch=join.nonmatch,
)
for join in self.join
]
elif isinstance(self.join, dict):
inner_loaders = {
key: JoinedDatasetInfo(
dataset=join.post_initialize(mds_path),
nonmatch=join.nonmatch,
)
for key, join in self.join.items()
}
else:
raise ValueError("Invalid join type")
self._dataset = JoinDatasetLoader(
datasets=inner_loaders,
joiner=self.joiner,
split_part=self.split_part,
subflavors=self.subflavors,
shuffle_over_epochs_multiplier=self.shuffle_over_epochs_multiplier,
split_config=self.split_config,
)
self._dataset.post_initialize(mds_path)
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
assert self._dataset is not None, "Missing post_initialize call."
return self._dataset.prepare(split_part=split_part)
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
assert self._dataset is not None, "Missing post_initialize call."
subset = self._get_subset(subset)
return self._dataset.get_datasets(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
@dataclass
class BlendWeightMixin:
weight: float = 1.0
@edataclass
class BlendDatasetReference(BlendWeightMixin, DatasetReference):
pass
@edataclass
class BlendJoinDatasetReference(BlendWeightMixin, MetadatasetJoin):
pass
@edataclass
class MetadatasetBlend(DatasetLoaderInterface, SubsetRatioMixin):
"""Blending of datasets by specifying the sampling weight for the inner datasets."""
blend: List[Union[BlendDatasetReference, BlendJoinDatasetReference]]
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is not None
for dataset in self.blend:
dataset.post_initialize(mds_path)
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
files = []
for dataset in self.blend:
files.extend(dataset.prepare(split_part=split_part))
return files
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
subset = self._get_subset(subset)
sum_weight = sum(dataset.weight for dataset in self.blend)
datasets = []
for dataset in self.blend:
inner_result = dataset.get_datasets(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
if inner_result.blend_mode not in (
DatasetBlendMode.NONE,
DatasetBlendMode.DATASET_WEIGHT,
):
raise ValueError(
"Can only blend datasets which are of the same blend mode. Cannot mix blend with blend_epochized."
)
for loaded_dataset in inner_result.datasets:
if inner_result.blend_mode == DatasetBlendMode.DATASET_WEIGHT:
assert isinstance(loaded_dataset.weight, float)
else:
assert inner_result.blend_mode == DatasetBlendMode.NONE
assert loaded_dataset.weight is None
assert loaded_dataset.repetitions is None
loaded_dataset.weight = 1.0
loaded_dataset.weight = loaded_dataset.weight * dataset.weight / sum_weight
datasets.append(loaded_dataset)
return LoadedDatasetList(
blend_mode=DatasetBlendMode.DATASET_WEIGHT,
datasets=datasets,
)
@dataclass
class BlendRepetitionsMixin:
repetitions: Union[int, float] = 1
@edataclass
class BlendEpochizedDatasetReference(BlendRepetitionsMixin, DatasetReference):
pass
@edataclass
class BlendEpochizedJoinDatasetReference(BlendRepetitionsMixin, MetadatasetJoin):
pass
@edataclass
class MetadatasetBlendEpochized(SubsetRatioMixin, DatasetLoaderInterface):
"""Blending of datasets, by specifying the number of repetitions for samples from the inner
datasets. Ensures that the constraint, that samples are seen exactly this many times before
repeating the "epoch" (i.e. one epoch contains the total number of repetitions for each inner
dataset)."""
blend_epochized: List[Union[BlendEpochizedDatasetReference, BlendEpochizedJoinDatasetReference]]
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is not None
for dataset in self.blend_epochized:
dataset.post_initialize(mds_path)
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
files = []
for dataset in self.blend_epochized:
files.extend(dataset.prepare(split_part=split_part))
return files
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
subset = self._get_subset(subset)
datasets = []
for dataset in self.blend_epochized:
inner_result = dataset.get_datasets(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
if inner_result.blend_mode not in (
DatasetBlendMode.NONE,
DatasetBlendMode.SAMPLE_REPETITIONS,
):
raise ValueError(
"Can only blend datasets which are of the same blend mode. Cannot mix blend with blend_epochized."
)
for loaded_dataset in inner_result.datasets:
if inner_result.blend_mode == DatasetBlendMode.SAMPLE_REPETITIONS:
assert isinstance(loaded_dataset.repetitions, (int, float))
else:
assert loaded_dataset.weight is None
assert loaded_dataset.repetitions is None
loaded_dataset.repetitions = 1
loaded_dataset.repetitions = dataset.repetitions * loaded_dataset.repetitions
datasets.append(loaded_dataset)
return LoadedDatasetList(
blend_mode=DatasetBlendMode.SAMPLE_REPETITIONS,
datasets=datasets,
)
@edataclass
class MetadatasetV2(DatasetLoaderInterface):
path: EPath
splits: Dict[
str, Union[MetadatasetBlend, MetadatasetBlendEpochized, MetadatasetJoin, DatasetReference]
]
def post_initialize(self, mds_path: Optional[EPath] = None):
assert mds_path is None
for split in self.splits.values():
split.post_initialize(self.path)
def prepare(self, split_part: Optional[str] = None) -> Sequence[EPath]:
# In the case of prepare for MetadatasetV2, we ignore the passed cache_path
# and instead use the own path.
# If someone runs energon prepare on a metadataset that refers to another metadataset,
# any actions concerning the inner metadataset will be done on the inner metadataset's path.
if split_part is None:
files = []
for split_part, split in self.splits.items():
files.extend(split.prepare(split_part=split_part))
else:
files = self.splits[split_part].prepare(split_part=split_part)
# Cleanup paths here
cache_path = EPath(self.path.parent / f"{self.path.name}.cache")
remove_files = set(cache_path.glob("*")) - set(files)
for file in remove_files:
print(f"Cleanup cache file {file}")
file.unlink()
# Do not return paths. This is a barrier
return ()
def get_datasets(
self,
*,
training: bool,
split_part: Union[Literal["train", "val", "test"], str],
worker_config: WorkerConfig,
subflavors: Optional[Dict[str, Any]] = None,
shuffle_over_epochs_multiplier: Optional[int] = 1,
subset: Optional[DatasetSubset] = None,
**kwargs,
) -> LoadedDatasetList:
return self.splits[split_part].get_datasets(
training=training,
split_part=split_part,
worker_config=worker_config,
subflavors=subflavors,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
subset=subset,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment