Commit f356f546 authored by maming's avatar maming
Browse files

Initial commit

parents
Pipeline #3339 canceled with stages
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from abc import ABC
from typing import (
Callable,
Generator,
Optional,
Tuple,
TypeVar,
)
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.jsonl.ijsonl import (
CachedIJsonlOffsetReader,
IJsonlFile,
IJsonlIndexReader,
IJsonlSamplePointer,
)
from megatron.energon.flavors.webdataset.structs import FilteredSample
from megatron.energon.source_info import SourceInfo
T_index = TypeVar("T_index", covariant=False)
class IJsonlReader(ABC):
"""
Class for reading indexed jsonl files containing json samples.
The common usage patterns and random-access interfaces are provided here.
Args:
base_path: The path to the dataset.
jsonl_path: The path to the jsonl file.
jsonl_filename: The jsonl file name.
sample_filter: An optional filter function to select samples by their key.
index_cache_size: The size of the index cache.
"""
jsonl_path: EPath
sample_filter: Optional[Callable[[str], bool]]
cached_offset_reader: CachedIJsonlOffsetReader
ijsonl_file: IJsonlFile | None = None
def __init__(
self,
jsonl_path: EPath,
sample_filter: Optional[Callable[[str], bool]] = None,
index_cache_size: int = 5,
):
self.jsonl_path = jsonl_path
self.sample_filter = sample_filter
self.cached_offset_reader = CachedIJsonlOffsetReader(
jsonl_path, cache_size=index_cache_size
)
def __len__(self) -> int:
return len(self.cached_offset_reader)
def __str__(self) -> str:
return f"IJsonlReader(jsonl_path={self.jsonl_path})"
def _get_item_by_sample_pointer(
self,
sample_pointer: IJsonlSamplePointer,
) -> FilteredSample | None:
"""
Get a sample from the dataset or slice it.
Args:
sample_pointer: The sample pointer to get the sample from.
sample_index: The global index of the sample in the dataset.
Returns:
The sample or None if the sample is invalid.
"""
key = str(sample_pointer.index)
if self.sample_filter is not None and not self.sample_filter(key):
return None
if self.ijsonl_file is None:
self.ijsonl_file = IJsonlFile(self.jsonl_path.open("rb"))
json_data = self.ijsonl_file.next(sample_pointer.byte_offset, sample_pointer.byte_size)
if json_data is None:
return None
return FilteredSample(
__key__=f"{self.jsonl_path.name}/{key}",
__shard__=self.jsonl_path.name,
__restore_key__=("Webdataset", sample_pointer.index),
__sources__=(
SourceInfo(
dataset_path=str(self.jsonl_path),
index=sample_pointer.index,
shard_name=self.jsonl_path.name,
file_names=(f"{key}.json",),
),
),
json=json_data,
)
def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInfo] | None:
"""
Get a sample from the dataset.
"""
assert isinstance(idx, (int, str)), f"Invalid argument type for __getitem__: {type(idx)}"
full_entry_name = False
if isinstance(idx, str):
if idx.endswith(".json"):
num_idx = idx.removesuffix(".json")
full_entry_name = True
try:
idx = int(num_idx)
except ValueError:
raise ValueError(f"Invalid JSONL sample key: {idx}")
byte_offset, byte_size = self.cached_offset_reader.get_ijsonl_byte_offset(idx)
sample: FilteredSample | None = self._get_item_by_sample_pointer(
IJsonlSamplePointer(
index=idx,
byte_offset=byte_offset,
byte_size=byte_size,
)
)
if sample is None:
return None
if full_entry_name:
assert len(sample["__sources__"]) == 1
return sample["json"], sample["__sources__"][0]
else:
return sample
def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]:
"""List all samples in the jsonl file.
Returns:
A generator of tuples of (sample_key, size, tar_file_id)
"""
last_byte_offset = 0
with IJsonlIndexReader(self.jsonl_path) as ijsonl_index_reader:
for sample_idx, byte_offset in enumerate(ijsonl_index_reader):
if last_byte_offset == byte_offset:
continue
yield str(sample_idx), byte_offset - last_byte_offset, 0
last_byte_offset = byte_offset
def list_all_sample_parts(self) -> Generator[Tuple[str, int, int], None, None]:
"""List all sample parts in the jsonl file.
Returns:
A generator of tuples of (sample_key + "." + part_name, size, tar_file_id)
"""
last_byte_offset = 0
with IJsonlIndexReader(self.jsonl_path) as ijsonl_index_reader:
for sample_idx, byte_offset in enumerate(ijsonl_index_reader):
if last_byte_offset == byte_offset:
continue
yield f"{sample_idx}.json", byte_offset - last_byte_offset, 0
last_byte_offset = byte_offset
def list_sample_parts(self, sample_key: str) -> Generator[Tuple[str, int, int], None, None]:
"""Given a sample key, list all its parts. (E.g. given 1, list 1.jpg, 1.json, etc.)
Args:
sample_key: The sample key to list the parts of.
Returns:
A generator of tuples of (part_name, size, tar_file_id)
"""
try:
sample_idx = int(sample_key)
except ValueError:
raise ValueError(f"Invalid JSONL sample key: {sample_key}")
_, byte_size = self.cached_offset_reader.get_ijsonl_byte_offset(sample_idx)
yield f"{sample_key}.json", byte_size, 0
def get_total_size(self) -> int:
return self.cached_offset_reader.get_total_size()
def close(self):
if self.ijsonl_file is not None:
self.ijsonl_file.close()
self.cached_offset_reader.close()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import logging
from pathlib import Path
from typing import (
Any,
Dict,
Generator,
TypeVar,
Union,
)
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.jsonl.ijsonl import IJsonlFile, IJsonlIndexWriter
logger = logging.getLogger(__name__)
T = TypeVar("T", covariant=True)
class JsonlPreparator:
@staticmethod
def iter_dataset_content(
path: Union[str, EPath],
) -> Generator[Dict[str, Any], None, None]:
"""
Yield example dataset content for a few samples.
Args:
path: Path to the tar file.
"""
with EPath(path).open("rb") as f:
with IJsonlFile(f) as index_reader:
for entry in index_reader:
yield {"json": entry}
@classmethod
def prepare_dataset(
cls,
path: Union[Path, EPath],
) -> int:
"""
Preprocess the jsonl file. Preprocessing is done in parallel.
Counts the number of samples.
Args:
path: Path to the jsonl file
Returns:
Count of samples in the jsonl file.
"""
count = 0
# Processing is lagging behind. The offsets include empty lines. The whole file must be covered!
last_offset = 0
with IJsonlIndexWriter(EPath(path)) as iw:
with EPath(path).open("rb") as f:
while True:
line = f.readline()
if not line:
break
line = line.strip()
if not line:
if last_offset:
last_offset = f.tell()
continue
assert line.startswith(b"{") and line.endswith(b"}"), (
f"Line {line} does not start and end with a json object {{}}."
)
iw.append(last_offset)
last_offset = f.tell()
count += 1
assert last_offset == f.tell(), (
f"The last offset {last_offset} does not match the file size {f.tell()}."
)
assert last_offset != 0, "File is empty."
iw.append(last_offset)
return count
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Optional
import torch
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class MultiChoiceVQASample(Sample):
"""Sample type for visual question answering."""
#: The input image tensor in the shape (C, H, W)
image: torch.Tensor
#: The context/question for the image
context: str
#: The candidate answers.
choices: Optional[List[str]] = None
#: The index of the correct answer.
correct_choice_idx: int = 0
class MultiChoiceVQAWebdataset(DefaultDecoderWebdatasetFactory[MultiChoiceVQASample]):
__sample_type__ = MultiChoiceVQASample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Optional, Union
import torch
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class OCRSample(Sample):
"""Sample type for optical character recognition."""
#: The input image tensor in the shape (C, H, W)
image: torch.Tensor
#: The text contained in the image
text: str
#: The bounding boxes of the blocks in the image float(N, 4|5<x, y, w, h>)
block_boxes: Optional[torch.Tensor] = None
#: The classes of the blocks in the image int(N, 1<block_class>)
block_classes: Optional[Union[torch.Tensor, List[str]]] = None
#: The text contained in each block (N,)
block_text: Optional[List[str]] = None
#: The bounding boxes of the lines in the image float(N, 4|5<x, y, w, h[, confidence]>)
lines_boxes: Optional[torch.Tensor] = None
#: The text contained in each line (N,)
lines_text: Optional[List[str]] = None
#: The bounding boxes of the words in the image float(N, 4|5<x, y, w, h[, confidence]>)
words_boxes: Optional[torch.Tensor] = None
#: The text contained in each word (N,)
words_text: Optional[List[str]] = None
#: The bounding boxes of the chars in the image float(N, 4|5<x, y, w, h[, confidence]>)
chars_boxes: Optional[torch.Tensor] = None
#: The character contained in each char (N,)
chars_text: Optional[List[str]] = None
class OCRWebdataset(DefaultDecoderWebdatasetFactory[OCRSample]):
__sample_type__ = OCRSample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Optional
import torch
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class SimilarityInterleavedSample(Sample):
"""Sample type for interleaved media such as text with images, but without image-text alignment.
That alignment has to be assigned from the similarity matrix."""
#: The images of the sequence
images: List[torch.Tensor]
#: The texts of the sequence
texts: List[str]
#: The optional audio samples of the sequence
audio: Optional[List[torch.Tensor]] = None
#: The optional video frames of the sequence
video: Optional[List[torch.Tensor]] = None
#: Similarity matrix between image and text entries in the sequence
similarity_matrix: Optional[torch.Tensor] = None
#: The index within texts representing the sentence that this image is matched to
matched_text_indices: Optional[List[int]] = None
class SimilarityInterleavedWebdataset(DefaultDecoderWebdatasetFactory[SimilarityInterleavedSample]):
__sample_type__ = SimilarityInterleavedSample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class TextSample(Sample):
"""Sample type for simple text."""
#: The text of the sample
text: str
class TextWebdataset(DefaultDecoderWebdatasetFactory[TextSample]):
__sample_type__ = TextSample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Optional
import torch
from megatron.energon.av import AVDecoder
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class VidQASample(Sample):
"""Sample type for video question answering."""
#: The video data containing the image and audio info.
video: AVDecoder
#: The context/question for the image.
context: str
#: The possible answers. Not set for testing.
answers: Optional[List[str]] = None
#: The weights of the possible answers. Optionally available.
answer_weights: Optional[torch.Tensor] = None
class VidQAWebdataset(DefaultDecoderWebdatasetFactory[VidQASample]):
__sample_type__ = VidQASample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Optional
import torch
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class VQASample(Sample):
"""Sample type for visual question answering."""
#: The input image tensor in the shape (C, H, W)
image: torch.Tensor
#: The context/question for the image
context: str
#: The possible answers. Not set for testing.
answers: Optional[List[str]] = None
#: The weights of the possible answers. Optionally available.
answer_weights: Optional[torch.Tensor] = None
class VQAWebdataset(DefaultDecoderWebdatasetFactory[VQASample]):
__sample_type__ = VQASample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Optional
import torch
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class VQAOCRSample(Sample):
"""Sample type for visual question answering."""
#: The input image tensor in the shape (C, H, W)
image: torch.Tensor
#: The context/question for the image (VQA)
context: str
#: The text contained in the image (OCR)
text: str
#: The possible answers. Not set for testing. (VQA)
answers: Optional[List[str]] = None
#: The weights of the possible answers. Optionally available. (VQA)
answer_weights: Optional[torch.Tensor] = None
#: The bounding boxes of the words in the image (N, 4|5<x, y, w, h[, confidence]>) (OCR)
words_boxes: Optional[torch.Tensor] = None
#: The text contained in each word (N,) (OCR)
words_text: Optional[List[str]] = None
class VQAOCRWebdataset(DefaultDecoderWebdatasetFactory[VQAOCRSample]):
__sample_type__ = VQAOCRSample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.av import AVData, AVDecoder
from megatron.energon.flavors.webdataset.base_webdataset import BaseWebdatasetFactory
from megatron.energon.flavors.webdataset.config import JOIN_INDEX_FILENAME, MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.default_decoder_webdataset import (
DefaultDecoderWebdatasetFactory,
)
from megatron.energon.flavors.webdataset.default_generic_webdataset import (
DefaultGenericWebdatasetFactory,
)
from megatron.energon.flavors.webdataset.empty_dataset_error import EmptyDatasetError
from megatron.energon.flavors.webdataset.joined_webdataset import JoinedWebdatasetFactory
from megatron.energon.flavors.webdataset.metadata import WebdatasetMeta
from megatron.energon.flavors.webdataset.sample_decoder import (
AVDecoderType,
ImageDecoderType,
SampleDecoder,
)
from megatron.energon.flavors.webdataset.standard_webdataset import StandardWebdatasetFactory
__all__ = [
"AVData",
"AVDecoder",
"AVDecoderType",
"BaseWebdatasetFactory",
"DefaultDecoderWebdatasetFactory",
"DefaultGenericWebdatasetFactory",
"EmptyDatasetError",
"ImageDecoderType",
"JOIN_INDEX_FILENAME",
"JoinedWebdatasetFactory",
"MAIN_FOLDER_NAME",
"SampleDecoder",
"StandardWebdatasetFactory",
"WebdatasetMeta",
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import multiprocessing
from abc import ABC, abstractmethod
from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar
T_result = TypeVar("T_result")
T_aggregation_data = TypeVar("T_aggregation_data")
T_input_data = TypeVar("T_input_data")
class BaseAggregator(ABC, Generic[T_aggregation_data, T_result]):
"""
Base class for a user-defined aggregator.
Implement on_start, on_item, and on_finish to handle aggregator logic.
"""
def on_start(self, aggregator_pool: AggregatorPool) -> None:
"""
Called exactly once in the aggregator process before receiving any items.
"""
pass
@abstractmethod
def on_item(self, item: T_aggregation_data, aggregator_pool: AggregatorPool) -> None:
"""
Called for each item produced by the workers.
"""
...
def on_finish(self, aggregator_pool: AggregatorPool) -> None:
"""
Called once when all workers have signaled completion (i.e. all items are processed).
"""
pass
def get_final_result_data(self) -> T_result:
"""
Called after on_finish to retrieve any final data produced by the aggregator.
"""
return None
class AggregatorPool(Generic[T_input_data, T_aggregation_data, T_result]):
"""
A pool that manages multiple worker processes sending results to
a single aggregator process.
The user must provide:
- user_produce_data(task) -> yields items (streaming results)
- aggregator: an instance of a class derived from BaseAggregator
which implements on_start, on_item, on_finish, etc.
"""
num_workers: int
user_produce_data: Callable[[T_input_data], Iterable[Any]]
aggregator: BaseAggregator[T_aggregation_data, T_result]
task_queue: multiprocessing.Queue[Optional[T_input_data]]
result_queue: multiprocessing.Queue[Optional[T_aggregation_data]]
def __init__(
self,
num_workers: int,
user_produce_data: Callable[[T_input_data], Iterable[Any]],
aggregator: BaseAggregator[T_aggregation_data, T_result],
) -> None:
"""
Args:
num_workers: Number of worker processes.
user_produce_data: Function that takes a task and yields items (the "large" data stream).
aggregator: An instance of a user-defined class for handling aggregator logic.
"""
self.num_workers = num_workers
self.user_produce_data = user_produce_data
self.aggregator = aggregator
# Queues for tasks and results
self.task_queue = multiprocessing.Queue()
self.result_queue = multiprocessing.Queue()
# Queue to pass final aggregator data back to the main process
self._final_result_data_queue = multiprocessing.Queue()
# Will store whatever is pulled from _final_data_queue in close()
self._aggregator_final_result_data: Optional[Any] = None
def _worker(self, worker_id: int) -> None:
"""Function that runs inside each worker process."""
while True:
task = self.task_queue.get()
if task is None:
# No more tasks, signal aggregator that this worker is done
break
# Produce data in a streaming fashion
for item in self.user_produce_data(task):
self.result_queue.put(item)
# After finishing all tasks, send a sentinel to the aggregator
self.result_queue.put(None)
def _aggregator_run(self) -> T_result:
"""
Function that runs in the aggregator process.
Keeps reading items from result_queue.
- If an item is None, that means a worker finished all of its tasks.
- Otherwise, call aggregator.on_item(...) with that item.
"""
# Let the aggregator do any initialization it needs
self.aggregator.on_start(self)
finished_workers = 0
while finished_workers < self.num_workers:
item = self.result_queue.get()
if item is None:
# A worker has finished all of its tasks
finished_workers += 1
else:
# Process the item in the aggregator
self.aggregator.on_item(item, self)
# All workers done, aggregator can finalize
self.aggregator.on_finish(self)
# After finishing, serialize the aggregator's final data
return self.aggregator.get_final_result_data()
def submit_task(self, task: T_input_data) -> None:
"""
Submit a task to be processed by a worker.
"""
self.task_queue.put(task)
def process(self) -> T_result:
"""
Starts the worker processes and runs the aggregation in the main process.
Waits for all workers to finish and retrieves the aggregator's final data.
"""
workers: List[multiprocessing.Process] = []
# Start worker processes
for w_id in range(self.num_workers):
p = multiprocessing.Process(target=self._worker, args=(w_id,), daemon=True)
p.start()
workers.append(p)
# Send the sentinel (None) to each worker
for _ in range(self.num_workers):
self.task_queue.put(None)
# Now wait for aggregator to finish reading the queue
result = self._aggregator_run()
# Wait for all workers to finish
for p in workers:
p.join()
return result
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, TypeVar
import numpy as np
from megatron.energon.cache import FileStore
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import (
BaseCoreDatasetFactory,
SavableDataset,
legacy_handler,
)
from megatron.energon.flavors.webdataset.error_handler import ErrorHandler
from megatron.energon.flavors.webdataset.metadata import WebdatasetMeta
from megatron.energon.flavors.webdataset.prepare import WebdatasetPreparator
from megatron.energon.flavors.webdataset.sample_loader import (
RawSampleData,
WebdatasetSampleLoaderDataset,
)
from megatron.energon.flavors.webdataset.sharder import Sharder
from megatron.energon.flavors.webdataset.structs import (
DatasetSubset,
FilteredSample,
ShardInfo,
reraise_exception,
)
from megatron.energon.source_info import SourceInfo
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.map_dataset import MapDataset
T_sample = TypeVar("T_sample", covariant=True)
T = TypeVar("T", covariant=True)
logger = logging.getLogger(__name__)
class BaseWebdatasetFactory(
BaseCoreDatasetFactory[T_sample],
WebdatasetPreparator,
Sharder,
ErrorHandler,
Generic[T_sample],
ABC,
):
"""
Base class for all webdataset sample loader factories. Applies proper sharding across workers.
"""
path: EPath
paths: list[EPath]
shards: List[ShardInfo]
sample_excludes: set[str]
split_part_files: list[str]
training: bool
worker_config: WorkerConfig
shuffle_over_epochs: Optional[int]
parallel_shard_iters: Optional[int]
max_samples_per_sequence: Optional[int]
subset: Optional[DatasetSubset]
part_filter: Optional[Callable[[str], bool]]
handler: Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None]
shards: List[ShardInfo]
def __init__(
self,
path: EPath,
*,
split_part: str,
training: bool,
worker_config: WorkerConfig,
shuffle_over_epochs: Optional[int] = 1,
parallel_shard_iters: Optional[int] = None,
max_samples_per_sequence: Optional[int] = None,
subset: Optional[DatasetSubset] = None,
split_config: Optional[str] = None,
part_filter: Optional[Callable[[str], bool]] = None,
handler: Callable[
[Exception, Optional[str], Optional[list[SourceInfo]]], None
] = reraise_exception,
):
"""
Base factory for the webdataset sample loader.
Args:
path: Path to the dataset.
split_part: Which part to load (e.g. 'train', 'val', 'test').
training: If true, apply shuffling and loop the dataset.
worker_config: Configuration for the workers.
shuffle_over_epochs: Only effective if training=True.
How many epochs to shuffle over if training.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather shard slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices
are drawn with replacement).
parallel_shard_iters: Number of parallel opened shards per worker, shuffling between.
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequentially iterated).
subset: If specified, the dataset will be subsetted.
split_config: Config file to use for shard split definitions.
part_filter: (internal) Function for filtering tar files by dict keys
handler: Exception handler. Args: (exception, key, source_info).
"""
assert self.__sample_type__ is not None, f"Class {type(self)} must define __sample_type__"
wds_meta = WebdatasetMeta.from_config(
path=path, split_part=split_part, split_config=split_config
)
self.path = path
self.paths = [path]
self.shards = wds_meta.shards
self.sample_excludes = wds_meta.sample_excludes
self.split_part_files = wds_meta.split_part_files
self.training = training
self.worker_config = worker_config
self.shuffle_over_epochs = shuffle_over_epochs
self.parallel_shard_iters = parallel_shard_iters
self.max_samples_per_sequence = max_samples_per_sequence
self.subset = subset
self.part_filter = part_filter
self.handler = legacy_handler(handler)
def __len__(self) -> int:
return sum(shard.count for shard in self.shards)
def build(self, worker_rotation_offset: int = 0) -> SavableDataset[T_sample]:
from megatron.energon.flavors.webdataset.itar_reader import ShardInfosITarReader
if self.parallel_shard_iters is None:
if self.training:
# 16 seems to be a good choice since we don't want too many file handles open
parallel_shard_iters = 16
else:
parallel_shard_iters = 1
else:
parallel_shard_iters = self.parallel_shard_iters
workers_sample_slice_offsets = self.shard_workers(
self.shards,
worker_config=self.worker_config,
max_samples_per_sequence=self.max_samples_per_sequence,
rotation_offset=worker_rotation_offset,
subset=self.subset,
)
_print_shard_slices(self.worker_config, self.shards, workers_sample_slice_offsets)
itar_reader = ShardInfosITarReader(
self.path,
self.shards,
part_filter=self.part_filter,
sample_filter=self.sample_filter,
itar_cache_size=parallel_shard_iters,
)
dataset = WebdatasetSampleLoaderDataset(
join_readers=[itar_reader],
workers_sample_slice_offsets=workers_sample_slice_offsets,
worker_config=self.worker_config,
shuffle_over_epochs=self.shuffle_over_epochs if self.training else None,
parallel_slice_iters=parallel_shard_iters,
)
return MapDataset(
dataset,
self._load_sample_raw,
error_handler=self.error_handler,
stateless_map_fn=True,
map_fn_config=self.config,
worker_config=self.worker_config,
)
def as_file_store(self) -> "FileStore":
from megatron.energon.cache.file_store import WebdatasetFileStore
return WebdatasetFileStore(self.path)
def sample_filter(self, key: str) -> bool:
return key not in self.sample_excludes
def _load_sample_raw(self, raw_sample: RawSampleData) -> T_sample:
# Just a wrapper for the inner tuple. Tuple should be of length 1.
assert len(raw_sample.data) == 1 and raw_sample.data[0] is not None
return self.load_sample(raw_sample.data[0])
@abstractmethod
def load_sample(self, raw_data: FilteredSample) -> T_sample:
"""Loads the sample from the dataset."""
...
def config(self) -> Dict[str, Any]:
return dict(
type=type(self).__qualname__,
training=self.training,
_path=str(self.path),
shards=[
dict(
name=shard.name,
count=shard.count,
_path=str(shard.path),
)
for shard in self.shards
],
sample_excludes=list(self.sample_excludes),
shuffle_over_epochs=self.shuffle_over_epochs,
parallel_shard_iters=self.parallel_shard_iters,
max_samples_per_sequence=self.max_samples_per_sequence,
subset=self.subset.config() if self.subset is not None else None,
)
def __str__(self):
return f"{type(self).__name__}(path={self.path})"
def _print_shard_slices(
worker_config: WorkerConfig, shards: List[ShardInfo], slice_offsets: Sequence[Sequence[int]]
):
shard_starts = np.cumsum([0] + [shard.count for shard in shards])
def shard_range_info(start: int, end: int) -> str:
start_shard_idx = np.searchsorted(shard_starts, start, side="right") - 1
end_shard_idx = np.searchsorted(shard_starts, end, side="left") - 1
if start_shard_idx == end_shard_idx:
shard = shards[start_shard_idx]
if start - shard_starts[start_shard_idx] == 0:
start_str = "(start)"
else:
start_str = ""
if end - shard_starts[start_shard_idx] == shard.count:
end_str = "(end)"
else:
end_str = ""
return f"{shard.name}[{start - shard_starts[start_shard_idx]}{start_str}, {end - shard_starts[start_shard_idx]}{end_str}]"
else:
start_shard = shards[start_shard_idx]
end_shard = shards[end_shard_idx]
if start - shard_starts[start_shard_idx] == 0:
start_str = "(start)"
else:
start_str = ""
if end - shard_starts[end_shard_idx] == end_shard.count:
end_str = "(end)"
else:
end_str = ""
return f"{start_shard.name}[{start - shard_starts[start_shard_idx]}{start_str},]-{end_shard.name}[,{end - shard_starts[end_shard_idx]}{end_str}]"
for worker_idx, sample_slice_offsets in enumerate(slice_offsets):
start_idx = sample_slice_offsets[0]
end_idx = sample_slice_offsets[-1]
if len(sample_slice_offsets) > 6:
offset_str = f"{', '.join(str(o) for o in sample_slice_offsets[:3])} ...<{len(sample_slice_offsets) - 6}> {', '.join(str(o) for o in sample_slice_offsets[-3:])}"
else:
offset_str = ", ".join(str(o) for o in sample_slice_offsets)
if len(sample_slice_offsets) > 6:
slices_str = (
", ".join(
shard_range_info(start, end)
for start, end in zip(sample_slice_offsets[:3], sample_slice_offsets[1:4])
)
+ f" ...<{len(sample_slice_offsets) - 6}> "
+ ", ".join(
shard_range_info(start, end)
for start, end in zip(sample_slice_offsets[-4:-1], sample_slice_offsets[-3:])
)
)
else:
slices_str = ", ".join(
shard_range_info(start, end)
for start, end in zip(sample_slice_offsets[:-1], sample_slice_offsets[1:])
)
print(
f"rank={worker_config.rank}, worker={worker_idx}: sample_range=[{start_idx}, {end_idx}] in {len(sample_slice_offsets) - 1} slices, "
f"sum(count)={end_idx - start_idx}: indexes=[{offset_str}] slices=[{slices_str}]"
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import re
split_name_re = re.compile(r"^((?:.*/|)[^.]+)[.]([^/]*)$")
skip_meta_re = re.compile(r"__[^/]*__($|/)")
MAIN_FOLDER_NAME = ".nv-meta"
JOIN_INDEX_FILENAME = "join_index.bin"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Dict, Generic, Optional, TypeVar
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset.default_generic_webdataset import (
DefaultGenericWebdatasetFactory,
)
from megatron.energon.flavors.webdataset.sample_decoder import DEFAULT_DECODER, SampleDecoder
from megatron.energon.flavors.webdataset.structs import FilteredSample
T_sample = TypeVar("T_sample", covariant=True)
class DefaultDecoderWebdatasetFactory(DefaultGenericWebdatasetFactory[T_sample], Generic[T_sample]):
"""
Extends the default webdataset loading with decoding of contained files, such as images, videos or nested
containers.
"""
# The webdataset decoder function, if to be applied
_decoder: Optional[SampleDecoder]
def __init__(
self,
path: EPath,
*,
decoder: Optional[SampleDecoder] = DEFAULT_DECODER,
**kwargs,
):
"""
Factory for the webdataset sample loader including the decoder.
Args:
path: Path to the dataset (passed to parent)
decoder: If provided, use this decoder, otherwise just load raw bytes.
**kwargs: Args passed to parent constructor
"""
self._decoder = decoder
super().__init__(path, **kwargs)
def load_sample(self, sample: FilteredSample) -> T_sample:
if self._decoder is not None:
sample = self._decoder(sample)
return super().load_sample(sample)
def config(self) -> Dict[str, Any]:
return dict(
**super().config(),
**(self._decoder.config() if self._decoder is not None else {}),
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.flavors.webdataset.base_webdataset import BaseWebdatasetFactory
from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.field_access import field_access, split_field_access
from megatron.energon.flavors.webdataset.structs import FilteredSample
from megatron.energon.module_loader import ModuleLoader
T_sample = TypeVar("T_sample", covariant=True)
class DefaultGenericWebdatasetFactory(BaseWebdatasetFactory[T_sample], Generic[T_sample]):
"""
Default implementation of webdataset for generic samples and the generic config interface for use with dataset.yaml.
"""
_sample_loader: Callable[[Dict[str, Any]], Dict[str, Any]]
def __init__(
self,
path: EPath,
*,
subflavors: Optional[Dict[str, Any]] = None,
field_map: Optional[Dict[str, str]] = None,
sample_loader: Optional[Union[str, Callable[[dict], dict]]] = None,
part_filter: Optional[Union[str, List[str], Callable[[str], bool]]] = None,
**kwargs,
):
"""
Factory for the webdataset sample loader and basic configuration options.
Args:
subflavors: Subflavors dictionary to set for all loaded samples.
field_map: Mapping from the webdataset fields to the sample fields.
sample_loader: Function to load the sample from the webdataset fields. May be a string
in order to load a function from a module, or a callable directly.
part_filter: Filter for the parts to load. May be a string in order to load a function
from a module, or a callable directly.
**kwargs: Args passed to parent constructor.
"""
assert (field_map is None) != (sample_loader is None), (
"Either field_map or sample_loader must be provided."
)
if sample_loader is not None:
assert part_filter is not None, (
"part_filter must be provided if sample_loader is provided."
)
module_loader = ModuleLoader()
if isinstance(sample_loader, str):
sample_loader = module_loader.get_function(
sample_loader, "sample_loader", relative_path=path / MAIN_FOLDER_NAME
)
else:
assert callable(sample_loader)
sample_loader = sample_loader
if isinstance(part_filter, list):
parts = set(part_filter)
part_filter = lambda part: part in parts
elif isinstance(part_filter, str):
part_filter = module_loader.get_function(
part_filter, "part_filter", relative_path=path / MAIN_FOLDER_NAME
)
else:
assert callable(part_filter)
self._sample_loader = sample_loader
else:
assert field_map is not None
assert part_filter is None
# Split field map fields by json[field][field]
fields = {key: split_field_access(field) for key, field in field_map.items()}
assert set(field.name for field in dataclasses.fields(self.__sample_type__)).issuperset(
fields.keys()
) and set(
field.name
for field in dataclasses.fields(self.__sample_type__)
if field.default is not dataclasses.MISSING
and field.default_factory is not dataclasses.MISSING
).issubset(field_map.keys()), (
f"field_map does not map to type {self.__sample_type__.__name__} fields"
)
self._sample_loader = lambda sample: {
k: field_access(sample, v) for k, v in fields.items()
}
parts = set(access[0] for options in fields.values() for access in options)
part_filter = lambda part: part in parts
inner_sample_loader = self._sample_loader
self._sample_loader = lambda sample: {
"__key__": sample["__key__"],
**inner_sample_loader(sample),
"__restore_key__": sample["__restore_key__"],
"__subflavors__": self.subflavors,
"__sources__": sample["__sources__"],
}
super().__init__(path, **kwargs, part_filter=part_filter)
self.subflavors = subflavors or {}
def load_sample(self, sample: FilteredSample) -> T_sample:
return self.__sample_type__(**self._sample_loader(sample))
def config(self) -> Dict[str, Any]:
return dict(
**super().config(),
subflavors=self.subflavors,
sample_loader=SavableDataset._function_config(self._sample_loader),
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
class EmptyDatasetError(Exception):
"""Raised when a dataset is empty."""
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Callable, Generic, Optional, Tuple, TypeVar, Union
from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset.structs import FilteredSample, reraise_exception
from megatron.energon.source_info import SourceInfo
T_sample = TypeVar("T_sample", covariant=True)
class ErrorHandler(Generic[T_sample]):
handler: Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None] = (
reraise_exception
)
def sample_error_handler(
self, e: Exception, sample_key: Optional[str], sources: Optional[list[SourceInfo]] = None
):
if isinstance(e, SYSTEM_EXCEPTIONS):
raise FatalSampleError(f"Error in sample {sample_key!r}: {e}") from e
self.handler(e, sample_key, sources)
def error_handler(
self,
e: Exception,
sample: Union[
T_sample,
dict,
FilteredSample,
None,
Tuple[Union[T_sample, dict, FilteredSample, None], ...],
],
):
if isinstance(sample, dict):
key = sample.get("__key__")
sources = sample.get("__sources__")
elif isinstance(sample, list):
if isinstance(sample[0], dict):
key = ",".join("None" if s is None else s.get("__key__") for s in sample)
sources = [src for s in sample for src in s.get("__sources__", ())]
elif isinstance(sample[0], Sample):
key = ",".join("None" if s is None else s.__key__ for s in sample)
sources = [src for s in sample for src in s.__sources__]
else:
key = None
sources = None
elif isinstance(sample, Sample):
key = sample.__key__
sources = sample.__sources__
else:
key = None
sources = None
self.sample_error_handler(e, key, sources)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import re
from typing import Any, List, Union
from megatron.energon.errors import SampleException, compact_str
_field_split_re = re.compile(r"^(?P<field_name>[^[\]]+)(?P<access>(?:\[([^[\]]+)])*)$")
def split_field_access(field_desc: str) -> List[List[str]]:
"""
Splits a field_map for access::
'abcdef,ghi' -> [['abcdef'], ['ghi']]
'abcdef[ghi]' -> [['abcdef', 'ghi']]
'abcdef[ghi][jkl]' -> [['abcdef', 'ghi', 'jkl']]
"""
options = field_desc.split(",")
option_fields = []
for option in options:
match = _field_split_re.match(option)
if match:
option_fields.append(
[match.group("field_name")]
+ [
access.lstrip("[").rstrip("]")
for access in match.group("access").split("][")
if access
]
)
else:
option_fields.append([field_desc])
return option_fields
class FieldAccessError(SampleException):
pass
def _field_access(value: Union[dict, list, str, int, bool, None], field: List[str]) -> Any:
"""
Accesses a (nested) field in the value.
Args:
value: The value to access
field: The access instruction (e.g. `['field1', 'field2']` for
`value['field1']['field2']`)
Returns:
The accessed value
"""
try:
if len(field) == 0:
return value
elif isinstance(value, dict):
return _field_access(value[field[0]], field[1:])
elif isinstance(value, list):
return _field_access(value[int(field[0])], field[1:])
else:
raise FieldAccessError(
f"Cannot access literal value {compact_str(value)} with {field!r}"
)
except FieldAccessError:
raise
except KeyError:
raise FieldAccessError(f"Cannot access {'.'.join(field)!r} in {compact_str(value)}")
def field_access(value: Union[dict, list, str, int, bool, None], field: List[List[str]]) -> Any:
"""
Accesses a (nested) field in the value.
Args:
value: The value to access
field: The access instruction (e.g. `[['field1', 'field2']]` for
`value['field1']['field2']`, or `[['field1'], ['field2']]` for value.get('field1', value['field2'])`)
Returns:
The accessed value
"""
for f in field[:-1]:
try:
return _field_access(value, f)
except (KeyError, ValueError, IndexError):
pass
return _field_access(value, field[-1])
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import sqlite3
import struct
from pathlib import Path
from typing import BinaryIO, Generator, List, Optional, Tuple, Union
from numpy import int8
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset.itar import ITarRawSamplePartPointer, ITarSamplePointer
from megatron.energon.flavors.webdataset.thread_local_sqlite import ThreadLocalSqlite
from megatron.energon.local_copy import ensure_local_copy
class SqliteIndexWriter:
sqlite_path: EPath
db: Optional[sqlite3.Connection]
duplicates: List[Tuple[str, int]]
def __init__(self, sqlite_path: EPath):
"""
Initializes an SQLite database and sets up the samples table:
- samples(tar_file_id INTEGER,
sample_key TEXT,
sample_index INTEGER,
byte_offset INTEGER,
byte_size INTEGER)
and the sample_parts table:
- sample_parts(tar_file_id INTEGER,
sample_index INTEGER,
part_name TEXT,
content_byte_offset INTEGER,
content_byte_size INTEGER)
Also creates indexes:
- samples(sample_key)
- samples(tar_file_id, sample_index)
- sample_parts(tar_file_id, sample_index, content_byte_offset)
- sample_parts(tar_file_id, sample_index, part_name, content_byte_offset, content_byte_size)
"""
# Final path and temporary path
self.sqlite_path = sqlite_path
# Initialize SQLite connection
path = str(self.sqlite_path)
# Only supporting local file system, because sqlite does not support remote file systems.
# TODO: Implement remote file systems. Maybe create locally in tmp then upload?
assert path.startswith("/"), (
f"SQLite path must be absolute local file system path: {self.sqlite_path}"
)
Path(path).parent.mkdir(parents=True, exist_ok=True)
self.db = sqlite3.connect(path)
self.db.execute("PRAGMA busy_timeout = 5000;") # wait up to 5000ms when locked
self.db.execute("PRAGMA journal_mode = WAL;")
# Create the sample table
self.db.execute("DROP INDEX IF EXISTS idx_samples_sample_key")
self.db.execute("DROP INDEX IF EXISTS idx_samples_by_tar_and_idx")
self.db.execute("DROP TABLE IF EXISTS samples")
self.db.execute(
"""
CREATE TABLE samples (
tar_file_id INTEGER,
sample_key TEXT,
sample_index INTEGER,
byte_offset INTEGER,
byte_size INTEGER
)
"""
)
# Create the sample parts table
self.db.execute("DROP INDEX IF EXISTS idx_sample_parts_seq")
self.db.execute("DROP INDEX IF EXISTS idx_sample_parts_full")
self.db.execute("DROP TABLE IF EXISTS sample_parts")
self.db.execute(
"""
CREATE TABLE sample_parts (
tar_file_id INTEGER,
sample_index INTEGER,
part_name TEXT,
content_byte_offset INTEGER,
content_byte_size INTEGER
)
"""
)
self.duplicates = []
def append_sample(
self,
tar_file_id: int8,
sample_key: str,
sample_index: int,
byte_offset: Optional[int],
byte_size: Optional[int],
):
"""
Adds a new sample row to the samples table.
Args:
tar_file_id: The index of the tar file in the reader.
sample_key: The key of the sample.
sample_index: The index of the sample in the tar file.
byte_offset: The byte offset of the sample in the tar file.
byte_size: The size of the sample in the tar file.
"""
assert self.db is not None, "Database is closed"
# Insert a row in the samples table
self.db.execute(
"""
INSERT INTO samples (tar_file_id, sample_key, sample_index, byte_offset, byte_size)
VALUES (?, ?, ?, ?, ?)
""",
(tar_file_id, sample_key, sample_index, byte_offset, byte_size),
)
def append_part(
self,
tar_file_id: int8,
sample_index: int,
part_name: str,
content_byte_offset: int,
content_byte_size: int,
):
"""Adds a new part row to the samples table."""
assert self.db is not None, "Database is closed"
# Insert a row in the sample parts table
self.db.execute(
"""
INSERT INTO sample_parts (tar_file_id, sample_index, part_name, content_byte_offset, content_byte_size)
VALUES (?, ?, ?, ?, ?)
""",
(tar_file_id, sample_index, part_name, content_byte_offset, content_byte_size),
)
def close(self):
"""
Closes the DB connection. If finalize=True, the temporary database is
renamed to the final name, overwriting if necessary.
"""
assert self.db is not None, "Database is closed"
# Create the index after adding all the samples for better speed
# Index on sample_key for fast lookups
self.db.execute("CREATE INDEX IF NOT EXISTS idx_samples_sample_key ON samples(sample_key)")
# Create index on the samples table. Help the planner if it chooses `samples` as the probe side of the join
self.db.execute(
"CREATE INDEX IF NOT EXISTS idx_samples_by_tar_and_idx ON samples(tar_file_id, sample_index)"
)
# Create index on the sample_parts table for fast sequential access
self.db.execute(
"CREATE INDEX IF NOT EXISTS idx_sample_parts_seq ON sample_parts(tar_file_id, sample_index, content_byte_offset)"
)
# Create a full index on the sample_parts table for equality lookups and getting offsets directly from key
self.db.execute(
"CREATE INDEX IF NOT EXISTS idx_sample_parts_full ON sample_parts(tar_file_id, sample_index, part_name, content_byte_offset, content_byte_size)"
)
# Check if sample_key are all unique
# self.db.execute("CREATE TEMP TABLE temp AS SELECT sample_key, COUNT(*) AS c FROM samples GROUP BY sample_key HAVING c > 1")
duplicates = self.db.execute(
"SELECT sample_key, COUNT(*) AS c FROM samples GROUP BY sample_key HAVING c > 1 LIMIT 5"
).fetchall()
if len(duplicates) > 0:
self.duplicates = duplicates
if self.db is not None:
self.db.commit()
self.db.close()
self.db = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# If an exception occurred, do not finalize (so you can inspect the temp file)
self.close()
class JoinIndexWriter:
"""Describes how one primary dataset is joined with multiple secondary datasets.
For fast random access, this is a binary format that is memory-mapped.
The first 16 bytes are a header with the number of columns (1 primary + N secondary).
Each row contains (shard_idx, byte_offset, byte_size) for each column.
"""
def __init__(self, join_index_path: EPath):
self.join_index_path = join_index_path
self.join_index_file = join_index_path.open("wb")
self.num_columns = None
def append(self, *columns: Tuple[int, int, int]):
"""Appends a new row to the join index file.
Each row contains (shard_idx, byte_offset, byte_size) for each column.
"""
if self.num_columns is None:
# Write the number of columns
self.join_index_file.write(b"JIDX0001") # Magic bytes with version
self.join_index_file.write(struct.pack("q", len(columns)))
self.num_columns = len(columns)
else:
assert len(columns) == self.num_columns, (
f"Inconsistent number of keys: Had {self.num_columns} before, got {len(columns)}"
)
# Write the columns
for key in columns:
assert isinstance(key, tuple) and len(key) == 3
self.join_index_file.write(struct.pack("qqq", *key))
def close(self):
self.join_index_file.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class SqliteIndexReader:
"""Reads samples from an SQLite database created by SqliteIndexWriter.
The database contains a table with the following schema:
- samples(tar_file_id INTEGER,
sample_key TEXT,
sample_index INTEGER,
byte_offset INTEGER,
byte_size INTEGER)
- sample_parts(tar_file_id INTEGER,
sample_index INTEGER,
part_name TEXT,
content_byte_offset INTEGER,
content_byte_size INTEGER)
"""
sqlite_path: EPath
db: ThreadLocalSqlite
def __init__(self, sqlite_path: EPath):
"""Initialize the SQLite database reader.
Args:
sqlite_path: Path to the SQLite database file
"""
self.sqlite_path = ensure_local_copy(sqlite_path)
# Initialize SQLite connection
path = str(self.sqlite_path)
# Only supporting local file system, because sqlite does not support remote file systems
assert path.startswith("/"), (
f"SQLite path must be absolute local file system path: {self.sqlite_path}"
)
path = f"file:{path}?mode=ro&immutable=1"
self.db = ThreadLocalSqlite(path, is_uri=True)
def db_has_sample_parts(self) -> bool:
"""Check if the database has a sample_parts table.
Returns:
True if sample_parts table exists, False otherwise.
"""
assert self.db is not None, "Database is closed"
db_exists = self.db.select_one(
"SELECT name FROM sqlite_master WHERE type='table' AND name='sample_parts'"
)
self.db.thread_close()
return db_exists is not None
def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]:
"""List all sample keys in the database.
Returns:
Tuple of (sample_key, byte_size)
"""
assert self.db is not None, "Database is closed"
for row in self.db.select_all("SELECT sample_key, byte_size, tar_file_id FROM samples"):
yield row[0], row[1], row[2]
def list_all_sample_parts(self) -> Generator[Tuple[str, int, int], None, None]:
"""List all sample parts (i.e. individual files) in the database.
Returns:
Tuple of (full_key, size, tar_file_id)
"""
assert self.db is not None, "Database is closed"
# Select all parts (sorted by tar_file_id, sample_index) but joined with the sample_key names
for row in self.db.select_all(
"SELECT "
"s.sample_key || '.' || sp.part_name AS full_key, "
"sp.content_byte_size AS size, "
"sp.tar_file_id AS tar_file_id "
"FROM sample_parts AS sp "
"JOIN samples AS s "
"ON sp.tar_file_id = s.tar_file_id AND sp.sample_index = s.sample_index "
"ORDER BY sp.tar_file_id, sp.sample_index, sp.content_byte_offset"
):
yield row[0], row[1], row[2]
def list_sample_parts(self, sample_key: str) -> Generator[Tuple[str, int, int], None, None]:
"""List all sample parts (i.e. individual files) in the database.
Args:
sample_key: The sample key to look up
Returns:
Tuple of (part_name, size, tar_file_id)
"""
assert self.db is not None, "Database is closed"
# Select all parts (sorted by tar_file_id, sample_index) but joined with the sample_key names
for row in self.db.select_all(
"SELECT "
"sp.part_name AS part_name, "
"sp.content_byte_size AS size, "
"sp.tar_file_id AS tar_file_id "
"FROM sample_parts AS sp "
"JOIN samples AS s "
"ON sp.tar_file_id = s.tar_file_id AND sp.sample_index = s.sample_index "
"WHERE s.sample_key = ? "
"ORDER BY sp.tar_file_id, sp.sample_index, sp.content_byte_offset",
(sample_key,),
):
yield row[0], row[1], row[2]
def get_total_size(self) -> int:
"""Get the total size of all samples in the database."""
assert self.db is not None, "Database is closed"
count = self.db.select_one("SELECT SUM(byte_size) FROM samples")
return count[0] if count else 0
def get_sample_count(self) -> int:
"""Get the total number of samples in the database."""
assert self.db is not None, "Database is closed"
count = self.db.select_one("SELECT COUNT(*) FROM samples")
return count[0] if count else 0
def get_sample_part(self, key: str, part_name: str) -> ITarRawSamplePartPointer:
"""Get a sample part by its key name and part name.
Args:
key: The sample key to look up
part_name: The part name to look up
Returns:
Pointer to the sample part raw data.
"""
assert self.db is not None, "Database is closed"
row = self.db.select_one(
"SELECT sp.tar_file_id, sp.content_byte_offset, sp.content_byte_size "
"FROM sample_parts AS sp "
"JOIN samples AS s "
"ON sp.tar_file_id = s.tar_file_id AND sp.sample_index = s.sample_index "
"WHERE s.sample_key = ? AND sp.part_name = ?",
(key, part_name),
)
if row is None:
raise KeyError(
f"Sample part not found: key={key}, part_name={part_name} in {self.sqlite_path}"
)
return ITarRawSamplePartPointer(
tar_file_id=row[0],
raw_byte_offset=row[1],
raw_byte_size=row[2],
)
def get_sample_pointer_by_key(self, key: str) -> ITarSamplePointer:
"""Get a sample by its key name.
Args:
key: The sample key to look up
Returns:
Tuple of (tar_file_id, sample_key, sample_index, byte_offset, byte_size)
"""
assert self.db is not None, "Database is closed"
sample = self.db.select_one(
"SELECT tar_file_id, sample_key, sample_index, byte_offset, byte_size "
"FROM samples WHERE sample_key = ?",
(key,),
)
if sample is None:
raise KeyError(f"Sample key not found: {key}")
return ITarSamplePointer(
tar_file_id=sample[0],
byte_offset=sample[3],
byte_size=sample[4],
)
def close(self):
"""Close the database connection."""
if self.db is not None:
self.db.thread_close()
del self.db
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class JoinIndexReader:
"""Reads a join index file in different ways.
If a column is specified, only that column is read, otherwise the full rows.
You can iterate over the rows, or read a specific row by index, or get the full tensor.
Each row contains (shard_idx, byte_offset, byte_size) for each column.
"""
join_index_path: EPath
join_index_file: BinaryIO
column: Optional[int]
num_columns: int
has_iterated: bool
index_row_position: int
def __init__(self, join_index_path: EPath, column: Optional[int] = None):
self.join_index_path = join_index_path
self.join_index_byte_size = join_index_path.size()
self.column = column
self.join_index_file = join_index_path.open("rb")
self.has_iterated = False
self.index_row_position = -1
# Read the header
bytes_magic = self.join_index_file.read(8)
assert isinstance(bytes_magic, bytes)
assert bytes_magic[:4] == b"JIDX", f"Invalid magic bytes: {bytes_magic}"
assert bytes_magic[4:8] == b"0001", f"Unsupported version: {bytes_magic[4:8]}"
# Read the number of columns
bytes_seckeys = self.join_index_file.read(8)
assert isinstance(bytes_seckeys, bytes)
self.num_columns = struct.unpack("q", bytes_seckeys)[0]
self.index_row_position = 0
def get_as_tensor(self):
"""Returns the join index as a tensor with shape (N, num_columns, 3)."""
assert not self.has_iterated, "Cannot get_as_tensor after iterating"
import torch
# Read the raw bytes for all N * 3 int64s.
data = self.join_index_file.read()
self.index_file_position = self.join_index_file.tell()
assert len(data) % (8 * 3) == 0, (
f"Index file reading: Expected multiple of 3 * 8 bytes, got {len(data)} bytes"
)
return torch.frombuffer(data, dtype=torch.int64).view(-1, self.num_columns, 3)
def __len__(self):
return (self.join_index_byte_size - 16) // (self.num_columns * 8 * 3)
def __iter__(self):
return self
def _read_one_row(
self, column: Optional[int] = None
) -> Union[None, List[Tuple[int, int, int]]]:
row = []
for col_idx in range(self.num_columns):
if column is not None and col_idx != column:
# Skip this column
self.join_index_file.seek(8 * 3, 1)
continue
bytes_key = self.join_index_file.read(8 * 3)
if not bytes_key:
return None
assert isinstance(bytes_key, bytes)
key_tuple = struct.unpack("qqq", bytes_key)
row.append(key_tuple)
self.index_row_position += 1
return row
def __next__(self) -> Union[None, List[Tuple[int, int, int]]]:
self.has_iterated = True
return self._read_one_row(column=self.column)
def tell_row(self) -> int:
return self.index_row_position
def __getitem__(self, idx: int) -> List[Tuple[int, int, int]]:
"""Reads the idx-th row of the join index file."""
assert 0 <= idx < len(self), f"Index out of bounds: {idx} not in [0, {len(self)})"
# Seek to the correct position
if self.index_row_position != idx:
self.join_index_file.seek(16 + idx * self.num_columns * 8 * 3)
self.index_row_position = idx
# Read the secondary keys
row = self._read_one_row(column=self.column)
assert row is not None, f"Failed to read row {idx}"
return row
def close(self):
self.join_index_file.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import contextlib
import struct
import tarfile
from types import TracebackType
from typing import BinaryIO, Dict, Generator, Optional, Tuple, Type, Union
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.retry_stream import RetryReadStream
ITAR_SUFFIX = ".tar.idx"
@edataclass
class ITarSamplePointer:
"""
Points to a sample inside some tar file on disk.
The tar_file_id refers to the tar_filenames in the reader.
"""
# The index of the tar file, to be matched with the tar_filenames in the reader.
tar_file_id: int
# The byte offset of the sample in the tar file.
byte_offset: int
# The size of the sample in the tar file.
byte_size: int
@edataclass
class ITarRawSamplePartPointer:
"""
Points to a part of a sample inside some tar file on disk.
The tar_file_id refers to the tar_filenames in the reader.
The raw_byte_offset and raw_byte_size refer to the sample's part's raw data in the tar file.
"""
# The index of the tar file, to be matched with the tar_filenames in the reader.
tar_file_id: int
# The byte offset of the file's data in the tar file.
raw_byte_offset: int
# The size of the file's data in the tar file.
raw_byte_size: int
class TarIndexReader:
def __init__(self, tar_path: Union[EPath, str]):
tar_path = EPath(tar_path)
index_path = tar_path.with_suffix(ITAR_SUFFIX)
self._length = index_path.size() // 8
self.itar = index_path.open("rb")
def __getitem__(self, index: int) -> int:
if index >= self._length or index < 0:
raise IndexError(f"Index {index} out of range")
if self.itar.tell() != 8 * index:
self.itar.seek(8 * index)
return struct.unpack("Q", self.itar.read(8))[0]
def __iter__(self) -> Generator[int, None, None]:
self.itar.seek(0)
while True:
raw = self.itar.read(8)
if len(raw) == 0:
break
assert len(raw) == 8
yield struct.unpack("Q", raw)[0]
def __len__(self) -> int:
return self._length
def close(self):
self.itar.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class TarIndexWriter:
def __init__(self, tar_path: EPath):
self.final_name = tar_path.with_suffix(ITAR_SUFFIX)
self.tmp_name = tar_path.with_suffix(ITAR_SUFFIX + ".tmp")
self.itar = self.tmp_name.open("wb")
def append(self, offset: int):
self.itar.write(struct.pack("Q", offset))
def close(self, finalize: bool = True):
self.itar.close()
if finalize:
self.tmp_name.move(self.final_name)
else:
self.tmp_name.unlink()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close(finalize=exc_val is None)
class SubFileReader(BinaryIO):
"""A file-like object that reads a subfile (i.e. offset, size defined portion) of a larger
file."""
def __init__(self, stream: BinaryIO, offset: int, size: int):
self.offset = offset
self._pos = 0
self.size = size
self.stream = stream
self.stream.seek(self.offset)
def read(self, n: int = -1) -> bytes:
if n == -1:
n = self.size - self._pos
else:
n = min(n, self.size - self._pos)
if n == 0:
return b""
read = self.stream.read(n)
self._pos += len(read)
return read
def seek(self, offset: int, whence: int = 0) -> int:
if whence == 0:
self._pos = offset
elif whence == 1:
self._pos += offset
elif whence == 2:
self._pos = self.size + offset
else:
raise ValueError("Invalid whence value")
self._pos = max(0, min(self._pos, self.size))
self.stream.seek(self.offset + self._pos)
return self._pos
def tell(self) -> int:
return self._pos
def __enter__(self) -> BinaryIO:
return self
def __exit__(
self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType
) -> None:
self.close()
def close(self) -> None:
self.stream.close()
def isatty(self) -> bool:
return False
def seekable(self) -> bool:
return True
def writable(self) -> bool:
return False
def get_itar_byte_offset(
path: Union[str, EPath],
sample_offset: int = 0,
) -> int:
"""Gets the byte offset from sample offsets."""
if sample_offset == 0:
return 0
with TarIndexReader(path) as itar:
return itar[sample_offset]
@edataclass
class CacheEntry:
tar_index_reader: TarIndexReader
lookahead_offset: Optional[int] = None
lookahead_byteoffset: Optional[int] = None
class CachedItarOffsetReader:
"""
This class is a high-level wrapper around TarIndexReader that caches some
of the recent lookups for faster access. It is designed for the case when
you need to read multiple offsets from the same tar file or from multiple
tar files.
Args:
cache_size: The number of entries to keep in the cache. By default, we keep 32.
"""
def __init__(self, cache_size: int = 32):
# Maps (tar_file, current_offset) -> CacheEntry
self.tar_index_reader_cache: Dict[Tuple[str, int], CacheEntry] = {}
self.cache_size = cache_size
def _find_or_create_entry(
self,
tar_file: Union[str, "EPath"],
sample_offset: int,
) -> Tuple[Tuple[str, int], CacheEntry]:
"""
1. If we already have a key == (tar_file, sample_offset), return it.
2. Otherwise, create a new entry (and evict if necessary).
"""
tar_file = str(tar_file)
key = (tar_file, sample_offset)
# Direct hit in the cache?
if key in self.tar_index_reader_cache:
return key, self.tar_index_reader_cache[key]
# We didn't find an existing entry. Create a new one.
# Evict if needed.
if len(self.tar_index_reader_cache) >= self.cache_size:
self._evict_one_entry()
new_reader = TarIndexReader(tar_file)
cache_entry = CacheEntry(tar_index_reader=new_reader)
self.tar_index_reader_cache[key] = cache_entry
return key, cache_entry
def _evict_one_entry(self):
"""
Evict the 'oldest' item in the cache. Here we just pop the first item
returned by iter(...) in Python 3.7+ which *should* be insertion order,
but not strictly an LRU. For true LRU, you can use OrderedDict or similar.
"""
oldest_key = next(iter(self.tar_index_reader_cache))
oldest_entry = self.tar_index_reader_cache.pop(oldest_key)
oldest_entry.tar_index_reader.close()
def _get_itar_byte_offset_with_entry(
self,
cache_entry: CacheEntry,
sample_offset: int,
) -> Tuple[int, int]:
"""
Return (start_byte_offset, length_to_next),
possibly using per-entry lookahead for speed.
"""
tar_index_reader = cache_entry.tar_index_reader
# If offset=0, define the result as byte offset=0 for convenience
if sample_offset == 0:
result_byte_offset = 0
elif sample_offset == cache_entry.lookahead_offset:
# Reuse the previously cached byte offset from the lookahead
assert cache_entry.lookahead_byteoffset is not None, (
"Lookahead offset matched but no lookahead byte offset found."
)
result_byte_offset = cache_entry.lookahead_byteoffset
else:
# Normal random access
result_byte_offset = tar_index_reader[sample_offset]
# Prepare the lookahead for (sample_offset+1)
next_offset = sample_offset + 1
try:
cache_entry.lookahead_byteoffset = tar_index_reader[next_offset]
cache_entry.lookahead_offset = next_offset
except IndexError:
cache_entry.lookahead_offset = None
cache_entry.lookahead_byteoffset = None
# length = difference to the next offset, or 0 if none
if cache_entry.lookahead_byteoffset is not None:
length = cache_entry.lookahead_byteoffset - result_byte_offset
else:
length = 0
return result_byte_offset, length
def get_itar_byte_offset(
self,
tar_file: Union[str, "EPath"],
sample_offset: int = 0,
) -> Tuple[int, int]:
"""
High-level API to get the byte offset and length for the given file & sample_offset.
"""
# Find or create the suitable CacheEntry
key, entry = self._find_or_create_entry(tar_file, sample_offset)
# Use (and update) the per-entry lookahead logic
result_byte_offset, length = self._get_itar_byte_offset_with_entry(entry, sample_offset)
# Update cache entry with the new offset
self.tar_index_reader_cache.pop(key)
if entry.lookahead_offset is not None:
new_key = (str(tar_file), entry.lookahead_offset)
if new_key not in self.tar_index_reader_cache:
self.tar_index_reader_cache[new_key] = entry
else:
# Already have this entry in the cache, so we can close the reader and use the existing one
# TODO: We may actually may want to keep multiple readers open, because they may be multiple
# sequences to the same sequence.
entry.tar_index_reader.close()
else:
# No lookahead, so we can close the reader
entry.tar_index_reader.close()
return result_byte_offset, length
class ITarFile(tarfile.TarFile):
"""This class is a subclass of tarfile.TarFile that allows for reading a tarfile,
with random access while keeping the file open.
Usage:
with open(filename, "rb") as fileobj:
with ITarFile.open(fileobj=fileobj, mode="r:") as f:
f.offset = 101888
tarinfo = f.next()
print(tarinfo.name)
member_bytes = f.extractfile(tarinfo)
# Read more offsets here ...
"""
def __init__(self, *args, **kwargs):
self.in_init = True
try:
super().__init__(*args, **kwargs)
finally:
self.in_init = False
def next(self):
if self.in_init:
# Don't automatically read the first member
return None
if self.offset != self.fileobj.tell():
# This prevents tarfile from reading the one byte before
self.fileobj.seek(self.offset)
return super().next()
@contextlib.contextmanager
def open_itar(path: Union[str, EPath], byte_offset: int = 0, byte_size: Optional[int] = None):
"""
Open an indexed tarfile with offset and size.
Args:
path: Path to the tarfile to open
byte_offset: Byte offset within the file
byte_size: Size of the file to read
Returns:
The opened tarfile
"""
path = EPath(path)
# TODO: if tar file startswith(b"\x1f\x8b\x08") -> Seekable gzip file
with path.open("rb") as f:
if f.read(3) == b"\x1f\x8b\x08":
# Open as seekable tgz
raise ValueError("Seekable tgz not supported yet")
if byte_offset != 0 or byte_size is not None:
if byte_size is None:
byte_size = path.size() - byte_offset
with RetryReadStream(path) as stream:
with SubFileReader(
stream,
offset=byte_offset,
size=byte_size,
) as fileobj:
with ITarFile.open(fileobj=fileobj, mode="r:") as f:
yield f
else:
with RetryReadStream(path) as fileobj:
with ITarFile.open(fileobj=fileobj, mode="r:") as f:
yield f
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment