Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
Qwen2.5-14B-Instruct_dcu-megatron
Commits
f356f546
Commit
f356f546
authored
Feb 04, 2026
by
maming
Browse files
Initial commit
parents
Pipeline
#3339
canceled with stages
Changes
346
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5617 additions
and
0 deletions
+5617
-0
Megatron-Energon/src/megatron/energon/module_loader.py
Megatron-Energon/src/megatron/energon/module_loader.py
+71
-0
Megatron-Energon/src/megatron/energon/retry_stream.py
Megatron-Energon/src/megatron/energon/retry_stream.py
+94
-0
Megatron-Energon/src/megatron/energon/rng.py
Megatron-Energon/src/megatron/energon/rng.py
+172
-0
Megatron-Energon/src/megatron/energon/savable.py
Megatron-Energon/src/megatron/energon/savable.py
+21
-0
Megatron-Energon/src/megatron/energon/savable_loader.py
Megatron-Energon/src/megatron/energon/savable_loader.py
+1409
-0
Megatron-Energon/src/megatron/energon/source_info.py
Megatron-Energon/src/megatron/energon/source_info.py
+38
-0
Megatron-Energon/src/megatron/energon/state.py
Megatron-Energon/src/megatron/energon/state.py
+6
-0
Megatron-Energon/src/megatron/energon/task_encoder/__init__.py
...ron-Energon/src/megatron/energon/task_encoder/__init__.py
+38
-0
Megatron-Energon/src/megatron/energon/task_encoder/base.py
Megatron-Energon/src/megatron/energon/task_encoder/base.py
+1152
-0
Megatron-Energon/src/megatron/energon/task_encoder/cooking.py
...tron-Energon/src/megatron/energon/task_encoder/cooking.py
+125
-0
Megatron-Energon/src/megatron/energon/task_encoder/loader.py
Megatron-Energon/src/megatron/energon/task_encoder/loader.py
+277
-0
Megatron-Energon/src/megatron/energon/tools/__init__.py
Megatron-Energon/src/megatron/energon/tools/__init__.py
+2
-0
Megatron-Energon/src/megatron/energon/tools/analyze_debug.py
Megatron-Energon/src/megatron/energon/tools/analyze_debug.py
+623
-0
Megatron-Energon/src/megatron/energon/tools/checkpoint.py
Megatron-Energon/src/megatron/energon/tools/checkpoint.py
+357
-0
Megatron-Energon/src/megatron/energon/tools/info.py
Megatron-Energon/src/megatron/energon/tools/info.py
+119
-0
Megatron-Energon/src/megatron/energon/tools/lint.py
Megatron-Energon/src/megatron/energon/tools/lint.py
+123
-0
Megatron-Energon/src/megatron/energon/tools/mount.py
Megatron-Energon/src/megatron/energon/tools/mount.py
+407
-0
Megatron-Energon/src/megatron/energon/tools/prepare.py
Megatron-Energon/src/megatron/energon/tools/prepare.py
+413
-0
Megatron-Energon/src/megatron/energon/tools/preview.py
Megatron-Energon/src/megatron/energon/tools/preview.py
+140
-0
Megatron-Energon/src/megatron/energon/transforms/__init__.py
Megatron-Energon/src/megatron/energon/transforms/__init__.py
+30
-0
No files found.
Too many changes to show.
To preserve performance only
346 of 346+
files are displayed.
Plain diff
Email patch
Megatron-Energon/src/megatron/energon/module_loader.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Callable
,
ClassVar
,
Dict
,
Optional
,
Union
from
megatron.energon.epathlib
import
EPath
class
ModuleLoader
:
"""Class providing access to python files to be loaded as modules."""
#: Loaded modules. Global singleton, no need to load a module multiple times.
_modules
:
ClassVar
[
Dict
[
EPath
,
Dict
[
str
,
Any
]]]
=
{}
def
_load_module
(
self
,
path
:
Union
[
str
,
EPath
]):
"""Internally loads a module"""
path
=
EPath
(
path
)
with
path
.
open
(
"r"
)
as
rf
:
src_code
=
rf
.
read
()
code
=
compile
(
src_code
,
path
.
url
,
"exec"
)
mod
=
{}
eval
(
code
,
mod
)
self
.
_modules
[
path
]
=
mod
def
__getitem__
(
self
,
path
:
Union
[
str
,
EPath
])
->
Any
:
"""
Gets a module by path.
Args:
path: The path to load
Returns:
Loaded module dict.
"""
path
=
EPath
(
path
)
if
path
not
in
self
.
_modules
:
self
.
_load_module
(
path
)
return
self
.
_modules
[
path
]
def
get_function
(
self
,
module_fn
:
str
,
default_fn_name
:
Optional
[
str
]
=
None
,
relative_path
:
Optional
[
EPath
]
=
None
,
)
->
Callable
:
"""
Loads a function in the form `"module/submodule.py:function_name"`.
Args:
module_fn: The function string to load. Must contain the path of the python module
and a function name separated by a colon (might be omitted if `default_fn_name` is
specified).
default_fn_name: Default name of the function if not given in `module_fn` string.
relative_path: The relative parent path to the module. If not specified, the current
working directory / absolute path is used.
Returns:
The function from the module
"""
if
":"
in
module_fn
:
module
,
fn_name
=
module_fn
.
rsplit
(
":"
,
1
)
else
:
if
default_fn_name
is
None
:
raise
ValueError
(
"Function name must be specified"
)
fn_name
=
default_fn_name
module
=
module_fn
if
relative_path
is
not
None
:
module
=
str
(
relative_path
/
module
)
fn
=
self
[
module
][
fn_name
]
assert
callable
(
fn
)
return
fn
Megatron-Energon/src/megatron/energon/retry_stream.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
types
import
TracebackType
from
typing
import
BinaryIO
,
Optional
,
Type
from
megatron.energon.epathlib
import
EPath
class
RetryReadStream
(
BinaryIO
):
"""A stream that retries reading from a file. Only supports reading bytes."""
_path
:
EPath
_file
:
Optional
[
BinaryIO
]
_pos
:
int
_size
:
int
def
__init__
(
self
,
path
:
EPath
):
"""Construct a RetryReadStream. It reads only bytes from a file."""
self
.
_path
=
path
self
.
_file
=
None
self
.
_pos
=
0
self
.
_size
=
path
.
size
()
def
__enter__
(
self
)
->
"RetryReadStream"
:
return
self
def
__exit__
(
self
,
exc_type
:
Type
[
BaseException
],
exc_val
:
BaseException
,
exc_tb
:
TracebackType
)
->
None
:
self
.
close
()
def
close
(
self
)
->
None
:
if
self
.
_file
is
not
None
:
self
.
_file
.
close
()
def
read
(
self
,
n
:
int
=
-
1
)
->
bytes
:
buf
=
b
""
for
retry
in
range
(
10
):
try
:
if
self
.
_file
is
None
:
self
.
_file
=
self
.
_path
.
open
(
"rb"
)
self
.
_file
.
seek
(
self
.
_pos
)
res
=
self
.
_file
.
read
(
n
)
self
.
_pos
+=
len
(
res
)
buf
+=
res
if
(
(
n
==
-
1
and
self
.
_pos
>=
self
.
_size
)
or
len
(
buf
)
==
n
or
self
.
_pos
>=
self
.
_size
):
return
res
except
IOError
:
try
:
self
.
_file
.
close
()
except
IOError
:
pass
self
.
_file
=
None
if
retry
==
9
:
raise
continue
def
seek
(
self
,
offset
:
int
,
whence
:
int
=
0
)
->
int
:
if
whence
==
0
:
pass
elif
whence
==
1
:
offset
+=
self
.
_pos
elif
whence
==
2
:
offset
+=
self
.
_size
else
:
raise
ValueError
(
f
"Invalid whence value:
{
whence
}
"
)
offset
=
min
(
max
(
offset
,
0
),
self
.
_size
)
self
.
_pos
=
offset
try
:
if
self
.
_file
is
not
None
:
self
.
_file
.
seek
(
offset
)
except
IOError
:
pass
return
self
.
_pos
def
tell
(
self
)
->
int
:
return
self
.
_pos
def
isatty
(
self
)
->
bool
:
return
False
def
readable
(
self
)
->
bool
:
return
True
def
seekable
(
self
)
->
bool
:
return
True
def
writable
(
self
)
->
bool
:
return
False
Megatron-Energon/src/megatron/energon/rng.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
hashlib
import
random
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Sequence
,
TypeVar
import
numpy
import
torch
import
torch.distributed
import
torch.utils.data
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.savable
import
FlexState
,
Savable
from
megatron.energon.worker
import
WorkerConfig
T
=
TypeVar
(
"T"
)
class
WorkerRng
(
Savable
):
"""Helper class for getting a worker random generator, which is still in itself deterministic.
If not in a worker, uses the global random generator's seed to initialize a new rng."""
worker_config
:
WorkerConfig
_rng
:
Optional
[
torch
.
Generator
]
=
None
_restore_state
:
Optional
[
bytes
]
=
None
def
__init__
(
self
,
worker_config
:
WorkerConfig
):
self
.
worker_config
=
worker_config
@
property
def
rng
(
self
)
->
torch
.
Generator
:
if
self
.
_rng
is
None
or
self
.
_restore_state
is
not
None
:
self
.
worker_config
.
assert_worker
()
self
.
_rng
=
torch
.
Generator
()
if
self
.
_restore_state
is
not
None
:
self
.
_rng
.
set_state
(
torch
.
frombuffer
(
bytearray
(
self
.
_restore_state
),
dtype
=
torch
.
uint8
,
).
clone
()
)
else
:
# Restore to initial state (either due to zero sized states, or just initial state)
self
.
_rng
.
manual_seed
(
self
.
worker_config
.
worker_seed
())
self
.
_restore_state
=
None
return
self
.
_rng
def
randbelow
(
self
,
n
:
int
)
->
int
:
return
torch
.
randint
(
0
,
n
,
(),
generator
=
self
.
rng
).
item
()
def
choice_idx
(
self
,
probs
:
torch
.
Tensor
)
->
int
:
if
len
(
probs
)
==
1
:
return
0
else
:
# Custom implementation of multinomial to ensure consistency
# Torch changed their implementation of torch.multinomial in 2.7.0 and to be
# consistent with any torch version, we use a custom implementation here instead.
# This is anyways just a very simple case of multinomial, thus this should be fine.
# Actually, benchmarks show that this is faster than torch.multinomial by a factor of
# 10 even on CPU.
cdf
=
torch
.
cumsum
(
probs
,
dim
=
0
)
val
=
torch
.
rand
(
1
,
generator
=
self
.
rng
)
*
cdf
[
-
1
]
return
torch
.
searchsorted
(
cdf
,
val
).
item
()
def
choice
(
self
,
l
:
List
[
T
],
probs
:
Optional
[
torch
.
Tensor
]
=
None
)
->
T
:
if
probs
is
None
:
return
l
[
self
.
randbelow
(
len
(
l
))]
assert
len
(
l
)
==
len
(
probs
)
return
l
[
self
.
choice_idx
(
probs
)]
def
shuffle
(
self
,
l
:
List
[
T
])
->
List
[
T
]:
"""Returns a new list with shuffled entries"""
p
=
torch
.
randperm
(
len
(
l
),
generator
=
self
.
rng
)
return
[
l
[
p
[
i
]]
for
i
in
range
(
len
(
l
))]
def
rand_pop
(
self
,
l
:
List
[
T
])
->
T
:
return
l
.
pop
(
self
.
randbelow
(
len
(
l
)))
def
save_state
(
self
)
->
FlexState
:
return
FlexState
(
rng
=
None
if
self
.
rng
is
None
else
bytes
(
self
.
rng
.
get_state
().
tolist
()))
def
restore_state
(
self
,
state
:
FlexState
):
if
state
[
"rng"
]
is
None
:
self
.
_restore_state
=
None
else
:
self
.
_restore_state
=
state
[
"rng"
]
@
edataclass
class
SystemRngState
:
"""The state of the global random generators.
Note that the data types of the internal RNG states are implementation details of the
respective libraries and may change in the future.
Python does not even specify the type in their docs. Hence we will allow arbitrary types,
because all that matters is that we can save and restore them. We will not use the data
anywhere else.
"""
torch
:
Any
# Currently `torch.Tensor`
numpy
:
Any
# Currently `dict[str, Any] | tuple[str, NDArray[uint32], int, int, float]`
random
:
Any
# Currently a nested tuple
def
_hashable_value
(
self
,
value
:
Any
)
->
Any
:
if
isinstance
(
value
,
(
int
,
float
,
bool
,
str
))
or
value
is
None
:
return
value
elif
isinstance
(
value
,
torch
.
Tensor
):
return
self
.
_hashable_value
(
value
.
tolist
())
elif
isinstance
(
value
,
numpy
.
ndarray
):
return
self
.
_hashable_value
(
value
.
tolist
())
elif
isinstance
(
value
,
Mapping
):
return
tuple
(
(
self
.
_hashable_value
(
k
),
self
.
_hashable_value
(
v
))
for
k
,
v
in
value
.
items
()
)
elif
isinstance
(
value
,
Sequence
):
return
tuple
(
self
.
_hashable_value
(
v
)
for
v
in
value
)
else
:
raise
ValueError
(
f
"Cannot hash value of type
{
type
(
value
)
}
:
{
value
!
r
}
"
)
def
__repr__
(
self
):
# If the hash is the same, the state is the same. Should suffice to identify the state.
return
f
"SystemRngState(hash=
{
hash
(
self
.
_hashable_value
((
self
.
torch
,
self
.
numpy
,
self
.
random
)))
}
)"
class
SystemRng
:
"""A class to seed, save or restore the global random generators.
This affects torch, numpy and the standard library random module."""
@
staticmethod
def
seed
(
seed
:
int
)
->
None
:
"""Seeds the global random generators."""
torch
.
manual_seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
@
staticmethod
def
save_state
()
->
SystemRngState
:
"""Saves the global rng state for torch, numpy and random."""
return
SystemRngState
(
torch
=
torch
.
get_rng_state
(),
numpy
=
numpy
.
random
.
get_state
(),
random
=
random
.
getstate
(),
)
@
staticmethod
def
restore_state
(
state
:
SystemRngState
)
->
None
:
"""Restores the global rng state for torch, numpy and random."""
torch
.
set_rng_state
(
state
.
torch
)
numpy
.
random
.
set_state
(
state
.
numpy
)
random
.
setstate
(
state
.
random
)
@
staticmethod
def
get_seed_from_args
(
*
args
:
Any
)
->
int
:
"""Deterministically generates a seed from the given arguments.
The str() representation of each arg is used."""
# Use a deterministic hash function to compute the seed
hash_digest
=
hashlib
.
sha1
(
"|"
.
join
([
str
(
obj
)
for
obj
in
args
]).
encode
(
"utf-8"
)).
digest
()
# We use the first 4 bytes of the hash as the seed and fix the endianness
seed_value
=
int
.
from_bytes
(
hash_digest
[:
4
],
byteorder
=
"big"
)
return
seed_value
@
staticmethod
def
seed_args
(
*
args
:
Any
)
->
None
:
"""Seeds the global random generators deterministically from the given arguments."""
SystemRng
.
seed
(
SystemRng
.
get_seed_from_args
(
*
args
))
Megatron-Energon/src/megatron/energon/savable.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
abc
import
ABC
,
abstractmethod
from
megatron.energon.state
import
FlexState
class
Savable
(
ABC
):
@
abstractmethod
def
save_state
(
self
)
->
FlexState
:
"""Dumps the state of the object into some serializable format.
Should only dump the part that is persistable for energon usage.
"""
...
@
abstractmethod
def
restore_state
(
self
,
state
:
FlexState
):
"""Restores the state of the object.
If state is None, the object is restored to its default state."""
...
Megatron-Energon/src/megatron/energon/savable_loader.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
gc
import
multiprocessing
import
queue
import
threading
import
time
import
traceback
import
weakref
from
functools
import
partial
from
typing
import
(
Any
,
ClassVar
,
Dict
,
Generic
,
Iterator
,
List
,
Optional
,
Sequence
,
Tuple
,
TypeVar
,
Union
,
)
import
numpy
as
np
import
torch
import
torch.distributed
import
torch.multiprocessing
from
torch.utils.data
import
DataLoader
,
IterableDataset
from
megatron.energon.cache
import
CachePool
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.errors
import
deprecated
from
megatron.energon.flavors.base_dataset
import
(
FlexState
,
SavableDataset
,
State
,
add_sample_restore_key
,
)
from
megatron.energon.rng
import
SystemRng
,
SystemRngState
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
from
megatron.energon.wrappers.batch_dataset
import
BatchDataset
from
megatron.energon.wrappers.gc_dataset
import
GC_DEFAULT_EVERY_N_ITER
,
GcDataset
,
gc_init_worker
from
megatron.energon.wrappers.log_sample_dataset
import
default_get_keys
from
megatron.energon.wrappers.watchdog_dataset
import
WatchdogDataset
T
=
TypeVar
(
"T"
)
def
_init_worker
(
seed_per_worker
:
List
[
int
],
worker_id
:
int
):
"""Initializes the the worker process.
Sets the random seeds and prepare EPath for the forked worker process.
"""
gc_init_worker
(
worker_id
)
worker_seed
=
seed_per_worker
[
worker_id
]
SystemRng
.
seed
(
worker_seed
)
class
SimpleSavableDatasetWrapper
(
BaseWrapperDataset
[
T
,
Tuple
[
int
,
int
,
T
]],
Generic
[
T
]):
"""Wrapper for non-multiprocessing savable datasets. Restarts the inner dataset. This class is
not intended to be used directly."""
#: The cache pool to use for the dataset.
cache_pool
:
CachePool
_state_restored
:
bool
_sample_index
:
int
_savable_fields
=
(
"_sample_index"
,)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T
],
worker_config
:
WorkerConfig
,
cache_pool
:
CachePool
):
"""
Args:
dataset: The dataset to wrap.
worker_config: The worker config to use for the dataset.
cache_pool: The cache pool to use for the dataset.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
cache_pool
=
cache_pool
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_sample_index
=
0
self
.
_state_restored
=
False
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
self
.
dataset
.
len_worker
(
worker_idx
)
@
property
def
__len__
(
self
):
# Note: This disables hasattr(self, "__len__"), because that attr will
raise
AttributeError
(
"Disabled direct length access to avoid DataLoader warnings."
)
def
__iter__
(
self
):
self
.
_state_restored
=
True
worker_id
=
self
.
worker_config
.
rank_worker_id
()
global_worker_id
=
self
.
worker_config
.
global_worker_id
()
while
self
.
_state_restored
:
self
.
_state_restored
=
False
self
.
worker_config
.
worker_activate
(
self
.
_sample_index
,
cache_pool
=
self
.
cache_pool
)
worker_active
=
True
try
:
for
src_data
in
self
.
dataset
:
self
.
worker_config
.
worker_deactivate
()
worker_active
=
False
sample_index
=
self
.
_sample_index
src_data
=
add_sample_restore_key
(
src_data
,
global_worker_id
,
sample_index
,
src
=
self
)
self
.
_sample_index
+=
1
yield
worker_id
,
sample_index
,
src_data
if
self
.
_state_restored
:
# Restart iterator after restore
break
self
.
worker_config
.
worker_activate
(
self
.
_sample_index
,
cache_pool
=
self
.
cache_pool
)
worker_active
=
True
finally
:
if
worker_active
:
self
.
worker_config
.
worker_deactivate
()
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T
:
id
,
global_worker_id
,
sample_idx
=
restore_key
[:
3
]
assert
id
==
type
(
self
).
__name__
restore_key
=
restore_key
[
3
:]
self
.
worker_config
.
worker_activate
(
sample_idx
,
override_global_rank
=
global_worker_id
,
cache_pool
=
self
.
cache_pool
)
try
:
return
add_sample_restore_key
(
self
.
dataset
.
restore_sample
(
restore_key
),
global_worker_id
,
sample_idx
,
src
=
self
,
)
finally
:
self
.
worker_config
.
worker_deactivate
()
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
self
.
dataset
.
config
()
def
__str__
(
self
):
return
f
"SimpleSavableDatasetWrapper(dataset=
{
self
.
dataset
}
)"
@
edataclass
class
SavableDatasetState
(
State
):
"""State of the dataset wrapper. It stores the global random states and the index of the next
sample to be returned from the dataset. This class is not intended to be used directly, but by
:class:`megatron.energon.SavableDatasetWrapper`."""
#: The state of all the system random number generators
rng
:
SystemRngState
#: The state of the savable dataset
dataset_state
:
FlexState
#: Index of the next sample to be returned from the dataset
sample_index
:
int
def
__repr__
(
self
):
return
f
"SavableDatasetState(rng=
{
self
.
rng
!
r
}
, sample_index=
{
self
.
sample_index
}
)"
@
edataclass
class
SavableCheckpoint
:
"""Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. An instance is created
regularly to be able to save the state of the dataset wrapper before the currently emitted
sample.
"""
#: The state of the wrapper
state
:
Optional
[
SavableDatasetState
]
#: The time at which the checkpoint was created
checkpoint_time
:
float
#: Index of the next sample to be returned from the dataset after restoring the checkpoint
sample_index
:
int
@
edataclass
class
SavableDatasetCheckpoint
(
State
):
"""Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. The checkpoint state
represents a state before that checkpoint, with an offset (i.e. samples to be skipped)."""
#: The state of the wrapper at the sample index when the checkpoint was created.
state
:
Optional
[
SavableDatasetState
]
#: Offset of the checkpoint to the actual sample index to be restored.
offset
:
int
class
SavableDatasetWrapper
(
IterableDataset
[
Tuple
[
int
,
int
,
T
]],
Generic
[
T
]):
"""Internal class for wrapping a savable dataset for a worker process. Provides communication
with the :class:`megatron.energon.SavableDataLoader`. This class is not intended to be used directly.
See :class:`megatron.energon.SavableDataLoader` for more information."""
#: The wrapped dataset
dataset
:
SavableDataset
[
T
]
#: The configuration of the worker process
worker_config
:
WorkerConfig
#: The time interval in seconds to wait at minimum between two checkpoints
checkpoint_every_sec
:
float
#: The minimum number of samples to be emitted between two checkpoints. Should be `number of
# workers * 2`.
checkpoint_every_min_n_samples
:
int
#: The number of checkpoints to keep in memory, before discarding. Should be 2.
n_checkpoints
:
int
#: The cache pool to use for the dataset.
cache_pool
:
CachePool
#: The queue of the worker process to receive commands from the `SavableDataLoader`.
_cmd_queues
:
List
[
torch
.
multiprocessing
.
Queue
]
#: The queue of the worker process to send results to the `SavableDataLoader`.
_result_queues
:
List
[
torch
.
multiprocessing
.
Queue
]
_sample_index
:
int
=
0
_worker_offset
:
int
=
0
_last_checkpoints
:
List
[
SavableCheckpoint
]
_workers_restore_from
:
List
[
Optional
[
SavableDatasetState
]]
=
list
()
_workers_skip_samples
:
List
[
int
]
_running
:
bool
=
False
_command_lock
:
Optional
[
threading
.
RLock
]
=
None
_cmd_thread
:
Optional
[
threading
.
Thread
]
=
None
def
__init__
(
self
,
dataset
:
SavableDataset
[
T
],
worker_config
:
WorkerConfig
,
checkpoint_every_sec
:
float
,
checkpoint_every_min_n_samples
:
int
,
n_checkpoints
:
int
=
2
,
*
,
cmd_queues
:
List
[
torch
.
multiprocessing
.
Queue
],
result_queues
:
List
[
torch
.
multiprocessing
.
Queue
],
cache_pool
:
CachePool
,
):
"""
Create the savable dataset wrapper for multiprocessing data loading.
Args:
dataset: The dataset to wrap
worker_config: The worker config as used by all datasets
checkpoint_every_sec: The time interval in seconds to wait at minimum between two
checkpoints.
checkpoint_every_min_n_samples: The minimum number of samples to be emitted between
two checkpoints. Should be `number of workers * 2`.
n_checkpoints: Number of checkpoints to keep.
cmd_queues: The command queues for communicating with the worker processes.
result_queues: The result queues for communicating with the worker processes.
cache_pool: The cache pool to use for the dataset.
"""
num_workers
=
max
(
worker_config
.
num_workers
,
1
)
self
.
dataset
=
dataset
self
.
worker_config
=
worker_config
self
.
checkpoint_every_sec
=
checkpoint_every_sec
self
.
checkpoint_every_min_n_samples
=
checkpoint_every_min_n_samples
self
.
n_checkpoints
=
n_checkpoints
self
.
_last_checkpoints
=
[
SavableCheckpoint
(
state
=
None
,
checkpoint_time
=
time
.
perf_counter
(),
sample_index
=-
1
)
]
self
.
_workers_restore_from
=
[
None
]
*
num_workers
self
.
_workers_skip_samples
=
[
0
]
*
num_workers
self
.
_cmd_queues
=
cmd_queues
self
.
_result_queues
=
result_queues
self
.
cache_pool
=
cache_pool
@
staticmethod
def
_command_thread
(
self
:
"SavableDatasetWrapper"
):
"""The internal thread, which processes the command and result queues. This thread is
static, because `self` is actually passed as weakref proxy, to avoid keeping the dataset
alive via the thread.
"""
# print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread starting")
assert
self
.
_command_lock
is
not
None
try
:
while
self
.
_running
:
try
:
cmd_args
=
self
.
_cmd_queues
[
self
.
_worker_id
].
get
(
timeout
=
1
)
except
queue
.
Empty
:
continue
# print(f"recv cmd {cmd_args}")
with
self
.
_command_lock
:
cmd
=
cmd_args
[
0
]
if
cmd
is
None
:
break
try
:
fn
=
getattr
(
self
,
cmd
)
self
.
_result_queues
[
self
.
_worker_id
].
put
(
{
self
.
_worker_id
:
fn
(
*
cmd_args
[
1
:])}
)
# print(f"result sent")
except
Exception
as
e
:
traceback
.
print_exc
()
self
.
_result_queues
[
self
.
_worker_id
].
put
({
self
.
_worker_id
:
e
})
# print(f"exc sent")
except
BaseException
:
traceback
.
print_exc
()
raise
finally
:
pass
# print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread closing")
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
self
.
dataset
.
len_worker
(
worker_idx
)
def
len_rank
(
self
):
return
self
.
dataset
.
len_rank
()
@
property
def
__len__
(
self
):
# Note: This disables hasattr(self, "__len__"), because that attr will
raise
AttributeError
(
"Disabled direct length access to avoid DataLoader warnings."
)
def
__del__
(
self
):
if
self
.
_cmd_thread
is
not
None
:
# print(f"{id(self)}:{multiprocessing.current_process().ident} Closing cmd thread")
self
.
_running
=
False
self
.
_cmd_thread
.
join
()
self
.
_command_lock
=
None
self
.
_cmd_thread
=
None
# print(f"{id(self)}:{multiprocessing.current_process().ident} Cmd thread closed")
def
__iter__
(
self
):
# First: Set the worker offset globally for the current worker
WorkerConfig
.
worker_id_offset
=
self
.
_worker_offset
self
.
_worker_id
=
self
.
worker_config
.
rank_worker_id
()
global_worker_id
=
self
.
worker_config
.
global_worker_id
()
if
self
.
_cmd_thread
is
None
:
self
.
_running
=
True
self
.
_command_lock
=
threading
.
RLock
()
weakref_self
=
weakref
.
proxy
(
self
)
self
.
_cmd_thread
=
threading
.
Thread
(
target
=
SavableDatasetWrapper
.
_command_thread
,
name
=
"command_thread"
,
args
=
(
weakref_self
,),
daemon
=
True
,
)
self
.
_cmd_thread
.
start
()
# atexit.register(lambda: weakref_self.__del__())
try
:
assert
self
.
_command_lock
is
not
None
with
self
.
_command_lock
:
if
self
.
_workers_restore_from
[
self
.
_worker_id
]
is
not
None
:
my_state
=
self
.
_workers_restore_from
[
self
.
_worker_id
]
my_ds_state
=
my_state
.
dataset_state
assert
my_state
is
not
None
if
my_ds_state
is
None
:
self
.
dataset
.
reset_state_deep
()
else
:
self
.
dataset
.
restore_state
(
my_ds_state
)
self
.
_restore_state
(
my_state
)
self
.
_workers_restore_from
[
self
.
_worker_id
]
=
None
else
:
# Store the initial state of the worker if we stop before the first sample
self
.
_store_checkpoint
()
# If skipping, also restart the iterator to reach the start of the restored
# checkpoint
last_was_skip
=
True
while
last_was_skip
:
dataset_has_samples
=
False
self
.
worker_config
.
worker_activate
(
self
.
_sample_index
,
cache_pool
=
self
.
cache_pool
)
worker_active
=
True
try
:
for
src_data
in
self
.
dataset
:
self
.
worker_config
.
worker_deactivate
()
worker_active
=
False
dataset_has_samples
=
True
if
self
.
_workers_skip_samples
[
self
.
_worker_id
]
>
0
:
# Skip ahead to reach the start of the restored checkpoint
# print(f"Skip [{self._sample_index}:{self._worker_id}] {src_data}")
self
.
_workers_skip_samples
[
self
.
_worker_id
]
-=
1
self
.
_sample_index
+=
1
last_was_skip
=
True
self
.
worker_config
.
worker_activate
(
self
.
_sample_index
,
cache_pool
=
self
.
cache_pool
)
worker_active
=
True
continue
last_was_skip
=
False
sample_index
=
self
.
_sample_index
add_sample_restore_key
(
src_data
,
global_worker_id
,
sample_index
,
src
=
self
)
self
.
_sample_index
+=
1
self
.
_store_checkpoint
()
try
:
self
.
_command_lock
.
release
()
# print(f"{id(self)}:{multiprocessing.current_process().ident} Lock released")
# Commands may be executed only when data was yielded, not during
# iteration fetching.
# print(f"Yield next data [{sample_index}:{self._worker_id}] {src_data}")
yield
self
.
_worker_id
,
sample_index
,
src_data
finally
:
# print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquiring")
self
.
_command_lock
.
acquire
()
# print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquired")
self
.
worker_config
.
worker_activate
(
self
.
_sample_index
,
cache_pool
=
self
.
cache_pool
)
worker_active
=
True
finally
:
if
worker_active
:
self
.
worker_config
.
worker_deactivate
()
# If the dataset is empty, don't try again and again
if
not
dataset_has_samples
:
break
finally
:
# print(f"{id(self)}:{multiprocessing.current_process().ident} Worker iter closing")
# Always store a final checkpoint (it's likely to be saved)
self
.
_store_checkpoint
(
force
=
True
)
def
_store_checkpoint
(
self
,
force
:
bool
=
False
)
->
None
:
"""
Internally create a checkpoint for the current state. This is required to store states
from the past, which have already been yielded here, but not yet been retrieved from the
intermediate queues.
Args:
force: If true, ignore time or frequency condition.
"""
if
(
force
or
(
self
.
_last_checkpoints
[
-
1
].
checkpoint_time
+
self
.
checkpoint_every_sec
<
time
.
perf_counter
()
and
self
.
_last_checkpoints
[
-
1
].
sample_index
+
self
.
checkpoint_every_min_n_samples
<=
self
.
_sample_index
)
or
self
.
_sample_index
<=
1
):
# print(f"Storing checkpoint at {self._worker_id}:{self._sample_index}")
self
.
_last_checkpoints
.
append
(
SavableCheckpoint
(
state
=
self
.
_save_state
(),
checkpoint_time
=
time
.
perf_counter
(),
sample_index
=
self
.
_sample_index
,
)
)
if
len
(
self
.
_last_checkpoints
)
>
self
.
n_checkpoints
:
self
.
_last_checkpoints
.
pop
(
0
)
def
_save_state
(
self
)
->
SavableDatasetState
:
"""Saves the internal state"""
return
SavableDatasetState
(
rng
=
SystemRng
.
save_state
(),
dataset_state
=
self
.
dataset
.
save_state
(),
sample_index
=
self
.
_sample_index
,
)
def
_restore_state
(
self
,
state
:
SavableDatasetState
)
->
None
:
"""Restores the internal worker state"""
assert
torch
.
utils
.
data
.
get_worker_info
()
is
not
None
,
"Can only restore in worker process"
if
state
.
rng
is
None
:
SystemRng
.
seed
(
torch
.
initial_seed
()
&
0xFFFFFFFF
)
else
:
SystemRng
.
restore_state
(
state
.
rng
)
self
.
_sample_index
=
state
.
sample_index
self
.
_last_checkpoints
=
[
SavableCheckpoint
(
state
=
self
.
_save_state
(),
checkpoint_time
=
time
.
perf_counter
(),
sample_index
=
self
.
_sample_index
,
)
]
def
get_checkpoint
(
self
,
last_sample_indexes
:
List
[
int
])
->
SavableDatasetCheckpoint
:
"""
Get a checkpoint given the last emitted sample indexes for all workers.
Args:
last_sample_indexes: The last emitted sample indexes for all workers.
Returns:
The found checkpoint including the offset to the next sample index
"""
sample_index
=
last_sample_indexes
[
self
.
_worker_id
]
+
1
for
checkpoint
in
reversed
(
self
.
_last_checkpoints
):
if
checkpoint
.
sample_index
<=
sample_index
:
# print(f"Found cp for {sample_index} at {checkpoint.sample_index}")
return
SavableDatasetCheckpoint
(
state
=
checkpoint
.
state
,
offset
=
sample_index
-
checkpoint
.
sample_index
,
)
# Immediate save after restore
if
len
(
self
.
_last_checkpoints
)
==
0
:
return
SavableDatasetCheckpoint
(
state
=
self
.
_workers_restore_from
[
self
.
_worker_id
],
offset
=
self
.
_workers_skip_samples
[
self
.
_worker_id
],
)
raise
ValueError
(
"No checkpoint found"
)
def
restore_checkpoint
(
self
,
worker_states
:
Optional
[
List
[
SavableDatasetCheckpoint
]],
worker_offset
:
int
,
)
->
None
:
"""
Restores the merged checkpoint from all worker processes.
Args:
worker_states: The state to restore for each worker
worker_offset: The offset of the last worker which has emitted a sample. This will be
set in all worker processes to ensure the right worker starts as first.
"""
assert
torch
.
utils
.
data
.
get_worker_info
()
is
None
,
"Cannot restore in worker process"
num_workers
=
max
(
self
.
worker_config
.
num_workers
,
1
)
if
worker_states
is
None
:
self
.
_workers_restore_from
=
[
None
]
*
num_workers
assert
worker_offset
==
0
self
.
_worker_offset
=
0
self
.
_workers_skip_samples
=
[
0
]
*
num_workers
else
:
assert
isinstance
(
worker_states
,
list
)
assert
len
(
worker_states
)
==
num_workers
assert
isinstance
(
worker_states
[
0
],
SavableDatasetCheckpoint
)
self
.
_worker_offset
=
worker_offset
# Tear the state_list apart (which has len=num_workers)
# and store the states in the internal arrays
self
.
_workers_restore_from
=
[
state
.
state
for
state
in
worker_states
]
self
.
_workers_skip_samples
=
[
state
.
offset
for
state
in
worker_states
]
def
get_initial_checkpoint
(
self
)
->
Optional
[
List
[
SavableDatasetCheckpoint
]]:
"""
Get the initial checkpoint for all worker processes if they have not started yet.
Returns:
The initial checkpoint for all worker processes and the worker offset.
"""
assert
torch
.
utils
.
data
.
get_worker_info
()
is
None
,
(
"Cannot get initial checkpoint in worker process"
)
if
all
(
s
is
None
for
s
in
self
.
_workers_restore_from
):
assert
all
(
s
==
0
for
s
in
self
.
_workers_skip_samples
)
# Initial state, no checkpoint
return
None
return
[
SavableDatasetCheckpoint
(
state
=
state
,
offset
=
offset
,
)
for
state
,
offset
in
zip
(
self
.
_workers_restore_from
,
self
.
_workers_skip_samples
)
]
def
can_restore_sample
(
self
)
->
bool
:
return
self
.
dataset
.
can_restore_sample
()
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T
:
id
,
global_worker_id
,
sample_idx
=
restore_key
[:
3
]
assert
id
==
type
(
self
).
__name__
restore_key
=
restore_key
[
3
:]
self
.
worker_config
.
worker_activate
(
sample_idx
,
override_global_rank
=
global_worker_id
)
try
:
return
add_sample_restore_key
(
self
.
dataset
.
restore_sample
(
restore_key
),
global_worker_id
,
sample_idx
,
src
=
self
,
)
finally
:
self
.
worker_config
.
worker_deactivate
()
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
self
.
dataset
.
config
()
def
__str__
(
self
):
return
f
"SavableDatasetWrapper(dataset=
{
self
.
dataset
}
)"
@
edataclass
class
SavableDataLoaderState
(
State
):
"""Saved state of the :class:`megatron.energon.SavableDataLoader`. Contains the state for all worker
processed of a single rank."""
#: The internal state of the dataset (for each worker process)
worker_states
:
List
[
Union
[
SavableDatasetCheckpoint
,
FlexState
]]
#: Which worker will be the next to emit a sample. Used to restore the proper order
next_worker_id
:
int
#: The micro batch size that was used, if available.
#: On restore, this is used to compare the new and old micro batch size.
micro_batch_size
:
Optional
[
int
]
class
SavableDataLoader
(
DataLoader
[
T
],
Generic
[
T
]):
"""DataLoader that supports saving and restoring the state of the dataset.
When restoring, the dataloader and dataset must be instantiated with the exactly same
parameters.
How this works (for no worker processes)
----------------------------------------
1. The state of the dataset is saved using :meth:`megatron.energon.SavableDataset.save_state`
2. (for compatibility) The state of the dataset is converted to using inner arrays using
:meth:`megatron.energon.SavableDataset.merge_states`.
3. The state can be restored using :meth:`megatron.energon.SavableDataset.restore_state` given the
previously saved (and merged) state.
How this works (for worker processes)
-------------------------------------
- First issue is, that worker processes work with internal queues between processes to pass
loaded samples to the main process (also to perform collating). This means that the whole
state of the dataset is not directly accessible from the main process.
- To solve this issue, the dataset regularly saves a checkpoint of its state to be able to
resume from that state (and skip the samples that have already been yielded).
- To have a consistent state, the sample index from the latest yielded samples is saved for all
worker instances. Thus, the main process knows exactly which sample indexes should come next
from which worker.
- Internally, pytorch iterates through the workers in order to retrieve the next worker's
samples. Unfortunately, that next worker index cannot be restored in pytorch's dataloader,
thus the workers are shifted internally by that offset
(see :attr:`megatron.energon.WorkerConfig.worker_id_offset`).
1. The dataset is wrapped in a :class:`megatron.energon.SavableDatasetWrapper`. This allows the main
process to communicate with the worker and send commands to the workers and retrieve the
results.
2. The state of the dataset is saved using
:meth:`megatron.energon.SavableDatasetWrapper.get_checkpoint`. This gives the last checkpoint
from the requested sample index and stores the offset (i.e. number of samples to skip) from
that checkpoint.
3. The state is merged using :meth:`megatron.energon.SavableDatasetWrapper.merge_checkpoints`. This
merges the states of all workers and returns a single state that can be used to restore the
state of the dataset.
4. The state can be restored using :meth:`megatron.energon.SavableDatasetWrapper.restore_state`
before a worker is started, such that all workers initially receive the same state array.
The worker firstly sets the worker index offset, then uses its (shifted) own index to get its
required state from the merged state array.
"""
#: The worker config
worker_config
:
WorkerConfig
#: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper`
dataset
:
Union
[
SavableDatasetWrapper
[
T
],
SimpleSavableDatasetWrapper
[
T
]]
#: The global ID counter
_next_id
:
ClassVar
[
int
]
=
0
#: Class instance id
id
:
int
=
0
#: The queues used to send commands to the workers
cmd_queues
:
List
[
torch
.
multiprocessing
.
Queue
]
#: The queues used to receive results from the workers
result_queues
:
List
[
torch
.
multiprocessing
.
Queue
]
#: Instance of the current data iterator. There shall be only one active iterator, such that the
# dataset is not iterated multiple times in parallel. The state will continue between epochs.
_epoch_iterator
:
Optional
[
Iterator
[
T
]]
=
None
#: Whether the dataloader has running workers.
_has_workers
:
bool
=
False
#: The index of the current worker. -1 if not started yet.
_worker_sample_counters
:
List
[
int
]
#: Id of the next worker to retrieve data from
_next_worker_id
:
int
=
0
#: Global index of the last yielded sample
_global_sample_idx
:
int
=
0
#: Current iterator index of the last yielded sample
_sample_idx
:
int
=
0
def
__init__
(
self
,
dataset
:
SavableDataset
[
T
],
*
,
checkpoint_every_sec
:
float
=
60
,
checkpoint_every_min_n_samples
:
Optional
[
int
]
=
None
,
n_checkpoints
:
Optional
[
int
]
=
None
,
gc_collect_every_n_steps
:
int
=
GC_DEFAULT_EVERY_N_ITER
,
gc_freeze_at_start
:
bool
=
True
,
prefetch_factor
:
int
=
2
,
cache_pool
:
Optional
[
CachePool
]
=
None
,
watchdog_timeout_seconds
:
Optional
[
float
]
=
60
,
watchdog_initial_timeout_seconds
:
Optional
[
float
]
=
None
,
fail_on_timeout
:
bool
=
False
,
):
"""
Create the dataloader supporting saving and restoring the state.
Args:
dataset: The dataset to load.
worker_config: The worker config to use
checkpoint_every_sec: This is the time in seconds after which a checkpoint is saved.
It may take the same duration to restore a checkpoint, but introduces additional
overhead during reading data from the dataset, so this should be chosen accordingly.
Only applies if using workers.
checkpoint_every_min_n_samples: Overwrites the minimum number of samples between
checkpoints. Defaults to `number of workers * 2`. Only applies if using workers.
n_checkpoints: The number of checkpoints to keep in memory. Only applies if using
workers. If None, computes a suitable value.
gc_collect_every_n_steps: The number of steps after which the garbage collector is
called. As we're usually handling large (but few) tensors here, and the python
garbage collection is already full of objects just by importing, this can improve
the memory footprint quite a lot, and may even be necessary to avoid memory
overflow.
gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker
processes. This improves the garbage collection performance by a lot.
In rare cases, this may cause issues and can be disabled. Keep enabled if you
experience no issues.
cache_pool: If set, the cache pool to use for the dataset.
watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled.
watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.
fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace.
"""
self
.
worker_config
=
dataset
.
worker_config
self
.
id
=
self
.
next_id
()
dataset
=
WatchdogDataset
(
dataset
,
worker_config
=
self
.
worker_config
,
timeout_seconds
=
watchdog_timeout_seconds
,
initial_timeout_seconds
=
watchdog_initial_timeout_seconds
,
fail_on_timeout
=
fail_on_timeout
,
)
if
gc_collect_every_n_steps
>
0
:
dataset
=
GcDataset
(
dataset
,
worker_config
=
self
.
worker_config
,
every_n_iter
=
gc_collect_every_n_steps
,
freeze
=
gc_freeze_at_start
,
)
self
.
cmd_queues
=
[
multiprocessing
.
Queue
()
for
_
in
range
(
self
.
worker_config
.
num_workers
)]
self
.
result_queues
=
[
multiprocessing
.
Queue
()
for
_
in
range
(
self
.
worker_config
.
num_workers
)
]
num_procs
=
max
(
self
.
worker_config
.
num_workers
,
1
)
if
n_checkpoints
is
None
:
n_checkpoints
=
prefetch_factor
*
num_procs
+
1
if
self
.
worker_config
.
num_workers
>
0
:
if
checkpoint_every_min_n_samples
is
None
:
checkpoint_every_min_n_samples
=
self
.
worker_config
.
num_workers
*
2
dataset
=
SavableDatasetWrapper
(
dataset
,
self
.
worker_config
,
checkpoint_every_sec
=
checkpoint_every_sec
,
checkpoint_every_min_n_samples
=
checkpoint_every_min_n_samples
,
n_checkpoints
=
n_checkpoints
,
cmd_queues
=
self
.
cmd_queues
,
result_queues
=
self
.
result_queues
,
cache_pool
=
cache_pool
,
)
else
:
dataset
=
SimpleSavableDatasetWrapper
(
dataset
,
self
.
worker_config
,
cache_pool
=
cache_pool
)
self
.
_worker_sample_counters
=
[
-
1
]
*
num_procs
kwargs
=
{}
if
self
.
worker_config
.
num_workers
>
0
:
kwargs
[
"persistent_workers"
]
=
True
kwargs
[
"prefetch_factor"
]
=
prefetch_factor
# Assert that prefetch_factor works well with num_checkpoints.
# This ensures that the oldest checkpoint is old enough to cover
# all the buffered samples in the torch dataloader.
assert
prefetch_factor
*
num_procs
+
1
<=
n_checkpoints
,
(
"When increasing prefetch_factor, also increase n_checkpoints, so that "
"the number of checkpoints is at least as large as num_workers * prefetch_factor + 1"
)
# Compute seeds for each worker, based on current rank
seed_per_worker
=
[
self
.
worker_config
.
worker_seed
(
i
)
for
i
in
range
(
self
.
worker_config
.
num_workers
)
]
super
().
__init__
(
dataset
,
batch_size
=
None
,
shuffle
=
False
,
num_workers
=
self
.
worker_config
.
num_workers
,
pin_memory
=
True
,
worker_init_fn
=
partial
(
_init_worker
,
seed_per_worker
),
**
kwargs
,
)
if
self
.
worker_config
.
should_log
(
level
=
1
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"SavableLoader.__init__"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
None
,
"id"
:
self
.
id
,
"config"
:
dataset
.
config
(),
}
)
@
staticmethod
def
next_id
()
->
int
:
next_id
=
SavableDataLoader
.
_next_id
SavableDataLoader
.
_next_id
+=
1
return
next_id
def
__len__
(
self
):
# We override this, because otherwise we'll see warnings
return
self
.
dataset
.
len_rank
()
def
_epoch_iter
(
self
):
"""Iterator for one epoch, i.e. until the inner dataset raises StopIteration."""
iter_idx
=
0
id
=
self
.
next_id
()
if
self
.
worker_config
.
should_log
(
level
=
1
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"SavableDataLoader.iter"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
None
,
"id"
:
self
.
id
,
"iter_id"
:
id
,
}
)
try
:
for
worker_id
,
sample_idx
,
sample
in
super
().
__iter__
():
self
.
_worker_sample_counters
[
worker_id
]
=
sample_idx
# If the next sample will be from the first worker, we can safely resume
self
.
_next_worker_id
=
(
worker_id
+
1
)
%
max
(
self
.
num_workers
,
1
)
# self._debugf.write(
# f"[w={worker_id}, s={sample_idx}] {self._sample_str(sample)}\n"
# )
# self._debugf.flush()
if
self
.
worker_config
.
should_log
(
level
=
1
):
keys
=
default_get_keys
(
sample
)
self
.
worker_config
.
worker_log
(
{
**
{
"t"
:
"SavableDataLoader.yield"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
None
,
"id"
:
self
.
id
,
"iter_id"
:
id
,
"worker_id"
:
worker_id
,
"worker_idx"
:
sample_idx
,
"idx"
:
self
.
_sample_idx
,
"iter_idx"
:
iter_idx
,
"global_idx"
:
self
.
_global_sample_idx
,
},
**
({}
if
keys
is
None
else
{
"keys"
:
keys
}),
}
)
self
.
_sample_idx
+=
1
self
.
_global_sample_idx
+=
1
iter_idx
+=
1
yield
sample
self
.
_epoch_iterator
=
None
self
.
_next_worker_id
=
0
finally
:
if
self
.
worker_config
.
should_log
(
level
=
1
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"SavableDataLoader.StopIteration"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
None
,
"id"
:
self
.
id
,
"iter_id"
:
self
.
id
,
}
)
def
__iter__
(
self
):
if
self
.
num_workers
>
0
:
# Always keep same iterator alive, as long as it yields data
if
self
.
_epoch_iterator
is
None
:
self
.
_epoch_iterator
=
self
.
_epoch_iter
()
self
.
_sample_idx
=
0
self
.
_has_workers
=
True
# print("New Iterator", self._persistent_iterator)
return
self
.
_epoch_iterator
else
:
return
self
.
_epoch_iter
()
def
_worker_command
(
self
,
*
cmd_args
)
->
List
[
Any
]:
"""Executes a command in all workers and returns the results."""
# print(f"cmd: {cmd_args}")
for
cmd_queue
in
self
.
cmd_queues
:
cmd_queue
.
put
(
cmd_args
)
# print(f"waiting for res")
assert
len
(
self
.
result_queues
)
==
self
.
worker_config
.
num_workers
res
=
{
k
:
v
for
results_queue
in
self
.
result_queues
for
k
,
v
in
results_queue
.
get
().
items
()}
res
=
[
res
[
i
]
for
i
in
range
(
len
(
res
))]
# print(f"res: {res}")
for
r
in
res
:
if
isinstance
(
r
,
Exception
):
raise
r
return
res
def
_get_batch_size
(
self
)
->
Optional
[
int
]:
"""Try to infer micro batch size from the dataset"""
if
isinstance
(
self
.
dataset
,
(
SavableDatasetWrapper
,
SimpleSavableDatasetWrapper
)):
dataset
=
self
.
dataset
.
dataset
else
:
dataset
=
self
.
dataset
if
(
isinstance
(
dataset
,
BaseWrapperDataset
)
and
(
bds
:
=
dataset
.
_find_wrapped_dataset
(
BatchDataset
))
is
not
None
):
assert
isinstance
(
bds
,
BatchDataset
)
return
bds
.
batch_size
else
:
return
None
def
save_state_rank
(
self
)
->
Optional
[
SavableDataLoaderState
]:
"""
Saves the state of the dataset for the current rank. Allows for restoring the state later
using `restore_state_rank`, given the result of this method.
Returns:
The state of the dataset.
"""
# Fetch current rank's worker's state
if
self
.
num_workers
==
0
:
# No workers configured
assert
isinstance
(
self
.
dataset
,
SimpleSavableDatasetWrapper
)
worker_states
=
[
self
.
dataset
.
save_state
()]
assert
self
.
_next_worker_id
==
0
elif
self
.
_has_workers
:
# Fetch from worker processes
worker_states
=
self
.
_worker_command
(
"get_checkpoint"
,
self
.
_worker_sample_counters
)
else
:
# Workers configured, but not started yet.
# If a state has already been restored, it will be returned.
assert
isinstance
(
self
.
dataset
,
SavableDatasetWrapper
)
worker_states
=
self
.
dataset
.
get_initial_checkpoint
()
if
worker_states
is
None
:
return
None
# Merge the states
merged_state
=
SavableDataLoaderState
(
worker_states
=
worker_states
,
next_worker_id
=
self
.
_next_worker_id
,
micro_batch_size
=
self
.
_get_batch_size
(),
)
# Not distributed -> return the merged state
return
merged_state
def
restore_state_rank
(
self
,
state
:
Optional
[
SavableDataLoaderState
])
->
None
:
"""
Restores the saved state for the current rank.
Args:
state: The state to restore, as saved by `save_state_rank`.
"""
assert
not
self
.
_has_workers
,
"Cannot restore state while workers are running"
if
state
is
None
:
# Assume initial state
return
assert
isinstance
(
state
,
SavableDataLoaderState
)
old_micro_batch_size
=
state
.
micro_batch_size
micro_batch_size
=
self
.
_get_batch_size
()
if
self
.
num_workers
==
0
:
# No workers configured
assert
isinstance
(
self
.
dataset
,
SimpleSavableDatasetWrapper
)
assert
micro_batch_size
==
old_micro_batch_size
,
(
"Changing micro batch size is not allowed without workers"
)
assert
len
(
state
.
worker_states
)
==
1
assert
isinstance
(
state
.
worker_states
[
0
],
FlexState
)
self
.
dataset
.
restore_state
(
state
.
worker_states
[
0
])
else
:
# Workers configured
assert
isinstance
(
self
.
dataset
,
SavableDatasetWrapper
)
assert
all
(
isinstance
(
s
,
SavableDatasetCheckpoint
)
for
s
in
state
.
worker_states
)
# Check batch sizes (before and after)
if
micro_batch_size
!=
old_micro_batch_size
:
assert
micro_batch_size
is
not
None
and
old_micro_batch_size
is
not
None
,
(
"Cannot resume with different batching mode "
"(batching to non-batching or vice versa)"
)
if
micro_batch_size
>
old_micro_batch_size
:
raise
ValueError
(
"Resuming with larger micro batch size is not allowed: "
f
"
{
micro_batch_size
}
>
{
state
.
micro_batch_size
}
"
)
elif
(
micro_batch_size
<
old_micro_batch_size
and
old_micro_batch_size
%
micro_batch_size
!=
0
):
raise
ValueError
(
"Resuming with smaller micro batch size only allowed if the old "
f
"micro batch size is a multiple of the new one:
{
micro_batch_size
}
<
{
state
.
micro_batch_size
}
"
)
batch_size_ratio
=
old_micro_batch_size
//
micro_batch_size
for
worker_state
in
state
.
worker_states
:
assert
isinstance
(
worker_state
,
SavableDatasetCheckpoint
)
# When resuming with a smaller micro batch size, the offset must be scaled
# up to the new micro batch size to skip the same number of samples as before.
worker_state
.
offset
*=
batch_size_ratio
self
.
dataset
.
restore_checkpoint
(
state
.
worker_states
,
worker_offset
=
state
.
next_worker_id
)
# Initialize the worker-sample counters so that every worker owns a valid
# "last emitted sample" index. Workers that have not emitted anything yet keep
# the default value ``-1``.
assert
isinstance
(
state
.
worker_states
,
list
)
self
.
_worker_sample_counters
=
[
(
ws
.
state
.
sample_index
-
1
if
(
isinstance
(
ws
,
SavableDatasetCheckpoint
)
and
ws
.
state
is
not
None
)
else
-
1
)
for
ws
in
state
.
worker_states
]
self
.
_next_worker_id
=
state
.
next_worker_id
@
deprecated
(
"`save_state` is deprecated and was renamed to `save_state_global` and will be removed "
"in a future update. If you actually do not want to gather the states to a rank, use "
"`save_state_rank` instead."
)
def
save_state
(
self
,
dst_rank
:
int
)
->
Optional
[
Sequence
[
Optional
[
SavableDataLoaderState
]]]:
"""Deprecated. Use `save_state_global` (or `save_state_rank`) instead."""
return
self
.
save_state_global
(
dst_rank
)
def
save_state_global
(
self
,
global_dst_rank
:
int
)
->
Optional
[
Sequence
[
Optional
[
SavableDataLoaderState
]]]:
"""
Saves the state of the dataset globally, collecting the state from all ranks using torch
distributed. Allows for restoring the state later using `restore_state_global`, given the
result of this method.
Typical scenario: Save the state to disk only on the `dst_rank`, the other ranks do not
save the state. Later, restore the state either only loaded on the `dst_rank` or
loading on all ranks separately using `restore_state_global`.
Note: If you want to save/restore the state per rank separately, use `save_state_rank` and
the corresponding `restore_state_rank`. Also, these do not rely on torch distributed.
Args:
global_dst_rank: The state will be gathered to this rank. The rank refers to the
global rank, not the rank within the data parallel group.
Returns:
The state of the dataset (or `None`, if not on `dst_rank`).
"""
# Fetch current rank's worker's state
merged_state
=
self
.
save_state_rank
()
# Gather the merged states
if
self
.
worker_config
.
world_size
>
1
:
output
:
Optional
[
Sequence
[
Optional
[
SavableDataLoaderState
]]]
if
self
.
worker_config
.
global_rank
()
==
global_dst_rank
:
output
=
[
None
]
*
self
.
worker_config
.
world_size
else
:
# Check if the global_dst_rank is in the same group at all
if
self
.
worker_config
.
data_parallel_group
is
not
None
:
try
:
_
=
torch
.
distributed
.
get_group_rank
(
self
.
worker_config
.
data_parallel_group
,
global_dst_rank
)
except
RuntimeError
:
raise
ValueError
(
f
"global_dst_rank
{
global_dst_rank
}
is not in the group of the current rank's worker config"
)
output
=
None
torch
.
distributed
.
gather_object
(
merged_state
,
output
,
global_dst_rank
,
group
=
self
.
worker_config
.
data_parallel_group
,
)
return
output
else
:
# Not distributed -> return the merged state
return
[
merged_state
]
@
deprecated
(
"`restore_state` was renamed to `restore_state_global` and will be removed in a future update."
)
def
restore_state
(
self
,
state
:
Optional
[
Sequence
[
Optional
[
SavableDataLoaderState
]]],
)
->
None
:
"""Deprecated. Use `restore_state_global` (or `restore_state_rank`) instead."""
return
self
.
restore_state_global
(
state
)
def
restore_state_global
(
self
,
state
:
Optional
[
Sequence
[
Optional
[
SavableDataLoaderState
]]],
*
,
src_rank
:
Optional
[
int
]
=
None
,
)
->
None
:
"""
Restores the saved state from `save_state_global` (in torch distributed setup).
The global state needs be loaded on every rank that has a data loader instance.
Optionally, one can specify a src_rank and only provide the state once.
In case of multiple data parallel groups, you must provide the state once
in each data parallel group. In this case the `src_rank` is the rank within the
data parallel group.
Args:
state: The state to restore, as saved by `save_state_global`.
src_rank: The rank from which the state is broadcasted (within the data parallel group, if using DP groups).
"""
assert
self
.
_epoch_iterator
is
None
,
"Cannot restore state while workers are running"
# Only restore multi-rank if state is actually a list and we are in a distributed setup.
# Otherwise treat as single rank state.
if
src_rank
is
None
or
self
.
worker_config
.
world_size
==
1
:
assert
isinstance
(
state
,
list
),
"State must be a list in distributed setup"
assert
len
(
state
)
==
self
.
worker_config
.
world_size
,
(
"State must be a list of size world_size"
)
# All ranks have the state
# Select the state of the current rank
rank_state
=
state
[
self
.
worker_config
.
rank
]
else
:
if
self
.
worker_config
.
data_parallel_group
is
not
None
:
# Only the src_rank has the state within this dp group
try
:
global_src_rank
=
torch
.
distributed
.
get_global_rank
(
self
.
worker_config
.
data_parallel_group
,
src_rank
)
except
RuntimeError
:
raise
ValueError
(
f
"src_rank
{
src_rank
}
is not in the group of the current rank's worker config"
)
else
:
# If no DP group is given, we assume the global rank is
# the same as the data parallel rank
global_src_rank
=
src_rank
if
self
.
worker_config
.
rank
!=
src_rank
:
# Send the state to all other ranks
assert
state
is
None
# Must still be a list of Nones
state
=
[
None
]
*
self
.
worker_config
.
world_size
else
:
assert
isinstance
(
state
,
list
),
"State must be a list in distributed setup"
assert
len
(
state
)
==
self
.
worker_config
.
world_size
,
(
"State must be a list of size world_size"
)
local_object
=
[
None
]
torch
.
distributed
.
scatter_object_list
(
local_object
,
state
,
src
=
global_src_rank
,
group
=
self
.
worker_config
.
data_parallel_group
,
)
rank_state
=
local_object
[
0
]
self
.
restore_state_rank
(
rank_state
)
def
can_restore_sample
(
self
)
->
bool
:
return
self
.
dataset
.
can_restore_sample
()
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T
:
"""Restores a sample from a key. This is useful to debug the dataset."""
return
self
.
dataset
.
restore_sample
(
restore_key
)
def
config
(
self
):
"""Get the configuration, which defines the dataset. Useful in conjunction with `save_state`
and `restore_state` to match the configuration as well."""
return
{
"type"
:
type
(
self
).
__qualname__
,
"num_workers"
:
self
.
num_workers
,
"persistent_workers"
:
self
.
persistent_workers
,
"pin_memory"
:
self
.
pin_memory
,
"prefetch_factor"
:
None
if
self
.
num_workers
==
0
else
self
.
prefetch_factor
,
"dataset"
:
self
.
dataset
.
config
(),
}
class
BasicDataLoader
(
DataLoader
[
T
],
Generic
[
T
]):
"""DataLoader that supports debugging the dataset without saving capability (e.g. for val/eval)."""
#: The worker config
worker_config
:
WorkerConfig
#: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper`
dataset
:
Union
[
SavableDatasetWrapper
[
T
],
SavableDataset
[
T
]]
id
:
int
_sample_idx
:
int
=
0
def
__init__
(
self
,
dataset
:
SavableDataset
[
T
],
gc_collect_every_n_steps
:
int
=
GC_DEFAULT_EVERY_N_ITER
,
gc_freeze_at_start
:
bool
=
True
,
prefetch_factor
:
int
=
2
,
cache_pool
:
Optional
[
CachePool
]
=
None
,
watchdog_timeout_seconds
:
Optional
[
float
]
=
60
,
watchdog_initial_timeout_seconds
:
Optional
[
float
]
=
None
,
fail_on_timeout
:
bool
=
False
,
):
"""
Create the dataloader supporting saving and restoring the state.
Args:
dataset: The dataset to load.
gc_collect_every_n_steps: The number of steps after which the garbage collector is
called. As we're usually handling large (but few) tensors here, and the python
garbage collection is already full of objects just by importing, this can improve
the memory footprint quite a lot, and may even be necessary to avoid memory
overflow.
gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker
processes. This improves the garbage collection performance by a lot.
In rare cases, this may cause issues and can be disabled. Keep enabled if you
experience no issues.
cache_pool: If set, the cache pool to use for the dataset.
watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled.
watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.
fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace.
"""
self
.
worker_config
=
dataset
.
worker_config
self
.
id
=
SavableDataLoader
.
next_id
()
dataset
=
WatchdogDataset
(
dataset
,
worker_config
=
self
.
worker_config
,
timeout_seconds
=
watchdog_timeout_seconds
,
initial_timeout_seconds
=
watchdog_initial_timeout_seconds
,
fail_on_timeout
=
fail_on_timeout
,
)
if
gc_collect_every_n_steps
>
0
:
dataset
=
GcDataset
(
dataset
,
worker_config
=
self
.
worker_config
,
every_n_iter
=
gc_collect_every_n_steps
,
freeze
=
gc_freeze_at_start
,
)
dataset
=
SimpleSavableDatasetWrapper
(
dataset
,
worker_config
=
self
.
worker_config
,
cache_pool
=
cache_pool
)
self
.
_worker_sample_counters
=
[
0
]
*
max
(
self
.
worker_config
.
num_workers
,
1
)
kwargs
=
{}
if
self
.
worker_config
.
num_workers
>
0
:
# These must not be specified for num_workers =0
kwargs
[
"persistent_workers"
]
=
True
kwargs
[
"prefetch_factor"
]
=
prefetch_factor
seed_per_worker
=
[
self
.
worker_config
.
worker_seed
(
i
)
for
i
in
range
(
self
.
worker_config
.
num_workers
)
]
gc
.
collect
()
# This ensures that we don't include any old worker refs in the newly forked worker processes
super
().
__init__
(
dataset
,
batch_size
=
None
,
shuffle
=
False
,
num_workers
=
self
.
worker_config
.
num_workers
,
pin_memory
=
True
,
worker_init_fn
=
partial
(
_init_worker
,
seed_per_worker
),
**
kwargs
,
)
if
self
.
worker_config
.
should_log
(
level
=
1
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"BasicDataLoader.__init__"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
None
,
"id"
:
self
.
id
,
"config"
:
self
.
config
(),
}
)
def
__len__
(
self
):
# We override this, because otherwise we'll see warnings
return
self
.
dataset
.
len_rank
()
def
__iter__
(
self
):
def
_inner_generator
(
iterator
):
iter_idx
=
0
id
=
SavableDataLoader
.
next_id
()
if
self
.
worker_config
.
should_log
(
level
=
1
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"BasicDataLoader.iter"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
None
,
"id"
:
self
.
id
,
"iter_id"
:
id
,
}
)
try
:
for
worker_id
,
sample_idx
,
sample
in
iterator
:
# If the next sample will be from the first worker, we can safely resume
if
self
.
worker_config
.
should_log
(
level
=
1
):
keys
=
default_get_keys
(
sample
)
self
.
worker_config
.
worker_log
(
{
**
{
"t"
:
"BasicDataLoader.yield"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
None
,
"id"
:
self
.
id
,
"iter_id"
:
self
.
id
,
"worker_id"
:
worker_id
,
"worker_idx"
:
sample_idx
,
"idx"
:
iter_idx
,
"iter_idx"
:
iter_idx
,
"global_idx"
:
self
.
_sample_idx
,
},
**
({}
if
keys
is
None
else
{
"keys"
:
keys
}),
}
)
self
.
_sample_idx
+=
1
iter_idx
+=
1
yield
sample
finally
:
if
self
.
worker_config
.
should_log
(
level
=
1
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"BasicDataLoader.StopIteration"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
None
,
"id"
:
self
.
id
,
"iter_id"
:
id
,
}
)
return
_inner_generator
(
super
().
__iter__
())
def
config
(
self
):
"""Get the configuration, which defines the dataset. Useful in conjunction with `save_state`
and `restore_state` to match the configuration as well."""
return
{
"type"
:
type
(
self
).
__qualname__
,
"num_workers"
:
self
.
worker_config
.
num_workers
,
"persistent_workers"
:
self
.
persistent_workers
,
"pin_memory"
:
self
.
pin_memory
,
"prefetch_factor"
:
None
if
self
.
num_workers
==
0
else
self
.
prefetch_factor
,
"dataset"
:
self
.
dataset
.
config
(),
}
def
can_restore_sample
(
self
)
->
bool
:
return
self
.
dataset
.
can_restore_sample
()
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T
:
"""Restores a sample from a key. This is useful to debug the dataset."""
return
self
.
dataset
.
restore_sample
(
restore_key
)
def
_sample_str
(
self
,
sample
):
"""Returns a human readable debug string for a single sample, also uniquely identifying it."""
import
dataclasses
import
hashlib
if
isinstance
(
sample
,
torch
.
Tensor
):
return
f
"Tensor(shape=
{
sample
.
shape
}
, dtype=
{
sample
.
dtype
}
, sha256=
{
hashlib
.
sha256
(
sample
.
detach
().
cpu
().
numpy
().
tobytes
()).
hexdigest
()
!
r
}
)"
elif
isinstance
(
sample
,
np
.
ndarray
):
return
f
"ndarray(shape=
{
sample
.
shape
}
, dtype=
{
sample
.
dtype
}
, sha256=
{
hashlib
.
sha256
(
sample
.
tobytes
()).
hexdigest
()
!
r
}
)"
elif
isinstance
(
sample
,
(
str
,
bytes
)):
if
len
(
sample
)
>
100
:
return
f
"
{
sample
[:
100
]
!
r
}
..."
return
repr
(
sample
)
elif
isinstance
(
sample
,
(
int
,
float
)):
return
repr
(
sample
)
elif
isinstance
(
sample
,
dict
):
innerstr
=
", "
.
join
(
f
"
{
k
}
=
{
self
.
_sample_str
(
v
)
}
, "
for
k
,
v
in
sample
.
items
())
return
f
"dict(
{
innerstr
}
)"
elif
isinstance
(
sample
,
(
list
,
tuple
)):
innerstr
=
", "
.
join
(
self
.
_sample_str
(
v
)
for
v
in
sample
)
return
f
"
{
type
(
sample
)
}
(
{
innerstr
}
)"
elif
dataclasses
.
is_dataclass
(
sample
):
innerstr
=
", "
.
join
(
f
"
{
field
.
name
}
=
{
self
.
_sample_str
(
getattr
(
sample
,
field
.
name
))
}
, "
for
field
in
dataclasses
.
fields
(
sample
)
)
return
f
"
{
type
(
sample
).
__name__
}
(
{
innerstr
}
)"
else
:
return
repr
(
sample
)
Megatron-Energon/src/megatron/energon/source_info.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
dataclasses
import
dataclass
from
typing
import
Any
,
Union
from
megatron.energon.epathlib
import
EPath
@
dataclass
(
frozen
=
True
,
slots
=
True
,
kw_only
=
True
)
class
SourceInfo
:
"""Information about the source of a sample, i.e. where the data was loaded from."""
#: The path to the dataset
dataset_path
:
EPath
#: The index of the sample in the dataset
index
:
Union
[
str
,
int
]
#: The name of the shard tar file
shard_name
:
str
#: The names of the files in the shard used to create the sample
file_names
:
tuple
[
str
,
...]
def
add_source_info
(
sample
:
Any
,
*
source_infos
:
SourceInfo
)
->
Any
:
"""Adds a source info to a sample."""
from
megatron.energon.flavors.base_dataset
import
Sample
if
isinstance
(
sample
,
Sample
)
or
hasattr
(
sample
,
"__sources__"
):
if
sample
.
__sources__
is
None
:
sample
.
__sources__
=
source_infos
else
:
sample
.
__sources__
=
(
*
sample
.
__sources__
,
*
source_infos
)
elif
isinstance
(
sample
,
dict
):
if
sample
.
get
(
"__sources__"
)
is
None
:
sample
[
"__sources__"
]
=
source_infos
else
:
sample
[
"__sources__"
]
=
(
*
sample
[
"__sources__"
],
*
source_infos
)
return
sample
Megatron-Energon/src/megatron/energon/state.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
class
FlexState
(
dict
):
pass
Megatron-Energon/src/megatron/energon/task_encoder/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.task_encoder.base
import
(
AugmentTaskEncoder
,
Batch
,
DefaultTaskEncoder
,
TaskEncoder
,
batch_list
,
batch_pad_stack
,
batch_stack
,
generic_batch
,
stateless
,
)
from
megatron.energon.task_encoder.cooking
import
Cooker
,
basic_sample_keys
,
cooker
from
megatron.energon.task_encoder.loader
import
(
get_train_dataset
,
get_val_dataset
,
get_val_datasets
,
)
__all__
=
[
"AugmentTaskEncoder"
,
"basic_sample_keys"
,
"batch_list"
,
"batch_pad_stack"
,
"batch_stack"
,
"Batch"
,
"cooker"
,
"Cooker"
,
"DefaultTaskEncoder"
,
"generic_batch"
,
"get_train_dataset"
,
"get_val_dataset"
,
"get_val_datasets"
,
"stateless"
,
"TaskEncoder"
,
]
Megatron-Energon/src/megatron/energon/task_encoder/base.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
dataclasses
import
functools
import
inspect
from
abc
import
ABC
from
dataclasses
import
is_dataclass
from
types
import
MethodType
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
Generic
,
Hashable
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
,
overload
,
)
import
torch
from
typing_extensions
import
ParamSpec
from
megatron.energon.cache
import
CachePool
,
DecodeFileStore
,
FileStore
from
megatron.energon.edataclass
import
edataclass
from
megatron.energon.flavors
import
(
CrudeSample
,
CrudeWebdataset
,
PinMemoryMixin
,
Sample
,
SampleDecoder
,
SavableDataset
,
)
from
megatron.energon.flavors.base_dataset
import
ExtendableDataclassMixin
from
megatron.energon.metadataset.loader_interface
import
DatasetBlendMode
,
LoadedDataset
from
megatron.energon.rng
import
SystemRng
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.task_encoder.cooking
import
Cooker
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers
import
(
BatchDataset
,
BlendDataset
,
ConcatDataset
,
EpochizeDataset
,
GroupBatchDataset
,
LimitDataset
,
LogSampleDataset
,
MapDataset
,
PackingDataset
,
ShuffleBufferDataset
,
)
from
megatron.energon.wrappers.repeat_dataset
import
RepeatDataset
T
=
TypeVar
(
"T"
)
V
=
TypeVar
(
"V"
)
T_sample
=
TypeVar
(
"T_sample"
,
bound
=
Sample
)
T_encoded_sample
=
TypeVar
(
"T_encoded_sample"
)
T_raw_batch
=
TypeVar
(
"T_raw_batch"
)
T_batch
=
TypeVar
(
"T_batch"
)
FeatureBatcher
=
Callable
[[
List
[
Any
]],
Any
]
def
generic_batch
(
batch
:
List
[
Any
])
->
Any
:
"""Based on the types/shapes of the batch: Will either pad and stack, or return as list.
Recurses structures (dict, dataclass, namedtuple) and applies the same logic to each field."""
if
isinstance
(
batch
[
0
],
torch
.
Tensor
):
return
batch_pad_stack
(
batch
)
elif
isinstance
(
batch
[
0
],
dict
):
return
{
key
:
generic_batch
([
sample
[
key
]
for
sample
in
batch
])
for
key
in
batch
[
0
].
keys
()}
elif
is_dataclass
(
batch
[
0
]):
if
hasattr
(
batch
[
0
],
"from_samples"
):
# The dataclass defines a method for batching samples
return
batch
[
0
].
from_samples
(
batch
)
return
type
(
batch
[
0
])(
**
{
field
.
name
:
generic_batch
([
getattr
(
sample
,
field
.
name
)
for
sample
in
batch
])
for
field
in
dataclasses
.
fields
(
batch
[
0
])
}
)
elif
isinstance
(
batch
[
0
],
tuple
)
and
hasattr
(
batch
[
0
],
"_fields"
):
# NamedTuple
return
type
(
batch
[
0
])(
**
{
field
:
generic_batch
([
getattr
(
sample
,
field
)
for
sample
in
batch
])
for
field
in
batch
[
0
].
_fields
}
)
else
:
return
batch_list
(
batch
)
def
batch_stack
(
batch
:
List
[
Any
])
->
Any
:
"""Stack a batch of tensors."""
return
torch
.
stack
(
batch
,
dim
=
0
)
def
batch_pad_stack
(
batch
:
List
[
Any
])
->
Any
:
"""Stack a batch of arbitrary-sized tensors padded with 0s."""
max_size
=
[
max
(
b
.
shape
[
dim
]
for
b
in
batch
)
for
dim
in
range
(
batch
[
0
].
ndim
)]
batch_tensor
=
batch
[
0
].
new_zeros
((
len
(
batch
),
*
max_size
))
for
i
,
b
in
enumerate
(
batch
):
batch_tensor
[(
i
,
*
(
slice
(
0
,
s
)
for
s
in
b
.
shape
))]
=
b
# Pad all tensors to max_size
return
batch_tensor
def
batch_list
(
batch
:
List
[
Any
])
->
Any
:
"""Stack a batch of tensors padded with 0s."""
return
batch
P
=
ParamSpec
(
"P"
)
@
overload
def
stateless
(
*
,
restore_seeds
:
bool
=
False
,
failure_tolerance
:
Optional
[
int
]
=
None
)
->
Callable
[[
Callable
[
P
,
T
]],
Callable
[
P
,
T
]]:
...
@
overload
def
stateless
(
fn
:
Callable
[
P
,
T
])
->
Callable
[
P
,
T
]:
...
def
stateless
(
fn
:
Optional
[
Callable
[...,
T
]]
=
None
,
*
,
restore_seeds
:
bool
=
False
,
failure_tolerance
:
Optional
[
int
]
=
None
,
)
->
Union
[
Callable
[[
Callable
[...,
T
]],
Callable
[...,
T
]],
Callable
[...,
T
]]:
"""Decorator to mark a function of the task encoder as restorable.
Args:
fn: The function to decorate.
restore_seeds: Whether to restore the seeds for the function. I.e. the seeds are set
from the sample index and the worker seed, such that they can be restored when a sample
is restored from that function.
failure_tolerance: The number of consecutive exceptions that are handled, after which a `FatalSampleError` is
raised for this function. Set to 0 to disable.
Usage:
.. code-block:: python
@stateless
def encode_sample(self, sample: T_sample) -> T_encoded_sample:
...
# Or if randomness is used (e.g. for augmentations):
@stateless(restore_seeds=True)
def encode_sample(self, sample: T_sample) -> T_encoded_sample:
...
"""
if
fn
is
None
:
return
lambda
f
:
stateless
(
f
,
restore_seeds
=
restore_seeds
,
failure_tolerance
=
failure_tolerance
)
if
restore_seeds
:
worker_seed
=
None
@
functools
.
wraps
(
fn
)
def
seed_wrapper_generator
(
self
,
*
args
,
**
kwargs
):
nonlocal
worker_seed
if
worker_seed
is
None
:
worker_seed
=
WorkerConfig
.
active_worker_config
.
worker_seed
()
# Save the RNG states and set the new seed
outer_rng_state
=
SystemRng
.
save_state
()
# Before constructing the generator and before the first
# iteration, set inner RNG based on seed computed
# from worker_seed and current sample index
SystemRng
.
seed_args
(
worker_seed
,
self
.
current_sample_index
)
it
=
iter
(
fn
(
self
,
*
args
,
**
kwargs
))
inner_rand_state
=
None
while
True
:
if
inner_rand_state
is
not
None
:
# Restore inner random state before calling the generator
# This will not be done on the first iteration
SystemRng
.
restore_state
(
inner_rand_state
)
try
:
# Now call the generator. This will yield the sample
# But note it may also throw an exception or a StopIteration
sample
=
next
(
it
)
# Save inner random state after calling the generator
inner_rand_state
=
SystemRng
.
save_state
()
except
StopIteration
:
# We're stopping here, but the outer random state
# will be restored before returning (in finally below)
break
finally
:
# Restore outer rand state before yielding or when an exception was raised
SystemRng
.
restore_state
(
outer_rng_state
)
# Now yield the sample.
# This will give control back to the caller who may
# change the random state.
yield
sample
# Save outer random state after yielding
outer_rng_state
=
SystemRng
.
save_state
()
@
functools
.
wraps
(
fn
)
def
seed_wrapper
(
self
,
*
args
,
**
kwargs
):
nonlocal
worker_seed
if
worker_seed
is
None
:
worker_seed
=
WorkerConfig
.
active_worker_config
.
worker_seed
()
# Save the RNG states and set the new seed
rng_state
=
SystemRng
.
save_state
()
SystemRng
.
seed_args
(
worker_seed
,
self
.
current_sample_index
)
try
:
return
fn
(
self
,
*
args
,
**
kwargs
)
finally
:
# Restore the RNGs
SystemRng
.
restore_state
(
rng_state
)
if
inspect
.
isgeneratorfunction
(
fn
):
setattr
(
seed_wrapper_generator
,
"__stateless__"
,
True
)
return
seed_wrapper_generator
else
:
setattr
(
seed_wrapper
,
"__stateless__"
,
True
)
return
seed_wrapper
setattr
(
fn
,
"__stateless__"
,
True
)
if
failure_tolerance
is
not
None
:
setattr
(
fn
,
"__failure_tolerance__"
,
failure_tolerance
)
return
fn
def
get_stateless
(
fn
:
Callable
)
->
bool
:
"""Get whether a function is stateless."""
return
getattr
(
fn
,
"__stateless__"
,
False
)
def
get_failure_tolerance
(
fn
:
Callable
,
default_failure_tolerance
:
Optional
[
int
]
=
None
)
->
Optional
[
int
]:
"""Get the failure tolerance of a function."""
return
getattr
(
fn
,
"__failure_tolerance__"
,
default_failure_tolerance
)
@
edataclass
class
Batch
(
PinMemoryMixin
,
ExtendableDataclassMixin
):
"""Base class for a batch dataclass. Provides a default implementation for pinning memory.
Additionally, it provides a future safe implementation for creating an instance from another
batch `Batch.derive_from`."""
#: Uniquely identifies each sample in the dataset.
__key__
:
list
[
str
]
#: Key for restoring the sample. This is used to restore the sample from a checkpoint. It
# should be a (nested) tuple of strings and integers, which can be used to index the dataset.
__restore_key__
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...]
#: A dataset may define a subflavors to distinguish between samples of the same sample type.
__subflavors__
:
Optional
[
list
[
Optional
[
Dict
[
str
,
Any
]]]]
=
None
#: Information about the source of the sample, i.e. where the data was loaded from.
__sources__
:
Optional
[
tuple
[
SourceInfo
,
...]]
=
None
@
classmethod
def
derive_from
(
cls
:
Type
[
T_batch
],
base_batch
:
"Batch"
,
**
kwargs
)
->
T_batch
:
"""
Uses the base fields of `Batch` from base_batch (i.e. __key__, __restore_key__, __subflavors__, __sources__)
and creates a new batch with the kwargs as fields. This is useful for creating new batches, while keeping the
metadata of the base batch.
Use like::
.. code-block:: python
def encode_batch(batch: RawBatch) -> Batch:
return Batch.derive_from(batch, field1=batch.field1 + 1)
Args:
base_batch: The base batch to copy the base fields / metadata from.
kwargs: The fields of the new batch.
Returns:
The new batch.
"""
base_kwargs
=
{
field
.
name
:
getattr
(
base_batch
,
field
.
name
)
for
field
in
dataclasses
.
fields
(
Batch
)
}
return
cls
(
**
base_kwargs
,
**
kwargs
,
)
@
classmethod
def
from_samples
(
cls
:
Type
[
T_batch
],
samples
:
Sequence
[
Sample
],
**
kwargs
)
->
T_batch
:
"""
Creates a batch from samples to be batched. Tensors will be padded and stacked, other types will be put into
lists. This is the default implementation for `Batch.from_samples`.
Args:
samples: The samples to batch.
kwargs: Additional (overriding) fields of the batch.
Returns:
The constructed batch.
"""
assert
all
(
dataclasses
.
is_dataclass
(
scls
)
for
scls
in
samples
),
(
"Samples must be dataclasses"
)
# assert dataclasses.is_dataclass(cls), "Batch must be dataclass"
init_args
=
{}
fields
=
dataclasses
.
fields
(
cls
)
for
field
in
fields
:
if
field
.
name
in
kwargs
:
init_args
[
field
.
name
]
=
kwargs
[
field
.
name
]
elif
field
.
name
==
"__sources__"
:
if
any
(
sample
.
__sources__
is
not
None
for
sample
in
samples
):
# Special handling, needs flattening
init_args
[
field
.
name
]
=
tuple
(
source
for
sample
in
samples
if
sample
.
__sources__
for
source
in
sample
.
__sources__
)
elif
field
.
name
==
"__subflavors__"
:
if
any
(
sample
.
__subflavors__
is
not
None
for
sample
in
samples
):
init_args
[
field
.
name
]
=
[
sample
.
__subflavors__
for
sample
in
samples
if
sample
.
__subflavors__
]
else
:
value
=
[
getattr
(
sample
,
field
.
name
)
for
sample
in
samples
]
if
len
(
samples
)
>
0
and
isinstance
(
samples
[
0
],
torch
.
Tensor
):
value
=
batch_pad_stack
(
value
)
init_args
[
field
.
name
]
=
value
return
cls
(
**
init_args
)
class
TaskEncoder
(
ABC
,
Generic
[
T_sample
,
T_encoded_sample
,
T_raw_batch
,
T_batch
]):
"""
Base class for task encoders.
Task encoding follows these steps:
0. Data comes from the dataset
1. :meth:`megatron.energon.TaskEncoder.encode_sample` / :meth:`megatron.energon.TaskEncoder.preencode_sample` is called on each sample
2. :meth:`megatron.energon.TaskEncoder.select_samples_to_pack` is called on the buffer of samples
3. :meth:`megatron.energon.TaskEncoder.postencode_sample` is called on each sample of the current pack
4. :meth:`megatron.energon.TaskEncoder.pack_selected_samples` is called on the selected sample pack
5. :meth:`megatron.energon.TaskEncoder.batch` is called on the list of encoded samples
6. :meth:`megatron.energon.TaskEncoder.encode_batch` is called on the batch
7. yield to main process
8. :meth:`megatron.energon.Batch.to_device` is called on the encoded batch
9. resulting encoded batch is passed to the network
"""
__default_failure_tolerance__
:
Optional
[
int
]
=
100
cookers
:
Sequence
[
Cooker
[
T_sample
]]
=
()
#: Internal: List of registered cookers. Will be the same as `cookers` after registering cookers.
_registered_cookers
:
List
[
Cooker
[
T_sample
]]
#: The decoder to use for decoding samples. Set manually as needed to override options.
decoder
:
Optional
[
SampleDecoder
]
=
SampleDecoder
()
@
stateless
def
cook_crude_sample
(
self
,
sample
:
Union
[
T_sample
,
CrudeSample
],
get_primary_aux
:
Callable
[[],
FileStore
],
**
aux
:
FileStore
,
)
->
T_sample
:
"""
Cooks a crude sample.
Args:
sample: The sample to cook.
get_primary_aux: A function that returns the (cached) primary auxiliary dataset.
**aux: The auxiliary side dishes to use for cooking.
Returns: The cooked sample.
"""
if
isinstance
(
sample
,
CrudeSample
):
for
cooker
in
self
.
cookers
:
if
cooker
.
is_match
(
sample
):
assert
get_stateless
(
cooker
.
cook
),
"Cooker must be stateless"
if
not
cooker
.
need_primary
and
not
cooker
.
need_cache
:
kwargs
=
aux
else
:
kwargs
:
dict
=
{}
if
cooker
.
need_primary
:
kwargs
[
"primary"
]
=
get_primary_aux
()
kwargs
.
update
(
aux
)
if
cooker
.
need_cache
:
kwargs
[
"cache"
]
=
self
.
cache
return
cooker
.
cook
(
sample
,
**
kwargs
)
raise
NotImplementedError
(
"You are using crude samples but not providing a way to cook them: "
f
"Sample key=
{
sample
[
'__key__'
]
}
, subflavors=
{
sample
[
'__subflavors__'
]
}
, "
f
"self.cookers=
{
self
.
cookers
}
"
)
else
:
assert
isinstance
(
sample
,
Sample
),
"Sample must be a complete Sample or a CrudeSample"
return
sample
def
_is_overridden
(
self
,
bound_method
:
Callable
[...,
Any
],
bases
:
Optional
[
Sequence
[
Type
[
Any
]]]
=
None
)
->
bool
:
"""Check if a method is overridden by a subclass of the base class(es).
By default, only TaskEncoder is used as a base class.
This is mainly used for optimization purposes. If the default method
is a no-op, we can skip it entirely unless the user has overridden it.
Args:
bound_method: The method to check.
bases: The base classes to check against.
Returns:
True if the method is overridden outside of TaskEncoder, False otherwise.
"""
if
not
isinstance
(
bound_method
,
MethodType
):
# If the method is not bound, it is always overridden
return
True
# Get the underlying function
func
=
bound_method
.
__func__
# Check if the subclass method matches any of the base class methods
if
bases
is
None
:
bases
=
(
TaskEncoder
,)
return
not
any
(
getattr
(
base
,
func
.
__name__
)
is
func
for
base
in
bases
)
@
stateless
def
encode_sample
(
self
,
sample
:
T_sample
)
->
Union
[
T_encoded_sample
,
Generator
[
T_encoded_sample
,
None
,
None
]]:
"""Encode a single sample. May raise :exc:`megatron.energon.SkipSample` to skip a sample.
Alternatively, this can be a generator that yields (or ignores) new samples.
If this is defined, :func:`preencode_sample` and :func:`postencode_sample` must not be defined.
"""
return
sample
@
stateless
def
preencode_sample
(
self
,
sample
:
T_sample
)
->
Union
[
T_sample
,
Generator
[
T_sample
,
None
,
None
]]:
"""Pre-encode a single sample. May raise :exc:`megatron.energon.SkipSample` to skip a sample.
Alternatively, this can be a generator that yields (or ignores) new samples.
Use in conjunction with packing and caching.
If this is defined, :func:`encode_sample` must not be defined.
"""
return
sample
@
stateless
def
postencode_sample
(
self
,
sample
:
T_sample
)
->
Union
[
T_encoded_sample
,
Generator
[
T_encoded_sample
,
None
,
None
]]:
"""Post-encode a single sample. May raise :exc:`megatron.energon.SkipSample` to skip a sample.
Alternatively, this can be a generator that yields (or ignores) new samples.
Use in conjunction with packing and caching.
If this is defined, :func:`encode_sample` must not be defined.
"""
return
sample
@
stateless
def
batch
(
self
,
samples
:
List
[
T_encoded_sample
])
->
T_raw_batch
:
"""Move a batch to a device. May raise :exc:`megatron.energon.SkipSample` to skip a batch."""
return
self
.
_batch
(
samples
,
type
(
samples
[
0
]))
def
batch_group_criterion
(
self
,
sample
:
T_encoded_sample
)
->
Tuple
[
Hashable
,
Optional
[
int
]]:
"""
Return a group criterion for the sample. Default implementation does not group
(effectively, it returns a single value `(None, None)`, thus only one group is used).
Returns the key of the bucket to put this sample into, and the size of the bucket (=batch size).
The bucket size must always be the same for the same bucket key.
May raise :exc:`megatron.energon.SkipSample` to skip a batch.
"""
return
None
,
None
@
stateless
def
encode_batch
(
self
,
batch
:
T_raw_batch
)
->
Union
[
T_batch
,
Generator
[
T_batch
,
None
,
None
]]:
"""Encode a batch of samples. May raise :exc:`megatron.energon.SkipSample` to skip a batch.
Alternatively, this can be a generator that yields (or ignores) new batches."""
return
batch
def
_batch
(
self
,
samples
:
List
[
T_encoded_sample
],
result_type
:
Type
[
T_raw_batch
],
actions
:
Optional
[
Dict
[
str
,
FeatureBatcher
]]
=
None
,
default_action
:
FeatureBatcher
=
generic_batch
,
)
->
T_raw_batch
:
"""
Batch a list of samples.
Args:
samples: The samples to batch
result_type: Type of the result (might be dict, dataclass, or namedtuple)
actions: For each field (=key), may specify a specific batcher
default_action: The batcher to apply to all fields not in `action`
Returns:
The batched result
"""
if
dataclasses
.
is_dataclass
(
result_type
)
and
hasattr
(
result_type
,
"from_samples"
):
return
result_type
.
from_samples
(
samples
)
# Get dict of samples
if
isinstance
(
samples
[
0
],
dict
):
list_samples
=
{
key
:
[
sample
[
key
]
for
sample
in
samples
]
for
key
in
samples
[
0
].
keys
()}
elif
is_dataclass
(
samples
[
0
]):
list_samples
=
{
field
.
name
:
[
getattr
(
sample
,
field
.
name
)
for
sample
in
samples
]
for
field
in
dataclasses
.
fields
(
samples
[
0
])
}
elif
isinstance
(
samples
[
0
],
tuple
)
and
hasattr
(
samples
[
0
],
"_fields"
):
# NamedTuple
list_samples
=
{
field
:
[
getattr
(
sample
,
field
)
for
sample
in
samples
]
for
field
in
samples
[
0
].
_fields
}
else
:
raise
ValueError
(
"Unrecognized sample type."
)
# Convert each field
if
actions
is
not
None
:
list_samples
=
{
key
:
default_action
(
value
)
if
key
not
in
actions
else
actions
[
key
](
value
)
for
key
,
value
in
list_samples
.
items
()
}
else
:
list_samples
=
{
key
:
default_action
(
value
)
for
key
,
value
in
list_samples
.
items
()}
# Construct result
if
issubclass
(
result_type
,
dict
):
return
list_samples
elif
dataclasses
.
is_dataclass
(
result_type
)
or
issubclass
(
result_type
,
tuple
):
# DataClass or NamedTuple
return
result_type
(
**
list_samples
)
else
:
raise
ValueError
(
"Unrecognized result type."
)
def
select_samples_to_pack
(
self
,
samples
:
List
[
T_encoded_sample
]
)
->
List
[
List
[
T_encoded_sample
]]:
"""
For packing, selects the samples to be packed together.
Packing is only active when packing_buffer_size is set.
Internally this stage is called "pre_packing".
Args:
samples: The samples to pre-pack. A full buffer will be passed into the function.
Returns: The pre-packed samples as a list of lists of samples.
"""
raise
NotImplementedError
(
"Packing only effective when overridden."
)
def
pack_selected_samples
(
self
,
samples
:
List
[
T_encoded_sample
])
->
T_encoded_sample
:
"""
Given one set of samples to pack, returns the final packed sample.
Packing is only active when packing_buffer_size is set.
Internally this stage is called "final_packing".
Args:
samples: The samples to pack into a single sample
Returns: The final packed sample.
"""
raise
NotImplementedError
(
"Packing only effective when overridden."
)
def
build_batch
(
self
,
dataset
:
SavableDataset
[
T_encoded_sample
],
*
,
batch_size
:
Optional
[
int
],
batch_drop_last
:
bool
=
False
,
packing_buffer_size
:
Optional
[
int
]
=
None
,
worker_config
:
WorkerConfig
,
)
->
SavableDataset
[
T_raw_batch
]:
"""Applies the batcher to the dataset."""
dataset
:
SavableDataset
[
Any
]
if
packing_buffer_size
is
not
None
:
select_samples_to_pack_provided
=
self
.
_is_overridden
(
self
.
select_samples_to_pack
)
pack_selected_samples_provided
=
self
.
_is_overridden
(
self
.
pack_selected_samples
)
assert
select_samples_to_pack_provided
and
pack_selected_samples_provided
,
(
"Both select_samples_to_pack and pack_selected_samples methods must be provided in the TaskEncoder when using packing_buffer_size"
)
if
self
.
_is_overridden
(
self
.
postencode_sample
):
post_encode_fn
=
self
.
postencode_sample
else
:
post_encode_fn
=
None
dataset
=
PackingDataset
(
dataset
,
buffer_size
=
packing_buffer_size
,
pre_packer
=
self
.
select_samples_to_pack
,
final_packer
=
self
.
pack_selected_samples
,
final_packer_stateless
=
get_stateless
(
self
.
pack_selected_samples
),
sample_encoder
=
post_encode_fn
,
sample_encoder_stateless
=
True
if
post_encode_fn
is
None
else
get_stateless
(
post_encode_fn
),
worker_config
=
worker_config
,
pre_packer_failure_tolerance
=
get_failure_tolerance
(
self
.
select_samples_to_pack
,
self
.
__default_failure_tolerance__
),
final_packer_failure_tolerance
=
get_failure_tolerance
(
self
.
pack_selected_samples
,
self
.
__default_failure_tolerance__
),
sample_encoder_failure_tolerance
=
None
if
post_encode_fn
is
None
else
get_failure_tolerance
(
post_encode_fn
,
self
.
__default_failure_tolerance__
),
)
elif
self
.
_is_overridden
(
self
.
postencode_sample
):
dataset
=
MapDataset
(
dataset
,
self
.
postencode_sample
,
worker_config
=
worker_config
,
stateless_map_fn
=
get_stateless
(
self
.
postencode_sample
),
failure_tolerance
=
get_failure_tolerance
(
self
.
postencode_sample
,
self
.
__default_failure_tolerance__
),
)
if
self
.
_is_overridden
(
self
.
batch_group_criterion
):
dataset
=
GroupBatchDataset
(
dataset
,
fixed_batch_size
=
batch_size
,
sample_group_key
=
self
.
batch_group_criterion
,
batcher
=
self
.
batch
,
drop_last
=
batch_drop_last
,
worker_config
=
worker_config
,
failure_tolerance
=
get_failure_tolerance
(
self
.
batch
,
self
.
__default_failure_tolerance__
),
)
if
self
.
_is_overridden
(
self
.
encode_batch
):
dataset
=
MapDataset
(
dataset
,
self
.
encode_batch
,
worker_config
=
worker_config
,
stateless_map_fn
=
get_stateless
(
self
.
encode_batch
),
failure_tolerance
=
get_failure_tolerance
(
self
.
encode_batch
,
self
.
__default_failure_tolerance__
),
)
else
:
# No grouping is active
if
batch_size
is
not
None
:
dataset
=
BatchDataset
(
dataset
,
batch_size
=
batch_size
,
batcher
=
self
.
batch
,
batcher_stateless
=
get_stateless
(
self
.
batch
),
drop_last
=
batch_drop_last
,
worker_config
=
worker_config
,
failure_tolerance
=
get_failure_tolerance
(
self
.
batch
,
self
.
__default_failure_tolerance__
),
)
if
self
.
_is_overridden
(
self
.
encode_batch
):
dataset
=
MapDataset
(
dataset
,
self
.
encode_batch
,
worker_config
=
worker_config
,
stateless_map_fn
=
get_stateless
(
self
.
encode_batch
),
failure_tolerance
=
get_failure_tolerance
(
self
.
encode_batch
,
self
.
__default_failure_tolerance__
),
)
return
dataset
def
build_cook_crude_sample
(
self
,
dataset
:
SavableDataset
[
Union
[
T_sample
,
dict
]],
*
,
worker_config
:
WorkerConfig
,
subflavors
:
Dict
[
str
,
Any
],
get_primary_aux
:
Callable
[[],
FileStore
],
aux
:
Optional
[
Dict
[
str
,
FileStore
]]
=
None
,
)
->
SavableDataset
[
T_sample
]:
"""Applies the sample cooker to the dataset if we have cookers registered."""
assert
self
.
cookers
,
"No cookers registered, but got crude dataset."
if
aux
is
not
None
and
self
.
decoder
is
not
None
:
aux
=
{
k
:
DecodeFileStore
(
v
,
decoder
=
self
.
decoder
)
for
k
,
v
in
aux
.
items
()}
# Cache the primary auxiliary dataset for this dataset, i.e. construct it once when needed
primary_aux
=
None
def
_get_primary_aux
():
nonlocal
primary_aux
if
primary_aux
is
None
:
try
:
if
aux
is
not
None
:
primary_aux
=
aux
.
get
(
"primary"
)
if
primary_aux
is
None
:
primary_aux
=
get_primary_aux
()
assert
primary_aux
is
not
None
,
"Primary auxiliary dataset must always exist"
if
self
.
decoder
is
not
None
:
primary_aux
=
DecodeFileStore
(
primary_aux
,
decoder
=
self
.
decoder
)
except
Exception
as
e
:
# Make the exception throw through for the sample being loaded
raise
SystemError
(
"Error getting primary auxiliary dataset"
)
from
e
return
primary_aux
if
aux
is
not
None
:
cook_fn
=
functools
.
partial
(
self
.
cook_crude_sample
,
get_primary_aux
=
_get_primary_aux
,
**
aux
)
else
:
cook_fn
=
functools
.
partial
(
self
.
cook_crude_sample
,
get_primary_aux
=
_get_primary_aux
)
return
MapDataset
(
dataset
,
cook_fn
,
worker_config
=
worker_config
,
stateless_map_fn
=
True
,
map_fn_config
=
dict
(
cookers
=
[
dict
(
cook
=
SavableDataset
.
_function_config
(
cooker
.
cook
),
has_subflavors
=
cooker
.
has_subflavors
,
)
for
cooker
in
self
.
cookers
],
subflavors
=
subflavors
,
),
failure_tolerance
=
get_failure_tolerance
(
cook_fn
,
self
.
__default_failure_tolerance__
),
)
def
_load_dataset
(
self
,
dataset
:
LoadedDataset
,
worker_rotation_offset
:
int
,
worker_config
:
WorkerConfig
)
->
SavableDataset
[
T_sample
]:
"""Loads a train dataset, optionally cooking the samples."""
if
dataset
.
dataset
.
__sample_type__
==
CrudeSample
:
return
self
.
build_cook_crude_sample
(
dataset
.
dataset
.
build
(
worker_rotation_offset
=
worker_rotation_offset
),
worker_config
=
worker_config
,
subflavors
=
dataset
.
dataset
.
subflavors
,
get_primary_aux
=
dataset
.
dataset
.
as_file_store
,
aux
=
dataset
.
aux
,
)
else
:
assert
dataset
.
aux
is
None
,
"Aux is not supported for non-crude datasets."
return
dataset
.
dataset
.
build
(
worker_rotation_offset
=
worker_rotation_offset
)
def
build_encode_sample
(
self
,
dataset
:
SavableDataset
[
T_sample
],
*
,
worker_config
:
WorkerConfig
,
)
->
SavableDataset
[
T_encoded_sample
]:
"""Applies the sample encoder to the dataset."""
if
self
.
_is_overridden
(
self
.
preencode_sample
):
pre_encode_fn
=
self
.
preencode_sample
assert
not
self
.
_is_overridden
(
self
.
encode_sample
,
bases
=
(
TaskEncoder
,
DefaultTaskEncoder
)
),
"Cannot have both pre- and post-encode functions defined."
elif
self
.
_is_overridden
(
self
.
encode_sample
):
pre_encode_fn
=
self
.
encode_sample
else
:
pre_encode_fn
=
None
if
pre_encode_fn
is
not
None
:
dataset
=
MapDataset
(
dataset
,
pre_encode_fn
,
worker_config
=
worker_config
,
stateless_map_fn
=
get_stateless
(
pre_encode_fn
),
failure_tolerance
=
get_failure_tolerance
(
pre_encode_fn
,
self
.
__default_failure_tolerance__
),
)
return
dataset
def
build_train_datasets
(
self
,
*
,
datasets
:
List
[
LoadedDataset
],
worker_config
:
WorkerConfig
,
batch_size
:
Optional
[
int
],
batch_drop_last
:
bool
=
False
,
packing_buffer_size
:
Optional
[
int
]
=
None
,
virtual_epoch_length
:
int
=
0
,
shuffle_buffer_size
:
Optional
[
int
]
=
None
,
blend_mode
:
DatasetBlendMode
=
DatasetBlendMode
.
NONE
,
repeat
:
bool
=
True
,
)
->
SavableDataset
[
T_batch
]:
"""Combines train datasets to a single dataset."""
# Check if there's a CrudeWebdataset but no cookers
for
dataset
in
datasets
:
if
isinstance
(
dataset
.
dataset
,
CrudeWebdataset
):
assert
self
.
cookers
,
"CrudeWebdataset found, but no cookers registered."
global_workers
=
max
(
1
,
worker_config
.
num_workers
)
*
worker_config
.
world_size
rotation_lengths
=
[
len
(
dataset
.
dataset
)
for
dataset
in
datasets
]
for
i
in
range
(
1
,
len
(
rotation_lengths
)):
rotation_lengths
[
i
]
+=
rotation_lengths
[
i
-
1
]
worker_rotation_offsets
=
[
rotation_length
%
global_workers
for
rotation_length
in
[
0
]
+
rotation_lengths
[:
-
1
]
]
if
blend_mode
==
DatasetBlendMode
.
DATASET_WEIGHT
:
assert
repeat
,
(
"If repeat is False, the datasets can only be repeated or have no mode. Cannot blend with dataset weights."
)
inner_datasets
=
[
(
RepeatDataset
(
self
.
_load_dataset
(
dataset
,
worker_rotation_offset
,
worker_config
=
worker_config
),
worker_config
=
worker_config
,
),
1.0
if
dataset
.
weight
is
None
else
float
(
dataset
.
weight
),
)
for
dataset
,
worker_rotation_offset
in
zip
(
datasets
,
worker_rotation_offsets
)
]
# Already repeating the inner datasets, so no need to repeat again
repeat
=
False
elif
blend_mode
==
DatasetBlendMode
.
SAMPLE_REPETITIONS
or
(
not
repeat
and
blend_mode
==
DatasetBlendMode
.
NONE
):
inner_datasets
=
[
(
(
self
.
_load_dataset
(
dataset
,
worker_rotation_offset
,
worker_config
=
worker_config
)
if
dataset
.
repetitions
is
None
or
dataset
.
repetitions
==
1
else
RepeatDataset
(
self
.
_load_dataset
(
dataset
,
worker_rotation_offset
,
worker_config
=
worker_config
),
repeats
=
dataset
.
repetitions
,
worker_config
=
worker_config
,
)
),
len
(
dataset
.
dataset
)
*
(
1
if
dataset
.
repetitions
is
None
else
dataset
.
repetitions
),
)
for
dataset
,
worker_rotation_offset
in
zip
(
datasets
,
worker_rotation_offsets
)
]
else
:
inner_datasets
=
[
(
RepeatDataset
(
self
.
_load_dataset
(
dataset
,
worker_rotation_offset
,
worker_config
=
worker_config
),
worker_config
=
worker_config
,
),
1.0
,
)
for
dataset
,
worker_rotation_offset
in
zip
(
datasets
,
worker_rotation_offsets
)
]
# Already repeating the inner datasets, so no need to repeat again
repeat
=
False
if
len
(
inner_datasets
)
>
1
:
# The worker offset for each dataset is the cumsum of the dataset lengths, but modulo the
# global number of workers.
dataset
=
BlendDataset
(
*
[
inner_dataset
[:
2
]
for
inner_dataset
in
inner_datasets
],
worker_config
=
worker_config
,
)
elif
len
(
datasets
)
==
1
:
dataset
=
inner_datasets
[
0
][
0
]
else
:
raise
ValueError
(
"No datasets given."
)
if
repeat
:
# Still need to repeat the dataset
dataset
=
RepeatDataset
(
dataset
,
worker_config
=
worker_config
)
if
shuffle_buffer_size
is
not
None
and
shuffle_buffer_size
>
1
:
dataset
=
ShuffleBufferDataset
(
dataset
,
size
=
shuffle_buffer_size
,
worker_config
=
worker_config
,
)
dataset
=
self
.
build_encode_sample
(
dataset
,
worker_config
=
worker_config
)
dataset
=
self
.
build_batch
(
dataset
,
batch_size
=
batch_size
,
batch_drop_last
=
batch_drop_last
,
packing_buffer_size
=
packing_buffer_size
,
worker_config
=
worker_config
,
)
if
virtual_epoch_length
>
0
:
dataset
=
EpochizeDataset
(
dataset
,
length
=
virtual_epoch_length
,
worker_config
=
worker_config
,
)
if
worker_config
.
should_log
(
level
=
1
):
dataset
=
LogSampleDataset
(
dataset
,
mode
=
"train"
,
worker_config
=
worker_config
)
return
dataset
def
build_val_datasets
(
self
,
*
,
datasets
:
List
[
LoadedDataset
],
worker_config
:
WorkerConfig
,
batch_size
:
int
,
batch_drop_last
:
bool
=
False
,
packing_buffer_size
:
Optional
[
int
]
=
None
,
limit
:
Optional
[
int
]
=
None
,
)
->
SavableDataset
[
T_batch
]:
"""Combines val datasets to a single dataset."""
# Check if there's a CrudeWebdataset but no cookers
for
dataset
in
datasets
:
if
isinstance
(
dataset
,
CrudeWebdataset
):
assert
self
.
cookers
,
"CrudeWebdataset found, but no cookers registered."
global_workers
=
max
(
1
,
worker_config
.
num_workers
)
*
worker_config
.
world_size
rotation_lengths
=
[
len
(
dataset
.
dataset
)
for
dataset
in
datasets
]
for
i
in
range
(
1
,
len
(
rotation_lengths
)):
rotation_lengths
[
i
]
+=
rotation_lengths
[
i
-
1
]
worker_rotation_offsets
=
[
rotation_length
%
global_workers
for
rotation_length
in
[
0
]
+
rotation_lengths
[:
-
1
]
]
if
len
(
datasets
)
>
1
:
dataset
=
ConcatDataset
(
*
[
self
.
_load_dataset
(
dataset
,
worker_rotation_offset
,
worker_config
)
for
dataset
,
worker_rotation_offset
in
zip
(
datasets
,
worker_rotation_offsets
)
],
worker_config
=
worker_config
,
)
elif
len
(
datasets
)
==
1
:
dataset
=
self
.
_load_dataset
(
datasets
[
0
],
worker_rotation_offsets
[
0
],
worker_config
)
else
:
raise
ValueError
(
"No datasets given."
)
dataset
=
self
.
build_encode_sample
(
dataset
,
worker_config
=
worker_config
)
dataset
=
self
.
build_batch
(
dataset
,
batch_size
=
batch_size
,
batch_drop_last
=
batch_drop_last
,
packing_buffer_size
=
packing_buffer_size
,
worker_config
=
worker_config
,
)
if
limit
is
not
None
and
limit
>
0
:
dataset
=
LimitDataset
(
dataset
,
length
=
limit
,
worker_config
=
worker_config
,
reset_after_epoch
=
True
,
)
if
worker_config
.
should_log
(
level
=
2
):
dataset
=
LogSampleDataset
(
dataset
,
mode
=
"val"
,
worker_config
=
worker_config
)
return
dataset
@
property
def
current_batch_index
(
self
)
->
int
:
"""Returns the current index for the next batch yielded from the current worker. Each batch
on the current rank will get a strictly increasing unique number. Counting happens on each
rank separately (i.e. each rank will get the same numbers for same batch index)."""
assert
WorkerConfig
.
active_worker_config
is
not
None
,
(
"The batch_index can only be fetched within the worker, and to be usable, you must use the get_(savable_)loader methods provided from the package."
)
return
WorkerConfig
.
active_worker_config
.
active_worker_batch_index
@
property
def
current_sample_index
(
self
)
->
int
:
"""Returns the current index for the next sample yielded from the current routine (e.g.
for `encode_sample`, `batch`, or `encode_batch`). Each routine will get a number
representing the number of calls to that function. Across workers, this number will be
unique, but it is not synced across workers, thus it may raise in different intervals (e.g.
if batching does not work the same for all batches). When restoring a sample, this number is
also restored and can be relied on for deterministic randomness reproduction of a sample."""
assert
WorkerConfig
.
active_worker_config
is
not
None
,
(
"The batch_index can only be fetched within the worker, and to be usable, you must use the get_(savable_)loader methods provided from the package."
)
return
WorkerConfig
.
active_worker_config
.
active_worker_sample_index
@
property
def
cache
(
self
)
->
CachePool
:
"""Returns the cache pool to use for caching out sample data to disk (for use with cookers / aux file stores).
This is set and configured externally by the loader."""
assert
WorkerConfig
.
active_worker_config
is
not
None
,
(
"The cache can only be fetched within the worker, and to be usable, you must use the get_(savable_)loader methods provided from the package."
)
assert
WorkerConfig
.
active_worker_config
.
_cache_pool
is
not
None
,
(
"Cache pool must be set by the loader."
)
return
WorkerConfig
.
active_worker_config
.
_cache_pool
class
DefaultTaskEncoder
(
TaskEncoder
[
T_sample
,
T_encoded_sample
,
T_raw_batch
,
T_batch
],
ABC
,
Generic
[
T_sample
,
T_encoded_sample
,
T_raw_batch
,
T_batch
],
):
"""
The default task encoder supports automagically mapping to target types.
You may override any methods to customize the behavior. By default, `encode_sample` is the
identity function, `batch` calls `
\\
_batch` with the type of the first sample, and `encode
\\
_batch`
is also the identity function. If you set any of `encoded_sample_type`, `raw_batch_type` or
`batch_type`, the corresponding method return that type, where it automatically maps the fields
(by name) to your new type.
"""
_encoded_sample_type
:
Optional
[
Type
[
T_encoded_sample
]]
_raw_batch_type
:
Optional
[
Type
[
T_raw_batch
]]
_batch_type
:
Optional
[
Type
[
T_batch
]]
def
__init__
(
self
,
*
,
encoded_sample_type
:
Optional
[
Type
[
T_encoded_sample
]]
=
None
,
raw_batch_type
:
Optional
[
Type
[
T_raw_batch
]]
=
None
,
batch_type
:
Optional
[
Type
[
T_batch
]]
=
None
,
):
"""
Initialize the default task encoder.
Types may be:
- A `@dataclass` class: Return that typed dataclass. Field names must match the input
fields.
- A `NamedTuple` class: Return that typed namedtuple. Field names must match the input
fields.
- `dict`: Simply return the input as dict with field names as keys.
Args:
encoded_sample_type: Type of encoded samples (before batching)
raw_batch_type: Type of the batched samples (after batching)
batch_type: Type of the encoded batched samples
cache: Cache pool to use for caching. If not provided, a no-op cache pool will be used.
"""
self
.
_encoded_sample_type
=
encoded_sample_type
self
.
_raw_batch_type
=
raw_batch_type
self
.
_batch_type
=
batch_type
@
stateless
def
encode_sample
(
self
,
sample
:
T_sample
)
->
Union
[
T_encoded_sample
,
Generator
[
T_encoded_sample
,
None
,
None
]]:
"""Encode a single sample. The default implementation converts to the
_encoded_sample_type."""
if
self
.
_encoded_sample_type
is
None
or
isinstance
(
sample
,
self
.
_encoded_sample_type
):
return
sample
if
is_dataclass
(
sample
):
fields
=
{
field
.
name
:
getattr
(
sample
,
field
.
name
)
for
field
in
dataclasses
.
fields
(
sample
)
}
elif
isinstance
(
sample
,
tuple
)
and
hasattr
(
sample
,
"_fields"
):
fields
=
{
field
:
getattr
(
sample
,
field
)
for
field
in
sample
.
_fields
}
elif
isinstance
(
sample
,
dict
):
fields
=
sample
else
:
raise
ValueError
(
"Unrecognized sample type."
)
if
issubclass
(
self
.
_encoded_sample_type
,
dict
):
return
fields
elif
dataclasses
.
is_dataclass
(
self
.
_encoded_sample_type
)
or
issubclass
(
self
.
_encoded_sample_type
,
tuple
):
# DataClass or NamedTuple
return
self
.
_encoded_sample_type
(
**
fields
)
else
:
raise
ValueError
(
"Unrecognized encoded sample type."
)
@
stateless
def
batch
(
self
,
samples
:
List
[
T_encoded_sample
])
->
T_raw_batch
:
"""Batch a list of samples. The default implementation uses default batching to convert
to _batch_type."""
actions
=
None
if
isinstance
(
samples
[
0
],
Sample
):
actions
=
{
"__subflavors__"
:
lambda
x
:
x
,
}
return
self
.
_batch
(
samples
,
type
(
samples
[
0
])
if
self
.
_raw_batch_type
is
None
else
self
.
_raw_batch_type
,
actions
=
actions
,
)
@
stateless
def
encode_batch
(
self
,
batch
:
T_raw_batch
)
->
Union
[
T_batch
,
Generator
[
T_batch
,
None
,
None
]]:
"""Encode a batch of samples. The default implementation converts to the
_encoded_batch_type."""
if
self
.
_batch_type
is
None
or
self
.
_raw_batch_type
==
self
.
_batch_type
:
return
batch
if
is_dataclass
(
batch
):
fields
=
{
field
.
name
:
getattr
(
batch
,
field
.
name
)
for
field
in
dataclasses
.
fields
(
batch
)}
elif
isinstance
(
batch
,
tuple
)
and
hasattr
(
batch
,
"_fields"
):
fields
=
{
field
:
getattr
(
batch
,
field
)
for
field
in
batch
.
_fields
}
elif
isinstance
(
batch
,
dict
):
fields
=
batch
else
:
raise
ValueError
(
"Unrecognized sample type."
)
if
issubclass
(
self
.
_batch_type
,
dict
):
return
fields
elif
dataclasses
.
is_dataclass
(
self
.
_batch_type
)
or
issubclass
(
self
.
_batch_type
,
tuple
):
# DataClass or NamedTuple
return
self
.
_batch_type
(
**
fields
)
else
:
raise
ValueError
(
"Unrecognized encoded sample type."
)
class
AugmentTaskEncoder
(
TaskEncoder
[
T_sample
,
T_encoded_sample
,
T_raw_batch
,
T_batch
],
Generic
[
T_sample
,
T_encoded_sample
,
T_raw_batch
,
T_batch
],
):
"""Augment a task encoder with additional functionality. By default, delegates everything to the
original task encoder."""
def
__init__
(
self
,
task_encoder
:
TaskEncoder
[
T_sample
,
T_encoded_sample
,
T_raw_batch
,
T_batch
]):
"""Initialize the augmenting task encoder.
Args:
task_encoder: The delegate task encoder. All calls will by default be forwarded to this.
"""
self
.
_task_encoder
=
task_encoder
@
property
def
decoder
(
self
)
->
SampleDecoder
:
return
self
.
_task_encoder
.
decoder
def
encode_sample
(
self
,
sample
:
T_sample
)
->
T_encoded_sample
:
return
self
.
_task_encoder
.
encode_sample
(
sample
)
def
batch
(
self
,
samples
:
List
[
T_encoded_sample
])
->
T_raw_batch
:
return
self
.
_task_encoder
.
batch
(
samples
)
def
encode_batch
(
self
,
batch_data
:
T_raw_batch
)
->
T_batch
:
return
self
.
_task_encoder
.
encode_batch
(
batch_data
)
Megatron-Energon/src/megatron/energon/task_encoder/cooking.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
dataclasses
import
functools
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Generic
,
Optional
,
TypeVar
,
Union
,
overload
from
megatron.energon.flavors.base_dataset
import
Sample
from
megatron.energon.flavors.crude
import
CrudeSample
from
megatron.energon.source_info
import
SourceInfo
T_sample
=
TypeVar
(
"T_sample"
,
bound
=
Sample
,
covariant
=
True
)
F
=
TypeVar
(
"F"
,
bound
=
Callable
[...,
Sample
])
@
overload
def
cooker
(
fn
:
None
=
None
,
)
->
Callable
[[
F
],
F
]:
...
@
overload
def
cooker
(
*
,
need_cache
:
bool
=
False
,
need_primary
:
bool
=
False
,
)
->
Callable
[[
F
],
F
]:
...
def
cooker
(
fn
:
Optional
[
F
]
=
None
,
*
,
need_cache
:
bool
=
False
,
need_primary
:
bool
=
False
,
)
->
Union
[
F
,
Callable
[[
F
],
F
],
]:
"""Decorator to mark a function as a cooker, optionally enabling cache and primary dataset
arguments."""
if
fn
is
None
:
return
functools
.
partial
(
cooker
,
need_cache
=
need_cache
,
need_primary
=
need_primary
,
)
@
functools
.
wraps
(
fn
)
def
fn_wrapper
(
*
args
,
**
kwargs
):
return
fn
(
*
args
,
**
kwargs
)
setattr
(
fn_wrapper
,
"__cooker_need_cache__"
,
need_cache
)
setattr
(
fn_wrapper
,
"__cooker_need_primary__"
,
need_primary
)
return
fn_wrapper
def
get_cooker_need_cache
(
fn
:
Callable
[...,
T_sample
])
->
bool
:
"""Get whether a function is a cooker."""
return
getattr
(
fn
,
"__cooker_need_cache__"
,
False
)
def
get_cooker_need_primary
(
fn
:
Callable
[...,
T_sample
])
->
bool
:
"""Get whether a function is a cooker."""
return
getattr
(
fn
,
"__cooker_need_primary__"
,
False
)
@
dataclass
class
Cooker
(
Generic
[
T_sample
]):
"""A cooker transforms a crude sample (simple dict) into a specific sample type inheriting
from `Sample`.
The `cook` method performs the transformation, the other fields are used to select the
samples which this cooker can transform. If no filters are provided, the cooker will transform
any `CrudeSample`.
"""
#: The callable that performs the cooking (i.e. loading / transforming the crude sample).
# Signature is:
# `(/, raw_sample: dict, *, primary?: RandomAccessDataset, **aux: RandomAccessDataset, cache?: Cache) -> Sample`.
# `primary` is passed only if want_primary_random_access is true.
# `cache` is passed only if want_cache is true.
cook
:
Callable
[...,
T_sample
]
#: The subflavors to be present in the sample to be cooked by this cooker. All keys and values
# must match.
has_subflavors
:
Optional
[
dict
]
=
None
@
property
def
need_primary
(
self
)
->
bool
:
return
get_cooker_need_primary
(
self
.
cook
)
@
property
def
need_cache
(
self
)
->
bool
:
return
get_cooker_need_cache
(
self
.
cook
)
def
is_match
(
self
,
crude_sample
:
CrudeSample
)
->
bool
:
if
self
.
has_subflavors
is
not
None
:
# Checks if the dict entries provided as a filter all match
# the ones in the sample. The sample may have additional entries.
for
k
,
v
in
self
.
has_subflavors
.
items
():
if
(
k
not
in
crude_sample
[
"__subflavors__"
]
or
crude_sample
[
"__subflavors__"
][
k
]
!=
v
):
return
False
return
True
def
basic_sample_keys
(
crude_sample
:
dict
,
additional_source_info
:
tuple
[
SourceInfo
,
...]
=
()
)
->
dict
:
"""A convenience helper to extract the basic keys from a crude sample,
which you will always need to forward to the cooked sample."""
res
=
{
field
.
name
:
crude_sample
[
field
.
name
]
for
field
in
dataclasses
.
fields
(
Sample
)
if
field
.
name
in
crude_sample
}
if
additional_source_info
:
res
[
"__sources__"
]
=
(
*
crude_sample
[
"__sources__"
],
*
additional_source_info
)
return
res
Megatron-Energon/src/megatron/energon/task_encoder/loader.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
pathlib
import
Path
from
typing
import
Any
,
List
,
Literal
,
Optional
,
Tuple
,
TypeVar
,
Union
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.errors
import
warn_deprecated
from
megatron.energon.flavors
import
BaseCoreDatasetFactory
,
SavableDataset
from
megatron.energon.flavors.webdataset.sample_decoder
import
SampleDecoder
from
megatron.energon.metadataset
import
load_dataset
from
megatron.energon.task_encoder.base
import
DefaultTaskEncoder
,
TaskEncoder
,
WorkerConfig
T
=
TypeVar
(
"T"
,
covariant
=
True
)
def
_split_kwargs
(
kwargs
:
dict
)
->
dict
:
loader_kwargs
=
{}
try
:
loader_kwargs
[
"split_part"
]
=
kwargs
.
pop
(
"split_part"
)
except
KeyError
:
pass
try
:
loader_kwargs
[
"dataset_config"
]
=
kwargs
.
pop
(
"dataset_config"
)
except
KeyError
:
pass
try
:
loader_kwargs
[
"split_config"
]
=
kwargs
.
pop
(
"split_config"
)
except
KeyError
:
pass
return
loader_kwargs
def
_split_deprecated_decoder_kwargs
(
kwargs
:
dict
,
task_encoder
:
TaskEncoder
)
->
None
:
"""
auto_decode: bool = True,
image_decode: ImageDecoder = "torchrgb",
ignore_decoder_errors: bool = False,
av_decode: AVDecoder = "AVDecoder",
video_decode_audio: bool = False,
"""
auto_decode
=
True
decoder_kwargs
=
{}
if
"auto_decode"
in
kwargs
:
auto_decode
=
kwargs
.
pop
(
"auto_decode"
)
if
"image_decode"
in
kwargs
:
decoder_kwargs
[
"image_decode"
]
=
kwargs
.
pop
(
"image_decode"
)
if
"av_decode"
in
kwargs
:
decoder_kwargs
[
"av_decode"
]
=
kwargs
.
pop
(
"av_decode"
)
if
"video_decode_audio"
in
kwargs
:
decoder_kwargs
[
"video_decode_audio"
]
=
kwargs
.
pop
(
"video_decode_audio"
)
if
not
auto_decode
:
task_encoder
.
decoder
=
None
elif
len
(
decoder_kwargs
)
>
0
:
warn_deprecated
(
"The following decoder kwargs are deprecated and will be removed in a future version: "
+
", "
.
join
(
decoder_kwargs
.
keys
())
+
". Instead, set the decoder directly in your task encoder."
)
if
(
hasattr
(
task_encoder
,
"decoder"
)
and
task_encoder
.
decoder
is
not
None
and
task_encoder
.
decoder
is
not
DefaultTaskEncoder
.
decoder
):
# The task encoder already has a decoder set.
# The user might be reusing the task encoder in multiple calls to get_train_dataset
# and get_val_dataset.
# We need to check if the decoder is the same as the one we are setting here.
# If it is, we can return.
if
task_encoder
.
decoder
.
config
()
==
SampleDecoder
(
**
decoder_kwargs
).
config
():
# It's the same decoder, nothing to do.
return
else
:
raise
ValueError
(
"Task encoder already has a decoder, and you are setting a different decoder, which is not allowed."
)
assert
(
not
hasattr
(
task_encoder
,
"decoder"
)
or
task_encoder
.
decoder
is
DefaultTaskEncoder
.
decoder
),
"Task encoder already has a decoder, and setting using deprecated kwargs is not allowed."
task_encoder
.
decoder
=
SampleDecoder
(
**
decoder_kwargs
)
def
get_train_dataset
(
path
:
Union
[
str
,
EPath
,
Path
],
*
,
split_part
:
Union
[
Literal
[
"train"
],
str
]
=
"train"
,
worker_config
:
WorkerConfig
,
batch_size
:
Optional
[
int
],
batch_drop_last
:
bool
=
False
,
packing_buffer_size
:
Optional
[
int
]
=
None
,
shuffle_buffer_size
:
Optional
[
int
],
max_samples_per_sequence
:
Optional
[
int
],
virtual_epoch_length
:
int
=
0
,
shuffle_over_epochs_multiplier
:
Optional
[
int
]
=
1
,
task_encoder
:
TaskEncoder
[
Any
,
Any
,
Any
,
T
]
=
DefaultTaskEncoder
(),
repeat
:
bool
=
True
,
**
kwargs
,
)
->
SavableDataset
[
T
]:
"""
Get training data loader with sensible defaults. See `get_dataset` for more details.
The following recipe will be used:
- :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
- `task_encoder.encode_sample`
- (:class:`megatron.energon.MixDataset` if mixing)
- :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
- `task_encoder.encode_batch`
- :class:`megatron.energon.EpochizeDataset` (if `virtual_epoch_length` is set)
Args:
path: Path to the dataset.
split_part: Default split part to use.
worker_config: Worker configuration to use.
batch_size: Size of a batch. If None, do not batch
batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding).
max_samples_per_sequence: If set, limit the number of samples per sample-sequence to this.
virtual_epoch_length: If set, the dataset will be epochized to this length (=iterating
will be suspended and the for-loop returns, next for-loop continues iterating).
Otherwise, the dataset will loop indefinitely.
shuffle_over_epochs_multiplier: Shuffle the shards over this many epochs.
task_encoder: Task encoder to use.
repeat: By default, the inner datasets will loop. If set to False, stop iteration after
one epoch. Must only be set to False in conjunction with blend_epochized in the
metadataset if one is used.
cache_pool: If set, the cache pool to use for the dataset.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The dataloader.
"""
loader
=
load_dataset
(
path
,
**
_split_kwargs
(
kwargs
))
_split_deprecated_decoder_kwargs
(
kwargs
,
task_encoder
)
datasets
=
loader
.
get_datasets
(
training
=
True
,
split_part
=
split_part
,
worker_config
=
worker_config
,
max_samples_per_sequence
=
max_samples_per_sequence
,
shuffle_over_epochs_multiplier
=
shuffle_over_epochs_multiplier
,
decoder
=
task_encoder
.
decoder
,
**
kwargs
,
)
return
task_encoder
.
build_train_datasets
(
datasets
=
datasets
.
datasets
,
worker_config
=
worker_config
,
batch_size
=
batch_size
,
batch_drop_last
=
batch_drop_last
,
packing_buffer_size
=
packing_buffer_size
,
virtual_epoch_length
=
virtual_epoch_length
,
shuffle_buffer_size
=
shuffle_buffer_size
,
blend_mode
=
datasets
.
blend_mode
,
repeat
=
repeat
,
)
def
get_val_dataset
(
path
:
Union
[
str
,
EPath
,
Path
],
*
,
split_part
:
Union
[
Literal
[
"val"
,
"test"
],
str
]
=
"val"
,
worker_config
:
WorkerConfig
,
batch_size
:
int
,
batch_drop_last
:
bool
=
False
,
packing_buffer_size
:
Optional
[
int
]
=
None
,
limit
:
Optional
[
int
]
=
None
,
task_encoder
:
TaskEncoder
[
Any
,
Any
,
Any
,
T
]
=
DefaultTaskEncoder
(),
**
kwargs
,
)
->
SavableDataset
[
T
]:
"""
Get the validation/test dataset with sensible defaults. See `get_dataset` for more details.
The following recipe will be used:
- :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
- `task_encoder.encode_sample`
- (:class:`megatron.energon.MixDataset` if mixing)
- :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
- :class:`megatron.energon.LimitDataset` (if `limit` is set)
- `task_encoder.encode_batch`
Args:
path: Path to the dataset.
split_part: Default split part to use.
worker_config: Worker configuration to use.
batch_size: Size of a batch
batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
limit: If set, limit the number of batches loaded from the dataset to this.
task_encoder: Task encoder to use.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The loaded dataset.
"""
_split_deprecated_decoder_kwargs
(
kwargs
,
task_encoder
)
loader
=
load_dataset
(
path
,
**
_split_kwargs
(
kwargs
))
datasets
=
loader
.
get_datasets
(
training
=
False
,
split_part
=
split_part
,
worker_config
=
worker_config
,
decoder
=
task_encoder
.
decoder
,
**
kwargs
,
)
return
task_encoder
.
build_val_datasets
(
datasets
=
datasets
.
datasets
,
worker_config
=
worker_config
,
batch_size
=
batch_size
,
batch_drop_last
=
batch_drop_last
,
packing_buffer_size
=
packing_buffer_size
,
limit
=
limit
,
)
def
get_val_datasets
(
path
:
Union
[
str
,
EPath
,
Path
],
*
,
split_part
:
Union
[
Literal
[
"val"
,
"test"
],
str
]
=
"val"
,
worker_config
:
WorkerConfig
,
batch_size
:
int
,
batch_drop_last
:
bool
=
False
,
packing_buffer_size
:
Optional
[
int
]
=
None
,
limit
:
Optional
[
int
]
=
None
,
task_encoder
:
TaskEncoder
[
Any
,
Any
,
Any
,
T
]
=
DefaultTaskEncoder
(),
**
kwargs
,
)
->
List
[
Tuple
[
SavableDataset
[
T
],
BaseCoreDatasetFactory
]]:
"""
Get the validation/test dataset with sensible defaults. See `get_dataset` for more details.
The following recipe will be used:
- :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
- `task_encoder.encode_sample`
- (:class:`megatron.energon.MixDataset` if mixing)
- :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
- :class:`megatron.energon.LimitDataset` (if `limit` is set)
- `task_encoder.encode_batch`
Args:
path: Path to the dataset.
split_part: Default split part to use.
worker_config: Worker configuration to use.
batch_size: Size of a batch
batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
limit: If set, limit the number of batches loaded from the dataset to this.
task_encoder: Task encoder to use.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The loaded val datasets, with the source datasets.
"""
_split_deprecated_decoder_kwargs
(
kwargs
,
task_encoder
)
loader
=
load_dataset
(
path
,
**
_split_kwargs
(
kwargs
))
datasets
=
loader
.
get_datasets
(
training
=
False
,
split_part
=
split_part
,
worker_config
=
worker_config
,
decoder
=
task_encoder
.
decoder
,
**
kwargs
,
)
return
[
(
task_encoder
.
build_val_datasets
(
datasets
=
[
dataset
],
worker_config
=
worker_config
,
batch_size
=
batch_size
,
batch_drop_last
=
batch_drop_last
,
packing_buffer_size
=
packing_buffer_size
,
limit
=
limit
,
),
dataset
.
dataset
,
)
for
dataset
in
datasets
.
datasets
]
Megatron-Energon/src/megatron/energon/tools/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
Megatron-Energon/src/megatron/energon/tools/analyze_debug.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
json
import
time
import
traceback
from
concurrent.futures.process
import
ProcessPoolExecutor
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
(
Container
,
Dict
,
Generator
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
,
)
import
click
import
numpy
as
np
from
PIL
import
Image
from
tqdm
import
tqdm
cpal
=
np
.
array
(
[
[
int
(
x
)
for
x
in
line
.
split
(
" "
)]
for
line
in
"""255 255 255
1 0 103
213 255 0
255 0 86
158 0 142
14 76 161
255 229 2
0 95 57
0 255 0
149 0 58
255 147 126
164 36 0
0 21 68
145 208 203
98 14 0
107 104 130
0 0 255
0 125 181
106 130 108
0 174 126
194 140 159
190 153 112
0 143 156
95 173 78
255 0 0
255 0 246
255 2 157
104 61 59
255 116 163
150 138 232
152 255 82
167 87 64
1 255 254
255 238 232
254 137 0
189 198 255
1 208 255
187 136 0
117 68 177
165 255 210
255 166 254
119 77 0
122 71 130
38 52 0
0 71 84
67 0 44
181 0 255
255 177 103
255 219 102
144 251 146
126 45 210
189 211 147
229 111 254
222 255 116
0 255 120
0 155 255
0 100 1
0 118 255
133 169 0
0 185 23
120 130 49
0 255 198
255 110 65
232 94 190"""
.
split
(
"
\n
"
)
],
dtype
=
np
.
int32
,
)
class
YieldBatchLogLine
(
TypedDict
):
# Json example:
# {
# "t": "yield_batch",
# "r": 1,
# "w": 1,
# "m": "train",
# "idx": 1,
# "keys": ["parts/data-train-000051.tar/528866", ...],
# }
t
:
Literal
[
"yield_batch"
]
r
:
int
w
:
int
m
:
Literal
[
"train"
,
"val"
]
idx
:
int
keys
:
List
[
str
]
class
SampleLoaderYieldLogLine
(
TypedDict
):
# Json example:
# {
# "t": "WebdatasetSampleLoaderDataset._slices_iter.yield",
# "r": 1,
# "w": 1,
# "index": 528800,
# "key": "parts/data-train-000051.tar/528866",
# "shard": "parts/data-train-000051.tar",
# "count": 633,
# "epoch": 0,
# "epoch_count": 633
# }
t
:
Literal
[
"WebdatasetSampleLoaderDataset._slices_iter.yield"
]
r
:
int
w
:
int
#: The global index in the underlying dataset (concats of all shards)
index
:
int
#: The sample key from the shard, concatenated as f"{shard}/{key}"
key
:
str
#: Name of the shard
shard
:
str
#: Number of samples yielded from the sample loader over all epochs
count
:
int
#: Number of repetitions of the dataset (=epochs). First epoch is 0.
epoch
:
int
#: Number of samples yielded from the sample loader in the current epoch
epoch_count
:
int
class
AutosizingHeatmapWriter
:
"""Writes a heatmap, automatically resizing it if necessary."""
def
__init__
(
self
,
heatmap_samples
:
int
,
heatmap_steps
:
int
,
colorize
:
bool
=
True
):
self
.
heatmap
=
np
.
zeros
((
heatmap_samples
,
heatmap_steps
,
3
),
dtype
=
np
.
int32
)
self
.
heatmap_sample_factor
=
1
self
.
heatmap_step_factor
=
1
self
.
heatmap_sample_max
=
-
1
self
.
heatmap_step_max
=
-
1
self
.
colors_size
=
cpal
.
shape
[
0
]
if
colorize
else
1
def
add
(
self
,
sample_id
:
int
,
step
:
int
,
src
:
int
)
->
None
:
"""
Add a point to the heatmap (i.e. increase count at that position).
Args:
sample_id: The sample id (y-axis)
step: The step (x-axis)
"""
# Resize heatmap?
while
self
.
heatmap
.
shape
[
0
]
*
self
.
heatmap_sample_factor
<=
sample_id
:
self
.
heatmap
[:
self
.
heatmap
.
shape
[
0
]
//
2
]
=
self
.
heatmap
[::
2
]
+
self
.
heatmap
[
1
::
2
]
self
.
heatmap
[
self
.
heatmap
.
shape
[
0
]
//
2
:]
=
0
self
.
heatmap_sample_factor
*=
2
self
.
heatmap_sample_max
=
0
while
self
.
heatmap
.
shape
[
1
]
*
self
.
heatmap_step_factor
<=
step
:
self
.
heatmap
[:,
:
self
.
heatmap
.
shape
[
1
]
//
2
]
=
(
self
.
heatmap
[:,
::
2
]
+
self
.
heatmap
[:,
1
::
2
]
)
self
.
heatmap
[:,
self
.
heatmap
.
shape
[
1
]
//
2
:]
=
0
self
.
heatmap_step_factor
*=
2
self
.
heatmap_step_max
=
0
# Save point
step
//=
self
.
heatmap_step_factor
sample_id
//=
self
.
heatmap_sample_factor
self
.
heatmap
[
sample_id
,
step
]
+=
cpal
[
src
%
self
.
colors_size
]
self
.
heatmap_step_max
=
max
(
self
.
heatmap_step_max
,
step
)
self
.
heatmap_sample_max
=
max
(
self
.
heatmap_sample_max
,
sample_id
)
def
save
(
self
,
path
:
Union
[
Path
,
str
],
gain
:
float
):
"""
Save the heatmap to the given path.
Args:
path: The path to save the heatmap to.
gain: The gain (=multiplication factor) for the heatmap.
Returns:
The maximum sample id and step id that were used in the heatmap.
"""
heatmap
=
self
.
heatmap
[:
self
.
heatmap_sample_max
+
1
,
:
self
.
heatmap_step_max
+
1
]
heatmap
=
heatmap
.
astype
(
np
.
float32
)
heatmap
=
np
.
clip
(
heatmap
*
gain
/
heatmap
.
max
((
0
,
1
))
*
255
,
0
,
255
).
astype
(
np
.
uint8
)
Image
.
fromarray
(
heatmap
).
save
(
path
)
return
(
self
.
heatmap_sample_max
*
self
.
heatmap_sample_factor
,
self
.
heatmap_step_max
*
self
.
heatmap_step_factor
,
)
@
click
.
command
(
name
=
"analyze-debug"
)
@
click
.
argument
(
"log_paths"
,
nargs
=-
1
,
type
=
click
.
Path
(
exists
=
True
,
file_okay
=
True
,
dir_okay
=
True
,
path_type
=
Path
),
)
@
click
.
option
(
"--heatmap-path"
,
type
=
click
.
Path
(
exists
=
False
,
writable
=
True
,
dir_okay
=
False
,
path_type
=
Path
),
default
=
Path
(
"heatmap.png"
),
)
@
click
.
option
(
"--heatmap-steps"
,
type
=
int
,
default
=
1000
,
help
=
"Size of the heatmap in step direction. All steps will be downscaled to this size."
,
)
@
click
.
option
(
"--heatmap-samples"
,
type
=
int
,
default
=
1000
,
help
=
"Size of the heatmap in sample direction. All samples will be downscaled to this size."
,
)
@
click
.
option
(
"--heatmap-gain"
,
type
=
float
,
default
=
10
,
help
=
"Gain (=multiplication factor) for the heatmap"
,
)
@
click
.
option
(
"--force-loading-order"
,
is_flag
=
True
,
default
=
False
,
help
=
"If true, force using the dataloader loading order instead of batch data"
,
)
@
click
.
option
(
"--include-modality"
,
type
=
str
,
default
=
"train"
,
help
=
"Choose which modality/modalities (train,val) to include. Comma separate for multiple."
,
)
@
click
.
option
(
"--skip"
,
type
=
int
,
default
=
0
,
help
=
"If >0, skip this many steps at the beginning of log file parsing."
,
)
@
click
.
option
(
"--no-colors"
,
is_flag
=
True
,
default
=
False
,
help
=
"If set, disable colorizing ranks."
,
)
def
command
(
log_paths
:
List
[
Path
],
heatmap_path
:
Path
,
heatmap_steps
:
int
,
heatmap_samples
:
int
,
heatmap_gain
:
float
,
force_loading_order
:
bool
,
include_modality
:
str
,
skip
:
int
,
no_colors
:
bool
,
):
"""Internal tool to analyze randomness.
The LOG_PATH should point to the folder with the debug log, or to a single log file."""
if
len
(
log_paths
)
==
0
:
raise
click
.
ClickException
(
"No log paths specified"
)
log_files
=
[]
for
log_path
in
log_paths
:
if
log_path
.
is_dir
():
log_files
.
extend
(
sorted
(
log_path
.
glob
(
"*.jsonl"
)))
elif
log_path
.
is_file
():
log_files
.
append
(
log_path
)
else
:
raise
click
.
ClickException
(
f
"Invalid log path:
{
log_path
}
"
)
if
len
(
log_files
)
==
0
:
raise
click
.
ClickException
(
"No log files found"
)
heatmap
=
AutosizingHeatmapWriter
(
heatmap_samples
,
heatmap_steps
,
colorize
=
not
no_colors
)
print
(
f
"Analyzing
{
len
(
log_files
)
}
logs..."
)
modalities
=
[
m
.
strip
()
for
m
in
include_modality
.
split
(
","
)]
key_index
=
{}
count
=
0
if
not
force_loading_order
:
loaders
=
[
LoaderLogIter
(
log_file
,
start_idx
=
skip
)
for
log_file
in
log_files
]
loaders_by_id
:
Dict
[
int
,
Tuple
[
LoaderInfo
,
List
[
LoaderLogIter
]]]
=
{}
with
ProcessPoolExecutor
(
max_workers
=
16
)
as
executor
:
for
loader
,
loader_info
in
tqdm
(
executor
.
map
(
_proc_map_loader
,
loaders
),
total
=
len
(
loaders
)
):
for
loader_id
,
loader_info
in
loader_info
.
items
():
if
loader_id
in
loaders_by_id
:
existing_loader_info
,
existing_loaders
=
loaders_by_id
[
loader_id
]
assert
(
existing_loader_info
.
modality
==
loader_info
.
modality
and
existing_loader_info
.
path
==
loader_info
.
path
),
(
f
"Found multiple loaders for
{
loader_id
}
:
{
existing_loader_info
.
modality
,
existing_loader_info
.
path
}
and
{
loader_info
.
modality
,
loader_info
.
path
}
"
)
existing_loader_info
.
global_count
=
max
(
existing_loader_info
.
global_count
,
loader_info
.
global_count
)
existing_loaders
.
append
(
loader
)
else
:
loaders_by_id
[
loader_id
]
=
(
loader_info
,
[
loader
])
print
(
"Available loaders:"
)
selected_loader_id
=
None
must_select
=
False
for
loader_id
,
(
loader_info
,
_iters
)
in
loaders_by_id
.
items
():
print
(
f
"
{
loader_id
}
:
{
loader_info
.
modality
}
{
loader_info
.
path
}
{
loader_info
.
global_count
}
steps"
)
if
loader_info
.
modality
in
modalities
:
if
selected_loader_id
is
None
:
selected_loader_id
=
loader_id
else
:
# Have multiple loaders
must_select
=
True
if
must_select
:
while
True
:
loader_id_str
=
input
(
"Choose loader id: "
)
try
:
selected_loader_id
=
int
(
loader_id_str
)
except
ValueError
:
print
(
f
"Invalid loader id
{
loader_id_str
}
1"
)
continue
if
selected_loader_id
in
loaders_by_id
:
break
print
(
f
"Invalid loader id
{
selected_loader_id
}
"
)
assert
selected_loader_id
is
not
None
selected_loader_info
,
selected_loader_readers
=
loaders_by_id
[
selected_loader_id
]
print
(
f
"Reading for loader
{
selected_loader_id
}
:
{
selected_loader_info
.
modality
}
{
selected_loader_info
.
path
}
"
)
log_iters
=
[
(
idx
,
loader
.
log_entries
(
loader_ids
=
{
selected_loader_id
}))
for
idx
,
loader
in
enumerate
(
selected_loader_readers
)
]
with
tqdm
(
total
=
selected_loader_info
.
global_count
)
as
pbar
:
while
len
(
log_iters
)
>
0
:
cur_count
=
0
# Iterate over all iterators for this count and put into heatmap
for
src_idx
,
log_iter
in
tuple
(
log_iters
):
# Iterate until None (=next count) is encountered
while
True
:
try
:
log_keys
=
next
(
log_iter
)
except
StopIteration
:
log_iters
.
remove
((
src_idx
,
log_iter
))
break
except
OSError
:
traceback
.
print_exc
()
log_iters
.
remove
((
src_idx
,
log_iter
))
break
else
:
if
log_keys
is
None
:
break
for
log_key
in
log_keys
:
key_id
=
key_index
.
setdefault
(
log_key
,
len
(
key_index
))
heatmap
.
add
(
key_id
,
count
,
src_idx
)
cur_count
+=
1
if
cur_count
==
0
:
print
(
f
"No data for step
{
count
}
"
)
count
+=
1
pbar
.
update
(
1
)
if
len
(
key_index
)
==
0
:
if
force_loading_order
:
print
(
"Forcing to use sample loader logs"
)
else
:
print
(
"No batch information in logs, trying sample loader logs..."
)
if
modalities
!=
{
"train"
,
"val"
}:
print
(
" Data includes all modalities (train and val)"
)
print
(
" Shuffle buffer and batching will not be considered, only the loading order from disk"
)
log_iters
=
[
_iter_sl_log_line_keys
(
_iter_sl_log_samples
(
log_file
),
start_idx
=
skip
)
for
log_file
in
log_files
]
key_index
=
{}
count
=
0
start
=
time
.
time
()
while
len
(
log_iters
)
>
0
:
cur_count
=
0
# Iterate over all iterators for this count and put into heatmap
for
log_iter
in
tuple
(
log_iters
):
# Iterate until None (=next count) is encountered
while
True
:
try
:
log_key
=
next
(
log_iter
)
except
StopIteration
:
log_iters
.
remove
(
log_iter
)
break
except
OSError
:
traceback
.
print_exc
()
log_iters
.
remove
(
log_iter
)
break
else
:
if
log_key
is
None
:
break
key_id
=
key_index
.
setdefault
(
log_key
,
len
(
key_index
))
heatmap
.
add
(
key_id
,
count
)
cur_count
+=
1
if
cur_count
==
0
:
print
(
f
"No data for step
{
count
}
"
)
if
time
.
time
()
-
start
>
10
:
print
(
f
" Step
{
count
}
"
)
start
=
time
.
time
()
count
+=
1
if
count
==
0
:
raise
click
.
ClickException
(
"No data found in logs"
)
print
(
f
"Found
{
len
(
key_index
)
}
unique sample keys,
{
count
}
steps"
)
# print(f"Heatmap factors: {heatmap_sample_factor} samples, {heatmap_step_factor} steps")
# print(f"Heatmap max: {heatmap_sample_max} samples, {heatmap_step_max} steps")
n_samples
,
n_steps
=
heatmap
.
save
(
heatmap_path
,
heatmap_gain
)
print
(
f
"Wrote heatmap to
{
heatmap_path
}
"
)
print
(
"Heatmap axes:"
)
print
(
f
" x-axis:
{
n_steps
}
worker steps"
)
print
(
f
" y-axis:
{
n_samples
}
samples"
)
class
LoaderInitLogLine
(
TypedDict
):
t
:
Literal
[
"SavableLoader.__init__"
,
"BasicDataLoader.__init__"
]
r
:
int
w
:
None
id
:
int
config
:
dict
class
LoaderIterLogLine
(
TypedDict
):
t
:
Literal
[
"SavableDataLoader.iter"
,
"BasicDataLoader.iter"
]
r
:
int
w
:
None
id
:
int
iter_id
:
int
class
LoaderYieldLogLine
(
TypedDict
):
t
:
Literal
[
"SavableDataLoader.yield"
,
"BasicDataLoader.yield"
]
r
:
int
w
:
None
id
:
int
iter_id
:
int
worker_id
:
int
worker_idx
:
int
idx
:
int
iter_idx
:
int
global_idx
:
int
keys
:
Optional
[
List
[
str
]]
class
LoaderStopLogLine
(
TypedDict
):
t
:
Literal
[
"SavableDataLoader.StopIteration"
,
"BasicDataLoader.StopIteration"
]
r
:
int
w
:
None
id
:
int
iter_id
:
int
LoaderLines
=
Union
[
LoaderInitLogLine
,
LoaderIterLogLine
,
LoaderYieldLogLine
,
LoaderStopLogLine
,
]
LOADER_LOG_LINE_TYPES_T
=
(
"SavableLoader.__init__"
,
"BasicDataLoader.__init__"
,
"SavableDataLoader.iter"
,
"BasicDataLoader.iter"
,
"SavableDataLoader.yield"
,
"BasicDataLoader.yield"
,
"SavableDataLoader.StopIteration"
,
"BasicDataLoader.StopIteration"
,
)
@
dataclass
class
LoaderInfo
:
id
:
int
modality
:
str
path
:
str
global_count
:
int
class
LoaderLogIter
:
def
__init__
(
self
,
path
:
Path
,
start_idx
:
int
=
0
):
self
.
_path
=
path
self
.
_start_idx
=
start_idx
def
_iter_log_lines
(
self
,
which
:
Iterable
[
str
])
->
Generator
[
LoaderLines
,
None
,
None
]:
try
:
with
self
.
_path
.
open
(
"r"
)
as
rf
:
for
line
in
rf
:
if
any
(
f
'"t": "
{
t
}
"'
in
line
for
t
in
which
):
try
:
yield
json
.
loads
(
line
.
strip
())
except
json
.
JSONDecodeError
:
print
(
"Cannot decode line"
,
repr
(
line
))
except
IOError
as
e
:
print
(
f
"Ignoring IOError:
{
e
}
for
{
self
.
_path
}
"
)
@
staticmethod
def
_find_config_modality
(
config
:
dict
)
->
Literal
[
"train"
,
"val"
]:
assert
isinstance
(
config
,
dict
)
if
"map_fn_config"
in
config
and
"training"
in
config
[
"map_fn_config"
]:
return
"train"
if
config
[
"map_fn_config"
][
"training"
]
else
"val"
elif
"dataset"
in
config
:
return
LoaderLogIter
.
_find_config_modality
(
config
[
"dataset"
])
elif
"dataset_weights"
in
config
:
return
LoaderLogIter
.
_find_config_modality
(
config
[
"dataset_weights"
][
0
][
0
])
elif
"datasets"
in
config
:
return
LoaderLogIter
.
_find_config_modality
(
config
[
"datasets"
][
0
])
assert
False
,
f
"Unrecognized config
{
config
}
"
@
staticmethod
def
_find_config_path
(
config
:
dict
)
->
str
:
assert
isinstance
(
config
,
dict
)
if
"map_fn_config"
in
config
and
"_path"
in
config
[
"map_fn_config"
]:
return
config
[
"map_fn_config"
][
"_path"
]
elif
"dataset"
in
config
:
return
LoaderLogIter
.
_find_config_path
(
config
[
"dataset"
])
elif
"dataset_weights"
in
config
:
return
LoaderLogIter
.
_find_config_path
(
config
[
"dataset_weights"
][
0
][
0
])
elif
"datasets"
in
config
:
return
LoaderLogIter
.
_find_config_path
(
config
[
"datasets"
][
0
])
assert
False
,
f
"Unrecognized config
{
config
}
"
def
loaders
(
self
)
->
Dict
[
int
,
LoaderInfo
]:
loaders
=
{}
for
log_line
in
self
.
_iter_log_lines
(
(
"SavableLoader.__init__"
,
"BasicDataLoader.__init__"
,
"SavableDataLoader.yield"
,
"BasicDataLoader.yield"
,
)
):
if
log_line
[
"t"
]
in
(
"SavableLoader.__init__"
,
"BasicDataLoader.__init__"
):
loaders
[
log_line
[
"id"
]]
=
LoaderInfo
(
id
=
log_line
[
"id"
],
modality
=
self
.
_find_config_modality
(
log_line
[
"config"
]),
path
=
self
.
_find_config_path
(
log_line
[
"config"
]),
global_count
=
0
,
)
elif
log_line
[
"t"
]
in
(
"SavableDataLoader.yield"
,
"BasicDataLoader.yield"
):
loaders
[
log_line
[
"id"
]].
global_count
=
log_line
[
"global_idx"
]
return
loaders
def
log_entries
(
self
,
loader_ids
:
Container
[
int
])
->
Generator
[
Optional
[
List
[
str
]],
None
,
None
]:
idx
=
self
.
_start_idx
for
log_line
in
self
.
_iter_log_lines
((
"SavableDataLoader.yield"
,
"BasicDataLoader.yield"
)):
if
(
log_line
[
"t"
]
in
(
"SavableDataLoader.yield"
,
"BasicDataLoader.yield"
)
and
log_line
[
"id"
]
in
loader_ids
):
assert
log_line
[
"global_idx"
]
>=
idx
,
(
f
"Found entry
{
log_line
}
with wrong idx <
{
idx
}
"
)
while
log_line
[
"global_idx"
]
!=
idx
:
yield
None
idx
+=
1
if
"keys"
in
log_line
:
yield
log_line
[
"keys"
]
def
__repr__
(
self
)
->
str
:
return
f
"log(
{
str
(
self
.
_path
)
}
)"
def
_proc_map_loader
(
loader
:
LoaderLogIter
)
->
Tuple
[
LoaderLogIter
,
Dict
[
int
,
LoaderInfo
]]:
return
(
loader
,
loader
.
loaders
())
def
_iter_sl_log_line_keys
(
log_lines
:
Iterable
[
SampleLoaderYieldLogLine
],
start_idx
:
int
=
0
,
)
->
Generator
[
Optional
[
str
],
None
,
None
]:
count
=
start_idx
for
log_line
in
log_lines
:
if
log_line
[
"count"
]
<
start_idx
:
continue
assert
log_line
[
"count"
]
>=
count
while
log_line
[
"count"
]
!=
count
:
yield
None
count
+=
1
yield
log_line
[
"key"
]
def
_iter_sl_log_samples
(
path
:
Path
)
->
Generator
[
SampleLoaderYieldLogLine
,
None
,
None
]:
with
path
.
open
(
"r"
)
as
rf
:
for
line
in
rf
:
if
'"t": "WebdatasetSampleLoaderDataset._slices_iter.yield"'
in
line
:
try
:
yield
json
.
loads
(
line
.
strip
())
except
json
.
JSONDecodeError
:
print
(
"Cannot decode line"
,
repr
(
line
))
if
__name__
==
"__main__"
:
command
()
Megatron-Energon/src/megatron/energon/tools/checkpoint.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
re
from
typing
import
List
,
Optional
import
click
import
torch
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.savable_loader
import
SavableDataLoaderState
def
natural_sort_key
(
s
):
"""
Function to use for natural sorting of filenames.
This splits the input string by numbers and non-numbers and ensures
that numbers are compared as integers, not as strings.
"""
return
[
int
(
text
)
if
text
.
isdigit
()
else
text
.
lower
()
for
text
in
re
.
split
(
r
"(\d+)"
,
s
)]
def
detect_and_replicate_pattern
(
file_list
):
"""
Given a list of file paths, detect the single numeric pattern and return
a function that, when called with integer n (starting from 0), generates
the n-th filename following that pattern.
Raises an Exception if no pattern or multiple patterns are found.
"""
if
not
file_list
:
raise
ValueError
(
"Cannot detect a pattern from an empty list."
)
# -- 1) Sort the list using a natural key so that numbers compare numerically
sorted_files
=
sorted
(
file_list
,
key
=
natural_sort_key
)
# -- 2) Tokenize each filename into [text, number, text, number, ...].
# We'll look for the pattern of tokens across all files.
def
tokenize_filename
(
fname
):
# Use the same split so that digit tokens are separated
# from non-digit tokens.
tokens
=
re
.
split
(
r
"(\d+)"
,
fname
)
# tokens is like ["f", "001", ".txt"] for "f001.txt"
return
tokens
tokenized
=
[
tokenize_filename
(
f
)
for
f
in
sorted_files
]
# Check that all tokenized filenames have the same number of chunks:
token_len
=
len
(
tokenized
[
0
])
for
t
in
tokenized
:
if
len
(
t
)
!=
token_len
:
raise
Exception
(
"Filenames do not share a consistent token structure."
)
# -- 3) Identify exactly one numeric token position that changes across all files.
# All other positions must be identical across the entire list.
num_positions
=
[]
# positions in the token list that differ
for
pos
in
range
(
token_len
):
# Check if this chunk is the same for all files or not:
# We compare "raw text" for non-digit chunks, and "integer value" for digit chunks.
# For the first file's token, check if it's digits or not
example_token
=
tokenized
[
0
][
pos
]
example_is_digit
=
example_token
.
isdigit
()
# Collect how all files differ at this position
all_tokens_at_pos
=
[
t
[
pos
]
for
t
in
tokenized
]
# If it's supposed to be a numeric token,
# we compare the integer values to see if they differ or not.
# If it's a non-numeric token, they must all be identical.
if
example_is_digit
:
# Parse integer values
values
=
[
int
(
x
)
if
x
.
isdigit
()
else
None
for
x
in
all_tokens_at_pos
]
# If *any* of them is None or they vary, we track that as "differences".
# But let's see if indeed they differ across the files or not.
if
len
(
set
(
values
))
>
1
:
# This token position changes among files
num_positions
.
append
(
pos
)
else
:
# The numeric token is the same for all files, so no variation here
pass
else
:
# Non-digit token, must be identical across all files
if
len
(
set
(
all_tokens_at_pos
))
!=
1
:
raise
Exception
(
"Non-digit token differs among files. Invalid pattern."
)
# We expect exactly 1 changing numeric token position
if
len
(
num_positions
)
==
0
:
raise
Exception
(
"No numeric portion found that differs among files."
)
if
len
(
num_positions
)
>
1
:
raise
Exception
(
"Multiple numeric portions found that differ. Not a single pattern."
)
varying_pos
=
num_positions
[
0
]
# -- 4) Extract the numeric values of that varying position for all sorted files,
# check consecutive increments and find the zero-padding width.
numeric_values
=
[
int
(
t
[
varying_pos
])
for
t
in
tokenized
]
# Check if consecutive differences are all +1
for
i
in
range
(
len
(
numeric_values
)
-
1
):
if
numeric_values
[
i
+
1
]
-
numeric_values
[
i
]
!=
1
:
raise
Exception
(
"Numeric values are not consecutive. Pattern is invalid."
)
# The "base" number is numeric_values[0], i.e. the value for n=0
base_value
=
numeric_values
[
0
]
# The zero-padding width is based on the first file's numeric token
zero_padding_width
=
len
(
tokenized
[
0
][
varying_pos
])
# -- 5) Construct the function that, given n, returns the enumerated filename.
# We'll verify it against the original sorted list as well.
def
generate_filename
(
n
):
# Rebuild the token array from the first file's tokens,
# except we replace the one numeric token with (base_value + n) zero-padded.
new_tokens
=
tokenized
[
0
][:]
new_int_value
=
base_value
+
n
# zero-pad with the discovered width
new_str_value
=
str
(
new_int_value
).
zfill
(
zero_padding_width
)
# Replace the numeric position
new_tokens
[
varying_pos
]
=
new_str_value
# Join all tokens back into a string
return
""
.
join
(
new_tokens
)
# -- 6) Verify that generate_filename(i) reproduces the sorted list exactly
# for i in [0..len(sorted_files)-1].
for
i
in
range
(
len
(
sorted_files
)):
candidate
=
generate_filename
(
i
)
if
candidate
!=
sorted_files
[
i
]:
raise
Exception
(
"Verification failed. The generated pattern does not match the input list."
)
# If we get here, everything is good. Return the generator function.
return
generate_filename
class
RankStateIterable
:
"""Iterates the SavableDatasetCheckpoints of mulitple ranks in a round-robin fashion."""
def
__init__
(
self
,
state_files
:
List
[
EPath
]):
state_file_names
=
[
state_file
.
name
for
state_file
in
state_files
]
self
.
file_pattern_func
=
detect_and_replicate_pattern
(
state_file_names
)
self
.
num_states
=
len
(
state_files
)
# First open the first one to figure out if this is a global checkpoint or not
first_state
=
torch
.
load
(
str
(
state_files
[
0
]),
weights_only
=
False
)
if
isinstance
(
first_state
,
dict
)
and
"dataloader_state_dict"
in
first_state
:
self
.
megatron_style
=
True
first_state
=
first_state
[
"dataloader_state_dict"
]
else
:
self
.
megatron_style
=
False
if
isinstance
(
first_state
,
SavableDataLoaderState
):
if
self
.
megatron_style
:
self
.
rank_states
=
[
first_state
]
+
[
torch
.
load
(
str
(
state_file
),
weights_only
=
False
)[
"dataloader_state_dict"
]
for
state_file
in
state_files
[
1
:]
]
else
:
self
.
rank_states
=
[
first_state
]
+
[
torch
.
load
(
str
(
state_file
),
weights_only
=
False
)
for
state_file
in
state_files
[
1
:]
]
self
.
is_global_checkpoint
=
False
elif
isinstance
(
first_state
,
list
):
assert
len
(
state_files
)
==
1
,
"Global checkpoint must contain exactly one file"
assert
all
(
isinstance
(
state
,
SavableDataLoaderState
)
for
state
in
first_state
)
self
.
rank_states
=
first_state
self
.
is_global_checkpoint
=
True
else
:
raise
ValueError
(
f
"Unknown checkpoint type:
{
type
(
first_state
)
}
"
)
self
.
rank_cur_worker
=
[
0
]
*
len
(
self
.
rank_states
)
self
.
rank_worker_offset
=
[
state
.
next_worker_id
for
state
in
self
.
rank_states
]
self
.
rank_num_workers
=
[
len
(
state
.
worker_states
)
for
state
in
self
.
rank_states
]
assert
all
(
self
.
rank_num_workers
[
0
]
==
num_workers
for
num_workers
in
self
.
rank_num_workers
),
"All ranks must have the same number of workers."
def
write_new_states_to_folder
(
self
,
output_folder
:
EPath
,
new_states
:
List
[
SavableDataLoaderState
]
):
for
rank_idx
,
rank_state
in
enumerate
(
new_states
):
output_file
=
output_folder
/
self
.
file_pattern_func
(
rank_idx
)
if
self
.
megatron_style
:
torch
.
save
(
{
"dataloader_state_dict"
:
rank_state
},
str
(
output_file
),
)
else
:
torch
.
save
(
rank_state
,
str
(
output_file
))
def
get_num_ranks
(
self
):
return
len
(
self
.
rank_states
)
def
get_num_workers
(
self
):
return
self
.
rank_num_workers
[
0
]
def
get_micro_batch_size
(
self
):
return
self
.
rank_states
[
0
].
micro_batch_size
def
__iter__
(
self
):
"""Iterates the SavableDatasetCheckpoints of mulitple ranks in a round-robin fashion."""
for
rank
,
state
in
enumerate
(
self
.
rank_states
):
for
worker_state
in
state
.
worker_states
:
yield
worker_state
@
click
.
command
(
name
=
"redist"
)
@
click
.
argument
(
"input_files"
,
nargs
=-
1
,
type
=
click
.
Path
(
file_okay
=
True
,
dir_okay
=
False
,
exists
=
True
,
path_type
=
EPath
),
required
=
True
,
)
@
click
.
argument
(
"output_path"
,
type
=
click
.
Path
(
file_okay
=
False
,
dir_okay
=
True
,
path_type
=
EPath
),
)
@
click
.
option
(
"--new-world-size"
,
type
=
int
,
help
=
"Number of ranks to redistribute to"
,
required
=
False
)
def
command_redist
(
input_files
:
List
[
EPath
],
output_path
:
EPath
,
new_world_size
:
Optional
[
int
]
=
None
):
"""Redistribute a checkpoint.
Read checkpoint files from INPUT_FILES and redistribute them for a new
number of ranks. Write the output to OUTPUT_PATH."""
# Verify input files
if
not
input_files
:
raise
click
.
ClickException
(
"No input files provided"
)
input_file_list
=
sorted
(
input_files
,
key
=
lambda
x
:
natural_sort_key
(
x
.
name
))
click
.
echo
(
f
"Processing
{
len
(
input_file_list
)
}
checkpoint files"
)
# Determine if we're processing a single global checkpoint or multiple rank files
rsi
=
RankStateIterable
(
input_file_list
)
if
not
rsi
.
rank_states
:
raise
click
.
ClickException
(
"No valid checkpoint states found"
)
if
new_world_size
is
None
:
click
.
echo
(
f
"Current DP world size:
{
rsi
.
get_num_ranks
()
}
"
)
click
.
echo
(
f
"Current number of workers per DP rank:
{
rsi
.
get_num_workers
()
}
"
)
new_world_size
=
click
.
prompt
(
"Please enter the new DP world size"
,
type
=
int
)
assert
isinstance
(
new_world_size
,
int
)
if
new_world_size
<=
0
:
raise
click
.
ClickException
(
"New world size must be greater than 0"
)
total_num_workers
=
rsi
.
get_num_workers
()
*
rsi
.
get_num_ranks
()
assert
total_num_workers
%
new_world_size
==
0
,
(
"New DP world size must be a multiple of the current DP world size"
)
new_workers_per_rank
=
total_num_workers
//
new_world_size
# Ensure output directory exists
output_path
.
mkdir
(
exist_ok
=
True
,
parents
=
True
)
new_rank_states
=
[
list
()
for
_
in
range
(
new_world_size
)]
rsi_iter
=
iter
(
rsi
)
for
rank_idx
in
range
(
new_world_size
):
for
_
in
range
(
new_workers_per_rank
):
state
=
next
(
rsi_iter
)
new_rank_states
[
rank_idx
].
append
(
state
)
assert
all
(
len
(
new_rank_states
[
0
])
==
len
(
new_rank_states
[
rank
])
for
rank
in
range
(
1
,
new_world_size
)
),
"All ranks must have the same number of workers, also for the new distribution."
new_states
=
[
SavableDataLoaderState
(
worker_states
=
new_rank_state
,
next_worker_id
=
0
,
# Reset the next worker ID
micro_batch_size
=
rsi
.
get_micro_batch_size
(),
)
for
new_rank_state
in
new_rank_states
]
# Save the redistributed checkpoint
if
rsi
.
is_global_checkpoint
:
# Save as a single global checkpoint file
output_file
=
output_path
/
input_file_list
[
0
].
name
torch
.
save
(
new_states
,
str
(
output_file
))
click
.
echo
(
f
"Saved global checkpoint to
{
output_file
}
"
)
else
:
rsi
.
write_new_states_to_folder
(
output_path
,
new_states
)
click
.
echo
(
f
"Saved
{
new_world_size
}
rank checkpoint files to
{
output_path
}
"
)
@
click
.
command
(
name
=
"info"
)
@
click
.
argument
(
"input_files"
,
nargs
=-
1
,
type
=
click
.
Path
(
file_okay
=
True
,
dir_okay
=
False
,
exists
=
True
,
path_type
=
EPath
),
required
=
True
,
)
def
command_info
(
input_files
:
List
[
EPath
]):
"""Display information about a checkpoint.
Read a checkpoint from CHECKPOINT_PATH (either a single file or directory with *.pt files)
and display information about it.
"""
# Load the checkpoint(s)
rsi
=
RankStateIterable
(
input_files
)
# Display basic information
if
rsi
.
is_global_checkpoint
:
click
.
echo
(
"Checkpoint type: Global checkpoint"
)
else
:
click
.
echo
(
"Checkpoint type: Per-rank checkpoint files"
)
click
.
echo
(
f
"Number of DP ranks:
{
rsi
.
get_num_ranks
()
}
"
)
click
.
echo
(
f
"Number of workers per DP rank:
{
rsi
.
get_num_workers
()
}
"
)
# Additional detailed information
click
.
echo
(
"
\n
Detailed information:"
)
for
rank_idx
,
state
in
enumerate
(
rsi
.
rank_states
):
if
rsi
.
is_global_checkpoint
:
click
.
echo
(
f
" Rank
{
rank_idx
}
:"
)
else
:
click
.
echo
(
f
" Rank
{
rank_idx
}
(
{
input_files
[
rank_idx
].
name
}
):"
)
click
.
echo
(
f
" Next worker ID:
{
state
.
next_worker_id
}
"
)
click
.
echo
(
f
" Number of worker states:
{
len
(
state
.
worker_states
)
}
"
)
@
click
.
group
(
name
=
"checkpoint"
,
context_settings
=
dict
(
help_option_names
=
[
"-h"
,
"--help"
]),
invoke_without_command
=
True
,
)
@
click
.
pass_context
def
command
(
ctx
):
"""Tools for energon checkpoints."""
# This is needed to show help if no subcommand is provided
if
ctx
.
invoked_subcommand
is
None
:
click
.
echo
(
command
.
get_help
(
ctx
))
command
.
add_command
(
command_redist
)
command
.
add_command
(
command_info
)
if
__name__
==
"__main__"
:
command
()
Megatron-Energon/src/megatron/energon/tools/info.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
importlib
import
import_module
from
typing
import
Type
import
braceexpand
import
click
import
megatron.energon
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.eyaml
import
load_yaml
,
load_yaml_json
from
megatron.energon.flavors.base_dataset
import
BaseCoreDatasetFactory
from
megatron.energon.flavors.webdataset
import
MAIN_FOLDER_NAME
from
megatron.energon.flavors.webdataset.metadata
import
get_dataset_info
fmt
=
"""
* {dataset_name}: `{path}` ({samples_count:,} samples, {samples_size} in {shards_count} shards)
* Created with energon version: {energon_version}
* Sample Type: {{py:class}}`{sample_name} <{sample_fullname}>`
* Default Splits:
{splits_str}
"""
split_fmt
=
""" * `{split_name}`: {split_ratio:.0f}%, {split_samples_count:,} samples in {split_shards_count} shards
"""
def
fmt_size
(
size
:
int
)
->
str
:
keys
=
[
"B"
,
"KiB"
,
"MiB"
,
"GiB"
,
"TiB"
]
for
key
in
keys
:
if
size
<
1024
:
return
f
"
{
size
:.
2
f
}
{
key
}
"
size
/=
1024
return
f
"
{
size
:.
2
f
}
PiB"
@
click
.
command
(
name
=
"info"
)
@
click
.
argument
(
"path"
,
type
=
click
.
Path
(
file_okay
=
False
,
dir_okay
=
True
,
path_type
=
EPath
),
)
@
click
.
option
(
"--split-config"
,
default
=
"split.yaml"
,
help
=
"Split config file name"
,
show_default
=
True
)
@
click
.
option
(
"--dataset-config"
,
default
=
"dataset.yaml"
,
help
=
"Dataset config file name"
,
show_default
=
True
)
def
command
(
path
:
EPath
,
split_config
:
str
,
dataset_config
:
str
,
):
"""
Get summarizing information about a dataset.
"""
ds_config
=
load_yaml
((
path
/
MAIN_FOLDER_NAME
/
dataset_config
).
read_bytes
())
info_config
=
get_dataset_info
(
path
)
split_config_obj
=
load_yaml_json
(
path
/
MAIN_FOLDER_NAME
/
split_config
)
ds_energon_version
=
info_config
.
get
(
"energon_version"
,
"unknown"
)
samples_count
=
sum
(
info_config
[
"shard_counts"
].
values
())
dict_sample_type
=
ds_config
[
"sample_type"
]
sample_module
=
import_module
(
dict_sample_type
[
"__module__"
])
sample_cls
:
Type
[
BaseCoreDatasetFactory
]
=
getattr
(
sample_module
,
dict_sample_type
[
"__class__"
])
sample_module
=
sample_cls
.
__module__
if
(
sample_module
.
startswith
(
"megatron.energon"
)
and
getattr
(
megatron
.
energon
,
dict_sample_type
[
"__class__"
],
None
)
==
sample_cls
):
sample_module
=
"megatron.energon"
sample_name
=
sample_cls
.
__name__
sample_fullname
=
sample_module
+
"."
+
sample_name
def
srt_key
(
pair
):
try
:
return
(
"train"
,
"val"
,
"test"
).
index
(
pair
[
0
])
except
ValueError
:
return
3
# Brace expand all the split part files
expanded_split_parts
=
{}
for
split_name
,
split_parts
in
split_config_obj
[
"split_parts"
].
items
():
expanded_split_parts
[
split_name
]
=
[]
for
split_part
in
split_parts
:
for
name
in
braceexpand
.
braceexpand
(
split_part
):
expanded_split_parts
[
split_name
].
append
(
name
)
splits_str
=
""
.
join
(
split_fmt
.
format
(
split_name
=
split_name
,
split_ratio
=
round
(
100
*
sum
(
info_config
[
"shard_counts"
][
shard
]
for
shard
in
split_parts
)
/
samples_count
,
2
,
),
split_samples_count
=
sum
(
info_config
[
"shard_counts"
][
shard
]
for
shard
in
split_parts
),
split_shards_count
=
len
(
split_parts
),
)
for
split_name
,
split_parts
in
sorted
(
expanded_split_parts
.
items
(),
key
=
srt_key
)
)
print
(
fmt
.
format
(
dataset_name
=
path
.
name
,
path
=
str
(
path
),
samples_count
=
samples_count
,
samples_size
=
fmt_size
(
sum
((
path
/
split_name
).
size
()
for
split_name
in
info_config
[
"shard_counts"
].
keys
())
),
shards_count
=
len
(
info_config
[
"shard_counts"
]),
sample_name
=
sample_name
,
sample_fullname
=
sample_fullname
,
splits_str
=
splits_str
,
energon_version
=
ds_energon_version
,
)
)
Megatron-Energon/src/megatron/energon/tools/lint.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
logging
import
traceback
from
typing
import
Optional
import
click
import
tqdm
from
megatron.energon
import
WorkerConfig
,
get_loader
,
get_val_dataset
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.webdataset.empty_dataset_error
import
EmptyDatasetError
from
megatron.energon.flavors.webdataset.itar
import
TarIndexReader
from
megatron.energon.flavors.webdataset.metadata
import
get_dataset_info
def
check_index
(
dataset_path
:
EPath
):
if
not
dataset_path
.
is_dir
():
click
.
echo
(
f
"
{
dataset_path
}
is not a directory, therefore the index will not be checked"
)
return
ok
=
True
# Get info file
info
=
get_dataset_info
(
dataset_path
)
click
.
echo
(
"Checking the index files..."
)
shards
=
info
[
"shard_counts"
]
for
shard_file
,
length
in
shards
.
items
():
with
TarIndexReader
(
dataset_path
/
shard_file
)
as
itar
:
l
=
len
(
itar
)
if
l
-
1
!=
length
:
ok
=
False
print
(
f
"Error in shard
{
shard_file
}
: Shard length in Info file
{
length
}
!=
{
l
-
1
}
(length in index)"
)
return
ok
@
click
.
command
(
name
=
"lint"
)
@
click
.
argument
(
"path"
,
type
=
click
.
Path
(
path_type
=
EPath
),
)
@
click
.
option
(
"--split-parts"
,
default
=
"train,val,test"
,
help
=
"The splits to verify"
,
show_default
=
True
)
@
click
.
option
(
"--dataset-config"
,
default
=
"dataset.yaml"
,
help
=
"Dataset config file name"
,
show_default
=
True
)
@
click
.
option
(
"--split-config"
,
default
=
"split.yaml"
,
help
=
"Split config file name"
,
show_default
=
True
)
@
click
.
option
(
"--parallel"
,
default
=
1
,
help
=
"Number of parallel workers"
,
show_default
=
True
,
type
=
int
)
def
command
(
path
:
EPath
,
split_parts
:
str
,
dataset_config
:
str
,
split_config
:
str
,
parallel
:
int
):
"""Check energon dataset for errors.
The PATH should point to the folder with the dataset.
The dataset must comply with the energon dataset format. See README.md for more details."""
# Check the tar file index
if
not
check_index
(
path
):
raise
click
.
ClickException
(
"Validation failed with errors, see logs for details."
)
# Check the dataset
failed
=
False
ignore_list
=
[]
def
handler
(
exc
:
Exception
,
key
:
Optional
[
str
]
=
None
)
->
None
:
nonlocal
failed
failed
=
True
logging
.
exception
(
str
(
exc
))
if
key
is
not
None
:
ignore_list
.
append
(
key
)
kwargs
=
{}
if
dataset_config
!=
"dataset.yaml"
:
kwargs
[
"dataset_config"
]
=
dataset_config
if
split_config
!=
"split.yaml"
:
kwargs
[
"split_config"
]
=
split_config
worker_config
=
WorkerConfig
(
rank
=
0
,
world_size
=
1
,
num_workers
=
parallel
)
for
split_part
in
split_parts
.
split
(
","
):
try
:
dataset
=
get_val_dataset
(
EPath
(
path
),
split_part
=
split_part
,
worker_config
=
worker_config
,
batch_size
=
1
,
handler
=
handler
,
**
kwargs
,
)
except
EmptyDatasetError
:
click
.
echo
(
f
"Skipping empty split part
{
split_part
}
"
)
continue
try
:
for
_
in
tqdm
.
tqdm
(
get_loader
(
dataset
)):
pass
except
InterruptedError
:
raise
except
BaseException
:
traceback
.
print_exc
()
raise
click
.
ClickException
(
"Validation failed with errors, see logs for details."
)
if
failed
:
click
.
echo
(
"The following shards/samples failed (maybe set as dataset.yaml:ignore_list):"
,
err
=
True
)
for
item
in
ignore_list
:
click
.
echo
(
f
"-
{
item
}
"
,
err
=
True
)
raise
click
.
ClickException
(
"Validation failed with errors, see logs for details."
)
if
__name__
==
"__main__"
:
command
()
Megatron-Energon/src/megatron/energon/tools/mount.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
os
import
stat
from
errno
import
EBADF
,
ENOENT
from
pathlib
import
Path
from
sqlite3
import
OperationalError
from
typing
import
Generator
,
Optional
import
click
from
megatron.energon.cache.file_store
import
WebdatasetFileStore
from
megatron.energon.epathlib
import
EPath
MULTI_WARN
=
"WARNING_SAME_KEY_IN_MULTIPLE_TAR_FILES"
EnergonFS
:
Optional
[
type
]
try
:
from
mfusepy
import
FUSE
,
FuseOSError
,
Operations
class
_EnergonFS
(
Operations
):
"""
Read-only filesystem that exposes an energon WebdatasetFileStore.
"""
def
__init__
(
self
,
db_path
:
EPath
,
*
,
sample_folders
:
bool
=
False
,
print_debug
:
int
=
0
,
allow_slow_mode
:
bool
=
False
,
)
->
None
:
self
.
_sample_folders
=
sample_folders
self
.
_wds_filestore
=
WebdatasetFileStore
(
db_path
)
self
.
_all_sample_parts
=
{}
self
.
_slow_mode
=
False
try
:
for
key
,
size
,
tar_file_id
in
self
.
_wds_filestore
.
list_all_sample_parts
():
if
key
not
in
self
.
_all_sample_parts
:
# Only take the first tar file id
self
.
_all_sample_parts
[
key
]
=
size
except
OperationalError
:
if
not
allow_slow_mode
:
raise
RuntimeError
(
"The dataset was prepared with an older version of energon. Either update the dataset, or allow slow mode."
)
else
:
assert
sample_folders
,
(
"Only sample_folders mode is supported when using slow mode."
)
self
.
_slow_mode
=
True
self
.
_samples_with_multiple_tar_files
=
set
()
self
.
_all_samples
=
{}
for
key
,
size
,
tar_file_id
in
self
.
_wds_filestore
.
list_all_samples
():
if
key
not
in
self
.
_all_samples
:
self
.
_all_samples
[
key
]
=
size
else
:
self
.
_samples_with_multiple_tar_files
.
add
(
key
)
self
.
_total_size
=
None
# When a file is opened, we keep the bytes in memory for now (until it is closed)
self
.
_open_files
=
{}
# Get current uid and gid
self
.
_uid
=
os
.
getuid
()
self
.
_gid
=
os
.
getgid
()
# Get modification time of the db file
try
:
self
.
_mtime
=
os
.
path
.
getmtime
(
str
(
db_path
))
except
FileNotFoundError
:
# Remote file systems have no modification time
self
.
_mtime
=
0
self
.
_print
=
print_debug
def
statfs
(
self
,
path
:
str
)
->
dict
:
"""Return information about the file system.
This is called when the user runs `df` on the mount point.
"""
if
self
.
_total_size
is
None
:
print
(
"Computing total size..."
,
end
=
""
,
flush
=
True
)
self
.
_total_size
=
self
.
_wds_filestore
.
get_total_size
()
print
(
f
"done:
{
self
.
_total_size
}
bytes"
)
return
dict
(
f_bsize
=
512
,
f_blocks
=
self
.
_total_size
//
512
,
f_bavail
=
0
,
f_bfree
=
0
,
f_files
=
len
(
self
.
_all_sample_parts
)
if
not
self
.
_slow_mode
else
0
,
f_ffree
=
0
,
f_namemax
=
1024
,
)
def
getattr
(
self
,
path
:
str
,
fh
:
int
=
0
)
->
dict
:
"""Return information about one file or folder.
This is called when using `ls -l` etc.
Returns a dict with the following keys:
- st_mode: File mode (S_IFDIR, S_IFREG, etc.)
- st_nlink: Number of links
- st_size: Size of the file
- st_ctime: Creation time
- st_mtime: Modification time
- st_atime: Access time
- st_uid: User ID of the file
- st_gid: Group ID of the file
"""
if
path
[
0
]
!=
"/"
:
raise
FuseOSError
(
ENOENT
)
if
path
==
"/"
:
return
dict
(
st_mode
=
0o555
|
stat
.
S_IFDIR
,
st_nlink
=
2
,
st_size
=
0
,
st_ctime
=
self
.
_mtime
,
st_mtime
=
self
.
_mtime
,
st_atime
=
self
.
_mtime
,
st_uid
=
self
.
_uid
,
st_gid
=
self
.
_gid
,
)
# Strip leading '/'
path
=
path
[
1
:]
if
path
.
endswith
(
MULTI_WARN
):
return
dict
(
st_mode
=
0o000
|
stat
.
S_IFBLK
,
st_nlink
=
1
,
st_size
=
0
,
st_ctime
=
self
.
_mtime
,
st_mtime
=
self
.
_mtime
,
)
if
self
.
_sample_folders
:
folder
,
part_name
=
self
.
_path_parts
(
path
)
if
part_name
!=
""
:
# This is a sample part (file)
if
folder
not
in
self
.
_all_samples
:
raise
FuseOSError
(
ENOENT
)
full_name
=
f
"
{
folder
}
.
{
part_name
}
"
if
self
.
_slow_mode
and
full_name
not
in
self
.
_all_sample_parts
:
# Slow mode
for
entry
,
size
,
tar_file_id
in
self
.
_wds_filestore
.
list_sample_parts
(
folder
,
slow_mode
=
True
):
cur_full_name
=
f
"
{
folder
}
.
{
entry
}
"
self
.
_all_sample_parts
[
cur_full_name
]
=
size
if
full_name
not
in
self
.
_all_sample_parts
:
raise
FuseOSError
(
ENOENT
)
file_size
=
self
.
_all_sample_parts
[
full_name
]
mode
=
0o444
|
stat
.
S_IFREG
else
:
# This is a sample (directory)
if
path
not
in
self
.
_all_samples
:
raise
FuseOSError
(
ENOENT
)
file_size
=
self
.
_all_samples
[
path
]
mode
=
0o555
|
stat
.
S_IFDIR
else
:
if
path
not
in
self
.
_all_sample_parts
:
raise
FuseOSError
(
ENOENT
)
file_size
=
self
.
_all_sample_parts
[
path
]
mode
=
0o444
|
stat
.
S_IFREG
return
dict
(
st_mode
=
mode
,
st_nlink
=
1
,
st_size
=
file_size
,
st_ctime
=
self
.
_mtime
,
st_mtime
=
self
.
_mtime
,
st_atime
=
self
.
_mtime
,
st_uid
=
self
.
_uid
,
st_gid
=
self
.
_gid
,
)
def
_path_parts
(
self
,
path
:
str
)
->
tuple
[
str
,
str
]:
"""Split a path into a folder and a part name and check for errors.
We only allow paths of the form "sample_key/part_name".
The leading "/" must be stripped before.
"""
path_parts
=
path
.
split
(
"/"
)
# path_parts [0] == "sample_key"
# path_parts [1] == "part_name"
if
len
(
path_parts
)
>
2
:
raise
FuseOSError
(
ENOENT
)
if
len
(
path_parts
)
==
1
:
part_name
=
""
else
:
part_name
=
path_parts
[
1
]
return
path_parts
[
0
],
part_name
def
readdir
(
self
,
path
:
str
,
fh
:
int
=
0
)
->
Generator
[
str
,
None
,
None
]:
"""List the contents of a directory.
This is called when using `ls` etc.
Returns a generator of the entries in the directory as strings.
"""
if
path
[
0
]
!=
"/"
:
raise
FuseOSError
(
ENOENT
)
path
=
path
[
1
:]
if
self
.
_sample_folders
:
if
path
==
""
:
yield
"."
yield
".."
for
entry
in
self
.
_all_samples
.
keys
():
yield
entry
else
:
folder
,
part_name
=
self
.
_path_parts
(
path
)
if
folder
not
in
self
.
_all_samples
or
part_name
!=
""
:
raise
FuseOSError
(
ENOENT
)
yield
"."
yield
".."
single_tar_id
=
None
all_entries
=
list
(
self
.
_wds_filestore
.
list_sample_parts
(
folder
,
slow_mode
=
self
.
_slow_mode
)
)
for
entry
,
size
,
tar_file_id
in
all_entries
:
if
single_tar_id
is
None
:
single_tar_id
=
tar_file_id
elif
single_tar_id
!=
tar_file_id
:
break
yield
entry
if
folder
in
self
.
_samples_with_multiple_tar_files
:
yield
MULTI_WARN
else
:
if
path
!=
""
:
# Only "/" is allowed for listing all sample parts
raise
FuseOSError
(
ENOENT
)
yield
"."
yield
".."
for
entry
in
self
.
_all_sample_parts
.
keys
():
yield
entry
for
key
in
self
.
_samples_with_multiple_tar_files
:
yield
f
"
{
key
}
.
{
MULTI_WARN
}
"
def
open
(
self
,
path
:
str
,
flags
:
int
=
0
)
->
int
:
"""Open a file for reading.
Actually, we already read the file into memory when it is opened.
The read operation just returns a slice of the memory buffer.
Returns a dummy file descriptor.
"""
if
path
[
0
]
!=
"/"
:
raise
FuseOSError
(
ENOENT
)
path
=
path
[
1
:]
# read-only: deny write flags
if
flags
&
(
os
.
O_WRONLY
|
os
.
O_RDWR
|
os
.
O_APPEND
):
raise
FuseOSError
(
ENOENT
)
if
self
.
_sample_folders
:
folder
,
part_name
=
self
.
_path_parts
(
path
)
if
folder
not
in
self
.
_all_samples
:
raise
FuseOSError
(
ENOENT
)
full_name
=
f
"
{
folder
}
.
{
part_name
}
"
file_bytes
,
_
=
self
.
_wds_filestore
[
full_name
]
else
:
if
path
not
in
self
.
_all_sample_parts
:
raise
FuseOSError
(
ENOENT
)
file_bytes
,
_
=
self
.
_wds_filestore
[
path
]
assert
isinstance
(
file_bytes
,
bytes
)
self
.
_open_files
[
path
]
=
file_bytes
# dummy file handle
return
0
def
read
(
self
,
path
:
str
,
size
:
int
,
offset
:
int
,
fh
:
int
=
0
)
->
bytes
:
"""Read from an open file.
This is called when using `read` etc.
Returns the bytes object of a previously opened file.
"""
if
path
[
0
]
!=
"/"
:
raise
FuseOSError
(
EBADF
)
path
=
path
[
1
:]
if
path
not
in
self
.
_open_files
:
raise
FuseOSError
(
ENOENT
)
data
=
self
.
_open_files
[
path
]
return
data
[
offset
:
offset
+
size
]
def
release
(
self
,
path
:
str
,
fh
:
int
=
0
)
->
None
:
"""Release an open file.
This is called when the file is closed. We can now discard the memory buffer.
"""
if
path
[
0
]
!=
"/"
:
raise
FuseOSError
(
ENOENT
)
path
=
path
[
1
:]
if
path
not
in
self
.
_open_files
:
raise
FuseOSError
(
ENOENT
)
del
self
.
_open_files
[
path
]
def
destroy
(
self
,
path
:
str
)
->
None
:
print
(
"Closing energon mount."
)
if
len
(
self
.
_open_files
)
>
0
:
print
(
f
"Number of still open files:
{
len
(
self
.
_open_files
)
}
"
)
self
.
_wds_filestore
.
close
()
EnergonFS
=
_EnergonFS
except
(
ImportError
,
OSError
):
# mfusepy or fuse not installed, so we can't mount the filesystem
EnergonFS
=
None
@
click
.
command
(
name
=
"mount"
)
@
click
.
argument
(
"path"
,
type
=
click
.
Path
(
path_type
=
EPath
),
)
@
click
.
argument
(
"mountpoint"
,
type
=
click
.
Path
(
path_type
=
Path
),
)
@
click
.
option
(
"--detach"
,
"-d"
,
is_flag
=
True
,
default
=
False
,
help
=
"Run in background"
,
)
@
click
.
option
(
"--sample-folders"
,
"-s"
,
is_flag
=
True
,
default
=
False
,
help
=
"Present a virtual folder per sample. Otherwise a single folder with all sample parts."
,
)
@
click
.
option
(
"--allow-slow"
,
"-w"
,
is_flag
=
True
,
default
=
False
,
help
=
"Allow slow mode. This is useful if the dataset was prepared with an older version of energon."
,
)
def
command
(
path
:
EPath
,
mountpoint
:
Path
,
detach
:
bool
,
sample_folders
:
bool
,
allow_slow
:
bool
):
"""
Mount an energon WebdatasetFileStore at the given mountpoint.
The PATH should point to the folder with the dataset.
"""
path
=
EPath
(
path
)
mountpoint
=
Path
(
mountpoint
)
print
(
f
"Mounting
{
path
}
at
{
mountpoint
}
..."
)
mountpoint
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
EnergonFS
is
not
None
:
energon_fs
=
EnergonFS
(
path
,
sample_folders
=
sample_folders
,
allow_slow_mode
=
allow_slow
)
else
:
print
(
"mfusepy or fuse not installed, so we can't mount the filesystem."
)
return
print
(
"Mount is ready to go."
)
try
:
FUSE
(
operations
=
energon_fs
,
mountpoint
=
str
(
mountpoint
),
foreground
=
not
detach
,
nothreads
=
True
,
ro
=
True
,
)
except
RuntimeError
as
e
:
if
str
(
e
)
!=
"7"
:
raise
Megatron-Energon/src/megatron/energon/tools/prepare.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
dataclasses
import
inspect
import
json
import
re
import
typing
from
types
import
FunctionType
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Type
import
click
import
yaml
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors
import
(
BaseWebdatasetFactory
,
CrudeJsonlDatasetFactory
,
CrudeWebdataset
,
Sample
,
)
from
megatron.energon.flavors.webdataset
import
MAIN_FOLDER_NAME
from
megatron.energon.flavors.webdataset.metadata
import
(
EnergonDatasetType
,
check_dataset_info_present
,
get_dataset_info
,
get_dataset_type
,
)
from
megatron.energon.metadataset.loader
import
prepare_metadataset
def
type_str
(
tp
:
Type
)
->
str
:
"""Returns a human-readable string for a type."""
if
typing
.
get_origin
(
tp
)
is
not
None
:
return
repr
(
tp
)
if
isinstance
(
tp
,
type
):
if
tp
.
__module__
==
"builtins"
:
return
tp
.
__qualname__
return
f
"
{
tp
.
__module__
}
.
{
tp
.
__qualname__
}
"
if
tp
is
...:
return
"..."
if
isinstance
(
tp
,
FunctionType
):
return
tp
.
__name__
return
repr
(
tp
)
def
sample_loader_template
(
fields
:
dict
,
parts
:
list
):
"""Returns a template for a sample_loader.py file."""
fields_str
=
""
for
field
in
fields
:
if
field
.
name
in
(
"__key__"
,
"__restore_key__"
,
"__subflavors__"
):
continue
line
=
f
"""
{
field
.
name
}
=raw["TODO"], # expected type:
{
type_str
(
field
.
type
)
}
"""
if
field
.
default
is
not
dataclasses
.
MISSING
:
line
+=
", default: "
+
repr
(
field
.
default
)
fields_str
+=
line
+
"
\n
"
return
"
\n
"
.
join
(
[
"# This file was automatically generated by `energon prepare`."
,
"# TODO: Edit it to return the proper fields"
,
"# import torch"
,
""
,
"def sample_loader(raw: dict) -> dict:"
" # Note: Images are already decoded to tensors"
,
" # TODO: Set the correct values for all (required) fields"
,
" return dict("
,
fields_str
,
" )"
,
""
,
"def part_filter(part: str) -> bool:"
,
" # TODO: Filter for parts required by the sample_loader"
,
" # E.g. if your dataset contains jpeg, txt and json, but you won't use json,"
,
" # remove it from the list, such that it is not decoded. If you need all, keep as is"
,
f
" return part in
{
tuple
(
parts
)
!
r
}
"
,
""
,
]
)
def
printify_json
(
data
:
Any
)
->
Any
:
"""Shortens json data to a human-readable length."""
if
isinstance
(
data
,
dict
):
return
{
k
:
printify_json
(
v
)
for
k
,
v
in
data
.
items
()}
elif
isinstance
(
data
,
list
):
if
len
(
data
)
>
3
:
return
[
printify_json
(
v
)
for
v
in
data
[:
3
]]
+
[
"..."
]
return
[
printify_json
(
v
)
for
v
in
data
]
elif
isinstance
(
data
,
str
):
return
data
[:
25
]
+
(
"..."
if
len
(
data
)
>
25
else
""
)
return
data
@
click
.
command
(
name
=
"prepare"
)
@
click
.
argument
(
"path"
,
type
=
click
.
Path
(
path_type
=
EPath
),
)
@
click
.
option
(
"--progress/--no-progress"
,
default
=
True
,
)
@
click
.
option
(
"--split-parts"
,
help
=
"Path pattern for parts in the form 'train:train/{000000-009999}.tar'. Will ignore ratio."
,
multiple
=
True
,
default
=
None
,
)
@
click
.
option
(
"--exclude"
,
help
=
"Exclude tar file paths (relative to root) matching that regex (at any position)"
,
)
@
click
.
option
(
"--num-workers"
,
type
=
int
,
default
=
16
,
help
=
"Number of workers to use to index tar files"
,
)
@
click
.
option
(
"--tar-index-only"
,
help
=
"Only (re)generate the tar-index"
,
is_flag
=
True
,
)
@
click
.
option
(
"--shuffle-tars"
,
help
=
"If set, the tar files will be shuffled before splitting."
,
is_flag
=
True
,
)
def
command
(
path
:
EPath
,
progress
:
bool
,
split_parts
:
Optional
[
List
[
str
]],
exclude
:
str
,
num_workers
:
int
,
tar_index_only
:
bool
,
shuffle_tars
:
bool
,
):
"""Prepare WebDataset for use with energon.
The PATH should point to the folder with the dataset.
This tool will add the required metadata yaml files to the dataset. See README.md for more
details.
"""
ds_type
=
get_dataset_type
(
path
)
if
ds_type
==
EnergonDatasetType
.
METADATASET
:
print
(
"Preparing metadataset..."
)
prepare_metadataset
(
path
)
return
elif
ds_type
==
EnergonDatasetType
.
JSONL
:
print
(
"Preparing jsonl dataset..."
)
count
=
CrudeJsonlDatasetFactory
.
prepare_dataset
(
path
)
print
(
f
"Done. Found
{
count
}
samples."
)
return
assert
path
.
is_dir
(),
f
"Path
{
path
}
is not a known dataset type"
if
tar_index_only
:
info
=
get_dataset_info
(
path
)
all_tars
=
list
(
info
[
"shard_counts"
].
keys
())
else
:
if
check_dataset_info_present
(
path
):
if
not
click
.
confirm
(
"It seems the dataset had already been prepared. Do you want to continue?"
):
return
all_tars
=
list
(
path
.
glob
(
"**/*.tar"
))
+
list
(
path
.
glob
(
"**/*.tgz"
))
all_tars
=
[
str
(
p
.
relative_to
(
path
))
for
p
in
sorted
(
all_tars
)]
if
exclude
:
all_tars
=
[
p
for
p
in
all_tars
if
not
re
.
search
(
exclude
,
p
)]
if
len
(
all_tars
)
==
0
:
click
.
echo
(
"Did not find any tar files. Exiting."
)
return
if
not
tar_index_only
:
click
.
echo
(
f
"Found
{
len
(
all_tars
)
}
tar files in total. The first and last ones are:"
)
click
.
echo
(
f
"-
{
all_tars
[
0
]
}
"
)
click
.
echo
(
f
"-
{
all_tars
[
-
1
]
}
"
)
click
.
echo
(
"If you want to exclude some of them, cancel with ctrl+c and specify an exclude "
"filter in the command line."
)
split_parts_patterns
:
Optional
[
List
[
Tuple
[
str
,
str
]]]
if
split_parts
:
split_parts_patterns
=
[
tuple
(
x
.
split
(
":"
,
1
))
for
x
in
split_parts
]
split_parts_ratio
=
None
elif
not
tar_index_only
:
split_input
=
click
.
prompt
(
'Please enter a desired train/val/test split like "0.5, 0.2, 0.3" or "8,1,1"'
,
type
=
str
)
# Extract split floats
try
:
split
=
[
float
(
x
.
strip
())
for
x
in
split_input
.
split
(
","
)]
assert
len
(
split
)
==
3
except
(
ValueError
,
AssertionError
):
click
.
echo
(
"Invalid split. Stopping."
)
return
split_parts_ratio
=
[(
"train"
,
split
[
0
]),
(
"val"
,
split
[
1
]),
(
"test"
,
split
[
2
])]
split_parts_patterns
=
None
else
:
split_parts_ratio
=
None
split_parts_patterns
=
None
if
progress
:
def
progress_fn
(
els
,
length
=
None
):
with
click
.
progressbar
(
els
,
label
=
"Indexing shards"
,
show_pos
=
True
,
length
=
length
,
)
as
bar
:
yield
from
bar
else
:
def
progress_fn
(
els
,
length
=
None
):
return
els
found_types
,
duplicates
=
BaseWebdatasetFactory
.
prepare_dataset
(
path
,
all_tars
,
split_parts_ratio
=
split_parts_ratio
,
split_parts_patterns
=
split_parts_patterns
,
progress_fn
=
progress_fn
,
tar_index_only
=
tar_index_only
,
shuffle_seed
=
42
if
shuffle_tars
else
None
,
workers
=
num_workers
,
)
if
duplicates
:
print
(
f
"Examples of duplicates found:
{
duplicates
}
"
)
print
()
print
(
"The dataset has duplicate keys. Best practice is to use unique keys. "
"You won't be able to use this dataset for joining "
"later on."
)
found_types
=
list
(
found_types
)
if
tar_index_only
:
return
if
duplicates
:
if
not
click
.
confirm
(
"Do you want to continue?"
):
return
# Print json of first two samples
for
sample_idx
,
data
in
enumerate
(
BaseWebdatasetFactory
.
iter_dataset_content
(
path
/
all_tars
[
0
],
(
"json"
,))
):
print
(
f
"Sample
{
sample_idx
}
, keys:"
)
for
key
in
data
.
keys
():
print
(
f
" -
{
key
}
"
)
if
"json"
in
data
:
print
(
f
"Json content of sample
{
sample_idx
}
of
{
all_tars
[
0
]
}
:"
)
print
(
json
.
dumps
(
printify_json
(
json
.
loads
(
data
[
"json"
])),
indent
=
2
))
if
sample_idx
>=
1
:
break
if
len
(
found_types
)
>
10
:
click
.
echo
(
f
"Found the following part types in the dataset:
{
', '
.
join
(
found_types
[:
10
])
}
and more.."
)
allow_interactive_field_map
=
False
else
:
click
.
echo
(
f
"Found the following part types in the dataset:
{
', '
.
join
(
found_types
)
}
"
)
allow_interactive_field_map
=
True
if
click
.
confirm
(
"Do you want to create a dataset.yaml interactively?"
,
default
=
True
):
# Get a list of all classes in megatron.energon that are subclasses of WebdatasetBase
import
megatron.energon
as
data_import
display_name_and_class
=
[
(
name
,
cls
)
for
name
,
cls
in
inspect
.
getmembers
(
data_import
)
if
isinstance
(
cls
,
type
)
and
issubclass
(
cls
,
Sample
)
]
display_name_and_class
.
append
((
"Crude sample (plain dict for cooking)"
,
CrudeWebdataset
))
# Print all classes and ask user to pick one
click
.
echo
(
"The following sample types are available:"
)
for
i
,
(
name
,
cls
)
in
enumerate
(
display_name_and_class
):
click
.
echo
(
f
"
{
i
}
.
{
name
}
"
)
while
True
:
choice
=
click
.
prompt
(
"Please enter a number to choose a class"
,
type
=
int
)
try
:
_
,
cls
=
display_name_and_class
[
choice
]
break
except
IndexError
:
click
.
echo
(
"Invalid choice. Please try again."
)
continue
if
cls
==
CrudeWebdataset
:
click
.
echo
(
"CrudeWebdataset does not need a field map. You will need to provide a `Cooker` for your dataset samples in your `TaskEncoder`."
)
click
.
echo
(
"Furthermore, you might want to add `subflavors` in your meta dataset specification."
)
dataset_definition
=
{
"__module__"
:
"megatron.energon"
,
"__class__"
:
cls
.
__name__
,
}
else
:
click
.
echo
(
"The sample type you selected:
\n
"
)
click
.
echo
(
inspect
.
getsource
(
cls
))
dataset_definition
=
{
"sample_type"
:
{
"__module__"
:
"megatron.energon"
,
"__class__"
:
cls
.
__name__
,
},
}
if
not
allow_interactive_field_map
:
click
.
echo
(
"You cannot set a field_map for this dataset. You will need a sample_loader."
)
if
allow_interactive_field_map
and
click
.
confirm
(
"Do you want to set a simple field_map[Y] (or write your own sample_loader [n])?"
,
default
=
True
,
):
click
.
echo
(
"
\n
For each field, please specify the corresponding name in the WebDataset."
)
click
.
echo
(
f
"Available types in WebDataset:
{
', '
.
join
(
found_types
)
}
"
)
click
.
echo
(
"Leave empty for skipping optional field"
)
click
.
echo
(
"You may also access json fields e.g. by setting the field to: json[field][field]"
)
click
.
echo
(
"You may also specify alternative fields e.g. by setting to: jpg,png"
)
click
.
echo
(
f
"Please enter the field_map for
{
cls
.
__name__
}
:"
)
dataset_definition
[
"field_map"
]
=
field_map
=
{}
for
field
in
dataclasses
.
fields
(
cls
):
if
field
.
name
in
(
"__key__"
,
"__restore_key__"
,
"__subflavors__"
,
"__sources__"
,
):
continue
while
True
:
if
(
field
.
default
is
dataclasses
.
MISSING
and
field
.
default_factory
is
dataclasses
.
MISSING
):
default
=
""
elif
field
.
default
is
not
dataclasses
.
MISSING
:
default
=
f
", default:
{
field
.
default
}
"
elif
field
.
default_factory
is
not
dataclasses
.
MISSING
:
default
=
f
", default:
{
field
.
default_factory
!
r
}
"
else
:
raise
RuntimeError
(
"This should never happen"
)
field_map
[
field
.
name
]
=
input
(
f
"Please enter a webdataset field name for '
{
field
.
name
}
' "
f
"(
{
field
.
type
}{
default
}
): "
,
)
if
not
field_map
[
field
.
name
]
and
default
:
del
field_map
[
field
.
name
]
break
type_ok
=
True
for
option
in
field_map
[
field
.
name
].
split
(
","
):
field_name
=
option
.
split
(
"["
,
1
)[
0
]
if
field_name
not
in
found_types
:
click
.
echo
(
"That type doesn't exist in the WebDataset. Please try again."
)
type_ok
=
False
if
type_ok
:
break
else
:
if
not
allow_interactive_field_map
:
template_part_types
=
set
([
"TODO"
])
else
:
template_part_types
=
found_types
if
not
(
path
/
MAIN_FOLDER_NAME
/
"sample_loader.py"
).
is_file
()
or
click
.
confirm
(
"Do you want to override the existing sample loader?"
):
with
(
path
/
MAIN_FOLDER_NAME
/
"sample_loader.py"
).
open
(
"w"
)
as
f
:
f
.
write
(
sample_loader_template
(
dataclasses
.
fields
(
cls
),
parts
=
template_part_types
,
)
)
click
.
echo
(
f
"
\n
Created
{
path
/
MAIN_FOLDER_NAME
/
'sample_loader.py'
}
. Please edit it to "
f
"return the proper values."
)
dataset_definition
[
"sample_loader"
]
=
"sample_loader.py:sample_loader"
dataset_definition
[
"part_filter"
]
=
"sample_loader.py:part_filter"
# Write the dataset.yaml file
with
(
path
/
MAIN_FOLDER_NAME
/
"dataset.yaml"
).
open
(
"w"
)
as
f
:
yaml
.
dump
(
dataset_definition
,
f
,
sort_keys
=
False
)
else
:
click
.
echo
(
"You will have to add a dataset.yaml manually."
)
click
.
echo
(
"Done"
)
if
__name__
==
"__main__"
:
command
()
Megatron-Energon/src/megatron/energon/tools/preview.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
dataclasses
import
shutil
import
traceback
from
typing
import
Any
import
click
import
torch
from
megatron.energon
import
Sample
,
StandardWebdatasetFactory
from
megatron.energon.dataset_config
import
load_config
from
megatron.energon.epathlib
import
EPath
from
megatron.energon.flavors.webdataset
import
MAIN_FOLDER_NAME
from
megatron.energon.flavors.webdataset.empty_dataset_error
import
EmptyDatasetError
from
megatron.energon.loader
import
get_loader
from
megatron.energon.worker
import
WorkerConfig
def
to_str
(
value
:
Any
,
indent
:
str
)
->
str
:
if
isinstance
(
value
,
torch
.
Tensor
):
orig_value
=
value
# Probably image?
if
value
.
ndim
==
3
and
value
.
shape
[
0
]
in
[
1
,
3
,
4
]:
# Convert to grayscale
if
value
.
shape
[
0
]
==
1
:
value
=
value
[
0
]
elif
value
.
shape
[
0
]
==
3
:
value
=
value
.
to
(
dtype
=
torch
.
float32
).
mean
(
dim
=
0
)
elif
value
.
shape
[
0
]
==
4
:
value
=
value
[:
3
].
to
(
dtype
=
torch
.
float32
).
mean
(
dim
=
0
)
if
value
.
ndim
==
2
:
# 2d image -> ascii print
# Resize to fit terminal
dst_w
,
dst_h
=
shutil
.
get_terminal_size
((
80
,
24
))
orig_h
,
orig_w
=
value
.
shape
dst_w
-=
len
(
indent
)
procrustes
=
0.3
# keep aspect ratio
if
orig_w
/
orig_h
<
dst_w
/
dst_h
:
dst_h
=
int
(
dst_w
*
procrustes
*
orig_h
/
orig_w
)
else
:
dst_w
=
int
(
dst_h
/
procrustes
*
orig_w
/
orig_h
)
value
=
torch
.
nn
.
functional
.
interpolate
(
value
[
None
,
None
,
:,
:].
to
(
dtype
=
torch
.
float32
),
size
=
(
dst_h
,
dst_w
),
mode
=
"area"
)[
0
,
0
]
# normalize
value
=
(
value
-
value
.
min
())
/
(
value
.
max
()
-
value
.
min
())
# to ascii text
return
(
f
"Tensor(shape=
{
orig_value
.
shape
}
, dtype=
{
orig_value
.
dtype
}
):
\n
{
indent
}
"
+
f
"
\n
{
indent
}
"
.
join
(
""
.
join
(
" .:-=+*#%@@"
[
int
(
v
*
10
)]
for
v
in
row
)
for
row
in
value
.
tolist
()
)
+
"
\n
"
)
elif
value
.
ndim
==
1
:
# 1d array... print it?
return
f
"Tensor(shape=
{
value
.
shape
}
, dtype=
{
value
.
dtype
}
):
{
value
[:
128
].
tolist
()
}
"
else
:
return
f
"Tensor(shape=
{
value
.
shape
}
, dtype=
{
value
.
dtype
}
)"
elif
isinstance
(
value
,
(
str
,
int
,
float
,
bool
,
type
(
None
))):
return
repr
(
value
)
elif
isinstance
(
value
,
(
list
,
tuple
)):
if
hasattr
(
value
,
"_fields"
):
return
(
f
"
{
type
(
value
).
__name__
}
(
\n
{
indent
}
"
+
f
",
\n
{
indent
}
"
.
join
(
f
"
{
field
.
name
}
=
{
to_str
(
value
,
indent
+
' '
)
}
"
for
value
,
field
in
zip
(
value
,
value
.
_fields
)
)
+
f
"
\n
{
indent
}
)"
)
if
len
(
value
)
>
0
and
isinstance
(
value
,
(
str
,
int
,
float
,
bool
)):
return
repr
(
type
(
value
)(
to_str
(
v
,
indent
)
for
v
in
value
))
else
:
return
(
f
"[
\n
{
indent
}
"
+
f
"
\n
{
indent
}
"
.
join
(
to_str
(
v
,
indent
+
" "
)
for
v
in
value
)
+
f
"
\n
{
indent
}
]"
)
elif
isinstance
(
value
,
bytes
):
return
f
"bytes(length=
{
len
(
value
)
}
, value=
{
value
[:
128
]
!
r
}
)"
return
repr
(
value
)
def
pprint
(
idx
:
int
,
sample
:
Sample
):
click
.
echo
(
f
"Sample
{
idx
}
"
)
for
field
in
dataclasses
.
fields
(
sample
):
if
field
.
name
in
(
"__restore_key__"
,
"__subflavors__"
,
"__sources__"
):
continue
click
.
echo
(
f
" -
{
field
.
name
}
(
{
field
.
type
}
):
{
to_str
(
getattr
(
sample
,
field
.
name
),
''
)
}
"
)
@
click
.
command
(
name
=
"preview"
)
@
click
.
argument
(
"path"
,
type
=
click
.
Path
(
file_okay
=
False
,
dir_okay
=
True
,
path_type
=
EPath
),
)
@
click
.
option
(
"--split-parts"
,
default
=
"train,val,test"
,
help
=
"The splits to verify"
,
show_default
=
True
)
@
click
.
option
(
"--dataset-config"
,
default
=
"dataset.yaml"
,
help
=
"Dataset config file name"
,
show_default
=
True
)
def
command
(
path
:
EPath
,
split_parts
:
str
,
dataset_config
:
str
):
"""Preview samples of a dataset on the console."""
worker_config
=
WorkerConfig
(
rank
=
0
,
world_size
=
1
,
num_workers
=
0
)
for
split_part
in
split_parts
.
split
(
","
):
try
:
dataset
=
load_config
(
EPath
(
path
)
/
MAIN_FOLDER_NAME
/
dataset_config
,
default_kwargs
=
dict
(
path
=
path
,
split_part
=
split_part
,
training
=
False
,
worker_config
=
worker_config
,
),
default_type
=
StandardWebdatasetFactory
,
)
except
EmptyDatasetError
:
click
.
echo
(
f
"Dataset
{
split_part
}
is empty. Skipping."
)
continue
try
:
for
idx
,
sample
in
enumerate
(
get_loader
(
dataset
.
build
())):
pprint
(
idx
,
sample
)
click
.
confirm
(
"Continue?"
,
abort
=
True
)
except
click
.
Abort
:
click
.
echo
(
"Exiting Preview"
)
except
BaseException
:
traceback
.
print_exc
()
raise
click
.
ClickException
(
"Validation failed with errors, see logs for details."
)
if
__name__
==
"__main__"
:
command
()
Megatron-Energon/src/megatron/energon/transforms/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
megatron.energon.transforms.custom
import
CustomTransform
from
megatron.energon.transforms.mappers
import
(
CenterCropMapper
,
RandomCropMapper
,
RandomHorizontalFlipMapper
,
RandomPerspectiveMapper
,
RandomResizedCropMapper
,
RandomRotationMapper
,
RandomVerticalFlipMapper
,
ResizeMapper
,
TransformMapper
,
)
from
megatron.energon.transforms.merge
import
MergeTransform
__all__
=
[
"TransformMapper"
,
"ResizeMapper"
,
"RandomResizedCropMapper"
,
"CenterCropMapper"
,
"RandomCropMapper"
,
"RandomHorizontalFlipMapper"
,
"RandomVerticalFlipMapper"
,
"RandomRotationMapper"
,
"RandomPerspectiveMapper"
,
"MergeTransform"
,
"CustomTransform"
,
]
Prev
1
…
4
5
6
7
8
9
10
11
12
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment