Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
Qwen2.5-14B-Instruct_dcu-megatron
Commits
f356f546
Commit
f356f546
authored
Feb 04, 2026
by
maming
Browse files
Initial commit
parents
Pipeline
#3339
canceled with stages
Changes
346
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4904 additions
and
0 deletions
+4904
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/itar_reader.py
...on/src/megatron/energon/flavors/webdataset/itar_reader.py
+625
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/joined_webdataset.py
.../megatron/energon/flavors/webdataset/joined_webdataset.py
+252
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/metadata.py
...ergon/src/megatron/energon/flavors/webdataset/metadata.py
+164
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/prepare.py
...nergon/src/megatron/energon/flavors/webdataset/prepare.py
+456
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/sample_decoder.py
...src/megatron/energon/flavors/webdataset/sample_decoder.py
+142
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/sample_loader.py
.../src/megatron/energon/flavors/webdataset/sample_loader.py
+469
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/sharder.py
...nergon/src/megatron/energon/flavors/webdataset/sharder.py
+407
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/standard_webdataset.py
...egatron/energon/flavors/webdataset/standard_webdataset.py
+64
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/structs.py
...nergon/src/megatron/energon/flavors/webdataset/structs.py
+128
-0
Megatron-Energon/src/megatron/energon/flavors/webdataset/thread_local_sqlite.py
...egatron/energon/flavors/webdataset/thread_local_sqlite.py
+154
-0
Megatron-Energon/src/megatron/energon/fork_hook.py
Megatron-Energon/src/megatron/energon/fork_hook.py
+190
-0
Megatron-Energon/src/megatron/energon/loader.py
Megatron-Energon/src/megatron/energon/loader.py
+119
-0
Megatron-Energon/src/megatron/energon/local_copy.py
Megatron-Energon/src/megatron/energon/local_copy.py
+49
-0
Megatron-Energon/src/megatron/energon/metadataset/__init__.py
...tron-Energon/src/megatron/energon/metadataset/__init__.py
+17
-0
Megatron-Energon/src/megatron/energon/metadataset/dataset_loader.py
...nergon/src/megatron/energon/metadataset/dataset_loader.py
+111
-0
Megatron-Energon/src/megatron/energon/metadataset/join_dataset_loader.py
...n/src/megatron/energon/metadataset/join_dataset_loader.py
+559
-0
Megatron-Energon/src/megatron/energon/metadataset/loader.py
Megatron-Energon/src/megatron/energon/metadataset/loader.py
+94
-0
Megatron-Energon/src/megatron/energon/metadataset/loader_interface.py
...rgon/src/megatron/energon/metadataset/loader_interface.py
+96
-0
Megatron-Energon/src/megatron/energon/metadataset/metadataset.py
...n-Energon/src/megatron/energon/metadataset/metadataset.py
+202
-0
Megatron-Energon/src/megatron/energon/metadataset/metadataset_v2.py
...nergon/src/megatron/energon/metadataset/metadataset_v2.py
+606
-0
No files found.
Too many changes to show.
To preserve performance only
346 of 346+
files are displayed.
Plain diff
Email patch
Megatron-Energon/src/megatron/energon/flavors/webdataset/itar_reader.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
abc
import
ABC
,
abstractmethod
from
bisect
import
bisect_right
from
typing
import
(
Callable
,
Dict
,
Generator
,
Generic
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
,
overload
,
)
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.webdataset.config
import
skip_meta_re
,
split_name_re
from
megatron.energon.flavors.webdataset.indexing
import
JoinIndexReader
,
SqliteIndexReader
from
megatron.energon.flavors.webdataset.itar
import
(
CachedItarOffsetReader
,
ITarFile
,
ITarRawSamplePartPointer
,
ITarSamplePointer
,
)
from
megatron.energon.flavors.webdataset.metadata
import
get_info_shard_files
from
megatron.energon.flavors.webdataset.structs
import
FilteredSample
,
ShardInfo
from
megatron.energon.source_info
import
SourceInfo
T_index
=
TypeVar
(
"T_index"
,
covariant
=
False
)
class
ITarReader
(
ABC
,
Generic
[
T_index
]):
"""
An abstract base class for reading a sequence of tar files containing samples.
The common usage patterns and random-access interfaces are provided here.
Args:
base_path: The base path for the relative tar_filepaths.
tar_filenames: The list of tar file names.
tar_filepaths: The corresponding list of full paths to the tar files.
part_filter: An optional filter function to select parts of the samples.
itar_cache_size: The number of tar readers to keep open at the same time.
sample_filter: An optional filter function to select samples by their key.
"""
base_path
:
EPath
tar_filenames
:
List
[
str
]
tar_filepaths
:
List
[
EPath
]
part_filter
:
Optional
[
Callable
[[
str
],
bool
]]
itar_files_cache
:
Dict
[
int
,
ITarFile
]
sample_filter
:
Optional
[
Callable
[[
str
],
bool
]]
def
__init__
(
self
,
base_path
:
EPath
,
tar_filenames
:
List
[
str
],
tar_filepaths
:
List
[
EPath
],
part_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
itar_cache_size
:
int
=
5
,
sample_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
):
assert
len
(
tar_filenames
)
==
len
(
tar_filepaths
),
(
f
"tar_filenames length (
{
len
(
tar_filenames
)
}
) does not match "
f
"tar_filepaths length (
{
len
(
tar_filepaths
)
}
)"
)
self
.
base_path
=
base_path
self
.
tar_filenames
=
tar_filenames
self
.
tar_filepaths
=
tar_filepaths
self
.
part_filter
=
part_filter
self
.
itar_files_cache
=
{}
self
.
itar_cache_size
=
itar_cache_size
self
.
sample_filter
=
sample_filter
@
abstractmethod
def
__len__
(
self
)
->
int
:
"""Returns the total number of samples in the reader."""
raise
NotImplementedError
@
abstractmethod
def
__str__
(
self
)
->
str
:
"""
Must return a descriptive string of the concrete reader.
"""
raise
NotImplementedError
def
close
(
self
):
for
tar_file
in
self
.
itar_files_cache
.
values
():
tar_file
.
fileobj
.
close
()
tar_file
.
close
()
self
.
itar_files_cache
.
clear
()
@
abstractmethod
def
_get_itar_sample_pointer
(
self
,
idx
:
T_index
)
->
ITarSamplePointer
:
"""Get the ITarSample object for the given index."""
raise
NotImplementedError
def
_get_itarfile_cached
(
self
,
tar_file_id
:
int
)
->
ITarFile
:
"""
Get the ITarFile object for the given tar file id.
If the file is not already open, open it. If we exceed
the global cache limit, close the least recently used file.
"""
if
tar_file_id
not
in
self
.
itar_files_cache
:
file_object
=
self
.
tar_filepaths
[
tar_file_id
].
open
(
mode
=
"rb"
)
tar_file
=
ITarFile
.
open
(
fileobj
=
file_object
,
mode
=
"r:"
)
self
.
itar_files_cache
[
tar_file_id
]
=
tar_file
# If we hit the limit of open files, close the least recently used file
while
len
(
self
.
itar_files_cache
)
>
self
.
itar_cache_size
:
# Get the oldest file
lru_key
=
next
(
iter
(
self
.
itar_files_cache
))
self
.
itar_files_cache
[
lru_key
].
fileobj
.
close
()
self
.
itar_files_cache
[
lru_key
].
close
()
del
self
.
itar_files_cache
[
lru_key
]
return
self
.
itar_files_cache
[
tar_file_id
]
def
_get_part_by_raw_sample_pointer
(
self
,
raw_sample_pointer
:
ITarRawSamplePartPointer
,
entry_name
:
str
,
)
->
tuple
[
bytes
,
SourceInfo
]:
"""
Get a sample part and the source info from the dataset.
Args:
raw_sample_pointer: The raw data sample pointer to get the sample from.
Returns:
The raw data bytes.
"""
# Open the tar file (cached)
tar_file
=
self
.
_get_itarfile_cached
(
raw_sample_pointer
.
tar_file_id
)
shard_name
=
self
.
tar_filenames
[
raw_sample_pointer
.
tar_file_id
]
# Get the raw data from the tar file
rest
=
tar_file
.
fileobj
.
tell
()
tar_file
.
fileobj
.
seek
(
raw_sample_pointer
.
raw_byte_offset
)
raw_data
=
tar_file
.
fileobj
.
read
(
raw_sample_pointer
.
raw_byte_size
)
tar_file
.
fileobj
.
seek
(
rest
)
return
raw_data
,
SourceInfo
(
dataset_path
=
self
.
base_path
,
index
=
entry_name
,
shard_name
=
shard_name
,
file_names
=
(
entry_name
,),
)
def
_get_item_by_sample_pointer
(
self
,
sample_pointer
:
ITarSamplePointer
,
restore_index
:
str
|
int
,
entry_match_fn
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
)
->
FilteredSample
|
None
:
"""
Get a sample from the dataset or slice it.
Args:
sample_pointer: The sample pointer to get the sample from.
sample_index: The global index of the sample in the dataset.
entry_match_fn: An optional function to filter the entries in the sample.
Returns:
The sample or None if the sample is not found.
"""
# Open the tar file (cached)
tar_file
=
self
.
_get_itarfile_cached
(
sample_pointer
.
tar_file_id
)
shard_name
=
self
.
tar_filenames
[
sample_pointer
.
tar_file_id
]
sample_base_name
=
None
sample_name
=
None
group_parts
:
Dict
[
str
,
bytes
]
=
{}
file_names
:
list
[
str
]
=
[]
# Position the tar file at the correct offset
tar_file
.
offset
=
sample_pointer
.
byte_offset
while
tar_file
.
offset
<
sample_pointer
.
byte_offset
+
sample_pointer
.
byte_size
:
tarinfo
=
tar_file
.
next
()
if
tarinfo
is
None
:
raise
ValueError
(
f
"Unexpected end of tar file:
{
self
.
tar_filenames
[
sample_pointer
.
tar_file_id
]
}
"
)
fname
=
tarinfo
.
name
if
not
tarinfo
.
isfile
()
or
fname
is
None
:
continue
if
skip_meta_re
.
match
(
fname
):
continue
# Extract the base_name and extension
m
=
split_name_re
.
match
(
fname
)
if
not
m
:
continue
cur_base_name
,
cur_ext
=
m
.
groups
()
if
sample_base_name
is
None
:
sample_base_name
=
cur_base_name
sample_name
=
f
"
{
shard_name
}
/
{
cur_base_name
}
"
if
self
.
sample_filter
is
not
None
and
not
self
.
sample_filter
(
sample_name
):
return
None
else
:
if
sample_base_name
!=
cur_base_name
:
raise
ValueError
(
f
"Inconsistent sample base name:
{
sample_base_name
}
vs
{
cur_base_name
}
"
)
if
entry_match_fn
is
not
None
:
# If entry_match_fn is provided, use it to determine if we should take this entry
take_entry
=
entry_match_fn
(
fname
)
else
:
# If no entry_match_fn is provided, use the part_filter to determine if we should take this entry
take_entry
=
self
.
part_filter
is
None
or
self
.
part_filter
(
cur_ext
)
if
take_entry
:
member_bytes
=
tar_file
.
extractfile
(
tarinfo
).
read
()
group_parts
[
cur_ext
]
=
member_bytes
file_names
.
append
(
fname
)
if
sample_base_name
is
None
:
raise
ValueError
(
f
"No valid files found in sample
{
sample_pointer
}
"
)
return
FilteredSample
(
__key__
=
f
"
{
shard_name
}
/
{
sample_base_name
}
"
,
__shard__
=
self
.
tar_filenames
[
sample_pointer
.
tar_file_id
],
__restore_key__
=
(
"Webdataset"
,
restore_index
),
__sources__
=
(
SourceInfo
(
dataset_path
=
self
.
base_path
,
index
=
restore_index
,
shard_name
=
shard_name
,
file_names
=
tuple
(
file_names
),
),
),
**
group_parts
,
)
def
__getitem__
(
self
,
idx
:
T_index
)
->
FilteredSample
|
None
:
"""
Get a sample from the dataset or slice it.
"""
assert
isinstance
(
idx
,
int
),
f
"Invalid argument type for __getitem__:
{
type
(
idx
)
}
"
sample_pointer
=
self
.
_get_itar_sample_pointer
(
idx
)
return
self
.
_get_item_by_sample_pointer
(
sample_pointer
,
idx
)
class
JoinIndexFileITarReader
(
ITarReader
[
int
]):
"""
A concrete ITarReader that reads samples from a join index file (via JoinIndexReader).
"""
index_file
:
EPath
column
:
int
index_reader_cache
:
Dict
[
int
,
JoinIndexReader
]
index_reader_cache_size
:
int
def
__init__
(
self
,
index_file
:
EPath
,
column
:
int
,
tar_filenames
:
List
[
str
],
base_path
:
EPath
,
part_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
itar_cache_size
:
int
=
5
,
sample_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
):
self
.
index_file
=
index_file
self
.
column
=
column
# Create the full path to each tar file
tar_filepaths
=
[
base_path
/
fn
for
fn
in
tar_filenames
]
self
.
index_reader_cache
=
{}
self
.
index_reader_cache_size
=
itar_cache_size
super
().
__init__
(
base_path
=
base_path
,
tar_filenames
=
tar_filenames
,
tar_filepaths
=
tar_filepaths
,
part_filter
=
part_filter
,
itar_cache_size
=
itar_cache_size
,
sample_filter
=
sample_filter
,
)
def
_get_join_index_reader_cached
(
self
,
sample_idx
:
int
)
->
JoinIndexReader
:
"""
Get the JoinIndexReader object for the given sample index, or create it if it doesn't exist.
"""
if
sample_idx
not
in
self
.
index_reader_cache
:
index_reader
=
JoinIndexReader
(
self
.
index_file
,
column
=
self
.
column
)
self
.
index_reader_cache
[
sample_idx
]
=
index_reader
# If we hit the limit of open files, close the least recently used file
while
len
(
self
.
index_reader_cache
)
>
self
.
index_reader_cache_size
:
# Get the oldest file
lru_key
=
next
(
iter
(
self
.
index_reader_cache
))
self
.
index_reader_cache
[
lru_key
].
close
()
del
self
.
index_reader_cache
[
lru_key
]
return
self
.
index_reader_cache
[
sample_idx
]
def
_get_itar_sample_pointer
(
self
,
sample_idx
:
int
)
->
ITarSamplePointer
:
"""
Get the ITarSample object for the given index.
"""
index_reader
=
self
.
_get_join_index_reader_cached
(
sample_idx
)
row
=
index_reader
[
sample_idx
]
# Update cache entry
new_offset
=
index_reader
.
tell_row
()
del
self
.
index_reader_cache
[
sample_idx
]
self
.
index_reader_cache
[
new_offset
]
=
index_reader
assert
len
(
row
)
==
1
shard_idx
,
byte_offset
,
byte_size
=
row
[
0
]
return
ITarSamplePointer
(
tar_file_id
=
shard_idx
,
byte_offset
=
byte_offset
,
byte_size
=
byte_size
,
)
def
__len__
(
self
)
->
int
:
try
:
# Get any reader, they will all work
index_reader
=
next
(
iter
(
self
.
index_reader_cache
.
values
()))
except
StopIteration
:
# If there's no reader yet, we need to create one to get the length
index_reader
=
self
.
_get_join_index_reader_cached
(
0
)
return
len
(
index_reader
)
def
__str__
(
self
)
->
str
:
return
(
f
"JoinIndexFileITarReader("
f
"len=
{
len
(
self
)
}
, base_path=
{
self
.
base_path
}
, "
f
"len(shards)=
{
len
(
self
.
tar_filenames
)
}
, "
f
"shards=[
{
self
.
tar_filenames
[
0
]
if
self
.
tar_filenames
else
'N/A'
}
, ...])"
)
class
ShardInfosITarReader
(
ITarReader
[
int
]):
"""
A concrete ITarReader that constructs its internal sample list from a list of ShardInfos.
"""
shard_infos
:
List
[
ShardInfo
]
shard_tar_file_idxs
:
List
[
int
]
shard_count_cumsum
:
List
[
int
]
cached_offset_reader
:
CachedItarOffsetReader
def
__init__
(
self
,
base_path
:
EPath
,
shard_infos
:
List
[
ShardInfo
],
part_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
itar_cache_size
:
int
=
5
,
sample_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
):
# Build the tar_filenames and tar_filepaths from shard_infos,
# constructing the samples tensor as we go.
cur_tar_files
:
Dict
[
str
,
Tuple
[
int
,
EPath
]]
=
{}
self
.
shard_infos
=
shard_infos
# Compute the cumsum of the shard counts, so that we can look up
# the shard index for a given sample index.
# Get all tar files from the shard_infos
self
.
shard_count_cumsum
=
[
0
]
self
.
shard_tar_file_idxs
=
[]
sample_idx
=
0
for
shardinfo
in
shard_infos
:
filepath
=
shardinfo
.
path
filename
=
shardinfo
.
name
if
filename
not
in
cur_tar_files
:
cur_tar_files
[
filename
]
=
(
len
(
cur_tar_files
),
filepath
)
sample_idx
+=
shardinfo
.
count
self
.
shard_count_cumsum
.
append
(
sample_idx
)
self
.
shard_tar_file_idxs
.
append
(
cur_tar_files
[
filename
][
0
])
tar_filenames
=
list
(
cur_tar_files
.
keys
())
tar_filepaths
=
[
p
[
1
]
for
p
in
cur_tar_files
.
values
()]
# Instantiate cached reader for the .tar.idx files
self
.
cached_offset_reader
=
CachedItarOffsetReader
(
cache_size
=
itar_cache_size
)
super
().
__init__
(
base_path
=
base_path
,
tar_filenames
=
tar_filenames
,
tar_filepaths
=
tar_filepaths
,
part_filter
=
part_filter
,
itar_cache_size
=
itar_cache_size
,
sample_filter
=
sample_filter
,
)
def
_get_itar_sample_pointer
(
self
,
idx
:
int
)
->
ITarSamplePointer
:
"""
Get the ITarSample object for the given index.
"""
# Find the shard index using binary search
shard_idx
=
bisect_right
(
self
.
shard_count_cumsum
,
idx
)
-
1
if
shard_idx
<
0
or
shard_idx
>=
len
(
self
.
shard_infos
):
raise
IndexError
(
f
"Index out of bounds:
{
idx
}
"
)
# Get the shard info for the given index
shard
=
self
.
shard_infos
[
shard_idx
]
sample_idx_in_shard_file
=
idx
-
self
.
shard_count_cumsum
[
shard_idx
]
# Now we know the tar file and the sample offset in the file.
# We need to figure out the byte offset and size of the sample,
# by looking it up in the .tar.idx file.
byte_offset
,
byte_size
=
self
.
cached_offset_reader
.
get_itar_byte_offset
(
shard
.
path
,
sample_idx_in_shard_file
)
return
ITarSamplePointer
(
tar_file_id
=
self
.
shard_tar_file_idxs
[
shard_idx
],
byte_offset
=
byte_offset
,
byte_size
=
byte_size
,
)
def
__len__
(
self
)
->
int
:
return
self
.
shard_count_cumsum
[
-
1
]
def
__str__
(
self
)
->
str
:
return
(
f
"ShardInfosITarReader("
f
"len=
{
len
(
self
)
}
, base_path=
{
self
.
base_path
}
, "
f
"len(shards)=
{
len
(
self
.
tar_filenames
)
}
, "
f
"shards=[
{
self
.
tar_filenames
[
0
]
if
self
.
tar_filenames
else
'N/A'
}
, ...])"
)
class
SqliteITarEntryReader
(
ITarReader
[
str
]):
"""
A concrete ITarReader that constructs its internal sample list from a SQLite database.
"""
sqlite_reader
:
SqliteIndexReader
db_has_sample_parts
:
int
def
__init__
(
self
,
base_path
:
EPath
,
part_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
itar_cache_size
:
int
=
5
,
sample_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
key_is_full_entryname
:
bool
=
False
,
):
from
megatron.energon.flavors.webdataset.config
import
MAIN_FOLDER_NAME
from
megatron.energon.flavors.webdataset.indexing
import
SqliteIndexReader
# shard_name_to_info_idx = {name: i for i, name in enumerate(wds_meta.info_shard_files)}
tar_filenames
=
get_info_shard_files
(
base_path
)
tar_filepaths
=
[
base_path
/
fn
for
fn
in
tar_filenames
]
# Initialize the SQLite reader
sqlite_path
=
base_path
/
MAIN_FOLDER_NAME
/
"index.sqlite"
self
.
sqlite_reader
=
SqliteIndexReader
(
sqlite_path
)
self
.
db_has_sample_parts
=
self
.
sqlite_reader
.
db_has_sample_parts
()
self
.
key_is_full_entryname
=
key_is_full_entryname
super
().
__init__
(
base_path
=
base_path
,
tar_filenames
=
tar_filenames
,
tar_filepaths
=
tar_filepaths
,
part_filter
=
part_filter
,
itar_cache_size
=
itar_cache_size
,
sample_filter
=
sample_filter
,
)
def
_get_itar_sample_pointer
(
self
,
sample_key
:
str
)
->
ITarSamplePointer
:
"""
Get the ITarSample object for the given index.
"""
return
self
.
sqlite_reader
.
get_sample_pointer_by_key
(
sample_key
)
def
list_all_samples
(
self
)
->
Generator
[
Tuple
[
str
,
int
,
int
],
None
,
None
]:
"""List all samples in the jsonl file.
Returns:
A generator of tuples of (sample_key, size, tar_file_id)
"""
return
self
.
sqlite_reader
.
list_all_samples
()
def
list_all_sample_parts
(
self
)
->
Generator
[
Tuple
[
str
,
int
,
int
],
None
,
None
]:
"""List all sample parts in the jsonl file.
Returns:
A generator of tuples of (sample_key + "." + part_name, size, tar_file_id)
"""
return
self
.
sqlite_reader
.
list_all_sample_parts
()
def
list_sample_parts
(
self
,
sample_key
:
str
,
slow_mode
:
bool
=
False
)
->
Generator
[
Tuple
[
str
,
int
,
int
],
None
,
None
]:
"""Given a sample key, list all its parts. (E.g. given 001, list 001.jpg, 001.json, etc.)
If allow_fallback is True, and the database is an older version, which
does not contain the sample_parts table, we will try to find the sample_parts
in the tar files.
Args:
sample_key: The sample key to list the parts of.
allow_fallback: If True, and the database is an older version, which
does not contain the sample_parts table, we will try to find the sample_parts
in the tar files.
Returns:
A generator of tuples of (part_name, size, tar_file_id)
"""
if
not
slow_mode
:
yield
from
self
.
sqlite_reader
.
list_sample_parts
(
sample_key
)
else
:
sample_pointer
=
self
.
_get_itar_sample_pointer
(
sample_key
)
sample
=
self
.
_get_item_by_sample_pointer
(
sample_pointer
,
0
,
entry_match_fn
=
None
)
assert
isinstance
(
sample
,
dict
),
f
"Sample not found:
{
sample_pointer
}
"
for
ext
in
sample
.
keys
():
if
not
ext
.
startswith
(
"__"
):
yield
ext
,
len
(
sample
[
ext
]),
sample_pointer
.
tar_file_id
def
get_total_size
(
self
)
->
int
:
return
self
.
sqlite_reader
.
get_total_size
()
@
overload
def
__getitem__
(
self
,
key
:
str
)
->
Union
[
FilteredSample
,
tuple
[
bytes
,
SourceInfo
]]:
...
@
overload
def
__getitem__
(
self
,
key
:
slice
)
->
"ITarReader"
:
...
def
__getitem__
(
self
,
key
:
Union
[
slice
,
str
]
)
->
Union
[
FilteredSample
,
tuple
[
bytes
,
SourceInfo
],
ITarReader
]:
"""
Either get a sample from the dataset by the sample key including all its entries,
or get the bytes of a specific entry by the full filename of the entry inside the tar.
"""
if
isinstance
(
key
,
slice
):
# Return a new reader with a sliced samples tensor
raise
NotImplementedError
(
"Slicing is not yet implemented"
)
assert
isinstance
(
key
,
str
),
"Invalid argument type for __getitem__"
if
self
.
key_is_full_entryname
:
m
=
split_name_re
.
match
(
key
)
if
not
m
:
raise
ValueError
(
f
"Invalid file name:
{
key
}
"
)
sample_key
,
sample_ext
=
m
.
groups
()
entry_match_fn
=
lambda
fname
:
key
==
fname
if
self
.
db_has_sample_parts
:
# Directly fetch the sample part (byte offset and size) from the database
raw_sample_pointer
=
self
.
sqlite_reader
.
get_sample_part
(
sample_key
,
sample_ext
)
raw_data
,
source_info
=
self
.
_get_part_by_raw_sample_pointer
(
raw_sample_pointer
,
key
)
return
raw_data
,
source_info
else
:
sample_key
=
key
sample_ext
=
None
entry_match_fn
=
None
sample_pointer
=
self
.
_get_itar_sample_pointer
(
sample_key
)
sample
=
self
.
_get_item_by_sample_pointer
(
sample_pointer
,
key
,
entry_match_fn
=
entry_match_fn
)
assert
sample
is
not
None
,
f
"Sample not found:
{
sample_key
}
"
if
self
.
key_is_full_entryname
:
assert
isinstance
(
sample_ext
,
str
)
assert
len
(
sample
[
"__sources__"
])
==
1
# Return the bytes directly
return
sample
[
sample_ext
],
sample
[
"__sources__"
][
0
]
else
:
return
sample
# Return the FilteredSample
def
__len__
(
self
)
->
int
:
"""Return the total number of samples in the database."""
return
self
.
sqlite_reader
.
get_sample_count
()
def
__str__
(
self
)
->
str
:
"""Return a descriptive string of this reader."""
return
(
f
"SqliteITarEntryReader("
f
"len=
{
len
(
self
)
}
, base_path=
{
self
.
base_path
}
, "
f
"len(shards)=
{
len
(
self
.
tar_filenames
)
}
, "
f
"shards=[
{
self
.
tar_filenames
[
0
]
if
self
.
tar_filenames
else
'N/A'
}
, ...])"
)
def
close
(
self
):
"""Close the SQLite reader and any open ITarFiles."""
# Close the SQLite reader
if
hasattr
(
self
,
"sqlite_reader"
)
and
self
.
sqlite_reader
is
not
None
:
self
.
sqlite_reader
.
close
()
# Close any open ITarFiles (using parent class implementation)
for
tar_file_id
in
list
(
self
.
itar_files_cache
.
keys
()):
tar_file
=
self
.
itar_files_cache
[
tar_file_id
]
if
(
tar_file
is
not
None
and
hasattr
(
tar_file
,
"fileobj"
)
and
tar_file
.
fileobj
is
not
None
):
tar_file
.
fileobj
.
close
()
if
tar_file
is
not
None
and
hasattr
(
tar_file
,
"close"
):
tar_file
.
close
()
del
self
.
itar_files_cache
[
tar_file_id
]
Megatron-Energon/src/megatron/energon/flavors/webdataset/joined_webdataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
abc
import
ABC
from
typing
import
(
Any
,
Callable
,
Dict
,
Generic
,
List
,
Mapping
,
Optional
,
Sequence
,
Type
,
TypeVar
,
Union
,
)
from
megatron.energon.cache
import
FileStore
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.base_dataset
import
(
BaseCoreDatasetFactory
,
Sample
,
SavableDataset
,
legacy_handler
,
set_sample_restore_key
,
)
from
megatron.energon.flavors.webdataset.base_webdataset
import
BaseWebdatasetFactory
from
megatron.energon.flavors.webdataset.error_handler
import
ErrorHandler
from
megatron.energon.flavors.webdataset.indexing
import
JoinIndexReader
from
megatron.energon.flavors.webdataset.itar_reader
import
JoinIndexFileITarReader
from
megatron.energon.flavors.webdataset.sample_loader
import
(
RawSampleData
,
WebdatasetSampleLoaderDataset
,
)
from
megatron.energon.flavors.webdataset.sharder
import
Sharder
from
megatron.energon.flavors.webdataset.structs
import
DatasetSubset
,
ShardInfo
,
reraise_exception
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.map_dataset
import
MapDataset
T_sample
=
TypeVar
(
"T_sample"
,
covariant
=
True
)
class
JoinedWebdatasetFactory
(
BaseCoreDatasetFactory
[
T_sample
],
Sharder
,
ErrorHandler
[
T_sample
],
Generic
[
T_sample
],
ABC
):
"""
Base class for all webdataset loaders. Applies proper sharding across workers. Can join multiple datasets.
"""
training
:
bool
worker_config
:
WorkerConfig
shuffle_over_epochs
:
Optional
[
int
]
=
1
parallel_shard_iters
:
Optional
[
int
]
max_samples_per_sequence
:
Optional
[
int
]
subset
:
Optional
[
DatasetSubset
]
join_index
:
EPath
handler
:
Callable
[[
Exception
,
Optional
[
str
],
Optional
[
list
[
SourceInfo
]]],
None
]
shards
:
List
[
Sequence
[
ShardInfo
]]
part_datasets
:
SavableDataset
[
T_sample
]
inner_datasets
:
List
[
BaseWebdatasetFactory
]
inner_dataset_keys
:
Optional
[
List
[
str
]]
_sample_joiner
:
Callable
[...,
T_sample
]
def
__init__
(
self
,
inner_datasets
:
Union
[
Sequence
[
BaseWebdatasetFactory
],
Mapping
[
str
,
BaseWebdatasetFactory
]],
*
,
training
:
bool
,
worker_config
:
WorkerConfig
,
shuffle_over_epochs
:
Optional
[
int
]
=
1
,
parallel_shard_iters
:
Optional
[
int
]
=
None
,
max_samples_per_sequence
:
Optional
[
int
]
=
None
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
join_index
:
EPath
,
joiner
:
Union
[
Type
[
T_sample
],
Callable
[...,
T_sample
]],
handler
:
Callable
[
[
Exception
,
Optional
[
str
],
Optional
[
list
[
SourceInfo
]]],
None
]
=
reraise_exception
,
):
"""
Constructs the loader for a joined webdataset. The samples from the inner datasets are joined into a single
sample using the joiner function.
Args:
inner_dataset: The inner datasets. Must be loaded internally with `_is_composed=True`.
Either a list (
\\
*args for joiner) or a dict (
\\
*
\\
*kwargs for joiner) of datasets,
where the samples will be passed to the joiner function as
\\
*args or
\\
*
\\
*kwargs.
training: If true, apply shuffling and loop the dataset.
worker_config: Configuration for the workers.
shuffle_over_epochs: Only effective if training=True.
How many epochs to shuffle over if training.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather shard slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices
are drawn with replacement).
parallel_shard_iters: Number of parallel opened shards per worker, shuffling between.
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequentially iterated).
subset: If specified, the inner dataset(s) will be subsetted.
join_index: Path to the join index file. Only required for join_method="left".
joiner: Type of the joined samples or a method for joining the samples.
handler: Exception handler. Args: (exception, key).
"""
self
.
__sample_type__
=
joiner
assert
all
(
not
hasattr
(
d
,
"dataset"
)
for
d
in
inner_datasets
),
(
"Inner dataset was not instantiated with _is_composed=True"
)
if
isinstance
(
joiner
,
type
)
and
issubclass
(
joiner
,
Sample
):
joiner
=
joiner
.
from_joined
else
:
assert
callable
(
joiner
),
f
"Joiner
{
joiner
}
must be a callable or a Sample subclass"
if
isinstance
(
inner_datasets
,
Mapping
):
inner_keys
=
list
(
inner_datasets
.
keys
())
self
.
inner_dataset_keys
=
inner_keys
# Wrap the joiner to pass the samples as kwargs
self
.
_sample_joiner
=
lambda
*
samples
:
joiner
(
**
dict
(
zip
(
inner_keys
,
samples
)))
inner_datasets
=
list
(
inner_datasets
.
values
())
else
:
assert
isinstance
(
inner_datasets
,
Sequence
)
self
.
_sample_joiner
=
joiner
self
.
inner_dataset_keys
=
None
self
.
join_index
=
join_index
self
.
inner_datasets
=
inner_datasets
self
.
shards
=
list
(
zip
(
*
(
dataset
.
shards
for
dataset
in
self
.
inner_datasets
)))
self
.
training
=
training
self
.
worker_config
=
worker_config
self
.
shuffle_over_epochs
=
shuffle_over_epochs
self
.
parallel_shard_iters
=
parallel_shard_iters
self
.
max_samples_per_sequence
=
max_samples_per_sequence
self
.
subset
=
subset
self
.
handler
=
legacy_handler
(
handler
)
def
__len__
(
self
)
->
int
:
return
sum
(
shard
.
count
for
shard
in
self
.
inner_datasets
[
0
].
shards
)
def
build
(
self
,
worker_rotation_offset
:
int
=
0
)
->
SavableDataset
[
T_sample
]:
if
self
.
parallel_shard_iters
is
None
:
if
self
.
training
:
# 16 seems to be a good choice since we don't want too many file handles open
parallel_shard_iters
=
16
else
:
parallel_shard_iters
=
1
else
:
parallel_shard_iters
=
self
.
parallel_shard_iters
# Get join index, get size, distribute samples
# Get samples for each worker on current rank
assert
self
.
join_index
.
is_file
(),
(
f
"Join index
{
self
.
join_index
}
does not exist, did you prepare the metadataset? "
"If you already prepared the metadataset, the join index might be outdated due to "
"modifications to the inner datasets. In this case, you need to re-prepare the metadataset."
)
with
JoinIndexReader
(
self
.
join_index
)
as
jir
:
total_samples
=
len
(
jir
)
workers_sample_slice_offsets
=
self
.
slice_workers
(
total_samples
,
worker_config
=
self
.
worker_config
,
max_samples_per_sequence
=
self
.
max_samples_per_sequence
,
rotation_offset
=
worker_rotation_offset
,
subset
=
self
.
subset
,
)
for
worker_idx
,
sample_slice_offsets
in
enumerate
(
workers_sample_slice_offsets
):
start_idx
=
sample_slice_offsets
[
0
]
end_idx
=
sample_slice_offsets
[
-
1
]
if
len
(
sample_slice_offsets
)
>
6
:
offset_str
=
f
"
{
', '
.
join
(
str
(
o
)
for
o
in
sample_slice_offsets
[:
3
])
}
...<
{
len
(
sample_slice_offsets
)
-
6
}
>
{
', '
.
join
(
str
(
o
)
for
o
in
sample_slice_offsets
[
-
3
:])
}
"
else
:
offset_str
=
", "
.
join
(
str
(
o
)
for
o
in
sample_slice_offsets
)
print
(
f
"rank=
{
self
.
worker_config
.
rank
}
, worker=
{
worker_idx
}
: sample_range=[
{
start_idx
}
,
{
end_idx
}
) in
{
len
(
sample_slice_offsets
)
-
1
}
slices, "
f
"sum(count)=
{
end_idx
-
start_idx
}
: [
{
offset_str
}
]"
)
itar_readers
=
[
JoinIndexFileITarReader
(
index_file
=
self
.
join_index
,
column
=
col_idx
,
tar_filenames
=
indexed_dataset
.
split_part_files
,
base_path
=
indexed_dataset
.
path
,
part_filter
=
indexed_dataset
.
part_filter
,
itar_cache_size
=
parallel_shard_iters
,
)
for
col_idx
,
indexed_dataset
in
enumerate
(
self
.
inner_datasets
)
]
dataset
=
WebdatasetSampleLoaderDataset
(
join_readers
=
itar_readers
,
workers_sample_slice_offsets
=
workers_sample_slice_offsets
,
worker_config
=
self
.
worker_config
,
shuffle_over_epochs
=
self
.
shuffle_over_epochs
if
self
.
training
else
None
,
parallel_slice_iters
=
parallel_shard_iters
,
)
return
self
.
_process_samples
(
dataset
)
def
as_file_store
(
self
)
->
FileStore
:
raise
NotImplementedError
(
"Not supported on joined datasets"
)
@
property
def
paths
(
self
)
->
List
[
EPath
]:
return
[
dataset
.
path
for
dataset
in
self
.
inner_datasets
]
def
_process_samples
(
self
,
dataset
:
SavableDataset
[
RawSampleData
])
->
SavableDataset
[
T_sample
]:
"""Internally loads the sample."""
return
MapDataset
(
dataset
,
self
.
load_sample
,
error_handler
=
self
.
error_handler
,
stateless_map_fn
=
True
,
map_fn_config
=
self
.
config
,
worker_config
=
self
.
worker_config
,
)
def
load_sample
(
self
,
samples
:
RawSampleData
)
->
T_sample
:
assert
len
(
samples
.
data
)
>
0
and
samples
.
data
[
0
]
is
not
None
,
"Always need primary sample"
# First call the loaders of all inner datasets
loaded_samples
=
tuple
(
None
if
sample
is
None
else
dataset
.
load_sample
(
sample
)
for
dataset
,
sample
in
zip
(
self
.
inner_datasets
,
samples
.
data
)
)
# Then combine the loaded smaples into the final type
return
set_sample_restore_key
(
self
.
_sample_joiner
(
*
loaded_samples
),
*
samples
.
__restore_key__
,
src
=
self
,
fail_otherwise
=
True
,
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
type
=
type
(
self
).
__qualname__
,
joined_datasets
=
[
dataset
.
config
()
for
dataset
in
self
.
inner_datasets
],
training
=
self
.
training
,
shuffle_over_epochs
=
self
.
shuffle_over_epochs
,
parallel_shard_iters
=
self
.
parallel_shard_iters
,
max_samples_per_sequence
=
self
.
max_samples_per_sequence
,
subset
=
self
.
subset
.
config
()
if
self
.
subset
is
not
None
else
None
,
)
def
__str__
(
self
):
return
f
"
{
type
(
self
).
__name__
}
(paths=
{
self
.
paths
}
)"
Megatron-Energon/src/megatron/energon/flavors/webdataset/metadata.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
json
from
enum
import
Enum
from
typing
import
List
,
Set
import
braceexpand
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.eyaml
import
load_yaml
,
load_yaml_json
from
megatron.energon.flavors.webdataset.config
import
MAIN_FOLDER_NAME
from
megatron.energon.flavors.webdataset.empty_dataset_error
import
EmptyDatasetError
from
megatron.energon.flavors.webdataset.structs
import
(
ShardInfo
,
WebdatasetInfo
,
WebdatasetSplits
,
)
from
megatron.energon.typed_converter
import
JsonParser
class
EnergonDatasetType
(
Enum
):
METADATASET
=
"metadataset"
WEBDATASET
=
"webdataset"
JSONL
=
"jsonl"
INVALID
=
"invalid"
@
edataclass
class
WebdatasetMeta
:
"""Class for getting metadata from a webdataset."""
sample_excludes
:
Set
[
str
]
shards
:
List
[
ShardInfo
]
split_part_files
:
List
[
str
]
info_shard_files
:
List
[
str
]
@
staticmethod
def
from_config
(
path
:
EPath
,
*
,
split_part
:
str
,
split_config
:
str
|
None
=
None
,
)
->
"WebdatasetMeta"
:
"""
Loads the metadata for a webdataset, i.e. the shards and sample excludes.
Args:
split_part: Which part to load (e.g. 'train', 'val', 'test').
split_config: Config file to use for shard split definitions.
"""
if
split_config
is
None
:
split_config
=
"split.yaml"
parser
=
JsonParser
(
strict
=
True
)
info_object
=
get_dataset_info
(
path
)
info
=
parser
.
raw_to_typed
(
info_object
,
WebdatasetInfo
,
)
try
:
splits
=
parser
.
raw_to_typed
(
load_yaml_json
(
path
/
MAIN_FOLDER_NAME
/
split_config
),
WebdatasetSplits
,
)
except
FileNotFoundError
:
if
split_config
==
"split.yaml"
:
# Try split.json instead
splits
=
parser
.
raw_to_typed
(
load_yaml_json
(
path
/
MAIN_FOLDER_NAME
/
"split.json"
),
WebdatasetSplits
,
)
else
:
raise
assert
split_part
in
splits
.
split_parts
,
f
"Invalid split part:
{
split_part
!
r
}
"
split_excludes
=
{
excluded
for
excluded
in
splits
.
exclude
for
excluded
in
braceexpand
.
braceexpand
(
excluded
)
}
all_split_part_files
=
[
name
for
name
in
splits
.
split_parts
[
split_part
]
for
name
in
braceexpand
.
braceexpand
(
name
)
]
split_part_files
=
[
name
for
name
in
all_split_part_files
if
name
not
in
split_excludes
]
if
len
(
split_part_files
)
==
0
:
raise
EmptyDatasetError
(
f
"No shards found in split part
{
split_part
!
r
}
"
)
return
WebdatasetMeta
(
sample_excludes
=
{
excluded
for
excluded
in
split_excludes
if
"/"
in
excluded
},
shards
=
[
ShardInfo
(
name
=
name
,
path
=
path
/
name
,
count
=
info
.
shard_counts
[
name
],
)
for
name
in
split_part_files
],
split_part_files
=
all_split_part_files
,
info_shard_files
=
list
(
info
.
shard_counts
.
keys
()),
)
def
get_info_shard_files
(
path
:
EPath
)
->
List
[
str
]:
"""Use this if you don't need the full metadata for split parts, but just the shard files."""
parser
=
JsonParser
(
strict
=
True
)
info
=
parser
.
raw_to_typed
(
get_dataset_info
(
path
),
WebdatasetInfo
,
)
return
list
(
info
.
shard_counts
.
keys
())
def
get_dataset_info
(
path
:
EPath
)
->
dict
:
"""Given the path to an energon webdataset that contains a .nv-meta folder,
return the dataset info as a dict.
"""
info_config
=
path
/
MAIN_FOLDER_NAME
/
".info.json"
# YAML for backwards compatibility
yaml_info_config
=
path
/
MAIN_FOLDER_NAME
/
".info.yaml"
if
info_config
.
is_file
():
with
info_config
.
open
(
"r"
)
as
rf
:
return
json
.
load
(
rf
)
elif
yaml_info_config
.
is_file
():
return
load_yaml
(
yaml_info_config
.
read_bytes
())
else
:
raise
ValueError
(
f
"No info config file found at
{
info_config
}
or
{
yaml_info_config
}
"
)
def
check_dataset_info_present
(
path
:
EPath
)
->
bool
:
"""Given the path to an energon webdataset that contains a .nv-meta folder,
return True if the dataset info is present, False otherwise.
"""
return
(
path
/
MAIN_FOLDER_NAME
/
".info.json"
).
is_file
()
or
(
path
/
MAIN_FOLDER_NAME
/
".info.yaml"
).
is_file
()
def
get_dataset_type
(
path
:
EPath
)
->
EnergonDatasetType
:
"""Get the type of the dataset at the given path.
Args:
path: The path to the dataset as specified by the user.
Returns:
The type of the dataset.
"""
if
path
.
is_file
():
if
path
.
name
.
endswith
(
".jsonl"
):
return
EnergonDatasetType
.
JSONL
elif
path
.
name
.
endswith
(
".yaml"
):
return
EnergonDatasetType
.
METADATASET
else
:
return
EnergonDatasetType
.
INVALID
elif
check_dataset_info_present
(
path
):
return
EnergonDatasetType
.
WEBDATASET
else
:
return
EnergonDatasetType
.
INVALID
Megatron-Energon/src/megatron/energon/flavors/webdataset/prepare.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
functools
import
json
import
logging
import
random
import
re
import
tarfile
import
uuid
from
dataclasses
import
asdict
from
pathlib
import
Path
from
typing
import
(
Any
,
Callable
,
Container
,
Dict
,
Generator
,
Iterator
,
List
,
Optional
,
Set
,
Tuple
,
TypeVar
,
Union
,
)
import
braceexpand
import
yaml
from
megatron.energon
import
__version__
from
megatron.energon.bracecollapse
import
collapse
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.eyaml
import
load_yaml
from
megatron.energon.flavors.webdataset.aggregator_pool
import
AggregatorPool
,
BaseAggregator
from
megatron.energon.flavors.webdataset.config
import
MAIN_FOLDER_NAME
,
skip_meta_re
,
split_name_re
from
megatron.energon.flavors.webdataset.indexing
import
SqliteIndexWriter
from
megatron.energon.flavors.webdataset.itar
import
TarIndexWriter
from
megatron.energon.flavors.webdataset.structs
import
ShardInfo
,
WebdatasetInfo
,
WebdatasetSplits
from
megatron.energon.typed_converter
import
to_json_object
logger
=
logging
.
getLogger
(
__name__
)
T
=
TypeVar
(
"T"
,
covariant
=
True
)
@
edataclass
class
IndexAggregatable
:
"""
A base class for all objects that can be returned/yielded by `_preprocess_tar` and
received by `SqliteIndexWriterAggregator.on_item`.
"""
...
@
edataclass
class
IndexSample
(
IndexAggregatable
):
tar_file_id
:
int
sample_key
:
str
sample_index
:
int
byte_offset
:
int
byte_size
:
int
@
edataclass
class
IndexSamplePart
(
IndexAggregatable
):
tar_file_id
:
int
sample_index
:
int
part_name
:
str
content_byte_offset
:
int
content_byte_size
:
int
@
edataclass
class
IndexShardInfo
(
IndexAggregatable
):
shard_info
:
ShardInfo
parts
:
Set
[
str
]
class
SqliteIndexWriterAggregator
(
BaseAggregator
[
Tuple
[
ShardInfo
,
Set
[
str
]],
Tuple
[
List
[
ShardInfo
],
Set
[
str
],
bool
,
List
[
Tuple
[
str
,
int
]]]
]
):
sqlite_path
:
EPath
total_tasks
:
int
progress_fn
:
Optional
[
Callable
]
writer
:
Optional
[
SqliteIndexWriter
]
had_update
:
bool
shards
:
List
[
ShardInfo
]
found_parts
:
Set
[
str
]
prog_iter
:
Iterator
def
__init__
(
self
,
sqlite_path
:
EPath
,
total_tasks
:
int
,
progress_fn
:
Optional
[
Callable
[[
Iterator
[
Any
],
int
],
Iterator
[
T
]]]
=
None
,
):
self
.
sqlite_path
=
sqlite_path
self
.
total_tasks
=
total_tasks
self
.
writer
=
None
self
.
had_update
=
False
self
.
shards
=
[]
self
.
found_parts
=
set
()
if
progress_fn
is
not
None
:
self
.
prog_iter
=
progress_fn
(
iter
(
range
(
self
.
total_tasks
)),
self
.
total_tasks
)
else
:
self
.
prog_iter
=
iter
(
range
(
self
.
total_tasks
))
def
on_start
(
self
,
aggregator_pool
:
AggregatorPool
)
->
None
:
self
.
writer
=
SqliteIndexWriter
(
self
.
sqlite_path
)
def
on_item
(
self
,
item
:
IndexAggregatable
,
aggregator_pool
:
AggregatorPool
,
)
->
None
:
assert
self
.
writer
is
not
None
,
"Writer is not initialized."
if
isinstance
(
item
,
IndexSample
):
self
.
writer
.
append_sample
(
**
asdict
(
item
))
self
.
had_update
=
True
elif
isinstance
(
item
,
IndexSamplePart
):
self
.
writer
.
append_part
(
**
asdict
(
item
))
elif
isinstance
(
item
,
IndexShardInfo
):
# This is a (shard_info, parts) tuple
next
(
self
.
prog_iter
)
shard_info
,
cur_parts
=
item
.
shard_info
,
item
.
parts
assert
shard_info
.
count
!=
0
,
f
"Shard
{
shard_info
.
name
}
has no samples."
self
.
shards
.
append
(
shard_info
)
if
len
(
self
.
found_parts
)
<
50
:
self
.
found_parts
.
update
(
cur_parts
)
def
on_finish
(
self
,
aggregator_pool
:
AggregatorPool
)
->
None
:
assert
self
.
writer
is
not
None
,
"Writer is not initialized."
self
.
writer
.
close
()
def
get_final_result_data
(
self
,
)
->
Tuple
[
List
[
ShardInfo
],
Set
[
str
],
bool
,
List
[
Tuple
[
str
,
int
]]]:
assert
self
.
writer
is
not
None
,
"Writer is not initialized."
return
self
.
shards
,
self
.
found_parts
,
self
.
had_update
,
self
.
writer
.
duplicates
class
WebdatasetPreparator
:
@
staticmethod
def
_preprocess_tar
(
path
:
str
,
shard_to_idx
:
Dict
[
str
,
int
],
parent_path
:
EPath
,
max_parts
:
int
,
)
->
Generator
[
IndexAggregatable
,
None
,
None
]:
"""Process a single tar file, i.e. read the tarinfos, generate the tar index and return
stats.
This method is passed to the `user_produce_data` argument of AggregatorPool.
Args:
path: Path to the tar file.
shard_to_idx: Mapping from shard path to its index
parent_path: Root path of the dataset.
max_parts: Maximum number of different parts to return
Returns:
A generator of items that will be processed by SqliteIndexWriterAggregator.
See method `on_item` of SqliteIndexWriterAggregator.
The items are either:
- A sample dictionary with information about the offset, key etc.
- Or a tuple of shard info and a set of found parts for statistics.
"""
shard_info
=
ShardInfo
(
name
=
path
,
path
=
parent_path
/
path
,
count
=
0
)
try
:
# Note: Write to .tmp file first, then remove .tmp extension, to make sure only complete
# files are used.
tar
:
tarfile
.
TarFile
with
shard_info
.
path
.
open
(
"rb"
)
as
f
:
with
(
tarfile
.
open
(
fileobj
=
f
,
mode
=
"r:*"
)
as
tar
,
TarIndexWriter
(
shard_info
.
path
)
as
iw
,
):
count
=
0
# The parts set is used to collect various file endings that are
# available in the dataset. This is used for the interactive prepare wizard.
parts
=
set
()
last_base_name
=
None
member
:
tarfile
.
TarInfo
next_index_sample
=
None
for
member
in
tar
:
if
not
member
.
isreg
():
continue
if
member
.
name
is
None
:
continue
if
skip_meta_re
.
match
(
member
.
name
):
continue
name_match
=
split_name_re
.
match
(
member
.
name
)
if
name_match
is
None
:
continue
base_name
=
name_match
.
group
(
1
)
if
len
(
parts
)
<
max_parts
:
parts
.
add
(
name_match
.
group
(
2
))
if
last_base_name
!=
base_name
:
iw
.
append
(
member
.
offset
)
if
next_index_sample
is
not
None
:
next_index_sample
[
"byte_size"
]
=
(
member
.
offset
-
next_index_sample
[
"byte_offset"
]
)
yield
IndexSample
(
**
next_index_sample
)
next_index_sample
=
dict
(
tar_file_id
=
shard_to_idx
[
path
],
sample_key
=
base_name
,
sample_index
=
count
,
byte_offset
=
member
.
offset
,
)
last_base_name
=
base_name
count
+=
1
# Yield this part of the sample to the aggregator
yield
IndexSamplePart
(
tar_file_id
=
shard_to_idx
[
path
],
sample_index
=
count
-
1
,
part_name
=
name_match
.
group
(
2
),
content_byte_offset
=
member
.
offset_data
,
content_byte_size
=
member
.
size
,
)
shard_info
.
count
=
count
iw
.
append
(
tar
.
offset
)
if
next_index_sample
is
not
None
:
next_index_sample
[
"byte_size"
]
=
(
tar
.
offset
-
next_index_sample
[
"byte_offset"
]
)
yield
IndexSample
(
**
next_index_sample
)
yield
IndexShardInfo
(
shard_info
=
shard_info
,
parts
=
parts
)
return
except
BaseException
:
logger
.
exception
(
f
"Shard failed to load:
{
path
!
r
}
. Skipping it."
)
yield
IndexShardInfo
(
shard_info
=
shard_info
,
parts
=
set
())
return
@
staticmethod
def
iter_dataset_content
(
path
:
Union
[
str
,
EPath
],
extract_keys
:
Container
[
str
]
=
(),
)
->
Generator
[
Dict
[
str
,
Any
],
None
,
None
]:
"""
Yield example dataset content for a few samples.
Args:
path: Path to the tar file.
"""
path
=
EPath
(
path
)
with
path
.
open
(
"rb"
)
as
f
:
tar
:
tarfile
.
TarFile
with
tarfile
.
open
(
fileobj
=
f
,
mode
=
"r:*"
)
as
tar
:
last_base_name
=
None
sample
=
{}
member
:
tarfile
.
TarInfo
for
member
in
tar
:
if
not
member
.
isreg
():
continue
if
member
.
name
is
None
:
continue
if
skip_meta_re
.
match
(
member
.
name
):
continue
name_match
=
split_name_re
.
match
(
member
.
name
)
if
name_match
is
None
:
continue
base_name
=
name_match
.
group
(
1
)
if
last_base_name
!=
base_name
:
if
sample
:
yield
sample
sample
=
{}
last_base_name
=
base_name
if
name_match
:
if
name_match
.
group
(
2
)
in
extract_keys
:
sample
[
name_match
.
group
(
2
)]
=
tar
.
extractfile
(
member
).
read
()
else
:
sample
[
name_match
.
group
(
2
)]
=
None
if
sample
:
yield
sample
@
classmethod
def
prepare_dataset
(
cls
,
parent_path
:
Union
[
Path
,
EPath
],
paths
:
List
[
str
],
*
,
split_parts_ratio
:
Optional
[
List
[
Tuple
[
str
,
float
]]]
=
None
,
split_parts_patterns
:
Optional
[
List
[
Tuple
[
str
,
str
]]]
=
None
,
split_config
:
str
=
"split.yaml"
,
shuffle_seed
:
Optional
[
int
]
=
42
,
progress_fn
:
Callable
[[
Iterator
[
Any
],
int
],
Iterator
[
T
]]
=
(
lambda
x
,
y
:
x
),
workers
:
int
=
32
,
tar_index_only
:
bool
=
False
,
)
->
Tuple
[
Set
[
str
],
List
[
Tuple
[
str
,
int
]]]:
"""
Preprocess the shards and write the split config. Preprocessing is done in parallel.
Counts the number of samples in each shard.
Args:
parent_path: Common parent path for the shards
paths: Paths to the shards
split_parts_ratio: Names of splits and their ratio (will be normalized)
split_parts_patterns: Names of splits and their path patterns
split_config: Filename for the split config (`parent_path / '.nv-meta' / split_config`), may be yaml or json
shuffle_seed: Seed for shuffling shards before splitting into split_parts. None to
disable.
progress_fn: Callback for progress bar
workers: Number of parallel workers for reading each shard
tar_index_only: Only create tar-index, then exit
Returns:
The set of all parts found in the shards. But at most 50.
"""
parent_path
=
EPath
(
parent_path
)
paths
=
[
path
for
path
in
paths
for
path
in
braceexpand
.
braceexpand
(
path
)]
# Construct a mapping from relative shard path to its index
shard_to_idx
=
{
path
:
idx
for
idx
,
path
in
enumerate
(
paths
)}
(
parent_path
/
MAIN_FOLDER_NAME
).
mkdir
(
exist_ok
=
True
)
aggregator
=
SqliteIndexWriterAggregator
(
parent_path
/
MAIN_FOLDER_NAME
/
"index.sqlite"
,
total_tasks
=
len
(
paths
),
progress_fn
=
progress_fn
,
)
process_tar
=
functools
.
partial
(
cls
.
_preprocess_tar
,
shard_to_idx
=
shard_to_idx
,
parent_path
=
parent_path
,
max_parts
=
50
,
)
pool
=
AggregatorPool
(
num_workers
=
workers
,
user_produce_data
=
process_tar
,
aggregator
=
aggregator
,
)
for
path
in
paths
:
pool
.
submit_task
(
path
)
shards
,
found_parts
,
had_update
,
duplicates
=
pool
.
process
()
if
had_update
:
logger
.
info
(
"Regenerating dataset UUID..."
)
with
(
parent_path
/
MAIN_FOLDER_NAME
/
"index.uuid"
).
open
(
"w"
)
as
f
:
f
.
write
(
str
(
uuid
.
uuid4
()))
json_info_config
=
parent_path
/
MAIN_FOLDER_NAME
/
".info.json"
yaml_info_config
=
parent_path
/
MAIN_FOLDER_NAME
/
".info.yaml"
if
tar_index_only
:
if
yaml_info_config
.
is_file
()
and
not
json_info_config
.
is_file
():
# Convert legacy .info.yaml to .info.json
with
json_info_config
.
open
(
"w"
)
as
f
:
json
.
dump
(
load_yaml
(
yaml_info_config
.
read_bytes
()),
f
,
indent
=
2
)
return
found_parts
,
duplicates
assert
len
(
shards
)
==
len
(
shard_to_idx
),
(
f
"Lengths of shards and shard_to_idx do not match:
{
len
(
shards
)
}
!=
{
len
(
shard_to_idx
)
}
"
)
# Sort the shards according to the order in the input list
shards
.
sort
(
key
=
lambda
shard
:
shard_to_idx
[
shard
.
name
])
# Save info
assert
[
shard
.
name
for
shard
in
shards
]
==
list
(
shard_to_idx
.
keys
()),
(
"Shards are not in the same order as in the input list."
)
info
=
WebdatasetInfo
(
energon_version
=
__version__
,
shard_counts
=
{
shard
.
name
:
shard
.
count
for
shard
in
shards
},
)
print
(
f
"Saving info to
{
json_info_config
}
"
)
with
json_info_config
.
open
(
"w"
)
as
wf
:
json
.
dump
(
to_json_object
(
info
),
wf
,
indent
=
2
)
if
yaml_info_config
.
is_file
():
# If a .info.yaml existed previously, let's also update it
# to keep them in sync
with
yaml_info_config
.
open
(
"w"
)
as
wf
:
yaml
.
dump
(
to_json_object
(
info
),
wf
)
if
split_parts_ratio
is
not
None
:
# Normalize ratio
total_ratio
=
sum
(
split_ratio
for
_
,
split_ratio
in
split_parts_ratio
)
split_parts_ratio
=
[
(
split_part
,
split_ratio
/
total_ratio
)
for
split_part
,
split_ratio
in
split_parts_ratio
]
# Sample from shards based on the split ratio from split parts
split_shards
=
{}
if
shuffle_seed
is
not
None
:
random
.
Random
(
shuffle_seed
).
shuffle
(
shards
)
split_total
=
0
split_offset
=
0
for
split_part
,
split_ratio
in
split_parts_ratio
:
split_total
+=
split_ratio
split_end
=
int
(
len
(
shards
)
*
split_total
)
split_shards
[
split_part
]
=
[
shard
.
name
for
shard
in
shards
[
split_offset
:
split_end
]]
split_offset
=
split_end
else
:
assert
split_parts_patterns
is
not
None
,
(
"Require either split_parts_ratio or split_parts_patterns"
)
# Sample from shards based on the split patterns from split parts
split_shards
=
{}
for
split_part
,
split_pattern
in
split_parts_patterns
:
patterns
=
[
re
.
compile
(
pattern
)
for
pattern
in
braceexpand
.
braceexpand
(
split_pattern
)
]
split_shards
[
split_part
]
=
[
shard
.
name
for
shard
in
shards
if
any
(
pattern
.
match
(
shard
.
name
)
for
pattern
in
patterns
)
]
# Optimize the split parts by trying to bracecollapse the shard names
print
(
"Collapsing split parts... "
,
flush
=
True
,
end
=
""
)
for
split_part
in
split_shards
:
split_shards
[
split_part
]
=
collapse
(
split_shards
[
split_part
],
keep_order
=
True
)
print
(
"Done"
,
flush
=
True
)
# Save split config
splits_config
=
WebdatasetSplits
(
split_parts
=
split_shards
)
with
(
parent_path
/
MAIN_FOLDER_NAME
/
split_config
).
open
(
"w"
)
as
wf
:
if
split_config
.
endswith
(
".yaml"
):
yaml
.
dump
(
to_json_object
(
splits_config
),
wf
,
sort_keys
=
False
)
elif
split_config
.
endswith
(
".json"
):
json
.
dump
(
to_json_object
(
splits_config
),
wf
,
indent
=
2
)
else
:
raise
ValueError
(
f
"Invalid split config extension:
{
split_config
}
"
)
return
found_parts
,
duplicates
Megatron-Energon/src/megatron/energon/flavors/webdataset/sample_decoder.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Callable
,
Literal
import
webdataset
from
megatron.energon.av
import
AVWebdatasetDecoder
from
megatron.energon.cache
import
FileStoreDecoder
from
megatron.energon.flavors.webdataset.config
import
split_name_re
try
:
import
filetype
except
ImportError
:
filetype
=
None
ImageDecoderType
=
Literal
[
"l8"
,
"rgb8"
,
"rgba8"
,
"l"
,
"rgb"
,
"rgba"
,
"torchl8"
,
"torchrgb8"
,
"torchrgba8"
,
"torchl"
,
"torchrgb"
,
"torch"
,
"torchrgba"
,
"pill"
,
"pil"
,
"pilrgb"
,
"pilrgba"
,
]
AVDecoderType
=
Literal
[
"torch"
,
"AVDecoder"
,
"pyav"
]
class
GuessingHandlerWrapper
:
"""A wrapper that guesses the extension of the file using the `filetype` package."""
def
__init__
(
self
,
handler
:
Callable
[[
str
,
bytes
],
Any
]):
"""
Wraps a handler to guess the extension of the file using the `filetype` package.
Args:
handler: The handler to wrap.
"""
self
.
handler
=
handler
if
filetype
is
None
:
raise
ImportError
(
"filetype is not installed. Install it with `pip install filetype`."
)
def
__call__
(
self
,
key
:
str
,
data
:
bytes
)
->
Any
:
"""The handler that guesses the extension of the file using the `filetype` package, then calls the delegate handler."""
kind
=
filetype
.
guess
(
data
)
if
kind
is
not
None
:
key
=
kind
.
extension
return
self
.
handler
(
key
,
data
)
@
staticmethod
def
wrap
(
active
:
bool
,
handlers
:
list
[
Callable
[[
str
,
bytes
],
Any
]]
)
->
list
[
Callable
[[
str
,
bytes
],
Any
]]:
"""
Wraps a list of handlers to guess the extension of the file using the `filetype` package.
Args:
active: Whether to wrap the handlers.
handlers: The handlers to wrap.
Returns:
The list of wrapped handlers.
"""
if
not
active
:
return
handlers
return
[
GuessingHandlerWrapper
(
handler
)
for
handler
in
handlers
]
class
SampleDecoder
(
FileStoreDecoder
):
"""The default decoder for webdataset samples."""
def
__init__
(
self
,
*
,
image_decode
:
ImageDecoderType
=
"torchrgb"
,
av_decode
:
AVDecoderType
=
"AVDecoder"
,
video_decode_audio
:
bool
=
False
,
guess_content
:
bool
=
False
,
):
"""
Args:
image_decode: This defines the decoding results.
av_decode: If "AVDecoder", returns an AVDecoder instance for flexible decoding. If "torch",
returns decoded VideoData.
video_decode_audio: Whether to decode audio from video files.
guess_content: Whether to guess the contents of the file using the `filetype` package.
"""
self
.
_config
=
dict
(
image_decode
=
image_decode
,
av_decode
=
av_decode
,
video_decode_audio
=
video_decode_audio
,
guess_content
=
guess_content
,
)
self
.
_decoder
=
webdataset
.
autodecode
.
Decoder
(
GuessingHandlerWrapper
.
wrap
(
guess_content
,
[
webdataset
.
autodecode
.
imagehandler
(
image_decode
),
AVWebdatasetDecoder
(
video_decode_audio
=
video_decode_audio
,
av_decode
=
av_decode
,
),
],
),
)
def
decode
(
self
,
fname
:
str
,
raw
:
bytes
)
->
Any
:
m
=
split_name_re
.
match
(
fname
)
if
m
:
cur_base_name
,
ext
=
m
.
groups
()
else
:
# If the file name does not match the expected format, use the file name as is.
# The handlers may still handle the file correctly (e.g. if guessing by content).
cur_base_name
=
fname
ext
=
"."
return
self
(
{
"__key__"
:
cur_base_name
,
ext
:
raw
,
}
)[
ext
]
def
__call__
(
self
,
sample
:
dict
)
->
dict
:
return
self
.
_decoder
(
sample
)
def
config
(
self
)
->
dict
:
return
self
.
_config
DEFAULT_DECODER
=
SampleDecoder
()
Megatron-Energon/src/megatron/energon/flavors/webdataset/sample_loader.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Dict
,
Generator
,
Iterator
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.flavors.base_dataset
import
FlexState
,
SavableDataset
from
megatron.energon.flavors.webdataset.itar_reader
import
ITarReader
from
megatron.energon.flavors.webdataset.structs
import
FilteredSample
from
megatron.energon.rng
import
WorkerRng
from
megatron.energon.worker
import
WorkerConfig
@
edataclass
class
RawSampleData
:
"""Represents the iteration state of a single slice slice to the index."""
#: Index of the sample. This is also the restore key
__restore_key__
:
Tuple
[
str
,
int
]
#: The sample data
data
:
Tuple
[
Optional
[
FilteredSample
],
...]
@
edataclass
class
SliceState
:
"""Represents the iteration state of a single slice slice to the index."""
#: The slice index of this slice state
index
:
int
#: The actual state: The global sample offset (`slice[index] <= offset < slice[index + 1]``)
current
:
int
class
WebdatasetSampleLoaderDataset
(
SavableDataset
[
RawSampleData
]):
"""Internal class for loading samples from webdataset slices"""
#: The readers for each joined dataset
join_readers
:
Sequence
[
ITarReader
]
#: The offsets of the slice slices to iterate over for the current worker
slice_offsets
:
Optional
[
Sequence
[
int
]]
# If = 1, every sample is seen exactly once per epoch. If > 1, samples
# (or rather slice slices) are shuffled within this number of epochs (i.e. randomly
# selected without replacement). If None, the slices are effectively shuffle over
# infinite epochs (i.e. slice slices are drawn with replacement).
shuffle_over_epochs
:
Optional
[
int
]
# Number of parallel iterators to be opened simultaneously (and random sample between them)
parallel_slice_iters
:
int
# Worker's random generator
_worker_rng
:
WorkerRng
#: The RNG state to be used for regenerating the pending slices
_pending_slices_rng_state
:
Optional
[
FlexState
]
#: The number of slices that have already been opened / processed and thus been removed from the
# pending slices.
_pending_slices_offset
:
Optional
[
int
]
#: Pending slices are the slices which have not yet been opened, but should be processed
# in the current "epoch". If None, regenerate from the seed and offset.
_pending_slice_indexes
:
Optional
[
List
[
int
]]
#: The active slices are the currently opened slices. May contain `None`, if there are fewer
# slices available (i.e. pending_slices empty) than parallel slice iterators requested.
_active_slice_state
:
List
[
Optional
[
SliceState
]]
#: The total number of samples retrieved, it's just a monotonically increasing counter
_sample_count
:
int
#: Number of epochs this dataset has been iterated over
_epoch_count
:
int
#: The number of samples retrieved in current epoch
_epoch_sample_count
:
int
_savable_fields
=
(
"_worker_rng"
,
"_pending_slices_offset"
,
"_pending_slice_indexes"
,
"_active_slice_state"
,
"_sample_count"
,
"_epoch_count"
,
"_epoch_sample_count"
,
)
def
__init__
(
self
,
join_readers
:
Sequence
[
ITarReader
],
workers_sample_slice_offsets
:
Sequence
[
Sequence
[
int
]],
*
,
worker_config
:
WorkerConfig
,
shuffle_over_epochs
:
Optional
[
int
]
=
None
,
parallel_slice_iters
:
int
=
1
,
):
"""
The webdataset loader. Iterates over the slice infos and yields the samples.
Args:
join_readers: A sequence of the joined readers (or just a single reader) to iterate over.
worker_slice_offsets: The offsets of the slice slices to iterate over, for each worker.
worker_config: The worker configuration.
shuffle_over_epochs: If None, disable shuffling.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather slice slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the slices are effectively shuffle over infinite epochs (i.e. slice slices
are drawn with replacement).
parallel_slice_iters: If > 1, samples are randomly drawn from parallel slice iterators.
This will not impact performance, but increase randomness. If = 1, the slices are
iterated in order.
"""
super
().
__init__
(
worker_config
=
worker_config
)
self
.
join_readers
=
join_readers
self
.
shuffle_over_epochs
=
shuffle_over_epochs
self
.
parallel_slice_iters
=
parallel_slice_iters
# Store the slices for all workers
# The slices for the current worker, will have to be extracted from this list later
self
.
workers_slice_offsets
=
workers_sample_slice_offsets
self
.
slice_offsets
=
None
self
.
reset_state_own
()
assert
shuffle_over_epochs
is
None
or
shuffle_over_epochs
==
-
1
or
shuffle_over_epochs
>=
1
assert
self
.
parallel_slice_iters
>=
1
def
reset_state_own
(
self
)
->
None
:
self
.
_worker_rng
=
WorkerRng
(
self
.
worker_config
)
self
.
_pending_slice_indexes
=
None
self
.
_pending_slices_offset
=
None
self
.
_pending_slices_rng_state
=
None
self
.
_active_slice_state
=
[
None
]
*
self
.
parallel_slice_iters
self
.
_sample_count
=
0
self
.
_epoch_count
=
0
self
.
_epoch_sample_count
=
0
def
ensure_slice_offsets
(
self
)
->
None
:
self
.
worker_config
.
assert_worker
()
if
self
.
slice_offsets
is
None
:
self
.
slice_offsets
=
self
.
workers_slice_offsets
[
self
.
worker_config
.
rank_worker_id
()]
def
_get_sample
(
self
,
index
:
int
)
->
RawSampleData
:
return
RawSampleData
(
__restore_key__
=
(
"Webdataset"
,
index
),
data
=
tuple
(
reader
[
index
]
for
reader
in
self
.
join_readers
),
)
def
_slices_once
(
self
)
->
List
[
int
]:
"""Yields the indexes to slice offsets once. Possibly shuffles the list."""
assert
self
.
slice_offsets
is
not
None
num_slices
=
len
(
self
.
slice_offsets
)
-
1
slices_offset
=
self
.
_pending_slices_offset
if
self
.
shuffle_over_epochs
is
None
:
# No shuffling
res_list
=
list
(
range
(
num_slices
))
if
slices_offset
is
None
:
slices_offset
=
0
else
:
# Restore state or start new (and save)
if
slices_offset
is
None
:
# Start new state. First, save the state to restore the same order.
self
.
_pending_slices_rng_state
=
self
.
_worker_rng
.
save_state
()
rng
=
self
.
_worker_rng
slices_offset
=
0
else
:
# Restore the state. Create a dedicated rng for this, as the main rng is in the
# state for iterating from the next iterator.
assert
self
.
_pending_slices_rng_state
is
not
None
rng
=
WorkerRng
(
self
.
worker_config
)
rng
.
restore_state
(
self
.
_pending_slices_rng_state
)
if
self
.
shuffle_over_epochs
==
-
1
:
# Shuffle with replacement (i.e. infinite epochs), effectively return as many slices
# as are required for parallel slice iterators.
# Next slices are drawn in the _slices_iter.
res_list
=
[
rng
.
randbelow
(
num_slices
)
for
_
in
range
(
self
.
parallel_slice_iters
)]
elif
self
.
shuffle_over_epochs
>=
1
:
# Shuffle without replacement (potentially over multiple epochs)
res_list
=
rng
.
shuffle
(
list
(
range
(
num_slices
))
*
self
.
shuffle_over_epochs
)
else
:
raise
ValueError
(
f
"Invalid shuffle_over_epochs:
{
self
.
shuffle_over_epochs
}
"
)
# Reverse, such that pop returns the first element (in O(1) time)
res_list
.
reverse
()
# Skip restored slice list already processed slices
assert
slices_offset
is
not
None
self
.
_pending_slices_offset
=
slices_offset
if
slices_offset
>
0
:
# Those have already been popped in the current state
del
res_list
[
-
slices_offset
:]
# Set the pending slices
self
.
_pending_slice_indexes
=
res_list
return
res_list
def
_slices_iter
(
self
)
->
Generator
[
RawSampleData
,
None
,
None
]:
"""Iterates the samples in a list of slices, possibly using multiple parallel iterators over
the slices."""
assert
self
.
slice_offsets
is
not
None
active_slice_probs
=
torch
.
zeros
(
self
.
parallel_slice_iters
,
dtype
=
torch
.
float32
)
active_slices
=
self
.
_active_slice_state
pending_slice_indexes
=
self
.
_pending_slice_indexes
def
slice_at
(
idx
:
int
)
->
SliceState
:
assert
self
.
slice_offsets
is
not
None
return
SliceState
(
index
=
idx
,
current
=
self
.
slice_offsets
[
idx
],
)
# Weight the slices by their size to get a more even distribution of samples
if
any
(
s
is
not
None
for
s
in
active_slices
)
or
self
.
_pending_slices_offset
is
not
None
:
# Having an active state, or pending slices. This means we are resuming an epoch.
if
pending_slice_indexes
is
None
:
# Need to restore the pending slices
pending_slice_indexes
=
self
.
_slices_once
()
assert
pending_slice_indexes
is
not
None
# Restore the state
assert
len
(
active_slices
)
==
self
.
parallel_slice_iters
for
idx
,
slice_state
in
enumerate
(
active_slices
):
if
slice_state
is
not
None
:
active_slice_probs
[
idx
]
=
(
self
.
slice_offsets
[
slice_state
.
index
+
1
]
-
self
.
slice_offsets
[
slice_state
.
index
]
)
if
self
.
worker_config
.
should_log
(
level
=
1
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"WebdatasetSampleLoaderDataset._slices_iter.resume_epoch"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
self
.
worker_config
.
rank_worker_id
(),
"pending_slice_indexes"
:
pending_slice_indexes
,
"active_slices"
:
[
(
None
if
state
is
None
else
{
"index"
:
state
.
index
,
"current"
:
state
.
current
,
}
)
for
state
in
active_slices
],
"count"
:
self
.
_sample_count
,
"epoch"
:
self
.
_epoch_count
,
"epoch_count"
:
self
.
_epoch_sample_count
,
"probs"
:
active_slice_probs
.
tolist
(),
}
)
else
:
# Start a new epoch
assert
pending_slice_indexes
is
None
pending_slice_indexes
=
self
.
_slices_once
()
if
self
.
worker_config
.
should_log
(
level
=
1
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"WebdatasetSampleLoaderDataset._slices_iter.next_epoch"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
self
.
worker_config
.
rank_worker_id
(),
"pending_slice_indexes"
:
pending_slice_indexes
,
"count"
:
self
.
_sample_count
,
"epoch"
:
self
.
_epoch_count
,
"epoch_count"
:
self
.
_epoch_sample_count
,
"probs"
:
active_slice_probs
.
tolist
(),
"shuffle_over_epochs"
:
self
.
shuffle_over_epochs
,
}
)
assert
self
.
_pending_slices_offset
is
not
None
# List of slice iterators, always of length `parallel_slice_iters`. May contain `None`.
active_slices
.
clear
()
# Fill up the slice iterators
while
len
(
pending_slice_indexes
)
>
0
and
len
(
active_slices
)
<
self
.
parallel_slice_iters
:
slice_index
=
pending_slice_indexes
.
pop
()
self
.
_pending_slices_offset
+=
1
slice_state
=
slice_at
(
slice_index
)
active_slice_probs
[
len
(
active_slices
)]
=
(
self
.
slice_offsets
[
slice_state
.
index
+
1
]
-
self
.
slice_offsets
[
slice_state
.
index
]
)
active_slices
.
append
(
slice_state
)
# Fill up the slice iterators with None
for
_
in
range
(
len
(
active_slices
),
self
.
parallel_slice_iters
):
active_slices
.
append
(
None
)
# print(
# f"Next slice iters generated for {self.worker_config.rank}:{self.worker_config.rank_worker_id()}: probs={active_slice_probs}"
# )
# for slice_state in active_slices:
# if slice_state is None:
# print(" - None")
# else:
# print(
# f" - [{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] at {slice_state.current}"
# )
# Iterate over the slice iterators while there is an iterator left
while
torch
.
count_nonzero
(
active_slice_probs
).
item
()
>
0
:
if
self
.
shuffle_over_epochs
is
None
:
# No shuffling, deterministic order, always the same
assert
self
.
parallel_slice_iters
==
1
slice_idx
=
0
else
:
# Take a random slice iterator
slice_idx
=
self
.
_worker_rng
.
choice_idx
(
active_slice_probs
)
slice_state
=
active_slices
[
slice_idx
]
assert
slice_state
is
not
None
sample
=
self
.
_get_sample
(
slice_state
.
current
)
# print(f"Read sample at {slice_state.current} -> {'None' if sample is None or sample.data[0] is None else sample.data[0]['__key__']}")
slice_state
.
current
+=
1
self
.
_sample_count
+=
1
self
.
_epoch_sample_count
+=
1
if
slice_state
.
current
>=
self
.
slice_offsets
[
slice_state
.
index
+
1
]:
# Iterator exhausted -> take next / remove from list
if
len
(
pending_slice_indexes
)
>
0
or
self
.
shuffle_over_epochs
==
-
1
:
if
len
(
pending_slice_indexes
)
>
0
:
# Take the next slice (without replacement)
next_idx
=
pending_slice_indexes
.
pop
()
assert
self
.
_pending_slices_offset
is
not
None
self
.
_pending_slices_offset
+=
1
else
:
# Randomly select a new slice directly (with replacement)
num_slices
=
len
(
self
.
slice_offsets
)
-
1
next_idx
=
self
.
_worker_rng
.
randbelow
(
num_slices
)
next_slice_state
=
slice_at
(
next_idx
)
active_slice_probs
[
slice_idx
]
=
(
self
.
slice_offsets
[
next_slice_state
.
index
+
1
]
-
self
.
slice_offsets
[
next_slice_state
.
index
]
)
active_slices
[
slice_idx
]
=
next_slice_state
# print(
# f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} "
# f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, "
# f"taking next slice {next_slice_state} [{slice_offsets[next_slice_state.index]}, {slice_offsets[next_slice_state.index + 1]}], "
# f"{len(pending_slice_indexes)} slices left, probs={active_slice_probs.tolist()}"
# )
else
:
active_slice_probs
[
slice_idx
]
=
0
active_slices
[
slice_idx
]
=
None
# print(
# f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} "
# f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, "
# f"no next slice, probs={active_slice_probs.tolist()}"
# )
if
self
.
worker_config
.
should_log
(
level
=
2
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"WebdatasetSampleLoaderDataset._slices_iter.exhausted"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
self
.
worker_config
.
rank_worker_id
(),
"remaining"
:
len
(
pending_slice_indexes
),
"count"
:
self
.
_sample_count
,
"epoch"
:
self
.
_epoch_count
,
"epoch_count"
:
self
.
_epoch_sample_count
,
"probs"
:
active_slice_probs
.
tolist
(),
}
)
if
sample
.
data
[
0
]
is
not
None
:
# Otherwise the sample was skipped.
if
self
.
worker_config
.
should_log
(
level
=
1
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"WebdatasetSampleLoaderDataset._slices_iter.yield"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
self
.
worker_config
.
rank_worker_id
(),
"index"
:
sample
.
__restore_key__
[
1
],
"key"
:
sample
.
data
[
0
][
"__key__"
],
"shard"
:
sample
.
data
[
0
][
"__shard__"
],
"count"
:
self
.
_sample_count
,
"epoch"
:
self
.
_epoch_count
,
"epoch_count"
:
self
.
_epoch_sample_count
,
}
)
# Now, yield the sample
yield
sample
del
sample
if
self
.
worker_config
.
should_log
(
level
=
2
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"WebdatasetSampleLoaderDataset._slices_iter.all_exhausted"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
self
.
worker_config
.
rank_worker_id
(),
"count"
:
self
.
_sample_count
,
"epoch"
:
self
.
_epoch_count
,
"epoch_count"
:
self
.
_epoch_sample_count
,
}
)
# Epoch has finished, reset states.
self
.
_epoch_count
+=
1
self
.
_epoch_sample_count
=
0
self
.
_pending_slice_indexes
=
None
self
.
_pending_slices_offset
=
None
# print(
# f"slice iters exhausted for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} after {cnt} samples"
# )
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
if
worker_idx
is
None
:
self
.
worker_config
.
assert_worker
()
worker_idx
=
self
.
worker_config
.
rank_worker_id
()
worker_slice_offsets
=
self
.
workers_slice_offsets
[
worker_idx
]
return
worker_slice_offsets
[
-
1
]
-
worker_slice_offsets
[
0
]
def
worker_has_samples
(
self
)
->
bool
:
self
.
worker_config
.
assert_worker
()
self
.
ensure_slice_offsets
()
assert
self
.
slice_offsets
is
not
None
return
len
(
self
.
slice_offsets
)
>
1
def
__iter__
(
self
)
->
Iterator
[
RawSampleData
]:
self
.
worker_config
.
assert_worker
()
self
.
ensure_slice_offsets
()
assert
self
.
slice_offsets
is
not
None
if
self
.
worker_config
.
should_log
(
level
=
1
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"WebdatasetSampleLoaderDataset.__iter__"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
self
.
worker_config
.
rank_worker_id
(),
"slice_offsets"
:
self
.
slice_offsets
,
"parallel_slice_iters"
:
self
.
parallel_slice_iters
,
"shuffle_over_epochs"
:
self
.
shuffle_over_epochs
,
}
)
if
len
(
self
.
slice_offsets
)
<=
1
:
return
yield
from
self
.
_slices_iter
()
def
can_restore_sample
(
self
)
->
bool
:
return
True
def
assert_can_restore
(
self
)
->
None
:
pass
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
RawSampleData
:
# Key is: ("Webdataset", index)
# The key is joined in the dataset's typed joining (i.e. load_sample of JoinedWebdatasetFactory).
id
,
index
=
restore_key
assert
id
==
"Webdataset"
assert
isinstance
(
index
,
int
)
return
self
.
_get_sample
(
index
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"joins"
:
len
(
self
.
join_readers
),
"len"
:
len
(
self
),
"slice_offsets"
:
[
list
(
worker_slice_offsets
)
for
worker_slice_offsets
in
self
.
workers_slice_offsets
],
"worker_config"
:
self
.
worker_config
.
config
(),
"shuffle_over_epochs"
:
self
.
shuffle_over_epochs
,
"parallel_slice_iters"
:
self
.
parallel_slice_iters
,
}
def
__str__
(
self
):
return
f
"WebdatasetSampleLoaderDataset(join_readers=
{
self
.
join_readers
}
, shuffle_over_epochs=
{
self
.
shuffle_over_epochs
}
, parallel_slice_iters=
{
self
.
parallel_slice_iters
}
)"
Megatron-Energon/src/megatron/energon/flavors/webdataset/sharder.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
itertools
import
zip_longest
from
typing
import
Generator
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
from
megatron.energon.flavors.webdataset.structs
import
DatasetSubset
,
ShardInfo
from
megatron.energon.worker
import
WorkerConfig
class
Sharder
:
@
staticmethod
def
_split_shard
(
start_offset
:
int
,
end_offset
:
int
,
max_samples_per_sequence
:
Optional
[
int
],
)
->
Tuple
[
int
,
...]:
"""Splits a shard into multiple slices of max_samples_per_sequence (more or less).
Returns the starting index of each slice (excluding the end_offset)."""
if
(
max_samples_per_sequence
is
not
None
and
end_offset
-
start_offset
>
max_samples_per_sequence
*
1.5
):
# Split the shard into slices of max_samples_per_sequence (more or less)
slice_count
=
max
(
round
((
end_offset
-
start_offset
)
/
max_samples_per_sequence
),
1
)
samples_per_sequence
=
(
end_offset
-
start_offset
)
/
slice_count
# Note this must include the end offset as well, so slice_count + 1 steps
return
tuple
(
start_offset
+
int
(
slice
*
samples_per_sequence
)
for
slice
in
range
(
slice_count
)
)
else
:
return
(
start_offset
,)
@
classmethod
def
_split_shards
(
cls
,
shard_cumsums
:
np
.
ndarray
,
offsets
:
Sequence
[
int
],
*
,
max_samples_per_sequence
:
Optional
[
int
],
)
->
Generator
[
Sequence
[
int
],
None
,
None
]:
"""
Splits the shards into multiple lists based on the offsets. The first offset is the start
of the first shard emitted, the last offset is the beginning of the last shard emitted.
(i.e. number of slice sequences emitted is `len(offsets) - 1`).
Args:
shard_cumsums: The source shard offsets
offsets: The offsets to samples to get shards for (must be strictly increasing)
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequential).
Returns:
A list of starting offsets for each slice (including the end offset)
"""
# Find shard idx for start
start_index
=
np
.
searchsorted
(
shard_cumsums
,
offsets
[
0
],
side
=
"right"
)
-
1
for
start_offset
,
end_offset
in
zip
(
offsets
,
offsets
[
1
:]):
# Find shard idx for end
end_index
=
start_index
while
end_index
+
1
<
len
(
shard_cumsums
)
and
end_offset
>
shard_cumsums
[
end_index
+
1
]:
end_index
+=
1
if
start_index
==
end_index
:
yield
(
*
cls
.
_split_shard
(
start_offset
=
start_offset
,
end_offset
=
end_offset
,
max_samples_per_sequence
=
max_samples_per_sequence
,
),
end_offset
,
)
else
:
# Middle is the original shards, start and end get an offset/length
yield
(
*
(
cls
.
_split_shard
(
start_offset
=
start_offset
,
end_offset
=
shard_cumsums
[
start_index
+
1
],
max_samples_per_sequence
=
max_samples_per_sequence
,
)
if
shard_cumsums
[
start_index
+
1
]
>
start_offset
else
()
),
*
(
offset
for
inner_shard_start
,
inner_shard_end
in
zip
(
shard_cumsums
[
start_index
+
1
:
end_index
],
shard_cumsums
[
start_index
+
2
:
end_index
+
1
],
)
for
offset
in
cls
.
_split_shard
(
start_offset
=
inner_shard_start
,
end_offset
=
inner_shard_end
,
max_samples_per_sequence
=
max_samples_per_sequence
,
)
),
*
cls
.
_split_shard
(
start_offset
=
shard_cumsums
[
end_index
],
end_offset
=
end_offset
,
max_samples_per_sequence
=
max_samples_per_sequence
,
),
end_offset
,
)
start_index
=
end_index
@
classmethod
def
_split_slices
(
cls
,
offsets
:
Sequence
[
int
],
*
,
max_samples_per_sequence
:
Optional
[
int
],
)
->
Generator
[
Sequence
[
int
],
None
,
None
]:
"""
Splits the offsets into approximately `max_samples_per_sequence` sized slices. Each sequence
of slices includes the end of that sequence.
Args:
offsets: The offsets to samples to get shards for (must be strictly increasing)
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequential).
Returns:
A list of offsets for each slice sequence.
"""
for
start
,
end
in
zip
(
offsets
[:
-
1
],
offsets
[
1
:]):
yield
(
*
cls
.
_split_shard
(
start_offset
=
start
,
end_offset
=
end
,
max_samples_per_sequence
=
max_samples_per_sequence
,
),
end
,
)
@
classmethod
def
_generalized_bit_reversal
(
cls
,
length_or_indices
:
Union
[
int
,
Sequence
[
int
]]
)
->
Sequence
[
int
]:
"""This function creates a permutation of given length.
The sequence is created by a recursive divide and interleave algorithm
to ensure a balanced distribution across ranks.
It corresponds to a generalized bit reversal permutation, which - for lengths
of power of two - is the reversed binary representation of the original indices.
For example for 16 indices, the sequence is:
[0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]
Visual illustration:
Step|0|1|2|3|4|5|6|7|8|9|A|B|C|D|E|F|
|-------------------------------|
0|X| | | | | | | | | | | | | | | |
1|X| | | | | | | |X| | | | | | | |
2|X| | | |X| | | |X| | | | | | | |
3|X| | | |X| | | |X| | | |X| | | |
4|X| |X| |X| | | |X| | | |X| | | |
5|X| |X| |X| | | |X| |X| |X| | | |
6|X| |X| |X| |X| |X| |X| |X| | | |
7|X| |X| |X| |X| |X| |X| |X| |X| |
8|X|X|X| |X| |X| |X| |X| |X| |X| |
9|X|X|X| |X| |X| |X|X|X| |X| |X| |
10|X|X|X| |X|X|X| |X|X|X| |X| |X| |
11|X|X|X| |X|X|X| |X|X|X| |X|X|X| |
12|X|X|X|X|X|X|X| |X|X|X| |X|X|X| |
13|X|X|X|X|X|X|X| |X|X|X|X|X|X|X| |
14|X|X|X|X|X|X|X|X|X|X|X|X|X|X|X| |
15|X|X|X|X|X|X|X|X|X|X|X|X|X|X|X|X|
"""
if
isinstance
(
length_or_indices
,
int
):
indices
=
list
(
range
(
length_or_indices
))
else
:
indices
=
length_or_indices
if
len
(
indices
)
<=
2
:
return
indices
mid
=
len
(
indices
)
//
2
left
=
indices
[:
mid
]
right
=
indices
[
mid
:]
left_result
=
cls
.
_generalized_bit_reversal
(
left
)
right_result
=
cls
.
_generalized_bit_reversal
(
right
)
# Interleave the results
zipped
=
zip_longest
(
left_result
,
right_result
)
result
=
[
item
for
sublist
in
zipped
for
item
in
sublist
if
item
is
not
None
]
return
result
@
classmethod
def
split_samples_to_workers
(
cls
,
start_samples
:
int
,
end_samples
:
int
,
worker_config
:
WorkerConfig
,
*
,
rotation_offset
:
int
=
0
,
)
->
Sequence
[
int
]:
# We split the total number of samples into the number of global workers across all ranks.
# Note that the global number of workers intentionally stays the same if you
# divide the number of ranks by N, and multiply the number of workers per rank by N.
# This allows to reproduce the same global batches with a different number of ranks.
total_samples
=
end_samples
-
start_samples
num_workers
=
max
(
1
,
worker_config
.
num_workers
)
global_workers
=
num_workers
*
worker_config
.
world_size
min_samples_per_worker
=
int
(
total_samples
/
global_workers
)
num_workers_with_more_samples
=
total_samples
%
global_workers
# We are going to compute the samples assigned to each worker on the current rank.
# This is done in multiple steps.
# Some of these steps could be collapsed into one, but we keep them separate for clarity:
# 1. Compute the number of samples per global worker (rotated by rotation_offset,
# typically given by previous datasets).
# 2. Permute the nuber of samples per global worker by a generalized bit reversal sequence
# 3. Given the sample counts, compute the start and end indices for each global worker
# 4. Extract the local worker sample assignments for the current rank.
# 5. Split the shards based on the start and end indices.
# 1. Let's compute it globally for all workers first
num_samples_per_global_worker
=
[]
for
global_worker_idx
in
range
(
global_workers
):
if
(
global_worker_idx
-
rotation_offset
+
global_workers
)
%
global_workers
<
num_workers_with_more_samples
:
# This worker gets one more sample
num_samples_per_global_worker
.
append
(
min_samples_per_worker
+
1
)
else
:
# This worker gets the minimum number of samples
num_samples_per_global_worker
.
append
(
min_samples_per_worker
)
# 2. Permute the number of samples per global worker
worker_bitrev_seq
=
cls
.
_generalized_bit_reversal
(
global_workers
)
# The worker_bitrev_seq is the order in which any remainder samples shall
# be assigned to workers.
# That means, the x-axis (array index) is the remainder sample index
# and the y-axis (value) is the global worker index.
# So we map the y (value) to the old global worker index from the linear sequence.
new_num_samples_per_global_worker
=
[
-
1
]
*
global_workers
for
old_worker_idx
,
new_worker_idx
in
enumerate
(
worker_bitrev_seq
):
new_num_samples_per_global_worker
[
new_worker_idx
]
=
num_samples_per_global_worker
[
old_worker_idx
]
num_samples_per_global_worker
=
new_num_samples_per_global_worker
# 3. Compute the global worker sample start and end indices
global_worker_sample_split_offsets
=
[
start_samples
]
cur_offset
=
start_samples
for
global_worker_idx
in
range
(
global_workers
):
cur_offset
+=
num_samples_per_global_worker
[
global_worker_idx
]
global_worker_sample_split_offsets
.
append
(
cur_offset
)
# 4. Now we extract the local rank's worker ranges
local_worker_sample_split_offsets
=
global_worker_sample_split_offsets
[
worker_config
.
rank
*
num_workers
:
(
worker_config
.
rank
+
1
)
*
num_workers
+
1
]
assert
len
(
local_worker_sample_split_offsets
)
==
num_workers
+
1
,
(
"If this fails, there's a bug in the code above."
)
return
local_worker_sample_split_offsets
@
staticmethod
def
_clean_offsets
(
offsets
:
Sequence
[
int
])
->
Sequence
[
int
]:
"""Removes empty offset slices, i.e. duplicates from offsets."""
return
(
*
(
int
(
start
)
for
start
,
end
in
zip
(
offsets
,
offsets
[
1
:])
if
start
<
end
),
int
(
offsets
[
-
1
]),
)
@
staticmethod
def
_compute_subset
(
total_samples
:
int
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
)
->
tuple
[
int
,
int
]:
start_samples
=
0
end_samples
=
total_samples
if
subset
is
None
:
return
start_samples
,
end_samples
if
subset
.
absolute_range
is
not
None
:
start_samples
,
end_samples
=
subset
.
absolute_range
if
end_samples
is
None
:
end_samples
=
total_samples
assert
end_samples
<=
total_samples
,
(
f
"Subset samples
{
subset
.
absolute_range
}
{
end_samples
=
}
>
{
total_samples
=
}
"
)
assert
start_samples
<=
end_samples
,
(
f
"Subset samples
{
subset
.
absolute_range
}
{
start_samples
=
}
>
{
end_samples
=
}
"
)
assert
start_samples
>=
0
,
(
f
"Subset samples
{
subset
.
absolute_range
}
{
start_samples
=
}
< 0"
)
if
subset
.
range
is
not
None
:
previous_total
=
end_samples
-
start_samples
end_samples
=
start_samples
+
int
(
previous_total
*
subset
.
range
[
1
])
start_samples
+=
int
(
previous_total
*
subset
.
range
[
0
])
assert
end_samples
<=
total_samples
,
(
f
"Subset ratio
{
subset
.
range
}
{
end_samples
=
}
is larger than total samples
{
total_samples
}
"
)
assert
start_samples
<=
end_samples
,
(
f
"Subset ratio
{
subset
.
range
}
{
start_samples
=
}
>
{
end_samples
=
}
"
)
assert
start_samples
>=
0
,
f
"Subset ratio
{
subset
.
range
}
{
start_samples
=
}
< 0"
return
start_samples
,
end_samples
@
classmethod
def
shard_workers
(
cls
,
shards
:
Sequence
[
ShardInfo
],
worker_config
:
WorkerConfig
,
*
,
max_samples_per_sequence
:
Optional
[
int
],
subset
:
Optional
[
DatasetSubset
]
=
None
,
rotation_offset
:
int
=
0
,
)
->
Sequence
[
Sequence
[
int
]]:
"""
Creates shard slices for each worker of the current rank.
For that, the number of global samples is split across the number of global workers across all
ranks. Then each worker gets a slice of the global samples.
Args:
shards: The shards to split
worker_config: The config for the current rank and workers
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequential).
subset: If specified, the dataset will be subsetted to the given ratio.
rotation_offset: The offset to use for the worker rotation.
Returns:
The shards for the current rank and all workers
"""
end_samples
=
sum
(
shard
.
count
for
shard
in
shards
)
if
subset
is
not
None
:
start_samples
,
end_samples
=
subset
.
compute_subset
(
end_samples
)
else
:
start_samples
=
0
local_worker_sample_split_offsets
=
cls
.
split_samples_to_workers
(
start_samples
,
end_samples
,
worker_config
,
rotation_offset
=
rotation_offset
,
)
shard_cumsums
=
np
.
cumsum
([
0
]
+
[
shard
.
count
for
shard
in
shards
])
return
tuple
(
# Filter out any empty shards for this worker
cls
.
_clean_offsets
(
offsets
)
for
offsets
in
cls
.
_split_shards
(
shard_cumsums
,
local_worker_sample_split_offsets
,
max_samples_per_sequence
=
max_samples_per_sequence
,
)
)
@
classmethod
def
slice_workers
(
cls
,
total_samples
:
int
,
worker_config
:
WorkerConfig
,
*
,
max_samples_per_sequence
:
Optional
[
int
],
subset
:
Optional
[
DatasetSubset
]
=
None
,
rotation_offset
:
int
=
0
,
)
->
Sequence
[
Sequence
[
int
]]:
"""
Creates shard slices for each worker of the current rank.
For that, the number of global samples is split across the number of global workers across all
ranks. Then each worker gets a slice of the global samples.
Args:
total_samples: The total number of samples
worker_config: The config for the current rank and workers
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequential).
subset: If specified, the dataset will be subsetted to the given ratio.
rotation_offset: The offset to use for the worker rotation.
Returns:
The shards for the current rank and all workers
"""
start_samples
,
end_samples
=
cls
.
_compute_subset
(
total_samples
,
subset
)
local_worker_sample_split_offsets
=
cls
.
split_samples_to_workers
(
start_samples
,
end_samples
,
worker_config
,
rotation_offset
=
rotation_offset
,
)
# Split the shards
return
tuple
(
# Filter out any empty shards for this worker
cls
.
_clean_offsets
(
offsets
)
for
offsets
in
cls
.
_split_slices
(
local_worker_sample_split_offsets
,
max_samples_per_sequence
=
max_samples_per_sequence
,
)
)
Megatron-Energon/src/megatron/energon/flavors/webdataset/standard_webdataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Generic
,
Type
,
TypeVar
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset.default_decoder_webdataset
import
(
DefaultDecoderWebdatasetFactory
,
)
T_sample
=
TypeVar
(
"T_sample"
,
covariant
=
True
)
class
StandardWebdatasetFactory
(
DefaultDecoderWebdatasetFactory
[
T_sample
],
Generic
[
T_sample
]):
"""
This dataset sample loader factory uses the sample type e.g. given from a dataset.yaml, and applies the default
loading logic, which includes decoding images, videos and containers.
"""
def
__init__
(
self
,
path
:
EPath
,
*
,
sample_type
:
Type
[
T_sample
],
**
kwargs
,
):
"""
Factory for the standard webdataset sample loader.
Args:
path: Path to the dataset (passed to parent)
sample_type: Type of the sample to be loaded
auto_decode: If true, use the default webdataset sample decoder.
image_decode: This defines the decoding results.
ignore_decoder_errors: If true, ignore errors when decoding.
subflavors: Subflavors dictionary to set for all loaded samples.
field_map: Mapping from the webdataset fields to the sample fields.
sample_loader: Function to load the sample from the webdataset fields. May be a string
in order to load a function from a module, or a callable directly.
part_filter: Filter for the parts to load. May be a string in order to load a function
from a module, or a callable directly.
split_part: Which part to load (e.g. 'train', 'val', 'test').
training: If true, apply shuffling and loop the dataset.
worker_config: Configuration for the workers.
shuffle_over_epochs: Only effective if training=True.
How many epochs to shuffle over if training.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather shard slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices
are drawn with replacement).
parallel_shard_iters: Number of parallel opened shards per worker, shuffling between.
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequentially iterated).
subset: If specified, the dataset will be subsetted.
split_config: Config file to use for shard split definitions.
handler: Exception handler. Args: (exception, key).
"""
self
.
__sample_type__
=
sample_type
assert
issubclass
(
sample_type
,
Sample
),
(
f
"sample_type must be subclass of megatron.energon.Sample, got
{
sample_type
.
__qualname__
}
"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/webdataset/structs.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
dataclasses
import
field
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.source_info
import
SourceInfo
@
edataclass
class
WebdatasetInfo
:
"""Info about a webdataset. Format for `.nv-meta/.info.yaml` or `.nv-meta/.info.json`."""
#: The version of the energon library that was used to prepare the dataset
energon_version
:
Optional
[
str
]
=
None
#: Maps shard name to number of samples in that shard
shard_counts
:
Dict
[
str
,
int
]
@
edataclass
class
WebdatasetSplits
:
"""Info about the splits of a webdataset. Format for `.nv-meta/split.yaml` or `.nv-meta/split.json`
(or custom user yaml/json)."""
#: Maps split part to list of shard names
split_parts
:
Dict
[
str
,
List
[
str
]]
#: Set of "<shard name>" or "<shard name>/<sample index>" to exclude
exclude
:
List
[
str
]
=
field
(
default_factory
=
list
)
@
edataclass
class
ShardInfo
:
"""Info about a single shard as passed through internally. Not exposed to the user."""
#: Name of the shard file (relative path from the nvinfo dir)
name
:
str
#: The path to the shard file
path
:
EPath
#: The number of samples in this shard
count
:
int
class
FilteredSample
(
TypedDict
):
"""This is just a definition for the internal loaders. Not exposed to the user."""
#: The key of the sample within the tar file.
#: If the tar file contains files 12.jpg and 12.txt,
#: those two files make one sample with the key "12"
__key__
:
str
#: The base name of the shard file e.g. "shard_000"
__shard__
:
str
#: Globally unique key to restore a sample from disk.
#: For example `("Webdataset", 123)` would restore the sample at index 123.
__restore_key__
:
Tuple
[
str
,
int
]
#: The source information for the sample.
__sources__
:
tuple
[
SourceInfo
,
...]
@
edataclass
class
DatasetSubset
:
"""A subset of a dataset.
A range is a tuple of two values, where the first value is the start of the subset and the second value is the end of the subset.
The sharder uses the (absolute/relative) ranges to compute the subsets:
* `absolute_range` (unit is samples) is applied first on the (e.g. train/val/test) subset
* then `range` (where `(0, 1)` would correspond to the whole dataset) is applied as relative ratio on the subset that is left.
This is the struct used internally for computing the range. The config is loaded via the metadataset_v2.
"""
range
:
tuple
[
float
,
float
]
|
None
=
None
absolute_range
:
tuple
[
int
,
int
|
None
]
|
None
=
None
def
compute_subset
(
self
,
total_samples
:
int
,
)
->
tuple
[
int
,
int
]:
"""
Computes the absolute subset of samples from the total number of samples.
The absolute range is applied first, then the relative range is applied on the subset that is left.
"""
start_samples
=
0
end_samples
=
total_samples
if
self
.
absolute_range
is
not
None
:
start_samples
,
end_samples
=
self
.
absolute_range
if
end_samples
is
None
:
end_samples
=
total_samples
assert
end_samples
<=
total_samples
,
(
f
"Subset samples
{
self
.
absolute_range
}
{
end_samples
=
}
>
{
total_samples
=
}
"
)
assert
start_samples
<=
end_samples
,
(
f
"Subset samples
{
self
.
absolute_range
}
{
start_samples
=
}
>
{
end_samples
=
}
"
)
assert
start_samples
>=
0
,
f
"Subset samples
{
self
.
absolute_range
}
{
start_samples
=
}
< 0"
if
self
.
range
is
not
None
:
previous_total
=
end_samples
-
start_samples
end_samples
=
start_samples
+
int
(
previous_total
*
self
.
range
[
1
])
start_samples
+=
int
(
previous_total
*
self
.
range
[
0
])
assert
end_samples
<=
total_samples
,
(
f
"Subset ratio
{
self
.
range
}
{
end_samples
=
}
is larger than total samples
{
total_samples
}
"
)
assert
start_samples
<=
end_samples
,
(
f
"Subset ratio
{
self
.
range
}
{
start_samples
=
}
>
{
end_samples
=
}
"
)
assert
start_samples
>=
0
,
f
"Subset ratio
{
self
.
range
}
{
start_samples
=
}
< 0"
return
start_samples
,
end_samples
def
config
(
self
)
->
dict
:
return
{
"range"
:
self
.
range
,
"absolute_range"
:
self
.
absolute_range
,
}
def
reraise_exception
(
exc
:
Exception
,
key
:
Optional
[
str
],
sources
:
Optional
[
list
[
SourceInfo
]]
=
None
)
->
None
:
if
sources
:
raise
Exception
(
f
"For sample
{
key
!
r
}
from
{
', '
.
join
(
f
'
{
source
.
dataset_path
}
[
{
source
.
index
}
]
{
source
.
shard_name
}{
source
.
file_names
!
r
}
' for source in sources)
}
"
)
from
exc
elif
key
:
raise
Exception
(
f
"For sample
{
key
!
r
}
"
)
from
exc
else
:
raise
Megatron-Energon/src/megatron/energon/flavors/webdataset/thread_local_sqlite.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
os
import
random
import
sqlite3
import
threading
import
time
from
typing
import
Any
,
ClassVar
class
ThreadLocalStorage
:
"""
A class that allows to store data in a thread-local storage.
Example Usage:
```python
class MyThreadLocalStorage(ThreadLocalStorage):
__thread_local__ = ("my_data",)
# This is shared across threads
other_data: int
# This is local per thread
my_data: int
def __thread_init__(self):
# This is called when the data on a thread is initialized, which has
# not been accessed yet on that thread to set the value of that data.
self.my_data = 0
```
"""
__thread_local__
:
ClassVar
[
tuple
[
str
,
...]]
_storage
:
object
def
__init__
(
self
):
self
.
_storage
=
threading
.
local
()
def
__getattribute__
(
self
,
name
:
str
)
->
Any
:
if
name
in
(
"__thread_local__"
,
"_storage"
):
return
object
.
__getattribute__
(
self
,
name
)
if
name
in
self
.
__thread_local__
:
if
not
self
.
_thread_initialized
:
self
.
_storage
.
__initialized__
=
True
self
.
__thread_init__
()
return
getattr
(
self
.
_storage
,
name
)
return
object
.
__getattribute__
(
self
,
name
)
def
__delattr__
(
self
,
name
:
str
)
->
None
:
if
name
in
self
.
__thread_local__
:
delattr
(
self
.
_storage
,
name
)
return
object
.
__delattr__
(
self
,
name
)
def
__setattr__
(
self
,
name
:
str
,
value
:
Any
)
->
None
:
if
name
in
self
.
__thread_local__
:
if
not
self
.
_thread_initialized
:
self
.
_storage
.
__initialized__
=
True
self
.
__thread_init__
()
setattr
(
self
.
_storage
,
name
,
value
)
return
object
.
__setattr__
(
self
,
name
,
value
)
@
property
def
_thread_initialized
(
self
)
->
bool
:
"""Check if the thread has been initialized."""
return
getattr
(
self
.
_storage
,
"__initialized__"
,
False
)
def
thread_close
(
self
):
"""Close the thread-local storage."""
if
self
.
_thread_initialized
:
delattr
(
self
.
_storage
,
"__initialized__"
)
def
__thread_init__
(
self
):
"""Called when the data on a thread is accessed for the first time, to
set the initial value of that data."""
# Copy the data from the default values
for
name
in
self
.
__thread_local__
:
try
:
default_value
=
object
.
__getattribute__
(
self
,
name
)
except
AttributeError
:
pass
else
:
setattr
(
self
.
_storage
,
name
,
default_value
)
class
ThreadLocalSqlite
(
ThreadLocalStorage
):
"""A class that allows to store data in a thread-local storage."""
database
:
str
is_uri
:
bool
__thread_local__
=
(
"connection"
,
"cursor"
)
connection
:
sqlite3
.
Connection
cursor
:
sqlite3
.
Cursor
def
__init__
(
self
,
database
:
str
,
is_uri
:
bool
=
False
):
super
().
__init__
()
self
.
database
=
database
self
.
is_uri
=
is_uri
def
__thread_init__
(
self
):
"""Initialize the connection and cursor."""
self
.
connection
=
sqlite3
.
connect
(
self
.
database
,
uri
=
self
.
is_uri
)
self
.
cursor
=
self
.
connection
.
cursor
()
self
.
connection
.
execute
(
"PRAGMA busy_timeout = 5000;"
)
def
select_one
(
self
,
query
:
str
,
params
:
tuple
[
Any
,
...]
=
()):
"""Select one row from the database."""
self
.
cursor
.
execute
(
query
,
params
)
return
self
.
cursor
.
fetchone
()
def
select_all
(
self
,
query
:
str
,
params
:
tuple
[
Any
,
...]
=
()):
"""Select all rows from the database."""
self
.
cursor
.
execute
(
query
,
params
)
return
self
.
cursor
.
fetchall
()
def
thread_close
(
self
):
"""Close the connection and cursor."""
if
self
.
_thread_initialized
:
self
.
cursor
.
close
()
self
.
connection
.
close
()
super
().
thread_close
()
def
main
():
"""Test the ThreadLocalSqlite class."""
import
concurrent.futures
sqlite
=
ThreadLocalSqlite
(
"tmp.sqlite"
)
sqlite
.
cursor
.
execute
(
"CREATE TABLE IF NOT EXISTS test (id INTEGER PRIMARY KEY, name TEXT)"
)
sqlite
.
cursor
.
execute
(
"INSERT INTO test (name) VALUES (?)"
,
(
"test"
,))
sqlite
.
connection
.
commit
()
def
_test_thread_local
(
sqlite_thread_local
:
ThreadLocalSqlite
):
time
.
sleep
(
random
.
random
())
print
(
sqlite_thread_local
.
select_all
(
"SELECT * FROM test"
))
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
10
)
as
executor
:
futures
=
[]
for
_
in
range
(
20
):
futures
.
append
(
executor
.
submit
(
_test_thread_local
,
sqlite
))
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
future
.
result
()
os
.
remove
(
"tmp.sqlite"
)
if
__name__
==
"__main__"
:
main
()
Megatron-Energon/src/megatron/energon/fork_hook.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
functools
import
os
import
weakref
from
dataclasses
import
dataclass
from
typing
import
Callable
def
_cleanup
(
hooks
,
key
,
wr
):
hooks
.
pop
(
key
)
class
WeakCallbacks
:
"""
A class that manages weak references to callback functions.
"""
# A dictionary of weak (or strong) references to functions.
_hooks
:
dict
[
int
,
Callable
[[],
Callable
[...,
None
]
|
None
]]
def
__init__
(
self
):
"""
Initialize the registry.
"""
self
.
_hooks
:
dict
[
int
,
Callable
[[],
Callable
[...,
None
]
|
None
]]
=
{}
def
add_hook
(
self
,
callable
:
Callable
[...,
None
],
make_persistent
:
bool
=
False
)
->
None
:
"""
Add a callback to the registry.
Args:
callable: The function to run before the fork of a worker process.
make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
"""
if
make_persistent
:
# Not a weakref, but always return the callable.
self
.
_hooks
[
id
(
callable
)]
=
lambda
:
callable
elif
getattr
(
callable
,
"__self__"
,
None
):
# Add a method reference to the hooks
key
=
id
(
callable
.
__self__
)
self
.
_hooks
[
key
]
=
weakref
.
WeakMethod
(
callable
,
functools
.
partial
(
_cleanup
,
self
.
_hooks
,
key
)
)
else
:
# Add a function reference to the hooks
key
=
id
(
callable
)
self
.
_hooks
[
key
]
=
weakref
.
ref
(
callable
,
functools
.
partial
(
_cleanup
,
self
.
_hooks
,
key
))
def
run
(
self
,
*
args
,
**
kwargs
)
->
None
:
"""
Run all the callbacks in the registry, passing the given arguments.
"""
for
hook
in
self
.
_hooks
.
values
():
ref
=
hook
()
if
ref
is
not
None
:
ref
(
*
args
,
**
kwargs
)
_after_in_child_fork_hooks
=
WeakCallbacks
()
_after_in_parent_fork_hooks
=
WeakCallbacks
()
_before_fork_hooks
=
WeakCallbacks
()
def
before_fork_hook
(
callable
:
Callable
[[],
None
],
make_persistent
:
bool
=
False
):
"""
Run function before the fork of a worker process.
The function must be persistent (i.e. not a lambda) or an instance method.
Args:
callable: The function to run before the fork of a worker process.
make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
"""
_before_fork_hooks
.
add_hook
(
callable
,
make_persistent
)
def
after_in_parent_fork_hook
(
callable
:
Callable
[[],
None
],
make_persistent
:
bool
=
False
):
"""
Run function after the fork of a worker process.
The function must be persistent (i.e. not a lambda) or an instance method.
Args:
callable: The function to run after the fork of a worker process.
make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
"""
_after_in_parent_fork_hooks
.
add_hook
(
callable
,
make_persistent
)
def
after_in_child_fork_hook
(
callable
:
Callable
[[],
None
],
make_persistent
:
bool
=
False
):
"""
Run function after the fork of a worker process.
The function must be persistent (i.e. not a lambda) or an instance method.
Args:
callable: The function to run after the fork of a worker process.
make_persistent: If True, the function will be stored as a strong reference, otherwise a weak reference is used.
"""
_after_in_child_fork_hooks
.
add_hook
(
callable
,
make_persistent
)
class
ForkMixin
:
"""
A mixin that runs a method after the fork of a worker process.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
__post_init__
()
def
__post_init__
(
self
):
if
getattr
(
self
.
__before_fork__
,
"__func__"
,
None
)
is
not
ForkMixin
.
__before_fork__
:
before_fork_hook
(
self
.
__before_fork__
)
if
(
getattr
(
self
.
__after_in_child_fork__
,
"__func__"
,
None
)
is
not
ForkMixin
.
__after_in_child_fork__
):
after_in_child_fork_hook
(
self
.
__after_in_child_fork__
)
if
(
getattr
(
self
.
__after_in_parent_fork__
,
"__func__"
,
None
)
is
not
ForkMixin
.
__after_in_parent_fork__
):
after_in_parent_fork_hook
(
self
.
__after_in_parent_fork__
)
def
__after_in_child_fork__
(
self
):
"""
A method that runs after the fork in the child process.
"""
pass
def
__after_in_parent_fork__
(
self
):
"""
A method that runs after the fork in the parent process.
"""
pass
def
__before_fork__
(
self
):
"""
A method that runs before the fork of a worker process.
"""
pass
@
dataclass
class
DataclassForkMixin
:
"""
A mixin that runs a method after the fork of a worker process.
"""
def
__post_init__
(
self
):
if
(
getattr
(
self
.
__before_fork__
,
"__func__"
,
None
)
is
not
DataclassForkMixin
.
__before_fork__
):
before_fork_hook
(
self
.
__before_fork__
)
if
(
getattr
(
self
.
__after_in_child_fork__
,
"__func__"
,
None
)
is
not
DataclassForkMixin
.
__after_in_child_fork__
):
after_in_child_fork_hook
(
self
.
__after_in_child_fork__
)
if
(
getattr
(
self
.
__after_in_parent_fork__
,
"__func__"
,
None
)
is
not
DataclassForkMixin
.
__after_in_parent_fork__
):
after_in_parent_fork_hook
(
self
.
__after_in_parent_fork__
)
def
__after_in_child_fork__
(
self
):
"""
A method that runs after the fork in the child process.
"""
pass
def
__after_in_parent_fork__
(
self
):
"""
A method that runs after the fork in the parent process.
"""
pass
def
__before_fork__
(
self
):
"""
A method that runs before the fork of a worker process.
"""
pass
os
.
register_at_fork
(
before
=
_before_fork_hooks
.
run
,
after_in_child
=
_after_in_child_fork_hooks
.
run
,
after_in_parent
=
_after_in_parent_fork_hooks
.
run
,
)
Megatron-Energon/src/megatron/energon/loader.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Optional
,
TypeVar
from
megatron.energon.cache
import
CachePool
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors
import
SavableDataset
from
megatron.energon.savable_loader
import
BasicDataLoader
,
SavableDataLoader
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.gc_dataset
import
GC_DEFAULT_EVERY_N_ITER
T
=
TypeVar
(
"T"
)
def
get_savable_loader
(
dataset
:
SavableDataset
[
T
],
*
,
worker_config
:
Optional
[
WorkerConfig
]
=
None
,
checkpoint_every_sec
:
float
=
60
,
checkpoint_every_min_n_samples
:
Optional
[
int
]
=
None
,
n_checkpoints
:
Optional
[
int
]
=
None
,
gc_collect_every_n_steps
:
int
=
GC_DEFAULT_EVERY_N_ITER
,
prefetch_factor
:
int
=
2
,
cache_pool
:
Optional
[
CachePool
]
=
None
,
watchdog_timeout_seconds
:
Optional
[
float
]
=
60
,
watchdog_initial_timeout_seconds
:
Optional
[
float
]
=
None
,
fail_on_timeout
:
bool
=
False
,
)
->
SavableDataLoader
[
T
]:
"""
Get a dataloader for the given dataset.
Args:
dataset: The dataset to create a loader for.
worker_config: Deprecated. Please pass this to the dataset instead.
checkpoint_every_sec: This is the time in seconds after which an internal checkpoint is
saved. It may take the same duration to restore a checkpoint, but introduces additional
overhead during reading data from the dataset, so this should be chosen accordingly.
Only applies if using workers.
checkpoint_every_min_n_samples: Overwrites the minimum number of samples between
checkpoints. Defaults to `number of workers * 2`. Only applies if using workers.
n_checkpoints: The number of internal checkpoints to keep. Only applies if using workers.
If None, computes a suitable value.
cache_pool: If set, the cache pool to use for the dataset.
watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled.
watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.
fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace.
Returns:
The instantiated :class:`megatron.energon.SavableDataLoader`, yielding batches from the dataset,
allowing to save the state of the dataset.
"""
if
worker_config
is
not
None
:
if
worker_config
!=
dataset
.
worker_config
:
raise
AssertionError
(
"The worker_config passed to get_savable_loader() does not match the one of the dataset. "
"Also note, it is deprecated to pass one to get_savable_loader() and it will have no effect."
)
else
:
warn_deprecated
(
"Passing a worker_config to get_savable_loader() is deprecated and will have no effect."
)
return
SavableDataLoader
(
dataset
,
checkpoint_every_sec
=
checkpoint_every_sec
,
checkpoint_every_min_n_samples
=
checkpoint_every_min_n_samples
,
n_checkpoints
=
n_checkpoints
,
gc_collect_every_n_steps
=
gc_collect_every_n_steps
,
prefetch_factor
=
prefetch_factor
,
cache_pool
=
cache_pool
,
watchdog_timeout_seconds
=
watchdog_timeout_seconds
,
watchdog_initial_timeout_seconds
=
watchdog_initial_timeout_seconds
,
fail_on_timeout
=
fail_on_timeout
,
)
def
get_loader
(
dataset
:
SavableDataset
[
T
],
*
,
worker_config
:
Optional
[
WorkerConfig
]
=
None
,
prefetch_factor
:
int
=
2
,
cache_pool
:
Optional
[
CachePool
]
=
None
,
watchdog_timeout_seconds
:
Optional
[
float
]
=
60
,
watchdog_initial_timeout_seconds
:
Optional
[
float
]
=
None
,
fail_on_timeout
:
bool
=
False
,
)
->
BasicDataLoader
[
T
]:
"""
Get a dataloader for the given dataset.
Args:
dataset: The dataset to create a loader for.
worker_config: Deprecated. Please pass this to the dataset instead.
cache_pool: If set, the cache pool to use for the dataset.
watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled.
watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.
fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace.
Returns:
The instantiated :class:`torch.data.DataLoader`, yielding batches from the dataset.
"""
if
worker_config
is
not
None
:
if
worker_config
!=
dataset
.
worker_config
:
raise
AssertionError
(
"The worker_config passed to get_loader() does not match the one of the dataset. "
"Also note, it is deprecated to pass one to get_loader() and it will have no effect."
)
else
:
warn_deprecated
(
"Passing a worker_config to get_loader() is deprecated and will have no effect."
)
return
BasicDataLoader
(
dataset
,
prefetch_factor
=
prefetch_factor
,
cache_pool
=
cache_pool
,
watchdog_timeout_seconds
=
watchdog_timeout_seconds
,
watchdog_initial_timeout_seconds
=
watchdog_initial_timeout_seconds
,
fail_on_timeout
=
fail_on_timeout
,
)
Megatron-Energon/src/megatron/energon/local_copy.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
hashlib
from
pathlib
import
Path
from
filelock
import
FileLock
from
megatron.energon.epathlib
import
EPath
LOCAL_COPY_TMP_DIR
=
Path
(
"/tmp/energon_local_copy"
)
def
ensure_local_copy
(
path
:
EPath
)
->
EPath
:
"""If the path is not local, copy it to a temporary directory and return the
path to the temporary directory.
Args:
path: The path to the file to copy.
Returns:
The path to the local copy of the file or the original path if it is already local.
"""
assert
path
.
is_file
(),
f
"Path
{
path
}
is not a file"
if
path
.
is_local
():
return
path
LOCAL_COPY_TMP_DIR
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
digest
=
hashlib
.
sha256
(
str
(
path
).
encode
()).
hexdigest
()
final_path
=
LOCAL_COPY_TMP_DIR
/
f
"
{
digest
}
.bin"
lock_path
=
final_path
.
with_suffix
(
".lock"
)
tmp_path
=
final_path
.
with_suffix
(
".part"
)
# Block until lock is free
with
FileLock
(
lock_path
,
timeout
=
60
*
5
):
if
final_path
.
exists
():
# someone else already produced it
return
EPath
(
final_path
)
# We are the downloader
try
:
path
.
copy
(
EPath
(
tmp_path
))
tmp_path
.
rename
(
final_path
)
finally
:
tmp_path
.
unlink
(
missing_ok
=
True
)
return
EPath
(
final_path
)
Megatron-Energon/src/megatron/energon/metadataset/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.metadataset.dataset_loader
import
DatasetLoader
from
megatron.energon.metadataset.loader
import
load_dataset
,
prepare_metadataset
from
megatron.energon.metadataset.loader_interface
import
DatasetLoaderInterface
from
megatron.energon.metadataset.metadataset
import
Metadataset
from
megatron.energon.metadataset.metadataset_v2
import
MetadatasetV2
__all__
=
[
"DatasetLoader"
,
"load_dataset"
,
"DatasetLoaderInterface"
,
"Metadataset"
,
"MetadatasetV2"
,
"prepare_metadataset"
,
]
Megatron-Energon/src/megatron/energon/metadataset/dataset_loader.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Dict
,
Literal
,
Optional
,
Union
from
megatron.energon.dataset_config
import
get_dataset_from_config
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors
import
BaseCoreDatasetFactory
from
megatron.energon.flavors.webdataset.structs
import
DatasetSubset
from
megatron.energon.metadataset.loader_interface
import
(
DatasetBlendMode
,
DatasetLoaderInterface
,
LoadedDataset
,
LoadedDatasetList
,
)
from
megatron.energon.worker
import
WorkerConfig
@
edataclass
class
DatasetLoader
(
DatasetLoaderInterface
):
"""Loads a dataset from a path."""
path
:
Union
[
str
,
EPath
]
split_part
:
Optional
[
str
]
=
None
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
dataset_config
:
Optional
[
str
]
=
None
split_config
:
Optional
[
str
]
=
None
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
):
pass
def
get_dataset
(
self
,
*
,
training
:
bool
,
split_part
:
Optional
[
str
]
=
None
,
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs
:
Optional
[
int
]
=
1
,
split_config
:
Optional
[
str
]
=
None
,
dataset_config
:
Optional
[
str
]
=
None
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
BaseCoreDatasetFactory
:
"""
Args:
training: If true, apply training randomization.
split_part: Default split part to use.
worker_config: Worker configuration.
shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding).
subflavors: Subflavors to use, might be overridden by inner datasets.
shuffle_over_epochs: Shuffle the dataset over this many epochs.
subset: If specified, the inner dataset(s) will be subsetted.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The loaded dataset
"""
if
self
.
split_part
is
not
None
:
split_part
=
self
.
split_part
if
split_part
is
None
:
raise
ValueError
(
"Missing split part"
)
if
self
.
subflavors
is
not
None
:
subflavors
=
{
**
self
.
subflavors
,
**
(
subflavors
or
{})}
if
split_config
is
None
:
split_config
=
self
.
split_config
if
dataset_config
is
None
:
dataset_config
=
self
.
dataset_config
return
get_dataset_from_config
(
self
.
path
,
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
dataset_config
=
dataset_config
,
split_config
=
split_config
,
shuffle_over_epochs
=
shuffle_over_epochs
,
subset
=
subset
,
**
kwargs
,
)
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
return
LoadedDatasetList
(
blend_mode
=
DatasetBlendMode
.
NONE
,
datasets
=
[
LoadedDataset
(
dataset
=
self
.
get_dataset
(
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs
=
shuffle_over_epochs_multiplier
,
subset
=
subset
,
**
kwargs
,
),
weight
=
None
,
)
],
)
Megatron-Energon/src/megatron/energon/metadataset/join_dataset_loader.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
hashlib
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
,
)
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.flavors
import
(
BaseCoreDatasetFactory
,
BaseWebdatasetFactory
,
JoinedWebdatasetFactory
,
Sample
,
)
from
megatron.energon.flavors.webdataset.config
import
MAIN_FOLDER_NAME
from
megatron.energon.flavors.webdataset.indexing
import
JoinIndexWriter
from
megatron.energon.flavors.webdataset.metadata
import
WebdatasetMeta
from
megatron.energon.flavors.webdataset.structs
import
DatasetSubset
from
megatron.energon.metadataset.dataset_loader
import
DatasetLoader
from
megatron.energon.metadataset.loader_interface
import
(
DatasetBlendMode
,
DatasetLoaderInterface
,
LoadedDataset
,
LoadedDatasetList
,
)
from
megatron.energon.worker
import
WorkerConfig
def
filter_samples_with_excludes
(
conn
,
db_alias
:
str
,
meta
:
"JoinedDatasetMetaInfo"
,
):
"""
Filter the samples in the database with the given excludes.
"""
filtered_name
=
f
"
{
db_alias
}
_filtered"
conn
.
execute
(
f
"DROP VIEW IF EXISTS
{
filtered_name
}
"
)
if
not
meta
.
excludes
:
# Nothing to exclude, just use the original table
conn
.
execute
(
f
"CREATE TEMP VIEW
{
filtered_name
}
AS SELECT * FROM
{
db_alias
}
.samples"
)
return
# Split the excludes into shard-level excludes and sample-level excludes
excluded_shard_ids
=
[]
excluded_sample_keys
=
[]
for
exclude
in
meta
.
excludes
:
if
exclude
in
meta
.
shard_name_to_info_idx
:
excluded_shard_ids
.
append
(
meta
.
shard_name_to_info_idx
[
exclude
])
else
:
# Find the shard name for the sample key
# Trivial split by .tar/
if
".tar/"
in
exclude
:
tarname
,
sample_key
=
exclude
.
split
(
".tar/"
,
1
)
shard_idx
=
meta
.
shard_name_to_info_idx
[
tarname
+
".tar"
]
excluded_sample_keys
.
append
((
shard_idx
,
sample_key
))
elif
exclude
.
endswith
(
".tar"
):
# This is a shard and was probably already excluded outside this function
pass
else
:
raise
ValueError
(
f
"Invalid exclusion: Cannot split exclude
{
exclude
}
into shard and sample key"
)
# Create a temporary table for the shard excludes
# The key will be integers according to the tar_file_id column of the samples table
conn
.
execute
(
f
"DROP TABLE IF EXISTS temp_shard_excludes_
{
db_alias
}
"
)
conn
.
execute
(
f
"""
CREATE TEMP TABLE temp_shard_excludes_
{
db_alias
}
(
exclude_key INTEGER PRIMARY KEY
)
"""
)
for
shard_id
in
excluded_shard_ids
:
conn
.
execute
(
f
"INSERT INTO temp_shard_excludes_
{
db_alias
}
(exclude_key) values (?)"
,
(
shard_id
,)
)
# Create a temporary table for the sample excludes
conn
.
execute
(
f
"DROP TABLE IF EXISTS temp_sample_excludes_
{
db_alias
}
"
)
conn
.
execute
(
f
"""
CREATE TEMP TABLE temp_sample_excludes_
{
db_alias
}
(
shard_idx INTEGER,
exclude_key TEXT,
PRIMARY KEY (shard_idx, exclude_key)
)
"""
)
conn
.
executemany
(
f
"INSERT INTO temp_sample_excludes_
{
db_alias
}
(shard_idx, exclude_key) values (?, ?)"
,
[(
shard_idx
,
sample_key
)
for
shard_idx
,
sample_key
in
excluded_sample_keys
],
)
# Create view for filtered samples
conn
.
execute
(
f
"""
CREATE TEMP VIEW
{
filtered_name
}
AS
SELECT *
FROM
{
db_alias
}
.samples s
WHERE s.tar_file_id NOT IN (
SELECT exclude_key
FROM temp_shard_excludes_
{
db_alias
}
)
AND NOT EXISTS (
SELECT 1
FROM temp_sample_excludes_
{
db_alias
}
e
WHERE e.shard_idx = s.tar_file_id
AND e.exclude_key = s.sample_key
)
"""
)
def
join_multiple_indices
(
meta_infos
:
List
[
"JoinedDatasetMetaInfo"
],
output_join_index_path
:
EPath
,
):
"""
Joins the 'samples' table of one primary_db with multiple secondary_dbs
by 'sample_key'. For each secondary DB, we select three columns:
- tar_file_id
- byte_offset
- byte_size
The result is streamed out row-by-row and written to join index.
Note that the order of samples is determined by the shard_map of the primary DB.
Args:
meta_infos: List of meta infos for all datasets.
output_join_index_path: Path to the output join index.
"""
primary
=
meta_infos
[
0
]
secondaries
=
meta_infos
[
1
:]
assert
primary
.
nonmatch
==
"error"
,
(
"Primary join dataset must have nonmatch set 'error' (default)"
)
import
sqlite3
# 1. Connect to the primary DB in 'main'
conn
=
sqlite3
.
connect
(
f
"file:
{
primary
.
db_path
!
s
}
?mode=ro"
,
uri
=
True
)
# For safety, enable a read-only or big timeouts
conn
.
execute
(
"PRAGMA busy_timeout = 5000;"
)
conn
.
execute
(
"PRAGMA journal_mode = WAL;"
)
# 2. Attach each secondary DB under a unique alias, e.g. db1, db2, ...
secondary_aliases
=
[]
for
i
,
sec_mi
in
enumerate
(
secondaries
,
start
=
1
):
alias
=
f
"db
{
i
}
"
secondary_aliases
.
append
(
alias
)
conn
.
execute
(
f
"ATTACH DATABASE ? AS
{
alias
}
"
,
(
f
"file:
{
sec_mi
.
db_path
}
?mode=ro"
,))
# Filter the primary and each secondary DB for excluded samples by creating
# a new VIEW for each
for
alias
,
mi
in
zip
([
"main"
]
+
secondary_aliases
,
meta_infos
):
filter_samples_with_excludes
(
conn
,
alias
,
mi
)
# Check each primary and secondary DB for duplicate sample_key values
for
alias
,
mi
in
zip
([
"main"
]
+
secondary_aliases
,
meta_infos
):
duplicates
=
conn
.
execute
(
f
"""
SELECT sample_key, COUNT(*) AS c
FROM
{
alias
}
_filtered
GROUP BY sample_key
HAVING c > 1
LIMIT 5
"""
).
fetchall
()
if
duplicates
:
raise
ValueError
(
f
"Can't join. Found duplicate sample keys in
{
mi
.
db_path
}
:
{
duplicates
}
"
)
# Create a temporary table to order the shards as in the current split config
conn
.
execute
(
"DROP TABLE IF EXISTS primary_order"
)
conn
.
execute
(
"""
CREATE TEMP TABLE primary_order (
tar_file_id INTEGER PRIMARY KEY,
split_index INTEGER
)
"""
)
conn
.
executemany
(
"INSERT INTO primary_order(tar_file_id, split_index) values (?, ?)"
,
((
n
,
i
)
for
i
,
n
in
enumerate
(
primary
.
split_part_oder
)),
)
# Map from tar_file_id to shard idx in the split part
tar_files_id_mapping
=
{}
for
alias
,
mi
in
zip
([
"main"
]
+
secondary_aliases
,
meta_infos
):
tar_files_id_mapping
[
alias
]
=
{
tar_file_id
:
shard_idx
for
shard_idx
,
tar_file_id
in
enumerate
(
mi
.
split_part_oder
)
}
# These are the columns we want to select in the main SQL query
select_cols
=
[
"main_filtered.tar_file_id AS main_tar_file_id"
,
"main_filtered.byte_offset AS main_byte_offset"
,
"main_filtered.byte_size AS main_byte_size"
,
]
for
i
,
alias
in
enumerate
(
secondary_aliases
,
start
=
1
):
select_cols
.
append
(
f
"
{
alias
}
_filtered.tar_file_id AS tar_file_id_
{
i
}
"
)
select_cols
.
append
(
f
"
{
alias
}
_filtered.byte_offset AS byte_offset_
{
i
}
"
)
select_cols
.
append
(
f
"
{
alias
}
_filtered.byte_size AS byte_size_
{
i
}
"
)
# Build the LEFT JOIN or INNER JOIN clauses
join_clauses
=
""
for
alias
,
mi
in
zip
(
secondary_aliases
,
secondaries
):
if
mi
.
nonmatch
==
"skip"
:
join_type
=
"INNER JOIN"
else
:
join_type
=
"LEFT JOIN"
join_clauses
+=
f
"
{
join_type
}
{
alias
}
_filtered ON main_filtered.sample_key =
{
alias
}
_filtered.sample_key"
# Construct the full SQL query
# We select three columns for the primary and each secondary DB
# Those are (tar_file_id, byte_offset, and byte_size)
# We join the secondary DBs to the primary DB using a LEFT JOIN, i.e.
# we keep all rows from the primary DB and add columns from the secondary DBs if available
# Finally, we also join the temporary shard order table to order the shards as in the split config.
# This join is done using an INNER JOIN, i.e. we only keep rows that have a matching shard index in the primary dataset,
# so we'll not include shards that come from other split parts
sql
=
f
"""
SELECT
{
", "
.
join
(
select_cols
)
}
FROM main_filtered
{
join_clauses
}
INNER JOIN primary_order o
ON main_tar_file_id = o.tar_file_id
ORDER BY o.split_index
"""
# 3. Execute the query; this returns a cursor we can iterate over row by row
cursor
=
conn
.
execute
(
sql
)
all_db_aliases
=
[
"main"
]
+
secondary_aliases
# 4. Write the results to a binary file join index file row by row
with
JoinIndexWriter
(
output_join_index_path
)
as
join_index_writer
:
# Example: We'll just show how to iterate the rows and pseudo-write them
num_rows
=
0
num_missing
=
[
0
]
*
len
(
meta_infos
)
for
row
in
cursor
:
# 'row' is a tuple of columns in the order of select_cols
join_tuples
=
[]
for
i
,
(
alias
,
meta_info
)
in
enumerate
(
zip
(
all_db_aliases
,
meta_infos
)):
tar_file_id
=
row
[
3
*
i
]
if
tar_file_id
is
None
:
# This column is missing in this secondary dataset
# How we handle this case depends on the nonmatch setting
if
meta_info
.
nonmatch
==
"none"
:
# The user accepts missing samples, we'll just add a dummy entry
join_tuples
.
append
((
-
1
,
-
1
,
-
1
))
num_missing
[
i
]
+=
1
elif
meta_info
.
nonmatch
==
"skip"
:
# The user wants to skip rows with missing samples.
# Skipping rows is already handled by the INNER JOIN above, so
# this case should not happen.
raise
AssertionError
(
f
"Join has encountered a missing sample: Sample key
{
row
[
0
]
}
missing from "
f
"
{
meta_info
.
db_path
}
, although nonmatch_skip is set"
)
else
:
# The user wants to raise an error on missing samples
raise
ValueError
(
f
"Join has encountered a missing sample: Sample key
{
row
[
0
]
}
missing from "
f
"
{
meta_info
.
db_path
}
, although neither nonmatch_none nor nonmatch_skip are set"
)
else
:
shard_idx
=
tar_files_id_mapping
[
alias
][
tar_file_id
]
byte_offset
=
row
[
3
*
i
+
1
]
byte_size
=
row
[
3
*
i
+
2
]
join_tuples
.
append
((
shard_idx
,
byte_offset
,
byte_size
))
else
:
# Each row contains (shard_idx, byte_offset, byte_size) for each secondary key.
join_index_writer
.
append
(
*
join_tuples
)
num_rows
+=
1
any_skip
=
any
(
mi
.
nonmatch
==
"skip"
for
mi
in
meta_infos
)
num_samples
=
conn
.
execute
(
"SELECT COUNT(*) FROM main_filtered INNER JOIN primary_order o ON main_filtered.tar_file_id = o.tar_file_id"
).
fetchone
()[
0
]
if
not
any_skip
:
# If no dataset has skipping active, we can check that the number of rows matches the number of samples in the primary DB
assert
num_rows
==
num_samples
,
(
f
"Number of rows in join index (
{
num_rows
}
) does not match number of samples in primary DB (
{
num_samples
}
)"
)
print
(
f
"Joined all
{
num_rows
}
samples"
)
else
:
print
(
f
"Joined
{
num_rows
}
/
{
num_samples
}
samples, skipped
{
num_samples
-
num_rows
}
samples due to join"
)
if
any
(
num_missing
):
print
(
f
"Non-matching samples filled with None for each dataset:
{
num_missing
}
"
)
conn
.
close
()
@
edataclass
class
JoinedDatasetInfo
:
"""Internal for passing the joined datasets."""
dataset
:
DatasetLoader
nonmatch
:
Literal
[
"skip"
,
"none"
,
"error"
]
@
edataclass
class
JoinedDatasetMetaInfo
:
"""Internal for passing the joined datasets."""
db_path
:
EPath
uuid
:
str
excludes
:
List
[
str
]
shard_name_to_info_idx
:
Dict
[
str
,
int
]
split_part_oder
:
List
[
int
]
nonmatch
:
Literal
[
"skip"
,
"none"
,
"error"
]
@
edataclass
class
JoinDatasetLoader
(
DatasetLoaderInterface
):
"""Loads a joined dataset from a path."""
datasets
:
Union
[
List
[
JoinedDatasetInfo
],
Dict
[
str
,
JoinedDatasetInfo
]]
joiner
:
Union
[
Type
[
Sample
],
Callable
[...,
Sample
]]
cache_path
:
Optional
[
EPath
]
=
None
split_part
:
Optional
[
str
]
=
None
split_config
:
Optional
[
str
]
=
None
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
def
_get_joined_meta
(
self
,
split_part
:
str
)
->
Tuple
[
EPath
,
List
[
JoinedDatasetMetaInfo
]]:
"""
Collect the metadata for the joined dataset.
Returns:
The hashfile path, and a list of the meta infos.
"""
# Get list of joinable datasets
datasets
=
self
.
datasets
if
isinstance
(
datasets
,
dict
):
datasets
=
list
(
datasets
.
values
())
meta_infos
:
List
[
JoinedDatasetMetaInfo
]
=
[]
for
dataset
in
datasets
:
print
(
f
" -
{
dataset
}
"
)
uuid_path
=
EPath
(
dataset
.
dataset
.
path
)
/
MAIN_FOLDER_NAME
/
"index.uuid"
try
:
uuid
=
uuid_path
.
read_text
()
except
FileNotFoundError
:
raise
FileNotFoundError
(
f
"Missing uuid file in
{
uuid_path
}
. You need to prepare the dataset "
"(with a recent version of energon). If you have already prepared the "
"dataset, it should be sufficient to run prepare with --tar-index-only."
)
db_path
=
EPath
(
dataset
.
dataset
.
path
)
/
MAIN_FOLDER_NAME
/
"index.sqlite"
# Precedence for split_part is:
# 1. Join dataset split part (overrides individual dataset split parts)
# 2. Individual dataset split part
# 3. If none of the above is set, use the split part of the surrounding meta dataset
cur_split_part
=
dataset
.
dataset
.
split_part
or
self
.
split_part
or
split_part
assert
cur_split_part
is
not
None
,
"Missing split part"
wds_meta
=
WebdatasetMeta
.
from_config
(
path
=
EPath
(
dataset
.
dataset
.
path
),
split_part
=
cur_split_part
,
split_config
=
dataset
.
dataset
.
split_config
,
)
shard_name_to_info_idx
=
{
name
:
i
for
i
,
name
in
enumerate
(
wds_meta
.
info_shard_files
)}
# Given wds_meta.split_part_files, translate their order to info idx IDs
split_part_oder
=
[
shard_name_to_info_idx
[
name
]
for
name
in
wds_meta
.
split_part_files
]
meta_infos
.
append
(
JoinedDatasetMetaInfo
(
db_path
=
db_path
,
uuid
=
uuid
,
excludes
=
list
(
wds_meta
.
sample_excludes
),
shard_name_to_info_idx
=
shard_name_to_info_idx
,
split_part_oder
=
split_part_oder
,
nonmatch
=
dataset
.
nonmatch
,
)
)
# Combine the hashes into a single hash by xor
hash
=
hashlib
.
sha256
()
for
meta_info
in
meta_infos
:
hash
.
update
(
b
"
\0
uuid="
)
hash
.
update
(
meta_info
.
uuid
.
encode
())
hash
.
update
(
b
"
\0
excludes="
)
for
exclude
in
meta_info
.
excludes
:
hash
.
update
(
exclude
.
encode
())
hash
.
update
(
b
"
\0
"
)
hash
.
update
(
f
"
\0
nonmatch=
{
meta_info
.
nonmatch
}
\0
"
.
encode
())
assert
self
.
cache_path
is
not
None
return
self
.
cache_path
/
f
"join_index_
{
hash
.
hexdigest
()
}
.bin"
,
meta_infos
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
):
assert
mds_path
is
not
None
self
.
cache_path
=
mds_path
.
parent
/
f
"
{
mds_path
.
name
}
.cache"
def
prepare
(
self
,
split_part
:
Optional
[
str
]
=
None
)
->
Sequence
[
EPath
]:
assert
self
.
cache_path
is
not
None
assert
split_part
is
not
None
join_index_path
,
meta_infos
=
self
.
_get_joined_meta
(
split_part
)
if
join_index_path
.
is_file
():
print
(
f
"Joined dataset already prepared at
{
join_index_path
}
and up-to-date"
)
return
(
join_index_path
,)
print
(
f
"Preparing joined dataset in
{
join_index_path
}
"
)
join_index_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
join_multiple_indices
(
meta_infos
=
meta_infos
,
output_join_index_path
=
join_index_path
,
)
return
(
join_index_path
,)
def
get_dataset
(
self
,
*
,
training
:
bool
,
split_part
:
Optional
[
str
]
=
None
,
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs
:
Optional
[
int
]
=
1
,
split_config
:
Optional
[
str
]
=
None
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
BaseCoreDatasetFactory
:
"""
Args:
training: If true, apply training randomization.
split_part: Default split part to use.
worker_config: Worker configuration.
shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding).
subflavors: Subflavors to use, might be overridden by inner datasets.
shuffle_over_epochs: Shuffle the dataset over this many epochs.
subset: If specified, the inner dataset(s) will be subsetted.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The loaded dataset
"""
if
self
.
split_config
is
not
None
:
split_config
=
self
.
split_config
if
self
.
split_part
is
not
None
:
split_part
=
self
.
split_part
if
split_part
is
None
:
raise
ValueError
(
"Missing split part"
)
if
self
.
subflavors
is
not
None
:
subflavors
=
{
**
self
.
subflavors
,
**
(
subflavors
or
{})}
join_index_path
,
_
=
self
.
_get_joined_meta
(
split_part
)
if
isinstance
(
self
.
datasets
,
list
):
inner_datasets
=
[
dataset
.
dataset
.
get_dataset
(
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs
=
shuffle_over_epochs
,
split_config
=
split_config
,
**
kwargs
,
)
for
dataset
in
self
.
datasets
]
assert
all
(
isinstance
(
d
,
BaseWebdatasetFactory
)
for
d
in
inner_datasets
),
(
"Can only merge webdatasets efficiently"
)
elif
isinstance
(
self
.
datasets
,
dict
):
inner_datasets
=
{
key
:
dataset
.
dataset
.
get_dataset
(
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs
=
shuffle_over_epochs
,
split_config
=
split_config
,
**
kwargs
,
)
for
key
,
dataset
in
self
.
datasets
.
items
()
}
assert
all
(
isinstance
(
d
,
BaseWebdatasetFactory
)
for
d
in
inner_datasets
.
values
()),
(
"Can only merge webdatasets efficiently"
)
else
:
raise
ValueError
(
"Invalid join type"
)
# Remove decoder from kwargs, it is already handled by the inner datasets
kwargs
.
pop
(
"decoder"
,
None
)
return
JoinedWebdatasetFactory
(
inner_datasets
=
inner_datasets
,
training
=
training
,
worker_config
=
worker_config
,
shuffle_over_epochs
=
shuffle_over_epochs
,
join_index
=
join_index_path
,
joiner
=
self
.
joiner
,
subset
=
subset
,
**
kwargs
,
)
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
return
LoadedDatasetList
(
blend_mode
=
DatasetBlendMode
.
NONE
,
datasets
=
[
LoadedDataset
(
dataset
=
self
.
get_dataset
(
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs
=
shuffle_over_epochs_multiplier
,
subset
=
subset
,
**
kwargs
,
),
weight
=
None
,
)
],
)
Megatron-Energon/src/megatron/energon/metadataset/loader.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
pathlib
import
Path
from
typing
import
Any
,
Union
from
megatron.energon.dataset_config
import
load_config
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.webdataset.metadata
import
EnergonDatasetType
,
get_dataset_type
from
megatron.energon.metadataset.dataset_loader
import
DatasetLoader
from
megatron.energon.metadataset.loader_interface
import
DatasetLoaderInterface
from
megatron.energon.metadataset.metadataset
import
Metadataset
from
megatron.energon.typed_converter
import
JsonParser
def
load_dataset
(
path
:
Union
[
str
,
EPath
,
Path
],
**
kwargs
,
)
->
DatasetLoaderInterface
:
"""Loads a (meta)dataset."""
if
isinstance
(
path
,
dict
):
mds
=
load_config
(
path
,
default_type
=
Metadataset
,
default_kwargs
=
dict
(
path
=
EPath
(
"/dict"
),
**
kwargs
),
)
return
mds
path
=
EPath
(
path
)
ds_type
=
get_dataset_type
(
path
)
if
ds_type
==
EnergonDatasetType
.
METADATASET
:
mds
=
load_config
(
path
,
default_type
=
Metadataset
,
default_kwargs
=
dict
(
path
=
path
,
**
kwargs
),
)
mds
.
post_initialize
()
return
mds
elif
ds_type
in
(
EnergonDatasetType
.
WEBDATASET
,
EnergonDatasetType
.
JSONL
):
ds
=
DatasetLoader
(
path
=
path
,
**
kwargs
)
ds
.
post_initialize
()
return
ds
else
:
raise
ValueError
(
f
"Invalid dataset at
{
path
}
"
)
class
MockJsonParser
(
JsonParser
):
"""Json Parser, which translates unknown objects to a mock class."""
def
_resolve_object
(
self
,
module_name
:
str
,
object_name
:
str
,
cls
:
type
,
is_type
:
bool
,
is_callable
:
bool
,
is_instantiating_class
:
bool
,
is_calling_function
:
bool
,
)
->
Any
:
try
:
return
super
().
_resolve_object
(
module_name
,
object_name
,
cls
,
is_type
,
is_callable
,
is_instantiating_class
,
is_calling_function
,
)
except
ModuleNotFoundError
:
class
MockObject
(
cls
):
def
__init__
(
self
,
*
_
,
**
__
):
pass
if
is_type
or
is_instantiating_class
:
return
MockObject
elif
is_callable
or
is_calling_function
:
return
MockObject
def
prepare_metadataset
(
path
:
EPath
):
from
megatron.energon.dataset_config
import
load_config
from
megatron.energon.metadataset.metadataset
import
Metadataset
meta_ds
=
load_config
(
path
,
default_type
=
Metadataset
,
default_kwargs
=
dict
(
path
=
path
),
parser
=
MockJsonParser
(
strict
=
True
),
)
meta_ds
.
post_initialize
()
meta_ds
.
prepare
()
Megatron-Energon/src/megatron/energon/metadataset/loader_interface.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
abc
import
ABC
,
abstractmethod
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Union
from
megatron.energon.cache
import
FileStore
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.base_dataset
import
BaseCoreDatasetFactory
from
megatron.energon.flavors.webdataset.structs
import
DatasetSubset
from
megatron.energon.worker
import
WorkerConfig
class
DatasetBlendMode
(
Enum
):
"""Determines how the the datasets are to be blended. Either by using the associated number as
the weight for sampling from that dataset, or alternatively by using the number as the number
of repetitions for samples in that dataset in one epoch (effectively, that corresponds to the
weight for samples)."""
NONE
=
"none"
DATASET_WEIGHT
=
"dataset_weight"
SAMPLE_REPETITIONS
=
"sample_repetitions"
@
edataclass
class
LoadedDataset
:
dataset
:
BaseCoreDatasetFactory
weight
:
Union
[
float
,
int
,
None
]
=
None
repetitions
:
Union
[
float
,
int
,
None
]
=
None
aux
:
Optional
[
Dict
[
str
,
FileStore
]]
=
None
@
edataclass
class
LoadedDatasetList
:
datasets
:
List
[
LoadedDataset
]
blend_mode
:
DatasetBlendMode
=
DatasetBlendMode
.
NONE
class
DatasetLoaderInterface
(
ABC
):
"""General interface for a dataset loader."""
@
abstractmethod
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
):
"""Called to finally initialize the dataset."""
...
@
abstractmethod
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
"""
Calls :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
for all innermost datasets and resolves their relative weights to absolute weights.
Args:
training: If true, apply training randomization.
split_part: Default split part to use.
worker_config: Worker configuration to use.
subflavors: Set the default subflavors for all datasets.
shuffle_over_epochs_multiplier: Multiply the inner datasets
`shuffle_over_epochs(_multiplier)` by this factor. E.g. if the inner dataset
has `shuffle_over_epochs_multiplier=2` and this function has
`shuffle_over_epochs_multiplier=3`, the inner dataset will be shuffled
over 6 epochs. Shuffling over `n` epochs guarantees that each sample is seen
exactly `n` times in `n` epochs of the inner dataset. Use -1 for shuffling over
an infinite number of epochs (effectively, this will draw shard slices with
replacement).
subset: If specified, the inner dataset(s) will be subsetted.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The dataset blending mode and the instantiated core datasets with their weights/repetitions.
"""
...
def
prepare
(
self
,
split_part
:
Optional
[
str
]
=
None
)
->
Sequence
[
EPath
]:
"""
Prepares the loader by creating caches and other necessary structures on disk.
Args:
split_part: Name of the split to load.
Returns:
List of paths to the cache paths. This is used for cleanup.
"""
return
()
Megatron-Energon/src/megatron/energon/metadataset/metadataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Union
from
megatron.energon.dataset_config
import
load_config
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.webdataset.metadata
import
check_dataset_info_present
from
megatron.energon.flavors.webdataset.structs
import
DatasetSubset
from
megatron.energon.metadataset.dataset_loader
import
DatasetLoader
from
megatron.energon.metadataset.loader_interface
import
(
DatasetBlendMode
,
DatasetLoaderInterface
,
LoadedDatasetList
,
)
from
megatron.energon.worker
import
WorkerConfig
@
edataclass
class
DatasetReference
:
path
:
Union
[
str
,
EPath
]
split_part
:
Optional
[
str
]
=
None
# Note: subflavor is only for legacy compatibility.
subflavor
:
Optional
[
str
]
=
None
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
dataset_config
:
Optional
[
str
]
=
None
split_config
:
Optional
[
str
]
=
None
weight
:
float
=
1.0
_dataset
:
Optional
[
DatasetLoaderInterface
]
=
None
def
__post_init__
(
self
):
if
self
.
subflavor
is
not
None
:
warn_deprecated
(
"subflavor is deprecated, use subflavors instead. This will be removed in a future release."
)
if
self
.
subflavors
is
None
:
self
.
subflavors
=
{
"__subflavor__"
:
self
.
subflavor
}
elif
"__subflavor__"
not
in
self
.
subflavors
:
self
.
subflavors
=
{
"__subflavor__"
:
self
.
subflavor
,
**
(
self
.
subflavors
or
{})}
self
.
subflavor
=
None
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
):
assert
mds_path
is
not
None
if
not
isinstance
(
self
.
path
,
EPath
):
self
.
path
=
mds_path
.
parent
/
self
.
path
if
self
.
path
.
is_file
():
assert
self
.
dataset_config
is
None
,
"Must not set dataset_config"
assert
self
.
split_config
is
None
,
"Must not set split_config"
self
.
_dataset
=
load_config
(
self
.
path
,
default_type
=
Metadataset
,
default_kwargs
=
dict
(
path
=
self
.
path
),
)
self
.
_dataset
.
post_initialize
()
elif
check_dataset_info_present
(
self
.
path
):
self
.
_dataset
=
DatasetLoader
(
path
=
self
.
path
,
split_config
=
self
.
split_config
,
dataset_config
=
self
.
dataset_config
,
)
self
.
_dataset
.
post_initialize
()
else
:
raise
FileNotFoundError
(
self
.
path
)
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
if
self
.
subflavors
is
not
None
:
subflavors
=
{
**
self
.
subflavors
,
**
(
subflavors
or
{})}
assert
self
.
_dataset
is
not
None
if
shuffle_over_epochs_multiplier
is
None
or
self
.
shuffle_over_epochs_multiplier
is
None
:
# If no shuffling is requested, this has override priority.
new_shuffle_over_epochs_multiplier
=
None
elif
shuffle_over_epochs_multiplier
==
-
1
or
self
.
shuffle_over_epochs_multiplier
==
-
1
:
# Next priority is sampling without replacement.
new_shuffle_over_epochs_multiplier
=
-
1
else
:
# Otherwise, multiply the shuffle over epochs multiplier.
new_shuffle_over_epochs_multiplier
=
(
shuffle_over_epochs_multiplier
*
self
.
shuffle_over_epochs_multiplier
)
return
self
.
_dataset
.
get_datasets
(
training
=
training
,
split_part
=
self
.
split_part
or
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs_multiplier
=
new_shuffle_over_epochs_multiplier
,
subset
=
subset
,
**
kwargs
,
)
@
edataclass
class
MetadatasetBlender
:
"""Internal blending of the dataset."""
datasets
:
List
[
DatasetReference
]
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
):
assert
mds_path
is
not
None
for
dataset
in
self
.
datasets
:
dataset
.
post_initialize
(
mds_path
)
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
sum_weight
=
sum
(
dataset
.
weight
for
dataset
in
self
.
datasets
)
datasets
=
[]
for
dataset
in
self
.
datasets
:
inner_result
=
dataset
.
get_datasets
(
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs_multiplier
=
shuffle_over_epochs_multiplier
,
subset
=
subset
,
**
kwargs
,
)
if
inner_result
.
blend_mode
not
in
(
DatasetBlendMode
.
NONE
,
DatasetBlendMode
.
DATASET_WEIGHT
,
):
raise
ValueError
(
"Can only blend datasets which are of the same blend mode. Cannot mix blend with blend_epochized."
)
for
loaded_dataset
in
inner_result
.
datasets
:
if
inner_result
.
blend_mode
==
DatasetBlendMode
.
DATASET_WEIGHT
:
assert
isinstance
(
loaded_dataset
.
weight
,
float
)
else
:
assert
loaded_dataset
.
weight
is
None
loaded_dataset
.
weight
=
1.0
loaded_dataset
.
weight
=
loaded_dataset
.
weight
*
dataset
.
weight
/
sum_weight
datasets
.
append
(
loaded_dataset
)
return
LoadedDatasetList
(
blend_mode
=
DatasetBlendMode
.
DATASET_WEIGHT
,
datasets
=
datasets
,
)
class
Metadataset
(
DatasetLoaderInterface
):
"""Main entry for metadataset."""
_path
:
EPath
_splits
:
Dict
[
str
,
MetadatasetBlender
]
def
__init__
(
self
,
path
:
Union
[
EPath
,
str
],
splits
:
Dict
[
str
,
MetadatasetBlender
],
):
"""Create the metadataset"""
self
.
_path
=
EPath
(
path
)
self
.
_splits
=
splits
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
):
assert
mds_path
is
None
for
split
in
self
.
_splits
.
values
():
split
.
post_initialize
(
self
.
_path
)
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
return
self
.
_splits
[
split_part
].
get_datasets
(
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs_multiplier
=
shuffle_over_epochs_multiplier
,
subset
=
subset
,
**
kwargs
,
)
Megatron-Energon/src/megatron/energon/metadataset/metadataset_v2.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
re
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Type
,
Union
from
megatron.energon.cache
import
FileStore
from
megatron.energon.cache.file_store
import
(
SystemFileStore
,
WebdatasetFileStore
,
)
from
megatron.energon.dataset_config
import
load_config
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors
import
Sample
from
megatron.energon.flavors.webdataset
import
MAIN_FOLDER_NAME
from
megatron.energon.flavors.webdataset.metadata
import
EnergonDatasetType
,
get_dataset_type
from
megatron.energon.flavors.webdataset.structs
import
DatasetSubset
from
megatron.energon.metadataset.dataset_loader
import
DatasetLoader
from
megatron.energon.metadataset.join_dataset_loader
import
JoinDatasetLoader
,
JoinedDatasetInfo
from
megatron.energon.metadataset.loader_interface
import
(
DatasetBlendMode
,
DatasetLoaderInterface
,
LoadedDatasetList
,
)
from
megatron.energon.metadataset.metadataset
import
Metadataset
from
megatron.energon.worker
import
WorkerConfig
# Regex for any URL-like string (any protocol)
url_regex
=
re
.
compile
(
r
"^(?P<protocol>[a-z][a-z0-9+.-]*)://(?P<path>.*)"
,
re
.
IGNORECASE
)
@
edataclass
class
AuxDatasetReference
:
path
:
Union
[
str
,
EPath
]
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
)
->
None
:
assert
mds_path
is
not
None
if
not
isinstance
(
self
.
path
,
EPath
):
self
.
path
=
mds_path
.
parent
/
self
.
path
assert
not
self
.
path
.
is_file
(),
(
"Auxiliary datasets must not be metadataset, but direct dataset references"
)
assert
(
self
.
path
/
MAIN_FOLDER_NAME
/
"index.sqlite"
).
is_file
(),
(
"Auxiliary datasets must be prepared Energon datasets. This one does not exist or is not prepared: "
+
str
(
self
.
path
)
)
def
get_file_store
(
self
)
->
FileStore
:
assert
isinstance
(
self
.
path
,
EPath
),
"Missing call to post_initialize"
return
WebdatasetFileStore
(
self
.
path
)
@
edataclass
class
AuxFilesystemReference
:
fs_path
:
Union
[
str
,
EPath
]
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
)
->
None
:
assert
mds_path
is
not
None
if
not
isinstance
(
self
.
fs_path
,
EPath
):
self
.
fs_path
=
mds_path
.
parent
/
self
.
fs_path
def
get_file_store
(
self
)
->
FileStore
:
assert
isinstance
(
self
.
fs_path
,
EPath
),
"Missing call to post_initialize"
return
SystemFileStore
(
self
.
fs_path
)
@
edataclass
class
Subset
:
"""
A subset range to be applied to a dataset. The range is always consecutive.
The range is a tuple of two values, where the first value is the start of the subset and the second value is the end of the subset (end not included).
The range can either be an absolute range with sample indices, or a ratio of the dataset size.
Relative range example: [25%, 75%]. This would limit the subset to the middle 50% of the dataset.
Absolute range example: [100, 200]. This would limit the subset to the 100 samples with indices 100-199.
For absolute ranges, the end can be set to "end" to indicate the end of the dataset, for example [100, end].
Since subsets can be specified at multiple levels of a hierarchy, for example in a blend,
their effects can be merged to a single subset.
Note however, that absolute ranges are only allowed for leaf datasets, while relative ranges
can be applied at any level.
"""
range
:
tuple
[
str
|
int
,
str
|
int
]
def
as_dataset_subset
(
self
)
->
DatasetSubset
:
"""Convert the subset with string values to a DatasetSubset object with `range` and `absolute_range`."""
start
,
end
=
self
.
range
def
_conv
(
value
:
str
|
int
)
->
float
|
int
|
None
:
if
isinstance
(
value
,
int
):
return
value
else
:
assert
isinstance
(
value
,
str
),
"Range must be a string if it's not an integer"
if
value
.
strip
()
==
"end"
:
return
None
assert
value
.
endswith
(
"%"
),
"Range must be a percentage"
percentage
=
float
(
value
.
removesuffix
(
"%"
))
assert
0
<=
percentage
<=
100
,
"Percentage must be between 0 and 100"
return
percentage
/
100.0
start
=
_conv
(
start
)
end
=
_conv
(
end
)
if
isinstance
(
start
,
int
):
assert
isinstance
(
end
,
int
)
or
end
is
None
,
(
"End must be an integer if start is an integer"
)
return
DatasetSubset
(
absolute_range
=
(
start
,
end
),
range
=
(
0
,
1
))
else
:
assert
isinstance
(
start
,
float
),
"Range start must be a float if it's not an integer"
assert
isinstance
(
end
,
float
)
or
end
is
None
,
"End must be a float if start is a float"
assert
0
<=
start
<=
1
,
"Start must be between 0 and 1"
assert
0
<=
end
<=
1
,
"End must be between 0 and 1"
assert
start
<=
end
,
"Start must be less than end"
return
DatasetSubset
(
range
=
(
start
,
end
),
absolute_range
=
None
)
def
merge
(
self
,
parent_subset
:
DatasetSubset
|
None
)
->
DatasetSubset
:
"""Merge this subset with a parent subset.
If the parent subset is None, return the subset.
If the parent subset is an absolute range, fail, because that's not allowed.
If the parent subset is a ratio, merge it with the subset.
Merging a child absolute range with a parent relative range:
In this case, both are kept in the DatasetSubset object and applies in "absolute first" order later.
Merging a child relative range with a parent relative range:
In this case, the relative parent range is applied to the child's relative range.
The absolute range is not affected.
For details on how this is applied, see `DatasetSubset.compute_subset`.
"""
assert
parent_subset
is
None
or
parent_subset
.
absolute_range
is
None
,
(
f
"Cannot merge absolute subset ranges. Absolute ranges are only allowed for a leaf dataset.
{
self
.
absolute_range
=
}
{
self
.
range
=
}
"
)
my_subset
=
self
.
as_dataset_subset
()
if
parent_subset
is
None
or
parent_subset
.
range
is
None
:
return
my_subset
# Assuming inner ratio: [0.25, 0.75] and outer ratio: [0, 0.5]
# Then the total ratio is supposed to be: [0.25 + 0*0.5, 0.25 + 0.5 * 0.5] = [0.25, 0.5]
total
=
my_subset
.
range
[
1
]
-
my_subset
.
range
[
0
]
return
DatasetSubset
(
range
=
(
my_subset
.
range
[
0
]
+
parent_subset
.
range
[
0
]
*
total
,
my_subset
.
range
[
0
]
+
parent_subset
.
range
[
1
]
*
total
,
),
absolute_range
=
my_subset
.
absolute_range
,
)
@
edataclass
class
SubsetRatioMixin
:
subset
:
Optional
[
Subset
]
=
None
def
_get_subset
(
self
,
parent_subset
:
Optional
[
DatasetSubset
])
->
Optional
[
DatasetSubset
]:
if
parent_subset
is
not
None
:
assert
parent_subset
.
absolute_range
is
None
,
(
f
"Can only use absolute subset ranges for a leaf dataset (Range
{
parent_subset
.
absolute_range
=
}
)"
)
if
self
.
subset
is
not
None
:
return
self
.
subset
.
merge
(
parent_subset
)
else
:
return
parent_subset
elif
self
.
subset
is
not
None
:
return
self
.
subset
.
merge
(
None
)
return
None
@
edataclass
class
DatasetReference
(
SubsetRatioMixin
,
DatasetLoaderInterface
):
path
:
Union
[
str
,
EPath
]
split_part
:
Optional
[
str
]
=
None
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
dataset_config
:
Optional
[
str
]
=
None
split_config
:
Optional
[
str
]
=
None
#: Auxiliary datasets. May only be specified for crude datasets for cooking. Cooking will get
# these references to load data from. If specified as string, it will be interpreted as a
# dataset path.
aux
:
Optional
[
Dict
[
str
,
str
]]
=
None
_dataset
:
Optional
[
DatasetLoaderInterface
]
=
None
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
)
->
None
:
assert
mds_path
is
not
None
if
not
isinstance
(
self
.
path
,
EPath
):
self
.
path
=
mds_path
.
parent
/
self
.
path
ds_type
=
get_dataset_type
(
self
.
path
)
if
ds_type
==
EnergonDatasetType
.
METADATASET
:
assert
self
.
aux
is
None
,
"Cannot specify auxiliary datasets for crude datasets"
assert
self
.
dataset_config
is
None
,
"Must not set dataset_config"
assert
self
.
split_config
is
None
,
"Must not set split_config"
# Note: For backwards compatibility, the type must be Metadataset (V1).
self
.
_dataset
=
load_config
(
self
.
path
,
default_type
=
Metadataset
,
default_kwargs
=
dict
(
path
=
self
.
path
),
)
self
.
_dataset
.
post_initialize
()
elif
ds_type
in
(
EnergonDatasetType
.
WEBDATASET
,
EnergonDatasetType
.
JSONL
):
self
.
_dataset
=
DatasetLoader
(
path
=
self
.
path
,
split_config
=
self
.
split_config
,
dataset_config
=
self
.
dataset_config
,
)
self
.
_dataset
.
post_initialize
()
if
self
.
aux
is
not
None
:
new_aux
=
{}
for
k
,
v
in
self
.
aux
.
items
():
if
m
:
=
url_regex
.
match
(
v
):
if
m
.
group
(
"protocol"
)
==
"filesystem"
:
new_aux
[
k
]
=
AuxFilesystemReference
(
fs_path
=
m
.
group
(
"path"
))
else
:
raise
ValueError
(
f
"Unsupported protocol:
{
m
.
group
(
'protocol'
)
}
"
)
else
:
new_aux
[
k
]
=
AuxDatasetReference
(
path
=
v
)
new_aux
[
k
].
post_initialize
(
mds_path
)
self
.
aux
=
new_aux
else
:
raise
FileNotFoundError
(
self
.
path
)
def
prepare
(
self
,
split_part
:
Optional
[
str
]
=
None
)
->
Sequence
[
EPath
]:
assert
self
.
_dataset
is
not
None
return
self
.
_dataset
.
prepare
(
split_part
=
split_part
)
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
if
self
.
subflavors
is
not
None
:
subflavors
=
{
**
self
.
subflavors
,
**
(
subflavors
or
{})}
assert
self
.
_dataset
is
not
None
if
shuffle_over_epochs_multiplier
is
None
or
self
.
shuffle_over_epochs_multiplier
is
None
:
# If no shuffling is requested, this has override priority.
new_shuffle_over_epochs_multiplier
=
None
elif
shuffle_over_epochs_multiplier
==
-
1
or
self
.
shuffle_over_epochs_multiplier
==
-
1
:
# Next priority is sampling without replacement.
new_shuffle_over_epochs_multiplier
=
-
1
else
:
# Otherwise, multiply the shuffle over epochs multiplier.
new_shuffle_over_epochs_multiplier
=
(
shuffle_over_epochs_multiplier
*
self
.
shuffle_over_epochs_multiplier
)
subset
=
self
.
_get_subset
(
subset
)
result
=
self
.
_dataset
.
get_datasets
(
training
=
training
,
split_part
=
self
.
split_part
or
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs_multiplier
=
new_shuffle_over_epochs_multiplier
,
subset
=
subset
,
**
kwargs
,
)
if
self
.
aux
is
not
None
:
aux
=
{
k
:
v
.
get_file_store
()
for
k
,
v
in
self
.
aux
.
items
()}
for
loaded_dataset
in
result
.
datasets
:
if
loaded_dataset
.
aux
is
None
:
loaded_dataset
.
aux
=
aux
else
:
loaded_dataset
.
aux
.
update
(
aux
)
return
result
@
edataclass
class
JoinDatasetReference
(
DatasetReference
):
nonmatch
:
Literal
[
"skip"
,
"none"
,
"error"
]
=
"error"
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
)
->
DatasetLoader
:
assert
mds_path
is
not
None
# Override and disable another metadataset reference, only allow direct dataset references.
# Do not store the loader, the parent MetadatasetJoin will do that.
if
not
isinstance
(
self
.
path
,
EPath
):
self
.
path
=
mds_path
.
parent
/
self
.
path
ds_type
=
get_dataset_type
(
self
.
path
)
if
ds_type
==
EnergonDatasetType
.
WEBDATASET
:
return
DatasetLoader
(
path
=
self
.
path
,
split_part
=
self
.
split_part
,
subflavors
=
self
.
subflavors
,
shuffle_over_epochs_multiplier
=
self
.
shuffle_over_epochs_multiplier
,
dataset_config
=
self
.
dataset_config
,
split_config
=
self
.
split_config
,
)
else
:
raise
ValueError
(
f
"Not a joinabledataset at
{
self
.
path
}
"
)
def
prepare
(
self
,
split_part
:
Optional
[
str
]
=
None
):
assert
False
,
(
"JoinDatasetReference should not be used directly, but only by MetadatasetJoin"
)
def
get_datasets
(
self
,
**
kwargs
,
)
->
LoadedDatasetList
:
assert
False
,
(
"JoinDatasetReference should not be used directly, but only by MetadatasetJoin"
)
@
edataclass
class
MetadatasetJoin
(
SubsetRatioMixin
,
DatasetLoaderInterface
):
join
:
Union
[
List
[
JoinDatasetReference
],
Dict
[
str
,
JoinDatasetReference
]]
joiner
:
Union
[
Type
[
Sample
],
Callable
[...,
Sample
]]
split_part
:
Optional
[
str
]
=
None
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
dataset_config
:
Optional
[
str
]
=
None
split_config
:
Optional
[
str
]
=
None
_dataset
:
Optional
[
JoinDatasetLoader
]
=
None
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
):
assert
mds_path
is
not
None
assert
self
.
join
is
not
None
assert
self
.
joiner
is
not
None
,
"Must set joiner for joining datasets"
assert
self
.
dataset_config
is
None
,
"Cannot set dataset_config for joining datasets"
assert
self
.
split_config
is
None
,
"Cannot set split_config for joining datasets"
if
isinstance
(
self
.
join
,
list
):
inner_loaders
=
[
JoinedDatasetInfo
(
dataset
=
join
.
post_initialize
(
mds_path
),
nonmatch
=
join
.
nonmatch
,
)
for
join
in
self
.
join
]
elif
isinstance
(
self
.
join
,
dict
):
inner_loaders
=
{
key
:
JoinedDatasetInfo
(
dataset
=
join
.
post_initialize
(
mds_path
),
nonmatch
=
join
.
nonmatch
,
)
for
key
,
join
in
self
.
join
.
items
()
}
else
:
raise
ValueError
(
"Invalid join type"
)
self
.
_dataset
=
JoinDatasetLoader
(
datasets
=
inner_loaders
,
joiner
=
self
.
joiner
,
split_part
=
self
.
split_part
,
subflavors
=
self
.
subflavors
,
shuffle_over_epochs_multiplier
=
self
.
shuffle_over_epochs_multiplier
,
split_config
=
self
.
split_config
,
)
self
.
_dataset
.
post_initialize
(
mds_path
)
def
prepare
(
self
,
split_part
:
Optional
[
str
]
=
None
)
->
Sequence
[
EPath
]:
assert
self
.
_dataset
is
not
None
,
"Missing post_initialize call."
return
self
.
_dataset
.
prepare
(
split_part
=
split_part
)
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
assert
self
.
_dataset
is
not
None
,
"Missing post_initialize call."
subset
=
self
.
_get_subset
(
subset
)
return
self
.
_dataset
.
get_datasets
(
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs_multiplier
=
shuffle_over_epochs_multiplier
,
subset
=
subset
,
**
kwargs
,
)
@
dataclass
class
BlendWeightMixin
:
weight
:
float
=
1.0
@
edataclass
class
BlendDatasetReference
(
BlendWeightMixin
,
DatasetReference
):
pass
@
edataclass
class
BlendJoinDatasetReference
(
BlendWeightMixin
,
MetadatasetJoin
):
pass
@
edataclass
class
MetadatasetBlend
(
DatasetLoaderInterface
,
SubsetRatioMixin
):
"""Blending of datasets by specifying the sampling weight for the inner datasets."""
blend
:
List
[
Union
[
BlendDatasetReference
,
BlendJoinDatasetReference
]]
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
):
assert
mds_path
is
not
None
for
dataset
in
self
.
blend
:
dataset
.
post_initialize
(
mds_path
)
def
prepare
(
self
,
split_part
:
Optional
[
str
]
=
None
)
->
Sequence
[
EPath
]:
files
=
[]
for
dataset
in
self
.
blend
:
files
.
extend
(
dataset
.
prepare
(
split_part
=
split_part
))
return
files
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
subset
=
self
.
_get_subset
(
subset
)
sum_weight
=
sum
(
dataset
.
weight
for
dataset
in
self
.
blend
)
datasets
=
[]
for
dataset
in
self
.
blend
:
inner_result
=
dataset
.
get_datasets
(
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs_multiplier
=
shuffle_over_epochs_multiplier
,
subset
=
subset
,
**
kwargs
,
)
if
inner_result
.
blend_mode
not
in
(
DatasetBlendMode
.
NONE
,
DatasetBlendMode
.
DATASET_WEIGHT
,
):
raise
ValueError
(
"Can only blend datasets which are of the same blend mode. Cannot mix blend with blend_epochized."
)
for
loaded_dataset
in
inner_result
.
datasets
:
if
inner_result
.
blend_mode
==
DatasetBlendMode
.
DATASET_WEIGHT
:
assert
isinstance
(
loaded_dataset
.
weight
,
float
)
else
:
assert
inner_result
.
blend_mode
==
DatasetBlendMode
.
NONE
assert
loaded_dataset
.
weight
is
None
assert
loaded_dataset
.
repetitions
is
None
loaded_dataset
.
weight
=
1.0
loaded_dataset
.
weight
=
loaded_dataset
.
weight
*
dataset
.
weight
/
sum_weight
datasets
.
append
(
loaded_dataset
)
return
LoadedDatasetList
(
blend_mode
=
DatasetBlendMode
.
DATASET_WEIGHT
,
datasets
=
datasets
,
)
@
dataclass
class
BlendRepetitionsMixin
:
repetitions
:
Union
[
int
,
float
]
=
1
@
edataclass
class
BlendEpochizedDatasetReference
(
BlendRepetitionsMixin
,
DatasetReference
):
pass
@
edataclass
class
BlendEpochizedJoinDatasetReference
(
BlendRepetitionsMixin
,
MetadatasetJoin
):
pass
@
edataclass
class
MetadatasetBlendEpochized
(
SubsetRatioMixin
,
DatasetLoaderInterface
):
"""Blending of datasets, by specifying the number of repetitions for samples from the inner
datasets. Ensures that the constraint, that samples are seen exactly this many times before
repeating the "epoch" (i.e. one epoch contains the total number of repetitions for each inner
dataset)."""
blend_epochized
:
List
[
Union
[
BlendEpochizedDatasetReference
,
BlendEpochizedJoinDatasetReference
]]
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
):
assert
mds_path
is
not
None
for
dataset
in
self
.
blend_epochized
:
dataset
.
post_initialize
(
mds_path
)
def
prepare
(
self
,
split_part
:
Optional
[
str
]
=
None
)
->
Sequence
[
EPath
]:
files
=
[]
for
dataset
in
self
.
blend_epochized
:
files
.
extend
(
dataset
.
prepare
(
split_part
=
split_part
))
return
files
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
subset
=
self
.
_get_subset
(
subset
)
datasets
=
[]
for
dataset
in
self
.
blend_epochized
:
inner_result
=
dataset
.
get_datasets
(
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs_multiplier
=
shuffle_over_epochs_multiplier
,
subset
=
subset
,
**
kwargs
,
)
if
inner_result
.
blend_mode
not
in
(
DatasetBlendMode
.
NONE
,
DatasetBlendMode
.
SAMPLE_REPETITIONS
,
):
raise
ValueError
(
"Can only blend datasets which are of the same blend mode. Cannot mix blend with blend_epochized."
)
for
loaded_dataset
in
inner_result
.
datasets
:
if
inner_result
.
blend_mode
==
DatasetBlendMode
.
SAMPLE_REPETITIONS
:
assert
isinstance
(
loaded_dataset
.
repetitions
,
(
int
,
float
))
else
:
assert
loaded_dataset
.
weight
is
None
assert
loaded_dataset
.
repetitions
is
None
loaded_dataset
.
repetitions
=
1
loaded_dataset
.
repetitions
=
dataset
.
repetitions
*
loaded_dataset
.
repetitions
datasets
.
append
(
loaded_dataset
)
return
LoadedDatasetList
(
blend_mode
=
DatasetBlendMode
.
SAMPLE_REPETITIONS
,
datasets
=
datasets
,
)
@
edataclass
class
MetadatasetV2
(
DatasetLoaderInterface
):
path
:
EPath
splits
:
Dict
[
str
,
Union
[
MetadatasetBlend
,
MetadatasetBlendEpochized
,
MetadatasetJoin
,
DatasetReference
]
]
def
post_initialize
(
self
,
mds_path
:
Optional
[
EPath
]
=
None
):
assert
mds_path
is
None
for
split
in
self
.
splits
.
values
():
split
.
post_initialize
(
self
.
path
)
def
prepare
(
self
,
split_part
:
Optional
[
str
]
=
None
)
->
Sequence
[
EPath
]:
# In the case of prepare for MetadatasetV2, we ignore the passed cache_path
# and instead use the own path.
# If someone runs energon prepare on a metadataset that refers to another metadataset,
# any actions concerning the inner metadataset will be done on the inner metadataset's path.
if
split_part
is
None
:
files
=
[]
for
split_part
,
split
in
self
.
splits
.
items
():
files
.
extend
(
split
.
prepare
(
split_part
=
split_part
))
else
:
files
=
self
.
splits
[
split_part
].
prepare
(
split_part
=
split_part
)
# Cleanup paths here
cache_path
=
EPath
(
self
.
path
.
parent
/
f
"
{
self
.
path
.
name
}
.cache"
)
remove_files
=
set
(
cache_path
.
glob
(
"*"
))
-
set
(
files
)
for
file
in
remove_files
:
print
(
f
"Cleanup cache file
{
file
}
"
)
file
.
unlink
()
# Do not return paths. This is a barrier
return
()
def
get_datasets
(
self
,
*
,
training
:
bool
,
split_part
:
Union
[
Literal
[
"train"
,
"val"
,
"test"
],
str
],
worker_config
:
WorkerConfig
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
**
kwargs
,
)
->
LoadedDatasetList
:
return
self
.
splits
[
split_part
].
get_datasets
(
training
=
training
,
split_part
=
split_part
,
worker_config
=
worker_config
,
subflavors
=
subflavors
,
shuffle_over_epochs_multiplier
=
shuffle_over_epochs_multiplier
,
subset
=
subset
,
**
kwargs
,
)
Prev
1
…
3
4
5
6
7
8
9
10
11
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment