Commit f356f546 authored by maming's avatar maming
Browse files

Initial commit

parents
Pipeline #3339 canceled with stages
[build-system]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"
[project]
name = "megatron-energon"
dynamic = ["version"]
authors = [
{ name="Lukas Vögtle", email="lvoegtle@nvidia.com" },
{ name="Philipp Fischer", email="pfischer@nvidia.com" },
]
description = "Megatron's multi-modal data loader"
readme = "README.md"
license = "BSD-3-Clause"
requires-python = ">=3.10"
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Operating System :: OS Independent",
]
dependencies = [
"braceexpand",
"click",
"dataslots; python_version<'3.10'",
"mfusepy",
"multi-storage-client>=0.18.0,<0.26.0",
"numpy",
"pillow>=10.0.1", # WEBP vulnerability fixed starting from 10.0.1
"pyyaml",
"rapidyaml>=0.10.0",
"s3fs",
"torch",
"tqdm",
"webdataset",
]
[project.optional-dependencies]
dev = [
"ruff",
"sphinxcontrib-napoleon",
"sphinx",
"myst-parser",
"soundfile",
"sphinx-rtd-theme",
"sphinx-click",
]
transforms = [
"torchvision", # Needed for megatron.energon.transforms
]
# Storage services for MSC
s3 = [
"multi-storage-client[boto3]",
]
aistore = [
"multi-storage-client[aistore]",
]
azure-storage-blob = [
"multi-storage-client[azure-storage-blob]",
]
google-cloud-storage = [
"multi-storage-client[google-cloud-storage]",
]
oci = [
"multi-storage-client[oci]",
]
# Dependencies for video decoding
av_decode = [
# needed for efficient audio and video file decoding
"bitstring>=4.2.3",
"sortedcontainers>=2.4.0",
"filetype>=1.2.0",
"ebmlite>=3.3.1",
"av>=14.4.0",
]
# If using guess_content=True for decoding
guess_content = [
"filetype>=1.0.0",
]
[project.urls]
Homepage = "https://github.com/NVIDIA/Megatron-Energon"
[tool.hatch.build.targets.wheel]
packages = ["src/megatron"]
[tool.hatch.build.targets.sdist]
packages = ["src/megatron"]
[project.scripts]
energon = "megatron.energon.cli.main:main"
[tool.hatch.version]
source = "vcs"
[tool.ruff.lint]
extend-select = ["I"]
ignore = ["E741", "E731"]
[tool.ruff]
line-length = 100
target-version = "py310"
include = ["**/*.py", "**/*.pyi"]
exclude = [
".idea",
"docs",
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import click
@dataclass
class HeaderUpdater:
file_ext: str
line_comment: Optional[str] = None
comment_start: Optional[str] = None
comment_end: Optional[str] = None
UPDATE_IDENTIFIER = "Copyright"
HEADER_LINES: Tuple[str, ...] = (
"Copyright (c) 2025, NVIDIA CORPORATION.",
"SPDX-License-Identifier: BSD-3-Clause",
)
_expected_lines: Tuple[str, ...] = ()
def __post_init__(self):
if self.line_comment is not None:
self._expected_lines = tuple(self.line_comment + line for line in self.HEADER_LINES)
else:
assert self.comment_start is not None and self.comment_end is not None
if len(self.HEADER_LINES) >= 2:
self._expected_lines = (
self.comment_start + self.HEADER_LINES[0],
*self.HEADER_LINES[1:-1],
self.HEADER_LINES[-1] + self.comment_end,
)
else:
assert len(self.HEADER_LINES) == 1
self._expected_lines = (
self.comment_start + self.HEADER_LINES[0] + self.comment_end,
)
def has_header(self, file: Path) -> bool:
with file.open() as rf:
num_lines = 0
for line, expected in zip(rf, self._expected_lines):
num_lines += 1
if line.rstrip("\n") != expected:
return False
return num_lines == len(self._expected_lines)
def fix_header(self, file: Path):
contents = file.read_text()
first_comment = self.line_comment if self.line_comment is not None else self.comment_start
if contents.startswith(first_comment) and contents[len(first_comment) :].startswith(
self.UPDATE_IDENTIFIER
):
# Already has header, but want to update
*header_lines, remainder = contents.split("\n", len(self._expected_lines))
new_contents = "\n".join(self._expected_lines) + "\n" + remainder
else:
# No header, add it
new_contents = "\n".join(self._expected_lines) + "\n" + contents
file.write_text(new_contents)
headers = (
HeaderUpdater(
file_ext=".py",
line_comment="# ",
),
HeaderUpdater(
file_ext=".sh",
line_comment="# ",
),
# Do not add to yamls
# HeaderUpdater(
# file_ext=".yml",
# line_comment="# ",
# ),
# HeaderUpdater(
# file_ext=".yaml",
# line_comment="# ",
# ),
HeaderUpdater(
file_ext=".md",
comment_start="<!--- ",
comment_end=" -->",
),
)
@click.command()
@click.argument(
"path",
type=click.Path(exists=True, file_okay=False, dir_okay=True, readable=True, path_type=Path),
)
@click.option(
"--fix", is_flag=True, help="Automatically add missing/adapt existing license headers."
)
def main(path: Path, fix: bool) -> None:
# Get all files to check
success = True
for header_update in headers:
for filepath in path.rglob(f"*{header_update.file_ext}"):
if not header_update.has_header(filepath):
success = False
print(f"Missing header in {filepath}")
if fix:
header_update.fix_header(filepath)
print(f"Added header to {filepath}")
if success:
print("All files have the correct license header.")
elif fix:
print("Some license headers were fixed.")
else:
print(
"Some license headers are missing or incorrect. Run `python scripts/license_headers.py . --fix` to fix."
)
exit(1)
if __name__ == "__main__":
main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
try:
from importlib.metadata import version
__version__ = version("megatron-energon")
except Exception:
# Fallback for development mode when package is not installed
__version__ = "unknown"
from megatron.energon.cache import (
CachePool,
DecodeFileStore,
DirectLazy,
FileCacheLazy,
FileStore,
FileStoreCachePool,
FileStoreDecoder,
Lazy,
MockLazy,
NoCachePool,
SystemFileStore,
)
from megatron.energon.edataclass import edataclass
from megatron.energon.flavors import (
BaseCoreDatasetFactory,
BaseWebdatasetFactory,
CaptioningSample,
CaptioningWebdataset,
CrudeJsonlDatasetFactory,
CrudeSample,
CrudeWebdataset,
DefaultCrudeJsonlDatasetFactory,
DefaultDecoderWebdatasetFactory,
DefaultGenericWebdatasetFactory,
ImageClassificationSample,
ImageClassificationWebdataset,
ImageSample,
ImageWebdataset,
InterleavedSample,
InterleavedWebdataset,
JoinedWebdatasetFactory,
MultiChoiceVQASample,
MultiChoiceVQAWebdataset,
OCRSample,
OCRWebdataset,
Sample,
SampleDecoder,
SavableDataset,
SimilarityInterleavedSample,
SimilarityInterleavedWebdataset,
StandardWebdatasetFactory,
TextSample,
TextWebdataset,
VidQASample,
VidQAWebdataset,
VQAOCRWebdataset,
VQASample,
VQAWebdataset,
)
from megatron.energon.loader import get_loader, get_savable_loader
from megatron.energon.metadataset import (
DatasetLoader,
DatasetLoaderInterface,
Metadataset,
MetadatasetV2,
load_dataset,
prepare_metadataset,
)
from megatron.energon.savable_loader import SavableDataLoader
from megatron.energon.source_info import SourceInfo
from megatron.energon.task_encoder import (
AugmentTaskEncoder,
Batch,
Cooker,
DefaultTaskEncoder,
TaskEncoder,
basic_sample_keys,
batch_list,
batch_pad_stack,
batch_stack,
cooker,
generic_batch,
get_train_dataset,
get_val_dataset,
get_val_datasets,
stateless,
)
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers import (
BatchDataset,
BlendDataset,
ConcatDataset,
EpochizeDataset,
FilterDataset,
GcDataset,
GroupBatchDataset,
IterMapDataset,
LimitDataset,
LogSampleDataset,
MapDataset,
MixBatchDataset,
PackingDataset,
RepeatDataset,
ShuffleBufferDataset,
SkipSample,
concat_pad,
generic_concat,
homogeneous_concat_mix,
)
__all__ = [
"__version__",
"AugmentTaskEncoder",
"BaseCoreDatasetFactory",
"BaseWebdatasetFactory",
"basic_sample_keys",
"batch_list",
"batch_pad_stack",
"batch_stack",
"Batch",
"BatchDataset",
"BlendDataset",
"CachePool",
"CaptioningSample",
"CaptioningWebdataset",
"concat_pad",
"ConcatDataset",
"cooker",
"Cooker",
"CrudeJsonlDatasetFactory",
"CrudeSample",
"CrudeWebdataset",
"DatasetLoader",
"DatasetLoaderInterface",
"DecodeFileStore",
"DefaultCrudeJsonlDatasetFactory",
"DefaultDecoderWebdatasetFactory",
"DefaultGenericWebdatasetFactory",
"DefaultTaskEncoder",
"DirectLazy",
"edataclass",
"EpochizeDataset",
"FileCacheLazy",
"FileStore",
"FileStoreCachePool",
"FileStoreDecoder",
"FilterDataset",
"GcDataset",
"generic_batch",
"generic_concat",
"get_loader",
"get_savable_loader",
"get_train_dataset",
"get_val_dataset",
"get_val_datasets",
"GroupBatchDataset",
"homogeneous_concat_mix",
"ImageClassificationSample",
"ImageClassificationWebdataset",
"ImageSample",
"ImageWebdataset",
"InterleavedSample",
"InterleavedWebdataset",
"IterMapDataset",
"JoinedWebdatasetFactory",
"Lazy",
"LimitDataset",
"load_dataset",
"LogSampleDataset",
"MapDataset",
"Metadataset",
"MetadatasetV2",
"MixBatchDataset",
"MockLazy",
"MultiChoiceVQASample",
"MultiChoiceVQAWebdataset",
"NoCachePool",
"OCRSample",
"OCRWebdataset",
"PackingDataset",
"prepare_metadataset",
"RepeatDataset",
"Sample",
"SampleDecoder",
"SavableDataLoader",
"SavableDataset",
"ShuffleBufferDataset",
"SimilarityInterleavedSample",
"SimilarityInterleavedWebdataset",
"SkipSample",
"SourceInfo",
"StandardWebdatasetFactory",
"stateless",
"SystemFileStore",
"TaskEncoder",
"TextSample",
"TextWebdataset",
"VidQASample",
"VidQAWebdataset",
"VQAOCRWebdataset",
"VQASample",
"VQAWebdataset",
"WorkerConfig",
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.av.av_data import AVData
from megatron.energon.av.av_decoder import (
AVDecoder,
AVMetadata,
AVWebdatasetDecoder,
)
from megatron.energon.av.utils import get_clips_uniform, get_single_frames_uniform
__all__ = [
"AVDecoder",
"AVMetadata",
"AVWebdatasetDecoder",
"AVData",
"get_clips_uniform",
"get_single_frames_uniform",
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Optional
import torch
from megatron.energon.edataclass import edataclass
@edataclass
class AVData:
#: A list of video clips in the shape (frames, channels, h, w)
video_clips: Optional[list[torch.Tensor]]
#: The timestamps for the video clips. List of tuples (start, end) in seconds
video_timestamps: Optional[list[tuple[float, float]]]
#: A list of audio clips in the shape (channels, samples)
audio_clips: Optional[list[torch.Tensor]]
#: The timestamps for the audio clips. List of tuples (start, end) in seconds
audio_timestamps: Optional[list[tuple[float, float]]]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import io
import warnings
from collections.abc import Iterator
from fractions import Fraction
from typing import BinaryIO, Literal, Optional, Sequence, Union, overload
import numpy as np
import torch
from megatron.energon.av.av_init import av_open
from megatron.energon.edataclass import edataclass
# Try importing optional dependencies
try:
import av
import av.container
import filetype
from megatron.energon.av.fastseek import Fastseek
AV_DECODE_AVAILABLE = True
except ImportError as e:
AV_DECODE_AVAILABLE = False
MISSING_DEPENDENCY = str(e)
from megatron.energon.av.av_data import AVData
class AVDecoder:
"""A class that provides a flexible interface for decoding audio and video data.
This class allows users to control decoding parameters at runtime rather than having them fixed
during initialization. It's particularly useful for cases where different samples may need different
decoding parameters.
"""
seeker: "Fastseek"
stream: BinaryIO
suppress_warnings: bool
def __init__(self, stream: BinaryIO, suppress_warnings: bool = False) -> None:
if not AV_DECODE_AVAILABLE:
raise ImportError(
f"AV decoding is not available. Please install the required dependencies with:\n"
f"pip install megatron-energon[av_decode]\n"
f"Missing dependency: {MISSING_DEPENDENCY}. Install megatron-energon[av_decode] to use AVDecoder."
)
self.stream = stream
self.suppress_warnings = suppress_warnings
assert "t" not in getattr(stream, "mode", "rb") and not isinstance(stream, io.TextIOBase), (
"Stream must not be opened in text mode"
)
try:
self.seeker = Fastseek(self.stream)
except ValueError:
self.stream.seek(0)
self.seeker = Fastseek(self.stream, probe=True)
self.stream.seek(0)
def get_video(self) -> AVData:
"""Get the entire video data from the stream (without audio)."""
video_clips, video_timestamps = self.get_video_clips(video_clip_ranges=[(0, float("inf"))])
return AVData(
video_clips=video_clips,
video_timestamps=video_timestamps,
audio_clips=[],
audio_timestamps=[],
)
def get_video_clips(
self,
video_clip_ranges: Sequence[tuple[float, float]],
video_unit: Literal["frames", "seconds"] = "seconds",
video_out_frame_size: Optional[tuple[int, int]] = None,
) -> tuple[list[torch.Tensor], list[tuple[float, float]]]:
"""Get video clips from the video stream.
Args:
video_clip_ranges: List of video clip start and end positions in the given unit (see video_unit)
video_unit: Unit of the video clip positions ("frames" for frame number, "seconds" for timestamp)
video_out_frame_size: Output size for video frames (width, height), or None to use the original frame size
Returns:
A tuple containing:
- video_clips: List of video clips
- video_clips_timestamps: List of timestamps for each video clip start and end in seconds
"""
assert video_unit in ("frames", "seconds")
self.stream.seek(0) # Reset the video stream so that pyav can read the entire container
with av_open(self.stream) as input_container:
assert len(input_container.streams.video) > 0, (
"No video stream found, but video_clips are requested"
)
video_stream = input_container.streams.video[0]
# Pre-calculate timing info for video
average_rate: Fraction = video_stream.average_rate # Frames per second
assert average_rate, "Video stream has no FPS."
time_base: Fraction = video_stream.time_base # Seconds per PTS unit
if video_clip_ranges is not None:
# Convert video_clip_ranges to seeker unit
if video_unit == "frames" and self.seeker.unit == "pts":
# Convert from frames to pts units
video_clip_ranges = [
(
clip[0] / average_rate / time_base,
clip[1] / average_rate / time_base,
)
for clip in video_clip_ranges
]
if not self.suppress_warnings:
warnings.warn(
"Video container unit is frames, but seeking in time units. The resulting frames may be slightly off.",
RuntimeWarning,
)
elif video_unit == "seconds" and self.seeker.unit == "frames":
# Convert from seconds to frames
video_clip_ranges = [
(
clip[0] * average_rate,
clip[1] * average_rate,
)
for clip in video_clip_ranges
]
if not self.suppress_warnings:
warnings.warn(
"Video container unit is time units, but seeking using frame number. The resulting frames may be slightly off.",
RuntimeWarning,
)
elif video_unit == "seconds" and self.seeker.unit == "pts":
# Convert from seconds to pts units
video_clip_ranges = [
(clip[0] / time_base, clip[1] / time_base) for clip in video_clip_ranges
]
frame_iterator: Iterator[av.VideoFrame] = input_container.decode(video=0)
previous_frame_index: int = 0
video_clips_frames: list[list[torch.Tensor]] = []
video_clips_timestamps: list[tuple[float, float]] = []
for video_clip_range in video_clip_ranges:
start_frame_index, end_frame_index = video_clip_range
# Convert to int if possible, set end to None if infinite
start_frame_index = int(start_frame_index)
end_frame_index = int(end_frame_index) if end_frame_index != float("inf") else None
clip_frames: list[torch.Tensor] = []
clip_timestamp_start = None
clip_timestamp_end = None
# Find start frame
if (
iframe_info := self.seeker.should_seek(previous_frame_index, start_frame_index)
) is not None:
input_container.seek(iframe_info.pts, stream=input_container.streams.video[0])
previous_frame_index = iframe_info.index
for frame in frame_iterator:
take_frame = False
last_frame = False
# Container uses frame counts, we can find the exact target frame by counting from the iframe which is at a known offset
if self.seeker.unit == "frames":
if previous_frame_index >= start_frame_index:
take_frame = True
if end_frame_index is not None and previous_frame_index >= end_frame_index:
last_frame = True
# Container uses time, the target frame might not correspond exactly to any metadata but the desired timestamp should
# fall within a frames display period
if self.seeker.unit == "pts":
if start_frame_index <= (frame.pts + frame.duration):
take_frame = True
if end_frame_index is not None and end_frame_index <= (
frame.pts + frame.duration
):
last_frame = True
if take_frame:
if video_out_frame_size is not None:
frame = frame.reformat(
width=video_out_frame_size[0],
height=video_out_frame_size[1],
format="rgb24",
interpolation="BILINEAR",
)
else:
frame = frame.reformat(format="rgb24")
clip_frames.append(torch.from_numpy(frame.to_ndarray()))
if clip_timestamp_start is None:
clip_timestamp_start = float(frame.pts * frame.time_base)
clip_timestamp_end = float((frame.pts + frame.duration) * frame.time_base)
previous_frame_index += 1
if last_frame:
break
if clip_timestamp_start is not None and clip_timestamp_end is not None:
video_clips_frames.append(clip_frames)
video_clips_timestamps.append((clip_timestamp_start, clip_timestamp_end))
# Stack frames within each clip
out_video_clips = [
torch.stack(clip_frames).permute((0, 3, 1, 2)) for clip_frames in video_clips_frames
]
return out_video_clips, video_clips_timestamps
def get_audio(self) -> AVData:
"""Get the entire audio data from the stream."""
audio_clips, audio_timestamps = self.get_audio_clips(audio_clip_ranges=[(0, float("inf"))])
return AVData(
video_clips=[],
video_timestamps=[],
audio_clips=audio_clips,
audio_timestamps=audio_timestamps,
)
def get_audio_clips(
self,
audio_clip_ranges: Sequence[tuple[float, float]],
audio_unit: Literal["samples", "seconds"] = "seconds",
) -> tuple[list[torch.Tensor], list[tuple[float, float]]]:
"""Get audio clips from the audio stream.
Args:
audio_clip_ranges: List of audio clip start and end positions in the given unit (see audio_unit)
audio_unit: Unit of the audio clip positions ("samples" for sample number, "seconds" for timestamp)
Returns:
A tuple containing:
- audio_clips: List of audio clips
- audio_clips_timestamps: List of timestamps for each audio clip start and end in seconds
"""
assert audio_unit in ("samples", "seconds")
self.stream.seek(0) # Reset the video stream so that pyav can read the entire container
with av_open(self.stream) as input_container:
assert len(input_container.streams.audio) > 0, (
"No audio stream found, but audio_clips are requested"
)
audio_stream = input_container.streams.audio[0]
audio_sample_rate = audio_stream.sample_rate
assert audio_sample_rate, "Audio streams without sample rate are not supported"
if audio_unit == "samples":
# Convert from samples to seconds
audio_clip_ranges = [
(
float(clip[0] / audio_sample_rate),
float(clip[1] / audio_sample_rate),
)
for clip in audio_clip_ranges
]
out_audio_clips: list[torch.Tensor] = []
out_audio_clips_timestamps: list[tuple[float, float]] = []
def audio_frame_array(frame: av.AudioFrame) -> np.ndarray:
if frame.format.is_planar:
arr_processed = frame.to_ndarray() # Already (channels, samples)
else:
# Calculate the number of channels and samples
channels = int(frame.layout.nb_channels)
samples = int(frame.samples)
# Reshape the interleaved data to (samples, channels), then transpose to (channels, samples)
arr_processed = np.reshape(frame.to_ndarray(), (samples, channels)).transpose(
1, 0
)
return arr_processed
for start_time, end_time in audio_clip_ranges:
# Seek near start time, but rounded down to the nearest frame
input_container.seek(int(start_time * av.time_base))
if end_time != float("inf"):
desired_duration = end_time - start_time
desired_sample_count = int(desired_duration * audio_sample_rate + 0.5)
else:
desired_sample_count = None
clip_start_time = None
clip_end_time = None
decoded_samples = []
decoded_sample_count = 0
previous_frame = None
for frame in input_container.decode(audio=0):
assert frame.pts is not None, "Audio frame has no PTS timestamp"
cur_frame_time = float(frame.pts * frame.time_base)
cur_frame_duration = float(frame.duration * frame.time_base)
if cur_frame_time < start_time:
# Skip frames before the start time
previous_frame = frame
continue
if clip_start_time is None:
# This is our first matching frame
if previous_frame is not None:
# We have a previous frame that we need to crop to the start time
prev_start_time = float(previous_frame.pts * previous_frame.time_base)
prev_frame_array = audio_frame_array(previous_frame)
prev_frame_array = prev_frame_array[
:, int((start_time - prev_start_time) * audio_sample_rate + 0.5) :
]
decoded_samples.append(prev_frame_array)
decoded_sample_count += prev_frame_array.shape[1]
clip_start_time = start_time
clip_end_time = prev_start_time + cur_frame_duration
else:
clip_start_time = cur_frame_time
# Stop decoding if the end of the frame is past the end time
if cur_frame_time + cur_frame_duration >= end_time:
# Crop the last frame to the end time
last_frame_array = audio_frame_array(frame)
additional_samples = int(
(end_time - cur_frame_time) * audio_sample_rate + 0.5
)
projected_total_samples = decoded_sample_count + additional_samples
projected_total_samples = decoded_sample_count + additional_samples
if (
desired_sample_count is not None
and 0 < abs(projected_total_samples - desired_sample_count) < 2
):
# We are within 2 samples of the desired duration, let's adjust
# the last frame so that we get the desired duration
additional_samples = desired_sample_count - decoded_sample_count
last_frame_array = last_frame_array[:, :additional_samples]
decoded_samples.append(last_frame_array)
decoded_sample_count += last_frame_array.shape[1]
clip_end_time = end_time
break
frame_nd = audio_frame_array(frame) # (channels, samples)
decoded_samples.append(frame_nd)
decoded_sample_count += frame_nd.shape[1]
clip_end_time = cur_frame_time + cur_frame_duration
if decoded_samples:
# Combine all channels/samples along samples axis
clip_all = np.concatenate(decoded_samples, axis=-1) # (channels, total_samples)
if clip_start_time is not None and clip_end_time is not None:
out_audio_clips.append(torch.from_numpy(clip_all))
out_audio_clips_timestamps.append((clip_start_time, clip_end_time))
return out_audio_clips, out_audio_clips_timestamps
def get_video_with_audio(self) -> AVData:
"""Get the entire video and audio data from the stream."""
return self.get_clips(
video_clip_ranges=[(0, float("inf"))],
audio_clip_ranges=[(0, float("inf"))],
video_unit="seconds",
audio_unit="seconds",
)
def get_clips(
self,
video_clip_ranges: Optional[Sequence[tuple[float, float]]] = None,
audio_clip_ranges: Optional[Sequence[tuple[float, float]]] = None,
video_unit: Literal["frames", "seconds"] = "seconds",
audio_unit: Literal["samples", "seconds"] = "seconds",
video_out_frame_size: Optional[tuple[int, int]] = None,
) -> AVData:
"""Get clips from the video and/or audio streams.
Given a list of (start, end) tuples, this method will decode the video and/or audio clips
at the specified start and end times. The units of the start and end times are specified by
the `video_unit` and `audio_unit` arguments.
Args:
video_clip_ranges: List of video clip start and end positions in the given unit (see video_unit)
audio_clip_ranges: List of audio clip start and end positions in the given unit (see audio_unit)
video_unit: Unit of the video clip positions ("frames" for frame number, "seconds" for timestamp)
audio_unit: Unit of the audio clip positions ("samples" for sample number, "seconds" for timestamp)
video_out_frame_size: Output size for video frames (width, height), or None to use the original frame size
Returns:
AVData containing the decoded video and audio clips
"""
if video_clip_ranges is not None:
ret_video_clips, ret_video_clips_timestamps = self.get_video_clips(
video_clip_ranges, video_unit, video_out_frame_size
)
else:
ret_video_clips = []
ret_video_clips_timestamps = []
if audio_clip_ranges is not None:
ret_audio_clips, ret_audio_clips_timestamps = self.get_audio_clips(
audio_clip_ranges, audio_unit
)
else:
ret_audio_clips = []
ret_audio_clips_timestamps = []
return AVData(
video_clips=ret_video_clips,
video_timestamps=ret_video_clips_timestamps,
audio_clips=ret_audio_clips,
audio_timestamps=ret_audio_clips_timestamps,
)
def get_frames(
self,
video_decode_audio: bool = False,
) -> Optional[AVData]:
"""Decode the entire audio/video data and return an AVData object.
Args:
video_decode_audio: Whether to decode audio from video
Returns:
VideoData containing the decoded frames and metadata, or None if decoding failed
The video tensor is in the shape (frames, channels, height, width)
The audio tensor is in the shape (channels, samples)
"""
extension = self._get_extension()
if extension is not None:
extension = extension.lower()
if extension in ("mov", "mp4", "webm", "mkv", "avi", "m4v"):
if video_decode_audio:
return self.get_video_with_audio()
else:
return self.get_video()
elif extension in ("flac", "mp3", "wav"):
return self.get_audio()
else:
return None
def _get_extension(self) -> Optional[str]:
"""Get the file extension from the raw data."""
# Try to guess the file type using the first few bytes
self.stream.seek(0) # Reset stream position before guessing
ftype = filetype.guess(self.stream)
if ftype is None:
return None
return ftype.extension
def get_video_fps(self) -> float:
"""Get the FPS of the video stream."""
metadata = self.get_metadata(
get_video=True,
get_video_duration=False,
get_video_frame_count=False,
get_video_frame_size=False,
get_audio=False,
)
assert metadata.video_fps is not None
return metadata.video_fps
def get_audio_samples_per_second(self) -> int:
"""Get the number of samples per second of the audio stream."""
metadata = self.get_metadata(
get_video=False,
get_audio=True,
get_audio_duration=False,
)
assert metadata.audio_sample_rate is not None
return metadata.audio_sample_rate
def has_audio_stream(self) -> bool:
"""Check if the stream has an audio stream."""
self.stream.seek(0)
with av_open(self.stream) as input_container:
return len(input_container.streams.audio) > 0
def has_video_stream(self) -> bool:
"""Check if the stream has a video stream."""
self.stream.seek(0)
with av_open(self.stream) as input_container:
return len(input_container.streams.video) > 0
def get_audio_duration(self) -> Optional[float]:
"""Get the duration of the audio stream.
Returns:
The duration of the audio stream in seconds
"""
metadata = self.get_metadata(
get_video=False,
get_audio=True,
get_audio_duration=True,
)
return metadata.audio_duration
@overload
def get_video_duration(self, get_frame_count: Literal[True]) -> tuple[Optional[float], int]: ...
@overload
def get_video_duration(
self, get_frame_count: bool = False
) -> tuple[Optional[float], Optional[int]]: ...
def get_video_duration(
self, get_frame_count: bool = False
) -> tuple[Optional[float], Optional[int]]:
"""Get the duration of the video stream.
Args:
get_frame_count: Whether to return the number of frames in the video. This is a more costly operation.
Returns:
A tuple containing the duration in seconds, and the number of frames in the video
"""
metadata = self.get_metadata(
get_video=True,
get_video_duration=True,
get_video_frame_count=get_frame_count,
get_video_frame_size=False,
get_audio=False,
get_audio_duration=False,
)
return metadata.video_duration, metadata.video_num_frames
def get_metadata(
self,
get_video: bool = True,
get_video_duration: bool = True,
get_video_frame_count: bool = True,
get_video_frame_size: bool = True,
get_audio: bool = True,
get_audio_duration: bool = True,
) -> "AVMetadata":
"""Get the metadata of the media object.
Args:
get_video: Compute video metadata.
get_video_duration: Compute video duration if not found in header.
get_video_frame_count: Compute video frame count if not found in header.
get_video_frame_size: Compute video frame size if not found in header.
get_audio: Compute audio metadata.
get_audio_duration: Compute audio duration if not found in header.
"""
self.stream.seek(0)
with av_open(self.stream) as input_container:
metadata = AVMetadata()
if get_video and input_container.streams.video:
video_stream = input_container.streams.video[0]
metadata.video_duration = video_stream.duration
if get_video_duration and metadata.video_duration is None:
# If duration isn't found in header the whole video is decoded to
# determine the duration.
metadata.video_num_frames = 0
last_packet = None
for packet in input_container.demux(video=0):
if packet.pts is not None:
metadata.video_num_frames += 1
last_packet = packet
if last_packet is not None and last_packet.duration is not None:
assert last_packet.pts is not None
metadata.video_duration = last_packet.pts + last_packet.duration
if metadata.video_duration is not None:
if video_stream.start_time is not None:
metadata.video_duration -= video_stream.start_time
if video_stream.time_base is not None:
metadata.video_duration *= float(video_stream.time_base)
if get_video_frame_count and metadata.video_num_frames is None:
metadata.video_num_frames = sum(
1 for p in input_container.demux(video=0) if p.pts is not None
)
if video_stream.average_rate is not None:
metadata.video_fps = float(video_stream.average_rate)
elif metadata.video_num_frames is not None and metadata.video_duration is not None:
metadata.video_fps = metadata.video_num_frames / metadata.video_duration
if get_video_frame_size:
input_container.seek(0)
for first_frame in input_container.decode(video=0):
metadata.video_width = first_frame.width
metadata.video_height = first_frame.height
break
else:
metadata.video_width = video_stream.width
metadata.video_height = video_stream.height
if get_audio and input_container.streams.audio:
audio_stream = input_container.streams.audio[0]
metadata.audio_sample_rate = audio_stream.sample_rate
metadata.audio_duration = audio_stream.duration
if get_audio_duration and metadata.audio_duration is None:
last_packet = None
input_container.seek(0)
for packet in input_container.demux(audio=0):
if packet.pts is not None:
last_packet = packet
if last_packet is not None and last_packet.duration is not None:
assert last_packet.pts is not None
metadata.audio_duration = last_packet.pts + last_packet.duration
if metadata.audio_duration is not None:
if audio_stream.start_time is not None:
metadata.audio_duration -= audio_stream.start_time
if audio_stream.time_base is not None:
metadata.audio_duration *= float(audio_stream.time_base)
metadata.audio_channels = audio_stream.channels
return metadata
def __repr__(self):
return f"AVDecoder(stream={self.stream!r})"
class AVWebdatasetDecoder:
"""A decoder class for audio and video data that provides a consistent interface for decoding media files.
This class encapsulates the decoding parameters and provides a callable interface that can be used
with webdataset or other data loading pipelines. It supports both video and audio decoding with
configurable parameters for frame extraction, resizing, and audio clip extraction.
Args:
video_decode_audio: Whether to decode audio from video files. If True, audio will be
extracted alongside video frames.
av_decode: If "AVDecoder", returns an AVDecoder instance for flexible decoding. If "torch",
returns decoded VideoData.
Example:
>>> decoder = AVWebdatasetDecoder(
... video_decode_audio=True,
... av_decode="AVDecoder"
... )
>>> result = decoder("video.mp4", video_bytes)
"""
def __init__(
self,
video_decode_audio: bool,
av_decode: Literal["torch", "AVDecoder", "pyav"] = "AVDecoder",
) -> None:
self.video_decode_audio = video_decode_audio
self.av_decode = av_decode
def read_av_data(self, data: bytes) -> AVDecoder:
"""Decoder function that returns an AVData object for flexible decoding.
Args:
data: The raw bytes of the media file
Returns:
AVData object that can be used to decode the media with custom parameters
"""
return AVDecoder(io.BytesIO(data))
def __call__(
self, key: str, data: bytes
) -> Optional[
Union[AVData, AVDecoder, "av.container.InputContainer", "av.container.OutputContainer"]
]:
"""
Extract the video or audio data from default media extensions.
Args:
key: media file extension
data: raw media bytes
Returns:
If av_decode is "torch", returns VideoData containing the decoded frames and metadata.
If av_decode is "AVDecoder", returns an AVDecoder instance for flexible decoding.
If av_decode is "pyav", returns an av.container.InputContainer instance.
Returns None if decoding failed or file type is not supported.
"""
key = key.lower()
if not any(
key == ext or key.endswith("." + ext)
for ext in ("mp4", "avi", "mov", "webm", "mkv", "flac", "mp3", "wav", "flv")
):
return None
av_decoder = self.read_av_data(data)
if self.av_decode == "AVDecoder":
return av_decoder
elif self.av_decode == "pyav":
return av_open(av_decoder.stream)
elif self.av_decode == "torch":
return av_decoder.get_frames(
video_decode_audio=self.video_decode_audio,
)
else:
raise ValueError(f"Invalid av_decode value: {self.av_decode}")
@edataclass
class AVMetadata:
"""Metadata of the media object."""
video_duration: Optional[float] = None
video_num_frames: Optional[int] = None
video_fps: Optional[float] = None
video_width: Optional[int] = None
video_height: Optional[int] = None
audio_duration: Optional[float] = None
audio_channels: Optional[int] = None
audio_sample_rate: Optional[int] = None
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
try:
# Try importing optional dependencies
import av
import av.container
except ImportError:
pass
def av_open(file: str) -> "av.container.InputContainer":
"""Open a file with PyAV.
This function is a wrapper around av.open that disables additional threads in the container.
"""
input_container = av.open(file, "r")
try:
initialize_av_container(input_container)
except Exception:
input_container.close()
raise
return input_container
def initialize_av_container(input_container: "av.container.InputContainer") -> None:
"""Every PyAV container should be initialized with this function.
This function ensures that no additional threads are created.
This is to avoid deadlocks in ffmpeg when deallocating the container.
Furthermore, we cannot have multiple threads before forking the process when
using torch data loaders with multiple workers.
"""
for stream in input_container.streams:
cc = stream.codec_context
if cc is not None:
cc.thread_type = "NONE"
cc.thread_count = 0
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from .fastseek import Fastseek as Fastseek
from .keyframeinfo import KeyframeInfo as KeyframeInfo
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from collections import defaultdict
from bitstring.bits import BitsType
from ebmlite import MasterElement, loadSchema
from sortedcontainers import SortedList
from ..keyframeinfo import KeyframeInfo
class CueTrackPositions:
track: int
def __init__(self, el: MasterElement) -> None:
for c in el:
if c.name == "CueTrack":
self.track = c.value
class CuePoint:
time: int
track_positions: CueTrackPositions
def __init__(self, el: MasterElement) -> None:
for c in el:
if c.name == "CueTime":
self.time = c.value
if c.name == "CueTrackPositions":
self.track_positions = CueTrackPositions(c)
def parse_matroska(file: BitsType) -> SortedList:
try:
schema = loadSchema("matroska.xml")
doc = schema.load(file, headers=True)
except (KeyError, IOError, TypeError) as e:
raise ValueError(f"Matroska parsing failed with error {e}")
# Get cue times
stack = [c for c in doc if c.name == "Segment"]
cues = defaultdict(SortedList)
while len(stack) > 0:
el = stack.pop()
if el.name == "CuePoint":
cue = CuePoint(el)
cues[cue.track_positions.track].add(KeyframeInfo(cue.time, cue.time))
elif isinstance(el, MasterElement):
stack.extend([c for c in el if c.name in ["Cues", "CuePoint"]])
return cues
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from collections import defaultdict
from itertools import accumulate
from typing import Any, Generator
from bitstring import ConstBitStream, Error
from bitstring.bits import BitsType
from sortedcontainers import SortedList
from ..keyframeinfo import KeyframeInfo
box_atoms = {"moov", "trak", "mdia", "minf", "stbl", "edts"} # Non-exhaustive
def parse_table(cbs: ConstBitStream, table_size: int, struct: dict[str, str]) -> dict[str, Any]:
return [
dict(zip(struct.keys(), cbs.readlist(", ".join(struct.values()))))
for _ in range(table_size)
]
class Atom:
skip_version_and_flags: bool = False
@staticmethod
def make_atom(cbs: ConstBitStream) -> "Atom":
size: int = cbs.read("uint:32")
name: str = cbs.read("bytes:4").decode("ascii")
box: bool = name in box_atoms
if size == 0:
raise RuntimeError(
"MPEG parser detected a zero byte atom, this likely indicates a corrupt video."
)
subclass_list = [c for c in Atom.__subclasses__() if c.__name__ == name.upper()]
atom_class: type = Atom
if len(subclass_list) > 0:
atom_class: type = subclass_list[0]
cbs.bytepos += 4 # Skip version and flags TODO not every atom needs this
atom = atom_class(size, name, box)
atom._parse(cbs)
return atom
def __init__(self, size: int, name: str, box: bool) -> None:
self.size: int = size
self.name: str = name
self.box: bool = box
def _parse(self, cbs: ConstBitStream) -> None:
if not self.box:
cbs.bytepos += self.size - 8
def __str__(self) -> str:
return f"{self.name=}, {self.size=}, {self.box=}"
class TKHD(Atom):
"""
Parses the track header atom, see https://developer.apple.com/documentation/quicktime-file-format/track_header_atom
"""
def _parse(self, cbs: ConstBitStream) -> None:
cbs.bytepos += 8 # skip creation time and modification time
self.track_id: int = cbs.read("uint:32")
cbs.bytepos += 68 # Skip rest of structure
class HDLR(Atom):
"""
Parses the media handler atom, see https://developer.apple.com/documentation/quicktime-file-format/handler_reference_atom
NOTE: currently unused but could speed up parsing by skipping audio tracks
"""
def _parse(self, cbs: ConstBitStream) -> None:
self.component_type = cbs.read("bytes:4").decode("ascii")
self.component_subtype = cbs.read("bytes:4").decode("ascii")
# Skip rest of structure, the last field is variable so we need to use the total size
# 24 bytes already read (size (4), type (4), version (1), flags (3), component type (4), component subtype (4))
cbs.bytepos += self.size - 20
class STSS(Atom):
"""
Parses the sync sample atom https://developer.apple.com/documentation/quicktime-file-format/sample_table_atom/sync_sample_atom
"""
def _parse(self, cbs: ConstBitStream) -> None:
self.number_of_entries: int = cbs.read("uint:32")
self.sync_sample_table: dict[str, Any] = parse_table(
cbs, self.number_of_entries, {"number": "uint:32"}
)
class STTS(Atom):
"""
Parses the time to sample atom https://developer.apple.com/documentation/quicktime-file-format/time-to-sample_atom
"""
def _parse(self, cbs: ConstBitStream) -> None:
self.number_of_entries: int = cbs.read("uint:32")
self.time_to_sample_table: dict[str, Any] = parse_table(
cbs,
self.number_of_entries,
{"sample_count": "uint:32", "sample_duration": "uint:32"},
)
class CTTS(Atom):
"""
Parses the composition offset atom https://developer.apple.com/documentation/quicktime-file-format/composition_offset_atom
"""
def _parse(self, cbs: ConstBitStream) -> None:
self.number_of_entries: int = cbs.read("uint:32")
self.composition_offset_table: dict[str, Any] = parse_table(
cbs,
self.number_of_entries,
{
"sample_count": "uint:32",
"composition_offset": "int:32",
"media_rate": "",
},
)
class ELST(Atom):
"""
Parses the edit list atom https://developer.apple.com/documentation/quicktime-file-format/edit_list_atom
"""
def _parse(self, cbs: ConstBitStream) -> None:
self.number_of_entries: int = cbs.read("uint:32")
self.edit_list_table: dict[str, Any] = parse_table(
cbs,
self.number_of_entries,
{
"track_duration": "uint:32",
"media_time": "int:32",
"media_rate": "int:32",
},
)
class MDAT(Atom):
"""
Parses the media data atom https: https://developer.apple.com/documentation/quicktime-file-format/movie_data_atom
This is only here to handle the unusual size handling of mdat, if the normal size field is set to 1
then the actual size is stored as a 64 bit integer
"""
def _parse(self, cbs: ConstBitStream) -> None:
if self.size == 1:
cbs.bytepos -= 4 # No version or flags for mdat
self.size = cbs.read("uint:64")
seekto = self.size - 16
else:
seekto = self.size - 12
if cbs.bytepos + seekto >= (cbs.len / 8):
raise StopIteration()
cbs.bytepos += seekto
def parse_atoms(file: BitsType) -> Generator[Atom, None, None]:
try:
cbs = ConstBitStream(file)
while cbs.pos < len(cbs):
try:
yield Atom.make_atom(cbs)
except StopIteration:
return
except Error as e:
raise ValueError(f"MPEG parsing failed with error {e}")
def parse_mpeg(file: BitsType) -> dict[int, SortedList]:
sync_samples = {}
decode_timestamps = {}
presentation_time_offsets = {}
start_offsets = defaultdict(int)
current_track = -1
for a in parse_atoms(file):
if a.name == "tkhd":
a: TKHD
current_track = a.track_id
elif a.name == "stts":
a: STTS
decode_timestamps[current_track] = list(
accumulate(
sum(
[
[entry["sample_duration"]] * entry["sample_count"]
for entry in a.time_to_sample_table
],
[0],
)
)
)
elif a.name == "ctts":
a: CTTS
presentation_time_offsets[current_track] = sum(
[
[entry["composition_offset"]] * entry["sample_count"]
for entry in a.composition_offset_table
],
[],
)
elif a.name == "stss":
a: STSS
sync_samples[current_track] = [ss["number"] - 1 for ss in a.sync_sample_table]
elif a.name == "elst":
# NOTE the "media_time" here is a "delay" between decoding and presenting the first sample.
# We follow the ffmpeg convention that the first frame displays at time 0 which means we should
# *subtract* this offset from the decoding time values rather than adding it to presentation time values
# TODO there can be more than one of these, figure out how to handle it
a: ELST
start_offsets[current_track] = -a.edit_list_table[0]["media_time"]
keyframes = defaultdict(SortedList)
try:
for track_id in sync_samples.keys():
ptos = presentation_time_offsets.get(track_id)
dts = decode_timestamps[track_id]
for keyframe_number in sync_samples[track_id]:
pts = (
dts[keyframe_number]
+ start_offsets[track_id]
+ (0 if ptos is None else ptos[keyframe_number])
)
keyframes[track_id].add(KeyframeInfo(keyframe_number, pts))
except (KeyError, IndexError) as e:
raise ValueError(f"MPEG parsing failed with error {e}")
return keyframes
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from sortedcontainers import SortedList
from ...av_init import av_open
from ..keyframeinfo import KeyframeInfo
def parse_probe(file):
keyframes = {}
with av_open(file) as input_container:
for stream_idx, stream in enumerate(input_container.streams.video):
packet_pts = [
(index, p.pts)
for index, p in enumerate(input_container.demux(video=stream_idx))
if p.is_keyframe
]
packet_pts.sort(key=lambda x: x[1])
keyframes[stream.id] = SortedList([KeyframeInfo(*p) for p in packet_pts])
return keyframes
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Literal, Optional
import filetype
from bitstring.bits import BitsType
from sortedcontainers import SortedList
from .containers.matroska import parse_matroska
from .containers.mpeg import parse_mpeg
from .containers.probe import parse_probe
from .keyframeinfo import KeyframeInfo
class Fastseek:
"""
Gathers information from the video container file (e.g. metadata which requires minimal decoding)
to find keyframes in the video for fast seeking.
Information is returned in the form of KeyframeInfo structures which can be used by a decoding loop
to make informed decisions about the best seeking behavior
Currently supports:
- MP4/MOV: frames are indexed by number and frame counting can be used to get the exact frame
- Matroska/WebM: frames are indexed by time and inter-frame duration must be accounted for to get to the right frame
If your container is not listed above, pass "probe=True" to the constructor, this will use ffmpeg to parse the stream
without decoding it. Frames will be indexed by number. This is not as fast as using a supported container but is still
significantly faster than sequential decoding.
"""
keyframes: dict[int, SortedList[KeyframeInfo]]
unit: Literal["frames", "pts"]
mime: str
def __init__(self, file: BitsType, probe: bool = False) -> None:
"""Initialize the Fastseek object.
Args:
file: The video file data as a bitstring BitsType object. This should contain the raw bytes of the video file.
probe: If True, use ffmpeg to probe the stream without decoding. This is slower but works with any container format.
If False (default), attempt to parse the container format directly. Only works with MP4/MOV and Matroska/WebM.
Raises:
ValueError: If the file type cannot be determined or if the container format is not supported (when probe=False).
"""
if probe:
self.keyframes = parse_probe(file)
self.unit = "frames"
else:
ftype = filetype.guess(file)
if ftype is None:
raise ValueError(
"Unable to determine file type (hint: try passing probe=True to the Fastseek constructor)"
)
self.mime = ftype.mime
if ftype.mime in ["video/mp4", "video/quicktime"]:
self.keyframes = parse_mpeg(file)
self.unit = "frames"
elif ftype.mime in ["video/x-matroska", "video/webm"]:
self.keyframes = parse_matroska(file)
self.unit = "pts"
else:
raise ValueError(
f"Unsupported container: {ftype.mime} (hint: try passing probe=True to the Fastseek constructor)"
)
if len(self.keyframes) == 0:
raise ValueError(
f"The parser for {ftype.mime} was unable to find any streams (hint: try passing probe=True to the Fastseek constructor)"
)
if all(len(kf) == 0 for kf in self.keyframes.values()):
raise ValueError(
f"The parser for {ftype.mime} was unable to find any keyframes (hint: try passing probe=True to the Fastseek constructor)"
)
def should_seek(self, current: int, target: int, stream: int = 0) -> Optional[KeyframeInfo]:
"""Determine if seeking to a keyframe is necessary to reach the target frame.
This method helps optimize video seeking by determining whether a seek operation
is needed to reach the target frame. It returns information about the nearest
keyframe only if seeking would be beneficial (i.e., if sequential decoding from
the current position would be less efficient).
Args:
current: The current frame number or timestamp (depending on container format)
target: The desired frame number or timestamp to seek to
stream: The video stream index to use. Defaults to 0.
Returns:
Information about the nearest keyframe if seeking would be beneficial,
or None if sequential decoding from current position is more efficient.
The KeyframeInfo contains the keyframe's position and timing information.
Note:
The units for current and target depend on the container format:
- For MP4/MOV: frame numbers (count-based)
- For Matroska/WebM: timestamps (time-based)
"""
nearest_iframe: KeyframeInfo = self.nearest_keyframe(target, stream)
return (
nearest_iframe
if (current < nearest_iframe.index <= target) or (target < current)
else None
)
def nearest_keyframe(self, target: int, stream: int = 0) -> KeyframeInfo:
"""Find the nearest keyframe that comes before the target frame.
This method performs a binary search to find the keyframe that is closest to,
but not after, the target frame position. This is useful for determining the
optimal starting point for decoding to reach a specific frame.
Args:
target: The target frame number or timestamp to find the nearest keyframe for.
The unit (frame count or timestamp) depends on the container format.
stream: The video stream index to use. Defaults to 0.
Used when the container has multiple video streams.
Returns:
Information about the nearest keyframe before the target position.
Contains details like the keyframe's position, timestamp, and file offset.
Note:
The implementation currently uses a list-based approach for stream selection
as some video containers don't report track IDs correctly. This is a temporary
workaround and may be updated in the future.
"""
if stream >= len(self.keyframes):
raise ValueError(f"No stream with index {stream}")
stream_id = list(self.keyframes.keys())[stream]
if len(self.keyframes[stream_id]) == 0:
raise ValueError(f"No keyframes found for stream {stream}")
nearest_iframe_to_target_index: int = self.keyframes[stream_id].bisect_left(target) - 1
return self.keyframes[stream_id][max(0, nearest_iframe_to_target_index)]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from dataclasses import dataclass
@dataclass
class KeyframeInfo:
"""
Information about a detected keyframe
The exact meaning of the information will vary per container, however in general
index: the unit of representation for a frame (e.g., frame number)
pts: a timestamp that can be used by a decoder to seek to exactly this frame
"""
index: int
pts: int
def __eq__(self, other) -> bool:
if isinstance(other, KeyframeInfo):
return self.index == other.index
return self.index == other
def __lt__(self, other) -> bool:
if isinstance(other, KeyframeInfo):
return self.index < other.index
return self.index < other
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Literal, Optional, Tuple, Union, overload
import numpy as np
import torch
from megatron.energon.av.av_decoder import AVData, AVDecoder
def get_clips_uniform(
av_decoder: AVDecoder,
clip_duration_seconds: float,
num_clips: int,
request_video: bool = False,
request_audio: bool = False,
video_out_frame_size: Optional[tuple[int, int]] = None,
) -> AVData:
"""Extracts a sequence of clips, such that each clip is of
equal duration and the clips are equidistant from each other.
Args:
av_decoder: An AVDecoder instance.
clip_duration_seconds: The duration of each clip in seconds.
num_clips: The number of clips to extract.
request_video: Whether to request video clips.
request_audio: Whether to request audio clips.
video_out_frame_size: The size of the video frames to output, or None to use the original size.
Returns:
An AVData object containing the extracted video and audio clips.
"""
if not request_video and not request_audio:
raise ValueError("You must request at least one of video or audio")
video_duration = float("inf")
audio_duration = float("inf")
if request_video:
video_duration, _ = av_decoder.get_video_duration()
if video_duration is None:
raise ValueError("No video duration found")
if request_audio:
audio_duration = av_decoder.get_audio_duration()
if audio_duration is None:
raise ValueError("No audio duration found")
# Typically, audio and video don't have the exact same duration, so we take the minimum
# so that we can safely extract clips of equal duration.
total_duration = min(video_duration, audio_duration)
assert total_duration != float("inf")
if clip_duration_seconds == 0:
# Special case of single frames: End point should be start of last frame
video_fps = av_decoder.get_video_fps()
video_spf = 1 / video_fps
first_start_time = video_spf * 0.5
last_start_time = total_duration - video_spf * 0.5
else:
first_start_time = 0
last_start_time = total_duration - clip_duration_seconds
clips = [
(float(start_time), float(start_time + clip_duration_seconds))
for start_time in np.linspace(first_start_time, last_start_time, num_clips)
]
return av_decoder.get_clips(
video_clip_ranges=clips if request_video else None,
audio_clip_ranges=clips if request_audio else None,
video_unit="seconds",
audio_unit="seconds",
video_out_frame_size=video_out_frame_size,
)
@overload
def get_single_frames_uniform(
av_decoder: "AVDecoder",
num_frames: int,
*,
video_out_frame_size: Optional[Tuple[int, int]] = None,
return_timestamps: Literal[False] = False,
) -> torch.Tensor: ...
@overload
def get_single_frames_uniform(
av_decoder: "AVDecoder",
num_frames: int,
*,
video_out_frame_size: Optional[Tuple[int, int]] = None,
return_timestamps: Literal[True],
) -> Tuple[torch.Tensor, List[float]]: ...
def get_single_frames_uniform(
av_decoder: AVDecoder,
num_frames: int,
*,
video_out_frame_size: Optional[tuple[int, int]] = None,
return_timestamps: bool = False,
) -> Union[torch.Tensor, tuple[torch.Tensor, list[float]]]:
"""Extracts a sequence of clips, such that each clip contains
only a single frame and the frames are equidistant from each other.
Args:
av_decoder: An AVDecoder instance.
num_frames: The number of frames to extract.
video_out_frame_size: The size of the video frames to output, or None to use the original size.
Returns:
A tensor of shape (num_frames, channels, height, width) containing the extracted frames.
"""
av_data = get_clips_uniform(
av_decoder=av_decoder,
clip_duration_seconds=0,
num_clips=num_frames,
request_video=True,
request_audio=False,
video_out_frame_size=video_out_frame_size,
)
if len(av_data.video_clips) == 0:
raise ValueError("No video frames found")
# Concatenate all video single-frame clips to form a single tensor
video_tensor = torch.cat(av_data.video_clips, dim=0)
if return_timestamps:
return video_tensor, [t for t, _ in av_data.video_timestamps]
else:
return video_tensor
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import re
from collections import defaultdict
from typing import List, Tuple
__all__ = ["collapse"]
"""Helper functions for string tokenization and expression building."""
_num_re = re.compile(r"\d+")
def _tokenize(s: str) -> Tuple[List[str], List[Tuple[str, int, int]]]:
"""
Split the string into literal and numeric parts.
Always starts with a literal (sometimes empty)
Example:
"partition_00/shard_000000.tar" ->
lits = ["partition_", "/shard_", ".tar"]
nums = [("00", 0, 2), ("000000", 0, 6)]
Args:
s: Input string to tokenize.
Returns:
Tuple containing:
- lits: List of literal pieces, length = #nums + 1
- nums: List of tuples (raw, value, width) where:
- raw: original substring (keeps zero-padding)
- value: int(raw)
- width: len(raw)
"""
lits, nums = [], []
pos = 0
for m in _num_re.finditer(s):
lits.append(s[pos : m.start()])
raw = m.group(0)
nums.append((raw, int(raw), len(raw)))
pos = m.end()
lits.append(s[pos:])
return lits, nums
def _build_expr(
lits: List[str],
nums: List[Tuple[str, int, int]],
var_idx: int,
start_raw: str,
end_raw: str,
) -> str:
"""
Re-assemble the template, replacing slot with brace expansion syntax.
Args:
lits: List of literal pieces of the string.
nums: List of numeric parts as tuples (raw, value, width).
var_idx: Index of the numeric slot to replace with range.
start_raw: Starting value (raw string).
end_raw: Ending value (raw string).
Returns:
String with brace expansion syntax.
"""
parts: List[str] = []
for i in range(len(nums)):
parts.append(lits[i])
if i == var_idx:
parts.append(f"{{{start_raw}..{end_raw}}}")
else:
parts.append(nums[i][0])
parts.append(lits[-1])
return "".join(parts)
def _streaming_mode(strings: List[str]) -> List[str]:
"""
Compress strings in order-preserving streaming mode.
Complexity: O(N)
Args:
strings: List of strings to compress.
Returns:
List of compressed expressions.
"""
# Result list with brace expressions
out: List[str] = []
# Total number of strings
n = len(strings)
# Current index
i = 0
while i < n:
lits0, nums0 = _tokenize(strings[i])
# Strings without numbers can never form a range
if not nums0:
out.append(strings[i])
i += 1
continue
# Which numeric slot is changing?
var_idx: int = -1
start_raw: str = ""
prev_nums = nums0
# Last index in the current candidate range
run_end = i
# Starting with string `i` as the template, check subsequent strings `j` as long as they match
j = i + 1
while j < n:
lits1, nums1 = _tokenize(strings[j])
# Template must be identical (same number of literals and numeric slots)
if lits1 != lits0 or len(nums1) != len(nums0):
break
# Exactly one numeric slot may differ ─ find it
diff_slots = [k for k, (a, b) in enumerate(zip(prev_nums, nums1)) if a[1] != b[1]]
if len(diff_slots) != 1:
break
k = diff_slots[0]
# Width must stay the same
if nums1[k][2] != prev_nums[k][2]:
break
# Same changing slot for the whole run
if var_idx == -1:
var_idx, start_raw = k, nums0[k][0]
elif var_idx != k:
break
# Contiguous ascending (+1) only
if nums1[k][1] != prev_nums[k][1] + 1:
break
# OK - extend run
run_end = j
prev_nums = nums1
j += 1
run_len = run_end - i + 1
if run_len >= 2 and var_idx != -1:
# Emit range
end_raw = prev_nums[var_idx][0]
out.append(_build_expr(lits0, nums0, var_idx, start_raw, end_raw))
i = run_end + 1
else:
# Single string
out.append(strings[i])
i += 1
return out
def _bucket_greedy_mode(strings: List[str]) -> List[str]:
"""
Compress strings using bucket + greedy algorithm to minimize pattern count.
Complexity: O(N log N)
Args:
strings: List of strings to compress.
Returns:
List of compressed expressions (order may change).
"""
# Tokenize all stringsonce
tokenized = []
for s in strings:
lits, nums = _tokenize(s)
tokenized.append({"lits": lits, "nums": nums, "orig": s})
# Build buckets
buckets: defaultdict = defaultdict(list)
for idx, t in enumerate(tokenized):
lits, nums = t["lits"], t["nums"]
for var_idx, (raw, value, width) in enumerate(nums):
key_tokens = []
for k in range(len(nums)):
key_tokens.append(lits[k])
key_tokens.append(None if k == var_idx else nums[k][0])
key_tokens.append(lits[-1])
key = (var_idx, tuple(key_tokens))
buckets[key].append((idx, value, raw, width))
# Find contiguous runs inside every bucket
# candidate contain tuples (covered_size, indices, expression)
candidates = []
for (var_idx, _), entries in buckets.items():
# Sort by numeric *value*
entries.sort(key=lambda e: e[1])
# Start with the first entry
run = [entries[0]]
def _flush():
if len(run) >= 2:
idxs = [e[0] for e in run]
start_raw, end_raw = run[0][2], run[-1][2]
t0 = tokenized[idxs[0]]
expr = _build_expr(t0["lits"], t0["nums"], var_idx, start_raw, end_raw)
candidates.append((len(run), idxs, expr))
# Check subsequent entries
for e in entries[1:]:
prev = run[-1]
if e[1] == prev[1] + 1 and e[3] == prev[3]: # contiguous, same width
run.append(e)
else:
_flush()
run = [e]
_flush()
# Greedy cover: longest first, no overlaps
candidates.sort(key=lambda c: (-c[0], c[2])) # stable order
covered = [False] * len(strings)
out: List[str] = []
for _, idxs, expr in candidates:
if all(not covered[i] for i in idxs): # keep only disjoint
out.append(expr)
for i in idxs:
covered[i] = True
# Leftover single strings
out.extend(t["orig"] for i, t in enumerate(tokenized) if not covered[i])
return out
def collapse(strings: List[str], keep_order: bool = False) -> List[str]:
"""
Reverse-brace-expand a list of strings.
Args:
strings: The filenames / words to be compressed.
keep_order: Whether to preserve original order.
* False → minimise the **count** of patterns (order may change).
* True → keep the order of the input in the expanded output.
Returns:
List of brace-expressions plus (possibly) single strings.
"""
return _streaming_mode(strings) if keep_order else _bucket_greedy_mode(strings)
if __name__ == "__main__":
"""Self-test for the module."""
import time
ex1 = [
"/path/to/file001.tar.gz",
"/path/to/file003.tar.gz",
"/path/to/file002.tar.gz",
]
ex2 = ["python2", "python3.1", "python3.2", "python3.5"]
ex3 = ["a2b3c", "a4b3c", "a3b3c", "a4b2c", "a5b2c"]
for case in (ex1, ex2, ex3):
print("#", case)
print("unordered :", collapse(case))
print("ordered :", collapse(case, keep_order=True))
print()
# ex4 = [f"shard_{x:06d}" for x in range(10_000_000)]
ex4 = [
f"partition_{partition:02d}/shard_{x:06d}.tar"
for partition in range(5)
for x in range(1_000_000)
]
start = time.perf_counter()
res = collapse(ex4, keep_order=True)
print(res)
print(time.perf_counter() - start, "seconds")
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.cache.base import CachePool, FileStore, FileStoreDecoder, Lazy, MockLazy
from megatron.energon.cache.file_cache_pool import FileCacheLazy, FileStoreCachePool
from megatron.energon.cache.file_store import DecodeFileStore, SystemFileStore
from megatron.energon.cache.no_cache import DirectLazy, NoCachePool
__all__ = [
"CachePool",
"DecodeFileStore",
"DirectLazy",
"FileCacheLazy",
"FileStore",
"FileStoreCachePool",
"FileStoreDecoder",
"Lazy",
"MockLazy",
"NoCachePool",
"SystemFileStore",
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from abc import ABC, abstractmethod
from typing import Any, Callable, Generic, TypeVar
from megatron.energon.edataclass import edataclass
from megatron.energon.source_info import SourceInfo, add_source_info
T = TypeVar("T")
class FileStore(Generic[T]):
"""Base type for a dataset that can be accessed randomly by sample key."""
@abstractmethod
def __getitem__(self, key: str) -> tuple[T, SourceInfo]:
"""Returns the data for the given key."""
...
def get(self, key: str, sample: Any = None) -> Any:
"""Returns the data for the given key and adds the source info to the sample."""
data, source_info = self[key]
add_source_info(sample, source_info)
return data
@abstractmethod
def get_path(self) -> str:
"""Returns the path to the dataset."""
...
@edataclass
class Lazy(Generic[T]):
"""
Abstract base class for lazy references to data.
"""
ds: FileStore
fname: str
pool: "CachePool"
@abstractmethod
def get(self, sample: Any = None) -> T:
"""
Get the lazy data now and adds the source info to the sample.
"""
...
def __hash__(self) -> int:
"""Allows usage in sets and dicts as key."""
return hash((id(self.ds), self.fname))
def __eq__(self, other: Any) -> bool:
"""Allows usage in sets and dicts as key. Compares the data source and the filename."""
if not isinstance(other, Lazy):
return False
return self.ds is other.ds and self.fname == other.fname
@edataclass
class MockLazy(Lazy[T]):
"""
Mock object, which can be used as a Lazy. Allows the user to set the function to retrieve the
data. May be used to create a Lazy that is initialized from a function.
"""
ds: FileStore
fname: str
pool: "CachePool"
get_fn: Callable[[str], T]
def __init__(self, fname: str, get_fn: Callable[[str], T]):
"""
Initialize the MockLazy object.
Args:
fname: The file name of the mock object (may be used by the user).
get_fn: The function to retrieve/generate the data.
"""
self.ds = None
self.fname = fname
self.pool = None
self.get_fn = get_fn
def get(self, sample: Any = None) -> T:
"""
Get the lazy data now and adds no source info to the sample.
"""
return self.get_fn(self.fname)
def __hash__(self) -> int:
return hash((self.fname, self.get_fn))
def __eq__(self, other: Any) -> bool:
if not isinstance(other, MockLazy):
return False
return self.fname == other.fname and self.get_fn == other.get_fn
def __repr__(self) -> str:
return f"MockLazy(fname={self.fname!r}, get_fn={self.get_fn!r})"
class CachePool(ABC):
"""
A cache pool allows to load the needed data in the background and access it later.
The most important example being `FileStoreCachePool` which caches data on a local SSD disk.
To request data, use `get_lazy` to get a `Lazy` object.
Then, call `Lazy.get()` to get the data later on.
"""
@abstractmethod
def get(self, ds: FileStore, fname: str, sample: Any = None) -> Any:
"""
Get the data for a given file and adds the source info to the sample.
"""
...
@abstractmethod
def get_lazy(self, ds: FileStore, fname: str) -> Lazy:
"""
Get a lazy reference to the data for a given file.
"""
...
@abstractmethod
def to_cache(self, data: T, name: str) -> Lazy[T]:
"""
Move the data to the cache and return a lazy to fetch it later.
The name is used for debugging the data in the cache.
"""
...
@abstractmethod
def close(self) -> None:
"""
Close the cache pool.
"""
...
class FileStoreDecoder(ABC):
"""
Abstract base class for decoders.
"""
@abstractmethod
def decode(self, fname: str, data: bytes) -> Any:
"""
Decode the specified file (i.e. path/key.ext). The extension is used to select the decoder.
Args:
fname: The file name of the file to decode.
raw: The raw bytes of the file to decode.
Returns:
The decoded field's data.
"""
...
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import hashlib
import os
import pickle
import random
import string
import tempfile
import threading
import uuid
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, TypeVar
from megatron.energon.cache.base import CachePool, FileStore, Lazy
from megatron.energon.cache.file_store import DecodeFileStore
from megatron.energon.edataclass import edataclass
from megatron.energon.fork_hook import ForkMixin
from megatron.energon.source_info import SourceInfo, add_source_info
T = TypeVar("T")
@edataclass
class FileCacheLazy(Lazy[T]):
"""
Represents a reference to a background prefetch.
"""
# The cache pool that we're using.
pool: "FileStoreCachePool"
# The entry in the cache pool that we're using.
entry: "_PendingTask"
# If get() was called, this will be the data (uncached).
_data: Optional[tuple[T, SourceInfo]] = None
def get(self, sample: Any = None) -> T:
"""
Returns the data and adds the source info to the sample.
If the background job hasn't started, we cancel it,
do a direct read, and remove ourselves from the pool's references.
Otherwise, we wait for the job to finish, read from cache, and remove ourselves.
"""
if self._data is not None:
return self._data
self._data = self.pool._get_data(self.ds, self.fname, self.entry)
assert self._data is not None
add_source_info(sample, self._data[1])
return self._data[0]
def __hash__(self) -> int:
"""Allows usage in sets and dicts as key."""
return hash((id(self.ds), self.fname))
def __eq__(self, other: Any) -> bool:
"""Allows usage in sets and dicts as key. Compares the data source and the filename."""
if not isinstance(other, Lazy):
return False
return self.ds is other.ds and self.fname == other.fname
def __del__(self):
if self._data is None:
with self.pool._lock:
# Data was never fetched, still decrement refcount to delete the cache entry
self.pool._decrement_refcount_and_cleanup((self.ds.get_path(), self.fname))
@edataclass
class CacheFileLazy(Lazy[T]):
"""
Represents a reference to a cached object without deduplication.
"""
# The path to the file that contains the cached pickled object.
cache_path: Path | None
# If get() was called, this will be the data (uncached).
_data: Optional[T] = None
def get(self, sample: Any = None) -> T:
"""
Get the lazy data now and adds no source info to the sample.
"""
if self._data is None:
with open(self.cache_path, "rb") as f:
self._data = pickle.load(f)
self.cache_path.unlink()
self.cache_path = None
return self._data
def __del__(self):
if self.cache_path is not None:
self.cache_path.unlink(missing_ok=True)
self.cache_path = None
def __hash__(self) -> int:
return hash((self.fname, self.cache_path))
def __eq__(self, other: Any) -> bool:
if not isinstance(other, CacheFileLazy):
return False
return self.fname == other.fname and self.cache_path == other.cache_path
def __repr__(self) -> str:
return f"CacheFileLazy(fname={self.fname!r}, cache_path={self.cache_path!r})"
@edataclass
class _PendingTask:
"""Dataclass for storing a pending background task"""
# The dataset that we're caching.
ds: FileStore
# The file name that we're caching.
fname: str
# The future for the background task that sends the data to the cache.
send_to_cache_future: Future
# The number of references to the cache entry.
refcount: int = 1
# The size of the data to be cached.
data_size: int = 0
# Whether the data is required now, i.e. a reading thread is waiting for it.
require_data_now: bool = False
# The path to the cache file.
cache_path: Optional[Path] = None
# The source info for the data.
source_info: Optional[SourceInfo] = None
class FileStoreCachePool(CachePool, ForkMixin):
"""
Manages a thread pool to pre-fetch data onto an SSD cache.
Each (ds, fname) has one Future (one read). Multiple requests
share that same future. We track usage with a refcount.
To avoid multi-process collisions, we generate a random subfolder
for each instance.
"""
cache_dir: Path
max_cache_size: int
max_cache_count: int
current_cache_size: int
current_cache_count: int
method: Literal["raw", "pickle"]
# Thread pool for out-caching tasks
_worker_pool: Optional[ThreadPoolExecutor] = None
# (ds.path, fname) -> PendingTask
_pending_tasks: Dict[Tuple[str, str], _PendingTask]
# Lock for all shared structures
_lock: threading.Lock
# Condition variable to signal when cache space is available
_cache_space_available: threading.Condition
# Whether the pool is shutting down
_shutting_down: bool = False
def __init__(
self,
*,
parent_cache_dir: Optional[Path] = None,
num_workers: int = 8,
max_cache_size_gbytes: float = 1024,
max_cache_count: int = 10_000_000,
method: Literal["raw", "pickle"] = "raw",
):
"""
Initialize the cache pool.
Args:
parent_cache_dir: The parent directory for the cache.
num_workers: The number of worker threads to use for copying the data to the cache for lazy loading.
max_cache_size_gbytes: The maximum size of the cache in gigabytes. If the cache exceeds this size,
the prefetching will wait until the cache is below this size.
max_cache_count: The maximum number of files in the cache. If the cache exceeds this number,
the prefetching will wait until the cache is below this number.
method: The method to use for caching. "raw" store the non-decoded raw data. "pickle": first decode the data
and then store the pickled data.
"""
super().__init__()
# If no parent directory is given, create a temp directory
if parent_cache_dir is None:
parent_cache_dir = Path(tempfile.gettempdir())
self.parent_cache_dir = parent_cache_dir
self.num_workers = num_workers
# Initialize the cache pool (process volatile fields)
self.__after_fork__(initial=True)
self.method = method
# We'll store _pending_tasks in the form:
# (ds.path, fname) -> PendingTask
self._pending_tasks = {}
# Cache size management
self.max_cache_size = int(max_cache_size_gbytes * (1024**3))
self.max_cache_count = max_cache_count
self.current_cache_size = 0
self.current_cache_count = 0
# A lock to protect all shared structures
self._lock = threading.Lock()
# Condition variable to signal when cache space is available
self._cache_space_available = threading.Condition(self._lock)
def get(self, ds: FileStore, fname: str, sample: Any = None) -> Any:
"""
Synchronous read from the dataset (no cache usage).
"""
return ds.get(fname, sample)
def _get_data(self, ds: FileStore, fname: str, entry: _PendingTask) -> tuple[Any, SourceInfo]:
"""
Get the data for a given file from the cache and purge cache if no references are left.
* If the cache-out is complete, read from cache.
* If the cache-out is currently prefetching the data to local storage, wait until it's done.
* If the cache-out job is waiting for space, skip the cache and do a direct read.
* If the cache-out job is queued for caching, cancel and do a direct read.
* If the cache-out job failed, raise through and keep for other references.
* If the cache-out job is cancelled, requeue if there are other references waiting for it.
"""
result: tuple[Any, SourceInfo]
with self._lock:
try:
# Attempt to cancel if the job hasn't started
if entry.send_to_cache_future.cancel():
was_cached = False
try:
# Cancelled => job never ran. We'll do a direct read.
result = ds[fname]
finally:
# Decrement refcount
self._decrement_refcount_and_cleanup(key=(ds.get_path(), fname))
else:
# Future is already running or done.
# Release the lock so the background job can proceed,
# then reacquire it after waiting. Otherwise we might block the worker.
entry.require_data_now = True
self._cache_space_available.notify_all()
self._lock.release()
# If the job failed, let's keep the exception for other references.
was_cached = True
try:
# Can raise exception if job failed
was_cached = entry.send_to_cache_future.result()
if was_cached:
# The job is complete; read from cache
result = self._read_from_cache(entry)
else:
# The job failed, so we'll do a direct decode
result = ds[fname]
finally:
self._lock.acquire()
entry.require_data_now = False
# Decrement refcount
self._decrement_refcount_and_cleanup(key=(ds.get_path(), fname))
finally:
if entry.refcount > 0 and not was_cached:
# TODO: Could write to cache here, data is already fetched.
# Write the result to the cache
# Requeue the job, there is another reference to the cache entry
entry.send_to_cache_future = self._worker_pool.submit(
self._cache_out_task, ds, fname, entry
)
return result
def _cache_out_task(self, ds: FileStore, fname: str, entry: _PendingTask) -> bool:
with self._lock:
if self._shutting_down:
return False
# Perform the data read
if self.method == "raw":
if isinstance(ds, DecodeFileStore):
data, entry.source_info = ds.inner_reader[fname]
else:
data, entry.source_info = ds[fname]
elif self.method == "pickle":
data, entry.source_info = ds[fname]
data = pickle.dumps(data)
else:
raise ValueError(f"Invalid method: {self.method}")
# Wait until there's enough space in the cache
with self._lock:
entry.data_size = file_size = len(data)
while (
self.current_cache_count + 1 > self.max_cache_count
or self.current_cache_size + entry.data_size > self.max_cache_size
):
# Release the lock and wait for notification
self._cache_space_available.wait()
if entry.require_data_now or self._shutting_down:
# At least one reference requires the data now, stop waiting for space and exit immediately
return False
# Reserve the space
self.current_cache_size += file_size
self.current_cache_count += 1
if self._shutting_down or entry.refcount <= 0:
# No more references to this background job, don't write to cache
return False
try:
assert entry.cache_path is None, (
f"cache_path should be None, but is {entry.cache_path!r}"
)
# Write to cache
cache_path = self._make_cache_path(ds, fname)
self._write_to_cache(cache_path, data)
except:
with self._lock:
# Revert the space reservation
self.current_cache_size -= file_size
self.current_cache_count -= 1
self._cache_space_available.notify_all()
raise
else:
with self._lock:
entry.cache_path = cache_path
# print(
# f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Wrote to cache {cache_path} (rc={entry.refcount}, size={file_size}, name={fname})\n",
# end="",
# )
# Data is cached now, return True
return True
def get_lazy(self, ds: FileStore, fname: str) -> FileCacheLazy:
"""
Schedule a background pre-fetch. If multiple calls come in for the same (ds, fname),
they'll share the same Future and increment reference counts.
"""
key = (ds.get_path(), fname)
with self._lock:
if self._shutting_down:
raise RuntimeError("Cache pool is already shutting down")
entry = self._pending_tasks.get(key)
if entry:
# Already have a background task for this (ds, fname)
entry.refcount += 1
else:
# Create a new background task
entry = _PendingTask(
ds=ds,
fname=fname,
send_to_cache_future=None,
)
self._pending_tasks[key] = entry
entry.send_to_cache_future = self._worker_pool.submit(
self._cache_out_task, ds, fname, entry
)
return FileCacheLazy(ds=ds, fname=fname, pool=self, entry=entry)
def to_cache(self, data: T, name: str) -> CacheFileLazy[T]:
"""
Move the data to the cache and return a lazy to fetch it later.
"""
raw_data = pickle.dumps(data)
cache_fname = str(uuid.uuid4())
cache_path = self.cache_dir / cache_fname
self._write_to_cache(cache_path, raw_data)
return CacheFileLazy(ds=None, fname=name, pool=self, cache_path=cache_path)
def close(self) -> None:
"""
Shutdown the pool, wait for tasks, and clear our structures.
"""
with self._lock:
self._shutting_down = True
for entry in self._pending_tasks.values():
entry.send_to_cache_future.cancel()
self._cache_space_available.notify_all()
self._worker_pool.shutdown(wait=True)
with self._lock:
self._pending_tasks.clear()
def _decrement_refcount_and_cleanup(self, key: Tuple[FileStore, str]) -> None:
"""
Decrement the reference count in `_pending_tasks`.
If it hits zero, remove the entry. Optionally remove the file if so.
Assumes the caller holds `self._lock`.
"""
entry = self._pending_tasks.get(key)
if not entry:
# Already cleaned up
return
entry.refcount -= 1
if entry.refcount <= 0:
# No more references to this background job
del self._pending_tasks[key]
self._remove_cached_file(entry)
assert entry.refcount == 0, f"refcount should be 0: {entry.refcount}"
def _make_cache_path(self, ds: FileStore, fname: str) -> Path:
# This is safe, because the parent cache dir is unique per instance.
ds_hash = hashlib.md5(ds.get_path().encode("utf-8")).hexdigest()
fn_hash = hashlib.md5(fname.encode("utf-8")).hexdigest()
# ds_hash = str(ds.get_path()).replace("/", "_")
# fn_hash = fname.replace("/", "_")
return self.cache_dir / f"{ds_hash}_{fn_hash}"
def _write_to_cache(self, path: Path, data: bytes) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
f.write(data)
def _read_from_cache(self, entry: _PendingTask) -> tuple[Any, SourceInfo]:
assert entry.source_info is not None, "source_info should have been set"
with open(entry.cache_path, "rb") as f:
if self.method == "raw":
raw = f.read()
if isinstance(entry.ds, DecodeFileStore):
return entry.ds.decoder.decode(entry.fname, raw), entry.source_info
else:
return raw, entry.source_info
else:
return pickle.load(f), entry.source_info
def _remove_cached_file(self, entry: _PendingTask) -> None:
"""
Removes a file from disk and updates size counters.
Assumes the caller holds `self._lock`.
"""
if entry.cache_path is None:
return
if not entry.cache_path.exists():
return
try:
# print(
# f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Removing cached file {entry.cache_path} (rc={entry.refcount})\n",
# end="",
# )
entry.cache_path.unlink()
except OSError:
pass
entry.cache_path = None
if entry.data_size > 0:
self.current_cache_size -= entry.data_size
self.current_cache_count -= 1
# Notify waiting threads that space is now available
self._cache_space_available.notify_all()
def __before_fork__(self):
# Ensure the worker pool is shutdown before the fork
assert len(self._pending_tasks) == 0, "Pending tasks should be empty before fork"
# print(
# f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Before fork for oid={id(self)} random_suffix={self.cache_dir.name!r}\n",
# end="",
# )
self._worker_pool.shutdown(wait=True)
self._worker_pool = None
def __after_in_child_fork__(self):
self.__after_fork__()
def __after_in_parent_fork__(self):
self.__after_fork__()
def __after_fork__(self, initial: bool = False):
random_suffix = "".join(
random.Random(os.getpid() ^ random.randint(0, 2**32)).choices(
string.ascii_lowercase + string.digits, k=16
)
)
assert self._worker_pool is None
self._worker_pool = ThreadPoolExecutor(
max_workers=self.num_workers, thread_name_prefix="CacheWorker"
)
# Create a random subdirectory name to avoid collisions with other processes
# As the global random generator is cloned across processes, we need to use a process-specific seed
self.cache_dir = (self.parent_cache_dir / f"cache_{random_suffix}").resolve()
self.cache_dir.mkdir(parents=True, exist_ok=True)
# if initial:
# print(
# f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Init oid={id(self)} random_suffix={random_suffix!r}\n",
# end="",
# )
# else:
# print(
# f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: After fork for pid={os.getpid()} oid={id(self)} random_suffix={random_suffix!r}\n",
# end="",
# )
def __str__(self):
return f"FileStoreCachePool(cache_dir={self.cache_dir}, max_cache_size={self.max_cache_size}, max_cache_count={self.max_cache_count}, method={self.method}, current_cache_size={self.current_cache_size}, current_cache_count={self.current_cache_count})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Optional, Union
from megatron.energon.cache.base import FileStore, FileStoreDecoder
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.jsonl.ijsonl_reader import IJsonlReader
from megatron.energon.flavors.webdataset.itar_reader import SqliteITarEntryReader
from megatron.energon.source_info import SourceInfo
class DecodeFileStore(FileStore[Any]):
"""Used to wrap a FileStore and decode the data on access."""
def __init__(
self,
inner_reader: FileStore[bytes],
*,
decoder: FileStoreDecoder,
):
"""
Args:
inner_reader: The FileStore to wrap.
decoder: The decoder to apply to every item read from the FileStore.
"""
self.inner_reader = inner_reader
self.decoder = decoder
def __getitem__(self, fname: str) -> tuple[Any, SourceInfo]:
data, source_info = self.inner_reader[fname]
return self.decoder.decode(fname, data), source_info
def get_path(self) -> str:
return self.inner_reader.get_path()
def __str__(self):
return f"DecodeFileStore(inner_reader={self.inner_reader}, decoder={self.decoder})"
class SystemFileStore(FileStore[bytes]):
"""A FileStore that reads files directly from the file system."""
def __init__(self, base_dir: Optional[Union[EPath, str]] = None):
"""
Args:
base_dir: The base directory to use for relative paths. If None, you should only pass
absolute paths to __getitem__.
"""
self.base_dir = EPath(base_dir) if base_dir is not None else None
def __getitem__(self, key: str) -> tuple[bytes, SourceInfo]:
# Construct the full path from the dataset path and the file key
if self.base_dir is None:
file_path = EPath(key)
else:
file_path = self.base_dir / key
# Read and return the file contents as bytes
with file_path.open("rb") as f:
data = f.read()
return data, SourceInfo(
dataset_path=self.base_dir,
index=None,
shard_name=None,
file_names=(key,),
)
def get_path(self) -> str:
"""Returns the path to the dataset."""
return str(self.base_dir)
def __str__(self):
return f"SystemFileStore(base_dir={self.base_dir})"
class WebdatasetFileStore(SqliteITarEntryReader, FileStore[bytes]):
"""This dataset will directly read files from the dataset tar files from a prepared energon dataset."""
def __init__(
self,
dataset_path: EPath,
):
super().__init__(base_path=dataset_path, key_is_full_entryname=True)
def get_path(self) -> str:
return str(self.base_path)
class JsonlFileStore(IJsonlReader, FileStore[bytes]):
"""This dataset will directly read entries from a jsonl file."""
def get_path(self) -> str:
return str(self.jsonl_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment