Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
Qwen2.5-14B-Instruct_dcu-megatron
Commits
f356f546
Commit
f356f546
authored
Feb 04, 2026
by
maming
Browse files
Initial commit
parents
Pipeline
#3339
canceled with stages
Changes
346
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2313 additions
and
0 deletions
+2313
-0
Megatron-Energon/src/megatron/energon/flavors/jsonl/ijsonl_reader.py
...ergon/src/megatron/energon/flavors/jsonl/ijsonl_reader.py
+190
-0
Megatron-Energon/src/megatron/energon/flavors/jsonl/jsonl_prepare.py
...ergon/src/megatron/energon/flavors/jsonl/jsonl_prepare.py
+78
-0
Megatron-Energon/src/megatron/energon/flavors/multichoice_vqa.py
...n-Energon/src/megatron/energon/flavors/multichoice_vqa.py
+43
-0
Megatron-Energon/src/megatron/energon/flavors/ocr.py
Megatron-Energon/src/megatron/energon/flavors/ocr.py
+56
-0
Megatron-Energon/src/megatron/energon/flavors/similarity_interleaved.py
...on/src/megatron/energon/flavors/similarity_interleaved.py
+49
-0
Megatron-Energon/src/megatron/energon/flavors/text.py
Megatron-Energon/src/megatron/energon/flavors/text.py
+33
-0
Megatron-Energon/src/megatron/energon/flavors/vid_qa.py
Megatron-Energon/src/megatron/energon/flavors/vid_qa.py
+43
-0
Megatron-Energon/src/megatron/energon/flavors/vqa.py
Megatron-Energon/src/megatron/energon/flavors/vqa.py
+43
-0
Megatron-Energon/src/megatron/energon/flavors/vqa_and_ocr.py
Megatron-Energon/src/megatron/energon/flavors/vqa_and_ocr.py
+50
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/__init__.py
...ergon/src/megatron/energon/flavors/webdataset/__init__.py
+38
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/aggregator_pool.py
...rc/megatron/energon/flavors/webdataset/aggregator_pool.py
+163
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/base_webdataset.py
...rc/megatron/energon/flavors/webdataset/base_webdataset.py
+283
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/config.py
...Energon/src/megatron/energon/flavors/webdataset/config.py
+9
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/default_decoder_webdataset.py
.../energon/flavors/webdataset/default_decoder_webdataset.py
+52
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/default_generic_webdataset.py
.../energon/flavors/webdataset/default_generic_webdataset.py
+111
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/empty_dataset_error.py
...egatron/energon/flavors/webdataset/empty_dataset_error.py
+6
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/error_handler.py
.../src/megatron/energon/flavors/webdataset/error_handler.py
+57
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/field_access.py
...n/src/megatron/energon/flavors/webdataset/field_access.py
+88
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/indexing.py
...ergon/src/megatron/energon/flavors/webdataset/indexing.py
+541
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/itar.py
...n-Energon/src/megatron/energon/flavors/webdataset/itar.py
+380
-0
No files found.
Too many changes to show.
To preserve performance only
346 of 346+
files are displayed.
Plain diff
Email patch
Megatron-Energon/src/megatron/energon/flavors/jsonl/ijsonl_reader.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
abc
import
ABC
from
typing
import
(
Callable
,
Generator
,
Optional
,
Tuple
,
TypeVar
,
)
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.jsonl.ijsonl
import
(
CachedIJsonlOffsetReader
,
IJsonlFile
,
IJsonlIndexReader
,
IJsonlSamplePointer
,
)
from
megatron.energon.flavors.webdataset.structs
import
FilteredSample
from
megatron.energon.source_info
import
SourceInfo
T_index
=
TypeVar
(
"T_index"
,
covariant
=
False
)
class
IJsonlReader
(
ABC
):
"""
Class for reading indexed jsonl files containing json samples.
The common usage patterns and random-access interfaces are provided here.
Args:
base_path: The path to the dataset.
jsonl_path: The path to the jsonl file.
jsonl_filename: The jsonl file name.
sample_filter: An optional filter function to select samples by their key.
index_cache_size: The size of the index cache.
"""
jsonl_path
:
EPath
sample_filter
:
Optional
[
Callable
[[
str
],
bool
]]
cached_offset_reader
:
CachedIJsonlOffsetReader
ijsonl_file
:
IJsonlFile
|
None
=
None
def
__init__
(
self
,
jsonl_path
:
EPath
,
sample_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
index_cache_size
:
int
=
5
,
):
self
.
jsonl_path
=
jsonl_path
self
.
sample_filter
=
sample_filter
self
.
cached_offset_reader
=
CachedIJsonlOffsetReader
(
jsonl_path
,
cache_size
=
index_cache_size
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
cached_offset_reader
)
def
__str__
(
self
)
->
str
:
return
f
"IJsonlReader(jsonl_path=
{
self
.
jsonl_path
}
)"
def
_get_item_by_sample_pointer
(
self
,
sample_pointer
:
IJsonlSamplePointer
,
)
->
FilteredSample
|
None
:
"""
Get a sample from the dataset or slice it.
Args:
sample_pointer: The sample pointer to get the sample from.
sample_index: The global index of the sample in the dataset.
Returns:
The sample or None if the sample is invalid.
"""
key
=
str
(
sample_pointer
.
index
)
if
self
.
sample_filter
is
not
None
and
not
self
.
sample_filter
(
key
):
return
None
if
self
.
ijsonl_file
is
None
:
self
.
ijsonl_file
=
IJsonlFile
(
self
.
jsonl_path
.
open
(
"rb"
))
json_data
=
self
.
ijsonl_file
.
next
(
sample_pointer
.
byte_offset
,
sample_pointer
.
byte_size
)
if
json_data
is
None
:
return
None
return
FilteredSample
(
__key__
=
f
"
{
self
.
jsonl_path
.
name
}
/
{
key
}
"
,
__shard__
=
self
.
jsonl_path
.
name
,
__restore_key__
=
(
"Webdataset"
,
sample_pointer
.
index
),
__sources__
=
(
SourceInfo
(
dataset_path
=
str
(
self
.
jsonl_path
),
index
=
sample_pointer
.
index
,
shard_name
=
self
.
jsonl_path
.
name
,
file_names
=
(
f
"
{
key
}
.json"
,),
),
),
json
=
json_data
,
)
def
__getitem__
(
self
,
idx
:
int
|
str
)
->
FilteredSample
|
tuple
[
bytes
,
SourceInfo
]
|
None
:
"""
Get a sample from the dataset.
"""
assert
isinstance
(
idx
,
(
int
,
str
)),
f
"Invalid argument type for __getitem__:
{
type
(
idx
)
}
"
full_entry_name
=
False
if
isinstance
(
idx
,
str
):
if
idx
.
endswith
(
".json"
):
num_idx
=
idx
.
removesuffix
(
".json"
)
full_entry_name
=
True
try
:
idx
=
int
(
num_idx
)
except
ValueError
:
raise
ValueError
(
f
"Invalid JSONL sample key:
{
idx
}
"
)
byte_offset
,
byte_size
=
self
.
cached_offset_reader
.
get_ijsonl_byte_offset
(
idx
)
sample
:
FilteredSample
|
None
=
self
.
_get_item_by_sample_pointer
(
IJsonlSamplePointer
(
index
=
idx
,
byte_offset
=
byte_offset
,
byte_size
=
byte_size
,
)
)
if
sample
is
None
:
return
None
if
full_entry_name
:
assert
len
(
sample
[
"__sources__"
])
==
1
return
sample
[
"json"
],
sample
[
"__sources__"
][
0
]
else
:
return
sample
def
list_all_samples
(
self
)
->
Generator
[
Tuple
[
str
,
int
,
int
],
None
,
None
]:
"""List all samples in the jsonl file.
Returns:
A generator of tuples of (sample_key, size, tar_file_id)
"""
last_byte_offset
=
0
with
IJsonlIndexReader
(
self
.
jsonl_path
)
as
ijsonl_index_reader
:
for
sample_idx
,
byte_offset
in
enumerate
(
ijsonl_index_reader
):
if
last_byte_offset
==
byte_offset
:
continue
yield
str
(
sample_idx
),
byte_offset
-
last_byte_offset
,
0
last_byte_offset
=
byte_offset
def
list_all_sample_parts
(
self
)
->
Generator
[
Tuple
[
str
,
int
,
int
],
None
,
None
]:
"""List all sample parts in the jsonl file.
Returns:
A generator of tuples of (sample_key + "." + part_name, size, tar_file_id)
"""
last_byte_offset
=
0
with
IJsonlIndexReader
(
self
.
jsonl_path
)
as
ijsonl_index_reader
:
for
sample_idx
,
byte_offset
in
enumerate
(
ijsonl_index_reader
):
if
last_byte_offset
==
byte_offset
:
continue
yield
f
"
{
sample_idx
}
.json"
,
byte_offset
-
last_byte_offset
,
0
last_byte_offset
=
byte_offset
def
list_sample_parts
(
self
,
sample_key
:
str
)
->
Generator
[
Tuple
[
str
,
int
,
int
],
None
,
None
]:
"""Given a sample key, list all its parts. (E.g. given 1, list 1.jpg, 1.json, etc.)
Args:
sample_key: The sample key to list the parts of.
Returns:
A generator of tuples of (part_name, size, tar_file_id)
"""
try
:
sample_idx
=
int
(
sample_key
)
except
ValueError
:
raise
ValueError
(
f
"Invalid JSONL sample key:
{
sample_key
}
"
)
_
,
byte_size
=
self
.
cached_offset_reader
.
get_ijsonl_byte_offset
(
sample_idx
)
yield
f
"
{
sample_key
}
.json"
,
byte_size
,
0
def
get_total_size
(
self
)
->
int
:
return
self
.
cached_offset_reader
.
get_total_size
()
def
close
(
self
):
if
self
.
ijsonl_file
is
not
None
:
self
.
ijsonl_file
.
close
()
self
.
cached_offset_reader
.
close
()
Megatron-Energon/src/megatron/energon/flavors/jsonl/jsonl_prepare.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
logging
from
pathlib
import
Path
from
typing
import
(
Any
,
Dict
,
Generator
,
TypeVar
,
Union
,
)
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.jsonl.ijsonl
import
IJsonlFile
,
IJsonlIndexWriter
logger
=
logging
.
getLogger
(
__name__
)
T
=
TypeVar
(
"T"
,
covariant
=
True
)
class
JsonlPreparator
:
@
staticmethod
def
iter_dataset_content
(
path
:
Union
[
str
,
EPath
],
)
->
Generator
[
Dict
[
str
,
Any
],
None
,
None
]:
"""
Yield example dataset content for a few samples.
Args:
path: Path to the tar file.
"""
with
EPath
(
path
).
open
(
"rb"
)
as
f
:
with
IJsonlFile
(
f
)
as
index_reader
:
for
entry
in
index_reader
:
yield
{
"json"
:
entry
}
@
classmethod
def
prepare_dataset
(
cls
,
path
:
Union
[
Path
,
EPath
],
)
->
int
:
"""
Preprocess the jsonl file. Preprocessing is done in parallel.
Counts the number of samples.
Args:
path: Path to the jsonl file
Returns:
Count of samples in the jsonl file.
"""
count
=
0
# Processing is lagging behind. The offsets include empty lines. The whole file must be covered!
last_offset
=
0
with
IJsonlIndexWriter
(
EPath
(
path
))
as
iw
:
with
EPath
(
path
).
open
(
"rb"
)
as
f
:
while
True
:
line
=
f
.
readline
()
if
not
line
:
break
line
=
line
.
strip
()
if
not
line
:
if
last_offset
:
last_offset
=
f
.
tell
()
continue
assert
line
.
startswith
(
b
"{"
)
and
line
.
endswith
(
b
"}"
),
(
f
"Line
{
line
}
does not start and end with a json object {{}}."
)
iw
.
append
(
last_offset
)
last_offset
=
f
.
tell
()
count
+=
1
assert
last_offset
==
f
.
tell
(),
(
f
"The last offset
{
last_offset
}
does not match the file size
{
f
.
tell
()
}
."
)
assert
last_offset
!=
0
,
"File is empty."
iw
.
append
(
last_offset
)
return
count
Megatron-Energon/src/megatron/energon/flavors/multichoice_vqa.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
List
,
Optional
import
torch
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
MultiChoiceVQASample
(
Sample
):
"""Sample type for visual question answering."""
#: The input image tensor in the shape (C, H, W)
image
:
torch
.
Tensor
#: The context/question for the image
context
:
str
#: The candidate answers.
choices
:
Optional
[
List
[
str
]]
=
None
#: The index of the correct answer.
correct_choice_idx
:
int
=
0
class
MultiChoiceVQAWebdataset
(
DefaultDecoderWebdatasetFactory
[
MultiChoiceVQASample
]):
__sample_type__
=
MultiChoiceVQASample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/ocr.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
List
,
Optional
,
Union
import
torch
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
OCRSample
(
Sample
):
"""Sample type for optical character recognition."""
#: The input image tensor in the shape (C, H, W)
image
:
torch
.
Tensor
#: The text contained in the image
text
:
str
#: The bounding boxes of the blocks in the image float(N, 4|5<x, y, w, h>)
block_boxes
:
Optional
[
torch
.
Tensor
]
=
None
#: The classes of the blocks in the image int(N, 1<block_class>)
block_classes
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
str
]]]
=
None
#: The text contained in each block (N,)
block_text
:
Optional
[
List
[
str
]]
=
None
#: The bounding boxes of the lines in the image float(N, 4|5<x, y, w, h[, confidence]>)
lines_boxes
:
Optional
[
torch
.
Tensor
]
=
None
#: The text contained in each line (N,)
lines_text
:
Optional
[
List
[
str
]]
=
None
#: The bounding boxes of the words in the image float(N, 4|5<x, y, w, h[, confidence]>)
words_boxes
:
Optional
[
torch
.
Tensor
]
=
None
#: The text contained in each word (N,)
words_text
:
Optional
[
List
[
str
]]
=
None
#: The bounding boxes of the chars in the image float(N, 4|5<x, y, w, h[, confidence]>)
chars_boxes
:
Optional
[
torch
.
Tensor
]
=
None
#: The character contained in each char (N,)
chars_text
:
Optional
[
List
[
str
]]
=
None
class
OCRWebdataset
(
DefaultDecoderWebdatasetFactory
[
OCRSample
]):
__sample_type__
=
OCRSample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/similarity_interleaved.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
List
,
Optional
import
torch
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
SimilarityInterleavedSample
(
Sample
):
"""Sample type for interleaved media such as text with images, but without image-text alignment.
That alignment has to be assigned from the similarity matrix."""
#: The images of the sequence
images
:
List
[
torch
.
Tensor
]
#: The texts of the sequence
texts
:
List
[
str
]
#: The optional audio samples of the sequence
audio
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
#: The optional video frames of the sequence
video
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
#: Similarity matrix between image and text entries in the sequence
similarity_matrix
:
Optional
[
torch
.
Tensor
]
=
None
#: The index within texts representing the sentence that this image is matched to
matched_text_indices
:
Optional
[
List
[
int
]]
=
None
class
SimilarityInterleavedWebdataset
(
DefaultDecoderWebdatasetFactory
[
SimilarityInterleavedSample
]):
__sample_type__
=
SimilarityInterleavedSample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/text.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
TextSample
(
Sample
):
"""Sample type for simple text."""
#: The text of the sample
text
:
str
class
TextWebdataset
(
DefaultDecoderWebdatasetFactory
[
TextSample
]):
__sample_type__
=
TextSample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/vid_qa.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
List
,
Optional
import
torch
from
megatron.energon.av
import
AVDecoder
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
VidQASample
(
Sample
):
"""Sample type for video question answering."""
#: The video data containing the image and audio info.
video
:
AVDecoder
#: The context/question for the image.
context
:
str
#: The possible answers. Not set for testing.
answers
:
Optional
[
List
[
str
]]
=
None
#: The weights of the possible answers. Optionally available.
answer_weights
:
Optional
[
torch
.
Tensor
]
=
None
class
VidQAWebdataset
(
DefaultDecoderWebdatasetFactory
[
VidQASample
]):
__sample_type__
=
VidQASample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/vqa.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
List
,
Optional
import
torch
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
VQASample
(
Sample
):
"""Sample type for visual question answering."""
#: The input image tensor in the shape (C, H, W)
image
:
torch
.
Tensor
#: The context/question for the image
context
:
str
#: The possible answers. Not set for testing.
answers
:
Optional
[
List
[
str
]]
=
None
#: The weights of the possible answers. Optionally available.
answer_weights
:
Optional
[
torch
.
Tensor
]
=
None
class
VQAWebdataset
(
DefaultDecoderWebdatasetFactory
[
VQASample
]):
__sample_type__
=
VQASample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/vqa_and_ocr.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
List
,
Optional
import
torch
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
VQAOCRSample
(
Sample
):
"""Sample type for visual question answering."""
#: The input image tensor in the shape (C, H, W)
image
:
torch
.
Tensor
#: The context/question for the image (VQA)
context
:
str
#: The text contained in the image (OCR)
text
:
str
#: The possible answers. Not set for testing. (VQA)
answers
:
Optional
[
List
[
str
]]
=
None
#: The weights of the possible answers. Optionally available. (VQA)
answer_weights
:
Optional
[
torch
.
Tensor
]
=
None
#: The bounding boxes of the words in the image (N, 4|5<x, y, w, h[, confidence]>) (OCR)
words_boxes
:
Optional
[
torch
.
Tensor
]
=
None
#: The text contained in each word (N,) (OCR)
words_text
:
Optional
[
List
[
str
]]
=
None
class
VQAOCRWebdataset
(
DefaultDecoderWebdatasetFactory
[
VQAOCRSample
]):
__sample_type__
=
VQAOCRSample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/webdataset/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.av
import
AVData
,
AVDecoder
from
megatron.energon.flavors.webdataset.base_webdataset
import
BaseWebdatasetFactory
from
megatron.energon.flavors.webdataset.config
import
JOIN_INDEX_FILENAME
,
MAIN_FOLDER_NAME
from
megatron.energon.flavors.webdataset.default_decoder_webdataset
import
(
DefaultDecoderWebdatasetFactory
,
)
from
megatron.energon.flavors.webdataset.default_generic_webdataset
import
(
DefaultGenericWebdatasetFactory
,
)
from
megatron.energon.flavors.webdataset.empty_dataset_error
import
EmptyDatasetError
from
megatron.energon.flavors.webdataset.joined_webdataset
import
JoinedWebdatasetFactory
from
megatron.energon.flavors.webdataset.metadata
import
WebdatasetMeta
from
megatron.energon.flavors.webdataset.sample_decoder
import
(
AVDecoderType
,
ImageDecoderType
,
SampleDecoder
,
)
from
megatron.energon.flavors.webdataset.standard_webdataset
import
StandardWebdatasetFactory
__all__
=
[
"AVData"
,
"AVDecoder"
,
"AVDecoderType"
,
"BaseWebdatasetFactory"
,
"DefaultDecoderWebdatasetFactory"
,
"DefaultGenericWebdatasetFactory"
,
"EmptyDatasetError"
,
"ImageDecoderType"
,
"JOIN_INDEX_FILENAME"
,
"JoinedWebdatasetFactory"
,
"MAIN_FOLDER_NAME"
,
"SampleDecoder"
,
"StandardWebdatasetFactory"
,
"WebdatasetMeta"
,
]
Megatron-Energon/src/megatron/energon/flavors/webdataset/aggregator_pool.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
__future__
import
annotations
import
multiprocessing
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Generic
,
Iterable
,
List
,
Optional
,
TypeVar
T_result
=
TypeVar
(
"T_result"
)
T_aggregation_data
=
TypeVar
(
"T_aggregation_data"
)
T_input_data
=
TypeVar
(
"T_input_data"
)
class
BaseAggregator
(
ABC
,
Generic
[
T_aggregation_data
,
T_result
]):
"""
Base class for a user-defined aggregator.
Implement on_start, on_item, and on_finish to handle aggregator logic.
"""
def
on_start
(
self
,
aggregator_pool
:
AggregatorPool
)
->
None
:
"""
Called exactly once in the aggregator process before receiving any items.
"""
pass
@
abstractmethod
def
on_item
(
self
,
item
:
T_aggregation_data
,
aggregator_pool
:
AggregatorPool
)
->
None
:
"""
Called for each item produced by the workers.
"""
...
def
on_finish
(
self
,
aggregator_pool
:
AggregatorPool
)
->
None
:
"""
Called once when all workers have signaled completion (i.e. all items are processed).
"""
pass
def
get_final_result_data
(
self
)
->
T_result
:
"""
Called after on_finish to retrieve any final data produced by the aggregator.
"""
return
None
class
AggregatorPool
(
Generic
[
T_input_data
,
T_aggregation_data
,
T_result
]):
"""
A pool that manages multiple worker processes sending results to
a single aggregator process.
The user must provide:
- user_produce_data(task) -> yields items (streaming results)
- aggregator: an instance of a class derived from BaseAggregator
which implements on_start, on_item, on_finish, etc.
"""
num_workers
:
int
user_produce_data
:
Callable
[[
T_input_data
],
Iterable
[
Any
]]
aggregator
:
BaseAggregator
[
T_aggregation_data
,
T_result
]
task_queue
:
multiprocessing
.
Queue
[
Optional
[
T_input_data
]]
result_queue
:
multiprocessing
.
Queue
[
Optional
[
T_aggregation_data
]]
def
__init__
(
self
,
num_workers
:
int
,
user_produce_data
:
Callable
[[
T_input_data
],
Iterable
[
Any
]],
aggregator
:
BaseAggregator
[
T_aggregation_data
,
T_result
],
)
->
None
:
"""
Args:
num_workers: Number of worker processes.
user_produce_data: Function that takes a task and yields items (the "large" data stream).
aggregator: An instance of a user-defined class for handling aggregator logic.
"""
self
.
num_workers
=
num_workers
self
.
user_produce_data
=
user_produce_data
self
.
aggregator
=
aggregator
# Queues for tasks and results
self
.
task_queue
=
multiprocessing
.
Queue
()
self
.
result_queue
=
multiprocessing
.
Queue
()
# Queue to pass final aggregator data back to the main process
self
.
_final_result_data_queue
=
multiprocessing
.
Queue
()
# Will store whatever is pulled from _final_data_queue in close()
self
.
_aggregator_final_result_data
:
Optional
[
Any
]
=
None
def
_worker
(
self
,
worker_id
:
int
)
->
None
:
"""Function that runs inside each worker process."""
while
True
:
task
=
self
.
task_queue
.
get
()
if
task
is
None
:
# No more tasks, signal aggregator that this worker is done
break
# Produce data in a streaming fashion
for
item
in
self
.
user_produce_data
(
task
):
self
.
result_queue
.
put
(
item
)
# After finishing all tasks, send a sentinel to the aggregator
self
.
result_queue
.
put
(
None
)
def
_aggregator_run
(
self
)
->
T_result
:
"""
Function that runs in the aggregator process.
Keeps reading items from result_queue.
- If an item is None, that means a worker finished all of its tasks.
- Otherwise, call aggregator.on_item(...) with that item.
"""
# Let the aggregator do any initialization it needs
self
.
aggregator
.
on_start
(
self
)
finished_workers
=
0
while
finished_workers
<
self
.
num_workers
:
item
=
self
.
result_queue
.
get
()
if
item
is
None
:
# A worker has finished all of its tasks
finished_workers
+=
1
else
:
# Process the item in the aggregator
self
.
aggregator
.
on_item
(
item
,
self
)
# All workers done, aggregator can finalize
self
.
aggregator
.
on_finish
(
self
)
# After finishing, serialize the aggregator's final data
return
self
.
aggregator
.
get_final_result_data
()
def
submit_task
(
self
,
task
:
T_input_data
)
->
None
:
"""
Submit a task to be processed by a worker.
"""
self
.
task_queue
.
put
(
task
)
def
process
(
self
)
->
T_result
:
"""
Starts the worker processes and runs the aggregation in the main process.
Waits for all workers to finish and retrieves the aggregator's final data.
"""
workers
:
List
[
multiprocessing
.
Process
]
=
[]
# Start worker processes
for
w_id
in
range
(
self
.
num_workers
):
p
=
multiprocessing
.
Process
(
target
=
self
.
_worker
,
args
=
(
w_id
,),
daemon
=
True
)
p
.
start
()
workers
.
append
(
p
)
# Send the sentinel (None) to each worker
for
_
in
range
(
self
.
num_workers
):
self
.
task_queue
.
put
(
None
)
# Now wait for aggregator to finish reading the queue
result
=
self
.
_aggregator_run
()
# Wait for all workers to finish
for
p
in
workers
:
p
.
join
()
return
result
Megatron-Energon/src/megatron/energon/flavors/webdataset/base_webdataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
logging
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Dict
,
Generic
,
List
,
Optional
,
Sequence
,
TypeVar
import
numpy
as
np
from
megatron.energon.cache
import
FileStore
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.base_dataset
import
(
BaseCoreDatasetFactory
,
SavableDataset
,
legacy_handler
,
)
from
megatron.energon.flavors.webdataset.error_handler
import
ErrorHandler
from
megatron.energon.flavors.webdataset.metadata
import
WebdatasetMeta
from
megatron.energon.flavors.webdataset.prepare
import
WebdatasetPreparator
from
megatron.energon.flavors.webdataset.sample_loader
import
(
RawSampleData
,
WebdatasetSampleLoaderDataset
,
)
from
megatron.energon.flavors.webdataset.sharder
import
Sharder
from
megatron.energon.flavors.webdataset.structs
import
(
DatasetSubset
,
FilteredSample
,
ShardInfo
,
reraise_exception
,
)
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.map_dataset
import
MapDataset
T_sample
=
TypeVar
(
"T_sample"
,
covariant
=
True
)
T
=
TypeVar
(
"T"
,
covariant
=
True
)
logger
=
logging
.
getLogger
(
__name__
)
class
BaseWebdatasetFactory
(
BaseCoreDatasetFactory
[
T_sample
],
WebdatasetPreparator
,
Sharder
,
ErrorHandler
,
Generic
[
T_sample
],
ABC
,
):
"""
Base class for all webdataset sample loader factories. Applies proper sharding across workers.
"""
path
:
EPath
paths
:
list
[
EPath
]
shards
:
List
[
ShardInfo
]
sample_excludes
:
set
[
str
]
split_part_files
:
list
[
str
]
training
:
bool
worker_config
:
WorkerConfig
shuffle_over_epochs
:
Optional
[
int
]
parallel_shard_iters
:
Optional
[
int
]
max_samples_per_sequence
:
Optional
[
int
]
subset
:
Optional
[
DatasetSubset
]
part_filter
:
Optional
[
Callable
[[
str
],
bool
]]
handler
:
Callable
[[
Exception
,
Optional
[
str
],
Optional
[
list
[
SourceInfo
]]],
None
]
shards
:
List
[
ShardInfo
]
def
__init__
(
self
,
path
:
EPath
,
*
,
split_part
:
str
,
training
:
bool
,
worker_config
:
WorkerConfig
,
shuffle_over_epochs
:
Optional
[
int
]
=
1
,
parallel_shard_iters
:
Optional
[
int
]
=
None
,
max_samples_per_sequence
:
Optional
[
int
]
=
None
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
split_config
:
Optional
[
str
]
=
None
,
part_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
handler
:
Callable
[
[
Exception
,
Optional
[
str
],
Optional
[
list
[
SourceInfo
]]],
None
]
=
reraise_exception
,
):
"""
Base factory for the webdataset sample loader.
Args:
path: Path to the dataset.
split_part: Which part to load (e.g. 'train', 'val', 'test').
training: If true, apply shuffling and loop the dataset.
worker_config: Configuration for the workers.
shuffle_over_epochs: Only effective if training=True.
How many epochs to shuffle over if training.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather shard slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices
are drawn with replacement).
parallel_shard_iters: Number of parallel opened shards per worker, shuffling between.
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequentially iterated).
subset: If specified, the dataset will be subsetted.
split_config: Config file to use for shard split definitions.
part_filter: (internal) Function for filtering tar files by dict keys
handler: Exception handler. Args: (exception, key, source_info).
"""
assert
self
.
__sample_type__
is
not
None
,
f
"Class
{
type
(
self
)
}
must define __sample_type__"
wds_meta
=
WebdatasetMeta
.
from_config
(
path
=
path
,
split_part
=
split_part
,
split_config
=
split_config
)
self
.
path
=
path
self
.
paths
=
[
path
]
self
.
shards
=
wds_meta
.
shards
self
.
sample_excludes
=
wds_meta
.
sample_excludes
self
.
split_part_files
=
wds_meta
.
split_part_files
self
.
training
=
training
self
.
worker_config
=
worker_config
self
.
shuffle_over_epochs
=
shuffle_over_epochs
self
.
parallel_shard_iters
=
parallel_shard_iters
self
.
max_samples_per_sequence
=
max_samples_per_sequence
self
.
subset
=
subset
self
.
part_filter
=
part_filter
self
.
handler
=
legacy_handler
(
handler
)
def
__len__
(
self
)
->
int
:
return
sum
(
shard
.
count
for
shard
in
self
.
shards
)
def
build
(
self
,
worker_rotation_offset
:
int
=
0
)
->
SavableDataset
[
T_sample
]:
from
megatron.energon.flavors.webdataset.itar_reader
import
ShardInfosITarReader
if
self
.
parallel_shard_iters
is
None
:
if
self
.
training
:
# 16 seems to be a good choice since we don't want too many file handles open
parallel_shard_iters
=
16
else
:
parallel_shard_iters
=
1
else
:
parallel_shard_iters
=
self
.
parallel_shard_iters
workers_sample_slice_offsets
=
self
.
shard_workers
(
self
.
shards
,
worker_config
=
self
.
worker_config
,
max_samples_per_sequence
=
self
.
max_samples_per_sequence
,
rotation_offset
=
worker_rotation_offset
,
subset
=
self
.
subset
,
)
_print_shard_slices
(
self
.
worker_config
,
self
.
shards
,
workers_sample_slice_offsets
)
itar_reader
=
ShardInfosITarReader
(
self
.
path
,
self
.
shards
,
part_filter
=
self
.
part_filter
,
sample_filter
=
self
.
sample_filter
,
itar_cache_size
=
parallel_shard_iters
,
)
dataset
=
WebdatasetSampleLoaderDataset
(
join_readers
=
[
itar_reader
],
workers_sample_slice_offsets
=
workers_sample_slice_offsets
,
worker_config
=
self
.
worker_config
,
shuffle_over_epochs
=
self
.
shuffle_over_epochs
if
self
.
training
else
None
,
parallel_slice_iters
=
parallel_shard_iters
,
)
return
MapDataset
(
dataset
,
self
.
_load_sample_raw
,
error_handler
=
self
.
error_handler
,
stateless_map_fn
=
True
,
map_fn_config
=
self
.
config
,
worker_config
=
self
.
worker_config
,
)
def
as_file_store
(
self
)
->
"FileStore"
:
from
megatron.energon.cache.file_store
import
WebdatasetFileStore
return
WebdatasetFileStore
(
self
.
path
)
def
sample_filter
(
self
,
key
:
str
)
->
bool
:
return
key
not
in
self
.
sample_excludes
def
_load_sample_raw
(
self
,
raw_sample
:
RawSampleData
)
->
T_sample
:
# Just a wrapper for the inner tuple. Tuple should be of length 1.
assert
len
(
raw_sample
.
data
)
==
1
and
raw_sample
.
data
[
0
]
is
not
None
return
self
.
load_sample
(
raw_sample
.
data
[
0
])
@
abstractmethod
def
load_sample
(
self
,
raw_data
:
FilteredSample
)
->
T_sample
:
"""Loads the sample from the dataset."""
...
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
type
=
type
(
self
).
__qualname__
,
training
=
self
.
training
,
_path
=
str
(
self
.
path
),
shards
=
[
dict
(
name
=
shard
.
name
,
count
=
shard
.
count
,
_path
=
str
(
shard
.
path
),
)
for
shard
in
self
.
shards
],
sample_excludes
=
list
(
self
.
sample_excludes
),
shuffle_over_epochs
=
self
.
shuffle_over_epochs
,
parallel_shard_iters
=
self
.
parallel_shard_iters
,
max_samples_per_sequence
=
self
.
max_samples_per_sequence
,
subset
=
self
.
subset
.
config
()
if
self
.
subset
is
not
None
else
None
,
)
def
__str__
(
self
):
return
f
"
{
type
(
self
).
__name__
}
(path=
{
self
.
path
}
)"
def
_print_shard_slices
(
worker_config
:
WorkerConfig
,
shards
:
List
[
ShardInfo
],
slice_offsets
:
Sequence
[
Sequence
[
int
]]
):
shard_starts
=
np
.
cumsum
([
0
]
+
[
shard
.
count
for
shard
in
shards
])
def
shard_range_info
(
start
:
int
,
end
:
int
)
->
str
:
start_shard_idx
=
np
.
searchsorted
(
shard_starts
,
start
,
side
=
"right"
)
-
1
end_shard_idx
=
np
.
searchsorted
(
shard_starts
,
end
,
side
=
"left"
)
-
1
if
start_shard_idx
==
end_shard_idx
:
shard
=
shards
[
start_shard_idx
]
if
start
-
shard_starts
[
start_shard_idx
]
==
0
:
start_str
=
"(start)"
else
:
start_str
=
""
if
end
-
shard_starts
[
start_shard_idx
]
==
shard
.
count
:
end_str
=
"(end)"
else
:
end_str
=
""
return
f
"
{
shard
.
name
}
[
{
start
-
shard_starts
[
start_shard_idx
]
}{
start_str
}
,
{
end
-
shard_starts
[
start_shard_idx
]
}{
end_str
}
]"
else
:
start_shard
=
shards
[
start_shard_idx
]
end_shard
=
shards
[
end_shard_idx
]
if
start
-
shard_starts
[
start_shard_idx
]
==
0
:
start_str
=
"(start)"
else
:
start_str
=
""
if
end
-
shard_starts
[
end_shard_idx
]
==
end_shard
.
count
:
end_str
=
"(end)"
else
:
end_str
=
""
return
f
"
{
start_shard
.
name
}
[
{
start
-
shard_starts
[
start_shard_idx
]
}{
start_str
}
,]-
{
end_shard
.
name
}
[,
{
end
-
shard_starts
[
end_shard_idx
]
}{
end_str
}
]"
for
worker_idx
,
sample_slice_offsets
in
enumerate
(
slice_offsets
):
start_idx
=
sample_slice_offsets
[
0
]
end_idx
=
sample_slice_offsets
[
-
1
]
if
len
(
sample_slice_offsets
)
>
6
:
offset_str
=
f
"
{
', '
.
join
(
str
(
o
)
for
o
in
sample_slice_offsets
[:
3
])
}
...<
{
len
(
sample_slice_offsets
)
-
6
}
>
{
', '
.
join
(
str
(
o
)
for
o
in
sample_slice_offsets
[
-
3
:])
}
"
else
:
offset_str
=
", "
.
join
(
str
(
o
)
for
o
in
sample_slice_offsets
)
if
len
(
sample_slice_offsets
)
>
6
:
slices_str
=
(
", "
.
join
(
shard_range_info
(
start
,
end
)
for
start
,
end
in
zip
(
sample_slice_offsets
[:
3
],
sample_slice_offsets
[
1
:
4
])
)
+
f
" ...<
{
len
(
sample_slice_offsets
)
-
6
}
> "
+
", "
.
join
(
shard_range_info
(
start
,
end
)
for
start
,
end
in
zip
(
sample_slice_offsets
[
-
4
:
-
1
],
sample_slice_offsets
[
-
3
:])
)
)
else
:
slices_str
=
", "
.
join
(
shard_range_info
(
start
,
end
)
for
start
,
end
in
zip
(
sample_slice_offsets
[:
-
1
],
sample_slice_offsets
[
1
:])
)
print
(
f
"rank=
{
worker_config
.
rank
}
, worker=
{
worker_idx
}
: sample_range=[
{
start_idx
}
,
{
end_idx
}
] in
{
len
(
sample_slice_offsets
)
-
1
}
slices, "
f
"sum(count)=
{
end_idx
-
start_idx
}
: indexes=[
{
offset_str
}
] slices=[
{
slices_str
}
]"
)
Megatron-Energon/src/megatron/energon/flavors/webdataset/config.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
re
split_name_re
=
re
.
compile
(
r
"^((?:.*/|)[^.]+)[.]([^/]*)$"
)
skip_meta_re
=
re
.
compile
(
r
"__[^/]*__($|/)"
)
MAIN_FOLDER_NAME
=
".nv-meta"
JOIN_INDEX_FILENAME
=
"join_index.bin"
Megatron-Energon/src/megatron/energon/flavors/webdataset/default_decoder_webdataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Dict
,
Generic
,
Optional
,
TypeVar
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.webdataset.default_generic_webdataset
import
(
DefaultGenericWebdatasetFactory
,
)
from
megatron.energon.flavors.webdataset.sample_decoder
import
DEFAULT_DECODER
,
SampleDecoder
from
megatron.energon.flavors.webdataset.structs
import
FilteredSample
T_sample
=
TypeVar
(
"T_sample"
,
covariant
=
True
)
class
DefaultDecoderWebdatasetFactory
(
DefaultGenericWebdatasetFactory
[
T_sample
],
Generic
[
T_sample
]):
"""
Extends the default webdataset loading with decoding of contained files, such as images, videos or nested
containers.
"""
# The webdataset decoder function, if to be applied
_decoder
:
Optional
[
SampleDecoder
]
def
__init__
(
self
,
path
:
EPath
,
*
,
decoder
:
Optional
[
SampleDecoder
]
=
DEFAULT_DECODER
,
**
kwargs
,
):
"""
Factory for the webdataset sample loader including the decoder.
Args:
path: Path to the dataset (passed to parent)
decoder: If provided, use this decoder, otherwise just load raw bytes.
**kwargs: Args passed to parent constructor
"""
self
.
_decoder
=
decoder
super
().
__init__
(
path
,
**
kwargs
)
def
load_sample
(
self
,
sample
:
FilteredSample
)
->
T_sample
:
if
self
.
_decoder
is
not
None
:
sample
=
self
.
_decoder
(
sample
)
return
super
().
load_sample
(
sample
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
**
super
().
config
(),
**
(
self
.
_decoder
.
config
()
if
self
.
_decoder
is
not
None
else
{}),
)
Megatron-Energon/src/megatron/energon/flavors/webdataset/default_generic_webdataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
dataclasses
from
typing
import
Any
,
Callable
,
Dict
,
Generic
,
List
,
Optional
,
TypeVar
,
Union
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.base_dataset
import
SavableDataset
from
megatron.energon.flavors.webdataset.base_webdataset
import
BaseWebdatasetFactory
from
megatron.energon.flavors.webdataset.config
import
MAIN_FOLDER_NAME
from
megatron.energon.flavors.webdataset.field_access
import
field_access
,
split_field_access
from
megatron.energon.flavors.webdataset.structs
import
FilteredSample
from
megatron.energon.module_loader
import
ModuleLoader
T_sample
=
TypeVar
(
"T_sample"
,
covariant
=
True
)
class
DefaultGenericWebdatasetFactory
(
BaseWebdatasetFactory
[
T_sample
],
Generic
[
T_sample
]):
"""
Default implementation of webdataset for generic samples and the generic config interface for use with dataset.yaml.
"""
_sample_loader
:
Callable
[[
Dict
[
str
,
Any
]],
Dict
[
str
,
Any
]]
def
__init__
(
self
,
path
:
EPath
,
*
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
field_map
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
sample_loader
:
Optional
[
Union
[
str
,
Callable
[[
dict
],
dict
]]]
=
None
,
part_filter
:
Optional
[
Union
[
str
,
List
[
str
],
Callable
[[
str
],
bool
]]]
=
None
,
**
kwargs
,
):
"""
Factory for the webdataset sample loader and basic configuration options.
Args:
subflavors: Subflavors dictionary to set for all loaded samples.
field_map: Mapping from the webdataset fields to the sample fields.
sample_loader: Function to load the sample from the webdataset fields. May be a string
in order to load a function from a module, or a callable directly.
part_filter: Filter for the parts to load. May be a string in order to load a function
from a module, or a callable directly.
**kwargs: Args passed to parent constructor.
"""
assert
(
field_map
is
None
)
!=
(
sample_loader
is
None
),
(
"Either field_map or sample_loader must be provided."
)
if
sample_loader
is
not
None
:
assert
part_filter
is
not
None
,
(
"part_filter must be provided if sample_loader is provided."
)
module_loader
=
ModuleLoader
()
if
isinstance
(
sample_loader
,
str
):
sample_loader
=
module_loader
.
get_function
(
sample_loader
,
"sample_loader"
,
relative_path
=
path
/
MAIN_FOLDER_NAME
)
else
:
assert
callable
(
sample_loader
)
sample_loader
=
sample_loader
if
isinstance
(
part_filter
,
list
):
parts
=
set
(
part_filter
)
part_filter
=
lambda
part
:
part
in
parts
elif
isinstance
(
part_filter
,
str
):
part_filter
=
module_loader
.
get_function
(
part_filter
,
"part_filter"
,
relative_path
=
path
/
MAIN_FOLDER_NAME
)
else
:
assert
callable
(
part_filter
)
self
.
_sample_loader
=
sample_loader
else
:
assert
field_map
is
not
None
assert
part_filter
is
None
# Split field map fields by json[field][field]
fields
=
{
key
:
split_field_access
(
field
)
for
key
,
field
in
field_map
.
items
()}
assert
set
(
field
.
name
for
field
in
dataclasses
.
fields
(
self
.
__sample_type__
)).
issuperset
(
fields
.
keys
()
)
and
set
(
field
.
name
for
field
in
dataclasses
.
fields
(
self
.
__sample_type__
)
if
field
.
default
is
not
dataclasses
.
MISSING
and
field
.
default_factory
is
not
dataclasses
.
MISSING
).
issubset
(
field_map
.
keys
()),
(
f
"field_map does not map to type
{
self
.
__sample_type__
.
__name__
}
fields"
)
self
.
_sample_loader
=
lambda
sample
:
{
k
:
field_access
(
sample
,
v
)
for
k
,
v
in
fields
.
items
()
}
parts
=
set
(
access
[
0
]
for
options
in
fields
.
values
()
for
access
in
options
)
part_filter
=
lambda
part
:
part
in
parts
inner_sample_loader
=
self
.
_sample_loader
self
.
_sample_loader
=
lambda
sample
:
{
"__key__"
:
sample
[
"__key__"
],
**
inner_sample_loader
(
sample
),
"__restore_key__"
:
sample
[
"__restore_key__"
],
"__subflavors__"
:
self
.
subflavors
,
"__sources__"
:
sample
[
"__sources__"
],
}
super
().
__init__
(
path
,
**
kwargs
,
part_filter
=
part_filter
)
self
.
subflavors
=
subflavors
or
{}
def
load_sample
(
self
,
sample
:
FilteredSample
)
->
T_sample
:
return
self
.
__sample_type__
(
**
self
.
_sample_loader
(
sample
))
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
**
super
().
config
(),
subflavors
=
self
.
subflavors
,
sample_loader
=
SavableDataset
.
_function_config
(
self
.
_sample_loader
),
)
Megatron-Energon/src/megatron/energon/flavors/webdataset/empty_dataset_error.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
class
EmptyDatasetError
(
Exception
):
"""Raised when a dataset is empty."""
Megatron-Energon/src/megatron/energon/flavors/webdataset/error_handler.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Callable
,
Generic
,
Optional
,
Tuple
,
TypeVar
,
Union
from
megatron.energon.errors
import
SYSTEM_EXCEPTIONS
,
FatalSampleError
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset.structs
import
FilteredSample
,
reraise_exception
from
megatron.energon.source_info
import
SourceInfo
T_sample
=
TypeVar
(
"T_sample"
,
covariant
=
True
)
class
ErrorHandler
(
Generic
[
T_sample
]):
handler
:
Callable
[[
Exception
,
Optional
[
str
],
Optional
[
list
[
SourceInfo
]]],
None
]
=
(
reraise_exception
)
def
sample_error_handler
(
self
,
e
:
Exception
,
sample_key
:
Optional
[
str
],
sources
:
Optional
[
list
[
SourceInfo
]]
=
None
):
if
isinstance
(
e
,
SYSTEM_EXCEPTIONS
):
raise
FatalSampleError
(
f
"Error in sample
{
sample_key
!
r
}
:
{
e
}
"
)
from
e
self
.
handler
(
e
,
sample_key
,
sources
)
def
error_handler
(
self
,
e
:
Exception
,
sample
:
Union
[
T_sample
,
dict
,
FilteredSample
,
None
,
Tuple
[
Union
[
T_sample
,
dict
,
FilteredSample
,
None
],
...],
],
):
if
isinstance
(
sample
,
dict
):
key
=
sample
.
get
(
"__key__"
)
sources
=
sample
.
get
(
"__sources__"
)
elif
isinstance
(
sample
,
list
):
if
isinstance
(
sample
[
0
],
dict
):
key
=
","
.
join
(
"None"
if
s
is
None
else
s
.
get
(
"__key__"
)
for
s
in
sample
)
sources
=
[
src
for
s
in
sample
for
src
in
s
.
get
(
"__sources__"
,
())]
elif
isinstance
(
sample
[
0
],
Sample
):
key
=
","
.
join
(
"None"
if
s
is
None
else
s
.
__key__
for
s
in
sample
)
sources
=
[
src
for
s
in
sample
for
src
in
s
.
__sources__
]
else
:
key
=
None
sources
=
None
elif
isinstance
(
sample
,
Sample
):
key
=
sample
.
__key__
sources
=
sample
.
__sources__
else
:
key
=
None
sources
=
None
self
.
sample_error_handler
(
e
,
key
,
sources
)
Megatron-Energon/src/megatron/energon/flavors/webdataset/field_access.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
re
from
typing
import
Any
,
List
,
Union
from
megatron.energon.errors
import
SampleException
,
compact_str
_field_split_re
=
re
.
compile
(
r
"^(?P<field_name>[^[\]]+)(?P<access>(?:\[([^[\]]+)])*)$"
)
def
split_field_access
(
field_desc
:
str
)
->
List
[
List
[
str
]]:
"""
Splits a field_map for access::
'abcdef,ghi' -> [['abcdef'], ['ghi']]
'abcdef[ghi]' -> [['abcdef', 'ghi']]
'abcdef[ghi][jkl]' -> [['abcdef', 'ghi', 'jkl']]
"""
options
=
field_desc
.
split
(
","
)
option_fields
=
[]
for
option
in
options
:
match
=
_field_split_re
.
match
(
option
)
if
match
:
option_fields
.
append
(
[
match
.
group
(
"field_name"
)]
+
[
access
.
lstrip
(
"["
).
rstrip
(
"]"
)
for
access
in
match
.
group
(
"access"
).
split
(
"]["
)
if
access
]
)
else
:
option_fields
.
append
([
field_desc
])
return
option_fields
class
FieldAccessError
(
SampleException
):
pass
def
_field_access
(
value
:
Union
[
dict
,
list
,
str
,
int
,
bool
,
None
],
field
:
List
[
str
])
->
Any
:
"""
Accesses a (nested) field in the value.
Args:
value: The value to access
field: The access instruction (e.g. `['field1', 'field2']` for
`value['field1']['field2']`)
Returns:
The accessed value
"""
try
:
if
len
(
field
)
==
0
:
return
value
elif
isinstance
(
value
,
dict
):
return
_field_access
(
value
[
field
[
0
]],
field
[
1
:])
elif
isinstance
(
value
,
list
):
return
_field_access
(
value
[
int
(
field
[
0
])],
field
[
1
:])
else
:
raise
FieldAccessError
(
f
"Cannot access literal value
{
compact_str
(
value
)
}
with
{
field
!
r
}
"
)
except
FieldAccessError
:
raise
except
KeyError
:
raise
FieldAccessError
(
f
"Cannot access
{
'.'
.
join
(
field
)
!
r
}
in
{
compact_str
(
value
)
}
"
)
def
field_access
(
value
:
Union
[
dict
,
list
,
str
,
int
,
bool
,
None
],
field
:
List
[
List
[
str
]])
->
Any
:
"""
Accesses a (nested) field in the value.
Args:
value: The value to access
field: The access instruction (e.g. `[['field1', 'field2']]` for
`value['field1']['field2']`, or `[['field1'], ['field2']]` for value.get('field1', value['field2'])`)
Returns:
The accessed value
"""
for
f
in
field
[:
-
1
]:
try
:
return
_field_access
(
value
,
f
)
except
(
KeyError
,
ValueError
,
IndexError
):
pass
return
_field_access
(
value
,
field
[
-
1
])
Megatron-Energon/src/megatron/energon/flavors/webdataset/indexing.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
sqlite3
import
struct
from
pathlib
import
Path
from
typing
import
BinaryIO
,
Generator
,
List
,
Optional
,
Tuple
,
Union
from
numpy
import
int8
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.webdataset.itar
import
ITarRawSamplePartPointer
,
ITarSamplePointer
from
megatron.energon.flavors.webdataset.thread_local_sqlite
import
ThreadLocalSqlite
from
megatron.energon.local_copy
import
ensure_local_copy
class
SqliteIndexWriter
:
sqlite_path
:
EPath
db
:
Optional
[
sqlite3
.
Connection
]
duplicates
:
List
[
Tuple
[
str
,
int
]]
def
__init__
(
self
,
sqlite_path
:
EPath
):
"""
Initializes an SQLite database and sets up the samples table:
- samples(tar_file_id INTEGER,
sample_key TEXT,
sample_index INTEGER,
byte_offset INTEGER,
byte_size INTEGER)
and the sample_parts table:
- sample_parts(tar_file_id INTEGER,
sample_index INTEGER,
part_name TEXT,
content_byte_offset INTEGER,
content_byte_size INTEGER)
Also creates indexes:
- samples(sample_key)
- samples(tar_file_id, sample_index)
- sample_parts(tar_file_id, sample_index, content_byte_offset)
- sample_parts(tar_file_id, sample_index, part_name, content_byte_offset, content_byte_size)
"""
# Final path and temporary path
self
.
sqlite_path
=
sqlite_path
# Initialize SQLite connection
path
=
str
(
self
.
sqlite_path
)
# Only supporting local file system, because sqlite does not support remote file systems.
# TODO: Implement remote file systems. Maybe create locally in tmp then upload?
assert
path
.
startswith
(
"/"
),
(
f
"SQLite path must be absolute local file system path:
{
self
.
sqlite_path
}
"
)
Path
(
path
).
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
db
=
sqlite3
.
connect
(
path
)
self
.
db
.
execute
(
"PRAGMA busy_timeout = 5000;"
)
# wait up to 5000ms when locked
self
.
db
.
execute
(
"PRAGMA journal_mode = WAL;"
)
# Create the sample table
self
.
db
.
execute
(
"DROP INDEX IF EXISTS idx_samples_sample_key"
)
self
.
db
.
execute
(
"DROP INDEX IF EXISTS idx_samples_by_tar_and_idx"
)
self
.
db
.
execute
(
"DROP TABLE IF EXISTS samples"
)
self
.
db
.
execute
(
"""
CREATE TABLE samples (
tar_file_id INTEGER,
sample_key TEXT,
sample_index INTEGER,
byte_offset INTEGER,
byte_size INTEGER
)
"""
)
# Create the sample parts table
self
.
db
.
execute
(
"DROP INDEX IF EXISTS idx_sample_parts_seq"
)
self
.
db
.
execute
(
"DROP INDEX IF EXISTS idx_sample_parts_full"
)
self
.
db
.
execute
(
"DROP TABLE IF EXISTS sample_parts"
)
self
.
db
.
execute
(
"""
CREATE TABLE sample_parts (
tar_file_id INTEGER,
sample_index INTEGER,
part_name TEXT,
content_byte_offset INTEGER,
content_byte_size INTEGER
)
"""
)
self
.
duplicates
=
[]
def
append_sample
(
self
,
tar_file_id
:
int8
,
sample_key
:
str
,
sample_index
:
int
,
byte_offset
:
Optional
[
int
],
byte_size
:
Optional
[
int
],
):
"""
Adds a new sample row to the samples table.
Args:
tar_file_id: The index of the tar file in the reader.
sample_key: The key of the sample.
sample_index: The index of the sample in the tar file.
byte_offset: The byte offset of the sample in the tar file.
byte_size: The size of the sample in the tar file.
"""
assert
self
.
db
is
not
None
,
"Database is closed"
# Insert a row in the samples table
self
.
db
.
execute
(
"""
INSERT INTO samples (tar_file_id, sample_key, sample_index, byte_offset, byte_size)
VALUES (?, ?, ?, ?, ?)
"""
,
(
tar_file_id
,
sample_key
,
sample_index
,
byte_offset
,
byte_size
),
)
def
append_part
(
self
,
tar_file_id
:
int8
,
sample_index
:
int
,
part_name
:
str
,
content_byte_offset
:
int
,
content_byte_size
:
int
,
):
"""Adds a new part row to the samples table."""
assert
self
.
db
is
not
None
,
"Database is closed"
# Insert a row in the sample parts table
self
.
db
.
execute
(
"""
INSERT INTO sample_parts (tar_file_id, sample_index, part_name, content_byte_offset, content_byte_size)
VALUES (?, ?, ?, ?, ?)
"""
,
(
tar_file_id
,
sample_index
,
part_name
,
content_byte_offset
,
content_byte_size
),
)
def
close
(
self
):
"""
Closes the DB connection. If finalize=True, the temporary database is
renamed to the final name, overwriting if necessary.
"""
assert
self
.
db
is
not
None
,
"Database is closed"
# Create the index after adding all the samples for better speed
# Index on sample_key for fast lookups
self
.
db
.
execute
(
"CREATE INDEX IF NOT EXISTS idx_samples_sample_key ON samples(sample_key)"
)
# Create index on the samples table. Help the planner if it chooses `samples` as the probe side of the join
self
.
db
.
execute
(
"CREATE INDEX IF NOT EXISTS idx_samples_by_tar_and_idx ON samples(tar_file_id, sample_index)"
)
# Create index on the sample_parts table for fast sequential access
self
.
db
.
execute
(
"CREATE INDEX IF NOT EXISTS idx_sample_parts_seq ON sample_parts(tar_file_id, sample_index, content_byte_offset)"
)
# Create a full index on the sample_parts table for equality lookups and getting offsets directly from key
self
.
db
.
execute
(
"CREATE INDEX IF NOT EXISTS idx_sample_parts_full ON sample_parts(tar_file_id, sample_index, part_name, content_byte_offset, content_byte_size)"
)
# Check if sample_key are all unique
# self.db.execute("CREATE TEMP TABLE temp AS SELECT sample_key, COUNT(*) AS c FROM samples GROUP BY sample_key HAVING c > 1")
duplicates
=
self
.
db
.
execute
(
"SELECT sample_key, COUNT(*) AS c FROM samples GROUP BY sample_key HAVING c > 1 LIMIT 5"
).
fetchall
()
if
len
(
duplicates
)
>
0
:
self
.
duplicates
=
duplicates
if
self
.
db
is
not
None
:
self
.
db
.
commit
()
self
.
db
.
close
()
self
.
db
=
None
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
# If an exception occurred, do not finalize (so you can inspect the temp file)
self
.
close
()
class
JoinIndexWriter
:
"""Describes how one primary dataset is joined with multiple secondary datasets.
For fast random access, this is a binary format that is memory-mapped.
The first 16 bytes are a header with the number of columns (1 primary + N secondary).
Each row contains (shard_idx, byte_offset, byte_size) for each column.
"""
def
__init__
(
self
,
join_index_path
:
EPath
):
self
.
join_index_path
=
join_index_path
self
.
join_index_file
=
join_index_path
.
open
(
"wb"
)
self
.
num_columns
=
None
def
append
(
self
,
*
columns
:
Tuple
[
int
,
int
,
int
]):
"""Appends a new row to the join index file.
Each row contains (shard_idx, byte_offset, byte_size) for each column.
"""
if
self
.
num_columns
is
None
:
# Write the number of columns
self
.
join_index_file
.
write
(
b
"JIDX0001"
)
# Magic bytes with version
self
.
join_index_file
.
write
(
struct
.
pack
(
"q"
,
len
(
columns
)))
self
.
num_columns
=
len
(
columns
)
else
:
assert
len
(
columns
)
==
self
.
num_columns
,
(
f
"Inconsistent number of keys: Had
{
self
.
num_columns
}
before, got
{
len
(
columns
)
}
"
)
# Write the columns
for
key
in
columns
:
assert
isinstance
(
key
,
tuple
)
and
len
(
key
)
==
3
self
.
join_index_file
.
write
(
struct
.
pack
(
"qqq"
,
*
key
))
def
close
(
self
):
self
.
join_index_file
.
close
()
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
close
()
class
SqliteIndexReader
:
"""Reads samples from an SQLite database created by SqliteIndexWriter.
The database contains a table with the following schema:
- samples(tar_file_id INTEGER,
sample_key TEXT,
sample_index INTEGER,
byte_offset INTEGER,
byte_size INTEGER)
- sample_parts(tar_file_id INTEGER,
sample_index INTEGER,
part_name TEXT,
content_byte_offset INTEGER,
content_byte_size INTEGER)
"""
sqlite_path
:
EPath
db
:
ThreadLocalSqlite
def
__init__
(
self
,
sqlite_path
:
EPath
):
"""Initialize the SQLite database reader.
Args:
sqlite_path: Path to the SQLite database file
"""
self
.
sqlite_path
=
ensure_local_copy
(
sqlite_path
)
# Initialize SQLite connection
path
=
str
(
self
.
sqlite_path
)
# Only supporting local file system, because sqlite does not support remote file systems
assert
path
.
startswith
(
"/"
),
(
f
"SQLite path must be absolute local file system path:
{
self
.
sqlite_path
}
"
)
path
=
f
"file:
{
path
}
?mode=ro&immutable=1"
self
.
db
=
ThreadLocalSqlite
(
path
,
is_uri
=
True
)
def
db_has_sample_parts
(
self
)
->
bool
:
"""Check if the database has a sample_parts table.
Returns:
True if sample_parts table exists, False otherwise.
"""
assert
self
.
db
is
not
None
,
"Database is closed"
db_exists
=
self
.
db
.
select_one
(
"SELECT name FROM sqlite_master WHERE type='table' AND name='sample_parts'"
)
self
.
db
.
thread_close
()
return
db_exists
is
not
None
def
list_all_samples
(
self
)
->
Generator
[
Tuple
[
str
,
int
,
int
],
None
,
None
]:
"""List all sample keys in the database.
Returns:
Tuple of (sample_key, byte_size)
"""
assert
self
.
db
is
not
None
,
"Database is closed"
for
row
in
self
.
db
.
select_all
(
"SELECT sample_key, byte_size, tar_file_id FROM samples"
):
yield
row
[
0
],
row
[
1
],
row
[
2
]
def
list_all_sample_parts
(
self
)
->
Generator
[
Tuple
[
str
,
int
,
int
],
None
,
None
]:
"""List all sample parts (i.e. individual files) in the database.
Returns:
Tuple of (full_key, size, tar_file_id)
"""
assert
self
.
db
is
not
None
,
"Database is closed"
# Select all parts (sorted by tar_file_id, sample_index) but joined with the sample_key names
for
row
in
self
.
db
.
select_all
(
"SELECT "
"s.sample_key || '.' || sp.part_name AS full_key, "
"sp.content_byte_size AS size, "
"sp.tar_file_id AS tar_file_id "
"FROM sample_parts AS sp "
"JOIN samples AS s "
"ON sp.tar_file_id = s.tar_file_id AND sp.sample_index = s.sample_index "
"ORDER BY sp.tar_file_id, sp.sample_index, sp.content_byte_offset"
):
yield
row
[
0
],
row
[
1
],
row
[
2
]
def
list_sample_parts
(
self
,
sample_key
:
str
)
->
Generator
[
Tuple
[
str
,
int
,
int
],
None
,
None
]:
"""List all sample parts (i.e. individual files) in the database.
Args:
sample_key: The sample key to look up
Returns:
Tuple of (part_name, size, tar_file_id)
"""
assert
self
.
db
is
not
None
,
"Database is closed"
# Select all parts (sorted by tar_file_id, sample_index) but joined with the sample_key names
for
row
in
self
.
db
.
select_all
(
"SELECT "
"sp.part_name AS part_name, "
"sp.content_byte_size AS size, "
"sp.tar_file_id AS tar_file_id "
"FROM sample_parts AS sp "
"JOIN samples AS s "
"ON sp.tar_file_id = s.tar_file_id AND sp.sample_index = s.sample_index "
"WHERE s.sample_key = ? "
"ORDER BY sp.tar_file_id, sp.sample_index, sp.content_byte_offset"
,
(
sample_key
,),
):
yield
row
[
0
],
row
[
1
],
row
[
2
]
def
get_total_size
(
self
)
->
int
:
"""Get the total size of all samples in the database."""
assert
self
.
db
is
not
None
,
"Database is closed"
count
=
self
.
db
.
select_one
(
"SELECT SUM(byte_size) FROM samples"
)
return
count
[
0
]
if
count
else
0
def
get_sample_count
(
self
)
->
int
:
"""Get the total number of samples in the database."""
assert
self
.
db
is
not
None
,
"Database is closed"
count
=
self
.
db
.
select_one
(
"SELECT COUNT(*) FROM samples"
)
return
count
[
0
]
if
count
else
0
def
get_sample_part
(
self
,
key
:
str
,
part_name
:
str
)
->
ITarRawSamplePartPointer
:
"""Get a sample part by its key name and part name.
Args:
key: The sample key to look up
part_name: The part name to look up
Returns:
Pointer to the sample part raw data.
"""
assert
self
.
db
is
not
None
,
"Database is closed"
row
=
self
.
db
.
select_one
(
"SELECT sp.tar_file_id, sp.content_byte_offset, sp.content_byte_size "
"FROM sample_parts AS sp "
"JOIN samples AS s "
"ON sp.tar_file_id = s.tar_file_id AND sp.sample_index = s.sample_index "
"WHERE s.sample_key = ? AND sp.part_name = ?"
,
(
key
,
part_name
),
)
if
row
is
None
:
raise
KeyError
(
f
"Sample part not found: key=
{
key
}
, part_name=
{
part_name
}
in
{
self
.
sqlite_path
}
"
)
return
ITarRawSamplePartPointer
(
tar_file_id
=
row
[
0
],
raw_byte_offset
=
row
[
1
],
raw_byte_size
=
row
[
2
],
)
def
get_sample_pointer_by_key
(
self
,
key
:
str
)
->
ITarSamplePointer
:
"""Get a sample by its key name.
Args:
key: The sample key to look up
Returns:
Tuple of (tar_file_id, sample_key, sample_index, byte_offset, byte_size)
"""
assert
self
.
db
is
not
None
,
"Database is closed"
sample
=
self
.
db
.
select_one
(
"SELECT tar_file_id, sample_key, sample_index, byte_offset, byte_size "
"FROM samples WHERE sample_key = ?"
,
(
key
,),
)
if
sample
is
None
:
raise
KeyError
(
f
"Sample key not found:
{
key
}
"
)
return
ITarSamplePointer
(
tar_file_id
=
sample
[
0
],
byte_offset
=
sample
[
3
],
byte_size
=
sample
[
4
],
)
def
close
(
self
):
"""Close the database connection."""
if
self
.
db
is
not
None
:
self
.
db
.
thread_close
()
del
self
.
db
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
close
()
class
JoinIndexReader
:
"""Reads a join index file in different ways.
If a column is specified, only that column is read, otherwise the full rows.
You can iterate over the rows, or read a specific row by index, or get the full tensor.
Each row contains (shard_idx, byte_offset, byte_size) for each column.
"""
join_index_path
:
EPath
join_index_file
:
BinaryIO
column
:
Optional
[
int
]
num_columns
:
int
has_iterated
:
bool
index_row_position
:
int
def
__init__
(
self
,
join_index_path
:
EPath
,
column
:
Optional
[
int
]
=
None
):
self
.
join_index_path
=
join_index_path
self
.
join_index_byte_size
=
join_index_path
.
size
()
self
.
column
=
column
self
.
join_index_file
=
join_index_path
.
open
(
"rb"
)
self
.
has_iterated
=
False
self
.
index_row_position
=
-
1
# Read the header
bytes_magic
=
self
.
join_index_file
.
read
(
8
)
assert
isinstance
(
bytes_magic
,
bytes
)
assert
bytes_magic
[:
4
]
==
b
"JIDX"
,
f
"Invalid magic bytes:
{
bytes_magic
}
"
assert
bytes_magic
[
4
:
8
]
==
b
"0001"
,
f
"Unsupported version:
{
bytes_magic
[
4
:
8
]
}
"
# Read the number of columns
bytes_seckeys
=
self
.
join_index_file
.
read
(
8
)
assert
isinstance
(
bytes_seckeys
,
bytes
)
self
.
num_columns
=
struct
.
unpack
(
"q"
,
bytes_seckeys
)[
0
]
self
.
index_row_position
=
0
def
get_as_tensor
(
self
):
"""Returns the join index as a tensor with shape (N, num_columns, 3)."""
assert
not
self
.
has_iterated
,
"Cannot get_as_tensor after iterating"
import
torch
# Read the raw bytes for all N * 3 int64s.
data
=
self
.
join_index_file
.
read
()
self
.
index_file_position
=
self
.
join_index_file
.
tell
()
assert
len
(
data
)
%
(
8
*
3
)
==
0
,
(
f
"Index file reading: Expected multiple of 3 * 8 bytes, got
{
len
(
data
)
}
bytes"
)
return
torch
.
frombuffer
(
data
,
dtype
=
torch
.
int64
).
view
(
-
1
,
self
.
num_columns
,
3
)
def
__len__
(
self
):
return
(
self
.
join_index_byte_size
-
16
)
//
(
self
.
num_columns
*
8
*
3
)
def
__iter__
(
self
):
return
self
def
_read_one_row
(
self
,
column
:
Optional
[
int
]
=
None
)
->
Union
[
None
,
List
[
Tuple
[
int
,
int
,
int
]]]:
row
=
[]
for
col_idx
in
range
(
self
.
num_columns
):
if
column
is
not
None
and
col_idx
!=
column
:
# Skip this column
self
.
join_index_file
.
seek
(
8
*
3
,
1
)
continue
bytes_key
=
self
.
join_index_file
.
read
(
8
*
3
)
if
not
bytes_key
:
return
None
assert
isinstance
(
bytes_key
,
bytes
)
key_tuple
=
struct
.
unpack
(
"qqq"
,
bytes_key
)
row
.
append
(
key_tuple
)
self
.
index_row_position
+=
1
return
row
def
__next__
(
self
)
->
Union
[
None
,
List
[
Tuple
[
int
,
int
,
int
]]]:
self
.
has_iterated
=
True
return
self
.
_read_one_row
(
column
=
self
.
column
)
def
tell_row
(
self
)
->
int
:
return
self
.
index_row_position
def
__getitem__
(
self
,
idx
:
int
)
->
List
[
Tuple
[
int
,
int
,
int
]]:
"""Reads the idx-th row of the join index file."""
assert
0
<=
idx
<
len
(
self
),
f
"Index out of bounds:
{
idx
}
not in [0,
{
len
(
self
)
}
)"
# Seek to the correct position
if
self
.
index_row_position
!=
idx
:
self
.
join_index_file
.
seek
(
16
+
idx
*
self
.
num_columns
*
8
*
3
)
self
.
index_row_position
=
idx
# Read the secondary keys
row
=
self
.
_read_one_row
(
column
=
self
.
column
)
assert
row
is
not
None
,
f
"Failed to read row
{
idx
}
"
return
row
def
close
(
self
):
self
.
join_index_file
.
close
()
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
close
()
Megatron-Energon/src/megatron/energon/flavors/webdataset/itar.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
contextlib
import
struct
import
tarfile
from
types
import
TracebackType
from
typing
import
BinaryIO
,
Dict
,
Generator
,
Optional
,
Tuple
,
Type
,
Union
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.retry_stream
import
RetryReadStream
ITAR_SUFFIX
=
".tar.idx"
@
edataclass
class
ITarSamplePointer
:
"""
Points to a sample inside some tar file on disk.
The tar_file_id refers to the tar_filenames in the reader.
"""
# The index of the tar file, to be matched with the tar_filenames in the reader.
tar_file_id
:
int
# The byte offset of the sample in the tar file.
byte_offset
:
int
# The size of the sample in the tar file.
byte_size
:
int
@
edataclass
class
ITarRawSamplePartPointer
:
"""
Points to a part of a sample inside some tar file on disk.
The tar_file_id refers to the tar_filenames in the reader.
The raw_byte_offset and raw_byte_size refer to the sample's part's raw data in the tar file.
"""
# The index of the tar file, to be matched with the tar_filenames in the reader.
tar_file_id
:
int
# The byte offset of the file's data in the tar file.
raw_byte_offset
:
int
# The size of the file's data in the tar file.
raw_byte_size
:
int
class
TarIndexReader
:
def
__init__
(
self
,
tar_path
:
Union
[
EPath
,
str
]):
tar_path
=
EPath
(
tar_path
)
index_path
=
tar_path
.
with_suffix
(
ITAR_SUFFIX
)
self
.
_length
=
index_path
.
size
()
//
8
self
.
itar
=
index_path
.
open
(
"rb"
)
def
__getitem__
(
self
,
index
:
int
)
->
int
:
if
index
>=
self
.
_length
or
index
<
0
:
raise
IndexError
(
f
"Index
{
index
}
out of range"
)
if
self
.
itar
.
tell
()
!=
8
*
index
:
self
.
itar
.
seek
(
8
*
index
)
return
struct
.
unpack
(
"Q"
,
self
.
itar
.
read
(
8
))[
0
]
def
__iter__
(
self
)
->
Generator
[
int
,
None
,
None
]:
self
.
itar
.
seek
(
0
)
while
True
:
raw
=
self
.
itar
.
read
(
8
)
if
len
(
raw
)
==
0
:
break
assert
len
(
raw
)
==
8
yield
struct
.
unpack
(
"Q"
,
raw
)[
0
]
def
__len__
(
self
)
->
int
:
return
self
.
_length
def
close
(
self
):
self
.
itar
.
close
()
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
close
()
class
TarIndexWriter
:
def
__init__
(
self
,
tar_path
:
EPath
):
self
.
final_name
=
tar_path
.
with_suffix
(
ITAR_SUFFIX
)
self
.
tmp_name
=
tar_path
.
with_suffix
(
ITAR_SUFFIX
+
".tmp"
)
self
.
itar
=
self
.
tmp_name
.
open
(
"wb"
)
def
append
(
self
,
offset
:
int
):
self
.
itar
.
write
(
struct
.
pack
(
"Q"
,
offset
))
def
close
(
self
,
finalize
:
bool
=
True
):
self
.
itar
.
close
()
if
finalize
:
self
.
tmp_name
.
move
(
self
.
final_name
)
else
:
self
.
tmp_name
.
unlink
()
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
close
(
finalize
=
exc_val
is
None
)
class
SubFileReader
(
BinaryIO
):
"""A file-like object that reads a subfile (i.e. offset, size defined portion) of a larger
file."""
def
__init__
(
self
,
stream
:
BinaryIO
,
offset
:
int
,
size
:
int
):
self
.
offset
=
offset
self
.
_pos
=
0
self
.
size
=
size
self
.
stream
=
stream
self
.
stream
.
seek
(
self
.
offset
)
def
read
(
self
,
n
:
int
=
-
1
)
->
bytes
:
if
n
==
-
1
:
n
=
self
.
size
-
self
.
_pos
else
:
n
=
min
(
n
,
self
.
size
-
self
.
_pos
)
if
n
==
0
:
return
b
""
read
=
self
.
stream
.
read
(
n
)
self
.
_pos
+=
len
(
read
)
return
read
def
seek
(
self
,
offset
:
int
,
whence
:
int
=
0
)
->
int
:
if
whence
==
0
:
self
.
_pos
=
offset
elif
whence
==
1
:
self
.
_pos
+=
offset
elif
whence
==
2
:
self
.
_pos
=
self
.
size
+
offset
else
:
raise
ValueError
(
"Invalid whence value"
)
self
.
_pos
=
max
(
0
,
min
(
self
.
_pos
,
self
.
size
))
self
.
stream
.
seek
(
self
.
offset
+
self
.
_pos
)
return
self
.
_pos
def
tell
(
self
)
->
int
:
return
self
.
_pos
def
__enter__
(
self
)
->
BinaryIO
:
return
self
def
__exit__
(
self
,
exc_type
:
Type
[
BaseException
],
exc_val
:
BaseException
,
exc_tb
:
TracebackType
)
->
None
:
self
.
close
()
def
close
(
self
)
->
None
:
self
.
stream
.
close
()
def
isatty
(
self
)
->
bool
:
return
False
def
seekable
(
self
)
->
bool
:
return
True
def
writable
(
self
)
->
bool
:
return
False
def
get_itar_byte_offset
(
path
:
Union
[
str
,
EPath
],
sample_offset
:
int
=
0
,
)
->
int
:
"""Gets the byte offset from sample offsets."""
if
sample_offset
==
0
:
return
0
with
TarIndexReader
(
path
)
as
itar
:
return
itar
[
sample_offset
]
@
edataclass
class
CacheEntry
:
tar_index_reader
:
TarIndexReader
lookahead_offset
:
Optional
[
int
]
=
None
lookahead_byteoffset
:
Optional
[
int
]
=
None
class
CachedItarOffsetReader
:
"""
This class is a high-level wrapper around TarIndexReader that caches some
of the recent lookups for faster access. It is designed for the case when
you need to read multiple offsets from the same tar file or from multiple
tar files.
Args:
cache_size: The number of entries to keep in the cache. By default, we keep 32.
"""
def
__init__
(
self
,
cache_size
:
int
=
32
):
# Maps (tar_file, current_offset) -> CacheEntry
self
.
tar_index_reader_cache
:
Dict
[
Tuple
[
str
,
int
],
CacheEntry
]
=
{}
self
.
cache_size
=
cache_size
def
_find_or_create_entry
(
self
,
tar_file
:
Union
[
str
,
"EPath"
],
sample_offset
:
int
,
)
->
Tuple
[
Tuple
[
str
,
int
],
CacheEntry
]:
"""
1. If we already have a key == (tar_file, sample_offset), return it.
2. Otherwise, create a new entry (and evict if necessary).
"""
tar_file
=
str
(
tar_file
)
key
=
(
tar_file
,
sample_offset
)
# Direct hit in the cache?
if
key
in
self
.
tar_index_reader_cache
:
return
key
,
self
.
tar_index_reader_cache
[
key
]
# We didn't find an existing entry. Create a new one.
# Evict if needed.
if
len
(
self
.
tar_index_reader_cache
)
>=
self
.
cache_size
:
self
.
_evict_one_entry
()
new_reader
=
TarIndexReader
(
tar_file
)
cache_entry
=
CacheEntry
(
tar_index_reader
=
new_reader
)
self
.
tar_index_reader_cache
[
key
]
=
cache_entry
return
key
,
cache_entry
def
_evict_one_entry
(
self
):
"""
Evict the 'oldest' item in the cache. Here we just pop the first item
returned by iter(...) in Python 3.7+ which *should* be insertion order,
but not strictly an LRU. For true LRU, you can use OrderedDict or similar.
"""
oldest_key
=
next
(
iter
(
self
.
tar_index_reader_cache
))
oldest_entry
=
self
.
tar_index_reader_cache
.
pop
(
oldest_key
)
oldest_entry
.
tar_index_reader
.
close
()
def
_get_itar_byte_offset_with_entry
(
self
,
cache_entry
:
CacheEntry
,
sample_offset
:
int
,
)
->
Tuple
[
int
,
int
]:
"""
Return (start_byte_offset, length_to_next),
possibly using per-entry lookahead for speed.
"""
tar_index_reader
=
cache_entry
.
tar_index_reader
# If offset=0, define the result as byte offset=0 for convenience
if
sample_offset
==
0
:
result_byte_offset
=
0
elif
sample_offset
==
cache_entry
.
lookahead_offset
:
# Reuse the previously cached byte offset from the lookahead
assert
cache_entry
.
lookahead_byteoffset
is
not
None
,
(
"Lookahead offset matched but no lookahead byte offset found."
)
result_byte_offset
=
cache_entry
.
lookahead_byteoffset
else
:
# Normal random access
result_byte_offset
=
tar_index_reader
[
sample_offset
]
# Prepare the lookahead for (sample_offset+1)
next_offset
=
sample_offset
+
1
try
:
cache_entry
.
lookahead_byteoffset
=
tar_index_reader
[
next_offset
]
cache_entry
.
lookahead_offset
=
next_offset
except
IndexError
:
cache_entry
.
lookahead_offset
=
None
cache_entry
.
lookahead_byteoffset
=
None
# length = difference to the next offset, or 0 if none
if
cache_entry
.
lookahead_byteoffset
is
not
None
:
length
=
cache_entry
.
lookahead_byteoffset
-
result_byte_offset
else
:
length
=
0
return
result_byte_offset
,
length
def
get_itar_byte_offset
(
self
,
tar_file
:
Union
[
str
,
"EPath"
],
sample_offset
:
int
=
0
,
)
->
Tuple
[
int
,
int
]:
"""
High-level API to get the byte offset and length for the given file & sample_offset.
"""
# Find or create the suitable CacheEntry
key
,
entry
=
self
.
_find_or_create_entry
(
tar_file
,
sample_offset
)
# Use (and update) the per-entry lookahead logic
result_byte_offset
,
length
=
self
.
_get_itar_byte_offset_with_entry
(
entry
,
sample_offset
)
# Update cache entry with the new offset
self
.
tar_index_reader_cache
.
pop
(
key
)
if
entry
.
lookahead_offset
is
not
None
:
new_key
=
(
str
(
tar_file
),
entry
.
lookahead_offset
)
if
new_key
not
in
self
.
tar_index_reader_cache
:
self
.
tar_index_reader_cache
[
new_key
]
=
entry
else
:
# Already have this entry in the cache, so we can close the reader and use the existing one
# TODO: We may actually may want to keep multiple readers open, because they may be multiple
# sequences to the same sequence.
entry
.
tar_index_reader
.
close
()
else
:
# No lookahead, so we can close the reader
entry
.
tar_index_reader
.
close
()
return
result_byte_offset
,
length
class
ITarFile
(
tarfile
.
TarFile
):
"""This class is a subclass of tarfile.TarFile that allows for reading a tarfile,
with random access while keeping the file open.
Usage:
with open(filename, "rb") as fileobj:
with ITarFile.open(fileobj=fileobj, mode="r:") as f:
f.offset = 101888
tarinfo = f.next()
print(tarinfo.name)
member_bytes = f.extractfile(tarinfo)
# Read more offsets here ...
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
in_init
=
True
try
:
super
().
__init__
(
*
args
,
**
kwargs
)
finally
:
self
.
in_init
=
False
def
next
(
self
):
if
self
.
in_init
:
# Don't automatically read the first member
return
None
if
self
.
offset
!=
self
.
fileobj
.
tell
():
# This prevents tarfile from reading the one byte before
self
.
fileobj
.
seek
(
self
.
offset
)
return
super
().
next
()
@
contextlib
.
contextmanager
def
open_itar
(
path
:
Union
[
str
,
EPath
],
byte_offset
:
int
=
0
,
byte_size
:
Optional
[
int
]
=
None
):
"""
Open an indexed tarfile with offset and size.
Args:
path: Path to the tarfile to open
byte_offset: Byte offset within the file
byte_size: Size of the file to read
Returns:
The opened tarfile
"""
path
=
EPath
(
path
)
# TODO: if tar file startswith(b"\x1f\x8b\x08") -> Seekable gzip file
with
path
.
open
(
"rb"
)
as
f
:
if
f
.
read
(
3
)
==
b
"
\x1f\x8b\x08
"
:
# Open as seekable tgz
raise
ValueError
(
"Seekable tgz not supported yet"
)
if
byte_offset
!=
0
or
byte_size
is
not
None
:
if
byte_size
is
None
:
byte_size
=
path
.
size
()
-
byte_offset
with
RetryReadStream
(
path
)
as
stream
:
with
SubFileReader
(
stream
,
offset
=
byte_offset
,
size
=
byte_size
,
)
as
fileobj
:
with
ITarFile
.
open
(
fileobj
=
fileobj
,
mode
=
"r:"
)
as
f
:
yield
f
else
:
with
RetryReadStream
(
path
)
as
fileobj
:
with
ITarFile
.
open
(
fileobj
=
fileobj
,
mode
=
"r:"
)
as
f
:
yield
f
Prev
1
2
3
4
5
6
7
8
9
10
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment