Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
Qwen2.5-14B-Instruct_dcu-megatron
Commits
f356f546
Commit
f356f546
authored
Feb 04, 2026
by
maming
Browse files
Initial commit
parents
Pipeline
#3339
canceled with stages
Changes
346
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2470 additions
and
0 deletions
+2470
-0
Megatron-Energon/src/megatron/energon/wrappers/log_sample_dataset.py
...ergon/src/megatron/energon/wrappers/log_sample_dataset.py
+109
-0
Megatron-Energon/src/megatron/energon/wrappers/map_dataset.py
...tron-Energon/src/megatron/energon/wrappers/map_dataset.py
+253
-0
Megatron-Energon/src/megatron/energon/wrappers/mix_batch_dataset.py
...nergon/src/megatron/energon/wrappers/mix_batch_dataset.py
+133
-0
Megatron-Energon/src/megatron/energon/wrappers/packing_dataset.py
...-Energon/src/megatron/energon/wrappers/packing_dataset.py
+521
-0
Megatron-Energon/src/megatron/energon/wrappers/repeat_dataset.py
...n-Energon/src/megatron/energon/wrappers/repeat_dataset.py
+112
-0
Megatron-Energon/src/megatron/energon/wrappers/shuffle_buffer_dataset.py
...n/src/megatron/energon/wrappers/shuffle_buffer_dataset.py
+71
-0
Megatron-Energon/src/megatron/energon/wrappers/skip.py
Megatron-Energon/src/megatron/energon/wrappers/skip.py
+6
-0
Megatron-Energon/src/megatron/energon/wrappers/watchdog_dataset.py
...Energon/src/megatron/energon/wrappers/watchdog_dataset.py
+76
-0
Megatron-Energon/tests/__init__.py
Megatron-Energon/tests/__init__.py
+2
-0
Megatron-Energon/tests/data/sync_test.mkv
Megatron-Energon/tests/data/sync_test.mkv
+0
-0
Megatron-Energon/tests/data/sync_test.mp4
Megatron-Energon/tests/data/sync_test.mp4
+0
-0
Megatron-Energon/tests/data/test_audio.flac
Megatron-Energon/tests/data/test_audio.flac
+0
-0
Megatron-Energon/tests/data/test_audio.wav
Megatron-Energon/tests/data/test_audio.wav
+0
-0
Megatron-Energon/tests/epath_s3_emulator.py
Megatron-Energon/tests/epath_s3_emulator.py
+84
-0
Megatron-Energon/tests/s3_emulator/__init__.py
Megatron-Energon/tests/s3_emulator/__init__.py
+6
-0
Megatron-Energon/tests/s3_emulator/auth.py
Megatron-Energon/tests/s3_emulator/auth.py
+212
-0
Megatron-Energon/tests/s3_emulator/handler.py
Megatron-Energon/tests/s3_emulator/handler.py
+434
-0
Megatron-Energon/tests/s3_emulator/main.py
Megatron-Energon/tests/s3_emulator/main.py
+60
-0
Megatron-Energon/tests/s3_emulator/server.py
Megatron-Energon/tests/s3_emulator/server.py
+122
-0
Megatron-Energon/tests/s3_emulator/state.py
Megatron-Energon/tests/s3_emulator/state.py
+269
-0
No files found.
Too many changes to show.
To preserve performance only
346 of 346+
files are displayed.
Plain diff
Email patch
Megatron-Energon/src/megatron/energon/wrappers/log_sample_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Callable
,
Dict
,
Generic
,
Iterator
,
List
,
Literal
,
Optional
,
TypeVar
from
megatron.energon.flavors.base_dataset
import
SavableDataset
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
T_sample
=
TypeVar
(
"T_sample"
)
def
default_get_keys
(
batch
:
Any
)
->
Optional
[
List
[
str
]]:
"""Default get_keys, which has some heuristics to find the sample keys."""
if
isinstance
(
batch
,
list
):
batch
=
batch
[
0
]
if
(
hasattr
(
batch
,
"__key__"
)
and
isinstance
(
batch
.
__key__
,
list
)
and
all
(
isinstance
(
k
,
str
)
for
k
in
batch
.
__key__
)
):
return
batch
.
__key__
elif
(
hasattr
(
batch
,
"__keys__"
)
and
isinstance
(
batch
.
__keys__
,
list
)
and
all
(
isinstance
(
k
,
str
)
for
k
in
batch
.
__keys__
)
):
return
batch
.
__keys__
elif
(
isinstance
(
batch
,
dict
)
and
"__key__"
in
batch
and
all
(
isinstance
(
k
,
str
)
for
k
in
batch
[
"__key__"
])
):
return
batch
[
"__key__"
]
elif
(
isinstance
(
batch
,
dict
)
and
"__keys__"
in
batch
and
all
(
isinstance
(
k
,
str
)
for
k
in
batch
[
"__keys__"
])
):
return
batch
[
"__keys__"
]
elif
(
isinstance
(
batch
,
dict
)
and
"keys"
in
batch
and
all
(
isinstance
(
k
,
str
)
for
k
in
batch
[
"keys"
])
):
return
batch
[
"keys"
]
return
None
class
LogSampleDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample
],
Generic
[
T_sample
]):
"""This dataset logs every yielded sample to the debug logs."""
get_keys_fn
:
Callable
[[
T_sample
],
Optional
[
List
[
str
]]]
mode
:
Literal
[
"train"
,
"val"
]
_step
:
int
_savable_fields
=
(
"_step"
,)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
mode
:
Literal
[
"train"
,
"val"
],
worker_config
:
WorkerConfig
,
get_keys_fn
:
Callable
[[
T_sample
],
Optional
[
List
[
str
]]]
=
default_get_keys
,
):
"""Construct the log sample dataset, which logs every yielded sample to the debug logs.
Args:
dataset: The input dataset to wrap
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
get_keys_fn
=
get_keys_fn
self
.
mode
=
mode
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_step
=
0
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
self
.
dataset
.
len_worker
(
worker_idx
)
def
_log
(
self
,
sample
:
T_sample
)
->
None
:
if
self
.
worker_config
.
should_log
(
level
=
1
):
log_entry
=
{
"t"
:
"yield_batch"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
self
.
worker_config
.
global_worker_id
(),
"m"
:
self
.
mode
,
"idx"
:
self
.
_step
,
}
keys
=
self
.
get_keys_fn
(
sample
)
if
keys
is
not
None
:
log_entry
[
"keys"
]
=
keys
self
.
worker_config
.
worker_log
(
log_entry
)
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
for
sample
in
self
.
dataset
:
self
.
_log
(
sample
)
self
.
_step
+=
1
yield
sample
def
config
(
self
)
->
Dict
[
str
,
Any
]:
# Transparent logger, it won't change the samples
return
self
.
dataset
.
config
()
def
__str__
(
self
):
return
f
"LogSampleDataset(mode=
{
self
.
mode
}
, get_keys_fn=
{
self
.
get_keys_fn
}
, dataset=
{
self
.
dataset
}
)"
Megatron-Energon/src/megatron/energon/wrappers/map_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
inspect
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
Generic
,
Iterator
,
Optional
,
Sequence
,
Tuple
,
TypeVar
,
Union
,
)
from
megatron.energon.errors
import
SYSTEM_EXCEPTIONS
,
FatalSampleError
from
megatron.energon.flavors.base_dataset
import
SavableDataset
,
add_sample_restore_key
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers._log_exception
import
log_exception
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
,
SampleIndex
,
get_sample_restore_key
from
megatron.energon.wrappers.skip
import
SkipSample
T_sample
=
TypeVar
(
"T_sample"
)
T_sample_out
=
TypeVar
(
"T_sample_out"
)
class
MapDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample_out
],
Generic
[
T_sample
,
T_sample_out
]):
"""This dataset wrapper applies a custom function to transform each sample."""
map_fn
:
Callable
[[
T_sample
],
Union
[
T_sample_out
,
Generator
[
T_sample_out
,
None
,
None
]]]
error_handler
:
Callable
[[
Exception
,
T_sample
,
Sequence
[
SourceInfo
]],
None
]
stateless_map_fn
:
bool
map_fn_config
:
Optional
[
Union
[
Dict
[
str
,
Any
],
Callable
[[],
Dict
[
str
,
Any
]]]]
_sample_index
:
SampleIndex
_generator_sample_key
:
Optional
[
Any
]
_generator_offset
:
Optional
[
int
]
_last_map_failures
:
int
=
0
_savable_fields
=
(
"_sample_index"
,
"_generator_sample_key"
,
"_generator_offset"
,
)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
map_fn
:
Callable
[[
T_sample
],
Union
[
T_sample_out
,
Generator
[
T_sample_out
,
None
,
None
]]],
*
,
error_handler
:
Callable
[[
Exception
,
T_sample
,
Sequence
[
SourceInfo
]],
None
]
=
log_exception
,
stateless_map_fn
:
bool
=
False
,
map_fn_config
:
Optional
[
Union
[
Dict
[
str
,
Any
],
Callable
[[],
Dict
[
str
,
Any
]]]]
=
None
,
failure_tolerance
:
int
=
100
,
worker_config
:
WorkerConfig
,
):
"""Construct a MapDataset.
If this should be savable, the map_fn must only return a sample, or a generator yielding
0 or 1 sample per input sample. Otherwise this will be broken (see `IterMapDataset`).
Args:
dataset: The input dataset to wrap
map_fn: The function to apply to each sample. May raise
:exc:`megatron.energon.SkipSample` to skip a sample. Alternatively, may return a
generator to yield multiple or no samples.
error_handler: Handler for errors. Defaults to logging and ignoring the exception.
stateless_map_fn: If true, the map_fn is deterministic and stateless
(thus key for random access can propagate to inner dataset). Defaults to False.
map_fn_config: Configuration for the map_fn function. If callable, it should return the
configuration. Defaults to None.
failure_tolerance: The number of consecutive failures after which the dataset is considered broken. Set to 0 to disable.
worker_config: Worker configuration.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
map_fn
=
map_fn
self
.
error_handler
=
error_handler
self
.
stateless_map_fn
=
stateless_map_fn
self
.
map_fn_config
=
map_fn_config
self
.
failure_tolerance
=
failure_tolerance
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_sample_index
=
SampleIndex
(
self
.
worker_config
,
src
=
self
)
self
.
_generator_sample_key
=
None
self
.
_generator_offset
=
None
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
self
.
dataset
.
len_worker
(
worker_idx
)
def
__iter__
(
self
)
->
Iterator
[
T_sample_out
]:
if
self
.
_generator_sample_key
is
not
None
:
assert
self
.
_generator_offset
is
not
None
sample
=
self
.
dataset
.
restore_sample
(
self
.
_generator_sample_key
)
# Do not increment the sample index, use previous index
with
self
.
_sample_index
.
ctx
(
self
.
_sample_index
.
current_idx
)
as
sample_idx
:
mapped_sample
=
self
.
map_fn
(
sample
)
assert
isinstance
(
mapped_sample
,
Generator
)
assert
inspect
.
isgeneratorfunction
(
self
.
map_fn
),
(
f
"Generator in
{
self
.
map_fn
}
but not marked as such."
)
target_offset
=
self
.
_generator_offset
self
.
_generator_offset
=
0
for
idx
,
(
sample_idx
,
inner_sample
)
in
enumerate
(
self
.
_sample_index
.
iter_ctx
(
mapped_sample
,
sample_idx
)
):
# Skip other samples
if
idx
>=
target_offset
:
self
.
_generator_offset
=
idx
+
1
yield
add_sample_restore_key
(
inner_sample
,
sample_idx
,
idx
,
src
=
self
,
)
self
.
_generator_sample_key
=
None
self
.
_generator_offset
=
None
for
sample
in
self
.
dataset
:
restore_key
=
get_sample_restore_key
(
sample
)
try
:
with
self
.
_sample_index
.
ctx
()
as
sample_idx
:
mapped_sample
=
self
.
map_fn
(
sample
)
if
isinstance
(
mapped_sample
,
Generator
):
assert
inspect
.
isgeneratorfunction
(
self
.
map_fn
),
(
f
"Generator in
{
self
.
map_fn
}
but not marked as such."
)
self
.
_generator_sample_key
=
restore_key
self
.
_generator_offset
=
0
# In case of a generator, additionally store the index of the yielded samples
# per input sample
for
idx
,
(
sample_idx
,
inner_sample
)
in
enumerate
(
self
.
_sample_index
.
iter_ctx
(
mapped_sample
,
sample_idx
)
):
self
.
_generator_offset
=
idx
+
1
self
.
_last_map_failures
=
0
yield
add_sample_restore_key
(
inner_sample
,
sample_idx
,
idx
,
src
=
self
,
)
self
.
_generator_sample_key
=
None
self
.
_generator_offset
=
None
else
:
self
.
_last_map_failures
=
0
yield
add_sample_restore_key
(
mapped_sample
,
sample_idx
,
src
=
self
,
)
except
GeneratorExit
:
raise
except
SkipSample
:
pass
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
sample
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
sample
)
self
.
_last_map_failures
+=
1
print
(
f
"MapDataset
{
self
.
map_fn
}
failed
{
self
.
_last_map_failures
}
/
{
self
.
failure_tolerance
}
times in a row."
)
if
self
.
failure_tolerance
>
0
and
self
.
_last_map_failures
>=
self
.
failure_tolerance
:
raise
FatalSampleError
.
from_sample
(
sample
,
f
"MapDataset
{
self
.
map_fn
}
failed
{
self
.
_last_map_failures
}
times in a row. Likely your code or dataset are broken."
,
)
def
can_restore_sample
(
self
)
->
bool
:
return
super
().
can_restore_sample
()
and
self
.
stateless_map_fn
def
assert_can_restore
(
self
)
->
None
:
assert
self
.
stateless_map_fn
,
(
f
"MapDataset can only restore samples if map_fn
{
self
.
map_fn
}
is stateless."
)
super
().
assert_can_restore
()
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T_sample_out
:
self
.
assert_can_restore
()
if
inspect
.
isgeneratorfunction
(
self
.
map_fn
):
id
,
sample_idx
,
local_idx
=
restore_key
[:
3
]
assert
id
==
type
(
self
).
__name__
restore_key
=
restore_key
[
3
:]
assert
isinstance
(
local_idx
,
int
)
else
:
id
,
sample_idx
=
restore_key
[:
2
]
assert
id
==
type
(
self
).
__name__
restore_key
=
restore_key
[
2
:]
inner_sample
=
self
.
dataset
.
restore_sample
(
restore_key
)
try
:
with
self
.
_sample_index
.
ctx
(
sample_idx
):
mapped_sample
=
self
.
map_fn
(
inner_sample
)
if
isinstance
(
mapped_sample
,
Generator
):
assert
inspect
.
isgeneratorfunction
(
self
.
map_fn
),
(
f
"Generator in
{
self
.
map_fn
}
but not marked as such."
)
for
idx
,
(
sample_idx
,
res_sample
)
in
enumerate
(
self
.
_sample_index
.
iter_ctx
(
mapped_sample
,
sample_idx
)
):
self
.
_last_map_failures
=
0
if
idx
==
local_idx
:
return
add_sample_restore_key
(
res_sample
,
sample_idx
,
local_idx
,
src
=
self
)
assert
False
,
(
"Generator did not yield enough samples, but is marked stateless/deterministic."
)
else
:
self
.
_last_map_failures
=
0
return
add_sample_restore_key
(
mapped_sample
,
sample_idx
,
src
=
self
)
except
GeneratorExit
:
raise
FatalSampleError
.
from_sample
(
inner_sample
,
f
"MapDataset
{
self
.
map_fn
}
generator exited while trying to restore a sample."
,
)
except
SkipSample
:
raise
FatalSampleError
.
from_sample
(
inner_sample
,
f
"MapDataset
{
self
.
map_fn
}
skipped while trying to restore a sample."
)
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
inner_sample
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
inner_sample
)
self
.
_last_map_failures
+=
1
if
self
.
failure_tolerance
>
0
and
self
.
_last_map_failures
>=
self
.
failure_tolerance
:
raise
FatalSampleError
.
from_sample
(
inner_sample
,
f
"MapDataset
{
self
.
map_fn
}
failed
{
self
.
_last_map_failures
}
times in a row. Likely your code or dataset are broken."
,
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"dataset"
:
self
.
dataset
.
config
(),
"map_fn"
:
self
.
_function_config
(
self
.
map_fn
),
**
(
{
"map_fn_config"
:
(
self
.
map_fn_config
()
if
callable
(
self
.
map_fn_config
)
else
self
.
map_fn_config
)
}
if
self
.
map_fn_config
else
{}
),
"map_fn_stateless"
:
self
.
stateless_map_fn
,
}
def
__str__
(
self
):
return
f
"MapDataset(map_fn=
{
self
.
map_fn
}
, dataset=
{
self
.
dataset
}
)"
Megatron-Energon/src/megatron/energon/wrappers/mix_batch_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
dataclasses
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Tuple
,
TypeVar
,
Union
import
torch
from
megatron.energon.flavors.base_dataset
import
SavableDataset
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
from
megatron.energon.wrappers.batch_dataset
import
BatchDataset
from
megatron.energon.wrappers.blend_dataset
import
BlendDataset
T_batch_in
=
TypeVar
(
"T_batch_in"
)
T_batch
=
TypeVar
(
"T_batch"
)
def
generic_concat
(
batch
:
List
[
Any
])
->
Any
:
"""Based on the types/shapes of the batch: Will either pad and stack, or return as list.
Recurses structures (dict, dataclass, namedtuple) and applies the same logic to each field."""
if
isinstance
(
batch
[
0
],
torch
.
Tensor
):
return
concat_pad
(
batch
)
elif
isinstance
(
batch
[
0
],
dict
):
return
{
key
:
generic_concat
([
sample
[
key
]
for
sample
in
batch
])
for
key
in
batch
[
0
].
keys
()}
elif
dataclasses
.
is_dataclass
(
batch
[
0
]):
return
type
(
batch
[
0
])(
**
{
field
.
name
:
generic_concat
([
getattr
(
sample
,
field
.
name
)
for
sample
in
batch
])
for
field
in
dataclasses
.
fields
(
batch
[
0
])
}
)
elif
isinstance
(
batch
[
0
],
tuple
)
and
hasattr
(
batch
[
0
],
"_fields"
):
# NamedTuple
return
type
(
batch
[
0
])(
**
{
field
:
generic_concat
([
getattr
(
sample
,
field
)
for
sample
in
batch
])
for
field
in
batch
[
0
].
_fields
}
)
else
:
return
batch
def
concat_pad
(
batch
:
List
[
Any
])
->
Any
:
"""Concat a batch of arbitrary-sized tensors padded with 0s."""
total_bs
=
sum
(
b
.
shape
[
0
]
for
b
in
batch
)
max_size
=
[
max
(
b
.
shape
[
dim
]
for
b
in
batch
)
for
dim
in
range
(
1
,
batch
[
0
].
ndim
)]
concat_tensor
=
batch
[
0
].
new_zeros
((
total_bs
,
*
max_size
))
b_idx
=
0
for
b
in
batch
:
concat_tensor
[(
slice
(
b_idx
,
b_idx
+
b
.
shape
[
0
]),
*
(
slice
(
0
,
s
)
for
s
in
b
.
shape
[
1
:]))]
=
b
b_idx
+=
b
.
shape
[
0
]
# Pad all tensors to max_size
return
concat_tensor
def
homogeneous_concat_mix
(
samples
:
List
[
T_batch_in
])
->
T_batch
:
"""
Mixes a list of batches into a single batch. The default implementation is to concat the
batches if they are all of the same type, otherwise return a list of batches.
Args:
samples: THe samples to mix.
Returns:
The mixed batch.
"""
first_type
=
type
(
samples
[
0
])
assert
all
(
first_type
is
type
(
sample
)
for
sample
in
samples
)
# All the same type -> concat batches
return
generic_concat
(
samples
)
class
MixBatchDataset
(
BaseWrapperDataset
[
T_batch_in
,
T_batch
],
Generic
[
T_batch_in
,
T_batch
]):
"""
This dataset wrapper blends multiple iterable datasets together give a weight.
The datasets may be infinite. This dataset is always infinite.
Effectively combines :class:`megatron.energon.BlendDataset` and :class:`megatron.energon.BatchDataset`.
"""
def
__init__
(
self
,
*
dataset_weights
:
Tuple
[
SavableDataset
[
T_batch_in
],
float
],
batch_size
:
int
,
batch_mix_fn
:
Callable
[
[
List
[
T_batch_in
]],
Union
[
T_batch
,
Generator
[
T_batch
,
None
,
None
]]
]
=
lambda
x
:
x
,
worker_config
:
WorkerConfig
,
):
"""Construct a BlendDataset.
Args:
dataset_weights: Each argument should be a tuple of (dataset, weight) with a weight
between 0 and 1. The output samples are sampled from the input datasets with the
given probabilities. The datasets should have a batch size of 1, otherwise the
whole batches will be sampled.
batch_size: The batch size to output.
batch_mix_fn: A function that takes a list of samples from the input datasets and
returns a batch sample. The default implementation returns a list of batches.
For homogeneous datasets, it is recommended to use the
:func:`megatron.energon.homogeneous_concat_mix` which concatenates the batches. May raise
:exc:`megatron.energon.SkipSample` to skip a sample. May also return a generator, which
will be iterated over to produce batches.
worker_config: Configuration for the workers.
"""
super
().
__init__
(
BatchDataset
(
BlendDataset
(
*
dataset_weights
,
worker_config
=
worker_config
),
batch_size
=
batch_size
,
batcher
=
batch_mix_fn
,
worker_config
=
worker_config
,
),
worker_config
=
worker_config
,
)
def
reset_state_own
(
self
)
->
None
:
return
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
self
.
dataset
.
len_worker
(
worker_idx
)
def
__iter__
(
self
)
->
Iterator
[
T_batch
]:
yield
from
self
.
dataset
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"dataset"
:
self
.
dataset
.
config
(),
}
def
__str__
(
self
):
return
f
"MixBatchDataset(dataset=
{
self
.
dataset
}
)"
Megatron-Energon/src/megatron/energon/wrappers/packing_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
contextlib
import
inspect
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Optional
,
Sequence
,
TypeVar
,
Union
,
)
from
megatron.energon.errors
import
SYSTEM_EXCEPTIONS
,
FatalSampleError
from
megatron.energon.flavors.base_dataset
import
(
SavableDataset
,
add_sample_restore_key
,
set_sample_restore_key
,
)
from
megatron.energon.source_info
import
SourceInfo
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers._log_exception
import
log_exception
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
,
SampleIndex
,
get_sample_restore_key
from
megatron.energon.wrappers.buffer
import
SavableSampleBuffer
from
megatron.energon.wrappers.skip
import
SkipSample
T_sample
=
TypeVar
(
"T_sample"
)
T_encoded_sample
=
TypeVar
(
"T_encoded_sample"
)
T_batch_sample
=
TypeVar
(
"T_batch_sample"
)
class
PackingDataset
(
BaseWrapperDataset
[
T_sample
,
T_batch_sample
],
Generic
[
T_sample
,
T_encoded_sample
,
T_batch_sample
],
):
"""This dataset wrapper transforms samples of a dataset into chunks/packs of samples, which are
then combined into a batch."""
buffer_size
:
int
pre_packer
:
Callable
[[
List
[
T_sample
]],
List
[
List
[
T_sample
]]]
sample_encoder
:
Optional
[
Callable
[[
T_sample
],
T_encoded_sample
]]
sample_encoder_stateless
:
bool
final_packer
:
Callable
[[
List
[
T_encoded_sample
]],
T_batch_sample
]
final_packer_stateless
:
bool
packer_config
:
Optional
[
Union
[
Dict
[
str
,
Any
],
Callable
[[],
Dict
[
str
,
Any
]]]]
error_handler
:
Callable
[[
Exception
,
List
[
T_sample
],
Sequence
[
SourceInfo
]],
None
]
#: The buffer for collecting the samples that shall be packed.
_reading_buffer
:
SavableSampleBuffer
#: Contains the pre-selected samples to be packed.
#: The full buffer will be passed to the pre_packer.
_pre_packing_buffer
:
SavableSampleBuffer
#: Lengths of the selected groups of samples to be packed together.
#: The samples are stored sequentially in the pre_packing_buffer because
#: SavableSampleBuffer doesn't support nesting. But to keep the groups
#: separate, we need to store the lengths of the groups here.
_pre_packing_lengths
:
List
[
int
]
#: Sample index for the pre_packer
_pre_packing_sample_index
:
SampleIndex
#: Sample index for the sample_encoder
_sample_encoder_sample_index
:
SampleIndex
#: Sample index for the final_packer
_final_packing_sample_index
:
SampleIndex
# Local state: Tracking last failures for each component, to raise a fatal error after a certain number of failures.
_last_pre_pack_failures
:
int
=
0
_last_final_pack_failures
:
int
=
0
_last_sample_encoder_failures
:
int
=
0
_savable_fields
=
(
"_reading_buffer"
,
"_pre_packing_buffer"
,
"_pre_packing_lengths"
,
"_pre_packing_sample_index"
,
"_sample_encoder_sample_index"
,
"_final_packing_sample_index"
,
)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
buffer_size
:
int
,
pre_packer
:
Callable
[[
List
[
T_sample
]],
List
[
List
[
T_sample
]]],
final_packer
:
Callable
[[
List
[
T_encoded_sample
]],
T_batch_sample
],
*
,
final_packer_stateless
:
bool
=
False
,
sample_encoder
:
Optional
[
Callable
[[
T_sample
],
T_encoded_sample
]]
=
None
,
sample_encoder_stateless
:
bool
=
False
,
packer_config
:
Optional
[
Union
[
Dict
[
str
,
Any
],
Callable
[[],
Dict
[
str
,
Any
]]]]
=
None
,
error_handler
:
Callable
[
[
Exception
,
List
[
T_sample
],
Sequence
[
SourceInfo
]],
None
]
=
log_exception
,
pre_packer_failure_tolerance
:
int
=
100
,
final_packer_failure_tolerance
:
int
=
100
,
sample_encoder_failure_tolerance
:
int
=
100
,
worker_config
:
WorkerConfig
,
):
"""Construct a PackingDataset which is used for sequence packing.
Using a pre_packer and final_packer, it buffers the incoming samples, groups
them together based on the logic provided by the pre_packer, and then (using
the final_packer) combines each group into a packed single sample also called
a "pack" or a "packed sequence".
Args:
dataset: The input dataset to wrap
buffer_size: The desired size of the input buffer for pre packing. Last buffer of a dataset may be smaller.
pre_packer: Function which selects samples from the buffer to be packed together.
May raise :exc:`megatron.energon.SkipSample` to skip a buffer.
final_packer: Function which combines the selected samples into a single sample.
final_packer_stateless: If True, the final_packer is stateless, thus samples can be
stored/restored.
sample_encoder: Function which encodes the samples.
sample_encoder_stateless: If True, the sample_encoder is stateless, thus samples can be
stored/restored.
packer_config: Configuration for the (pre|final)_packer functions. If callable, it should return the
configuration. Defaults to None.
error_handler: Function which handles exceptions raised by the batcher. The default
implementation logs the exception.
pre_packer_failure_tolerance: Maximum number of pre-packer failures before raising an error. Set to 0 to disable.
final_packer_failure_tolerance: Maximum number of final-packer failures before raising an error. Set to 0 to disable.
sample_encoder_failure_tolerance: Maximum number of sample-encoder failures before raising an error. Set to 0 to disable.
worker_config: Configuration for the workers.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
assert
buffer_size
>
0
,
"Packing buffer size must be greater than 0."
self
.
buffer_size
=
buffer_size
self
.
pre_packer
=
pre_packer
self
.
final_packer
=
final_packer
self
.
final_packer_stateless
=
final_packer_stateless
self
.
sample_encoder
=
sample_encoder
self
.
sample_encoder_stateless
=
True
if
sample_encoder
is
None
else
sample_encoder_stateless
self
.
packer_config
=
packer_config
self
.
error_handler
=
error_handler
self
.
pre_packer_failure_tolerance
=
pre_packer_failure_tolerance
self
.
final_packer_failure_tolerance
=
final_packer_failure_tolerance
self
.
sample_encoder_failure_tolerance
=
sample_encoder_failure_tolerance
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_reading_buffer
=
SavableSampleBuffer
(
self
.
dataset
,
worker_config
=
self
.
worker_config
)
self
.
_pre_packing_buffer
=
SavableSampleBuffer
(
self
.
dataset
,
worker_config
=
self
.
worker_config
)
self
.
_pre_packing_lengths
=
[]
self
.
_pre_packing_sample_index
=
SampleIndex
(
self
.
worker_config
,
src
=
self
)
self
.
_final_packing_sample_index
=
SampleIndex
(
self
.
worker_config
,
src
=
self
)
self
.
_sample_encoder_sample_index
=
SampleIndex
(
self
.
worker_config
,
src
=
self
)
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
# The real length is unknown, since it depends on the packing function.
# We approximate it by the length of the source dataset.
return
self
.
dataset
.
len_worker
(
worker_idx
)
def
_fill_reading_buffer
(
self
,
source_iter
:
Iterator
,
log_progress
:
bool
=
False
)
->
bool
:
"""
Fill the reading buffer with samples from the dataset source iterator.
Args:
source_iter: Iterator of samples from the dataset.
log_progress: If True, log the progress of the filling.
Returns:
True if samples are successfully read into the buffer, False if no more data.
"""
if
log_progress
:
import
tqdm
pbar_ctx
=
pbar
=
tqdm
.
tqdm
(
total
=
self
.
buffer_size
,
desc
=
"Filling reading buffer"
)
else
:
pbar_ctx
=
contextlib
.
nullcontext
()
pbar
=
None
with
pbar_ctx
:
while
(
self
.
_reading_buffer
.
len_worker
()
+
self
.
_pre_packing_buffer
.
len_worker
()
<
self
.
buffer_size
):
try
:
sample
=
next
(
source_iter
)
self
.
_reading_buffer
.
append
(
sample
)
if
pbar
is
not
None
:
pbar
.
update
(
1
)
except
StopIteration
:
return
False
return
True
def
__iter__
(
self
)
->
Iterator
[
T_batch_sample
]:
pre_packing_lengths
=
self
.
_pre_packing_lengths
# The source dataset
src_iter
=
iter
(
self
.
dataset
)
self
.
_pre_packing_buffer
.
worker_start
()
self
.
_reading_buffer
.
worker_start
()
is_initial_pack
=
True
def
encode_pack_samples
(
pack
:
List
[
T_sample
])
->
List
[
T_encoded_sample
]:
"""Encode the samples in the pack using the sample encoder."""
# Apply the sample encoder to the pack
if
self
.
sample_encoder
is
None
:
return
pack
encoded_pack
=
[]
for
sample
in
pack
:
try
:
with
self
.
_sample_encoder_sample_index
.
ctx
()
as
encode_idx
:
encoded_sample
=
self
.
sample_encoder
(
sample
)
assert
not
isinstance
(
encoded_sample
,
Generator
),
"Generator not supported"
encoded_pack
.
append
(
add_sample_restore_key
(
encoded_sample
,
encode_idx
,
src
=
self
,
)
)
self
.
_last_sample_encoder_failures
=
0
except
SkipSample
:
pass
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
pack
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
[
sample
])
self
.
_last_sample_encoder_failures
+=
1
if
(
self
.
sample_encoder_failure_tolerance
>
0
and
self
.
_last_sample_encoder_failures
>=
self
.
sample_encoder_failure_tolerance
):
raise
FatalSampleError
.
from_sample
(
pack
,
f
"Sample encoder
{
self
.
sample_encoder
}
failed
{
self
.
_last_sample_encoder_failures
}
times. Likely your code or dataset are broken."
,
)
return
encoded_pack
def
next_pre_pack
():
"""Take the samples from the reading buffer and select groups of samples to be packed
together."""
assert
self
.
_pre_packing_buffer
.
len_worker
()
==
0
if
self
.
_reading_buffer
.
len_worker
()
>
0
:
# Take all samples from the reading buffer and pre_pack them
samples
=
self
.
_reading_buffer
.
buffer
.
copy
()
# Clear buffer and pre_packing_lengths
self
.
_reading_buffer
.
clear
()
pre_packing_lengths
.
clear
()
# Now pre pack the samples
try
:
with
self
.
_pre_packing_sample_index
.
ctx
():
pre_packs
=
self
.
pre_packer
(
samples
)
self
.
_last_pre_pack_failures
=
0
except
SkipSample
:
pre_packs
=
[]
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
samples
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
samples
)
pre_packs
=
[]
self
.
_last_pre_pack_failures
+=
1
if
(
self
.
pre_packer_failure_tolerance
>
0
and
self
.
_last_pre_pack_failures
>=
self
.
pre_packer_failure_tolerance
):
raise
FatalSampleError
.
from_sample
(
samples
,
f
"Pre packer
{
self
.
pre_packer
}
failed
{
self
.
_last_pre_pack_failures
}
times. Likely your code or dataset are broken."
,
)
# Put the pre-packed samples into the pre_packing_buffer
# They will be flattened here to avoid nested buffers
# But the lengths of the groups are stored in pre_packing_lengths
# so that the groups can be separated later
for
pre_pack
in
pre_packs
:
if
len
(
pre_pack
)
>
0
:
self
.
_pre_packing_buffer
.
extend
(
pre_pack
)
pre_packing_lengths
.
append
(
len
(
pre_pack
))
def
next_final_pack
()
->
Generator
[
T_batch_sample
,
None
,
None
]:
"""Yield the next packs from the buffer. The final packer is called on the fly."""
pack
=
self
.
_pre_packing_buffer
.
buffer
[:
pre_packing_lengths
[
0
]].
copy
()
if
len
(
pack
)
==
0
:
return
pack
=
encode_pack_samples
(
pack
)
del
self
.
_pre_packing_buffer
[:
pre_packing_lengths
[
0
]]
del
pre_packing_lengths
[
0
]
try
:
pack_restore_keys
=
tuple
(
get_sample_restore_key
(
sample
)
for
sample
in
pack
)
with
self
.
_final_packing_sample_index
.
ctx
()
as
pack_idx
:
final_packed_sample
=
self
.
final_packer
(
pack
)
if
isinstance
(
final_packed_sample
,
Generator
):
assert
inspect
.
isgeneratorfunction
(
self
.
final_packer
),
(
f
"Generator in
{
self
.
final_packer
}
but not marked as such."
)
for
pack_sub_idx
,
(
pack_idx
,
inner_batch_sample
)
in
enumerate
(
self
.
_final_packing_sample_index
.
iter_ctx
(
final_packed_sample
,
pack_idx
)
):
self
.
_last_final_pack_failures
=
0
yield
set_sample_restore_key
(
inner_batch_sample
,
pack_idx
,
pack_sub_idx
,
*
pack_restore_keys
,
src
=
self
,
)
else
:
self
.
_last_final_pack_failures
=
0
yield
set_sample_restore_key
(
final_packed_sample
,
pack_idx
,
*
pack_restore_keys
,
src
=
self
,
)
except
SkipSample
:
pass
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
pack
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
pack
)
self
.
_last_final_pack_failures
+=
1
if
(
self
.
final_packer_failure_tolerance
>
0
and
self
.
_last_final_pack_failures
>=
self
.
final_packer_failure_tolerance
):
raise
FatalSampleError
.
from_sample
(
pack
,
f
"Final packer
{
self
.
final_packer
}
failed
{
self
.
_last_final_pack_failures
}
times. Likely your code or dataset are broken."
,
)
# Main loop:
pre_pack_round
=
0
while
True
:
if
(
self
.
pre_packer_failure_tolerance
>
0
and
pre_pack_round
>
self
.
pre_packer_failure_tolerance
):
raise
RuntimeError
(
f
"Pre packer
{
self
.
pre_packer
}
did not yield any packs after
{
pre_pack_round
}
rounds. Likely your code or dataset are broken."
)
# Fill a portion of the buffer
if
not
self
.
_fill_reading_buffer
(
src_iter
,
log_progress
=
is_initial_pack
):
# Break out of the main loop when the source is exhausted.
break
is_initial_pack
=
False
# Create new pre packs if necessary
if
len
(
pre_packing_lengths
)
==
0
:
assert
self
.
_pre_packing_buffer
.
len_worker
()
==
0
assert
self
.
_reading_buffer
.
len_worker
()
==
self
.
buffer_size
next_pre_pack
()
if
len
(
pre_packing_lengths
)
==
0
:
# Retry packing, nothing was returned.
pre_pack_round
+=
1
continue
if
len
(
pre_packing_lengths
)
>
0
:
pre_pack_round
=
0
yield
from
next_final_pack
()
# Yield the remaining packs, flushing the collecting buffer
while
len
(
pre_packing_lengths
)
>
0
:
yield
from
next_final_pack
()
# If there are still samples in the partial reading buffer, pre-pack them and yield the
# resulting (partial) packs
if
self
.
_reading_buffer
.
len_worker
()
>
0
:
next_pre_pack
()
# Yield the remaining packs, flushing the collecting buffer
while
len
(
pre_packing_lengths
)
>
0
:
yield
from
next_final_pack
()
def
can_restore_sample
(
self
)
->
bool
:
# Cannot really verify if the returned elements contain a __restore_key__.
# If the user wants to use this, well...
return
(
super
().
can_restore_sample
()
and
self
.
final_packer_stateless
and
self
.
sample_encoder_stateless
)
def
assert_can_restore
(
self
):
assert
self
.
final_packer_stateless
and
self
.
sample_encoder_stateless
,
(
f
"Final packer
{
self
.
final_packer
}
and sample encoder
{
self
.
sample_encoder
}
must be stateless to restore samples."
)
super
().
assert_can_restore
()
def
restore_sample
(
self
,
restore_key
:
Any
)
->
T_sample
:
# We need to store multiple indices to restore a batch.
self
.
assert_can_restore
()
if
inspect
.
isgeneratorfunction
(
self
.
final_packer
):
id
,
pack_idx
,
pack_sub_idx
,
*
pack_restore_keys
=
restore_key
id
,
pack_idx
,
pack_sub_idx
,
*
pack_restore_keys
=
restore_key
assert
id
==
type
(
self
).
__name__
else
:
id
,
pack_idx
,
*
pack_restore_keys
=
restore_key
id
,
pack_idx
,
*
pack_restore_keys
=
restore_key
assert
id
==
type
(
self
).
__name__
pack
=
[]
for
inner_idx
in
pack_restore_keys
:
if
self
.
sample_encoder
is
not
None
:
id
,
sample_idx
,
*
inner_idx
=
inner_idx
assert
id
==
type
(
self
).
__name__
id
,
sample_idx
,
*
inner_idx
=
inner_idx
assert
id
==
type
(
self
).
__name__
assert
isinstance
(
sample_idx
,
int
)
sample
=
self
.
dataset
.
restore_sample
(
inner_idx
)
try
:
if
self
.
sample_encoder
is
not
None
:
with
self
.
_sample_encoder_sample_index
.
ctx
(
sample_idx
):
sample
=
self
.
sample_encoder
(
sample
)
assert
not
isinstance
(
sample
,
Generator
),
"Generator not supported"
self
.
_last_sample_encoder_failures
=
0
sample
=
add_sample_restore_key
(
sample
,
sample_idx
,
src
=
self
)
except
SkipSample
:
raise
FatalSampleError
.
from_sample
(
sample
,
f
"PackingDataset sample encoder
{
self
.
sample_encoder
}
skipped while trying to restore a sample."
,
)
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
sample
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
sample
)
self
.
_last_sample_encoder_failures
+=
1
if
(
self
.
sample_encoder_failure_tolerance
>
0
and
self
.
_last_sample_encoder_failures
>=
self
.
sample_encoder_failure_tolerance
):
raise
FatalSampleError
.
from_sample
(
sample
,
f
"PackingDataset sample encoder
{
self
.
sample_encoder
}
failed
{
self
.
_last_sample_encoder_failures
}
times. Likely your code or dataset are broken."
,
)
pack
.
append
(
sample
)
try
:
with
self
.
_final_packing_sample_index
.
ctx
(
pack_idx
):
final_pack
=
self
.
final_packer
(
pack
)
if
isinstance
(
final_pack
,
Generator
):
assert
inspect
.
isgeneratorfunction
(
self
.
final_packer
),
(
f
"Generator in
{
self
.
final_packer
}
but not marked as such."
)
for
cur_batch_sub_idx
,
(
pack_idx
,
inner_batch_sample
)
in
enumerate
(
self
.
_final_packing_sample_index
.
iter_ctx
(
final_pack
,
pack_idx
)
):
self
.
_last_final_pack_failures
=
0
if
cur_batch_sub_idx
==
pack_sub_idx
:
return
set_sample_restore_key
(
inner_batch_sample
,
pack_idx
,
pack_sub_idx
,
*
pack_restore_keys
,
src
=
self
,
)
assert
False
,
f
"Pack sub-index
{
pack_sub_idx
}
not found in pack"
else
:
self
.
_last_final_pack_failures
=
0
return
set_sample_restore_key
(
final_pack
,
pack_idx
,
*
pack_restore_keys
,
src
=
self
)
except
GeneratorExit
:
raise
FatalSampleError
.
from_sample
(
pack
,
f
"PackingDataset
{
self
.
final_packer
}
generator exited while trying to restore a pack."
,
)
except
SkipSample
:
raise
FatalSampleError
.
from_sample
(
pack
,
f
"PackingDataset
{
self
.
final_packer
}
skipped while trying to restore a pack."
)
except
SYSTEM_EXCEPTIONS
:
raise
FatalSampleError
.
from_sample
(
pack
)
except
Exception
as
e
:
self
.
error_handler
(
e
,
pack
)
self
.
_last_final_pack_failures
+=
1
if
(
self
.
final_packer_failure_tolerance
>
0
and
self
.
_last_final_pack_failures
>=
self
.
final_packer_failure_tolerance
):
raise
FatalSampleError
.
from_sample
(
pack
,
f
"PackingDataset
{
self
.
final_packer
}
failed
{
self
.
_last_final_pack_failures
}
times. Likely your code or dataset are broken."
,
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"buffer_size"
:
self
.
buffer_size
,
"pre_packer"
:
self
.
_function_config
(
self
.
pre_packer
),
"final_packer"
:
self
.
_function_config
(
self
.
final_packer
),
"final_packer_stateless"
:
self
.
final_packer_stateless
,
**
(
{
"packer_config"
:
(
self
.
packer_config
()
if
callable
(
self
.
packer_config
)
else
self
.
packer_config
)
}
if
self
.
packer_config
else
{}
),
"error_handler"
:
self
.
_function_config
(
self
.
error_handler
),
"worker_config"
:
self
.
worker_config
.
config
(),
"dataset"
:
self
.
dataset
.
config
(),
}
def
__str__
(
self
):
return
f
"PackingDataset(buffer_size=
{
self
.
buffer_size
}
, pre_packer=
{
self
.
pre_packer
}
, final_packer=
{
self
.
final_packer
}
, dataset=
{
self
.
dataset
}
)"
Megatron-Energon/src/megatron/energon/wrappers/repeat_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
math
from
typing
import
Any
,
Dict
,
Generic
,
Iterator
,
Optional
,
TypeVar
,
Union
from
megatron.energon.flavors.base_dataset
import
SavableDataset
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
T_sample
=
TypeVar
(
"T_sample"
)
class
RepeatDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample
],
Generic
[
T_sample
]):
"""This dataset repeats the inner dataset indefinitely or a specific number of repeats."""
repeats
:
Optional
[
Union
[
int
,
float
]]
_repetition
:
int
_index
:
int
_savable_fields
=
(
"_repetition"
,
"_index"
)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
*
,
repeats
:
Optional
[
Union
[
int
,
float
]]
=
None
,
restart
:
bool
=
True
,
worker_config
:
WorkerConfig
,
):
"""Construct a RepeatDataset.
Args:
dataset: The input dataset to repeat.
repeats: Number of repeats, `None` for indefinitely repeating.
restart: If true, restart the underlying dataset after iterating once through the
repeats if repeats is set to an integer, but still stop iterating.
worker_config: Configuration for the workers.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
repeats
=
repeats
self
.
restart
=
restart
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_repetition
=
0
self
.
_index
=
0
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
if
self
.
repeats
is
None
:
return
self
.
dataset
.
len_worker
(
worker_idx
)
return
int
(
self
.
dataset
.
len_worker
(
worker_idx
)
*
self
.
repeats
)
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
assert
self
.
repeats
is
not
None
or
self
.
dataset
.
worker_has_samples
(),
(
"Cannot repeat empty dataset indefinitely"
)
# TODO: There is a small difference in the total sum of samples (across ranks) * repeats
# and the sum(len_worker() for all workers across ranks).
# This is due to the fact that the number of samples is not exactly divisible by the number of workers.
# The dataset length is the size for the current rank. Need to divide by the number of workers
ds_len
=
self
.
dataset
.
len_worker
()
while
self
.
repeats
is
None
or
self
.
_repetition
<
self
.
repeats
:
if
self
.
repeats
is
not
None
and
self
.
_repetition
==
math
.
floor
(
self
.
repeats
):
# Last iteration, adjust the number of samples
fraction
=
self
.
repeats
-
math
.
floor
(
self
.
repeats
)
stop_after
=
math
.
floor
(
ds_len
*
fraction
)
if
self
.
_index
>=
stop_after
:
# We restored an index and it is already past the stop_after
break
else
:
stop_after
=
None
for
sample
in
self
.
dataset
:
self
.
_index
+=
1
yield
sample
if
stop_after
is
not
None
and
self
.
_index
>=
stop_after
:
break
if
self
.
worker_config
.
should_log
(
level
=
2
):
self
.
worker_config
.
worker_log
(
{
"t"
:
"RepeatDataset.repeat"
,
"r"
:
self
.
worker_config
.
rank
,
"w"
:
self
.
worker_config
.
rank_worker_id
(),
"offset"
:
self
.
_repetition
,
"repeats"
:
self
.
repeats
,
}
)
self
.
_repetition
+=
1
self
.
_index
=
0
if
self
.
restart
:
self
.
_repetition
=
0
else
:
# No more repeats
self
.
_repetition
=
math
.
ceil
(
self
.
repeats
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"dataset"
:
self
.
dataset
.
config
(),
"repeats"
:
self
.
repeats
,
"worker_config"
:
self
.
worker_config
.
config
(),
}
def
__str__
(
self
):
return
f
"RepeatDataset(repeats=
{
self
.
repeats
}
, dataset=
{
self
.
dataset
}
)"
Megatron-Energon/src/megatron/energon/wrappers/shuffle_buffer_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
typing
import
Any
,
Dict
,
Generic
,
Iterator
,
Tuple
,
TypeVar
,
Union
from
megatron.energon.flavors.base_dataset
import
SavableDataset
from
megatron.energon.rng
import
WorkerRng
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
from
megatron.energon.wrappers.buffer
import
SavableSampleBuffer
T_sample
=
TypeVar
(
"T_sample"
)
class
ShuffleBufferDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample
],
Generic
[
T_sample
]):
"""Shuffle buffer for the dataset."""
size
:
int
_worker_rng
:
WorkerRng
_active_buffer
:
SavableSampleBuffer
[
T_sample
]
_savable_fields
=
(
"_active_buffer"
,
"_worker_rng"
)
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
size
:
int
,
*
,
worker_config
:
WorkerConfig
,
):
"""Create a shuffle buffer for the dataset."""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
size
=
size
self
.
reset_state_own
()
def
reset_state_own
(
self
)
->
None
:
self
.
_worker_rng
=
WorkerRng
(
self
.
worker_config
)
self
.
_active_buffer
=
SavableSampleBuffer
(
self
.
dataset
,
worker_config
=
self
.
worker_config
)
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
self
.
dataset
.
len_worker
(
worker_idx
)
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
self
.
_active_buffer
.
worker_start
()
it
=
iter
(
self
.
_active_buffer
.
append_iter
())
while
True
:
if
self
.
_active_buffer
.
len_worker
()
>=
self
.
size
:
pop_idx
=
self
.
_worker_rng
.
randbelow
(
self
.
_active_buffer
.
len_worker
())
yield
self
.
_active_buffer
.
pop
(
pop_idx
)
else
:
try
:
next
(
it
)
except
StopIteration
:
break
while
self
.
_active_buffer
.
len_worker
()
>
0
:
pop_idx
=
self
.
_worker_rng
.
randbelow
(
self
.
_active_buffer
.
len_worker
())
yield
self
.
_active_buffer
.
pop
(
pop_idx
)
def
restore_sample
(
self
,
restore_key
:
Tuple
[
Union
[
str
,
int
,
tuple
],
...])
->
T_sample
:
return
self
.
_active_buffer
.
restore_sample
(
restore_key
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
return
{
"type"
:
type
(
self
).
__qualname__
,
"dataset"
:
self
.
dataset
.
config
(),
"size"
:
self
.
size
,
"worker_config"
:
self
.
worker_config
.
config
(),
}
def
__str__
(
self
):
return
f
"ShuffleBufferDataset(size=
{
self
.
size
}
, dataset=
{
self
.
dataset
}
)"
Megatron-Energon/src/megatron/energon/wrappers/skip.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
class
SkipSample
(
Exception
):
"""Exception to raise in the map_fn to skip a sample."""
Megatron-Energon/src/megatron/energon/wrappers/watchdog_dataset.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
warnings
from
typing
import
Any
,
Dict
,
Generic
,
Iterator
,
Optional
,
TypeVar
from
megatron.energon.flavors.base_dataset
import
SavableDataset
from
megatron.energon.watchdog
import
Watchdog
from
megatron.energon.worker
import
WorkerConfig
from
megatron.energon.wrappers.base
import
BaseWrapperDataset
T_sample
=
TypeVar
(
"T_sample"
)
class
WatchdogDataset
(
BaseWrapperDataset
[
T_sample
,
T_sample
],
Generic
[
T_sample
]):
"""This dataset wraps another dataset and watches the time it takes to yield samples."""
def
__init__
(
self
,
dataset
:
SavableDataset
[
T_sample
],
worker_config
:
WorkerConfig
,
timeout_seconds
:
Optional
[
float
]
=
60
,
initial_timeout_seconds
:
Optional
[
float
]
=
None
,
fail_on_timeout
:
bool
=
False
,
):
"""Construct the watchdog dataset, which wraps another dataset and watches
the time it takes to yield samples from the wrapped dataset.
Args:
dataset: The input dataset to wrap
worker_config: The worker configuration
timeout_seconds: The timeout in seconds. If None, the watchdog is disabled.
initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as timeout_seconds.
fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace.
"""
super
().
__init__
(
dataset
,
worker_config
=
worker_config
)
self
.
timeout_seconds
=
timeout_seconds
self
.
initial_timeout_seconds
=
initial_timeout_seconds
self
.
fail_on_timeout
=
fail_on_timeout
def
reset_state_own
(
self
)
->
None
:
pass
def
len_worker
(
self
,
worker_idx
:
int
|
None
=
None
)
->
int
:
return
self
.
dataset
.
len_worker
(
worker_idx
)
def
_watchdog_trigger
(
self
)
->
None
:
if
self
.
fail_on_timeout
:
# Raising an exception here will kill the whole process
raise
TimeoutError
(
f
"Watchdog triggered. Sample processing took longer than
{
self
.
timeout_seconds
}
seconds."
)
else
:
warnings
.
warn
(
f
"Watchdog triggered. Sample processing took longer than
{
self
.
timeout_seconds
}
seconds."
,
RuntimeWarning
,
)
def
__iter__
(
self
)
->
Iterator
[
T_sample
]:
if
self
.
timeout_seconds
is
None
:
yield
from
self
.
dataset
else
:
watchdog
=
Watchdog
(
timeout
=
self
.
timeout_seconds
,
initial_timeout
=
self
.
initial_timeout_seconds
,
callback
=
self
.
_watchdog_trigger
,
enabled
=
False
,
)
yield
from
watchdog
.
watch_iter
(
self
.
dataset
)
def
config
(
self
)
->
Dict
[
str
,
Any
]:
# Watchdog is transparent, it won't change the samples
return
self
.
dataset
.
config
()
def
__str__
(
self
):
return
f
"WatchdogDataset(dataset=
{
self
.
dataset
}
)"
Megatron-Energon/tests/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
Megatron-Energon/tests/data/sync_test.mkv
0 → 100644
View file @
f356f546
File added
Megatron-Energon/tests/data/sync_test.mp4
0 → 100644
View file @
f356f546
File added
Megatron-Energon/tests/data/test_audio.flac
0 → 100644
View file @
f356f546
File added
Megatron-Energon/tests/data/test_audio.wav
0 → 100644
View file @
f356f546
File added
Megatron-Energon/tests/epath_s3_emulator.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
os
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
typing
import
Generator
from
multistorageclient.rclone
import
read_rclone_config
from
tests.s3_emulator.state
import
S3State
from
tests.s3_emulator.test
import
s3_emulator
@
contextmanager
def
setup_s3_emulator
(
*
,
port
:
int
=
0
,
access_key
:
str
=
"test"
,
secret_key
:
str
=
"test"
,
root_dir
:
str
|
None
=
None
,
region
:
str
=
"us-east-1"
,
profile_name
:
str
=
"s3test"
,
)
->
Generator
[
S3State
,
None
,
None
]:
"""Set up S3 emulator and write necessary config files.
Args:
port: Port to bind the server to. Use 0 to let the OS choose a free port.
access_key: Access key for authentication
secret_key: Secret key for authentication
root_dir: Optional directory to persist S3 data
region: Region for authentication
profile_name: Name of the rclone profile. Must be different in all tests, to ensure that a
cached rclone config is used in MSC.
Returns:
The S3 emulator state. Can be used to quickly upload files to the emulator.
"""
try
:
with
s3_emulator
(
host
=
"127.0.0.1"
,
port
=
port
,
credentials
=
{
access_key
:
secret_key
},
root_dir
=
root_dir
,
region
=
region
,
)
as
emu
:
# Create config directory
config_dir
=
Path
(
"/tmp/XDG_CONFIG_HOME/.config/rclone"
)
config_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# Write rclone config
config_path
=
config_dir
/
"rclone.conf"
with
config_path
.
open
(
"w"
)
as
f
:
f
.
write
(
"
\n
"
.
join
(
[
f
"[
{
profile_name
}
]"
,
"type = s3"
,
"env_auth = false"
,
f
"access_key_id =
{
access_key
}
"
,
f
"secret_access_key =
{
secret_key
}
"
,
f
"region =
{
region
}
"
,
f
"endpoint = http://127.0.0.1:
{
emu
.
port
}
"
,
]
)
)
# Set environment variables
os
.
environ
[
"XDG_CONFIG_HOME"
]
=
"/tmp/XDG_CONFIG_HOME/.config"
os
.
environ
[
"HOME"
]
=
"/tmp/XDG_CONFIG_HOME"
# Hack to clear the cache of the rclone config for msc to get the "s3" profile
read_rclone_config
.
cache_clear
()
yield
emu
.
state
read_rclone_config
.
cache_clear
()
except
Exception
as
e
:
print
(
"ERROR in s3_emulator"
,
flush
=
True
)
print
(
"Full traceback:"
,
flush
=
True
)
import
traceback
traceback
.
print_exc
()
raise
e
Megatron-Energon/tests/s3_emulator/__init__.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
.server
import
S3EmulatorServer
from
.state
import
S3State
__all__
=
[
"S3EmulatorServer"
,
"S3State"
]
Megatron-Energon/tests/s3_emulator/auth.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
hmac
import
re
import
urllib.parse
as
_up
from
hashlib
import
sha256
from
typing
import
Dict
,
Mapping
,
MutableMapping
__all__
=
[
"S3Auth"
,
"InvalidSignature"
]
_SIGNED_HEADERS_RE
=
re
.
compile
(
r
"SignedHeaders=([^,]+)"
)
_CREDENTIAL_RE
=
re
.
compile
(
r
"Credential=([^,]+)"
)
_SIGNATURE_RE
=
re
.
compile
(
r
"Signature=([0-9a-fA-F]+)"
)
class
InvalidSignature
(
Exception
):
"""Raised when the supplied signature does not match."""
class
S3Auth
:
"""Very small subset implementation of AWS Signature V4 verification.
Only what is mandatory for the emulator to work for most typical SDK
operations is implemented. Notably, chunked uploads and presigned URLs are
not supported.
"""
def
__init__
(
self
,
credentials
:
Mapping
[
str
,
str
],
region
:
str
=
"us-east-1"
)
->
None
:
"""Initialize the S3 authentication handler.
Args:
credentials: Mapping of access_key to secret_key accepted by the server.
region: AWS region assumed when verifying the signing key.
"""
self
.
_creds
:
Dict
[
str
,
str
]
=
dict
(
credentials
)
self
.
_region
=
region
def
verify
(
self
,
method
:
str
,
canonical_uri
:
str
,
canonical_querystring
:
str
,
headers
:
Mapping
[
str
,
str
]
|
MutableMapping
[
str
,
str
],
payload
:
bytes
,
)
->
None
:
"""Validate the Authorization header for the given request.
Args:
method: HTTP method of the request.
canonical_uri: Canonical URI path.
canonical_querystring: Canonical query string.
headers: Request headers.
payload: Request body.
"""
auth_header
=
headers
.
get
(
"authorization"
)
or
headers
.
get
(
"Authorization"
)
if
auth_header
is
None
:
raise
InvalidSignature
(
"Missing Authorization header"
)
signed_headers
=
_first_group
(
_SIGNED_HEADERS_RE
,
auth_header
)
credential_str
=
_first_group
(
_CREDENTIAL_RE
,
auth_header
)
signature
=
_first_group
(
_SIGNATURE_RE
,
auth_header
)
if
not
(
signed_headers
and
credential_str
and
signature
):
raise
InvalidSignature
(
"Malformed Authorization header"
)
access_key
,
date_str
,
region
,
service
,
terminator
=
credential_str
.
split
(
"/"
)
if
service
!=
"s3"
or
terminator
!=
"aws4_request"
:
raise
InvalidSignature
(
"Invalid credential scope"
)
if
region
!=
self
.
_region
:
print
(
f
"Signature region
{
region
}
does not match server region
{
self
.
_region
}
"
)
secret_key
=
self
.
_creds
.
get
(
access_key
)
if
secret_key
is
None
:
raise
InvalidSignature
(
"Unknown access key"
)
# Canonical URI & query string (encode & normalise)
canonical_uri
=
_canonical_uri
(
canonical_uri
)
canonical_querystring
=
_canonical_querystring
(
canonical_querystring
)
# Construct canonical request ------------------------------------------------
# 1. Canonical headers
canonical_headers
=
""
for
hdr
in
signed_headers
.
split
(
";"
):
hdr_lower
=
hdr
.
lower
()
value
=
headers
.
get
(
hdr
)
or
headers
.
get
(
hdr_lower
)
if
value
is
None
:
raise
InvalidSignature
(
f
"Signed header '
{
hdr
}
' missing from request"
)
canonical_headers
+=
f
"
{
hdr_lower
}
:
{
_normalize_whitespace
(
str
(
value
))
}
\n
"
# 2. Hashed payload
payload_hash
=
sha256
(
payload
).
hexdigest
()
# 3. Canonical request string
canonical_request
=
"
\n
"
.
join
(
[
method
,
canonical_uri
,
canonical_querystring
,
canonical_headers
,
signed_headers
,
payload_hash
,
]
)
hashed_canonical_request
=
sha256
(
canonical_request
.
encode
()).
hexdigest
()
# String to sign
amz_date
=
headers
.
get
(
"x-amz-date"
)
or
headers
.
get
(
"X-Amz-Date"
)
if
amz_date
is
None
:
raise
ValueError
(
"Missing x-amz-date header"
)
string_to_sign
=
"
\n
"
.
join
(
[
"AWS4-HMAC-SHA256"
,
amz_date
,
"/"
.
join
([
date_str
,
region
,
"s3"
,
"aws4_request"
]),
hashed_canonical_request
,
]
)
# Calculate signing key and signature
date_key
=
_sign
((
"AWS4"
+
secret_key
).
encode
(),
date_str
)
region_key
=
_sign
(
date_key
,
region
)
service_key
=
_sign
(
region_key
,
"s3"
)
signing_key
=
_sign
(
service_key
,
"aws4_request"
)
calc_signature
=
hmac
.
new
(
signing_key
,
string_to_sign
.
encode
(),
sha256
).
hexdigest
()
if
not
hmac
.
compare_digest
(
calc_signature
,
signature
):
print
(
f
"Sig mismatch: expected=
{
signature
}
got=
{
calc_signature
}
"
)
raise
InvalidSignature
(
"Signature mismatch"
)
def
_first_group
(
regex
:
re
.
Pattern
[
str
],
string
:
str
)
->
str
|
None
:
"""Extract the first capture group from a regex match.
Args:
regex: The regex pattern to match.
string: The string to search in.
Returns:
The first capture group if found, None otherwise.
"""
match
=
regex
.
search
(
string
)
return
match
.
group
(
1
)
if
match
else
None
def
_sign
(
key
:
bytes
,
msg
:
str
)
->
bytes
:
"""Sign a message with a key using HMAC-SHA256.
Args:
key: The signing key.
msg: The message to sign.
Returns:
The HMAC-SHA256 signature.
"""
return
hmac
.
new
(
key
,
msg
.
encode
(),
sha256
).
digest
()
def
_normalize_whitespace
(
value
:
str
)
->
str
:
"""Collapse consecutive whitespace.
Args:
value: The string to normalize.
Returns:
The normalized string with collapsed whitespace.
"""
return
" "
.
join
(
value
.
strip
().
split
())
def
_percent_encode
(
value
:
str
)
->
str
:
"""Percent encode a string using AWS safe characters.
Args:
value: The string to encode.
Returns:
The percent-encoded string.
"""
return
_up
.
quote
(
value
,
safe
=
"-_.~"
)
def
_canonical_uri
(
uri
:
str
)
->
str
:
"""Return URI-encoded path as required by SigV4.
Each segment between / must be percent-encoded with the AWS safe list
-_.~. Duplicate slashes are preserved (AWS behaviour).
Args:
uri: The URI path to canonicalize.
Returns:
The canonical URI path.
"""
if
uri
==
""
:
return
"/"
encoded_parts
=
[
_percent_encode
(
_up
.
unquote
(
part
))
for
part
in
uri
.
split
(
"/"
)]
prefix
=
""
if
uri
.
startswith
(
"/"
)
else
"/"
return
prefix
+
"/"
.
join
(
encoded_parts
)
def
_canonical_querystring
(
raw_qs
:
str
)
->
str
:
"""Canonicalize a query string according to AWS SigV4 rules.
Args:
raw_qs: The raw query string to canonicalize.
Returns:
The canonical query string.
"""
if
raw_qs
==
""
:
return
""
pairs
=
_up
.
parse_qsl
(
raw_qs
,
keep_blank_values
=
True
)
encoded_pairs
=
[(
_percent_encode
(
k
),
_percent_encode
(
v
))
for
k
,
v
in
pairs
]
encoded_pairs
.
sort
()
return
"&"
.
join
(
f
"
{
k
}
=
{
v
}
"
for
k
,
v
in
encoded_pairs
)
Megatron-Energon/tests/s3_emulator/handler.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
urllib.parse
as
_up
from
datetime
import
datetime
,
timezone
from
email.utils
import
formatdate
from
hashlib
import
md5
from
http
import
HTTPStatus
from
http.server
import
BaseHTTPRequestHandler
from
typing
import
Protocol
from
.auth
import
InvalidSignature
,
S3Auth
from
.state
import
S3State
__all__
=
[
"S3RequestHandler"
]
class
S3RequestHandler
(
BaseHTTPRequestHandler
):
"""HTTP request handler implementing a minimal S3-compatible API.
This handler processes HTTP requests and maps them to S3 operations.
It supports basic S3 operations like bucket and object management,
including multipart uploads.
"""
server
:
"S3ServerProtocol"
# type: ignore[assignment]
def
log_message
(
self
,
fmt
:
str
,
*
args
):
"""Log a message to stdout.
Args:
fmt: Format string for the message.
*args: Arguments to format the message with.
"""
print
(
f
"
{
self
.
client_address
[
0
]
}
- -
{
fmt
%
args
}
"
)
def
do_PUT
(
self
):
"""Handle PUT requests for object creation and bucket creation."""
self
.
_handle_write
()
def
do_GET
(
self
):
"""Handle GET requests for object retrieval and bucket listing."""
self
.
_handle_read
(
listing
=
False
)
def
do_HEAD
(
self
):
"""Handle HEAD requests for object metadata."""
self
.
_handle_read
(
listing
=
False
,
only_headers
=
True
)
def
do_DELETE
(
self
):
"""Handle DELETE requests for object and bucket deletion."""
self
.
_handle_delete
()
def
do_POST
(
self
):
"""Handle POST requests for multipart upload operations."""
self
.
_handle_post
()
def
_read_body
(
self
)
->
bytes
:
"""Read and return the request body.
Returns:
The request body as bytes.
"""
length
=
int
(
self
.
headers
.
get
(
"Content-Length"
,
0
))
if
length
==
0
:
return
b
""
data
=
self
.
rfile
.
read
(
length
)
return
data
def
_split_path
(
self
)
->
tuple
[
str
,
str
,
_up
.
ParseResult
]:
"""Split the request path into bucket and key components.
Returns:
A tuple of (bucket, key, parsed_url).
"""
parsed
=
_up
.
urlparse
(
self
.
path
)
parts
=
[
p
for
p
in
parsed
.
path
.
split
(
"/"
)
if
p
]
bucket
=
parts
[
0
]
if
parts
else
""
key
=
"/"
.
join
(
parts
[
1
:])
if
len
(
parts
)
>
1
else
""
return
bucket
,
key
,
parsed
def
_auth
(
self
,
payload
:
bytes
,
parsed
:
_up
.
ParseResult
)
->
bool
:
"""Verify the request signature.
Args:
payload: The request body.
parsed: The parsed URL.
Returns:
True if authentication succeeds, False otherwise.
"""
try
:
self
.
server
.
auth
.
verify
(
method
=
self
.
command
,
canonical_uri
=
parsed
.
path
or
"/"
,
canonical_querystring
=
parsed
.
query
,
headers
=
self
.
headers
,
payload
=
payload
,
)
except
InvalidSignature
as
err
:
self
.
_send_error
(
HTTPStatus
.
FORBIDDEN
,
str
(
err
))
return
False
except
ValueError
as
err
:
self
.
_send_error
(
HTTPStatus
.
BAD_REQUEST
,
str
(
err
))
return
False
return
True
def
_handle_write
(
self
):
"""Handle PUT requests for object creation and bucket creation."""
bucket
,
key
,
parsed
=
self
.
_split_path
()
body
=
self
.
_read_body
()
if
not
self
.
_auth
(
body
,
parsed
):
return
qs
=
_up
.
parse_qs
(
parsed
.
query
,
keep_blank_values
=
True
)
# Multipart: upload part
if
"uploadId"
in
qs
and
"partNumber"
in
qs
:
upload_id
=
qs
[
"uploadId"
][
0
]
try
:
part_no
=
int
(
qs
[
"partNumber"
][
0
])
except
ValueError
:
self
.
_send_error
(
HTTPStatus
.
BAD_REQUEST
,
"Invalid partNumber"
)
return
try
:
self
.
server
.
state
.
upload_part
(
upload_id
,
part_no
,
body
)
except
KeyError
:
self
.
_send_error
(
HTTPStatus
.
NOT_FOUND
,
"Upload not found"
)
return
self
.
_send_status
(
HTTPStatus
.
OK
,
extra_headers
=
{
"ETag"
:
_etag
(
body
)})
return
if
not
bucket
:
self
.
_send_error
(
HTTPStatus
.
BAD_REQUEST
,
"Bucket must be specified"
)
return
if
key
==
""
:
# Bucket create
self
.
server
.
state
.
create_bucket
(
bucket
)
self
.
_send_status
(
HTTPStatus
.
OK
)
return
# Put object
self
.
server
.
state
.
put_object
(
bucket
,
key
,
body
)
self
.
_send_status
(
HTTPStatus
.
OK
,
extra_headers
=
{
"ETag"
:
_etag
(
body
)},
)
def
_handle_read
(
self
,
listing
:
bool
,
only_headers
:
bool
=
False
):
"""Handle GET/HEAD requests for object retrieval and bucket listing.
Args:
listing: Whether this is a bucket listing request.
only_headers: Whether to return only headers (HEAD request).
"""
bucket
,
key
,
parsed
=
self
.
_split_path
()
body
=
b
""
# GET/HEAD normally payload considered in signature (hash of empty string)
if
not
self
.
_auth
(
body
,
parsed
):
return
if
not
bucket
:
self
.
_send_error
(
HTTPStatus
.
BAD_REQUEST
,
"Bucket must be specified"
)
return
if
key
==
""
:
# List bucket contents
if
not
listing
:
# We treat listing with GET only
try
:
objects
=
self
.
server
.
state
.
list_objects
(
bucket
)
except
KeyError
:
self
.
_send_error
(
HTTPStatus
.
NOT_FOUND
,
"Bucket not found"
)
return
xml_body
=
self
.
_render_bucket_list
(
bucket
,
objects
)
self
.
_send_bytes
(
xml_body
,
content_type
=
"application/xml"
)
else
:
self
.
_send_error
(
HTTPStatus
.
NOT_IMPLEMENTED
,
"Listing not implemented"
)
return
try
:
data
=
self
.
server
.
state
.
get_object
(
bucket
,
key
)
except
FileNotFoundError
:
self
.
_send_error
(
HTTPStatus
.
NOT_FOUND
,
"Not found"
)
return
range_header
=
self
.
headers
.
get
(
"Range"
)
if
range_header
and
range_header
.
startswith
(
"bytes="
):
rng
=
range_header
.
split
(
"="
,
1
)[
1
]
if
"-"
not
in
rng
:
self
.
_send_error
(
HTTPStatus
.
REQUESTED_RANGE_NOT_SATISFIABLE
,
"Invalid Range"
)
return
start_str
,
end_str
=
rng
.
split
(
"-"
,
1
)
try
:
start
=
int
(
start_str
)
if
start_str
else
0
end
=
int
(
end_str
)
if
end_str
else
len
(
data
)
-
1
except
ValueError
:
self
.
_send_error
(
HTTPStatus
.
REQUESTED_RANGE_NOT_SATISFIABLE
,
"Invalid Range"
)
return
if
start
>
end
or
start
>=
len
(
data
):
self
.
_send_error
(
HTTPStatus
.
REQUESTED_RANGE_NOT_SATISFIABLE
,
"Invalid Range"
)
return
end
=
min
(
end
,
len
(
data
)
-
1
)
slice_data
=
data
[
start
:
end
+
1
]
headers
=
{
"Content-Range"
:
f
"bytes
{
start
}
-
{
end
}
/
{
len
(
data
)
}
"
,
"Accept-Ranges"
:
"bytes"
,
"Content-Length"
:
str
(
len
(
slice_data
)),
"ETag"
:
_etag
(
data
),
}
if
only_headers
:
headers
.
setdefault
(
"Content-Type"
,
"application/octet-stream"
)
headers
.
setdefault
(
"Last-Modified"
,
formatdate
(
usegmt
=
True
))
self
.
_send_status
(
HTTPStatus
.
PARTIAL_CONTENT
,
extra_headers
=
headers
)
else
:
self
.
_send_bytes
(
slice_data
,
status
=
HTTPStatus
.
PARTIAL_CONTENT
,
content_type
=
"application/octet-stream"
,
extra_headers
=
headers
,
)
else
:
if
only_headers
:
self
.
_send_status
(
HTTPStatus
.
OK
,
extra_headers
=
{
"Content-Length"
:
str
(
len
(
data
)),
"Accept-Ranges"
:
"bytes"
,
"Content-Type"
:
"application/octet-stream"
,
"Last-Modified"
:
formatdate
(
usegmt
=
True
),
"ETag"
:
_etag
(
data
),
},
)
else
:
self
.
_send_bytes
(
data
,
content_type
=
"application/octet-stream"
,
extra_headers
=
{
"Accept-Ranges"
:
"bytes"
},
)
def
_handle_delete
(
self
):
"""Handle DELETE requests for object and bucket deletion."""
bucket
,
key
,
parsed
=
self
.
_split_path
()
body
=
b
""
# empty
if
not
self
.
_auth
(
body
,
parsed
):
return
if
not
bucket
:
self
.
_send_error
(
HTTPStatus
.
BAD_REQUEST
,
"Bucket must be specified"
)
return
if
key
==
""
:
try
:
self
.
server
.
state
.
delete_bucket
(
bucket
)
except
(
KeyError
,
RuntimeError
)
as
err
:
self
.
_send_error
(
HTTPStatus
.
BAD_REQUEST
,
str
(
err
))
return
self
.
_send_status
(
HTTPStatus
.
NO_CONTENT
)
return
try
:
self
.
server
.
state
.
delete_object
(
bucket
,
key
)
except
FileNotFoundError
:
self
.
_send_error
(
HTTPStatus
.
NOT_FOUND
,
"Not found"
)
return
self
.
_send_status
(
HTTPStatus
.
NO_CONTENT
)
def
_handle_post
(
self
):
"""Handle POST requests for multipart upload operations."""
bucket
,
key
,
parsed
=
self
.
_split_path
()
body
=
self
.
_read_body
()
if
not
self
.
_auth
(
body
,
parsed
):
return
qs
=
_up
.
parse_qs
(
parsed
.
query
,
keep_blank_values
=
True
)
# Initiate multipart: POST ?uploads
if
"uploads"
in
qs
or
parsed
.
query
==
"uploads"
:
upload_id
=
self
.
server
.
state
.
initiate_multipart
(
bucket
,
key
)
xml
=
(
'<?xml version="1.0" encoding="UTF-8"?>'
"<InitiateMultipartUploadResult>"
f
"<Bucket>
{
_escape_xml
(
bucket
)
}
</Bucket>"
f
"<Key>
{
_escape_xml
(
key
)
}
</Key>"
f
"<UploadId>
{
upload_id
}
</UploadId>"
"</InitiateMultipartUploadResult>"
).
encode
()
self
.
_send_bytes
(
xml
,
status
=
HTTPStatus
.
OK
,
content_type
=
"application/xml"
)
return
# Complete multipart: POST ?uploadId=xxxx
if
"uploadId"
in
qs
:
upload_id
=
qs
[
"uploadId"
][
0
]
try
:
self
.
server
.
state
.
complete_multipart
(
upload_id
)
except
KeyError
:
self
.
_send_error
(
HTTPStatus
.
NOT_FOUND
,
"Upload not found"
)
return
xml
=
(
'<?xml version="1.0" encoding="UTF-8"?>'
"<CompleteMultipartUploadResult>"
f
"<Bucket>
{
_escape_xml
(
bucket
)
}
</Bucket>"
f
"<Key>
{
_escape_xml
(
key
)
}
</Key>"
f
"<UploadId>
{
upload_id
}
</UploadId>"
"</CompleteMultipartUploadResult>"
).
encode
()
self
.
_send_bytes
(
xml
,
status
=
HTTPStatus
.
OK
,
content_type
=
"application/xml"
)
return
self
.
_send_error
(
HTTPStatus
.
NOT_IMPLEMENTED
,
"Unsupported POST request"
)
def
_send_status
(
self
,
status
:
HTTPStatus
,
extra_headers
:
dict
[
str
,
str
]
|
None
=
None
):
"""Send an HTTP response with the given status code.
Args:
status: The HTTP status code to send.
extra_headers: Optional additional headers to include.
"""
self
.
send_response
(
status
.
value
)
headers
=
{
"Server"
:
"s3-emulator"
}
if
extra_headers
:
headers
.
update
(
extra_headers
)
for
k
,
v
in
headers
.
items
():
self
.
send_header
(
k
,
v
)
self
.
end_headers
()
def
_send_error
(
self
,
status
:
HTTPStatus
,
message
:
str
):
"""Send an error response.
Args:
status: The HTTP status code to send.
message: The error message to include in the response.
"""
print
(
f
"Error
{
status
}
:
{
message
}
"
)
self
.
_send_bytes
(
message
.
encode
(),
status
=
status
,
content_type
=
"text/plain"
)
def
_send_bytes
(
self
,
data
:
bytes
,
status
:
HTTPStatus
=
HTTPStatus
.
OK
,
content_type
:
str
=
"application/octet-stream"
,
extra_headers
:
dict
[
str
,
str
]
|
None
=
None
,
)
->
None
:
"""Send a response with binary data.
Args:
data: The binary data to send.
status: The HTTP status code to send. Defaults to 200 OK.
content_type: The Content-Type header value. Defaults to application/octet-stream.
extra_headers: Optional additional headers to include.
"""
self
.
send_response
(
status
.
value
)
headers
=
{
"Server"
:
"s3-emulator"
,
"Content-Type"
:
content_type
,
"Content-Length"
:
str
(
len
(
data
)),
}
if
extra_headers
:
headers
.
update
(
extra_headers
)
for
k
,
v
in
headers
.
items
():
self
.
send_header
(
k
,
v
)
self
.
end_headers
()
if
self
.
command
!=
"HEAD"
:
self
.
wfile
.
write
(
data
)
@
staticmethod
def
_render_bucket_list
(
bucket
:
str
,
objects
:
list
[
str
])
->
bytes
:
"""Generate an XML listing of objects in a bucket.
Args:
bucket: The bucket name.
objects: List of object keys in the bucket.
Returns:
The XML document as bytes.
"""
entries
=
[]
now
=
datetime
.
now
(
timezone
.
utc
).
strftime
(
"%Y-%m-%dT%H:%M:%S.000Z"
)
for
key
in
objects
:
try
:
data
=
S3RequestHandler
.
server
.
state
.
get_object
(
bucket
,
key
)
# type: ignore[attr-defined]
size
=
len
(
data
)
etag
=
_etag
(
data
)
except
Exception
:
# noqa: BLE001
size
=
0
etag
=
'""'
entries
.
append
(
"<Contents>"
f
"<Key>
{
_escape_xml
(
key
)
}
</Key>"
f
"<LastModified>
{
now
}
</LastModified>"
f
"<ETag>
{
etag
}
</ETag>"
f
"<Size>
{
size
}
</Size>"
"</Contents>"
)
obj_elems
=
""
.
join
(
entries
)
xml
=
(
'<?xml version="1.0" encoding="UTF-8"?>'
"<ListBucketResult>"
f
"<Name>
{
_escape_xml
(
bucket
)
}
</Name>"
f
"
{
obj_elems
}
"
"</ListBucketResult>"
)
return
xml
.
encode
()
class
S3ServerProtocol
(
Protocol
):
# noqa: D101
state
:
S3State
auth
:
S3Auth
def
_escape_xml
(
text
:
str
)
->
str
:
# noqa: D401
"""Escape special characters for XML.
Args:
text: The text to escape.
Returns:
The escaped text.
"""
return
(
text
.
replace
(
"&"
,
"&"
)
.
replace
(
"<"
,
"<"
)
.
replace
(
">"
,
">"
)
.
replace
(
'"'
,
"""
)
.
replace
(
"'"
,
"'"
)
)
def
_etag
(
data
:
bytes
)
->
str
:
# noqa: D401
"""Generate an ETag for binary data.
Args:
data: The binary data to generate an ETag for.
Returns:
The MD5 hash of the data as a hex string.
"""
return
md5
(
data
).
hexdigest
()
Megatron-Energon/tests/s3_emulator/main.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from
pathlib
import
Path
import
click
from
.server
import
S3EmulatorServer
@
click
.
command
()
@
click
.
option
(
"--host"
,
default
=
"0.0.0.0"
,
help
=
"Host to bind the server to"
,
)
@
click
.
option
(
"--port"
,
default
=
9000
,
type
=
int
,
help
=
"Port to bind the server to"
,
)
@
click
.
option
(
"--root-dir"
,
type
=
click
.
Path
(
path_type
=
Path
),
help
=
"Directory to persist S3 data"
,
)
@
click
.
option
(
"--access-key"
,
default
=
"test"
,
help
=
"Access key for authentication"
,
)
@
click
.
option
(
"--secret-key"
,
default
=
"test"
,
help
=
"Secret key for authentication"
,
)
@
click
.
option
(
"--region"
,
default
=
"us-east-1"
,
help
=
"Region for authentication"
,
)
def
main
(
host
:
str
,
port
:
int
,
root_dir
:
Path
|
None
,
access_key
:
str
,
secret_key
:
str
,
region
:
str
)
->
None
:
"""Start an S3 emulator server."""
server
=
S3EmulatorServer
(
host
=
host
,
port
=
port
,
credentials
=
{
access_key
:
secret_key
},
root_dir
=
root_dir
,
region
=
region
,
)
try
:
server
.
serve_forever
()
except
KeyboardInterrupt
:
server
.
shutdown
()
if
__name__
==
"__main__"
:
main
()
Megatron-Energon/tests/s3_emulator/server.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
threading
from
http.server
import
ThreadingHTTPServer
from
pathlib
import
Path
from
typing
import
Mapping
from
.auth
import
S3Auth
from
.handler
import
S3RequestHandler
from
.state
import
S3State
__all__
=
[
"S3EmulatorServer"
]
class
S3EmulatorServer
:
"""A lightweight, blocking S3 HTTP emulator.
This server provides a minimal S3-compatible HTTP interface for testing purposes.
It supports basic S3 operations like bucket and object management.
Example (blocking)::
server = S3EmulatorServer(
host="127.0.0.1",
port=9000,
credentials={"ACCESS": "SECRET"},
)
server.serve_forever()
Example (threaded)::
server = S3EmulatorServer(
host="127.0.0.1",
port=9000,
credentials={"ACCESS": "SECRET"},
)
server.start_background()
# ...
server.shutdown()
server.join()
"""
def
__init__
(
self
,
host
:
str
=
"0.0.0.0"
,
port
:
int
=
0
,
*
,
credentials
:
Mapping
[
str
,
str
]
|
None
=
None
,
root_dir
:
str
|
Path
|
None
=
None
,
region
:
str
=
"us-east-1"
,
):
"""
This server provides a minimal S3-compatible HTTP interface for testing purposes.
It supports basic S3 operations like bucket and object management.
There is no need to check that the port is bound, it is already bound after initialization.
Retrieve the real port with `.port` if set to 0.
The server is listening to the port immediately, but will only start processing after
`start_background()` (threaded) or `.serve_forever()` (blocking) is called.
Args:
host: The host address to bind to.
port: The port to bind to. Use 0 to let the OS choose a free port.
credentials: Optional mapping of access keys to secret keys.
root_dir: Optional path to persist the S3 store on disk.
region: AWS region to emulate.
"""
self
.
_state
=
S3State
(
Path
(
root_dir
)
if
root_dir
else
None
)
self
.
_auth
=
S3Auth
(
credentials
or
{
"test"
:
"test"
},
region
=
region
)
class
_Server
(
ThreadingHTTPServer
):
state
=
self
.
_state
auth
=
self
.
_auth
self
.
_httpd
:
ThreadingHTTPServer
=
_Server
((
host
,
port
),
S3RequestHandler
)
self
.
_thread
:
threading
.
Thread
|
None
=
None
print
(
f
"S3 emulator on http://
{
host
}
:
{
self
.
port
}
"
,
flush
=
True
)
@
property
def
port
(
self
)
->
int
:
"""Returns the port number the server is bound to."""
return
self
.
_httpd
.
server_port
@
property
def
state
(
self
)
->
S3State
:
"""Returns the internal S3 state object."""
return
self
.
_state
def
serve_forever
(
self
):
"""Start the server and block until shutdown is called.
This method will block the calling thread. For non-blocking usage,
see start_background().
"""
try
:
self
.
_httpd
.
serve_forever
()
finally
:
self
.
_state
.
flush
()
def
shutdown
(
self
):
"""Shutdown the server and flush any pending state changes."""
self
.
_httpd
.
shutdown
()
self
.
_state
.
flush
()
def
start_background
(
self
):
"""Start the server in a background thread."""
if
self
.
_thread
and
self
.
_thread
.
is_alive
():
raise
RuntimeError
(
"Server already running"
)
def
_run
():
self
.
serve_forever
()
self
.
_thread
=
threading
.
Thread
(
target
=
_run
,
daemon
=
True
)
self
.
_thread
.
start
()
def
join
(
self
,
timeout
:
float
|
None
=
None
):
"""Join the background thread.
Args:
timeout: Optional timeout in seconds to wait for thread completion.
"""
if
self
.
_thread
is
None
:
return
self
.
_thread
.
join
(
timeout
)
Megatron-Energon/tests/s3_emulator/state.py
0 → 100644
View file @
f356f546
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import
json
from
pathlib
import
Path
from
threading
import
RLock
from
typing
import
Dict
,
Optional
from
uuid
import
uuid4
__all__
=
[
"S3State"
]
class
S3State
:
"""A minimal, thread-safe, in-memory representation of an S3 object store.
Optionally, a root_dir can be supplied to persist the store on the local
file system. The directory structure mirrors the S3 layout:
<root_dir>/<bucket>/<key>
Buckets are directories, objects are stored as regular files. Metadata is
not currently persisted beyond the object byte payload.
"""
def
__init__
(
self
,
root_dir
:
Optional
[
Path
]
=
None
)
->
None
:
"""
Args:
root_dir: Path to persist the store on disk.
"""
self
.
_fs
:
Dict
[
str
,
Dict
[
str
,
bytes
]]
=
{}
self
.
_uploads
:
Dict
[
str
,
_MultipartUpload
]
=
{}
self
.
_lock
=
RLock
()
self
.
_root_dir
=
root_dir
if
self
.
_root_dir
is
not
None
:
self
.
_root_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
_load_from_disk
()
def
list_buckets
(
self
)
->
list
[
str
]:
"""List all buckets in the store.
Returns:
Sorted list of bucket names.
"""
with
self
.
_lock
:
return
sorted
(
self
.
_fs
.
keys
())
def
create_bucket
(
self
,
bucket
:
str
)
->
None
:
"""Create a new bucket.
Args:
bucket: Name of the bucket to create.
"""
with
self
.
_lock
:
if
bucket
in
self
.
_fs
:
print
(
f
"Bucket '
{
bucket
}
' already exists"
)
return
self
.
_fs
[
bucket
]
=
{}
if
self
.
_root_dir
is
not
None
:
(
self
.
_root_dir
/
bucket
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
def
delete_bucket
(
self
,
bucket
:
str
)
->
None
:
"""Delete a bucket.
Args:
bucket: Name of the bucket to delete.
"""
with
self
.
_lock
:
if
bucket
not
in
self
.
_fs
:
raise
KeyError
(
f
"Bucket '
{
bucket
}
' does not exist"
)
if
self
.
_fs
[
bucket
]:
raise
RuntimeError
(
"Bucket not empty"
)
del
self
.
_fs
[
bucket
]
if
self
.
_root_dir
is
not
None
:
bucket_path
=
self
.
_root_dir
/
bucket
if
bucket_path
.
exists
():
for
p
in
bucket_path
.
rglob
(
"*"
):
p
.
unlink
()
bucket_path
.
rmdir
()
def
put_object
(
self
,
bucket
:
str
,
key
:
str
,
data
:
bytes
)
->
None
:
"""Store an object in a bucket.
Args:
bucket: Name of the bucket.
key: Object key.
data: Object data.
"""
if
not
bucket
:
raise
ValueError
(
"Bucket name must be given"
)
with
self
.
_lock
:
if
bucket
not
in
self
.
_fs
:
self
.
_fs
[
bucket
]
=
{}
self
.
_fs
[
bucket
][
key
]
=
data
if
self
.
_root_dir
is
not
None
:
obj_path
=
(
self
.
_root_dir
/
bucket
/
key
).
resolve
()
obj_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
obj_path
.
write_bytes
(
data
)
def
get_object
(
self
,
bucket
:
str
,
key
:
str
)
->
bytes
:
"""Retrieve an object from a bucket.
Args:
bucket: Name of the bucket.
key: Object key.
Returns:
The object data.
"""
with
self
.
_lock
:
try
:
return
self
.
_fs
[
bucket
][
key
]
except
KeyError
as
exc
:
raise
FileNotFoundError
(
f
"
{
bucket
}
/
{
key
}
"
)
from
exc
def
delete_object
(
self
,
bucket
:
str
,
key
:
str
)
->
None
:
"""Delete an object from a bucket.
Args:
bucket: Name of the bucket.
key: Object key.
"""
with
self
.
_lock
:
try
:
del
self
.
_fs
[
bucket
][
key
]
except
KeyError
as
exc
:
raise
FileNotFoundError
(
f
"
{
bucket
}
/
{
key
}
"
)
from
exc
if
self
.
_root_dir
is
not
None
:
obj_path
=
self
.
_root_dir
/
bucket
/
key
if
obj_path
.
exists
():
obj_path
.
unlink
(
missing_ok
=
True
)
def
list_objects
(
self
,
bucket
:
str
)
->
list
[
str
]:
"""List all objects in a bucket.
Args:
bucket: Name of the bucket.
Returns:
Sorted list of object keys.
"""
with
self
.
_lock
:
if
bucket
not
in
self
.
_fs
:
raise
KeyError
(
f
"Bucket '
{
bucket
}
' does not exist"
)
return
sorted
(
self
.
_fs
[
bucket
].
keys
())
STATE_FILE
=
"__state.json"
def
_load_from_disk
(
self
)
->
None
:
"""Load persisted state from root_dir.
The object payload itself is not loaded in memory to keep startup
affordable. Only the structure (bucket -> keys) is persisted in a
state file.
"""
if
self
.
_root_dir
is
None
:
return
state_file
=
self
.
_root_dir
/
self
.
STATE_FILE
if
not
state_file
.
exists
():
return
try
:
mapping
=
json
.
loads
(
state_file
.
read_text
())
except
Exception
as
err
:
# noqa: BLE001
print
(
f
"Failed to read persisted state:
{
err
}
"
)
return
with
self
.
_lock
:
self
.
_fs
=
{
bucket
:
{
key
:
b
""
for
key
in
keys
}
for
bucket
,
keys
in
mapping
.
items
()}
def
flush
(
self
)
->
None
:
"""Persist only the structure of the store to disk."""
if
self
.
_root_dir
is
None
:
return
mapping
=
{
bucket
:
list
(
objects
.
keys
())
for
bucket
,
objects
in
self
.
_fs
.
items
()}
(
self
.
_root_dir
/
self
.
STATE_FILE
).
write_text
(
json
.
dumps
(
mapping
))
def
initiate_multipart
(
self
,
bucket
:
str
,
key
:
str
)
->
str
:
"""Create a new multipart upload.
Args:
bucket: Name of the bucket.
key: Object key.
Returns:
The upload ID.
"""
with
self
.
_lock
:
upload_id
=
uuid4
().
hex
self
.
_uploads
[
upload_id
]
=
_MultipartUpload
(
bucket
,
key
)
if
bucket
not
in
self
.
_fs
:
self
.
_fs
[
bucket
]
=
{}
return
upload_id
def
upload_part
(
self
,
upload_id
:
str
,
part_number
:
int
,
data
:
bytes
)
->
None
:
"""Upload a part of a multipart upload.
Args:
upload_id: The upload ID.
part_number: The part number.
data: The part data.
"""
with
self
.
_lock
:
mp
=
self
.
_uploads
.
get
(
upload_id
)
if
mp
is
None
:
raise
KeyError
(
"Invalid upload_id"
)
mp
.
parts
[
part_number
]
=
data
def
complete_multipart
(
self
,
upload_id
:
str
)
->
None
:
"""Complete a multipart upload.
Args:
upload_id: The upload ID.
"""
with
self
.
_lock
:
mp
=
self
.
_uploads
.
pop
(
upload_id
,
None
)
if
mp
is
None
:
raise
KeyError
(
"Invalid upload_id"
)
data
=
mp
.
assemble
()
if
mp
.
bucket
not
in
self
.
_fs
:
self
.
_fs
[
mp
.
bucket
]
=
{}
self
.
_fs
[
mp
.
bucket
][
mp
.
key
]
=
data
if
self
.
_root_dir
is
not
None
:
obj_path
=
(
self
.
_root_dir
/
mp
.
bucket
/
mp
.
key
).
resolve
()
obj_path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
obj_path
.
write_bytes
(
data
)
def
abort_multipart
(
self
,
upload_id
:
str
)
->
None
:
"""Abort a multipart upload.
Args:
upload_id: The upload ID.
"""
with
self
.
_lock
:
self
.
_uploads
.
pop
(
upload_id
,
None
)
def
add_file
(
self
,
src
:
Path
,
dst
:
str
):
"""Add a file or directory to the store.
Args:
src: Source file or directory path.
dst: Destination path in S3 format (bucket/key).
"""
if
src
.
is_dir
():
dst
=
dst
.
removesuffix
(
"/"
)
for
file
in
src
.
iterdir
():
self
.
add_file
(
file
,
dst
=
f
"
{
dst
}
/
{
file
.
name
}
"
)
elif
src
.
is_file
():
bucket
,
key
=
dst
.
removeprefix
(
"/"
).
split
(
"/"
,
1
)
self
.
put_object
(
bucket
,
key
,
src
.
read_bytes
())
else
:
raise
ValueError
(
f
"Invalid file:
{
src
}
"
)
class
_MultipartUpload
:
"""Internal helper class for managing multipart uploads."""
__slots__
=
(
"bucket"
,
"key"
,
"parts"
)
def
__init__
(
self
,
bucket
:
str
,
key
:
str
):
self
.
bucket
=
bucket
self
.
key
=
key
self
.
parts
:
Dict
[
int
,
bytes
]
=
{}
def
assemble
(
self
)
->
bytes
:
"""Assemble the uploaded parts into a complete object.
Returns:
The complete object data.
"""
if
not
self
.
parts
:
return
b
""
return
b
""
.
join
(
self
.
parts
[
n
]
for
n
in
sorted
(
self
.
parts
))
Prev
1
…
6
7
8
9
10
11
12
13
14
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment