Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
Qwen2.5-14B-Instruct_dcu-megatron
Commits
f356f546
Commit
f356f546
authored
Feb 04, 2026
by
maming
Browse files
Initial commit
parents
Pipeline
#3339
canceled with stages
Changes
346
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2142 additions
and
0 deletions
+2142
-0
Megatron-Energon/src/megatron/energon/cache/no_cache.py
Megatron-Energon/src/megatron/energon/cache/no_cache.py
+54
-0
Megatron-Energon/src/megatron/energon/cli/__init__.py
Megatron-Energon/src/megatron/energon/cli/__init__.py
+2
-0
Megatron-Energon/src/megatron/energon/cli/main.py
Megatron-Energon/src/megatron/energon/cli/main.py
+62
-0
Megatron-Energon/src/megatron/energon/dataset_config.py
Megatron-Energon/src/megatron/energon/dataset_config.py
+142
-0
Megatron-Energon/src/megatron/energon/edataclass.py
Megatron-Energon/src/megatron/energon/edataclass.py
+20
-0
Megatron-Energon/src/megatron/energon/epathlib/__init__.py
Megatron-Energon/src/megatron/energon/epathlib/__init__.py
+6
-0
Megatron-Energon/src/megatron/energon/epathlib/epath.py
Megatron-Energon/src/megatron/energon/epathlib/epath.py
+253
-0
Megatron-Energon/src/megatron/energon/epathlib/rclone_config.py
...on-Energon/src/megatron/energon/epathlib/rclone_config.py
+87
-0
Megatron-Energon/src/megatron/energon/errors.py
Megatron-Energon/src/megatron/energon/errors.py
+128
-0
Megatron-Energon/src/megatron/energon/eyaml.py
Megatron-Energon/src/megatron/energon/eyaml.py
+88
-0
Megatron-Energon/src/megatron/energon/flavors/__init__.py
Megatron-Energon/src/megatron/energon/flavors/__init__.py
+89
-0
Megatron-Energon/src/megatron/energon/flavors/base_dataset.py
...tron-Energon/src/megatron/energon/flavors/base_dataset.py
+485
-0
Megatron-Energon/src/megatron/energon/flavors/captioning.py
Megatron-Energon/src/megatron/energon/flavors/captioning.py
+38
-0
Megatron-Energon/src/megatron/energon/flavors/crude.py
Megatron-Energon/src/megatron/energon/flavors/crude.py
+51
-0
Megatron-Energon/src/megatron/energon/flavors/image.py
Megatron-Energon/src/megatron/energon/flavors/image.py
+35
-0
Megatron-Energon/src/megatron/energon/flavors/image_classification.py
...rgon/src/megatron/energon/flavors/image_classification.py
+40
-0
Megatron-Energon/src/megatron/energon/flavors/interleaved.py
Megatron-Energon/src/megatron/energon/flavors/interleaved.py
+36
-0
Megatron-Energon/src/megatron/energon/flavors/jsonl/__init__.py
...on-Energon/src/megatron/energon/flavors/jsonl/__init__.py
+11
-0
Megatron-Energon/src/megatron/energon/flavors/jsonl/crude_jsonl_dataset.py
...src/megatron/energon/flavors/jsonl/crude_jsonl_dataset.py
+219
-0
Megatron-Energon/src/megatron/energon/flavors/jsonl/ijsonl.py
...tron-Energon/src/megatron/energon/flavors/jsonl/ijsonl.py
+296
-0
No files found.
Too many changes to show.
To preserve performance only
346 of 346+
files are displayed.
Plain diff
Email patch
Megatron-Energon/src/megatron/energon/cache/no_cache.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Optional
,
TypeVar
from
megatron.energon.cache.base
import
CachePool
,
FileStore
,
Lazy
,
MockLazy
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.source_info
import
SourceInfo
,
add_source_info
T
=
TypeVar
(
"T"
)
@
edataclass
class
DirectLazy
(
Lazy
[
T
]):
"""
This is not really lazy, it will just defer the dataset access to the first get().
"""
_data
:
Optional
[
tuple
[
T
,
SourceInfo
]]
=
None
def
get
(
self
,
sample
:
Any
=
None
)
->
T
:
"""Get the lazy data now and adds no source info to the sample."""
if
self
.
_data
is
None
:
self
.
_data
=
self
.
ds
[
self
.
fname
]
assert
self
.
_data
is
not
None
add_source_info
(
sample
,
self
.
_data
[
1
])
return
self
.
_data
[
0
]
def
__hash__
(
self
)
->
int
:
"""Allows usage in sets and dicts as key."""
return
hash
((
id
(
self
.
ds
),
self
.
fname
))
def
__eq__
(
self
,
other
:
Any
)
->
bool
:
"""Allows usage in sets and dicts as key. Compares the data source and the filename."""
if
not
isinstance
(
other
,
Lazy
):
return
False
return
self
.
ds
is
other
.
ds
and
self
.
fname
==
other
.
fname
class
NoCachePool
(
CachePool
):
"""A pass-through cache pool that does not cache anything."""
def
get
(
self
,
ds
:
FileStore
,
fname
:
str
,
sample
:
Any
=
None
)
->
Any
:
"""Get the data for a given file and adds the source info to the sample."""
return
ds
.
get
(
fname
,
sample
)
def
get_lazy
(
self
,
ds
:
FileStore
,
fname
:
str
)
->
DirectLazy
:
return
DirectLazy
(
ds
=
ds
,
fname
=
fname
,
pool
=
self
)
def
to_cache
(
self
,
data
:
T
,
name
:
str
)
->
DirectLazy
:
return
MockLazy
(
fname
=
name
,
get_fn
=
lambda
_
:
data
)
def
close
(
self
)
->
None
:
pass
Megatron-Energon/src/megatron/energon/cli/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
Megatron-Energon/src/megatron/energon/cli/main.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""Main entry point for energon CLI tools."""
import
click
from
megatron.energon.tools.analyze_debug
import
command
as
analyze_debug_command
from
megatron.energon.tools.checkpoint
import
command
as
checkpoint_command
from
megatron.energon.tools.info
import
command
as
info_command
from
megatron.energon.tools.lint
import
command
as
lint_command
from
megatron.energon.tools.mount
import
command
as
mount_command
from
megatron.energon.tools.prepare
import
command
as
prepare_command
from
megatron.energon.tools.preview
import
command
as
preview_command
def
print_version
(
ctx
,
param
,
value
):
"""Print version and exit."""
if
not
value
or
ctx
.
resilient_parsing
:
return
try
:
from
megatron.energon
import
__version__
click
.
echo
(
f
"megatron-energon
{
__version__
}
"
)
except
ImportError
:
click
.
echo
(
"megatron-energon version unknown"
)
ctx
.
exit
()
@
click
.
group
(
context_settings
=
dict
(
help_option_names
=
[
"-h"
,
"--help"
]),
invoke_without_command
=
True
)
@
click
.
option
(
"--version"
,
"-V"
,
is_flag
=
True
,
expose_value
=
False
,
is_eager
=
True
,
callback
=
print_version
,
help
=
"Show version and exit."
,
)
@
click
.
pass_context
def
main
(
ctx
):
"""A set of tools that energon provides.
Among other things, you can use it to lint or preprocess your dataset.
See help of commands to learn more."""
# This is needed to show help if no subcommand is provided
if
ctx
.
invoked_subcommand
is
None
:
click
.
echo
(
main
.
get_help
(
ctx
))
main
.
add_command
(
analyze_debug_command
)
main
.
add_command
(
checkpoint_command
)
main
.
add_command
(
lint_command
)
main
.
add_command
(
info_command
)
main
.
add_command
(
mount_command
)
main
.
add_command
(
prepare_command
)
main
.
add_command
(
preview_command
)
if
__name__
==
"__main__"
:
main
()
Megatron-Energon/src/megatron/energon/dataset_config.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
TypeVar
,
Union
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.eyaml
import
load_yaml
from
megatron.energon.flavors
import
(
BaseCoreDatasetFactory
,
CrudeSample
,
DefaultCrudeJsonlDatasetFactory
,
StandardWebdatasetFactory
,
)
from
megatron.energon.flavors.webdataset
import
MAIN_FOLDER_NAME
from
megatron.energon.flavors.webdataset.metadata
import
EnergonDatasetType
,
get_dataset_type
from
megatron.energon.typed_converter
import
JsonParser
from
megatron.energon.worker
import
WorkerConfig
T
=
TypeVar
(
"T"
)
def
load_config
(
path
:
Union
[
EPath
,
Dict
[
str
,
Any
]],
*
,
default_type
:
Type
[
T
],
default_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
parser
:
JsonParser
=
JsonParser
(
strict
=
True
),
)
->
T
:
"""
Loads a config from a file or directly from a dictionary.
Args:
path: Path to the config to load or a dictionary containing the config.
default_type: If set, this is the type to use if no type is specified in the config.
default_kwargs: Default kwargs to use, will be overridden by the config.
Returns:
The instantiated type.
"""
if
isinstance
(
path
,
dict
):
data
=
path
else
:
# Read the config from a file
with
path
.
open
(
"rb"
)
as
f
:
data
:
dict
=
load_yaml
(
f
)
if
default_kwargs
is
not
None
:
new_data
=
default_kwargs
.
copy
()
new_data
.
update
(
data
)
data
=
new_data
return
parser
.
raw_to_instance
(
data
,
default_type
)
T_sample
=
TypeVar
(
"T_sample"
,
covariant
=
True
)
def
get_dataset_from_config
(
path
:
Union
[
EPath
,
Path
,
str
],
*
,
dataset_config
:
str
|
None
=
None
,
split_config
:
str
|
None
=
None
,
split_part
:
str
|
None
=
None
,
training
:
bool
=
True
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
worker_config
:
WorkerConfig
,
sample_type
:
Optional
[
Type
[
T_sample
]]
=
None
,
**
kwargs
,
)
->
BaseCoreDatasetFactory
[
T_sample
]:
"""
Gets a dataset from a config path or path to a jsonl file.
Args:
path: Path to the folder where the `.nv-meta` folder is contained, or path to a jsonl file.
dataset_config: Filename of the dataset config file (`path / '.nv-meta' / config`), or None for jsonl datasets.
split_config: Filename of the split config file (`path / '.nv-meta' / split_config`), or None for jsonl datasets.
split_part: Name of the split to load, or None for jsonl datasets.
training: If true, apply training randomization and loop the dataset.
subflavors: Merge-Override the __subflavors__ property of each sample.
worker_config: If set, use this worker config instead of the default one.
sample_type: Type of the samples to load, only used to ensure typing.
**kwargs: Additional arguments to be passed to the dataset constructor.
Returns:
The instantiated dataset
"""
path
=
EPath
(
path
)
dataset
:
BaseCoreDatasetFactory
[
T_sample
]
ds_type
=
get_dataset_type
(
path
)
if
ds_type
==
EnergonDatasetType
.
JSONL
:
assert
sample_type
is
CrudeSample
or
sample_type
is
None
,
(
f
"Sample type must be CrudeSample for jsonl datasets, but got
{
sample_type
}
"
)
assert
dataset_config
is
None
,
(
f
"Dataset config must be None for jsonl datasets, but got
{
dataset_config
}
"
)
assert
split_config
is
None
,
(
f
"Split config must be None for jsonl datasets, but got
{
split_config
}
"
)
# Note: We ignore split_part for jsonl datasets and always return the full dataset.
dataset
=
DefaultCrudeJsonlDatasetFactory
(
path
,
training
=
training
,
subflavors
=
subflavors
,
worker_config
=
worker_config
,
**
kwargs
,
)
elif
ds_type
==
EnergonDatasetType
.
WEBDATASET
:
if
dataset_config
is
None
:
dataset_config
=
"dataset.yaml"
if
split_config
is
None
:
split_config
=
"split.yaml"
if
split_part
is
None
:
split_part
=
"train"
dataset
=
load_config
(
path
/
MAIN_FOLDER_NAME
/
dataset_config
,
default_kwargs
=
dict
(
path
=
path
,
split_config
=
split_config
,
split_part
=
split_part
,
training
=
training
,
worker_config
=
worker_config
,
**
kwargs
,
),
default_type
=
StandardWebdatasetFactory
,
)
else
:
raise
ValueError
(
f
"Path
{
path
}
does not contain a
{
MAIN_FOLDER_NAME
}
/.info.yaml or .info.json file nor is it a jsonl file. "
f
"Did you forget to prepare the dataset? Please check the documentation for an introduction to dataset "
f
"preparation."
)
if
subflavors
is
not
None
:
dataset
.
subflavors
.
update
(
subflavors
)
if
sample_type
is
not
None
:
assert
issubclass
(
dataset
.
__sample_type__
,
sample_type
),
(
f
"Sample of type
{
dataset
.
__sample_type__
}
is not a subclass of
{
sample_type
}
."
)
return
dataset
Megatron-Energon/src/megatron/energon/edataclass.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
dataclasses
import
dataclass
from
typing_extensions
import
dataclass_transform
# We define an alias for `@dataclass(slots=True, kw_only=True)`,
# because we want to use this combination almost everywhere.
@
dataclass_transform
(
kw_only_default
=
True
,
slots_default
=
True
)
def
edataclass
(
cls
):
"""
A dataclass transform that sets the kw_only and slots defaults to True.
This is equivalent to `@dataclass(slots=True, kw_only=True)`.
If you need more options, use `dataclass` directly.
E.g.: `@dataclass(slots=True, kw_only=True, eq=False)`.
"""
return
dataclass
(
kw_only
=
True
,
slots
=
True
)(
cls
)
Megatron-Energon/src/megatron/energon/epathlib/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.epathlib.epath
import
EPath
__all__
=
[
"EPath"
]
Megatron-Energon/src/megatron/energon/epathlib/epath.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
re
import
shutil
from
pathlib
import
Path
as
PathlibPath
from
pathlib
import
PurePosixPath
from
typing
import
BinaryIO
,
Generator
,
Literal
,
Optional
,
TextIO
,
Tuple
,
Union
,
overload
import
multistorageclient
as
msc
from
megatron.energon.errors
import
warn_deprecated
class
EPath
:
"""EPath work much like pathlib.Path but does support local as well as remote file systems.
To initialize a remote path, prepend the path with "msc://" to use the Multi-Storage Client (MSC).
For example:
EPath("msc://profilename/my_datasets/webdataset-000.tar")
You will need to have your MSC configuration (~/.msc_config.yaml) set up to access the object stores
or use your rclone configuration. See https://nvidia.github.io/multi-storage-client/config/index.html
for more information.
"""
# The path without the protocol. Can also be in S3 for example
internal_path
:
PurePosixPath
# The profile used to access the file system
profile
:
str
# The file system
fs
:
msc
.
StorageClient
def
__init__
(
self
,
initial_path
:
Union
[
str
,
"EPath"
,
PathlibPath
],
)
->
None
:
if
isinstance
(
initial_path
,
EPath
):
self
.
internal_path
=
initial_path
.
internal_path
self
.
profile
=
initial_path
.
profile
self
.
fs
=
initial_path
.
fs
else
:
if
isinstance
(
initial_path
,
PathlibPath
):
path
=
str
(
initial_path
.
absolute
())
profile
=
"default"
else
:
protocol
,
profile
,
path
=
self
.
_split_protocol
(
initial_path
)
if
protocol
is
None
or
protocol
==
"file"
:
profile
=
"default"
path
=
str
(
PathlibPath
(
path
).
absolute
())
elif
protocol
==
"rclone"
:
warn_deprecated
(
"rclone:// protocol is deprecated. Use msc:// instead."
)
else
:
assert
protocol
==
"msc"
,
f
"Unknown protocol:
{
protocol
}
"
if
not
path
.
startswith
(
"/"
):
path
=
"/"
+
path
self
.
internal_path
=
self
.
_resolve
(
path
)
assert
profile
is
not
None
self
.
profile
=
profile
# Resolve the client. Only depends on the protocol and the first part of the path
self
.
fs
,
_
=
msc
.
resolve_storage_client
(
f
"msc://
{
self
.
profile
}
"
)
def
__getstate__
(
self
)
->
dict
:
return
{
"internal_path"
:
self
.
internal_path
,
"profile"
:
self
.
profile
,
# Do not save the fs when serializing, to avoid leaking credentials
}
def
__setstate__
(
self
,
state
:
dict
)
->
None
:
self
.
internal_path
=
state
[
"internal_path"
]
self
.
profile
=
state
[
"profile"
]
self
.
fs
,
_
=
msc
.
resolve_storage_client
(
f
"msc://
{
self
.
profile
}
"
)
@
staticmethod
def
_resolve
(
path
:
Union
[
str
,
PurePosixPath
])
->
PurePosixPath
:
"""Resolve a path, removing .. and . components."""
if
isinstance
(
path
,
str
):
path
=
PurePosixPath
(
path
)
parts
=
path
.
parts
if
parts
[
0
]
!=
"/"
:
raise
ValueError
(
"Only absolute paths are supported"
)
if
".."
in
parts
or
"."
in
parts
:
new_parts
=
[]
for
part
in
parts
[
1
:]:
if
part
==
".."
:
if
len
(
new_parts
)
==
0
:
raise
ValueError
(
f
"Path above root:
{
path
}
"
)
new_parts
.
pop
()
elif
part
==
"."
:
pass
else
:
new_parts
.
append
(
part
)
path
=
PurePosixPath
(
"/"
,
*
new_parts
)
return
path
@
staticmethod
def
_split_protocol
(
path
:
str
)
->
Tuple
[
Optional
[
str
],
Optional
[
str
],
str
]:
regex
=
re
.
compile
(
r
"^(?P<protocol>[a-z]+)://(?P<profile>[^/]+?)/(?P<path>.+)$"
)
m
=
regex
.
match
(
path
)
if
m
is
None
:
return
None
,
None
,
path
return
m
.
group
(
"protocol"
),
m
.
group
(
"profile"
),
m
.
group
(
"path"
)
@
property
def
_internal_str_path
(
self
)
->
str
:
"""Return the path as used inside the file system, without the protocol and fs part."""
return
str
(
self
.
internal_path
)
@
overload
def
open
(
self
,
mode
:
Literal
[
"r"
,
"w"
]
=
"r"
,
block_size
:
Optional
[
int
]
=
None
)
->
TextIO
:
...
@
overload
def
open
(
self
,
mode
:
Literal
[
"rb"
,
"wb"
],
block_size
:
Optional
[
int
]
=
None
)
->
BinaryIO
:
...
def
open
(
self
,
mode
:
Literal
[
"r"
,
"rb"
,
"w"
,
"wb"
]
=
"r"
,
block_size
:
Optional
[
int
]
=
None
)
->
Union
[
TextIO
,
BinaryIO
]:
return
self
.
fs
.
open
(
self
.
_internal_str_path
,
mode
)
def
read_text
(
self
)
->
str
:
with
self
.
open
()
as
f
:
return
f
.
read
()
def
read_bytes
(
self
)
->
bytes
:
with
self
.
open
(
"rb"
)
as
f
:
return
f
.
read
()
def
write_text
(
self
,
text
:
str
)
->
None
:
with
self
.
open
(
"w"
)
as
f
:
f
.
write
(
text
)
def
write_bytes
(
self
,
data
:
bytes
)
->
None
:
with
self
.
open
(
"wb"
)
as
f
:
f
.
write
(
data
)
def
copy
(
self
,
target
:
"EPath"
)
->
None
:
"""Copy a file to a new path, possibly between different file systems.
Args:
target: The path to the local file to download to.
"""
if
self
.
is_file
():
if
self
.
fs
==
target
.
fs
:
self
.
fs
.
copy
(
self
.
_internal_str_path
,
target
.
_internal_str_path
)
elif
target
.
is_local
():
self
.
fs
.
download_file
(
self
.
_internal_str_path
,
target
.
_internal_str_path
)
elif
self
.
is_local
():
target
.
fs
.
upload_file
(
target
.
_internal_str_path
,
self
.
_internal_str_path
)
else
:
with
self
.
open
(
"rb"
)
as
src_f
,
target
.
open
(
"wb"
)
as
dst_f
:
shutil
.
copyfileobj
(
src_f
,
dst_f
)
else
:
inner_path
=
EPath
(
self
)
for
fpath
in
self
.
fs
.
list
(
self
.
_internal_str_path
):
inner_path
.
internal_path
=
PurePosixPath
(
"/"
+
fpath
.
key
)
inner_path
.
copy
(
target
/
inner_path
.
relative_to
(
self
))
@
property
def
name
(
self
)
->
str
:
return
self
.
internal_path
.
name
@
property
def
parent
(
self
)
->
"EPath"
:
new_path
=
EPath
(
self
)
new_path
.
internal_path
=
self
.
internal_path
.
parent
return
new_path
@
property
def
url
(
self
)
->
str
:
if
self
.
is_local
():
return
self
.
_internal_str_path
int_path_str
=
str
(
self
.
internal_path
)
return
f
"msc://
{
self
.
profile
}{
int_path_str
}
"
def
is_local
(
self
)
->
bool
:
return
self
.
profile
==
"default"
def
is_dir
(
self
)
->
bool
:
try
:
return
self
.
fs
.
info
(
self
.
_internal_str_path
).
type
==
"directory"
except
FileNotFoundError
:
return
False
def
is_file
(
self
)
->
bool
:
return
self
.
fs
.
is_file
(
self
.
_internal_str_path
)
def
mkdir
(
self
,
exist_ok
:
bool
=
True
,
parents
:
bool
=
False
):
pass
def
glob
(
self
,
pattern
)
->
Generator
[
"EPath"
,
None
,
None
]:
search_path_pattern
=
(
self
/
pattern
).
_internal_str_path
for
path
in
self
.
fs
.
glob
(
search_path_pattern
):
assert
isinstance
(
path
,
str
)
new_path
=
EPath
(
self
)
new_path
.
internal_path
=
self
.
_resolve
(
self
.
internal_path
/
PurePosixPath
(
path
))
yield
new_path
def
size
(
self
)
->
int
:
return
self
.
fs
.
info
(
self
.
_internal_str_path
).
content_length
def
with_suffix
(
self
,
suffix
:
str
)
->
"EPath"
:
new_path
=
EPath
(
self
)
new_path
.
internal_path
=
self
.
internal_path
.
with_suffix
(
suffix
)
return
new_path
def
move
(
self
,
target
:
"EPath"
)
->
None
:
self
.
copy
(
target
)
self
.
unlink
()
def
unlink
(
self
)
->
None
:
return
self
.
fs
.
delete
(
self
.
_internal_str_path
)
def
relative_to
(
self
,
other
:
"EPath"
)
->
str
:
assert
self
.
profile
==
other
.
profile
,
"Can only use relative_to within same profile"
return
str
(
self
.
internal_path
.
relative_to
(
other
.
internal_path
))
def
__truediv__
(
self
,
other
:
Union
[
str
,
"EPath"
])
->
"EPath"
:
if
isinstance
(
other
,
EPath
):
# Always absolute
return
other
if
other
.
startswith
(
"/"
)
or
"://"
in
other
:
return
EPath
(
other
)
new_path
=
EPath
(
self
)
new_path
.
internal_path
=
self
.
_resolve
(
self
.
internal_path
/
other
)
return
new_path
def
__lt__
(
self
,
other
:
"EPath"
)
->
bool
:
assert
self
.
profile
==
other
.
profile
,
"Cannot compare paths from different profiles"
return
self
.
internal_path
<
other
.
internal_path
def
__str__
(
self
)
->
str
:
return
self
.
url
def
__repr__
(
self
)
->
str
:
return
f
"EPath(
{
str
(
self
)
!
r
}
)"
def
__hash__
(
self
)
->
int
:
return
hash
((
self
.
internal_path
,
self
.
profile
))
def
__eq__
(
self
,
other
:
object
)
->
bool
:
return
(
isinstance
(
other
,
EPath
)
and
self
.
internal_path
==
other
.
internal_path
and
self
.
profile
==
other
.
profile
)
Megatron-Energon/src/megatron/energon/epathlib/rclone_config.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
configparser
import
os
import
shutil
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
from
megatron.energon.edataclass
import
edataclass
@
edataclass
class
ConfigEntry
:
name
:
str
type
:
str
provider
:
Optional
[
str
]
access_key_id
:
Optional
[
str
]
secret_access_key
:
Optional
[
str
]
region
:
Optional
[
str
]
endpoint
:
Optional
[
str
]
def
find_executable_path
(
executable_name
):
"""Find the path of an executable in the PATH environment variable. Returns None if not found."""
executable_path
=
shutil
.
which
(
executable_name
)
if
executable_path
:
return
Path
(
executable_path
)
return
None
def
get_rclone_config_path
()
->
Optional
[
Path
]:
# First check if rclone executable is in PATH, if yes, check if rclone.conf is in the same directory
rclone_exe_path
=
find_executable_path
(
"rclone"
)
if
rclone_exe_path
is
not
None
and
rclone_exe_path
.
is_file
():
rclone_config_path
=
rclone_exe_path
.
with_name
(
"rclone.conf"
)
if
rclone_config_path
.
is_file
():
return
rclone_config_path
# As a second option check the XDG_CONFIG_HOME environment variable, if it is set, check for rclone/rclone.conf in that directory
xdg_config_home
=
os
.
environ
.
get
(
"XDG_CONFIG_HOME"
)
if
xdg_config_home
and
Path
(
xdg_config_home
).
is_dir
():
rclone_config_path
=
Path
(
xdg_config_home
)
/
"rclone"
/
"rclone.conf"
if
rclone_config_path
.
is_file
():
return
rclone_config_path
# As a third option check the default location ~/.config/rclone/rclone.conf
rclone_config_path
=
Path
.
home
()
/
".config"
/
"rclone"
/
"rclone.conf"
if
rclone_config_path
.
is_file
():
return
rclone_config_path
# Last option is to check the legacy location ~/.rclone.conf
legacy_config_path
=
Path
.
home
()
/
".rclone.conf"
if
legacy_config_path
.
is_file
():
return
legacy_config_path
return
None
def
read_rclone_config_at_path
(
config_path
:
Path
)
->
Dict
[
str
,
ConfigEntry
]:
"""Reads the config file and returns a dictionary with the config entries."""
config
=
configparser
.
ConfigParser
()
config
.
read
(
config_path
)
config_entries
=
{}
for
section
in
config
.
sections
():
entry
=
ConfigEntry
(
name
=
section
,
type
=
config
[
section
].
get
(
"type"
),
provider
=
config
[
section
].
get
(
"provider"
),
access_key_id
=
config
[
section
].
get
(
"access_key_id"
),
secret_access_key
=
config
[
section
].
get
(
"secret_access_key"
),
region
=
config
[
section
].
get
(
"region"
),
endpoint
=
config
[
section
].
get
(
"endpoint"
),
)
config_entries
[
section
]
=
entry
return
config_entries
def
read_rclone_config
()
->
Dict
[
str
,
ConfigEntry
]:
config_path
=
get_rclone_config_path
()
if
config_path
is
None
:
raise
FileNotFoundError
(
"Could not find rclone configuration file."
)
return
read_rclone_config_at_path
(
config_path
)
Megatron-Energon/src/megatron/energon/errors.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
dataclasses
import
itertools
import
warnings
from
functools
import
wraps
from
typing
import
Any
,
Type
,
TypeVar
,
Union
import
numpy
as
np
import
torch
def
compact_str
(
value
:
Union
[
dict
,
list
,
str
,
int
,
bool
,
None
],
depth
:
int
=
3
,
max_items
:
int
=
10
,
max_str_len
:
int
=
50
,
)
->
str
:
"""
Compact representation of a value as a string.
Args:
value: The value to compact
depth: The maximum depth to compact
max_items: The maximum number of items to show in a list or dict
max_str_len: The maximum string length to show
Returns: The printable string
"""
if
isinstance
(
value
,
dict
):
if
depth
<=
0
:
return
"{...}"
return
(
"{"
+
", "
.
join
(
(
f
"
{
k
}
:
{
v
!
r
}
"
if
isinstance
(
k
,
str
)
and
k
.
startswith
(
"__"
)
else
f
"
{
k
}
:
{
compact_str
(
v
,
depth
-
1
,
max_items
,
max_str_len
)
}
"
)
for
k
,
v
in
itertools
.
islice
(
value
.
items
(),
max_items
)
)
+
"}"
)
elif
isinstance
(
value
,
list
):
if
depth
<=
0
:
return
"[...]"
return
(
"["
+
", "
.
join
(
compact_str
(
v
,
depth
-
1
,
max_items
,
max_str_len
)
for
v
in
value
[:
max_items
]
)
+
"]"
)
elif
isinstance
(
value
,
tuple
):
if
depth
<=
0
:
return
"(...)"
return
(
"("
+
", "
.
join
(
compact_str
(
v
,
depth
-
1
,
max_items
,
max_str_len
)
for
v
in
value
[:
max_items
]
)
+
")"
)
elif
isinstance
(
value
,
str
):
if
len
(
value
)
>
max_str_len
:
return
repr
(
value
[:
max_str_len
]
+
"..."
)
return
repr
(
value
)
elif
isinstance
(
value
,
torch
.
Tensor
):
return
f
"Tensor(shape=
{
value
.
shape
}
, dtype=
{
value
.
dtype
}
, device=
{
value
.
device
}
)"
elif
isinstance
(
value
,
np
.
ndarray
):
return
f
"np.ndarray(shape=
{
value
.
shape
}
, dtype=
{
value
.
dtype
}
)"
elif
dataclasses
.
is_dataclass
(
value
):
return
f
"
{
value
.
__class__
.
__name__
}
(
{
', '
.
join
(
f
'
{
field
.
name
}
=
{
compact_str
(
getattr
(
value
,
field
.
name
))
}
' for field in dataclasses.fields(value))
}
)"
else
:
return
compact_str
(
repr
(
value
),
depth
,
max_items
,
max_str_len
)
T
=
TypeVar
(
"T"
)
class
SampleException
(
ValueError
):
@
classmethod
def
from_sample_key
(
cls
:
Type
[
T
],
sample_key
:
str
)
->
T
:
return
cls
(
f
"Sample
{
sample_key
}
failed"
)
@
classmethod
def
from_sample
(
cls
:
Type
[
T
],
sample
:
Any
,
message
:
str
=
""
)
->
T
:
if
message
:
message
=
f
":
{
message
}
"
return
cls
(
f
"Sample
{
compact_str
(
sample
)
}
failed
{
message
}
"
)
class
FatalSampleError
(
SampleException
):
# This will not be handled by the error handler
pass
def
warn_deprecated
(
reason
,
stacklevel
=
2
):
warnings
.
warn
(
reason
,
FutureWarning
,
stacklevel
=
stacklevel
)
def
deprecated
(
reason
):
def
decorator
(
func
):
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
warn_deprecated
(
f
"
{
func
.
__name__
}
is deprecated:
{
reason
}
"
,
stacklevel
=
3
)
return
func
(
*
args
,
**
kwargs
)
return
wrapper
return
decorator
SYSTEM_EXCEPTIONS
=
(
SystemError
,
SyntaxError
,
ImportError
,
StopIteration
,
StopAsyncIteration
,
MemoryError
,
RecursionError
,
ReferenceError
,
NameError
,
UnboundLocalError
,
FatalSampleError
,
)
Megatron-Energon/src/megatron/energon/eyaml.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
json
from
typing
import
Any
,
BinaryIO
,
Union
import
ryml
from
megatron.energon.epathlib.epath
import
EPath
def
load_yaml_json
(
path
:
EPath
)
->
Any
:
"""Load a YAML or JSON file from a path. Loading json is much faster than yaml.
Args:
path: The path to the file to load.
Returns:
The parsed YAML or JSON data as a native Python object.
"""
if
path
.
name
.
endswith
(
".yaml"
):
with
path
.
open
(
"rb"
)
as
f
:
return
load_yaml
(
f
)
elif
path
.
name
.
endswith
(
".json"
):
with
path
.
open
(
"rb"
)
as
f
:
return
json
.
load
(
f
)
else
:
raise
ValueError
(
f
"Invalid file extension:
{
path
.
name
}
"
)
def
load_yaml
(
stream
:
Union
[
BinaryIO
,
bytes
])
->
Any
:
"""Load a YAML file from a stream or bytes object using rapidyaml/ryml.
This is much faster than the standard yaml library, but we don't
support all YAML features.
Args:
stream: A stream or bytes object containing the YAML data.
Returns:
The parsed YAML data as a native Python object.
"""
if
isinstance
(
stream
,
bytes
):
bytes_data
=
stream
else
:
bytes_data
=
stream
.
read
()
tree
=
ryml
.
parse_in_arena
(
bytes_data
)
native_obj
=
_ryml_tree_to_native
(
tree
)
return
native_obj
def
_cast_scalar
(
tree
:
ryml
.
Tree
,
nid
:
int
)
->
object
:
"""Convert YAML scalar to bool | int | float | None | str."""
memview
=
tree
.
val
(
nid
)
s
=
memview
.
tobytes
().
decode
()
if
tree
.
is_val_quoted
(
nid
):
return
s
s_lower
=
s
.
lower
()
if
s_lower
in
{
"null"
,
"~"
}:
return
None
if
s_lower
in
{
"true"
,
"yes"
}:
return
True
if
s_lower
in
{
"false"
,
"no"
}:
return
False
try
:
if
"."
in
s
or
"e"
in
s_lower
:
return
float
(
s
)
return
int
(
s
)
except
ValueError
:
return
s
def
_to_native
(
t
:
ryml
.
Tree
,
nid
:
int
):
if
t
.
is_map
(
nid
):
# iterate children of a mapping node
return
{
t
.
key
(
ch
).
tobytes
().
decode
():
_to_native
(
t
,
ch
)
for
ch
in
ryml
.
children
(
t
,
nid
)}
if
t
.
is_seq
(
nid
):
return
[
_to_native
(
t
,
ch
)
for
ch
in
ryml
.
children
(
t
,
nid
)]
# scalar leaf
return
_cast_scalar
(
t
,
nid
)
def
_ryml_tree_to_native
(
tree
:
ryml
.
Tree
):
return
_to_native
(
tree
,
tree
.
root_id
())
Megatron-Energon/src/megatron/energon/flavors/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.flavors.base_dataset
import
(
BaseCoreDatasetFactory
,
PinMemoryMixin
,
Sample
,
SavableDataset
,
)
from
megatron.energon.flavors.captioning
import
CaptioningSample
,
CaptioningWebdataset
from
megatron.energon.flavors.crude
import
CrudeSample
,
CrudeWebdataset
from
megatron.energon.flavors.image
import
ImageSample
,
ImageWebdataset
from
megatron.energon.flavors.image_classification
import
(
ImageClassificationSample
,
ImageClassificationWebdataset
,
)
from
megatron.energon.flavors.interleaved
import
InterleavedSample
,
InterleavedWebdataset
from
megatron.energon.flavors.jsonl
import
(
CrudeJsonlDatasetFactory
,
DefaultCrudeJsonlDatasetFactory
,
)
from
megatron.energon.flavors.multichoice_vqa
import
MultiChoiceVQASample
,
MultiChoiceVQAWebdataset
from
megatron.energon.flavors.ocr
import
OCRSample
,
OCRWebdataset
from
megatron.energon.flavors.similarity_interleaved
import
(
SimilarityInterleavedSample
,
SimilarityInterleavedWebdataset
,
)
from
megatron.energon.flavors.text
import
TextSample
,
TextWebdataset
from
megatron.energon.flavors.vid_qa
import
VidQASample
,
VidQAWebdataset
from
megatron.energon.flavors.vqa
import
VQASample
,
VQAWebdataset
from
megatron.energon.flavors.vqa_and_ocr
import
VQAOCRWebdataset
from
megatron.energon.flavors.webdataset
import
(
AVData
,
AVDecoder
,
AVDecoderType
,
BaseWebdatasetFactory
,
DefaultDecoderWebdatasetFactory
,
DefaultGenericWebdatasetFactory
,
EmptyDatasetError
,
ImageDecoderType
,
JoinedWebdatasetFactory
,
SampleDecoder
,
StandardWebdatasetFactory
,
WebdatasetMeta
,
)
__all__
=
[
"AVData"
,
"AVDecoder"
,
"AVDecoderType"
,
"BaseCoreDatasetFactory"
,
"BaseWebdatasetFactory"
,
"CaptioningSample"
,
"CaptioningWebdataset"
,
"CrudeJsonlDatasetFactory"
,
"CrudeSample"
,
"CrudeWebdataset"
,
"DefaultCrudeJsonlDatasetFactory"
,
"DefaultDecoderWebdatasetFactory"
,
"DefaultGenericWebdatasetFactory"
,
"EmptyDatasetError"
,
"ImageClassificationSample"
,
"ImageClassificationWebdataset"
,
"ImageDecoderType"
,
"ImageSample"
,
"ImageWebdataset"
,
"InterleavedSample"
,
"InterleavedWebdataset"
,
"JoinedWebdatasetFactory"
,
"MultiChoiceVQASample"
,
"MultiChoiceVQAWebdataset"
,
"OCRSample"
,
"OCRWebdataset"
,
"PinMemoryMixin"
,
"Sample"
,
"SampleDecoder"
,
"SavableDataset"
,
"SimilarityInterleavedSample"
,
"SimilarityInterleavedWebdataset"
,
"StandardWebdatasetFactory"
,
"TextSample"
,
"TextWebdataset"
,
"VidQASample"
,
"VidQAWebdataset"
,
"VQAOCRWebdataset"
,
"VQASample"
,
"VQAWebdataset"
,
"WebdatasetMeta"
,
]
Megatron-Energon/src/megatron/energon/flavors/base_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
dataclasses
import
inspect
import
typing
from
abc
import
ABC
,
abstractmethod
from
copy
import
deepcopy
from
dataclasses
import
is_dataclass
from
typing
import
(
Any
,
Callable
,
ClassVar
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
cast
,
)
import
torch
from
torch.utils.data
import
IterableDataset
from
typing_extensions
import
Self
from
megatron.energon.cache
import
FileStore
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.savable
import
Savable
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.state
import
FlexState
from
megatron.energon.worker
import
WorkerConfig
T_sample
=
TypeVar
(
"T_sample"
,
covariant
=
True
)
T
=
TypeVar
(
"T"
,
covariant
=
True
)
class
PinMemoryMixin
:
"""A mixin class providing a generic `pin_memory` function."""
def
_pin_memory
(
self
,
batch
:
T
,
device
:
Union
[
torch
.
device
,
str
,
None
]
=
None
)
->
T
:
"""Pin memory of a batch. Uses recursion to handle nested structures. Supports nested
structures of dicts, dataclasses, namedtuples, lists and tuples."""
if
isinstance
(
batch
,
torch
.
Tensor
):
return
batch
.
pin_memory
(
device
)
elif
isinstance
(
batch
,
dict
):
return
{
key
:
self
.
_pin_memory
(
value
,
device
)
for
key
,
value
in
batch
.
items
()}
elif
dataclasses
.
is_dataclass
(
batch
):
return
type
(
batch
)(
**
{
field
.
name
:
self
.
_pin_memory
(
getattr
(
batch
,
field
.
name
),
device
)
for
field
in
dataclasses
.
fields
(
batch
)
}
)
elif
isinstance
(
batch
,
(
tuple
,
list
)):
if
hasattr
(
batch
,
"_fields"
):
# NamedTuple
return
type
(
batch
)(
*
[
self
.
_pin_memory
(
val
,
device
)
for
val
in
batch
])
else
:
# list / tuple
return
type
(
batch
)(
self
.
_pin_memory
(
val
,
device
)
for
val
in
batch
)
else
:
return
batch
def
pin_memory
(
self
:
Self
)
->
Self
:
return
self
.
_pin_memory
(
self
)
class
ExtendableDataclassMixin
:
"""A mixin class providing a generic `extend` function for copying dataclasses."""
@
classmethod
def
extend
(
cls
:
Type
[
T
],
src
,
**
kwargs
)
->
T
:
"""
Used for overridden dataclass instances. Example
.. code-block:: python
@dataclass
class MyBaseClass:
a: List[int]
@dataclass
class MyExtendedClass(MyBaseClass):
# Add a new field `b` to the state
b: List[int]
base = MyBaseClass(a=[1, 2, 3])
extended = MyExtendedClass.extend(base, b=[4, 5, 6])
Args:
src: The source dataclass instance to extend.
**kwargs: The new fields to add to the instance to construct the new instance.
Returns:
The extended dataclass instance.
"""
assert
is_dataclass
(
cls
),
"Must be a dataclass"
assert
issubclass
(
cls
,
type
(
src
)),
"Cannot extend class of different type"
for
f
in
dataclasses
.
fields
(
src
):
if
not
f
.
init
or
f
.
type
is
ClassVar
or
typing
.
get_origin
(
f
.
type
)
is
ClassVar
:
continue
if
f
.
name
not
in
kwargs
:
kwargs
[
f
.
name
]
=
getattr
(
src
,
f
.
name
)
return
cls
(
**
kwargs
)
@
edataclass
class
Sample
(
ABC
,
PinMemoryMixin
,
ExtendableDataclassMixin
):
"""An abstract base class for one element of a batch.
Each task should derive a specific subclass as a `@dataclass`, like
:class:`megatron.energon.CaptioningBatchSample`, and add the input and output fields as needed for
training.
"""
#: Uniquely identifies each sample in the dataset.
__key__
:
str
#: Key for restoring the sample. This is used to restore the sample from a checkpoint. It
# should be a (nested) tuple of strings and integers, which can be used to index the dataset.
__restore_key__
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...]
#: A dataset may define a subflavors to distinguish between samples of the same sample type.
__subflavors__
:
Optional
[
Dict
[
str
,
Any
]]
=
None
#: Information about the source of the sample, i.e. where the data was loaded from.
__sources__
:
Optional
[
tuple
[
SourceInfo
,
...]]
=
None
@
classmethod
def
derive_from
(
cls
:
Type
[
T_sample
],
base_sample
:
"Sample"
,
**
kwargs
)
->
T_sample
:
"""
Uses the base fields of `Sample` from base_sample (i.e. __key__, __restore_key__, __subflavors__, __sources__)
and creates a new sample with the kwargs as fields. This is useful for creating new samples, while keeping the
metadata of the base sample.
Args:
base_sample: The base sample to copy the base fields / metadata from.
kwargs: The fields of the new sample.
Returns:
The new sample.
"""
base_kwargs
=
{
field
.
name
:
getattr
(
base_sample
,
field
.
name
)
for
field
in
dataclasses
.
fields
(
Sample
)
if
field
.
name
not
in
kwargs
}
return
cls
(
**
base_kwargs
,
**
kwargs
,
)
@
classmethod
def
from_joined
(
cls
:
Type
[
T_sample
],
*
args
:
"Optional[Sample]"
,
**
kwargs
:
"Optional[Sample]"
)
->
T_sample
:
"""
Creates a sample from joined samples. The samples are either passed as positional arguments or as keyword
arguments. The first sample is the primary sample, which is used to initialize the key and subflavors.
In the default implementation, the joined samples' fields will be joined together, such that latter joined
samples will update the fields last (i.e. take precedence), except for the key and subflavors. The restore key
is later set externally.
Args:
args: The samples to join (either this or kwargs is specified).
kwargs: The samples to join (either this or args is specified). Not supported for the default
implementation. Overwriting implementations may use this.
Returns:
The joined constructed sample.
"""
assert
len
(
kwargs
)
==
0
,
(
"Please specify joined datasets as list for the default joiner. Keyword arguments are confusing, because keys are ignored."
)
excluded_fields
=
set
(
field
.
name
for
field
in
dataclasses
.
fields
(
Sample
))
init_args
=
{}
if
len
(
args
)
>
0
:
primary
=
args
[
0
]
assert
primary
is
not
None
,
"Primary sample must not be None."
fields
=
dataclasses
.
fields
(
primary
)
for
field
in
fields
:
init_args
[
field
.
name
]
=
getattr
(
primary
,
field
.
name
)
# Merge sources from all joined samples
init_args
[
"__sources__"
]
=
(
*
(
primary
.
__sources__
or
()),
*
(
src
for
arg
in
args
if
arg
is
not
None
and
arg
.
__sources__
is
not
None
for
src
in
arg
.
__sources__
),
)
for
arg
in
args
:
if
arg
is
None
:
continue
fields
=
dataclasses
.
fields
(
arg
)
for
field
in
fields
:
if
field
.
name
not
in
excluded_fields
:
init_args
[
field
.
name
]
=
getattr
(
arg
,
field
.
name
)
return
cls
(
**
init_args
)
@
edataclass
class
State
(
ABC
,
ExtendableDataclassMixin
):
"""An abstract base class for the state of a dataset. See :class:`megatron.energon.SavableDataset`.
The state of a dataset is used to save and restore the dataset state (i.e. random generators,
buffer states, file pointers, etc.).
Each dataset should derive a specific subclass as a `@dataclass` and add the fields as needed
for training.
To extend subclasses, use the .extend method. Example:
.. code-block:: python
@dataclass
class MyState(State):
a: int
@dataclass
class MyExtendedState(MyState):
# Add a new field `b` to the state
b: int
class MyStateSaver:
def save_state(self) -> MyState:
return MyState(a=42)
class MyExtendedStateSaver(MyStateSaver):
def save_state(self) -> MyExtendedState:
# Fetch state from super class, which is already a complete instance (cannot add
# new fields to it, type is fixed).
state: MyState = super().save_state()
# Now extend the state of the super class (of type `MyState`) with the new field
# required to define `MyExtendedState`.
return MyExtendedState.extend(state, b=21)
"""
class
SavableDataset
(
IterableDataset
[
T_sample
],
Savable
,
Generic
[
T_sample
],
ABC
):
"""A dataset that can be saved and restored (i.e. the random state, internal buffers, etc.).
I.e. it can be resumed from a checkpoint.
How dataset state saving works:
1. The dataset state needs to be saved in all forked worker processes which contain a copy of
the main dataset instance (see :class:`megatron.energon.SavableDataLoader`). Each worker returns
only its own state.
2. The main process merges the states via the :meth:`megatron.energon.SavableDataset.merge_states`
method in the main process on the main dataset instance (which doesn't hold the worker states,
as they were forked).
3. The main process saves the merged state to the checkpoint.
"""
worker_config
:
WorkerConfig
#: List of names of the fields that are saved and restored in the state.
_savable_fields
:
ClassVar
[
Tuple
[
str
,
...]]
=
()
def
__init__
(
self
,
worker_config
:
WorkerConfig
):
self
.
worker_config
=
worker_config
@
abstractmethod
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
"""Returns the length of the dataset for the current or a specific worker.
The length is the number of different available samples.
The number of actually yielded samples may be different (considering skipping samples or generator functions).
Args:
worker_idx: The index of the worker to return the length for.
If None, the length of the current worker is returned (must be in worker context).
"""
...
def
len_rank
(
self
)
->
int
:
"""Returns the length of the dataset for the current rank.
The length is the number of different available samples.
The number of actually yielded samples may be different (considering skipping samples or generator functions).
"""
return
sum
(
self
.
len_worker
(
i
)
for
i
in
range
(
self
.
worker_config
.
num_workers
or
1
))
def
__len__
(
self
)
->
int
:
"""Returns the length of the dataset for the current rank. Corresponds to `len_rank`."""
return
self
.
len_rank
()
def
save_state
(
self
)
->
FlexState
:
"""
Saves the state of the dataset. This will save and return the state of all fields
in the _savable_fields tuple.
Can only be called in a worker process.
"""
state
=
FlexState
()
state
[
"__class__"
]
=
type
(
self
).
__name__
for
key
in
self
.
_savable_fields
:
attr
=
getattr
(
self
,
key
)
if
isinstance
(
attr
,
Savable
):
state
[
key
]
=
attr
.
save_state
()
else
:
# Check if this field is a simple python type or a user class
if
attr
is
not
None
and
getattr
(
attr
,
"__module__"
,
"builtins"
)
!=
"builtins"
:
import
warnings
warnings
.
warn
(
f
"The savable attribute
{
key
}
of class
{
type
(
self
)
}
does "
"not inherit from Savable, nor it is a simple builtin type. Please double-check."
,
UserWarning
,
)
state
[
key
]
=
deepcopy
(
getattr
(
self
,
key
))
return
state
def
restore_state
(
self
,
state
:
FlexState
)
->
None
:
"""
Restores the state of the dataset. This will restore the state of all fields
in the _savable_fields tuple.
Can only be called in a worker process.
Args:
state: The state of the dataset as savable object. If None, restore initial state.
"""
assert
state
[
"__class__"
]
==
type
(
self
).
__name__
,
(
f
"Class name mismatch:
{
state
[
'__class__'
]
}
!=
{
type
(
self
).
__name__
}
"
)
for
key
in
self
.
_savable_fields
:
assert
key
in
state
,
f
"Key
{
key
}
not in state
{
state
}
"
value
=
state
.
get
(
key
)
assert
hasattr
(
self
,
key
),
f
"Savable field
{
key
}
not in dataset
{
self
}
"
if
isinstance
(
getattr
(
self
,
key
),
Savable
):
getattr
(
self
,
key
).
restore_state
(
value
)
else
:
setattr
(
self
,
key
,
value
)
@
abstractmethod
def
reset_state_own
(
self
)
->
None
:
"""Resets the state of the dataset to the initial state. Can only be called in a worker process."""
...
def
reset_state_deep
(
self
)
->
None
:
"""Resets the state of the dataset to the initial state. Can only be called in a worker process."""
self
.
reset_state_own
()
@
abstractmethod
def
worker_has_samples
(
self
)
->
bool
:
"""Returns True if the worker's split has samples. This is used to determine if this dataset
yields anything."""
...
@
staticmethod
def
_function_config
(
fn
:
Callable
)
->
str
:
mod
=
inspect
.
getmodule
(
fn
)
if
mod
is
not
None
:
mod_name
=
mod
.
__name__
else
:
mod_name
=
getattr
(
fn
,
"__module__"
,
"<unknown>"
)
return
f
"
{
mod_name
}
.
{
getattr
(
fn
,
'__qualname__'
,
getattr
(
fn
,
'__name__'
,
'<unknown>'
))
}
"
@
abstractmethod
def
config
(
self
)
->
Dict
[
str
,
Any
]:
"""Return a config dict that can be used to check if datasets have the same settings.
Variables in dicts starting with "_" represent a possibly changable setting, like a full
path which may be changed."""
return
{
"type"
:
type
(
self
).
__qualname__
,
}
def
can_restore_sample
(
self
)
->
bool
:
"""Returns True if the dataset can restore a sample from a key."""
return
False
def
assert_can_restore
(
self
)
->
None
:
"""Asserts that the dataset can restore a sample from a key."""
assert
self
.
can_restore_sample
(),
"This dataset cannot restore samples."
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T_sample
:
"""
Generic key type, because it might be either an integer (for a core dataset), or something
more complex (e.g. for blended datasets).
Default raises an exception (assumed non-deterministic if not implemented, does not
guarantee determinism).
"""
raise
NotImplementedError
(
"This dataset does not support indexing, because it is not safely deterministic."
)
class
BaseCoreDatasetFactory
(
Generic
[
T_sample
],
ABC
):
"""Base type for an inner dataset sample loader. This factory can be used to construct a sample loader, or for
joining in a joined dataset."""
__sample_type__
:
Type
[
T_sample
]
=
cast
(
Type
[
T_sample
],
None
)
paths
:
List
[
EPath
]
subflavors
:
Dict
[
str
,
Any
]
@
abstractmethod
def
build
(
self
,
worker_rotation_offset
:
int
=
0
)
->
SavableDataset
[
T_sample
]:
"""Builds the dataset."""
...
@
abstractmethod
def
as_file_store
(
self
)
->
"FileStore"
:
"""Returns the dataset as a random access dataset."""
...
@
abstractmethod
def
__len__
(
self
)
->
int
:
"""Returns the length of the dataset across all ranks."""
...
def
add_sample_restore_key
(
sample
:
T_sample
,
*
key
:
Union
[
int
,
str
],
src
:
Any
,
fail_otherwise
:
bool
=
False
)
->
T_sample
:
"""Adds a key to a sample. The sample must be a valid `Sample` or dict containing
__restore_key__, which is a tuple of keys that can be used to restore the inner sample.
This restore key is prepended with the `key`."""
if
isinstance
(
sample
,
Sample
)
or
hasattr
(
sample
,
"__restore_key__"
):
try
:
sample
.
__restore_key__
=
(
type
(
src
).
__name__
,
*
key
,
*
sample
.
__restore_key__
)
except
KeyError
:
pass
elif
isinstance
(
sample
,
dict
)
and
"__restore_key__"
in
sample
:
sample
[
"__restore_key__"
]
=
(
type
(
src
).
__name__
,
*
key
,
*
sample
[
"__restore_key__"
])
elif
fail_otherwise
:
raise
RuntimeError
(
"Did not yield a sample with a restore key, but is marked stateless/deterministic."
)
return
sample
def
set_sample_restore_key
(
sample
:
T_sample
,
*
key
:
Union
[
int
,
str
],
src
:
Any
,
fail_otherwise
:
bool
=
False
)
->
T_sample
:
"""Sets the restore key for a sample. The sample must be a valid `Sample` or dict containing
__restore_key__, which is a tuple of keys that can be used to restore the inner sample.
This restore key is prepended with the `key`."""
if
isinstance
(
sample
,
Sample
)
or
hasattr
(
sample
,
"__restore_key__"
):
try
:
sample
.
__restore_key__
=
(
type
(
src
).
__name__
,
*
key
)
except
KeyError
:
pass
elif
isinstance
(
sample
,
dict
)
and
"__restore_key__"
in
sample
:
sample
[
"__restore_key__"
]
=
(
type
(
src
).
__name__
,
*
key
)
elif
fail_otherwise
:
raise
RuntimeError
(
"Did not yield a sample with a restore key, but is marked stateless/deterministic."
)
return
sample
def
legacy_handler
(
handler
:
Union
[
Callable
[[
Exception
,
Optional
[
str
],
Optional
[
list
[
SourceInfo
]]],
None
],
Callable
[[
Exception
,
Optional
[
str
]],
None
],
],
)
->
Callable
[[
Exception
,
Optional
[
str
],
Optional
[
list
[
SourceInfo
]]],
None
]:
"""Safely returns the new style three argument handler. If the handler takes 2 arguments, it wraps it."""
import
functools
import
inspect
handler_sig
=
inspect
.
signature
(
handler
)
if
len
(
handler_sig
.
parameters
)
!=
3
:
original_handler
=
handler
@
functools
.
wraps
(
original_handler
)
def
wrapped_handler
(
exc
:
Exception
,
key
:
Optional
[
str
],
source_infos
:
Optional
[
list
[
SourceInfo
]]
)
->
None
:
return
original_handler
(
exc
,
key
)
return
wrapped_handler
else
:
return
handler
Megatron-Energon/src/megatron/energon/flavors/captioning.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
torch
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
CaptioningSample
(
Sample
):
"""Sample type for image captioning."""
#: The input image tensor in the shape (C, H, W)
image
:
torch
.
Tensor
#: The caption string
caption
:
str
class
CaptioningWebdataset
(
DefaultDecoderWebdatasetFactory
[
CaptioningSample
]):
__sample_type__
=
CaptioningSample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/crude.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
class
CrudeSample
(
dict
):
"""Generic sample type to be processed later."""
class
CrudeWebdataset
(
DefaultDecoderWebdatasetFactory
[
CrudeSample
]):
"""The CrudeWebdataset is used to load crude / raw samples and
decode them in the user code using so-called cookers.
See the documentation under "Crude Data" for more information.
"""
__sample_type__
=
CrudeSample
def
__init__
(
self
,
path
:
EPath
,
*
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
part_filter
:
Union
[
str
,
List
[
str
],
Callable
[[
str
],
bool
]]
=
lambda
_
:
True
,
**
kwargs
,
):
"""
Constructs a crude webdataset.
Args:
path: Root path to the joined datasets.
subflavors: Subflavors dictionary to set for all loaded samples.
part_filter: Function for filtering tar files to load by dict keys.
**kwargs: Additional arguments to the BaseWebdataset constructor.
"""
# We skip the parent class __init__ and call the BaseWebdataset.__init__ directly
if
"sample_loader"
in
kwargs
:
raise
ValueError
(
"sample_loader is not allowed to be set when using CrudeWebdataset"
)
super
().
__init__
(
path
,
subflavors
=
subflavors
,
sample_loader
=
lambda
sample
:
sample
,
part_filter
=
part_filter
,
**
kwargs
,
)
Megatron-Energon/src/megatron/energon/flavors/image.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
torch
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
ImageSample
(
Sample
):
"""Sample type for an image, e.g. for image reconstruction."""
#: The input image tensor in the shape (C, H, W)
image
:
torch
.
Tensor
class
ImageWebdataset
(
DefaultDecoderWebdatasetFactory
[
ImageSample
]):
__sample_type__
=
ImageSample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/image_classification.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Optional
import
torch
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
ImageClassificationSample
(
Sample
):
"""Sample type for classifying an image."""
#: The input image tensor in the shape (C, H, W)
image
:
torch
.
Tensor
#: The class label of the image
label
:
Optional
[
int
]
=
None
#: The class label of the image
label_name
:
Optional
[
str
]
=
None
class
ImageClassificationWebdataset
(
DefaultDecoderWebdatasetFactory
[
ImageClassificationSample
]):
__sample_type__
=
ImageClassificationSample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/interleaved.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
List
,
Union
import
torch
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib.epath
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.webdataset
import
DefaultDecoderWebdatasetFactory
@
edataclass
class
InterleavedSample
(
Sample
):
"""Sample type for interleaved media such as text with images."""
#: The interleaved media (either torch.tensor for an image, or str for text)
sequence
:
List
[
Union
[
torch
.
Tensor
,
str
]]
class
InterleavedWebdataset
(
DefaultDecoderWebdatasetFactory
[
InterleavedSample
]):
__sample_type__
=
InterleavedSample
def
__init__
(
self
,
path
:
EPath
,
**
kwargs
):
warn_deprecated
(
f
"
{
type
(
self
)
}
is deprecated, use the default instead and set the sample_type:
\n
"
f
"To convert, update your
{
path
}
/.nv-meta/dataset.yaml to:
\n
"
f
"# remove top-level __module__ and __class__
\n
"
f
"sample_type:
\n
"
f
" __module__: megatron.energon
\n
"
f
" __class__:
{
self
.
__sample_type__
.
__name__
}
\n
"
f
"# Keep the remaining content"
)
super
().
__init__
(
path
,
**
kwargs
)
Megatron-Energon/src/megatron/energon/flavors/jsonl/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.flavors.jsonl.crude_jsonl_dataset
import
(
CrudeJsonlDatasetFactory
,
DefaultCrudeJsonlDatasetFactory
,
)
__all__
=
[
"CrudeJsonlDatasetFactory"
,
"DefaultCrudeJsonlDatasetFactory"
,
]
Megatron-Energon/src/megatron/energon/flavors/jsonl/crude_jsonl_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
json
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
Optional
from
megatron.energon.cache
import
FileStore
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.base_dataset
import
(
BaseCoreDatasetFactory
,
SavableDataset
,
legacy_handler
,
)
from
megatron.energon.flavors.crude
import
CrudeSample
from
megatron.energon.flavors.jsonl.ijsonl
import
IJsonlIndexReader
from
megatron.energon.flavors.jsonl.jsonl_prepare
import
JsonlPreparator
from
megatron.energon.flavors.webdataset.base_webdataset
import
_print_shard_slices
from
megatron.energon.flavors.webdataset.error_handler
import
ErrorHandler
from
megatron.energon.flavors.webdataset.sample_loader
import
(
RawSampleData
,
WebdatasetSampleLoaderDataset
,
)
from
megatron.energon.flavors.webdataset.sharder
import
Sharder
from
megatron.energon.flavors.webdataset.structs
import
(
DatasetSubset
,
FilteredSample
,
ShardInfo
,
reraise_exception
,
)
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.map_dataset
import
MapDataset
logger
=
logging
.
getLogger
(
__name__
)
class
CrudeJsonlDatasetFactory
(
BaseCoreDatasetFactory
[
CrudeSample
],
JsonlPreparator
,
Sharder
,
ErrorHandler
,
):
"""
Factory class for creating a crude dataset from JSONL (JSON Lines) files.
This factory creates datasets from JSONL files where each line contains a JSON object.
The samples are returned as CrudeSample objects (dictionary-like) containing the raw JSON data.
"""
__sample_type__
=
CrudeSample
path
:
EPath
training
:
bool
worker_config
:
WorkerConfig
def
__init__
(
self
,
path
:
EPath
,
*
,
training
:
bool
,
worker_config
:
WorkerConfig
,
shuffle_over_epochs
:
Optional
[
int
]
=
1
,
parallel_shard_iters
:
Optional
[
int
]
=
None
,
max_samples_per_sequence
:
Optional
[
int
]
=
None
,
subset
:
Optional
[
DatasetSubset
]
=
None
,
part_filter
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
handler
:
Callable
[
[
Exception
,
Optional
[
str
],
Optional
[
list
[
SourceInfo
]]],
None
]
=
reraise_exception
,
):
"""
Factory for a jsonl file as a crude dataset.
Args:
path: Path to the jsonl file.
training: If true, apply shuffling and loop the dataset.
worker_config: Configuration for the workers.
shuffle_over_epochs: Only effective if training=True.
How many epochs to shuffle over if training.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather shard slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices
are drawn with replacement).
parallel_shard_iters: Number of parallel opened shards per worker, shuffling between.
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequentially iterated).
subset: If specified, the dataset will be subsetted.
part_filter: (internal) Function for filtering tar files by dict keys
handler: Exception handler. Args: (exception, key).
"""
assert
self
.
__sample_type__
is
not
None
,
f
"Class
{
type
(
self
)
}
must define __sample_type__"
self
.
path
=
path
self
.
paths
=
[
path
]
self
.
training
=
training
self
.
worker_config
=
worker_config
self
.
shuffle_over_epochs
=
shuffle_over_epochs
self
.
parallel_shard_iters
=
parallel_shard_iters
self
.
max_samples_per_sequence
=
max_samples_per_sequence
self
.
subset
=
subset
self
.
part_filter
=
part_filter
self
.
handler
=
legacy_handler
(
handler
)
if
part_filter
is
None
or
part_filter
(
"json"
):
self
.
_len
=
IJsonlIndexReader
.
count_samples
(
path
)
else
:
self
.
_len
=
0
assert
self
.
path
.
size
()
==
IJsonlIndexReader
.
size
(
path
),
(
"The index of the jsonl file does not match the file. Regenerate the index."
)
def
__len__
(
self
)
->
int
:
return
self
.
_len
def
build
(
self
,
worker_rotation_offset
:
int
=
0
)
->
SavableDataset
[
CrudeSample
]:
from
megatron.energon.flavors.jsonl.ijsonl_reader
import
IJsonlReader
if
self
.
parallel_shard_iters
is
None
:
if
self
.
training
:
# 16 seems to be a good choice since we don't want too many file handles open
parallel_shard_iters
=
16
else
:
parallel_shard_iters
=
1
else
:
parallel_shard_iters
=
self
.
parallel_shard_iters
virtual_shards
=
[
ShardInfo
(
name
=
self
.
path
.
name
,
path
=
self
.
path
,
count
=
self
.
_len
,
)
]
workers_sample_slice_offsets
=
self
.
shard_workers
(
virtual_shards
,
worker_config
=
self
.
worker_config
,
max_samples_per_sequence
=
self
.
max_samples_per_sequence
,
rotation_offset
=
worker_rotation_offset
,
subset
=
self
.
subset
,
)
_print_shard_slices
(
self
.
worker_config
,
virtual_shards
,
workers_sample_slice_offsets
)
itar_reader
=
IJsonlReader
(
self
.
path
,
index_cache_size
=
parallel_shard_iters
,
)
dataset
=
WebdatasetSampleLoaderDataset
(
join_readers
=
[
itar_reader
],
workers_sample_slice_offsets
=
workers_sample_slice_offsets
,
worker_config
=
self
.
worker_config
,
shuffle_over_epochs
=
self
.
shuffle_over_epochs
if
self
.
training
else
None
,
parallel_slice_iters
=
parallel_shard_iters
,
)
return
MapDataset
(
dataset
,
self
.
_load_sample_raw
,
error_handler
=
self
.
error_handler
,
stateless_map_fn
=
True
,
map_fn_config
=
self
.
config
,
worker_config
=
self
.
worker_config
,
)
def
as_file_store
(
self
)
->
"FileStore"
:
from
megatron.energon.cache.file_store
import
JsonlFileStore
return
JsonlFileStore
(
self
.
path
)
def
_load_sample
(
self
,
sample
:
FilteredSample
)
->
CrudeSample
:
return
CrudeSample
(
sample
)
def
_load_sample_raw
(
self
,
raw_sample
:
RawSampleData
)
->
CrudeSample
:
# Just a wrapper for the inner tuple. Tuple should be of length 1.
assert
len
(
raw_sample
.
data
)
==
1
and
raw_sample
.
data
[
0
]
is
not
None
return
self
.
_load_sample
(
raw_sample
.
data
[
0
])
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
type
=
type
(
self
).
__qualname__
,
training
=
self
.
training
,
_path
=
str
(
self
.
path
),
jsonl_filename
=
self
.
path
.
name
,
count
=
self
.
_len
,
shuffle_over_epochs
=
self
.
shuffle_over_epochs
,
parallel_shard_iters
=
self
.
parallel_shard_iters
,
max_samples_per_sequence
=
self
.
max_samples_per_sequence
,
subset
=
self
.
subset
.
config
()
if
self
.
subset
is
not
None
else
None
,
)
def
__str__
(
self
):
return
f
"
{
type
(
self
).
__name__
}
(path=
{
self
.
path
}
)"
class
DefaultCrudeJsonlDatasetFactory
(
CrudeJsonlDatasetFactory
):
"""
Adds subflavors to the sample and loads the json.
"""
def
__init__
(
self
,
path
:
EPath
,
*
,
subflavors
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
):
if
"decoder"
in
kwargs
:
del
kwargs
[
"decoder"
]
super
().
__init__
(
path
,
**
kwargs
)
self
.
subflavors
=
subflavors
def
_load_sample
(
self
,
sample
:
FilteredSample
)
->
CrudeSample
:
sample
[
"__subflavors__"
]
=
self
.
subflavors
# Instead of using a decoder, we just load the json here, as we know it's json.
sample
[
"json"
]
=
json
.
loads
(
sample
[
"json"
])
return
super
().
_load_sample
(
sample
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
**
super
().
config
(),
subflavors
=
self
.
subflavors
,
)
Megatron-Energon/src/megatron/energon/flavors/jsonl/ijsonl.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
struct
from
typing
import
BinaryIO
,
Dict
,
Generator
,
Optional
,
Tuple
,
Union
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.epathlib
import
EPath
IJSONL_SUFFIX
=
".jsonl.idx"
@
edataclass
class
IJsonlSamplePointer
:
"""
Points to a sample inside some jsonl file on disk.
"""
# The index of the sample in the jsonl file.
index
:
int
# The byte offset of the sample in the jsonl file.
byte_offset
:
int
# The size of the sample in the jsonl file.
byte_size
:
int
class
IJsonlIndexReader
:
def
__init__
(
self
,
jsonl_path
:
Union
[
EPath
,
str
]):
jsonl_path
=
EPath
(
jsonl_path
)
index_path
=
jsonl_path
.
with_suffix
(
IJSONL_SUFFIX
)
self
.
_length
=
index_path
.
size
()
//
8
self
.
ijsonl
=
index_path
.
open
(
"rb"
)
def
__getitem__
(
self
,
index
:
int
)
->
int
:
if
index
>=
self
.
_length
or
index
<
0
:
raise
IndexError
(
f
"Index
{
index
}
out of range"
)
if
self
.
ijsonl
.
tell
()
!=
8
*
index
:
self
.
ijsonl
.
seek
(
8
*
index
)
return
struct
.
unpack
(
"Q"
,
self
.
ijsonl
.
read
(
8
))[
0
]
def
__iter__
(
self
)
->
Generator
[
int
,
None
,
None
]:
self
.
ijsonl
.
seek
(
0
)
while
True
:
raw
=
self
.
ijsonl
.
read
(
8
)
if
len
(
raw
)
==
0
:
break
assert
len
(
raw
)
==
8
yield
struct
.
unpack
(
"Q"
,
raw
)[
0
]
def
__len__
(
self
)
->
int
:
return
self
.
_length
def
close
(
self
):
self
.
ijsonl
.
close
()
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
close
()
@
staticmethod
def
count_samples
(
jsonl_path
:
EPath
|
str
)
->
int
:
return
EPath
(
jsonl_path
).
with_suffix
(
IJSONL_SUFFIX
).
size
()
//
8
-
1
@
staticmethod
def
size
(
jsonl_path
:
EPath
)
->
int
:
with
IJsonlIndexReader
(
jsonl_path
)
as
reader
:
return
reader
[
len
(
reader
)
-
1
]
class
IJsonlIndexWriter
:
def
__init__
(
self
,
jsonl_path
:
EPath
):
self
.
final_name
=
jsonl_path
.
with_suffix
(
IJSONL_SUFFIX
)
self
.
tmp_name
=
jsonl_path
.
with_suffix
(
IJSONL_SUFFIX
+
".tmp"
)
self
.
ijsonl
=
self
.
tmp_name
.
open
(
"wb"
)
def
append
(
self
,
offset
:
int
):
self
.
ijsonl
.
write
(
struct
.
pack
(
"Q"
,
offset
))
def
close
(
self
,
finalize
:
bool
=
True
):
self
.
ijsonl
.
close
()
if
finalize
:
self
.
tmp_name
.
move
(
self
.
final_name
)
else
:
self
.
tmp_name
.
unlink
()
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
close
(
finalize
=
exc_val
is
None
)
@
edataclass
class
CacheEntry
:
ijsonl_index_reader
:
IJsonlIndexReader
lookahead_offset
:
Optional
[
int
]
=
None
lookahead_byteoffset
:
Optional
[
int
]
=
None
class
CachedIJsonlOffsetReader
:
"""
This class is a high-level wrapper around IJsonlIndexReader that caches some
of the recent lookups for faster access. It is designed for the case when
you need to read multiple offsets from the same jsonl file.
Args:
cache_size: The number of entries to keep in the cache. By default, we keep 32.
"""
def
__init__
(
self
,
jsonl_file
:
Union
[
str
,
EPath
],
cache_size
:
int
=
32
):
# Maps current_offset -> CacheEntry
self
.
ijsonl_index_reader_cache
:
Dict
[
int
,
CacheEntry
]
=
{}
self
.
cache_size
=
cache_size
self
.
jsonl_file
=
EPath
(
jsonl_file
)
def
close
(
self
):
for
cache_entry
in
self
.
ijsonl_index_reader_cache
.
values
():
cache_entry
.
ijsonl_index_reader
.
close
()
self
.
ijsonl_index_reader_cache
.
clear
()
def
_find_or_create_entry
(
self
,
sample_offset
:
int
,
)
->
CacheEntry
:
"""
1. If we already have a key == sample_offset, return it.
2. Otherwise, create a new entry or reuse the oldest entry.
"""
# Direct hit in the cache?
if
sample_offset
in
self
.
ijsonl_index_reader_cache
:
return
self
.
ijsonl_index_reader_cache
[
sample_offset
]
# We didn't find an existing entry. Create a new one.
# Evict if needed.
if
len
(
self
.
ijsonl_index_reader_cache
)
>=
self
.
cache_size
:
# Reuse the oldest entry
oldest_key
=
next
(
iter
(
self
.
ijsonl_index_reader_cache
))
cache_entry
=
self
.
ijsonl_index_reader_cache
.
pop
(
oldest_key
)
else
:
new_reader
=
IJsonlIndexReader
(
self
.
jsonl_file
)
cache_entry
=
CacheEntry
(
ijsonl_index_reader
=
new_reader
)
self
.
ijsonl_index_reader_cache
[
sample_offset
]
=
cache_entry
return
cache_entry
def
_get_ijsonl_byte_offset_with_entry
(
self
,
cache_entry
:
CacheEntry
,
sample_offset
:
int
,
)
->
Tuple
[
int
,
int
]:
"""
Return (start_byte_offset, length_to_next),
possibly using per-entry lookahead for speed.
"""
ijsonl_index_reader
=
cache_entry
.
ijsonl_index_reader
# If offset=0, define the result as byte offset=0 for convenience
if
sample_offset
==
0
:
result_byte_offset
=
0
elif
sample_offset
==
cache_entry
.
lookahead_offset
:
# Reuse the previously cached byte offset from the lookahead
assert
cache_entry
.
lookahead_byteoffset
is
not
None
,
(
"Lookahead offset matched but no lookahead byte offset found."
)
result_byte_offset
=
cache_entry
.
lookahead_byteoffset
else
:
# Normal random access
result_byte_offset
=
ijsonl_index_reader
[
sample_offset
]
# Prepare the lookahead for (sample_offset+1)
next_offset
=
sample_offset
+
1
try
:
cache_entry
.
lookahead_byteoffset
=
ijsonl_index_reader
[
next_offset
]
cache_entry
.
lookahead_offset
=
next_offset
except
IndexError
:
cache_entry
.
lookahead_offset
=
None
cache_entry
.
lookahead_byteoffset
=
None
# length = difference to the next offset, or 0 if none
if
cache_entry
.
lookahead_byteoffset
is
not
None
:
length
=
cache_entry
.
lookahead_byteoffset
-
result_byte_offset
else
:
length
=
0
return
result_byte_offset
,
length
def
get_ijsonl_byte_offset
(
self
,
sample_offset
:
int
=
0
,
)
->
Tuple
[
int
,
int
]:
"""
High-level API to get the byte offset and length for the given file & sample_offset.
"""
# Find or create the suitable CacheEntry
entry
=
self
.
_find_or_create_entry
(
sample_offset
)
# Use (and update) the per-entry lookahead logic
result_byte_offset
,
length
=
self
.
_get_ijsonl_byte_offset_with_entry
(
entry
,
sample_offset
)
# Update cache entry with the new offset
self
.
ijsonl_index_reader_cache
.
pop
(
sample_offset
)
if
entry
.
lookahead_offset
is
not
None
:
new_key
=
entry
.
lookahead_offset
if
new_key
not
in
self
.
ijsonl_index_reader_cache
:
self
.
ijsonl_index_reader_cache
[
new_key
]
=
entry
else
:
# Already have this entry in the cache, so we can close the reader and use the existing one
# TODO: We may actually may want to keep multiple readers open, because they may be multiple
# sequences to the same sequence.
entry
.
ijsonl_index_reader
.
close
()
else
:
# No lookahead, so we can close the reader
entry
.
ijsonl_index_reader
.
close
()
return
result_byte_offset
,
length
def
__len__
(
self
)
->
int
:
if
len
(
self
.
ijsonl_index_reader_cache
)
==
0
:
return
IJsonlIndexReader
.
count_samples
(
self
.
jsonl_file
)
return
len
(
next
(
iter
(
self
.
ijsonl_index_reader_cache
.
values
())).
ijsonl_index_reader
)
-
1
def
get_total_size
(
self
)
->
int
:
if
len
(
self
.
ijsonl_index_reader_cache
)
==
0
:
self
.
ijsonl_index_reader_cache
[
0
]
=
CacheEntry
(
ijsonl_index_reader
=
IJsonlIndexReader
(
self
.
jsonl_file
)
)
reader
=
next
(
iter
(
self
.
ijsonl_index_reader_cache
.
values
())).
ijsonl_index_reader
return
reader
[
len
(
reader
)
-
1
]
class
IJsonlFile
:
"""
This class is a high-level wrapper around a binary file that allows for reading a jsonl file,
with random access while keeping the file open.
Usage:
with open(filename, "rb") as fileobj:
with IJsonlFile(fileobj=fileobj) as f:
data = f.next(offset=101888, size=100)
json.loads(data)
# Or, if you want to read the whole file:
with open(filename, "rb") as fileobj:
with IJsonlFile(fileobj=fileobj) as f:
while True:
data = f.next()
if data is None:
break
json.loads(data)
# Or, if you want to read the whole file:
with open(filename, "rb") as fileobj:
with IJsonlFile(fileobj=fileobj) as f:
for data in f:
json.loads(data)
"""
def
__init__
(
self
,
fileobj
:
BinaryIO
):
self
.
fileobj
=
fileobj
def
seek
(
self
,
offset
:
int
):
self
.
fileobj
.
seek
(
offset
)
def
next
(
self
,
offset
:
int
|
None
=
None
,
size
:
int
|
None
=
None
)
->
bytes
|
None
:
if
offset
is
not
None
and
offset
!=
self
.
fileobj
.
tell
():
self
.
fileobj
.
seek
(
offset
)
if
size
is
None
:
entry
=
self
.
fileobj
.
readline
()
if
entry
==
b
""
:
return
None
return
entry
else
:
assert
size
>
0
,
"Size must contain at least the line terminator and a json object"
data
=
self
.
fileobj
.
read
(
size
)
if
data
==
b
""
:
return
None
return
data
def
__iter__
(
self
)
->
Generator
[
bytes
,
None
,
None
]:
while
True
:
data
=
self
.
next
()
if
data
is
None
:
break
yield
data
def
close
(
self
):
self
.
fileobj
.
close
()
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
fileobj
.
close
()
Prev
1
2
3
4
5
6
7
8
9
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment