Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
Qwen2.5-14B-Instruct_dcu-megatron
Commits
f356f546
Commit
f356f546
authored
Feb 04, 2026
by
maming
Browse files
Initial commit
parents
Pipeline
#3339
canceled with stages
Changes
346
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2951 additions
and
0 deletions
+2951
-0
Megatron-Energon/pyproject.toml
Megatron-Energon/pyproject.toml
+110
-0
Megatron-Energon/scripts/license_headers.py
Megatron-Energon/scripts/license_headers.py
+125
-0
Megatron-Energon/src/megatron/energon/__init__.py
Megatron-Energon/src/megatron/energon/__init__.py
+206
-0
Megatron-Energon/src/megatron/energon/av/__init__.py
Megatron-Energon/src/megatron/energon/av/__init__.py
+19
-0
Megatron-Energon/src/megatron/energon/av/av_data.py
Megatron-Energon/src/megatron/energon/av/av_data.py
+20
-0
Megatron-Energon/src/megatron/energon/av/av_decoder.py
Megatron-Energon/src/megatron/energon/av/av_decoder.py
+732
-0
Megatron-Energon/src/megatron/energon/av/av_init.py
Megatron-Energon/src/megatron/energon/av/av_init.py
+42
-0
Megatron-Energon/src/megatron/energon/av/fastseek/__init__.py
...tron-Energon/src/megatron/energon/av/fastseek/__init__.py
+4
-0
Megatron-Energon/src/megatron/energon/av/fastseek/containers/__init__.py
...n/src/megatron/energon/av/fastseek/containers/__init__.py
+2
-0
Megatron-Energon/src/megatron/energon/av/fastseek/containers/matroska.py
...n/src/megatron/energon/av/fastseek/containers/matroska.py
+51
-0
Megatron-Energon/src/megatron/energon/av/fastseek/containers/mpeg.py
...ergon/src/megatron/energon/av/fastseek/containers/mpeg.py
+241
-0
Megatron-Energon/src/megatron/energon/av/fastseek/containers/probe.py
...rgon/src/megatron/energon/av/fastseek/containers/probe.py
+23
-0
Megatron-Energon/src/megatron/energon/av/fastseek/fastseek.py
...tron-Energon/src/megatron/energon/av/fastseek/fastseek.py
+143
-0
Megatron-Energon/src/megatron/energon/av/fastseek/keyframeinfo.py
...-Energon/src/megatron/energon/av/fastseek/keyframeinfo.py
+31
-0
Megatron-Energon/src/megatron/energon/av/utils.py
Megatron-Energon/src/megatron/energon/av/utils.py
+137
-0
Megatron-Energon/src/megatron/energon/bracecollapse.py
Megatron-Energon/src/megatron/energon/bracecollapse.py
+280
-0
Megatron-Energon/src/megatron/energon/cache/__init__.py
Megatron-Energon/src/megatron/energon/cache/__init__.py
+21
-0
Megatron-Energon/src/megatron/energon/cache/base.py
Megatron-Energon/src/megatron/energon/cache/base.py
+161
-0
Megatron-Energon/src/megatron/energon/cache/file_cache_pool.py
...ron-Energon/src/megatron/energon/cache/file_cache_pool.py
+506
-0
Megatron-Energon/src/megatron/energon/cache/file_store.py
Megatron-Energon/src/megatron/energon/cache/file_store.py
+97
-0
No files found.
Too many changes to show.
To preserve performance only
346 of 346+
files are displayed.
Plain diff
Email patch
Megatron-Energon/pyproject.toml
0 → 100644
View file @
f356f546
[build-system]
requires
=
[
"hatchling"
,
"hatch-vcs"
]
build-backend
=
"hatchling.build"
[project]
name
=
"megatron-energon"
dynamic
=
["version"]
authors
=
[
{
name=
"Lukas Vögtle"
,
email=
"lvoegtle@nvidia.com"
}
,
{
name=
"Philipp Fischer"
,
email=
"pfischer@nvidia.com"
}
,
]
description
=
"Megatron's multi-modal data loader"
readme
=
"README.md"
license
=
"BSD-3-Clause"
requires-python
=
">=3.10"
classifiers
=
[
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.11"
,
"Programming Language :: Python :: 3.12"
,
"Operating System :: OS Independent"
,
]
dependencies
=
[
"braceexpand"
,
"click"
,
"dataslots; python_version<'3.10'"
,
"mfusepy"
,
"multi-storage-client>=0.18.0,<0.26.0"
,
"numpy"
,
"pillow>=10.0.1"
,
# WEBP vulnerability fixed starting from 10.0.1
"pyyaml"
,
"rapidyaml>=0.10.0"
,
"s3fs"
,
"torch"
,
"tqdm"
,
"webdataset"
,
]
[project.optional-dependencies]
dev
=
[
"ruff"
,
"sphinxcontrib-napoleon"
,
"sphinx"
,
"myst-parser"
,
"soundfile"
,
"sphinx-rtd-theme"
,
"sphinx-click"
,
]
transforms
=
[
"torchvision"
,
# Needed for megatron.energon.transforms
]
# Storage services for MSC
s3
=
[
"multi-storage-client[boto3]"
,
]
aistore
=
[
"multi-storage-client[aistore]"
,
]
azure-storage-blob
=
[
"multi-storage-client[azure-storage-blob]"
,
]
google-cloud-storage
=
[
"multi-storage-client[google-cloud-storage]"
,
]
oci
=
[
"multi-storage-client[oci]"
,
]
# Dependencies for video decoding
av_decode
=
[
# needed for efficient audio and video file decoding
"bitstring>=4.2.3"
,
"sortedcontainers>=2.4.0"
,
"filetype>=1.2.0"
,
"ebmlite>=3.3.1"
,
"av>=14.4.0"
,
]
# If using guess_content=True for decoding
guess_content
=
[
"filetype>=1.0.0"
,
]
[project.urls]
Homepage
=
"https://github.com/NVIDIA/Megatron-Energon"
[tool.hatch.build.targets.wheel]
packages
=
["src/megatron"]
[tool.hatch.build.targets.sdist]
packages
=
["src/megatron"]
[project.scripts]
energon
=
"megatron.energon.cli.main:main"
[tool.hatch.version]
source
=
"vcs"
[tool.ruff.lint]
extend-select
=
["I"]
ignore
=
[
"E741"
,
"E731"
]
[tool.ruff]
line-length
=
100
target-version
=
"py310"
include
=
[
"**/*.py"
,
"**/*.pyi"
]
exclude
=
[
".idea"
,
"docs"
,
]
Megatron-Energon/scripts/license_headers.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Optional
,
Tuple
import
click
@
dataclass
class
HeaderUpdater
:
file_ext
:
str
line_comment
:
Optional
[
str
]
=
None
comment_start
:
Optional
[
str
]
=
None
comment_end
:
Optional
[
str
]
=
None
UPDATE_IDENTIFIER
=
"Copyright"
HEADER_LINES
:
Tuple
[
str
,
...]
=
(
"Copyright (c) 2025, NVIDIA CORPORATION."
,
"SPDX-License-Identifier: BSD-3-Clause"
,
)
_expected_lines
:
Tuple
[
str
,
...]
=
()
def
__post_init__
(
self
):
if
self
.
line_comment
is
not
None
:
self
.
_expected_lines
=
tuple
(
self
.
line_comment
+
line
for
line
in
self
.
HEADER_LINES
)
else
:
assert
self
.
comment_start
is
not
None
and
self
.
comment_end
is
not
None
if
len
(
self
.
HEADER_LINES
)
>=
2
:
self
.
_expected_lines
=
(
self
.
comment_start
+
self
.
HEADER_LINES
[
0
],
*
self
.
HEADER_LINES
[
1
:
-
1
],
self
.
HEADER_LINES
[
-
1
]
+
self
.
comment_end
,
)
else
:
assert
len
(
self
.
HEADER_LINES
)
==
1
self
.
_expected_lines
=
(
self
.
comment_start
+
self
.
HEADER_LINES
[
0
]
+
self
.
comment_end
,
)
def
has_header
(
self
,
file
:
Path
)
->
bool
:
with
file
.
open
()
as
rf
:
num_lines
=
0
for
line
,
expected
in
zip
(
rf
,
self
.
_expected_lines
):
num_lines
+=
1
if
line
.
rstrip
(
"
\n
"
)
!=
expected
:
return
False
return
num_lines
==
len
(
self
.
_expected_lines
)
def
fix_header
(
self
,
file
:
Path
):
contents
=
file
.
read_text
()
first_comment
=
self
.
line_comment
if
self
.
line_comment
is
not
None
else
self
.
comment_start
if
contents
.
startswith
(
first_comment
)
and
contents
[
len
(
first_comment
)
:].
startswith
(
self
.
UPDATE_IDENTIFIER
):
# Already has header, but want to update
*
header_lines
,
remainder
=
contents
.
split
(
"
\n
"
,
len
(
self
.
_expected_lines
))
new_contents
=
"
\n
"
.
join
(
self
.
_expected_lines
)
+
"
\n
"
+
remainder
else
:
# No header, add it
new_contents
=
"
\n
"
.
join
(
self
.
_expected_lines
)
+
"
\n
"
+
contents
file
.
write_text
(
new_contents
)
headers
=
(
HeaderUpdater
(
file_ext
=
".py"
,
line_comment
=
"# "
,
),
HeaderUpdater
(
file_ext
=
".sh"
,
line_comment
=
"# "
,
),
# Do not add to yamls
# HeaderUpdater(
# file_ext=".yml",
# line_comment="# ",
# ),
# HeaderUpdater(
# file_ext=".yaml",
# line_comment="# ",
# ),
HeaderUpdater
(
file_ext
=
".md"
,
comment_start
=
"<!--- "
,
comment_end
=
" -->"
,
),
)
@
click
.
command
()
@
click
.
argument
(
"path"
,
type
=
click
.
Path
(
exists
=
True
,
file_okay
=
False
,
dir_okay
=
True
,
readable
=
True
,
path_type
=
Path
),
)
@
click
.
option
(
"--fix"
,
is_flag
=
True
,
help
=
"Automatically add missing/adapt existing license headers."
)
def
main
(
path
:
Path
,
fix
:
bool
)
->
None
:
# Get all files to check
success
=
True
for
header_update
in
headers
:
for
filepath
in
path
.
rglob
(
f
"*
{
header_update
.
file_ext
}
"
):
if
not
header_update
.
has_header
(
filepath
):
success
=
False
print
(
f
"Missing header in
{
filepath
}
"
)
if
fix
:
header_update
.
fix_header
(
filepath
)
print
(
f
"Added header to
{
filepath
}
"
)
if
success
:
print
(
"All files have the correct license header."
)
elif
fix
:
print
(
"Some license headers were fixed."
)
else
:
print
(
"Some license headers are missing or incorrect. Run `python scripts/license_headers.py . --fix` to fix."
)
exit
(
1
)
if
__name__
==
"__main__"
:
main
()
Megatron-Energon/src/megatron/energon/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
try
:
from
importlib.metadata
import
version
__version__
=
version
(
"megatron-energon"
)
except
Exception
:
# Fallback for development mode when package is not installed
__version__
=
"unknown"
from
megatron.energon.cache
import
(
CachePool
,
DecodeFileStore
,
DirectLazy
,
FileCacheLazy
,
FileStore
,
FileStoreCachePool
,
FileStoreDecoder
,
Lazy
,
MockLazy
,
NoCachePool
,
SystemFileStore
,
)
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.flavors
import
(
BaseCoreDatasetFactory
,
BaseWebdatasetFactory
,
CaptioningSample
,
CaptioningWebdataset
,
CrudeJsonlDatasetFactory
,
CrudeSample
,
CrudeWebdataset
,
DefaultCrudeJsonlDatasetFactory
,
DefaultDecoderWebdatasetFactory
,
DefaultGenericWebdatasetFactory
,
ImageClassificationSample
,
ImageClassificationWebdataset
,
ImageSample
,
ImageWebdataset
,
InterleavedSample
,
InterleavedWebdataset
,
JoinedWebdatasetFactory
,
MultiChoiceVQASample
,
MultiChoiceVQAWebdataset
,
OCRSample
,
OCRWebdataset
,
Sample
,
SampleDecoder
,
SavableDataset
,
SimilarityInterleavedSample
,
SimilarityInterleavedWebdataset
,
StandardWebdatasetFactory
,
TextSample
,
TextWebdataset
,
VidQASample
,
VidQAWebdataset
,
VQAOCRWebdataset
,
VQASample
,
VQAWebdataset
,
)
from
megatron.energon.loader
import
get_loader
,
get_savable_loader
from
megatron.energon.metadataset
import
(
DatasetLoader
,
DatasetLoaderInterface
,
Metadataset
,
MetadatasetV2
,
load_dataset
,
prepare_metadataset
,
)
from
megatron.energon.savable_loader
import
SavableDataLoader
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.task_encoder
import
(
AugmentTaskEncoder
,
Batch
,
Cooker
,
DefaultTaskEncoder
,
TaskEncoder
,
basic_sample_keys
,
batch_list
,
batch_pad_stack
,
batch_stack
,
cooker
,
generic_batch
,
get_train_dataset
,
get_val_dataset
,
get_val_datasets
,
stateless
,
)
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers
import
(
BatchDataset
,
BlendDataset
,
ConcatDataset
,
EpochizeDataset
,
FilterDataset
,
GcDataset
,
GroupBatchDataset
,
IterMapDataset
,
LimitDataset
,
LogSampleDataset
,
MapDataset
,
MixBatchDataset
,
PackingDataset
,
RepeatDataset
,
ShuffleBufferDataset
,
SkipSample
,
concat_pad
,
generic_concat
,
homogeneous_concat_mix
,
)
__all__
=
[
"__version__"
,
"AugmentTaskEncoder"
,
"BaseCoreDatasetFactory"
,
"BaseWebdatasetFactory"
,
"basic_sample_keys"
,
"batch_list"
,
"batch_pad_stack"
,
"batch_stack"
,
"Batch"
,
"BatchDataset"
,
"BlendDataset"
,
"CachePool"
,
"CaptioningSample"
,
"CaptioningWebdataset"
,
"concat_pad"
,
"ConcatDataset"
,
"cooker"
,
"Cooker"
,
"CrudeJsonlDatasetFactory"
,
"CrudeSample"
,
"CrudeWebdataset"
,
"DatasetLoader"
,
"DatasetLoaderInterface"
,
"DecodeFileStore"
,
"DefaultCrudeJsonlDatasetFactory"
,
"DefaultDecoderWebdatasetFactory"
,
"DefaultGenericWebdatasetFactory"
,
"DefaultTaskEncoder"
,
"DirectLazy"
,
"edataclass"
,
"EpochizeDataset"
,
"FileCacheLazy"
,
"FileStore"
,
"FileStoreCachePool"
,
"FileStoreDecoder"
,
"FilterDataset"
,
"GcDataset"
,
"generic_batch"
,
"generic_concat"
,
"get_loader"
,
"get_savable_loader"
,
"get_train_dataset"
,
"get_val_dataset"
,
"get_val_datasets"
,
"GroupBatchDataset"
,
"homogeneous_concat_mix"
,
"ImageClassificationSample"
,
"ImageClassificationWebdataset"
,
"ImageSample"
,
"ImageWebdataset"
,
"InterleavedSample"
,
"InterleavedWebdataset"
,
"IterMapDataset"
,
"JoinedWebdatasetFactory"
,
"Lazy"
,
"LimitDataset"
,
"load_dataset"
,
"LogSampleDataset"
,
"MapDataset"
,
"Metadataset"
,
"MetadatasetV2"
,
"MixBatchDataset"
,
"MockLazy"
,
"MultiChoiceVQASample"
,
"MultiChoiceVQAWebdataset"
,
"NoCachePool"
,
"OCRSample"
,
"OCRWebdataset"
,
"PackingDataset"
,
"prepare_metadataset"
,
"RepeatDataset"
,
"Sample"
,
"SampleDecoder"
,
"SavableDataLoader"
,
"SavableDataset"
,
"ShuffleBufferDataset"
,
"SimilarityInterleavedSample"
,
"SimilarityInterleavedWebdataset"
,
"SkipSample"
,
"SourceInfo"
,
"StandardWebdatasetFactory"
,
"stateless"
,
"SystemFileStore"
,
"TaskEncoder"
,
"TextSample"
,
"TextWebdataset"
,
"VidQASample"
,
"VidQAWebdataset"
,
"VQAOCRWebdataset"
,
"VQASample"
,
"VQAWebdataset"
,
"WorkerConfig"
,
]
Megatron-Energon/src/megatron/energon/av/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.av.av_data
import
AVData
from
megatron.energon.av.av_decoder
import
(
AVDecoder
,
AVMetadata
,
AVWebdatasetDecoder
,
)
from
megatron.energon.av.utils
import
get_clips_uniform
,
get_single_frames_uniform
__all__
=
[
"AVDecoder"
,
"AVMetadata"
,
"AVWebdatasetDecoder"
,
"AVData"
,
"get_clips_uniform"
,
"get_single_frames_uniform"
,
]
Megatron-Energon/src/megatron/energon/av/av_data.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Optional
import
torch
from
megatron.energon.edataclass
import
edataclass
@
edataclass
class
AVData
:
#: A list of video clips in the shape (frames, channels, h, w)
video_clips
:
Optional
[
list
[
torch
.
Tensor
]]
#: The timestamps for the video clips. List of tuples (start, end) in seconds
video_timestamps
:
Optional
[
list
[
tuple
[
float
,
float
]]]
#: A list of audio clips in the shape (channels, samples)
audio_clips
:
Optional
[
list
[
torch
.
Tensor
]]
#: The timestamps for the audio clips. List of tuples (start, end) in seconds
audio_timestamps
:
Optional
[
list
[
tuple
[
float
,
float
]]]
Megatron-Energon/src/megatron/energon/av/av_decoder.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
io
import
warnings
from
collections.abc
import
Iterator
from
fractions
import
Fraction
from
typing
import
BinaryIO
,
Literal
,
Optional
,
Sequence
,
Union
,
overload
import
numpy
as
np
import
torch
from
megatron.energon.av.av_init
import
av_open
from
megatron.energon.edataclass
import
edataclass
# Try importing optional dependencies
try
:
import
av
import
av.container
import
filetype
from
megatron.energon.av.fastseek
import
Fastseek
AV_DECODE_AVAILABLE
=
True
except
ImportError
as
e
:
AV_DECODE_AVAILABLE
=
False
MISSING_DEPENDENCY
=
str
(
e
)
from
megatron.energon.av.av_data
import
AVData
class
AVDecoder
:
"""A class that provides a flexible interface for decoding audio and video data.
This class allows users to control decoding parameters at runtime rather than having them fixed
during initialization. It's particularly useful for cases where different samples may need different
decoding parameters.
"""
seeker
:
"Fastseek"
stream
:
BinaryIO
suppress_warnings
:
bool
def
__init__
(
self
,
stream
:
BinaryIO
,
suppress_warnings
:
bool
=
False
)
->
None
:
if
not
AV_DECODE_AVAILABLE
:
raise
ImportError
(
f
"AV decoding is not available. Please install the required dependencies with:
\n
"
f
"pip install megatron-energon[av_decode]
\n
"
f
"Missing dependency:
{
MISSING_DEPENDENCY
}
. Install megatron-energon[av_decode] to use AVDecoder."
)
self
.
stream
=
stream
self
.
suppress_warnings
=
suppress_warnings
assert
"t"
not
in
getattr
(
stream
,
"mode"
,
"rb"
)
and
not
isinstance
(
stream
,
io
.
TextIOBase
),
(
"Stream must not be opened in text mode"
)
try
:
self
.
seeker
=
Fastseek
(
self
.
stream
)
except
ValueError
:
self
.
stream
.
seek
(
0
)
self
.
seeker
=
Fastseek
(
self
.
stream
,
probe
=
True
)
self
.
stream
.
seek
(
0
)
def
get_video
(
self
)
->
AVData
:
"""Get the entire video data from the stream (without audio)."""
video_clips
,
video_timestamps
=
self
.
get_video_clips
(
video_clip_ranges
=
[(
0
,
float
(
"inf"
))])
return
AVData
(
video_clips
=
video_clips
,
video_timestamps
=
video_timestamps
,
audio_clips
=
[],
audio_timestamps
=
[],
)
def
get_video_clips
(
self
,
video_clip_ranges
:
Sequence
[
tuple
[
float
,
float
]],
video_unit
:
Literal
[
"frames"
,
"seconds"
]
=
"seconds"
,
video_out_frame_size
:
Optional
[
tuple
[
int
,
int
]]
=
None
,
)
->
tuple
[
list
[
torch
.
Tensor
],
list
[
tuple
[
float
,
float
]]]:
"""Get video clips from the video stream.
Args:
video_clip_ranges: List of video clip start and end positions in the given unit (see video_unit)
video_unit: Unit of the video clip positions ("frames" for frame number, "seconds" for timestamp)
video_out_frame_size: Output size for video frames (width, height), or None to use the original frame size
Returns:
A tuple containing:
- video_clips: List of video clips
- video_clips_timestamps: List of timestamps for each video clip start and end in seconds
"""
assert
video_unit
in
(
"frames"
,
"seconds"
)
self
.
stream
.
seek
(
0
)
# Reset the video stream so that pyav can read the entire container
with
av_open
(
self
.
stream
)
as
input_container
:
assert
len
(
input_container
.
streams
.
video
)
>
0
,
(
"No video stream found, but video_clips are requested"
)
video_stream
=
input_container
.
streams
.
video
[
0
]
# Pre-calculate timing info for video
average_rate
:
Fraction
=
video_stream
.
average_rate
# Frames per second
assert
average_rate
,
"Video stream has no FPS."
time_base
:
Fraction
=
video_stream
.
time_base
# Seconds per PTS unit
if
video_clip_ranges
is
not
None
:
# Convert video_clip_ranges to seeker unit
if
video_unit
==
"frames"
and
self
.
seeker
.
unit
==
"pts"
:
# Convert from frames to pts units
video_clip_ranges
=
[
(
clip
[
0
]
/
average_rate
/
time_base
,
clip
[
1
]
/
average_rate
/
time_base
,
)
for
clip
in
video_clip_ranges
]
if
not
self
.
suppress_warnings
:
warnings
.
warn
(
"Video container unit is frames, but seeking in time units. The resulting frames may be slightly off."
,
RuntimeWarning
,
)
elif
video_unit
==
"seconds"
and
self
.
seeker
.
unit
==
"frames"
:
# Convert from seconds to frames
video_clip_ranges
=
[
(
clip
[
0
]
*
average_rate
,
clip
[
1
]
*
average_rate
,
)
for
clip
in
video_clip_ranges
]
if
not
self
.
suppress_warnings
:
warnings
.
warn
(
"Video container unit is time units, but seeking using frame number. The resulting frames may be slightly off."
,
RuntimeWarning
,
)
elif
video_unit
==
"seconds"
and
self
.
seeker
.
unit
==
"pts"
:
# Convert from seconds to pts units
video_clip_ranges
=
[
(
clip
[
0
]
/
time_base
,
clip
[
1
]
/
time_base
)
for
clip
in
video_clip_ranges
]
frame_iterator
:
Iterator
[
av
.
VideoFrame
]
=
input_container
.
decode
(
video
=
0
)
previous_frame_index
:
int
=
0
video_clips_frames
:
list
[
list
[
torch
.
Tensor
]]
=
[]
video_clips_timestamps
:
list
[
tuple
[
float
,
float
]]
=
[]
for
video_clip_range
in
video_clip_ranges
:
start_frame_index
,
end_frame_index
=
video_clip_range
# Convert to int if possible, set end to None if infinite
start_frame_index
=
int
(
start_frame_index
)
end_frame_index
=
int
(
end_frame_index
)
if
end_frame_index
!=
float
(
"inf"
)
else
None
clip_frames
:
list
[
torch
.
Tensor
]
=
[]
clip_timestamp_start
=
None
clip_timestamp_end
=
None
# Find start frame
if
(
iframe_info
:
=
self
.
seeker
.
should_seek
(
previous_frame_index
,
start_frame_index
)
)
is
not
None
:
input_container
.
seek
(
iframe_info
.
pts
,
stream
=
input_container
.
streams
.
video
[
0
])
previous_frame_index
=
iframe_info
.
index
for
frame
in
frame_iterator
:
take_frame
=
False
last_frame
=
False
# Container uses frame counts, we can find the exact target frame by counting from the iframe which is at a known offset
if
self
.
seeker
.
unit
==
"frames"
:
if
previous_frame_index
>=
start_frame_index
:
take_frame
=
True
if
end_frame_index
is
not
None
and
previous_frame_index
>=
end_frame_index
:
last_frame
=
True
# Container uses time, the target frame might not correspond exactly to any metadata but the desired timestamp should
# fall within a frames display period
if
self
.
seeker
.
unit
==
"pts"
:
if
start_frame_index
<=
(
frame
.
pts
+
frame
.
duration
):
take_frame
=
True
if
end_frame_index
is
not
None
and
end_frame_index
<=
(
frame
.
pts
+
frame
.
duration
):
last_frame
=
True
if
take_frame
:
if
video_out_frame_size
is
not
None
:
frame
=
frame
.
reformat
(
width
=
video_out_frame_size
[
0
],
height
=
video_out_frame_size
[
1
],
format
=
"rgb24"
,
interpolation
=
"BILINEAR"
,
)
else
:
frame
=
frame
.
reformat
(
format
=
"rgb24"
)
clip_frames
.
append
(
torch
.
from_numpy
(
frame
.
to_ndarray
()))
if
clip_timestamp_start
is
None
:
clip_timestamp_start
=
float
(
frame
.
pts
*
frame
.
time_base
)
clip_timestamp_end
=
float
((
frame
.
pts
+
frame
.
duration
)
*
frame
.
time_base
)
previous_frame_index
+=
1
if
last_frame
:
break
if
clip_timestamp_start
is
not
None
and
clip_timestamp_end
is
not
None
:
video_clips_frames
.
append
(
clip_frames
)
video_clips_timestamps
.
append
((
clip_timestamp_start
,
clip_timestamp_end
))
# Stack frames within each clip
out_video_clips
=
[
torch
.
stack
(
clip_frames
).
permute
((
0
,
3
,
1
,
2
))
for
clip_frames
in
video_clips_frames
]
return
out_video_clips
,
video_clips_timestamps
def
get_audio
(
self
)
->
AVData
:
"""Get the entire audio data from the stream."""
audio_clips
,
audio_timestamps
=
self
.
get_audio_clips
(
audio_clip_ranges
=
[(
0
,
float
(
"inf"
))])
return
AVData
(
video_clips
=
[],
video_timestamps
=
[],
audio_clips
=
audio_clips
,
audio_timestamps
=
audio_timestamps
,
)
def
get_audio_clips
(
self
,
audio_clip_ranges
:
Sequence
[
tuple
[
float
,
float
]],
audio_unit
:
Literal
[
"samples"
,
"seconds"
]
=
"seconds"
,
)
->
tuple
[
list
[
torch
.
Tensor
],
list
[
tuple
[
float
,
float
]]]:
"""Get audio clips from the audio stream.
Args:
audio_clip_ranges: List of audio clip start and end positions in the given unit (see audio_unit)
audio_unit: Unit of the audio clip positions ("samples" for sample number, "seconds" for timestamp)
Returns:
A tuple containing:
- audio_clips: List of audio clips
- audio_clips_timestamps: List of timestamps for each audio clip start and end in seconds
"""
assert
audio_unit
in
(
"samples"
,
"seconds"
)
self
.
stream
.
seek
(
0
)
# Reset the video stream so that pyav can read the entire container
with
av_open
(
self
.
stream
)
as
input_container
:
assert
len
(
input_container
.
streams
.
audio
)
>
0
,
(
"No audio stream found, but audio_clips are requested"
)
audio_stream
=
input_container
.
streams
.
audio
[
0
]
audio_sample_rate
=
audio_stream
.
sample_rate
assert
audio_sample_rate
,
"Audio streams without sample rate are not supported"
if
audio_unit
==
"samples"
:
# Convert from samples to seconds
audio_clip_ranges
=
[
(
float
(
clip
[
0
]
/
audio_sample_rate
),
float
(
clip
[
1
]
/
audio_sample_rate
),
)
for
clip
in
audio_clip_ranges
]
out_audio_clips
:
list
[
torch
.
Tensor
]
=
[]
out_audio_clips_timestamps
:
list
[
tuple
[
float
,
float
]]
=
[]
def
audio_frame_array
(
frame
:
av
.
AudioFrame
)
->
np
.
ndarray
:
if
frame
.
format
.
is_planar
:
arr_processed
=
frame
.
to_ndarray
()
# Already (channels, samples)
else
:
# Calculate the number of channels and samples
channels
=
int
(
frame
.
layout
.
nb_channels
)
samples
=
int
(
frame
.
samples
)
# Reshape the interleaved data to (samples, channels), then transpose to (channels, samples)
arr_processed
=
np
.
reshape
(
frame
.
to_ndarray
(),
(
samples
,
channels
)).
transpose
(
1
,
0
)
return
arr_processed
for
start_time
,
end_time
in
audio_clip_ranges
:
# Seek near start time, but rounded down to the nearest frame
input_container
.
seek
(
int
(
start_time
*
av
.
time_base
))
if
end_time
!=
float
(
"inf"
):
desired_duration
=
end_time
-
start_time
desired_sample_count
=
int
(
desired_duration
*
audio_sample_rate
+
0.5
)
else
:
desired_sample_count
=
None
clip_start_time
=
None
clip_end_time
=
None
decoded_samples
=
[]
decoded_sample_count
=
0
previous_frame
=
None
for
frame
in
input_container
.
decode
(
audio
=
0
):
assert
frame
.
pts
is
not
None
,
"Audio frame has no PTS timestamp"
cur_frame_time
=
float
(
frame
.
pts
*
frame
.
time_base
)
cur_frame_duration
=
float
(
frame
.
duration
*
frame
.
time_base
)
if
cur_frame_time
<
start_time
:
# Skip frames before the start time
previous_frame
=
frame
continue
if
clip_start_time
is
None
:
# This is our first matching frame
if
previous_frame
is
not
None
:
# We have a previous frame that we need to crop to the start time
prev_start_time
=
float
(
previous_frame
.
pts
*
previous_frame
.
time_base
)
prev_frame_array
=
audio_frame_array
(
previous_frame
)
prev_frame_array
=
prev_frame_array
[
:,
int
((
start_time
-
prev_start_time
)
*
audio_sample_rate
+
0.5
)
:
]
decoded_samples
.
append
(
prev_frame_array
)
decoded_sample_count
+=
prev_frame_array
.
shape
[
1
]
clip_start_time
=
start_time
clip_end_time
=
prev_start_time
+
cur_frame_duration
else
:
clip_start_time
=
cur_frame_time
# Stop decoding if the end of the frame is past the end time
if
cur_frame_time
+
cur_frame_duration
>=
end_time
:
# Crop the last frame to the end time
last_frame_array
=
audio_frame_array
(
frame
)
additional_samples
=
int
(
(
end_time
-
cur_frame_time
)
*
audio_sample_rate
+
0.5
)
projected_total_samples
=
decoded_sample_count
+
additional_samples
projected_total_samples
=
decoded_sample_count
+
additional_samples
if
(
desired_sample_count
is
not
None
and
0
<
abs
(
projected_total_samples
-
desired_sample_count
)
<
2
):
# We are within 2 samples of the desired duration, let's adjust
# the last frame so that we get the desired duration
additional_samples
=
desired_sample_count
-
decoded_sample_count
last_frame_array
=
last_frame_array
[:,
:
additional_samples
]
decoded_samples
.
append
(
last_frame_array
)
decoded_sample_count
+=
last_frame_array
.
shape
[
1
]
clip_end_time
=
end_time
break
frame_nd
=
audio_frame_array
(
frame
)
# (channels, samples)
decoded_samples
.
append
(
frame_nd
)
decoded_sample_count
+=
frame_nd
.
shape
[
1
]
clip_end_time
=
cur_frame_time
+
cur_frame_duration
if
decoded_samples
:
# Combine all channels/samples along samples axis
clip_all
=
np
.
concatenate
(
decoded_samples
,
axis
=-
1
)
# (channels, total_samples)
if
clip_start_time
is
not
None
and
clip_end_time
is
not
None
:
out_audio_clips
.
append
(
torch
.
from_numpy
(
clip_all
))
out_audio_clips_timestamps
.
append
((
clip_start_time
,
clip_end_time
))
return
out_audio_clips
,
out_audio_clips_timestamps
def
get_video_with_audio
(
self
)
->
AVData
:
"""Get the entire video and audio data from the stream."""
return
self
.
get_clips
(
video_clip_ranges
=
[(
0
,
float
(
"inf"
))],
audio_clip_ranges
=
[(
0
,
float
(
"inf"
))],
video_unit
=
"seconds"
,
audio_unit
=
"seconds"
,
)
def
get_clips
(
self
,
video_clip_ranges
:
Optional
[
Sequence
[
tuple
[
float
,
float
]]]
=
None
,
audio_clip_ranges
:
Optional
[
Sequence
[
tuple
[
float
,
float
]]]
=
None
,
video_unit
:
Literal
[
"frames"
,
"seconds"
]
=
"seconds"
,
audio_unit
:
Literal
[
"samples"
,
"seconds"
]
=
"seconds"
,
video_out_frame_size
:
Optional
[
tuple
[
int
,
int
]]
=
None
,
)
->
AVData
:
"""Get clips from the video and/or audio streams.
Given a list of (start, end) tuples, this method will decode the video and/or audio clips
at the specified start and end times. The units of the start and end times are specified by
the `video_unit` and `audio_unit` arguments.
Args:
video_clip_ranges: List of video clip start and end positions in the given unit (see video_unit)
audio_clip_ranges: List of audio clip start and end positions in the given unit (see audio_unit)
video_unit: Unit of the video clip positions ("frames" for frame number, "seconds" for timestamp)
audio_unit: Unit of the audio clip positions ("samples" for sample number, "seconds" for timestamp)
video_out_frame_size: Output size for video frames (width, height), or None to use the original frame size
Returns:
AVData containing the decoded video and audio clips
"""
if
video_clip_ranges
is
not
None
:
ret_video_clips
,
ret_video_clips_timestamps
=
self
.
get_video_clips
(
video_clip_ranges
,
video_unit
,
video_out_frame_size
)
else
:
ret_video_clips
=
[]
ret_video_clips_timestamps
=
[]
if
audio_clip_ranges
is
not
None
:
ret_audio_clips
,
ret_audio_clips_timestamps
=
self
.
get_audio_clips
(
audio_clip_ranges
,
audio_unit
)
else
:
ret_audio_clips
=
[]
ret_audio_clips_timestamps
=
[]
return
AVData
(
video_clips
=
ret_video_clips
,
video_timestamps
=
ret_video_clips_timestamps
,
audio_clips
=
ret_audio_clips
,
audio_timestamps
=
ret_audio_clips_timestamps
,
)
def
get_frames
(
self
,
video_decode_audio
:
bool
=
False
,
)
->
Optional
[
AVData
]:
"""Decode the entire audio/video data and return an AVData object.
Args:
video_decode_audio: Whether to decode audio from video
Returns:
VideoData containing the decoded frames and metadata, or None if decoding failed
The video tensor is in the shape (frames, channels, height, width)
The audio tensor is in the shape (channels, samples)
"""
extension
=
self
.
_get_extension
()
if
extension
is
not
None
:
extension
=
extension
.
lower
()
if
extension
in
(
"mov"
,
"mp4"
,
"webm"
,
"mkv"
,
"avi"
,
"m4v"
):
if
video_decode_audio
:
return
self
.
get_video_with_audio
()
else
:
return
self
.
get_video
()
elif
extension
in
(
"flac"
,
"mp3"
,
"wav"
):
return
self
.
get_audio
()
else
:
return
None
def
_get_extension
(
self
)
->
Optional
[
str
]:
"""Get the file extension from the raw data."""
# Try to guess the file type using the first few bytes
self
.
stream
.
seek
(
0
)
# Reset stream position before guessing
ftype
=
filetype
.
guess
(
self
.
stream
)
if
ftype
is
None
:
return
None
return
ftype
.
extension
def
get_video_fps
(
self
)
->
float
:
"""Get the FPS of the video stream."""
metadata
=
self
.
get_metadata
(
get_video
=
True
,
get_video_duration
=
False
,
get_video_frame_count
=
False
,
get_video_frame_size
=
False
,
get_audio
=
False
,
)
assert
metadata
.
video_fps
is
not
None
return
metadata
.
video_fps
def
get_audio_samples_per_second
(
self
)
->
int
:
"""Get the number of samples per second of the audio stream."""
metadata
=
self
.
get_metadata
(
get_video
=
False
,
get_audio
=
True
,
get_audio_duration
=
False
,
)
assert
metadata
.
audio_sample_rate
is
not
None
return
metadata
.
audio_sample_rate
def
has_audio_stream
(
self
)
->
bool
:
"""Check if the stream has an audio stream."""
self
.
stream
.
seek
(
0
)
with
av_open
(
self
.
stream
)
as
input_container
:
return
len
(
input_container
.
streams
.
audio
)
>
0
def
has_video_stream
(
self
)
->
bool
:
"""Check if the stream has a video stream."""
self
.
stream
.
seek
(
0
)
with
av_open
(
self
.
stream
)
as
input_container
:
return
len
(
input_container
.
streams
.
video
)
>
0
def
get_audio_duration
(
self
)
->
Optional
[
float
]:
"""Get the duration of the audio stream.
Returns:
The duration of the audio stream in seconds
"""
metadata
=
self
.
get_metadata
(
get_video
=
False
,
get_audio
=
True
,
get_audio_duration
=
True
,
)
return
metadata
.
audio_duration
@
overload
def
get_video_duration
(
self
,
get_frame_count
:
Literal
[
True
])
->
tuple
[
Optional
[
float
],
int
]:
...
@
overload
def
get_video_duration
(
self
,
get_frame_count
:
bool
=
False
)
->
tuple
[
Optional
[
float
],
Optional
[
int
]]:
...
def
get_video_duration
(
self
,
get_frame_count
:
bool
=
False
)
->
tuple
[
Optional
[
float
],
Optional
[
int
]]:
"""Get the duration of the video stream.
Args:
get_frame_count: Whether to return the number of frames in the video. This is a more costly operation.
Returns:
A tuple containing the duration in seconds, and the number of frames in the video
"""
metadata
=
self
.
get_metadata
(
get_video
=
True
,
get_video_duration
=
True
,
get_video_frame_count
=
get_frame_count
,
get_video_frame_size
=
False
,
get_audio
=
False
,
get_audio_duration
=
False
,
)
return
metadata
.
video_duration
,
metadata
.
video_num_frames
def
get_metadata
(
self
,
get_video
:
bool
=
True
,
get_video_duration
:
bool
=
True
,
get_video_frame_count
:
bool
=
True
,
get_video_frame_size
:
bool
=
True
,
get_audio
:
bool
=
True
,
get_audio_duration
:
bool
=
True
,
)
->
"AVMetadata"
:
"""Get the metadata of the media object.
Args:
get_video: Compute video metadata.
get_video_duration: Compute video duration if not found in header.
get_video_frame_count: Compute video frame count if not found in header.
get_video_frame_size: Compute video frame size if not found in header.
get_audio: Compute audio metadata.
get_audio_duration: Compute audio duration if not found in header.
"""
self
.
stream
.
seek
(
0
)
with
av_open
(
self
.
stream
)
as
input_container
:
metadata
=
AVMetadata
()
if
get_video
and
input_container
.
streams
.
video
:
video_stream
=
input_container
.
streams
.
video
[
0
]
metadata
.
video_duration
=
video_stream
.
duration
if
get_video_duration
and
metadata
.
video_duration
is
None
:
# If duration isn't found in header the whole video is decoded to
# determine the duration.
metadata
.
video_num_frames
=
0
last_packet
=
None
for
packet
in
input_container
.
demux
(
video
=
0
):
if
packet
.
pts
is
not
None
:
metadata
.
video_num_frames
+=
1
last_packet
=
packet
if
last_packet
is
not
None
and
last_packet
.
duration
is
not
None
:
assert
last_packet
.
pts
is
not
None
metadata
.
video_duration
=
last_packet
.
pts
+
last_packet
.
duration
if
metadata
.
video_duration
is
not
None
:
if
video_stream
.
start_time
is
not
None
:
metadata
.
video_duration
-=
video_stream
.
start_time
if
video_stream
.
time_base
is
not
None
:
metadata
.
video_duration
*=
float
(
video_stream
.
time_base
)
if
get_video_frame_count
and
metadata
.
video_num_frames
is
None
:
metadata
.
video_num_frames
=
sum
(
1
for
p
in
input_container
.
demux
(
video
=
0
)
if
p
.
pts
is
not
None
)
if
video_stream
.
average_rate
is
not
None
:
metadata
.
video_fps
=
float
(
video_stream
.
average_rate
)
elif
metadata
.
video_num_frames
is
not
None
and
metadata
.
video_duration
is
not
None
:
metadata
.
video_fps
=
metadata
.
video_num_frames
/
metadata
.
video_duration
if
get_video_frame_size
:
input_container
.
seek
(
0
)
for
first_frame
in
input_container
.
decode
(
video
=
0
):
metadata
.
video_width
=
first_frame
.
width
metadata
.
video_height
=
first_frame
.
height
break
else
:
metadata
.
video_width
=
video_stream
.
width
metadata
.
video_height
=
video_stream
.
height
if
get_audio
and
input_container
.
streams
.
audio
:
audio_stream
=
input_container
.
streams
.
audio
[
0
]
metadata
.
audio_sample_rate
=
audio_stream
.
sample_rate
metadata
.
audio_duration
=
audio_stream
.
duration
if
get_audio_duration
and
metadata
.
audio_duration
is
None
:
last_packet
=
None
input_container
.
seek
(
0
)
for
packet
in
input_container
.
demux
(
audio
=
0
):
if
packet
.
pts
is
not
None
:
last_packet
=
packet
if
last_packet
is
not
None
and
last_packet
.
duration
is
not
None
:
assert
last_packet
.
pts
is
not
None
metadata
.
audio_duration
=
last_packet
.
pts
+
last_packet
.
duration
if
metadata
.
audio_duration
is
not
None
:
if
audio_stream
.
start_time
is
not
None
:
metadata
.
audio_duration
-=
audio_stream
.
start_time
if
audio_stream
.
time_base
is
not
None
:
metadata
.
audio_duration
*=
float
(
audio_stream
.
time_base
)
metadata
.
audio_channels
=
audio_stream
.
channels
return
metadata
def
__repr__
(
self
):
return
f
"AVDecoder(stream=
{
self
.
stream
!
r
}
)"
class
AVWebdatasetDecoder
:
"""A decoder class for audio and video data that provides a consistent interface for decoding media files.
This class encapsulates the decoding parameters and provides a callable interface that can be used
with webdataset or other data loading pipelines. It supports both video and audio decoding with
configurable parameters for frame extraction, resizing, and audio clip extraction.
Args:
video_decode_audio: Whether to decode audio from video files. If True, audio will be
extracted alongside video frames.
av_decode: If "AVDecoder", returns an AVDecoder instance for flexible decoding. If "torch",
returns decoded VideoData.
Example:
>>> decoder = AVWebdatasetDecoder(
... video_decode_audio=True,
... av_decode="AVDecoder"
... )
>>> result = decoder("video.mp4", video_bytes)
"""
def
__init__
(
self
,
video_decode_audio
:
bool
,
av_decode
:
Literal
[
"torch"
,
"AVDecoder"
,
"pyav"
]
=
"AVDecoder"
,
)
->
None
:
self
.
video_decode_audio
=
video_decode_audio
self
.
av_decode
=
av_decode
def
read_av_data
(
self
,
data
:
bytes
)
->
AVDecoder
:
"""Decoder function that returns an AVData object for flexible decoding.
Args:
data: The raw bytes of the media file
Returns:
AVData object that can be used to decode the media with custom parameters
"""
return
AVDecoder
(
io
.
BytesIO
(
data
))
def
__call__
(
self
,
key
:
str
,
data
:
bytes
)
->
Optional
[
Union
[
AVData
,
AVDecoder
,
"av.container.InputContainer"
,
"av.container.OutputContainer"
]
]:
"""
Extract the video or audio data from default media extensions.
Args:
key: media file extension
data: raw media bytes
Returns:
If av_decode is "torch", returns VideoData containing the decoded frames and metadata.
If av_decode is "AVDecoder", returns an AVDecoder instance for flexible decoding.
If av_decode is "pyav", returns an av.container.InputContainer instance.
Returns None if decoding failed or file type is not supported.
"""
key
=
key
.
lower
()
if
not
any
(
key
==
ext
or
key
.
endswith
(
"."
+
ext
)
for
ext
in
(
"mp4"
,
"avi"
,
"mov"
,
"webm"
,
"mkv"
,
"flac"
,
"mp3"
,
"wav"
,
"flv"
)
):
return
None
av_decoder
=
self
.
read_av_data
(
data
)
if
self
.
av_decode
==
"AVDecoder"
:
return
av_decoder
elif
self
.
av_decode
==
"pyav"
:
return
av_open
(
av_decoder
.
stream
)
elif
self
.
av_decode
==
"torch"
:
return
av_decoder
.
get_frames
(
video_decode_audio
=
self
.
video_decode_audio
,
)
else
:
raise
ValueError
(
f
"Invalid av_decode value:
{
self
.
av_decode
}
"
)
@
edataclass
class
AVMetadata
:
"""Metadata of the media object."""
video_duration
:
Optional
[
float
]
=
None
video_num_frames
:
Optional
[
int
]
=
None
video_fps
:
Optional
[
float
]
=
None
video_width
:
Optional
[
int
]
=
None
video_height
:
Optional
[
int
]
=
None
audio_duration
:
Optional
[
float
]
=
None
audio_channels
:
Optional
[
int
]
=
None
audio_sample_rate
:
Optional
[
int
]
=
None
Megatron-Energon/src/megatron/energon/av/av_init.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
try
:
# Try importing optional dependencies
import
av
import
av.container
except
ImportError
:
pass
def
av_open
(
file
:
str
)
->
"av.container.InputContainer"
:
"""Open a file with PyAV.
This function is a wrapper around av.open that disables additional threads in the container.
"""
input_container
=
av
.
open
(
file
,
"r"
)
try
:
initialize_av_container
(
input_container
)
except
Exception
:
input_container
.
close
()
raise
return
input_container
def
initialize_av_container
(
input_container
:
"av.container.InputContainer"
)
->
None
:
"""Every PyAV container should be initialized with this function.
This function ensures that no additional threads are created.
This is to avoid deadlocks in ffmpeg when deallocating the container.
Furthermore, we cannot have multiple threads before forking the process when
using torch data loaders with multiple workers.
"""
for
stream
in
input_container
.
streams
:
cc
=
stream
.
codec_context
if
cc
is
not
None
:
cc
.
thread_type
=
"NONE"
cc
.
thread_count
=
0
Megatron-Energon/src/megatron/energon/av/fastseek/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
.fastseek
import
Fastseek
as
Fastseek
from
.keyframeinfo
import
KeyframeInfo
as
KeyframeInfo
Megatron-Energon/src/megatron/energon/av/fastseek/containers/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
Megatron-Energon/src/megatron/energon/av/fastseek/containers/matroska.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
collections
import
defaultdict
from
bitstring.bits
import
BitsType
from
ebmlite
import
MasterElement
,
loadSchema
from
sortedcontainers
import
SortedList
from
..keyframeinfo
import
KeyframeInfo
class
CueTrackPositions
:
track
:
int
def
__init__
(
self
,
el
:
MasterElement
)
->
None
:
for
c
in
el
:
if
c
.
name
==
"CueTrack"
:
self
.
track
=
c
.
value
class
CuePoint
:
time
:
int
track_positions
:
CueTrackPositions
def
__init__
(
self
,
el
:
MasterElement
)
->
None
:
for
c
in
el
:
if
c
.
name
==
"CueTime"
:
self
.
time
=
c
.
value
if
c
.
name
==
"CueTrackPositions"
:
self
.
track_positions
=
CueTrackPositions
(
c
)
def
parse_matroska
(
file
:
BitsType
)
->
SortedList
:
try
:
schema
=
loadSchema
(
"matroska.xml"
)
doc
=
schema
.
load
(
file
,
headers
=
True
)
except
(
KeyError
,
IOError
,
TypeError
)
as
e
:
raise
ValueError
(
f
"Matroska parsing failed with error
{
e
}
"
)
# Get cue times
stack
=
[
c
for
c
in
doc
if
c
.
name
==
"Segment"
]
cues
=
defaultdict
(
SortedList
)
while
len
(
stack
)
>
0
:
el
=
stack
.
pop
()
if
el
.
name
==
"CuePoint"
:
cue
=
CuePoint
(
el
)
cues
[
cue
.
track_positions
.
track
].
add
(
KeyframeInfo
(
cue
.
time
,
cue
.
time
))
elif
isinstance
(
el
,
MasterElement
):
stack
.
extend
([
c
for
c
in
el
if
c
.
name
in
[
"Cues"
,
"CuePoint"
]])
return
cues
Megatron-Energon/src/megatron/energon/av/fastseek/containers/mpeg.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
collections
import
defaultdict
from
itertools
import
accumulate
from
typing
import
Any
,
Generator
from
bitstring
import
ConstBitStream
,
Error
from
bitstring.bits
import
BitsType
from
sortedcontainers
import
SortedList
from
..keyframeinfo
import
KeyframeInfo
box_atoms
=
{
"moov"
,
"trak"
,
"mdia"
,
"minf"
,
"stbl"
,
"edts"
}
# Non-exhaustive
def
parse_table
(
cbs
:
ConstBitStream
,
table_size
:
int
,
struct
:
dict
[
str
,
str
])
->
dict
[
str
,
Any
]:
return
[
dict
(
zip
(
struct
.
keys
(),
cbs
.
readlist
(
", "
.
join
(
struct
.
values
()))))
for
_
in
range
(
table_size
)
]
class
Atom
:
skip_version_and_flags
:
bool
=
False
@
staticmethod
def
make_atom
(
cbs
:
ConstBitStream
)
->
"Atom"
:
size
:
int
=
cbs
.
read
(
"uint:32"
)
name
:
str
=
cbs
.
read
(
"bytes:4"
).
decode
(
"ascii"
)
box
:
bool
=
name
in
box_atoms
if
size
==
0
:
raise
RuntimeError
(
"MPEG parser detected a zero byte atom, this likely indicates a corrupt video."
)
subclass_list
=
[
c
for
c
in
Atom
.
__subclasses__
()
if
c
.
__name__
==
name
.
upper
()]
atom_class
:
type
=
Atom
if
len
(
subclass_list
)
>
0
:
atom_class
:
type
=
subclass_list
[
0
]
cbs
.
bytepos
+=
4
# Skip version and flags TODO not every atom needs this
atom
=
atom_class
(
size
,
name
,
box
)
atom
.
_parse
(
cbs
)
return
atom
def
__init__
(
self
,
size
:
int
,
name
:
str
,
box
:
bool
)
->
None
:
self
.
size
:
int
=
size
self
.
name
:
str
=
name
self
.
box
:
bool
=
box
def
_parse
(
self
,
cbs
:
ConstBitStream
)
->
None
:
if
not
self
.
box
:
cbs
.
bytepos
+=
self
.
size
-
8
def
__str__
(
self
)
->
str
:
return
f
"
{
self
.
name
=
}
,
{
self
.
size
=
}
,
{
self
.
box
=
}
"
class
TKHD
(
Atom
):
"""
Parses the track header atom, see https://developer.apple.com/documentation/quicktime-file-format/track_header_atom
"""
def
_parse
(
self
,
cbs
:
ConstBitStream
)
->
None
:
cbs
.
bytepos
+=
8
# skip creation time and modification time
self
.
track_id
:
int
=
cbs
.
read
(
"uint:32"
)
cbs
.
bytepos
+=
68
# Skip rest of structure
class
HDLR
(
Atom
):
"""
Parses the media handler atom, see https://developer.apple.com/documentation/quicktime-file-format/handler_reference_atom
NOTE: currently unused but could speed up parsing by skipping audio tracks
"""
def
_parse
(
self
,
cbs
:
ConstBitStream
)
->
None
:
self
.
component_type
=
cbs
.
read
(
"bytes:4"
).
decode
(
"ascii"
)
self
.
component_subtype
=
cbs
.
read
(
"bytes:4"
).
decode
(
"ascii"
)
# Skip rest of structure, the last field is variable so we need to use the total size
# 24 bytes already read (size (4), type (4), version (1), flags (3), component type (4), component subtype (4))
cbs
.
bytepos
+=
self
.
size
-
20
class
STSS
(
Atom
):
"""
Parses the sync sample atom https://developer.apple.com/documentation/quicktime-file-format/sample_table_atom/sync_sample_atom
"""
def
_parse
(
self
,
cbs
:
ConstBitStream
)
->
None
:
self
.
number_of_entries
:
int
=
cbs
.
read
(
"uint:32"
)
self
.
sync_sample_table
:
dict
[
str
,
Any
]
=
parse_table
(
cbs
,
self
.
number_of_entries
,
{
"number"
:
"uint:32"
}
)
class
STTS
(
Atom
):
"""
Parses the time to sample atom https://developer.apple.com/documentation/quicktime-file-format/time-to-sample_atom
"""
def
_parse
(
self
,
cbs
:
ConstBitStream
)
->
None
:
self
.
number_of_entries
:
int
=
cbs
.
read
(
"uint:32"
)
self
.
time_to_sample_table
:
dict
[
str
,
Any
]
=
parse_table
(
cbs
,
self
.
number_of_entries
,
{
"sample_count"
:
"uint:32"
,
"sample_duration"
:
"uint:32"
},
)
class
CTTS
(
Atom
):
"""
Parses the composition offset atom https://developer.apple.com/documentation/quicktime-file-format/composition_offset_atom
"""
def
_parse
(
self
,
cbs
:
ConstBitStream
)
->
None
:
self
.
number_of_entries
:
int
=
cbs
.
read
(
"uint:32"
)
self
.
composition_offset_table
:
dict
[
str
,
Any
]
=
parse_table
(
cbs
,
self
.
number_of_entries
,
{
"sample_count"
:
"uint:32"
,
"composition_offset"
:
"int:32"
,
"media_rate"
:
""
,
},
)
class
ELST
(
Atom
):
"""
Parses the edit list atom https://developer.apple.com/documentation/quicktime-file-format/edit_list_atom
"""
def
_parse
(
self
,
cbs
:
ConstBitStream
)
->
None
:
self
.
number_of_entries
:
int
=
cbs
.
read
(
"uint:32"
)
self
.
edit_list_table
:
dict
[
str
,
Any
]
=
parse_table
(
cbs
,
self
.
number_of_entries
,
{
"track_duration"
:
"uint:32"
,
"media_time"
:
"int:32"
,
"media_rate"
:
"int:32"
,
},
)
class
MDAT
(
Atom
):
"""
Parses the media data atom https: https://developer.apple.com/documentation/quicktime-file-format/movie_data_atom
This is only here to handle the unusual size handling of mdat, if the normal size field is set to 1
then the actual size is stored as a 64 bit integer
"""
def
_parse
(
self
,
cbs
:
ConstBitStream
)
->
None
:
if
self
.
size
==
1
:
cbs
.
bytepos
-=
4
# No version or flags for mdat
self
.
size
=
cbs
.
read
(
"uint:64"
)
seekto
=
self
.
size
-
16
else
:
seekto
=
self
.
size
-
12
if
cbs
.
bytepos
+
seekto
>=
(
cbs
.
len
/
8
):
raise
StopIteration
()
cbs
.
bytepos
+=
seekto
def
parse_atoms
(
file
:
BitsType
)
->
Generator
[
Atom
,
None
,
None
]:
try
:
cbs
=
ConstBitStream
(
file
)
while
cbs
.
pos
<
len
(
cbs
):
try
:
yield
Atom
.
make_atom
(
cbs
)
except
StopIteration
:
return
except
Error
as
e
:
raise
ValueError
(
f
"MPEG parsing failed with error
{
e
}
"
)
def
parse_mpeg
(
file
:
BitsType
)
->
dict
[
int
,
SortedList
]:
sync_samples
=
{}
decode_timestamps
=
{}
presentation_time_offsets
=
{}
start_offsets
=
defaultdict
(
int
)
current_track
=
-
1
for
a
in
parse_atoms
(
file
):
if
a
.
name
==
"tkhd"
:
a
:
TKHD
current_track
=
a
.
track_id
elif
a
.
name
==
"stts"
:
a
:
STTS
decode_timestamps
[
current_track
]
=
list
(
accumulate
(
sum
(
[
[
entry
[
"sample_duration"
]]
*
entry
[
"sample_count"
]
for
entry
in
a
.
time_to_sample_table
],
[
0
],
)
)
)
elif
a
.
name
==
"ctts"
:
a
:
CTTS
presentation_time_offsets
[
current_track
]
=
sum
(
[
[
entry
[
"composition_offset"
]]
*
entry
[
"sample_count"
]
for
entry
in
a
.
composition_offset_table
],
[],
)
elif
a
.
name
==
"stss"
:
a
:
STSS
sync_samples
[
current_track
]
=
[
ss
[
"number"
]
-
1
for
ss
in
a
.
sync_sample_table
]
elif
a
.
name
==
"elst"
:
# NOTE the "media_time" here is a "delay" between decoding and presenting the first sample.
# We follow the ffmpeg convention that the first frame displays at time 0 which means we should
# *subtract* this offset from the decoding time values rather than adding it to presentation time values
# TODO there can be more than one of these, figure out how to handle it
a
:
ELST
start_offsets
[
current_track
]
=
-
a
.
edit_list_table
[
0
][
"media_time"
]
keyframes
=
defaultdict
(
SortedList
)
try
:
for
track_id
in
sync_samples
.
keys
():
ptos
=
presentation_time_offsets
.
get
(
track_id
)
dts
=
decode_timestamps
[
track_id
]
for
keyframe_number
in
sync_samples
[
track_id
]:
pts
=
(
dts
[
keyframe_number
]
+
start_offsets
[
track_id
]
+
(
0
if
ptos
is
None
else
ptos
[
keyframe_number
])
)
keyframes
[
track_id
].
add
(
KeyframeInfo
(
keyframe_number
,
pts
))
except
(
KeyError
,
IndexError
)
as
e
:
raise
ValueError
(
f
"MPEG parsing failed with error
{
e
}
"
)
return
keyframes
Megatron-Energon/src/megatron/energon/av/fastseek/containers/probe.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
sortedcontainers
import
SortedList
from
...av_init
import
av_open
from
..keyframeinfo
import
KeyframeInfo
def
parse_probe
(
file
):
keyframes
=
{}
with
av_open
(
file
)
as
input_container
:
for
stream_idx
,
stream
in
enumerate
(
input_container
.
streams
.
video
):
packet_pts
=
[
(
index
,
p
.
pts
)
for
index
,
p
in
enumerate
(
input_container
.
demux
(
video
=
stream_idx
))
if
p
.
is_keyframe
]
packet_pts
.
sort
(
key
=
lambda
x
:
x
[
1
])
keyframes
[
stream
.
id
]
=
SortedList
([
KeyframeInfo
(
*
p
)
for
p
in
packet_pts
])
return
keyframes
Megatron-Energon/src/megatron/energon/av/fastseek/fastseek.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Literal
,
Optional
import
filetype
from
bitstring.bits
import
BitsType
from
sortedcontainers
import
SortedList
from
.containers.matroska
import
parse_matroska
from
.containers.mpeg
import
parse_mpeg
from
.containers.probe
import
parse_probe
from
.keyframeinfo
import
KeyframeInfo
class
Fastseek
:
"""
Gathers information from the video container file (e.g. metadata which requires minimal decoding)
to find keyframes in the video for fast seeking.
Information is returned in the form of KeyframeInfo structures which can be used by a decoding loop
to make informed decisions about the best seeking behavior
Currently supports:
- MP4/MOV: frames are indexed by number and frame counting can be used to get the exact frame
- Matroska/WebM: frames are indexed by time and inter-frame duration must be accounted for to get to the right frame
If your container is not listed above, pass "probe=True" to the constructor, this will use ffmpeg to parse the stream
without decoding it. Frames will be indexed by number. This is not as fast as using a supported container but is still
significantly faster than sequential decoding.
"""
keyframes
:
dict
[
int
,
SortedList
[
KeyframeInfo
]]
unit
:
Literal
[
"frames"
,
"pts"
]
mime
:
str
def
__init__
(
self
,
file
:
BitsType
,
probe
:
bool
=
False
)
->
None
:
"""Initialize the Fastseek object.
Args:
file: The video file data as a bitstring BitsType object. This should contain the raw bytes of the video file.
probe: If True, use ffmpeg to probe the stream without decoding. This is slower but works with any container format.
If False (default), attempt to parse the container format directly. Only works with MP4/MOV and Matroska/WebM.
Raises:
ValueError: If the file type cannot be determined or if the container format is not supported (when probe=False).
"""
if
probe
:
self
.
keyframes
=
parse_probe
(
file
)
self
.
unit
=
"frames"
else
:
ftype
=
filetype
.
guess
(
file
)
if
ftype
is
None
:
raise
ValueError
(
"Unable to determine file type (hint: try passing probe=True to the Fastseek constructor)"
)
self
.
mime
=
ftype
.
mime
if
ftype
.
mime
in
[
"video/mp4"
,
"video/quicktime"
]:
self
.
keyframes
=
parse_mpeg
(
file
)
self
.
unit
=
"frames"
elif
ftype
.
mime
in
[
"video/x-matroska"
,
"video/webm"
]:
self
.
keyframes
=
parse_matroska
(
file
)
self
.
unit
=
"pts"
else
:
raise
ValueError
(
f
"Unsupported container:
{
ftype
.
mime
}
(hint: try passing probe=True to the Fastseek constructor)"
)
if
len
(
self
.
keyframes
)
==
0
:
raise
ValueError
(
f
"The parser for
{
ftype
.
mime
}
was unable to find any streams (hint: try passing probe=True to the Fastseek constructor)"
)
if
all
(
len
(
kf
)
==
0
for
kf
in
self
.
keyframes
.
values
()):
raise
ValueError
(
f
"The parser for
{
ftype
.
mime
}
was unable to find any keyframes (hint: try passing probe=True to the Fastseek constructor)"
)
def
should_seek
(
self
,
current
:
int
,
target
:
int
,
stream
:
int
=
0
)
->
Optional
[
KeyframeInfo
]:
"""Determine if seeking to a keyframe is necessary to reach the target frame.
This method helps optimize video seeking by determining whether a seek operation
is needed to reach the target frame. It returns information about the nearest
keyframe only if seeking would be beneficial (i.e., if sequential decoding from
the current position would be less efficient).
Args:
current: The current frame number or timestamp (depending on container format)
target: The desired frame number or timestamp to seek to
stream: The video stream index to use. Defaults to 0.
Returns:
Information about the nearest keyframe if seeking would be beneficial,
or None if sequential decoding from current position is more efficient.
The KeyframeInfo contains the keyframe's position and timing information.
Note:
The units for current and target depend on the container format:
- For MP4/MOV: frame numbers (count-based)
- For Matroska/WebM: timestamps (time-based)
"""
nearest_iframe
:
KeyframeInfo
=
self
.
nearest_keyframe
(
target
,
stream
)
return
(
nearest_iframe
if
(
current
<
nearest_iframe
.
index
<=
target
)
or
(
target
<
current
)
else
None
)
def
nearest_keyframe
(
self
,
target
:
int
,
stream
:
int
=
0
)
->
KeyframeInfo
:
"""Find the nearest keyframe that comes before the target frame.
This method performs a binary search to find the keyframe that is closest to,
but not after, the target frame position. This is useful for determining the
optimal starting point for decoding to reach a specific frame.
Args:
target: The target frame number or timestamp to find the nearest keyframe for.
The unit (frame count or timestamp) depends on the container format.
stream: The video stream index to use. Defaults to 0.
Used when the container has multiple video streams.
Returns:
Information about the nearest keyframe before the target position.
Contains details like the keyframe's position, timestamp, and file offset.
Note:
The implementation currently uses a list-based approach for stream selection
as some video containers don't report track IDs correctly. This is a temporary
workaround and may be updated in the future.
"""
if
stream
>=
len
(
self
.
keyframes
):
raise
ValueError
(
f
"No stream with index
{
stream
}
"
)
stream_id
=
list
(
self
.
keyframes
.
keys
())[
stream
]
if
len
(
self
.
keyframes
[
stream_id
])
==
0
:
raise
ValueError
(
f
"No keyframes found for stream
{
stream
}
"
)
nearest_iframe_to_target_index
:
int
=
self
.
keyframes
[
stream_id
].
bisect_left
(
target
)
-
1
return
self
.
keyframes
[
stream_id
][
max
(
0
,
nearest_iframe_to_target_index
)]
Megatron-Energon/src/megatron/energon/av/fastseek/keyframeinfo.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
dataclasses
import
dataclass
@
dataclass
class
KeyframeInfo
:
"""
Information about a detected keyframe
The exact meaning of the information will vary per container, however in general
index: the unit of representation for a frame (e.g., frame number)
pts: a timestamp that can be used by a decoder to seek to exactly this frame
"""
index
:
int
pts
:
int
def
__eq__
(
self
,
other
)
->
bool
:
if
isinstance
(
other
,
KeyframeInfo
):
return
self
.
index
==
other
.
index
return
self
.
index
==
other
def
__lt__
(
self
,
other
)
->
bool
:
if
isinstance
(
other
,
KeyframeInfo
):
return
self
.
index
<
other
.
index
return
self
.
index
<
other
Megatron-Energon/src/megatron/energon/av/utils.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
List
,
Literal
,
Optional
,
Tuple
,
Union
,
overload
import
numpy
as
np
import
torch
from
megatron.energon.av.av_decoder
import
AVData
,
AVDecoder
def
get_clips_uniform
(
av_decoder
:
AVDecoder
,
clip_duration_seconds
:
float
,
num_clips
:
int
,
request_video
:
bool
=
False
,
request_audio
:
bool
=
False
,
video_out_frame_size
:
Optional
[
tuple
[
int
,
int
]]
=
None
,
)
->
AVData
:
"""Extracts a sequence of clips, such that each clip is of
equal duration and the clips are equidistant from each other.
Args:
av_decoder: An AVDecoder instance.
clip_duration_seconds: The duration of each clip in seconds.
num_clips: The number of clips to extract.
request_video: Whether to request video clips.
request_audio: Whether to request audio clips.
video_out_frame_size: The size of the video frames to output, or None to use the original size.
Returns:
An AVData object containing the extracted video and audio clips.
"""
if
not
request_video
and
not
request_audio
:
raise
ValueError
(
"You must request at least one of video or audio"
)
video_duration
=
float
(
"inf"
)
audio_duration
=
float
(
"inf"
)
if
request_video
:
video_duration
,
_
=
av_decoder
.
get_video_duration
()
if
video_duration
is
None
:
raise
ValueError
(
"No video duration found"
)
if
request_audio
:
audio_duration
=
av_decoder
.
get_audio_duration
()
if
audio_duration
is
None
:
raise
ValueError
(
"No audio duration found"
)
# Typically, audio and video don't have the exact same duration, so we take the minimum
# so that we can safely extract clips of equal duration.
total_duration
=
min
(
video_duration
,
audio_duration
)
assert
total_duration
!=
float
(
"inf"
)
if
clip_duration_seconds
==
0
:
# Special case of single frames: End point should be start of last frame
video_fps
=
av_decoder
.
get_video_fps
()
video_spf
=
1
/
video_fps
first_start_time
=
video_spf
*
0.5
last_start_time
=
total_duration
-
video_spf
*
0.5
else
:
first_start_time
=
0
last_start_time
=
total_duration
-
clip_duration_seconds
clips
=
[
(
float
(
start_time
),
float
(
start_time
+
clip_duration_seconds
))
for
start_time
in
np
.
linspace
(
first_start_time
,
last_start_time
,
num_clips
)
]
return
av_decoder
.
get_clips
(
video_clip_ranges
=
clips
if
request_video
else
None
,
audio_clip_ranges
=
clips
if
request_audio
else
None
,
video_unit
=
"seconds"
,
audio_unit
=
"seconds"
,
video_out_frame_size
=
video_out_frame_size
,
)
@
overload
def
get_single_frames_uniform
(
av_decoder
:
"AVDecoder"
,
num_frames
:
int
,
*
,
video_out_frame_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
return_timestamps
:
Literal
[
False
]
=
False
,
)
->
torch
.
Tensor
:
...
@
overload
def
get_single_frames_uniform
(
av_decoder
:
"AVDecoder"
,
num_frames
:
int
,
*
,
video_out_frame_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
return_timestamps
:
Literal
[
True
],
)
->
Tuple
[
torch
.
Tensor
,
List
[
float
]]:
...
def
get_single_frames_uniform
(
av_decoder
:
AVDecoder
,
num_frames
:
int
,
*
,
video_out_frame_size
:
Optional
[
tuple
[
int
,
int
]]
=
None
,
return_timestamps
:
bool
=
False
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
list
[
float
]]]:
"""Extracts a sequence of clips, such that each clip contains
only a single frame and the frames are equidistant from each other.
Args:
av_decoder: An AVDecoder instance.
num_frames: The number of frames to extract.
video_out_frame_size: The size of the video frames to output, or None to use the original size.
Returns:
A tensor of shape (num_frames, channels, height, width) containing the extracted frames.
"""
av_data
=
get_clips_uniform
(
av_decoder
=
av_decoder
,
clip_duration_seconds
=
0
,
num_clips
=
num_frames
,
request_video
=
True
,
request_audio
=
False
,
video_out_frame_size
=
video_out_frame_size
,
)
if
len
(
av_data
.
video_clips
)
==
0
:
raise
ValueError
(
"No video frames found"
)
# Concatenate all video single-frame clips to form a single tensor
video_tensor
=
torch
.
cat
(
av_data
.
video_clips
,
dim
=
0
)
if
return_timestamps
:
return
video_tensor
,
[
t
for
t
,
_
in
av_data
.
video_timestamps
]
else
:
return
video_tensor
Megatron-Energon/src/megatron/energon/bracecollapse.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
re
from
collections
import
defaultdict
from
typing
import
List
,
Tuple
__all__
=
[
"collapse"
]
"""Helper functions for string tokenization and expression building."""
_num_re
=
re
.
compile
(
r
"\d+"
)
def
_tokenize
(
s
:
str
)
->
Tuple
[
List
[
str
],
List
[
Tuple
[
str
,
int
,
int
]]]:
"""
Split the string into literal and numeric parts.
Always starts with a literal (sometimes empty)
Example:
"partition_00/shard_000000.tar" ->
lits = ["partition_", "/shard_", ".tar"]
nums = [("00", 0, 2), ("000000", 0, 6)]
Args:
s: Input string to tokenize.
Returns:
Tuple containing:
- lits: List of literal pieces, length = #nums + 1
- nums: List of tuples (raw, value, width) where:
- raw: original substring (keeps zero-padding)
- value: int(raw)
- width: len(raw)
"""
lits
,
nums
=
[],
[]
pos
=
0
for
m
in
_num_re
.
finditer
(
s
):
lits
.
append
(
s
[
pos
:
m
.
start
()])
raw
=
m
.
group
(
0
)
nums
.
append
((
raw
,
int
(
raw
),
len
(
raw
)))
pos
=
m
.
end
()
lits
.
append
(
s
[
pos
:])
return
lits
,
nums
def
_build_expr
(
lits
:
List
[
str
],
nums
:
List
[
Tuple
[
str
,
int
,
int
]],
var_idx
:
int
,
start_raw
:
str
,
end_raw
:
str
,
)
->
str
:
"""
Re-assemble the template, replacing slot with brace expansion syntax.
Args:
lits: List of literal pieces of the string.
nums: List of numeric parts as tuples (raw, value, width).
var_idx: Index of the numeric slot to replace with range.
start_raw: Starting value (raw string).
end_raw: Ending value (raw string).
Returns:
String with brace expansion syntax.
"""
parts
:
List
[
str
]
=
[]
for
i
in
range
(
len
(
nums
)):
parts
.
append
(
lits
[
i
])
if
i
==
var_idx
:
parts
.
append
(
f
"{{
{
start_raw
}
..
{
end_raw
}
}}"
)
else
:
parts
.
append
(
nums
[
i
][
0
])
parts
.
append
(
lits
[
-
1
])
return
""
.
join
(
parts
)
def
_streaming_mode
(
strings
:
List
[
str
])
->
List
[
str
]:
"""
Compress strings in order-preserving streaming mode.
Complexity: O(N)
Args:
strings: List of strings to compress.
Returns:
List of compressed expressions.
"""
# Result list with brace expressions
out
:
List
[
str
]
=
[]
# Total number of strings
n
=
len
(
strings
)
# Current index
i
=
0
while
i
<
n
:
lits0
,
nums0
=
_tokenize
(
strings
[
i
])
# Strings without numbers can never form a range
if
not
nums0
:
out
.
append
(
strings
[
i
])
i
+=
1
continue
# Which numeric slot is changing?
var_idx
:
int
=
-
1
start_raw
:
str
=
""
prev_nums
=
nums0
# Last index in the current candidate range
run_end
=
i
# Starting with string `i` as the template, check subsequent strings `j` as long as they match
j
=
i
+
1
while
j
<
n
:
lits1
,
nums1
=
_tokenize
(
strings
[
j
])
# Template must be identical (same number of literals and numeric slots)
if
lits1
!=
lits0
or
len
(
nums1
)
!=
len
(
nums0
):
break
# Exactly one numeric slot may differ ─ find it
diff_slots
=
[
k
for
k
,
(
a
,
b
)
in
enumerate
(
zip
(
prev_nums
,
nums1
))
if
a
[
1
]
!=
b
[
1
]]
if
len
(
diff_slots
)
!=
1
:
break
k
=
diff_slots
[
0
]
# Width must stay the same
if
nums1
[
k
][
2
]
!=
prev_nums
[
k
][
2
]:
break
# Same changing slot for the whole run
if
var_idx
==
-
1
:
var_idx
,
start_raw
=
k
,
nums0
[
k
][
0
]
elif
var_idx
!=
k
:
break
# Contiguous ascending (+1) only
if
nums1
[
k
][
1
]
!=
prev_nums
[
k
][
1
]
+
1
:
break
# OK - extend run
run_end
=
j
prev_nums
=
nums1
j
+=
1
run_len
=
run_end
-
i
+
1
if
run_len
>=
2
and
var_idx
!=
-
1
:
# Emit range
end_raw
=
prev_nums
[
var_idx
][
0
]
out
.
append
(
_build_expr
(
lits0
,
nums0
,
var_idx
,
start_raw
,
end_raw
))
i
=
run_end
+
1
else
:
# Single string
out
.
append
(
strings
[
i
])
i
+=
1
return
out
def
_bucket_greedy_mode
(
strings
:
List
[
str
])
->
List
[
str
]:
"""
Compress strings using bucket + greedy algorithm to minimize pattern count.
Complexity: O(N log N)
Args:
strings: List of strings to compress.
Returns:
List of compressed expressions (order may change).
"""
# Tokenize all stringsonce
tokenized
=
[]
for
s
in
strings
:
lits
,
nums
=
_tokenize
(
s
)
tokenized
.
append
({
"lits"
:
lits
,
"nums"
:
nums
,
"orig"
:
s
})
# Build buckets
buckets
:
defaultdict
=
defaultdict
(
list
)
for
idx
,
t
in
enumerate
(
tokenized
):
lits
,
nums
=
t
[
"lits"
],
t
[
"nums"
]
for
var_idx
,
(
raw
,
value
,
width
)
in
enumerate
(
nums
):
key_tokens
=
[]
for
k
in
range
(
len
(
nums
)):
key_tokens
.
append
(
lits
[
k
])
key_tokens
.
append
(
None
if
k
==
var_idx
else
nums
[
k
][
0
])
key_tokens
.
append
(
lits
[
-
1
])
key
=
(
var_idx
,
tuple
(
key_tokens
))
buckets
[
key
].
append
((
idx
,
value
,
raw
,
width
))
# Find contiguous runs inside every bucket
# candidate contain tuples (covered_size, indices, expression)
candidates
=
[]
for
(
var_idx
,
_
),
entries
in
buckets
.
items
():
# Sort by numeric *value*
entries
.
sort
(
key
=
lambda
e
:
e
[
1
])
# Start with the first entry
run
=
[
entries
[
0
]]
def
_flush
():
if
len
(
run
)
>=
2
:
idxs
=
[
e
[
0
]
for
e
in
run
]
start_raw
,
end_raw
=
run
[
0
][
2
],
run
[
-
1
][
2
]
t0
=
tokenized
[
idxs
[
0
]]
expr
=
_build_expr
(
t0
[
"lits"
],
t0
[
"nums"
],
var_idx
,
start_raw
,
end_raw
)
candidates
.
append
((
len
(
run
),
idxs
,
expr
))
# Check subsequent entries
for
e
in
entries
[
1
:]:
prev
=
run
[
-
1
]
if
e
[
1
]
==
prev
[
1
]
+
1
and
e
[
3
]
==
prev
[
3
]:
# contiguous, same width
run
.
append
(
e
)
else
:
_flush
()
run
=
[
e
]
_flush
()
# Greedy cover: longest first, no overlaps
candidates
.
sort
(
key
=
lambda
c
:
(
-
c
[
0
],
c
[
2
]))
# stable order
covered
=
[
False
]
*
len
(
strings
)
out
:
List
[
str
]
=
[]
for
_
,
idxs
,
expr
in
candidates
:
if
all
(
not
covered
[
i
]
for
i
in
idxs
):
# keep only disjoint
out
.
append
(
expr
)
for
i
in
idxs
:
covered
[
i
]
=
True
# Leftover single strings
out
.
extend
(
t
[
"orig"
]
for
i
,
t
in
enumerate
(
tokenized
)
if
not
covered
[
i
])
return
out
def
collapse
(
strings
:
List
[
str
],
keep_order
:
bool
=
False
)
->
List
[
str
]:
"""
Reverse-brace-expand a list of strings.
Args:
strings: The filenames / words to be compressed.
keep_order: Whether to preserve original order.
* False → minimise the **count** of patterns (order may change).
* True → keep the order of the input in the expanded output.
Returns:
List of brace-expressions plus (possibly) single strings.
"""
return
_streaming_mode
(
strings
)
if
keep_order
else
_bucket_greedy_mode
(
strings
)
if
__name__
==
"__main__"
:
"""Self-test for the module."""
import
time
ex1
=
[
"/path/to/file001.tar.gz"
,
"/path/to/file003.tar.gz"
,
"/path/to/file002.tar.gz"
,
]
ex2
=
[
"python2"
,
"python3.1"
,
"python3.2"
,
"python3.5"
]
ex3
=
[
"a2b3c"
,
"a4b3c"
,
"a3b3c"
,
"a4b2c"
,
"a5b2c"
]
for
case
in
(
ex1
,
ex2
,
ex3
):
print
(
"#"
,
case
)
print
(
"unordered :"
,
collapse
(
case
))
print
(
"ordered :"
,
collapse
(
case
,
keep_order
=
True
))
print
()
# ex4 = [f"shard_{x:06d}" for x in range(10_000_000)]
ex4
=
[
f
"partition_
{
partition
:
02
d
}
/shard_
{
x
:
06
d
}
.tar"
for
partition
in
range
(
5
)
for
x
in
range
(
1_000_000
)
]
start
=
time
.
perf_counter
()
res
=
collapse
(
ex4
,
keep_order
=
True
)
print
(
res
)
print
(
time
.
perf_counter
()
-
start
,
"seconds"
)
Megatron-Energon/src/megatron/energon/cache/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.cache.base
import
CachePool
,
FileStore
,
FileStoreDecoder
,
Lazy
,
MockLazy
from
megatron.energon.cache.file_cache_pool
import
FileCacheLazy
,
FileStoreCachePool
from
megatron.energon.cache.file_store
import
DecodeFileStore
,
SystemFileStore
from
megatron.energon.cache.no_cache
import
DirectLazy
,
NoCachePool
__all__
=
[
"CachePool"
,
"DecodeFileStore"
,
"DirectLazy"
,
"FileCacheLazy"
,
"FileStore"
,
"FileStoreCachePool"
,
"FileStoreDecoder"
,
"Lazy"
,
"MockLazy"
,
"NoCachePool"
,
"SystemFileStore"
,
]
Megatron-Energon/src/megatron/energon/cache/base.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Generic
,
TypeVar
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.source_info
import
SourceInfo
,
add_source_info
T
=
TypeVar
(
"T"
)
class
FileStore
(
Generic
[
T
]):
"""Base type for a dataset that can be accessed randomly by sample key."""
@
abstractmethod
def
__getitem__
(
self
,
key
:
str
)
->
tuple
[
T
,
SourceInfo
]:
"""Returns the data for the given key."""
...
def
get
(
self
,
key
:
str
,
sample
:
Any
=
None
)
->
Any
:
"""Returns the data for the given key and adds the source info to the sample."""
data
,
source_info
=
self
[
key
]
add_source_info
(
sample
,
source_info
)
return
data
@
abstractmethod
def
get_path
(
self
)
->
str
:
"""Returns the path to the dataset."""
...
@
edataclass
class
Lazy
(
Generic
[
T
]):
"""
Abstract base class for lazy references to data.
"""
ds
:
FileStore
fname
:
str
pool
:
"CachePool"
@
abstractmethod
def
get
(
self
,
sample
:
Any
=
None
)
->
T
:
"""
Get the lazy data now and adds the source info to the sample.
"""
...
def
__hash__
(
self
)
->
int
:
"""Allows usage in sets and dicts as key."""
return
hash
((
id
(
self
.
ds
),
self
.
fname
))
def
__eq__
(
self
,
other
:
Any
)
->
bool
:
"""Allows usage in sets and dicts as key. Compares the data source and the filename."""
if
not
isinstance
(
other
,
Lazy
):
return
False
return
self
.
ds
is
other
.
ds
and
self
.
fname
==
other
.
fname
@
edataclass
class
MockLazy
(
Lazy
[
T
]):
"""
Mock object, which can be used as a Lazy. Allows the user to set the function to retrieve the
data. May be used to create a Lazy that is initialized from a function.
"""
ds
:
FileStore
fname
:
str
pool
:
"CachePool"
get_fn
:
Callable
[[
str
],
T
]
def
__init__
(
self
,
fname
:
str
,
get_fn
:
Callable
[[
str
],
T
]):
"""
Initialize the MockLazy object.
Args:
fname: The file name of the mock object (may be used by the user).
get_fn: The function to retrieve/generate the data.
"""
self
.
ds
=
None
self
.
fname
=
fname
self
.
pool
=
None
self
.
get_fn
=
get_fn
def
get
(
self
,
sample
:
Any
=
None
)
->
T
:
"""
Get the lazy data now and adds no source info to the sample.
"""
return
self
.
get_fn
(
self
.
fname
)
def
__hash__
(
self
)
->
int
:
return
hash
((
self
.
fname
,
self
.
get_fn
))
def
__eq__
(
self
,
other
:
Any
)
->
bool
:
if
not
isinstance
(
other
,
MockLazy
):
return
False
return
self
.
fname
==
other
.
fname
and
self
.
get_fn
==
other
.
get_fn
def
__repr__
(
self
)
->
str
:
return
f
"MockLazy(fname=
{
self
.
fname
!
r
}
, get_fn=
{
self
.
get_fn
!
r
}
)"
class
CachePool
(
ABC
):
"""
A cache pool allows to load the needed data in the background and access it later.
The most important example being `FileStoreCachePool` which caches data on a local SSD disk.
To request data, use `get_lazy` to get a `Lazy` object.
Then, call `Lazy.get()` to get the data later on.
"""
@
abstractmethod
def
get
(
self
,
ds
:
FileStore
,
fname
:
str
,
sample
:
Any
=
None
)
->
Any
:
"""
Get the data for a given file and adds the source info to the sample.
"""
...
@
abstractmethod
def
get_lazy
(
self
,
ds
:
FileStore
,
fname
:
str
)
->
Lazy
:
"""
Get a lazy reference to the data for a given file.
"""
...
@
abstractmethod
def
to_cache
(
self
,
data
:
T
,
name
:
str
)
->
Lazy
[
T
]:
"""
Move the data to the cache and return a lazy to fetch it later.
The name is used for debugging the data in the cache.
"""
...
@
abstractmethod
def
close
(
self
)
->
None
:
"""
Close the cache pool.
"""
...
class
FileStoreDecoder
(
ABC
):
"""
Abstract base class for decoders.
"""
@
abstractmethod
def
decode
(
self
,
fname
:
str
,
data
:
bytes
)
->
Any
:
"""
Decode the specified file (i.e. path/key.ext). The extension is used to select the decoder.
Args:
fname: The file name of the file to decode.
raw: The raw bytes of the file to decode.
Returns:
The decoded field's data.
"""
...
Megatron-Energon/src/megatron/energon/cache/file_cache_pool.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
hashlib
import
os
import
pickle
import
random
import
string
import
tempfile
import
threading
import
uuid
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Literal
,
Optional
,
Tuple
,
TypeVar
from
megatron.energon.cache.base
import
CachePool
,
FileStore
,
Lazy
from
megatron.energon.cache.file_store
import
DecodeFileStore
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.fork_hook
import
ForkMixin
from
megatron.energon.source_info
import
SourceInfo
,
add_source_info
T
=
TypeVar
(
"T"
)
@
edataclass
class
FileCacheLazy
(
Lazy
[
T
]):
"""
Represents a reference to a background prefetch.
"""
# The cache pool that we're using.
pool
:
"FileStoreCachePool"
# The entry in the cache pool that we're using.
entry
:
"_PendingTask"
# If get() was called, this will be the data (uncached).
_data
:
Optional
[
tuple
[
T
,
SourceInfo
]]
=
None
def
get
(
self
,
sample
:
Any
=
None
)
->
T
:
"""
Returns the data and adds the source info to the sample.
If the background job hasn't started, we cancel it,
do a direct read, and remove ourselves from the pool's references.
Otherwise, we wait for the job to finish, read from cache, and remove ourselves.
"""
if
self
.
_data
is
not
None
:
return
self
.
_data
self
.
_data
=
self
.
pool
.
_get_data
(
self
.
ds
,
self
.
fname
,
self
.
entry
)
assert
self
.
_data
is
not
None
add_source_info
(
sample
,
self
.
_data
[
1
])
return
self
.
_data
[
0
]
def
__hash__
(
self
)
->
int
:
"""Allows usage in sets and dicts as key."""
return
hash
((
id
(
self
.
ds
),
self
.
fname
))
def
__eq__
(
self
,
other
:
Any
)
->
bool
:
"""Allows usage in sets and dicts as key. Compares the data source and the filename."""
if
not
isinstance
(
other
,
Lazy
):
return
False
return
self
.
ds
is
other
.
ds
and
self
.
fname
==
other
.
fname
def
__del__
(
self
):
if
self
.
_data
is
None
:
with
self
.
pool
.
_lock
:
# Data was never fetched, still decrement refcount to delete the cache entry
self
.
pool
.
_decrement_refcount_and_cleanup
((
self
.
ds
.
get_path
(),
self
.
fname
))
@
edataclass
class
CacheFileLazy
(
Lazy
[
T
]):
"""
Represents a reference to a cached object without deduplication.
"""
# The path to the file that contains the cached pickled object.
cache_path
:
Path
|
None
# If get() was called, this will be the data (uncached).
_data
:
Optional
[
T
]
=
None
def
get
(
self
,
sample
:
Any
=
None
)
->
T
:
"""
Get the lazy data now and adds no source info to the sample.
"""
if
self
.
_data
is
None
:
with
open
(
self
.
cache_path
,
"rb"
)
as
f
:
self
.
_data
=
pickle
.
load
(
f
)
self
.
cache_path
.
unlink
()
self
.
cache_path
=
None
return
self
.
_data
def
__del__
(
self
):
if
self
.
cache_path
is
not
None
:
self
.
cache_path
.
unlink
(
missing_ok
=
True
)
self
.
cache_path
=
None
def
__hash__
(
self
)
->
int
:
return
hash
((
self
.
fname
,
self
.
cache_path
))
def
__eq__
(
self
,
other
:
Any
)
->
bool
:
if
not
isinstance
(
other
,
CacheFileLazy
):
return
False
return
self
.
fname
==
other
.
fname
and
self
.
cache_path
==
other
.
cache_path
def
__repr__
(
self
)
->
str
:
return
f
"CacheFileLazy(fname=
{
self
.
fname
!
r
}
, cache_path=
{
self
.
cache_path
!
r
}
)"
@
edataclass
class
_PendingTask
:
"""Dataclass for storing a pending background task"""
# The dataset that we're caching.
ds
:
FileStore
# The file name that we're caching.
fname
:
str
# The future for the background task that sends the data to the cache.
send_to_cache_future
:
Future
# The number of references to the cache entry.
refcount
:
int
=
1
# The size of the data to be cached.
data_size
:
int
=
0
# Whether the data is required now, i.e. a reading thread is waiting for it.
require_data_now
:
bool
=
False
# The path to the cache file.
cache_path
:
Optional
[
Path
]
=
None
# The source info for the data.
source_info
:
Optional
[
SourceInfo
]
=
None
class
FileStoreCachePool
(
CachePool
,
ForkMixin
):
"""
Manages a thread pool to pre-fetch data onto an SSD cache.
Each (ds, fname) has one Future (one read). Multiple requests
share that same future. We track usage with a refcount.
To avoid multi-process collisions, we generate a random subfolder
for each instance.
"""
cache_dir
:
Path
max_cache_size
:
int
max_cache_count
:
int
current_cache_size
:
int
current_cache_count
:
int
method
:
Literal
[
"raw"
,
"pickle"
]
# Thread pool for out-caching tasks
_worker_pool
:
Optional
[
ThreadPoolExecutor
]
=
None
# (ds.path, fname) -> PendingTask
_pending_tasks
:
Dict
[
Tuple
[
str
,
str
],
_PendingTask
]
# Lock for all shared structures
_lock
:
threading
.
Lock
# Condition variable to signal when cache space is available
_cache_space_available
:
threading
.
Condition
# Whether the pool is shutting down
_shutting_down
:
bool
=
False
def
__init__
(
self
,
*
,
parent_cache_dir
:
Optional
[
Path
]
=
None
,
num_workers
:
int
=
8
,
max_cache_size_gbytes
:
float
=
1024
,
max_cache_count
:
int
=
10_000_000
,
method
:
Literal
[
"raw"
,
"pickle"
]
=
"raw"
,
):
"""
Initialize the cache pool.
Args:
parent_cache_dir: The parent directory for the cache.
num_workers: The number of worker threads to use for copying the data to the cache for lazy loading.
max_cache_size_gbytes: The maximum size of the cache in gigabytes. If the cache exceeds this size,
the prefetching will wait until the cache is below this size.
max_cache_count: The maximum number of files in the cache. If the cache exceeds this number,
the prefetching will wait until the cache is below this number.
method: The method to use for caching. "raw" store the non-decoded raw data. "pickle": first decode the data
and then store the pickled data.
"""
super
().
__init__
()
# If no parent directory is given, create a temp directory
if
parent_cache_dir
is
None
:
parent_cache_dir
=
Path
(
tempfile
.
gettempdir
())
self
.
parent_cache_dir
=
parent_cache_dir
self
.
num_workers
=
num_workers
# Initialize the cache pool (process volatile fields)
self
.
__after_fork__
(
initial
=
True
)
self
.
method
=
method
# We'll store _pending_tasks in the form:
# (ds.path, fname) -> PendingTask
self
.
_pending_tasks
=
{}
# Cache size management
self
.
max_cache_size
=
int
(
max_cache_size_gbytes
*
(
1024
**
3
))
self
.
max_cache_count
=
max_cache_count
self
.
current_cache_size
=
0
self
.
current_cache_count
=
0
# A lock to protect all shared structures
self
.
_lock
=
threading
.
Lock
()
# Condition variable to signal when cache space is available
self
.
_cache_space_available
=
threading
.
Condition
(
self
.
_lock
)
def
get
(
self
,
ds
:
FileStore
,
fname
:
str
,
sample
:
Any
=
None
)
->
Any
:
"""
Synchronous read from the dataset (no cache usage).
"""
return
ds
.
get
(
fname
,
sample
)
def
_get_data
(
self
,
ds
:
FileStore
,
fname
:
str
,
entry
:
_PendingTask
)
->
tuple
[
Any
,
SourceInfo
]:
"""
Get the data for a given file from the cache and purge cache if no references are left.
* If the cache-out is complete, read from cache.
* If the cache-out is currently prefetching the data to local storage, wait until it's done.
* If the cache-out job is waiting for space, skip the cache and do a direct read.
* If the cache-out job is queued for caching, cancel and do a direct read.
* If the cache-out job failed, raise through and keep for other references.
* If the cache-out job is cancelled, requeue if there are other references waiting for it.
"""
result
:
tuple
[
Any
,
SourceInfo
]
with
self
.
_lock
:
try
:
# Attempt to cancel if the job hasn't started
if
entry
.
send_to_cache_future
.
cancel
():
was_cached
=
False
try
:
# Cancelled => job never ran. We'll do a direct read.
result
=
ds
[
fname
]
finally
:
# Decrement refcount
self
.
_decrement_refcount_and_cleanup
(
key
=
(
ds
.
get_path
(),
fname
))
else
:
# Future is already running or done.
# Release the lock so the background job can proceed,
# then reacquire it after waiting. Otherwise we might block the worker.
entry
.
require_data_now
=
True
self
.
_cache_space_available
.
notify_all
()
self
.
_lock
.
release
()
# If the job failed, let's keep the exception for other references.
was_cached
=
True
try
:
# Can raise exception if job failed
was_cached
=
entry
.
send_to_cache_future
.
result
()
if
was_cached
:
# The job is complete; read from cache
result
=
self
.
_read_from_cache
(
entry
)
else
:
# The job failed, so we'll do a direct decode
result
=
ds
[
fname
]
finally
:
self
.
_lock
.
acquire
()
entry
.
require_data_now
=
False
# Decrement refcount
self
.
_decrement_refcount_and_cleanup
(
key
=
(
ds
.
get_path
(),
fname
))
finally
:
if
entry
.
refcount
>
0
and
not
was_cached
:
# TODO: Could write to cache here, data is already fetched.
# Write the result to the cache
# Requeue the job, there is another reference to the cache entry
entry
.
send_to_cache_future
=
self
.
_worker_pool
.
submit
(
self
.
_cache_out_task
,
ds
,
fname
,
entry
)
return
result
def
_cache_out_task
(
self
,
ds
:
FileStore
,
fname
:
str
,
entry
:
_PendingTask
)
->
bool
:
with
self
.
_lock
:
if
self
.
_shutting_down
:
return
False
# Perform the data read
if
self
.
method
==
"raw"
:
if
isinstance
(
ds
,
DecodeFileStore
):
data
,
entry
.
source_info
=
ds
.
inner_reader
[
fname
]
else
:
data
,
entry
.
source_info
=
ds
[
fname
]
elif
self
.
method
==
"pickle"
:
data
,
entry
.
source_info
=
ds
[
fname
]
data
=
pickle
.
dumps
(
data
)
else
:
raise
ValueError
(
f
"Invalid method:
{
self
.
method
}
"
)
# Wait until there's enough space in the cache
with
self
.
_lock
:
entry
.
data_size
=
file_size
=
len
(
data
)
while
(
self
.
current_cache_count
+
1
>
self
.
max_cache_count
or
self
.
current_cache_size
+
entry
.
data_size
>
self
.
max_cache_size
):
# Release the lock and wait for notification
self
.
_cache_space_available
.
wait
()
if
entry
.
require_data_now
or
self
.
_shutting_down
:
# At least one reference requires the data now, stop waiting for space and exit immediately
return
False
# Reserve the space
self
.
current_cache_size
+=
file_size
self
.
current_cache_count
+=
1
if
self
.
_shutting_down
or
entry
.
refcount
<=
0
:
# No more references to this background job, don't write to cache
return
False
try
:
assert
entry
.
cache_path
is
None
,
(
f
"cache_path should be None, but is
{
entry
.
cache_path
!
r
}
"
)
# Write to cache
cache_path
=
self
.
_make_cache_path
(
ds
,
fname
)
self
.
_write_to_cache
(
cache_path
,
data
)
except
:
with
self
.
_lock
:
# Revert the space reservation
self
.
current_cache_size
-=
file_size
self
.
current_cache_count
-=
1
self
.
_cache_space_available
.
notify_all
()
raise
else
:
with
self
.
_lock
:
entry
.
cache_path
=
cache_path
# print(
# f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Wrote to cache {cache_path} (rc={entry.refcount}, size={file_size}, name={fname})\n",
# end="",
# )
# Data is cached now, return True
return
True
def
get_lazy
(
self
,
ds
:
FileStore
,
fname
:
str
)
->
FileCacheLazy
:
"""
Schedule a background pre-fetch. If multiple calls come in for the same (ds, fname),
they'll share the same Future and increment reference counts.
"""
key
=
(
ds
.
get_path
(),
fname
)
with
self
.
_lock
:
if
self
.
_shutting_down
:
raise
RuntimeError
(
"Cache pool is already shutting down"
)
entry
=
self
.
_pending_tasks
.
get
(
key
)
if
entry
:
# Already have a background task for this (ds, fname)
entry
.
refcount
+=
1
else
:
# Create a new background task
entry
=
_PendingTask
(
ds
=
ds
,
fname
=
fname
,
send_to_cache_future
=
None
,
)
self
.
_pending_tasks
[
key
]
=
entry
entry
.
send_to_cache_future
=
self
.
_worker_pool
.
submit
(
self
.
_cache_out_task
,
ds
,
fname
,
entry
)
return
FileCacheLazy
(
ds
=
ds
,
fname
=
fname
,
pool
=
self
,
entry
=
entry
)
def
to_cache
(
self
,
data
:
T
,
name
:
str
)
->
CacheFileLazy
[
T
]:
"""
Move the data to the cache and return a lazy to fetch it later.
"""
raw_data
=
pickle
.
dumps
(
data
)
cache_fname
=
str
(
uuid
.
uuid4
())
cache_path
=
self
.
cache_dir
/
cache_fname
self
.
_write_to_cache
(
cache_path
,
raw_data
)
return
CacheFileLazy
(
ds
=
None
,
fname
=
name
,
pool
=
self
,
cache_path
=
cache_path
)
def
close
(
self
)
->
None
:
"""
Shutdown the pool, wait for tasks, and clear our structures.
"""
with
self
.
_lock
:
self
.
_shutting_down
=
True
for
entry
in
self
.
_pending_tasks
.
values
():
entry
.
send_to_cache_future
.
cancel
()
self
.
_cache_space_available
.
notify_all
()
self
.
_worker_pool
.
shutdown
(
wait
=
True
)
with
self
.
_lock
:
self
.
_pending_tasks
.
clear
()
def
_decrement_refcount_and_cleanup
(
self
,
key
:
Tuple
[
FileStore
,
str
])
->
None
:
"""
Decrement the reference count in `_pending_tasks`.
If it hits zero, remove the entry. Optionally remove the file if so.
Assumes the caller holds `self._lock`.
"""
entry
=
self
.
_pending_tasks
.
get
(
key
)
if
not
entry
:
# Already cleaned up
return
entry
.
refcount
-=
1
if
entry
.
refcount
<=
0
:
# No more references to this background job
del
self
.
_pending_tasks
[
key
]
self
.
_remove_cached_file
(
entry
)
assert
entry
.
refcount
==
0
,
f
"refcount should be 0:
{
entry
.
refcount
}
"
def
_make_cache_path
(
self
,
ds
:
FileStore
,
fname
:
str
)
->
Path
:
# This is safe, because the parent cache dir is unique per instance.
ds_hash
=
hashlib
.
md5
(
ds
.
get_path
().
encode
(
"utf-8"
)).
hexdigest
()
fn_hash
=
hashlib
.
md5
(
fname
.
encode
(
"utf-8"
)).
hexdigest
()
# ds_hash = str(ds.get_path()).replace("/", "_")
# fn_hash = fname.replace("/", "_")
return
self
.
cache_dir
/
f
"
{
ds_hash
}
_
{
fn_hash
}
"
def
_write_to_cache
(
self
,
path
:
Path
,
data
:
bytes
)
->
None
:
path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
with
open
(
path
,
"wb"
)
as
f
:
f
.
write
(
data
)
def
_read_from_cache
(
self
,
entry
:
_PendingTask
)
->
tuple
[
Any
,
SourceInfo
]:
assert
entry
.
source_info
is
not
None
,
"source_info should have been set"
with
open
(
entry
.
cache_path
,
"rb"
)
as
f
:
if
self
.
method
==
"raw"
:
raw
=
f
.
read
()
if
isinstance
(
entry
.
ds
,
DecodeFileStore
):
return
entry
.
ds
.
decoder
.
decode
(
entry
.
fname
,
raw
),
entry
.
source_info
else
:
return
raw
,
entry
.
source_info
else
:
return
pickle
.
load
(
f
),
entry
.
source_info
def
_remove_cached_file
(
self
,
entry
:
_PendingTask
)
->
None
:
"""
Removes a file from disk and updates size counters.
Assumes the caller holds `self._lock`.
"""
if
entry
.
cache_path
is
None
:
return
if
not
entry
.
cache_path
.
exists
():
return
try
:
# print(
# f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Removing cached file {entry.cache_path} (rc={entry.refcount})\n",
# end="",
# )
entry
.
cache_path
.
unlink
()
except
OSError
:
pass
entry
.
cache_path
=
None
if
entry
.
data_size
>
0
:
self
.
current_cache_size
-=
entry
.
data_size
self
.
current_cache_count
-=
1
# Notify waiting threads that space is now available
self
.
_cache_space_available
.
notify_all
()
def
__before_fork__
(
self
):
# Ensure the worker pool is shutdown before the fork
assert
len
(
self
.
_pending_tasks
)
==
0
,
"Pending tasks should be empty before fork"
# print(
# f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Before fork for oid={id(self)} random_suffix={self.cache_dir.name!r}\n",
# end="",
# )
self
.
_worker_pool
.
shutdown
(
wait
=
True
)
self
.
_worker_pool
=
None
def
__after_in_child_fork__
(
self
):
self
.
__after_fork__
()
def
__after_in_parent_fork__
(
self
):
self
.
__after_fork__
()
def
__after_fork__
(
self
,
initial
:
bool
=
False
):
random_suffix
=
""
.
join
(
random
.
Random
(
os
.
getpid
()
^
random
.
randint
(
0
,
2
**
32
)).
choices
(
string
.
ascii_lowercase
+
string
.
digits
,
k
=
16
)
)
assert
self
.
_worker_pool
is
None
self
.
_worker_pool
=
ThreadPoolExecutor
(
max_workers
=
self
.
num_workers
,
thread_name_prefix
=
"CacheWorker"
)
# Create a random subdirectory name to avoid collisions with other processes
# As the global random generator is cloned across processes, we need to use a process-specific seed
self
.
cache_dir
=
(
self
.
parent_cache_dir
/
f
"cache_
{
random_suffix
}
"
).
resolve
()
self
.
cache_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# if initial:
# print(
# f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: Init oid={id(self)} random_suffix={random_suffix!r}\n",
# end="",
# )
# else:
# print(
# f"FSCP r={torch.distributed.get_rank()}, pid={os.getpid()}: After fork for pid={os.getpid()} oid={id(self)} random_suffix={random_suffix!r}\n",
# end="",
# )
def
__str__
(
self
):
return
f
"FileStoreCachePool(cache_dir=
{
self
.
cache_dir
}
, max_cache_size=
{
self
.
max_cache_size
}
, max_cache_count=
{
self
.
max_cache_count
}
, method=
{
self
.
method
}
, current_cache_size=
{
self
.
current_cache_size
}
, current_cache_count=
{
self
.
current_cache_count
}
)"
Megatron-Energon/src/megatron/energon/cache/file_store.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Optional
,
Union
from
megatron.energon.cache.base
import
FileStore
,
FileStoreDecoder
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.jsonl.ijsonl_reader
import
IJsonlReader
from
megatron.energon.flavors.webdataset.itar_reader
import
SqliteITarEntryReader
from
megatron.energon.source_info
import
SourceInfo
class
DecodeFileStore
(
FileStore
[
Any
]):
"""Used to wrap a FileStore and decode the data on access."""
def
__init__
(
self
,
inner_reader
:
FileStore
[
bytes
],
*
,
decoder
:
FileStoreDecoder
,
):
"""
Args:
inner_reader: The FileStore to wrap.
decoder: The decoder to apply to every item read from the FileStore.
"""
self
.
inner_reader
=
inner_reader
self
.
decoder
=
decoder
def
__getitem__
(
self
,
fname
:
str
)
->
tuple
[
Any
,
SourceInfo
]:
data
,
source_info
=
self
.
inner_reader
[
fname
]
return
self
.
decoder
.
decode
(
fname
,
data
),
source_info
def
get_path
(
self
)
->
str
:
return
self
.
inner_reader
.
get_path
()
def
__str__
(
self
):
return
f
"DecodeFileStore(inner_reader=
{
self
.
inner_reader
}
, decoder=
{
self
.
decoder
}
)"
class
SystemFileStore
(
FileStore
[
bytes
]):
"""A FileStore that reads files directly from the file system."""
def
__init__
(
self
,
base_dir
:
Optional
[
Union
[
EPath
,
str
]]
=
None
):
"""
Args:
base_dir: The base directory to use for relative paths. If None, you should only pass
absolute paths to __getitem__.
"""
self
.
base_dir
=
EPath
(
base_dir
)
if
base_dir
is
not
None
else
None
def
__getitem__
(
self
,
key
:
str
)
->
tuple
[
bytes
,
SourceInfo
]:
# Construct the full path from the dataset path and the file key
if
self
.
base_dir
is
None
:
file_path
=
EPath
(
key
)
else
:
file_path
=
self
.
base_dir
/
key
# Read and return the file contents as bytes
with
file_path
.
open
(
"rb"
)
as
f
:
data
=
f
.
read
()
return
data
,
SourceInfo
(
dataset_path
=
self
.
base_dir
,
index
=
None
,
shard_name
=
None
,
file_names
=
(
key
,),
)
def
get_path
(
self
)
->
str
:
"""Returns the path to the dataset."""
return
str
(
self
.
base_dir
)
def
__str__
(
self
):
return
f
"SystemFileStore(base_dir=
{
self
.
base_dir
}
)"
class
WebdatasetFileStore
(
SqliteITarEntryReader
,
FileStore
[
bytes
]):
"""This dataset will directly read files from the dataset tar files from a prepared energon dataset."""
def
__init__
(
self
,
dataset_path
:
EPath
,
):
super
().
__init__
(
base_path
=
dataset_path
,
key_is_full_entryname
=
True
)
def
get_path
(
self
)
->
str
:
return
str
(
self
.
base_path
)
class
JsonlFileStore
(
IJsonlReader
,
FileStore
[
bytes
]):
"""This dataset will directly read entries from a jsonl file."""
def
get_path
(
self
)
->
str
:
return
str
(
self
.
jsonl_path
)
Prev
1
2
3
4
5
6
7
8
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment