Commit f356f546 authored by maming's avatar maming
Browse files

Initial commit

parents
Pipeline #3339 canceled with stages
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, TypeVar
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
T_sample = TypeVar("T_sample")
def default_get_keys(batch: Any) -> Optional[List[str]]:
"""Default get_keys, which has some heuristics to find the sample keys."""
if isinstance(batch, list):
batch = batch[0]
if (
hasattr(batch, "__key__")
and isinstance(batch.__key__, list)
and all(isinstance(k, str) for k in batch.__key__)
):
return batch.__key__
elif (
hasattr(batch, "__keys__")
and isinstance(batch.__keys__, list)
and all(isinstance(k, str) for k in batch.__keys__)
):
return batch.__keys__
elif (
isinstance(batch, dict)
and "__key__" in batch
and all(isinstance(k, str) for k in batch["__key__"])
):
return batch["__key__"]
elif (
isinstance(batch, dict)
and "__keys__" in batch
and all(isinstance(k, str) for k in batch["__keys__"])
):
return batch["__keys__"]
elif (
isinstance(batch, dict)
and "keys" in batch
and all(isinstance(k, str) for k in batch["keys"])
):
return batch["keys"]
return None
class LogSampleDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""This dataset logs every yielded sample to the debug logs."""
get_keys_fn: Callable[[T_sample], Optional[List[str]]]
mode: Literal["train", "val"]
_step: int
_savable_fields = ("_step",)
def __init__(
self,
dataset: SavableDataset[T_sample],
mode: Literal["train", "val"],
worker_config: WorkerConfig,
get_keys_fn: Callable[[T_sample], Optional[List[str]]] = default_get_keys,
):
"""Construct the log sample dataset, which logs every yielded sample to the debug logs.
Args:
dataset: The input dataset to wrap
"""
super().__init__(dataset, worker_config=worker_config)
self.get_keys_fn = get_keys_fn
self.mode = mode
self.reset_state_own()
def reset_state_own(self) -> None:
self._step = 0
def len_worker(self, worker_idx: int | None = None) -> int:
return self.dataset.len_worker(worker_idx)
def _log(self, sample: T_sample) -> None:
if self.worker_config.should_log(level=1):
log_entry = {
"t": "yield_batch",
"r": self.worker_config.rank,
"w": self.worker_config.global_worker_id(),
"m": self.mode,
"idx": self._step,
}
keys = self.get_keys_fn(sample)
if keys is not None:
log_entry["keys"] = keys
self.worker_config.worker_log(log_entry)
def __iter__(self) -> Iterator[T_sample]:
for sample in self.dataset:
self._log(sample)
self._step += 1
yield sample
def config(self) -> Dict[str, Any]:
# Transparent logger, it won't change the samples
return self.dataset.config()
def __str__(self):
return f"LogSampleDataset(mode={self.mode}, get_keys_fn={self.get_keys_fn}, dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import inspect
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Iterator,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError
from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key
from megatron.energon.source_info import SourceInfo
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers._log_exception import log_exception
from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key
from megatron.energon.wrappers.skip import SkipSample
T_sample = TypeVar("T_sample")
T_sample_out = TypeVar("T_sample_out")
class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]):
"""This dataset wrapper applies a custom function to transform each sample."""
map_fn: Callable[[T_sample], Union[T_sample_out, Generator[T_sample_out, None, None]]]
error_handler: Callable[[Exception, T_sample, Sequence[SourceInfo]], None]
stateless_map_fn: bool
map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]]
_sample_index: SampleIndex
_generator_sample_key: Optional[Any]
_generator_offset: Optional[int]
_last_map_failures: int = 0
_savable_fields = (
"_sample_index",
"_generator_sample_key",
"_generator_offset",
)
def __init__(
self,
dataset: SavableDataset[T_sample],
map_fn: Callable[[T_sample], Union[T_sample_out, Generator[T_sample_out, None, None]]],
*,
error_handler: Callable[[Exception, T_sample, Sequence[SourceInfo]], None] = log_exception,
stateless_map_fn: bool = False,
map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None,
failure_tolerance: int = 100,
worker_config: WorkerConfig,
):
"""Construct a MapDataset.
If this should be savable, the map_fn must only return a sample, or a generator yielding
0 or 1 sample per input sample. Otherwise this will be broken (see `IterMapDataset`).
Args:
dataset: The input dataset to wrap
map_fn: The function to apply to each sample. May raise
:exc:`megatron.energon.SkipSample` to skip a sample. Alternatively, may return a
generator to yield multiple or no samples.
error_handler: Handler for errors. Defaults to logging and ignoring the exception.
stateless_map_fn: If true, the map_fn is deterministic and stateless
(thus key for random access can propagate to inner dataset). Defaults to False.
map_fn_config: Configuration for the map_fn function. If callable, it should return the
configuration. Defaults to None.
failure_tolerance: The number of consecutive failures after which the dataset is considered broken. Set to 0 to disable.
worker_config: Worker configuration.
"""
super().__init__(dataset, worker_config=worker_config)
self.map_fn = map_fn
self.error_handler = error_handler
self.stateless_map_fn = stateless_map_fn
self.map_fn_config = map_fn_config
self.failure_tolerance = failure_tolerance
self.reset_state_own()
def reset_state_own(self) -> None:
self._sample_index = SampleIndex(self.worker_config, src=self)
self._generator_sample_key = None
self._generator_offset = None
def len_worker(self, worker_idx: int | None = None) -> int:
return self.dataset.len_worker(worker_idx)
def __iter__(self) -> Iterator[T_sample_out]:
if self._generator_sample_key is not None:
assert self._generator_offset is not None
sample = self.dataset.restore_sample(self._generator_sample_key)
# Do not increment the sample index, use previous index
with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx:
mapped_sample = self.map_fn(sample)
assert isinstance(mapped_sample, Generator)
assert inspect.isgeneratorfunction(self.map_fn), (
f"Generator in {self.map_fn} but not marked as such."
)
target_offset = self._generator_offset
self._generator_offset = 0
for idx, (sample_idx, inner_sample) in enumerate(
self._sample_index.iter_ctx(mapped_sample, sample_idx)
):
# Skip other samples
if idx >= target_offset:
self._generator_offset = idx + 1
yield add_sample_restore_key(
inner_sample,
sample_idx,
idx,
src=self,
)
self._generator_sample_key = None
self._generator_offset = None
for sample in self.dataset:
restore_key = get_sample_restore_key(sample)
try:
with self._sample_index.ctx() as sample_idx:
mapped_sample = self.map_fn(sample)
if isinstance(mapped_sample, Generator):
assert inspect.isgeneratorfunction(self.map_fn), (
f"Generator in {self.map_fn} but not marked as such."
)
self._generator_sample_key = restore_key
self._generator_offset = 0
# In case of a generator, additionally store the index of the yielded samples
# per input sample
for idx, (sample_idx, inner_sample) in enumerate(
self._sample_index.iter_ctx(mapped_sample, sample_idx)
):
self._generator_offset = idx + 1
self._last_map_failures = 0
yield add_sample_restore_key(
inner_sample,
sample_idx,
idx,
src=self,
)
self._generator_sample_key = None
self._generator_offset = None
else:
self._last_map_failures = 0
yield add_sample_restore_key(
mapped_sample,
sample_idx,
src=self,
)
except GeneratorExit:
raise
except SkipSample:
pass
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(sample)
except Exception as e:
self.error_handler(e, sample)
self._last_map_failures += 1
print(
f"MapDataset {self.map_fn} failed {self._last_map_failures}/{self.failure_tolerance} times in a row."
)
if self.failure_tolerance > 0 and self._last_map_failures >= self.failure_tolerance:
raise FatalSampleError.from_sample(
sample,
f"MapDataset {self.map_fn} failed {self._last_map_failures} times in a row. Likely your code or dataset are broken.",
)
def can_restore_sample(self) -> bool:
return super().can_restore_sample() and self.stateless_map_fn
def assert_can_restore(self) -> None:
assert self.stateless_map_fn, (
f"MapDataset can only restore samples if map_fn {self.map_fn} is stateless."
)
super().assert_can_restore()
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample_out:
self.assert_can_restore()
if inspect.isgeneratorfunction(self.map_fn):
id, sample_idx, local_idx = restore_key[:3]
assert id == type(self).__name__
restore_key = restore_key[3:]
assert isinstance(local_idx, int)
else:
id, sample_idx = restore_key[:2]
assert id == type(self).__name__
restore_key = restore_key[2:]
inner_sample = self.dataset.restore_sample(restore_key)
try:
with self._sample_index.ctx(sample_idx):
mapped_sample = self.map_fn(inner_sample)
if isinstance(mapped_sample, Generator):
assert inspect.isgeneratorfunction(self.map_fn), (
f"Generator in {self.map_fn} but not marked as such."
)
for idx, (sample_idx, res_sample) in enumerate(
self._sample_index.iter_ctx(mapped_sample, sample_idx)
):
self._last_map_failures = 0
if idx == local_idx:
return add_sample_restore_key(res_sample, sample_idx, local_idx, src=self)
assert False, (
"Generator did not yield enough samples, but is marked stateless/deterministic."
)
else:
self._last_map_failures = 0
return add_sample_restore_key(mapped_sample, sample_idx, src=self)
except GeneratorExit:
raise FatalSampleError.from_sample(
inner_sample,
f"MapDataset {self.map_fn} generator exited while trying to restore a sample.",
)
except SkipSample:
raise FatalSampleError.from_sample(
inner_sample, f"MapDataset {self.map_fn} skipped while trying to restore a sample."
)
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(inner_sample)
except Exception as e:
self.error_handler(e, inner_sample)
self._last_map_failures += 1
if self.failure_tolerance > 0 and self._last_map_failures >= self.failure_tolerance:
raise FatalSampleError.from_sample(
inner_sample,
f"MapDataset {self.map_fn} failed {self._last_map_failures} times in a row. Likely your code or dataset are broken.",
)
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset": self.dataset.config(),
"map_fn": self._function_config(self.map_fn),
**(
{
"map_fn_config": (
self.map_fn_config() if callable(self.map_fn_config) else self.map_fn_config
)
}
if self.map_fn_config
else {}
),
"map_fn_stateless": self.stateless_map_fn,
}
def __str__(self):
return f"MapDataset(map_fn={self.map_fn}, dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
from typing import Any, Callable, Dict, Generator, Generic, Iterator, List, Tuple, TypeVar, Union
import torch
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
from megatron.energon.wrappers.batch_dataset import BatchDataset
from megatron.energon.wrappers.blend_dataset import BlendDataset
T_batch_in = TypeVar("T_batch_in")
T_batch = TypeVar("T_batch")
def generic_concat(batch: List[Any]) -> Any:
"""Based on the types/shapes of the batch: Will either pad and stack, or return as list.
Recurses structures (dict, dataclass, namedtuple) and applies the same logic to each field."""
if isinstance(batch[0], torch.Tensor):
return concat_pad(batch)
elif isinstance(batch[0], dict):
return {key: generic_concat([sample[key] for sample in batch]) for key in batch[0].keys()}
elif dataclasses.is_dataclass(batch[0]):
return type(batch[0])(
**{
field.name: generic_concat([getattr(sample, field.name) for sample in batch])
for field in dataclasses.fields(batch[0])
}
)
elif isinstance(batch[0], tuple) and hasattr(batch[0], "_fields"):
# NamedTuple
return type(batch[0])(
**{
field: generic_concat([getattr(sample, field) for sample in batch])
for field in batch[0]._fields
}
)
else:
return batch
def concat_pad(batch: List[Any]) -> Any:
"""Concat a batch of arbitrary-sized tensors padded with 0s."""
total_bs = sum(b.shape[0] for b in batch)
max_size = [max(b.shape[dim] for b in batch) for dim in range(1, batch[0].ndim)]
concat_tensor = batch[0].new_zeros((total_bs, *max_size))
b_idx = 0
for b in batch:
concat_tensor[(slice(b_idx, b_idx + b.shape[0]), *(slice(0, s) for s in b.shape[1:]))] = b
b_idx += b.shape[0]
# Pad all tensors to max_size
return concat_tensor
def homogeneous_concat_mix(samples: List[T_batch_in]) -> T_batch:
"""
Mixes a list of batches into a single batch. The default implementation is to concat the
batches if they are all of the same type, otherwise return a list of batches.
Args:
samples: THe samples to mix.
Returns:
The mixed batch.
"""
first_type = type(samples[0])
assert all(first_type is type(sample) for sample in samples)
# All the same type -> concat batches
return generic_concat(samples)
class MixBatchDataset(BaseWrapperDataset[T_batch_in, T_batch], Generic[T_batch_in, T_batch]):
"""
This dataset wrapper blends multiple iterable datasets together give a weight.
The datasets may be infinite. This dataset is always infinite.
Effectively combines :class:`megatron.energon.BlendDataset` and :class:`megatron.energon.BatchDataset`.
"""
def __init__(
self,
*dataset_weights: Tuple[SavableDataset[T_batch_in], float],
batch_size: int,
batch_mix_fn: Callable[
[List[T_batch_in]], Union[T_batch, Generator[T_batch, None, None]]
] = lambda x: x,
worker_config: WorkerConfig,
):
"""Construct a BlendDataset.
Args:
dataset_weights: Each argument should be a tuple of (dataset, weight) with a weight
between 0 and 1. The output samples are sampled from the input datasets with the
given probabilities. The datasets should have a batch size of 1, otherwise the
whole batches will be sampled.
batch_size: The batch size to output.
batch_mix_fn: A function that takes a list of samples from the input datasets and
returns a batch sample. The default implementation returns a list of batches.
For homogeneous datasets, it is recommended to use the
:func:`megatron.energon.homogeneous_concat_mix` which concatenates the batches. May raise
:exc:`megatron.energon.SkipSample` to skip a sample. May also return a generator, which
will be iterated over to produce batches.
worker_config: Configuration for the workers.
"""
super().__init__(
BatchDataset(
BlendDataset(*dataset_weights, worker_config=worker_config),
batch_size=batch_size,
batcher=batch_mix_fn,
worker_config=worker_config,
),
worker_config=worker_config,
)
def reset_state_own(self) -> None:
return
def len_worker(self, worker_idx: int | None = None) -> int:
return self.dataset.len_worker(worker_idx)
def __iter__(self) -> Iterator[T_batch]:
yield from self.dataset
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset": self.dataset.config(),
}
def __str__(self):
return f"MixBatchDataset(dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import contextlib
import inspect
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
)
from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError
from megatron.energon.flavors.base_dataset import (
SavableDataset,
add_sample_restore_key,
set_sample_restore_key,
)
from megatron.energon.source_info import SourceInfo
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers._log_exception import log_exception
from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key
from megatron.energon.wrappers.buffer import SavableSampleBuffer
from megatron.energon.wrappers.skip import SkipSample
T_sample = TypeVar("T_sample")
T_encoded_sample = TypeVar("T_encoded_sample")
T_batch_sample = TypeVar("T_batch_sample")
class PackingDataset(
BaseWrapperDataset[T_sample, T_batch_sample],
Generic[T_sample, T_encoded_sample, T_batch_sample],
):
"""This dataset wrapper transforms samples of a dataset into chunks/packs of samples, which are
then combined into a batch."""
buffer_size: int
pre_packer: Callable[[List[T_sample]], List[List[T_sample]]]
sample_encoder: Optional[Callable[[T_sample], T_encoded_sample]]
sample_encoder_stateless: bool
final_packer: Callable[[List[T_encoded_sample]], T_batch_sample]
final_packer_stateless: bool
packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]]
error_handler: Callable[[Exception, List[T_sample], Sequence[SourceInfo]], None]
#: The buffer for collecting the samples that shall be packed.
_reading_buffer: SavableSampleBuffer
#: Contains the pre-selected samples to be packed.
#: The full buffer will be passed to the pre_packer.
_pre_packing_buffer: SavableSampleBuffer
#: Lengths of the selected groups of samples to be packed together.
#: The samples are stored sequentially in the pre_packing_buffer because
#: SavableSampleBuffer doesn't support nesting. But to keep the groups
#: separate, we need to store the lengths of the groups here.
_pre_packing_lengths: List[int]
#: Sample index for the pre_packer
_pre_packing_sample_index: SampleIndex
#: Sample index for the sample_encoder
_sample_encoder_sample_index: SampleIndex
#: Sample index for the final_packer
_final_packing_sample_index: SampleIndex
# Local state: Tracking last failures for each component, to raise a fatal error after a certain number of failures.
_last_pre_pack_failures: int = 0
_last_final_pack_failures: int = 0
_last_sample_encoder_failures: int = 0
_savable_fields = (
"_reading_buffer",
"_pre_packing_buffer",
"_pre_packing_lengths",
"_pre_packing_sample_index",
"_sample_encoder_sample_index",
"_final_packing_sample_index",
)
def __init__(
self,
dataset: SavableDataset[T_sample],
buffer_size: int,
pre_packer: Callable[[List[T_sample]], List[List[T_sample]]],
final_packer: Callable[[List[T_encoded_sample]], T_batch_sample],
*,
final_packer_stateless: bool = False,
sample_encoder: Optional[Callable[[T_sample], T_encoded_sample]] = None,
sample_encoder_stateless: bool = False,
packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None,
error_handler: Callable[
[Exception, List[T_sample], Sequence[SourceInfo]], None
] = log_exception,
pre_packer_failure_tolerance: int = 100,
final_packer_failure_tolerance: int = 100,
sample_encoder_failure_tolerance: int = 100,
worker_config: WorkerConfig,
):
"""Construct a PackingDataset which is used for sequence packing.
Using a pre_packer and final_packer, it buffers the incoming samples, groups
them together based on the logic provided by the pre_packer, and then (using
the final_packer) combines each group into a packed single sample also called
a "pack" or a "packed sequence".
Args:
dataset: The input dataset to wrap
buffer_size: The desired size of the input buffer for pre packing. Last buffer of a dataset may be smaller.
pre_packer: Function which selects samples from the buffer to be packed together.
May raise :exc:`megatron.energon.SkipSample` to skip a buffer.
final_packer: Function which combines the selected samples into a single sample.
final_packer_stateless: If True, the final_packer is stateless, thus samples can be
stored/restored.
sample_encoder: Function which encodes the samples.
sample_encoder_stateless: If True, the sample_encoder is stateless, thus samples can be
stored/restored.
packer_config: Configuration for the (pre|final)_packer functions. If callable, it should return the
configuration. Defaults to None.
error_handler: Function which handles exceptions raised by the batcher. The default
implementation logs the exception.
pre_packer_failure_tolerance: Maximum number of pre-packer failures before raising an error. Set to 0 to disable.
final_packer_failure_tolerance: Maximum number of final-packer failures before raising an error. Set to 0 to disable.
sample_encoder_failure_tolerance: Maximum number of sample-encoder failures before raising an error. Set to 0 to disable.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
assert buffer_size > 0, "Packing buffer size must be greater than 0."
self.buffer_size = buffer_size
self.pre_packer = pre_packer
self.final_packer = final_packer
self.final_packer_stateless = final_packer_stateless
self.sample_encoder = sample_encoder
self.sample_encoder_stateless = True if sample_encoder is None else sample_encoder_stateless
self.packer_config = packer_config
self.error_handler = error_handler
self.pre_packer_failure_tolerance = pre_packer_failure_tolerance
self.final_packer_failure_tolerance = final_packer_failure_tolerance
self.sample_encoder_failure_tolerance = sample_encoder_failure_tolerance
self.reset_state_own()
def reset_state_own(self) -> None:
self._reading_buffer = SavableSampleBuffer(self.dataset, worker_config=self.worker_config)
self._pre_packing_buffer = SavableSampleBuffer(
self.dataset, worker_config=self.worker_config
)
self._pre_packing_lengths = []
self._pre_packing_sample_index = SampleIndex(self.worker_config, src=self)
self._final_packing_sample_index = SampleIndex(self.worker_config, src=self)
self._sample_encoder_sample_index = SampleIndex(self.worker_config, src=self)
def len_worker(self, worker_idx: int | None = None) -> int:
# The real length is unknown, since it depends on the packing function.
# We approximate it by the length of the source dataset.
return self.dataset.len_worker(worker_idx)
def _fill_reading_buffer(self, source_iter: Iterator, log_progress: bool = False) -> bool:
"""
Fill the reading buffer with samples from the dataset source iterator.
Args:
source_iter: Iterator of samples from the dataset.
log_progress: If True, log the progress of the filling.
Returns:
True if samples are successfully read into the buffer, False if no more data.
"""
if log_progress:
import tqdm
pbar_ctx = pbar = tqdm.tqdm(total=self.buffer_size, desc="Filling reading buffer")
else:
pbar_ctx = contextlib.nullcontext()
pbar = None
with pbar_ctx:
while (
self._reading_buffer.len_worker() + self._pre_packing_buffer.len_worker()
< self.buffer_size
):
try:
sample = next(source_iter)
self._reading_buffer.append(sample)
if pbar is not None:
pbar.update(1)
except StopIteration:
return False
return True
def __iter__(self) -> Iterator[T_batch_sample]:
pre_packing_lengths = self._pre_packing_lengths
# The source dataset
src_iter = iter(self.dataset)
self._pre_packing_buffer.worker_start()
self._reading_buffer.worker_start()
is_initial_pack = True
def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]:
"""Encode the samples in the pack using the sample encoder."""
# Apply the sample encoder to the pack
if self.sample_encoder is None:
return pack
encoded_pack = []
for sample in pack:
try:
with self._sample_encoder_sample_index.ctx() as encode_idx:
encoded_sample = self.sample_encoder(sample)
assert not isinstance(encoded_sample, Generator), "Generator not supported"
encoded_pack.append(
add_sample_restore_key(
encoded_sample,
encode_idx,
src=self,
)
)
self._last_sample_encoder_failures = 0
except SkipSample:
pass
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(pack)
except Exception as e:
self.error_handler(e, [sample])
self._last_sample_encoder_failures += 1
if (
self.sample_encoder_failure_tolerance > 0
and self._last_sample_encoder_failures
>= self.sample_encoder_failure_tolerance
):
raise FatalSampleError.from_sample(
pack,
f"Sample encoder {self.sample_encoder} failed {self._last_sample_encoder_failures} times. Likely your code or dataset are broken.",
)
return encoded_pack
def next_pre_pack():
"""Take the samples from the reading buffer and select groups of samples to be packed
together."""
assert self._pre_packing_buffer.len_worker() == 0
if self._reading_buffer.len_worker() > 0:
# Take all samples from the reading buffer and pre_pack them
samples = self._reading_buffer.buffer.copy()
# Clear buffer and pre_packing_lengths
self._reading_buffer.clear()
pre_packing_lengths.clear()
# Now pre pack the samples
try:
with self._pre_packing_sample_index.ctx():
pre_packs = self.pre_packer(samples)
self._last_pre_pack_failures = 0
except SkipSample:
pre_packs = []
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(samples)
except Exception as e:
self.error_handler(e, samples)
pre_packs = []
self._last_pre_pack_failures += 1
if (
self.pre_packer_failure_tolerance > 0
and self._last_pre_pack_failures >= self.pre_packer_failure_tolerance
):
raise FatalSampleError.from_sample(
samples,
f"Pre packer {self.pre_packer} failed {self._last_pre_pack_failures} times. Likely your code or dataset are broken.",
)
# Put the pre-packed samples into the pre_packing_buffer
# They will be flattened here to avoid nested buffers
# But the lengths of the groups are stored in pre_packing_lengths
# so that the groups can be separated later
for pre_pack in pre_packs:
if len(pre_pack) > 0:
self._pre_packing_buffer.extend(pre_pack)
pre_packing_lengths.append(len(pre_pack))
def next_final_pack() -> Generator[T_batch_sample, None, None]:
"""Yield the next packs from the buffer. The final packer is called on the fly."""
pack = self._pre_packing_buffer.buffer[: pre_packing_lengths[0]].copy()
if len(pack) == 0:
return
pack = encode_pack_samples(pack)
del self._pre_packing_buffer[: pre_packing_lengths[0]]
del pre_packing_lengths[0]
try:
pack_restore_keys = tuple(get_sample_restore_key(sample) for sample in pack)
with self._final_packing_sample_index.ctx() as pack_idx:
final_packed_sample = self.final_packer(pack)
if isinstance(final_packed_sample, Generator):
assert inspect.isgeneratorfunction(self.final_packer), (
f"Generator in {self.final_packer} but not marked as such."
)
for pack_sub_idx, (pack_idx, inner_batch_sample) in enumerate(
self._final_packing_sample_index.iter_ctx(final_packed_sample, pack_idx)
):
self._last_final_pack_failures = 0
yield set_sample_restore_key(
inner_batch_sample,
pack_idx,
pack_sub_idx,
*pack_restore_keys,
src=self,
)
else:
self._last_final_pack_failures = 0
yield set_sample_restore_key(
final_packed_sample,
pack_idx,
*pack_restore_keys,
src=self,
)
except SkipSample:
pass
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(pack)
except Exception as e:
self.error_handler(e, pack)
self._last_final_pack_failures += 1
if (
self.final_packer_failure_tolerance > 0
and self._last_final_pack_failures >= self.final_packer_failure_tolerance
):
raise FatalSampleError.from_sample(
pack,
f"Final packer {self.final_packer} failed {self._last_final_pack_failures} times. Likely your code or dataset are broken.",
)
# Main loop:
pre_pack_round = 0
while True:
if (
self.pre_packer_failure_tolerance > 0
and pre_pack_round > self.pre_packer_failure_tolerance
):
raise RuntimeError(
f"Pre packer {self.pre_packer} did not yield any packs after {pre_pack_round} rounds. Likely your code or dataset are broken."
)
# Fill a portion of the buffer
if not self._fill_reading_buffer(src_iter, log_progress=is_initial_pack):
# Break out of the main loop when the source is exhausted.
break
is_initial_pack = False
# Create new pre packs if necessary
if len(pre_packing_lengths) == 0:
assert self._pre_packing_buffer.len_worker() == 0
assert self._reading_buffer.len_worker() == self.buffer_size
next_pre_pack()
if len(pre_packing_lengths) == 0:
# Retry packing, nothing was returned.
pre_pack_round += 1
continue
if len(pre_packing_lengths) > 0:
pre_pack_round = 0
yield from next_final_pack()
# Yield the remaining packs, flushing the collecting buffer
while len(pre_packing_lengths) > 0:
yield from next_final_pack()
# If there are still samples in the partial reading buffer, pre-pack them and yield the
# resulting (partial) packs
if self._reading_buffer.len_worker() > 0:
next_pre_pack()
# Yield the remaining packs, flushing the collecting buffer
while len(pre_packing_lengths) > 0:
yield from next_final_pack()
def can_restore_sample(self) -> bool:
# Cannot really verify if the returned elements contain a __restore_key__.
# If the user wants to use this, well...
return (
super().can_restore_sample()
and self.final_packer_stateless
and self.sample_encoder_stateless
)
def assert_can_restore(self):
assert self.final_packer_stateless and self.sample_encoder_stateless, (
f"Final packer {self.final_packer} and sample encoder {self.sample_encoder} must be stateless to restore samples."
)
super().assert_can_restore()
def restore_sample(self, restore_key: Any) -> T_sample:
# We need to store multiple indices to restore a batch.
self.assert_can_restore()
if inspect.isgeneratorfunction(self.final_packer):
id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key
id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key
assert id == type(self).__name__
else:
id, pack_idx, *pack_restore_keys = restore_key
id, pack_idx, *pack_restore_keys = restore_key
assert id == type(self).__name__
pack = []
for inner_idx in pack_restore_keys:
if self.sample_encoder is not None:
id, sample_idx, *inner_idx = inner_idx
assert id == type(self).__name__
id, sample_idx, *inner_idx = inner_idx
assert id == type(self).__name__
assert isinstance(sample_idx, int)
sample = self.dataset.restore_sample(inner_idx)
try:
if self.sample_encoder is not None:
with self._sample_encoder_sample_index.ctx(sample_idx):
sample = self.sample_encoder(sample)
assert not isinstance(sample, Generator), "Generator not supported"
self._last_sample_encoder_failures = 0
sample = add_sample_restore_key(sample, sample_idx, src=self)
except SkipSample:
raise FatalSampleError.from_sample(
sample,
f"PackingDataset sample encoder {self.sample_encoder} skipped while trying to restore a sample.",
)
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(sample)
except Exception as e:
self.error_handler(e, sample)
self._last_sample_encoder_failures += 1
if (
self.sample_encoder_failure_tolerance > 0
and self._last_sample_encoder_failures >= self.sample_encoder_failure_tolerance
):
raise FatalSampleError.from_sample(
sample,
f"PackingDataset sample encoder {self.sample_encoder} failed {self._last_sample_encoder_failures} times. Likely your code or dataset are broken.",
)
pack.append(sample)
try:
with self._final_packing_sample_index.ctx(pack_idx):
final_pack = self.final_packer(pack)
if isinstance(final_pack, Generator):
assert inspect.isgeneratorfunction(self.final_packer), (
f"Generator in {self.final_packer} but not marked as such."
)
for cur_batch_sub_idx, (pack_idx, inner_batch_sample) in enumerate(
self._final_packing_sample_index.iter_ctx(final_pack, pack_idx)
):
self._last_final_pack_failures = 0
if cur_batch_sub_idx == pack_sub_idx:
return set_sample_restore_key(
inner_batch_sample,
pack_idx,
pack_sub_idx,
*pack_restore_keys,
src=self,
)
assert False, f"Pack sub-index {pack_sub_idx} not found in pack"
else:
self._last_final_pack_failures = 0
return set_sample_restore_key(final_pack, pack_idx, *pack_restore_keys, src=self)
except GeneratorExit:
raise FatalSampleError.from_sample(
pack,
f"PackingDataset {self.final_packer} generator exited while trying to restore a pack.",
)
except SkipSample:
raise FatalSampleError.from_sample(
pack, f"PackingDataset {self.final_packer} skipped while trying to restore a pack."
)
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(pack)
except Exception as e:
self.error_handler(e, pack)
self._last_final_pack_failures += 1
if (
self.final_packer_failure_tolerance > 0
and self._last_final_pack_failures >= self.final_packer_failure_tolerance
):
raise FatalSampleError.from_sample(
pack,
f"PackingDataset {self.final_packer} failed {self._last_final_pack_failures} times. Likely your code or dataset are broken.",
)
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"buffer_size": self.buffer_size,
"pre_packer": self._function_config(self.pre_packer),
"final_packer": self._function_config(self.final_packer),
"final_packer_stateless": self.final_packer_stateless,
**(
{
"packer_config": (
self.packer_config() if callable(self.packer_config) else self.packer_config
)
}
if self.packer_config
else {}
),
"error_handler": self._function_config(self.error_handler),
"worker_config": self.worker_config.config(),
"dataset": self.dataset.config(),
}
def __str__(self):
return f"PackingDataset(buffer_size={self.buffer_size}, pre_packer={self.pre_packer}, final_packer={self.final_packer}, dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import math
from typing import Any, Dict, Generic, Iterator, Optional, TypeVar, Union
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
T_sample = TypeVar("T_sample")
class RepeatDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""This dataset repeats the inner dataset indefinitely or a specific number of repeats."""
repeats: Optional[Union[int, float]]
_repetition: int
_index: int
_savable_fields = ("_repetition", "_index")
def __init__(
self,
dataset: SavableDataset[T_sample],
*,
repeats: Optional[Union[int, float]] = None,
restart: bool = True,
worker_config: WorkerConfig,
):
"""Construct a RepeatDataset.
Args:
dataset: The input dataset to repeat.
repeats: Number of repeats, `None` for indefinitely repeating.
restart: If true, restart the underlying dataset after iterating once through the
repeats if repeats is set to an integer, but still stop iterating.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
self.repeats = repeats
self.restart = restart
self.reset_state_own()
def reset_state_own(self) -> None:
self._repetition = 0
self._index = 0
def len_worker(self, worker_idx: int | None = None) -> int:
if self.repeats is None:
return self.dataset.len_worker(worker_idx)
return int(self.dataset.len_worker(worker_idx) * self.repeats)
def __iter__(self) -> Iterator[T_sample]:
assert self.repeats is not None or self.dataset.worker_has_samples(), (
"Cannot repeat empty dataset indefinitely"
)
# TODO: There is a small difference in the total sum of samples (across ranks) * repeats
# and the sum(len_worker() for all workers across ranks).
# This is due to the fact that the number of samples is not exactly divisible by the number of workers.
# The dataset length is the size for the current rank. Need to divide by the number of workers
ds_len = self.dataset.len_worker()
while self.repeats is None or self._repetition < self.repeats:
if self.repeats is not None and self._repetition == math.floor(self.repeats):
# Last iteration, adjust the number of samples
fraction = self.repeats - math.floor(self.repeats)
stop_after = math.floor(ds_len * fraction)
if self._index >= stop_after:
# We restored an index and it is already past the stop_after
break
else:
stop_after = None
for sample in self.dataset:
self._index += 1
yield sample
if stop_after is not None and self._index >= stop_after:
break
if self.worker_config.should_log(level=2):
self.worker_config.worker_log(
{
"t": "RepeatDataset.repeat",
"r": self.worker_config.rank,
"w": self.worker_config.rank_worker_id(),
"offset": self._repetition,
"repeats": self.repeats,
}
)
self._repetition += 1
self._index = 0
if self.restart:
self._repetition = 0
else:
# No more repeats
self._repetition = math.ceil(self.repeats)
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset": self.dataset.config(),
"repeats": self.repeats,
"worker_config": self.worker_config.config(),
}
def __str__(self):
return f"RepeatDataset(repeats={self.repeats}, dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Dict, Generic, Iterator, Tuple, TypeVar, Union
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.rng import WorkerRng
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
from megatron.energon.wrappers.buffer import SavableSampleBuffer
T_sample = TypeVar("T_sample")
class ShuffleBufferDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""Shuffle buffer for the dataset."""
size: int
_worker_rng: WorkerRng
_active_buffer: SavableSampleBuffer[T_sample]
_savable_fields = ("_active_buffer", "_worker_rng")
def __init__(
self,
dataset: SavableDataset[T_sample],
size: int,
*,
worker_config: WorkerConfig,
):
"""Create a shuffle buffer for the dataset."""
super().__init__(dataset, worker_config=worker_config)
self.size = size
self.reset_state_own()
def reset_state_own(self) -> None:
self._worker_rng = WorkerRng(self.worker_config)
self._active_buffer = SavableSampleBuffer(self.dataset, worker_config=self.worker_config)
def len_worker(self, worker_idx: int | None = None) -> int:
return self.dataset.len_worker(worker_idx)
def __iter__(self) -> Iterator[T_sample]:
self._active_buffer.worker_start()
it = iter(self._active_buffer.append_iter())
while True:
if self._active_buffer.len_worker() >= self.size:
pop_idx = self._worker_rng.randbelow(self._active_buffer.len_worker())
yield self._active_buffer.pop(pop_idx)
else:
try:
next(it)
except StopIteration:
break
while self._active_buffer.len_worker() > 0:
pop_idx = self._worker_rng.randbelow(self._active_buffer.len_worker())
yield self._active_buffer.pop(pop_idx)
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample:
return self._active_buffer.restore_sample(restore_key)
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset": self.dataset.config(),
"size": self.size,
"worker_config": self.worker_config.config(),
}
def __str__(self):
return f"ShuffleBufferDataset(size={self.size}, dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
class SkipSample(Exception):
"""Exception to raise in the map_fn to skip a sample."""
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from typing import Any, Dict, Generic, Iterator, Optional, TypeVar
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.watchdog import Watchdog
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
T_sample = TypeVar("T_sample")
class WatchdogDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""This dataset wraps another dataset and watches the time it takes to yield samples."""
def __init__(
self,
dataset: SavableDataset[T_sample],
worker_config: WorkerConfig,
timeout_seconds: Optional[float] = 60,
initial_timeout_seconds: Optional[float] = None,
fail_on_timeout: bool = False,
):
"""Construct the watchdog dataset, which wraps another dataset and watches
the time it takes to yield samples from the wrapped dataset.
Args:
dataset: The input dataset to wrap
worker_config: The worker configuration
timeout_seconds: The timeout in seconds. If None, the watchdog is disabled.
initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as timeout_seconds.
fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace.
"""
super().__init__(dataset, worker_config=worker_config)
self.timeout_seconds = timeout_seconds
self.initial_timeout_seconds = initial_timeout_seconds
self.fail_on_timeout = fail_on_timeout
def reset_state_own(self) -> None:
pass
def len_worker(self, worker_idx: int | None = None) -> int:
return self.dataset.len_worker(worker_idx)
def _watchdog_trigger(self) -> None:
if self.fail_on_timeout:
# Raising an exception here will kill the whole process
raise TimeoutError(
f"Watchdog triggered. Sample processing took longer than {self.timeout_seconds} seconds."
)
else:
warnings.warn(
f"Watchdog triggered. Sample processing took longer than {self.timeout_seconds} seconds.",
RuntimeWarning,
)
def __iter__(self) -> Iterator[T_sample]:
if self.timeout_seconds is None:
yield from self.dataset
else:
watchdog = Watchdog(
timeout=self.timeout_seconds,
initial_timeout=self.initial_timeout_seconds,
callback=self._watchdog_trigger,
enabled=False,
)
yield from watchdog.watch_iter(self.dataset)
def config(self) -> Dict[str, Any]:
# Watchdog is transparent, it won't change the samples
return self.dataset.config()
def __str__(self):
return f"WatchdogDataset(dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from multistorageclient.rclone import read_rclone_config
from tests.s3_emulator.state import S3State
from tests.s3_emulator.test import s3_emulator
@contextmanager
def setup_s3_emulator(
*,
port: int = 0,
access_key: str = "test",
secret_key: str = "test",
root_dir: str | None = None,
region: str = "us-east-1",
profile_name: str = "s3test",
) -> Generator[S3State, None, None]:
"""Set up S3 emulator and write necessary config files.
Args:
port: Port to bind the server to. Use 0 to let the OS choose a free port.
access_key: Access key for authentication
secret_key: Secret key for authentication
root_dir: Optional directory to persist S3 data
region: Region for authentication
profile_name: Name of the rclone profile. Must be different in all tests, to ensure that a
cached rclone config is used in MSC.
Returns:
The S3 emulator state. Can be used to quickly upload files to the emulator.
"""
try:
with s3_emulator(
host="127.0.0.1",
port=port,
credentials={access_key: secret_key},
root_dir=root_dir,
region=region,
) as emu:
# Create config directory
config_dir = Path("/tmp/XDG_CONFIG_HOME/.config/rclone")
config_dir.mkdir(parents=True, exist_ok=True)
# Write rclone config
config_path = config_dir / "rclone.conf"
with config_path.open("w") as f:
f.write(
"\n".join(
[
f"[{profile_name}]",
"type = s3",
"env_auth = false",
f"access_key_id = {access_key}",
f"secret_access_key = {secret_key}",
f"region = {region}",
f"endpoint = http://127.0.0.1:{emu.port}",
]
)
)
# Set environment variables
os.environ["XDG_CONFIG_HOME"] = "/tmp/XDG_CONFIG_HOME/.config"
os.environ["HOME"] = "/tmp/XDG_CONFIG_HOME"
# Hack to clear the cache of the rclone config for msc to get the "s3" profile
read_rclone_config.cache_clear()
yield emu.state
read_rclone_config.cache_clear()
except Exception as e:
print("ERROR in s3_emulator", flush=True)
print("Full traceback:", flush=True)
import traceback
traceback.print_exc()
raise e
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from .server import S3EmulatorServer
from .state import S3State
__all__ = ["S3EmulatorServer", "S3State"]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import hmac
import re
import urllib.parse as _up
from hashlib import sha256
from typing import Dict, Mapping, MutableMapping
__all__ = ["S3Auth", "InvalidSignature"]
_SIGNED_HEADERS_RE = re.compile(r"SignedHeaders=([^,]+)")
_CREDENTIAL_RE = re.compile(r"Credential=([^,]+)")
_SIGNATURE_RE = re.compile(r"Signature=([0-9a-fA-F]+)")
class InvalidSignature(Exception):
"""Raised when the supplied signature does not match."""
class S3Auth:
"""Very small subset implementation of AWS Signature V4 verification.
Only what is mandatory for the emulator to work for most typical SDK
operations is implemented. Notably, chunked uploads and presigned URLs are
not supported.
"""
def __init__(self, credentials: Mapping[str, str], region: str = "us-east-1") -> None:
"""Initialize the S3 authentication handler.
Args:
credentials: Mapping of access_key to secret_key accepted by the server.
region: AWS region assumed when verifying the signing key.
"""
self._creds: Dict[str, str] = dict(credentials)
self._region = region
def verify(
self,
method: str,
canonical_uri: str,
canonical_querystring: str,
headers: Mapping[str, str] | MutableMapping[str, str],
payload: bytes,
) -> None:
"""Validate the Authorization header for the given request.
Args:
method: HTTP method of the request.
canonical_uri: Canonical URI path.
canonical_querystring: Canonical query string.
headers: Request headers.
payload: Request body.
"""
auth_header = headers.get("authorization") or headers.get("Authorization")
if auth_header is None:
raise InvalidSignature("Missing Authorization header")
signed_headers = _first_group(_SIGNED_HEADERS_RE, auth_header)
credential_str = _first_group(_CREDENTIAL_RE, auth_header)
signature = _first_group(_SIGNATURE_RE, auth_header)
if not (signed_headers and credential_str and signature):
raise InvalidSignature("Malformed Authorization header")
access_key, date_str, region, service, terminator = credential_str.split("/")
if service != "s3" or terminator != "aws4_request":
raise InvalidSignature("Invalid credential scope")
if region != self._region:
print(f"Signature region {region} does not match server region {self._region}")
secret_key = self._creds.get(access_key)
if secret_key is None:
raise InvalidSignature("Unknown access key")
# Canonical URI & query string (encode & normalise)
canonical_uri = _canonical_uri(canonical_uri)
canonical_querystring = _canonical_querystring(canonical_querystring)
# Construct canonical request ------------------------------------------------
# 1. Canonical headers
canonical_headers = ""
for hdr in signed_headers.split(";"):
hdr_lower = hdr.lower()
value = headers.get(hdr) or headers.get(hdr_lower)
if value is None:
raise InvalidSignature(f"Signed header '{hdr}' missing from request")
canonical_headers += f"{hdr_lower}:{_normalize_whitespace(str(value))}\n"
# 2. Hashed payload
payload_hash = sha256(payload).hexdigest()
# 3. Canonical request string
canonical_request = "\n".join(
[
method,
canonical_uri,
canonical_querystring,
canonical_headers,
signed_headers,
payload_hash,
]
)
hashed_canonical_request = sha256(canonical_request.encode()).hexdigest()
# String to sign
amz_date = headers.get("x-amz-date") or headers.get("X-Amz-Date")
if amz_date is None:
raise ValueError("Missing x-amz-date header")
string_to_sign = "\n".join(
[
"AWS4-HMAC-SHA256",
amz_date,
"/".join([date_str, region, "s3", "aws4_request"]),
hashed_canonical_request,
]
)
# Calculate signing key and signature
date_key = _sign(("AWS4" + secret_key).encode(), date_str)
region_key = _sign(date_key, region)
service_key = _sign(region_key, "s3")
signing_key = _sign(service_key, "aws4_request")
calc_signature = hmac.new(signing_key, string_to_sign.encode(), sha256).hexdigest()
if not hmac.compare_digest(calc_signature, signature):
print(f"Sig mismatch: expected={signature} got={calc_signature}")
raise InvalidSignature("Signature mismatch")
def _first_group(regex: re.Pattern[str], string: str) -> str | None:
"""Extract the first capture group from a regex match.
Args:
regex: The regex pattern to match.
string: The string to search in.
Returns:
The first capture group if found, None otherwise.
"""
match = regex.search(string)
return match.group(1) if match else None
def _sign(key: bytes, msg: str) -> bytes:
"""Sign a message with a key using HMAC-SHA256.
Args:
key: The signing key.
msg: The message to sign.
Returns:
The HMAC-SHA256 signature.
"""
return hmac.new(key, msg.encode(), sha256).digest()
def _normalize_whitespace(value: str) -> str:
"""Collapse consecutive whitespace.
Args:
value: The string to normalize.
Returns:
The normalized string with collapsed whitespace.
"""
return " ".join(value.strip().split())
def _percent_encode(value: str) -> str:
"""Percent encode a string using AWS safe characters.
Args:
value: The string to encode.
Returns:
The percent-encoded string.
"""
return _up.quote(value, safe="-_.~")
def _canonical_uri(uri: str) -> str:
"""Return URI-encoded path as required by SigV4.
Each segment between / must be percent-encoded with the AWS safe list
-_.~. Duplicate slashes are preserved (AWS behaviour).
Args:
uri: The URI path to canonicalize.
Returns:
The canonical URI path.
"""
if uri == "":
return "/"
encoded_parts = [_percent_encode(_up.unquote(part)) for part in uri.split("/")]
prefix = "" if uri.startswith("/") else "/"
return prefix + "/".join(encoded_parts)
def _canonical_querystring(raw_qs: str) -> str:
"""Canonicalize a query string according to AWS SigV4 rules.
Args:
raw_qs: The raw query string to canonicalize.
Returns:
The canonical query string.
"""
if raw_qs == "":
return ""
pairs = _up.parse_qsl(raw_qs, keep_blank_values=True)
encoded_pairs = [(_percent_encode(k), _percent_encode(v)) for k, v in pairs]
encoded_pairs.sort()
return "&".join(f"{k}={v}" for k, v in encoded_pairs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import urllib.parse as _up
from datetime import datetime, timezone
from email.utils import formatdate
from hashlib import md5
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler
from typing import Protocol
from .auth import InvalidSignature, S3Auth
from .state import S3State
__all__ = ["S3RequestHandler"]
class S3RequestHandler(BaseHTTPRequestHandler):
"""HTTP request handler implementing a minimal S3-compatible API.
This handler processes HTTP requests and maps them to S3 operations.
It supports basic S3 operations like bucket and object management,
including multipart uploads.
"""
server: "S3ServerProtocol" # type: ignore[assignment]
def log_message(self, fmt: str, *args):
"""Log a message to stdout.
Args:
fmt: Format string for the message.
*args: Arguments to format the message with.
"""
print(f"{self.client_address[0]} - - {fmt % args}")
def do_PUT(self):
"""Handle PUT requests for object creation and bucket creation."""
self._handle_write()
def do_GET(self):
"""Handle GET requests for object retrieval and bucket listing."""
self._handle_read(listing=False)
def do_HEAD(self):
"""Handle HEAD requests for object metadata."""
self._handle_read(listing=False, only_headers=True)
def do_DELETE(self):
"""Handle DELETE requests for object and bucket deletion."""
self._handle_delete()
def do_POST(self):
"""Handle POST requests for multipart upload operations."""
self._handle_post()
def _read_body(self) -> bytes:
"""Read and return the request body.
Returns:
The request body as bytes.
"""
length = int(self.headers.get("Content-Length", 0))
if length == 0:
return b""
data = self.rfile.read(length)
return data
def _split_path(self) -> tuple[str, str, _up.ParseResult]:
"""Split the request path into bucket and key components.
Returns:
A tuple of (bucket, key, parsed_url).
"""
parsed = _up.urlparse(self.path)
parts = [p for p in parsed.path.split("/") if p]
bucket = parts[0] if parts else ""
key = "/".join(parts[1:]) if len(parts) > 1 else ""
return bucket, key, parsed
def _auth(self, payload: bytes, parsed: _up.ParseResult) -> bool:
"""Verify the request signature.
Args:
payload: The request body.
parsed: The parsed URL.
Returns:
True if authentication succeeds, False otherwise.
"""
try:
self.server.auth.verify(
method=self.command,
canonical_uri=parsed.path or "/",
canonical_querystring=parsed.query,
headers=self.headers,
payload=payload,
)
except InvalidSignature as err:
self._send_error(HTTPStatus.FORBIDDEN, str(err))
return False
except ValueError as err:
self._send_error(HTTPStatus.BAD_REQUEST, str(err))
return False
return True
def _handle_write(self):
"""Handle PUT requests for object creation and bucket creation."""
bucket, key, parsed = self._split_path()
body = self._read_body()
if not self._auth(body, parsed):
return
qs = _up.parse_qs(parsed.query, keep_blank_values=True)
# Multipart: upload part
if "uploadId" in qs and "partNumber" in qs:
upload_id = qs["uploadId"][0]
try:
part_no = int(qs["partNumber"][0])
except ValueError:
self._send_error(HTTPStatus.BAD_REQUEST, "Invalid partNumber")
return
try:
self.server.state.upload_part(upload_id, part_no, body)
except KeyError:
self._send_error(HTTPStatus.NOT_FOUND, "Upload not found")
return
self._send_status(HTTPStatus.OK, extra_headers={"ETag": _etag(body)})
return
if not bucket:
self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified")
return
if key == "": # Bucket create
self.server.state.create_bucket(bucket)
self._send_status(HTTPStatus.OK)
return
# Put object
self.server.state.put_object(bucket, key, body)
self._send_status(
HTTPStatus.OK,
extra_headers={"ETag": _etag(body)},
)
def _handle_read(self, listing: bool, only_headers: bool = False):
"""Handle GET/HEAD requests for object retrieval and bucket listing.
Args:
listing: Whether this is a bucket listing request.
only_headers: Whether to return only headers (HEAD request).
"""
bucket, key, parsed = self._split_path()
body = b"" # GET/HEAD normally payload considered in signature (hash of empty string)
if not self._auth(body, parsed):
return
if not bucket:
self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified")
return
if key == "": # List bucket contents
if not listing:
# We treat listing with GET only
try:
objects = self.server.state.list_objects(bucket)
except KeyError:
self._send_error(HTTPStatus.NOT_FOUND, "Bucket not found")
return
xml_body = self._render_bucket_list(bucket, objects)
self._send_bytes(xml_body, content_type="application/xml")
else:
self._send_error(HTTPStatus.NOT_IMPLEMENTED, "Listing not implemented")
return
try:
data = self.server.state.get_object(bucket, key)
except FileNotFoundError:
self._send_error(HTTPStatus.NOT_FOUND, "Not found")
return
range_header = self.headers.get("Range")
if range_header and range_header.startswith("bytes="):
rng = range_header.split("=", 1)[1]
if "-" not in rng:
self._send_error(HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE, "Invalid Range")
return
start_str, end_str = rng.split("-", 1)
try:
start = int(start_str) if start_str else 0
end = int(end_str) if end_str else len(data) - 1
except ValueError:
self._send_error(HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE, "Invalid Range")
return
if start > end or start >= len(data):
self._send_error(HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE, "Invalid Range")
return
end = min(end, len(data) - 1)
slice_data = data[start : end + 1]
headers = {
"Content-Range": f"bytes {start}-{end}/{len(data)}",
"Accept-Ranges": "bytes",
"Content-Length": str(len(slice_data)),
"ETag": _etag(data),
}
if only_headers:
headers.setdefault("Content-Type", "application/octet-stream")
headers.setdefault("Last-Modified", formatdate(usegmt=True))
self._send_status(HTTPStatus.PARTIAL_CONTENT, extra_headers=headers)
else:
self._send_bytes(
slice_data,
status=HTTPStatus.PARTIAL_CONTENT,
content_type="application/octet-stream",
extra_headers=headers,
)
else:
if only_headers:
self._send_status(
HTTPStatus.OK,
extra_headers={
"Content-Length": str(len(data)),
"Accept-Ranges": "bytes",
"Content-Type": "application/octet-stream",
"Last-Modified": formatdate(usegmt=True),
"ETag": _etag(data),
},
)
else:
self._send_bytes(
data,
content_type="application/octet-stream",
extra_headers={"Accept-Ranges": "bytes"},
)
def _handle_delete(self):
"""Handle DELETE requests for object and bucket deletion."""
bucket, key, parsed = self._split_path()
body = b"" # empty
if not self._auth(body, parsed):
return
if not bucket:
self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified")
return
if key == "":
try:
self.server.state.delete_bucket(bucket)
except (KeyError, RuntimeError) as err:
self._send_error(HTTPStatus.BAD_REQUEST, str(err))
return
self._send_status(HTTPStatus.NO_CONTENT)
return
try:
self.server.state.delete_object(bucket, key)
except FileNotFoundError:
self._send_error(HTTPStatus.NOT_FOUND, "Not found")
return
self._send_status(HTTPStatus.NO_CONTENT)
def _handle_post(self):
"""Handle POST requests for multipart upload operations."""
bucket, key, parsed = self._split_path()
body = self._read_body()
if not self._auth(body, parsed):
return
qs = _up.parse_qs(parsed.query, keep_blank_values=True)
# Initiate multipart: POST ?uploads
if "uploads" in qs or parsed.query == "uploads":
upload_id = self.server.state.initiate_multipart(bucket, key)
xml = (
'<?xml version="1.0" encoding="UTF-8"?>'
"<InitiateMultipartUploadResult>"
f"<Bucket>{_escape_xml(bucket)}</Bucket>"
f"<Key>{_escape_xml(key)}</Key>"
f"<UploadId>{upload_id}</UploadId>"
"</InitiateMultipartUploadResult>"
).encode()
self._send_bytes(xml, status=HTTPStatus.OK, content_type="application/xml")
return
# Complete multipart: POST ?uploadId=xxxx
if "uploadId" in qs:
upload_id = qs["uploadId"][0]
try:
self.server.state.complete_multipart(upload_id)
except KeyError:
self._send_error(HTTPStatus.NOT_FOUND, "Upload not found")
return
xml = (
'<?xml version="1.0" encoding="UTF-8"?>'
"<CompleteMultipartUploadResult>"
f"<Bucket>{_escape_xml(bucket)}</Bucket>"
f"<Key>{_escape_xml(key)}</Key>"
f"<UploadId>{upload_id}</UploadId>"
"</CompleteMultipartUploadResult>"
).encode()
self._send_bytes(xml, status=HTTPStatus.OK, content_type="application/xml")
return
self._send_error(HTTPStatus.NOT_IMPLEMENTED, "Unsupported POST request")
def _send_status(self, status: HTTPStatus, extra_headers: dict[str, str] | None = None):
"""Send an HTTP response with the given status code.
Args:
status: The HTTP status code to send.
extra_headers: Optional additional headers to include.
"""
self.send_response(status.value)
headers = {"Server": "s3-emulator"}
if extra_headers:
headers.update(extra_headers)
for k, v in headers.items():
self.send_header(k, v)
self.end_headers()
def _send_error(self, status: HTTPStatus, message: str):
"""Send an error response.
Args:
status: The HTTP status code to send.
message: The error message to include in the response.
"""
print(f"Error {status}: {message}")
self._send_bytes(message.encode(), status=status, content_type="text/plain")
def _send_bytes(
self,
data: bytes,
status: HTTPStatus = HTTPStatus.OK,
content_type: str = "application/octet-stream",
extra_headers: dict[str, str] | None = None,
) -> None:
"""Send a response with binary data.
Args:
data: The binary data to send.
status: The HTTP status code to send. Defaults to 200 OK.
content_type: The Content-Type header value. Defaults to application/octet-stream.
extra_headers: Optional additional headers to include.
"""
self.send_response(status.value)
headers = {
"Server": "s3-emulator",
"Content-Type": content_type,
"Content-Length": str(len(data)),
}
if extra_headers:
headers.update(extra_headers)
for k, v in headers.items():
self.send_header(k, v)
self.end_headers()
if self.command != "HEAD":
self.wfile.write(data)
@staticmethod
def _render_bucket_list(bucket: str, objects: list[str]) -> bytes:
"""Generate an XML listing of objects in a bucket.
Args:
bucket: The bucket name.
objects: List of object keys in the bucket.
Returns:
The XML document as bytes.
"""
entries = []
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
for key in objects:
try:
data = S3RequestHandler.server.state.get_object(bucket, key) # type: ignore[attr-defined]
size = len(data)
etag = _etag(data)
except Exception: # noqa: BLE001
size = 0
etag = '""'
entries.append(
"<Contents>"
f"<Key>{_escape_xml(key)}</Key>"
f"<LastModified>{now}</LastModified>"
f"<ETag>{etag}</ETag>"
f"<Size>{size}</Size>"
"</Contents>"
)
obj_elems = "".join(entries)
xml = (
'<?xml version="1.0" encoding="UTF-8"?>'
"<ListBucketResult>"
f"<Name>{_escape_xml(bucket)}</Name>"
f"{obj_elems}"
"</ListBucketResult>"
)
return xml.encode()
class S3ServerProtocol(Protocol): # noqa: D101
state: S3State
auth: S3Auth
def _escape_xml(text: str) -> str: # noqa: D401
"""Escape special characters for XML.
Args:
text: The text to escape.
Returns:
The escaped text.
"""
return (
text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)
def _etag(data: bytes) -> str: # noqa: D401
"""Generate an ETag for binary data.
Args:
data: The binary data to generate an ETag for.
Returns:
The MD5 hash of the data as a hex string.
"""
return md5(data).hexdigest()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from pathlib import Path
import click
from .server import S3EmulatorServer
@click.command()
@click.option(
"--host",
default="0.0.0.0",
help="Host to bind the server to",
)
@click.option(
"--port",
default=9000,
type=int,
help="Port to bind the server to",
)
@click.option(
"--root-dir",
type=click.Path(path_type=Path),
help="Directory to persist S3 data",
)
@click.option(
"--access-key",
default="test",
help="Access key for authentication",
)
@click.option(
"--secret-key",
default="test",
help="Secret key for authentication",
)
@click.option(
"--region",
default="us-east-1",
help="Region for authentication",
)
def main(
host: str, port: int, root_dir: Path | None, access_key: str, secret_key: str, region: str
) -> None:
"""Start an S3 emulator server."""
server = S3EmulatorServer(
host=host,
port=port,
credentials={access_key: secret_key},
root_dir=root_dir,
region=region,
)
try:
server.serve_forever()
except KeyboardInterrupt:
server.shutdown()
if __name__ == "__main__":
main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import threading
from http.server import ThreadingHTTPServer
from pathlib import Path
from typing import Mapping
from .auth import S3Auth
from .handler import S3RequestHandler
from .state import S3State
__all__ = ["S3EmulatorServer"]
class S3EmulatorServer:
"""A lightweight, blocking S3 HTTP emulator.
This server provides a minimal S3-compatible HTTP interface for testing purposes.
It supports basic S3 operations like bucket and object management.
Example (blocking)::
server = S3EmulatorServer(
host="127.0.0.1",
port=9000,
credentials={"ACCESS": "SECRET"},
)
server.serve_forever()
Example (threaded)::
server = S3EmulatorServer(
host="127.0.0.1",
port=9000,
credentials={"ACCESS": "SECRET"},
)
server.start_background()
# ...
server.shutdown()
server.join()
"""
def __init__(
self,
host: str = "0.0.0.0",
port: int = 0,
*,
credentials: Mapping[str, str] | None = None,
root_dir: str | Path | None = None,
region: str = "us-east-1",
):
"""
This server provides a minimal S3-compatible HTTP interface for testing purposes.
It supports basic S3 operations like bucket and object management.
There is no need to check that the port is bound, it is already bound after initialization.
Retrieve the real port with `.port` if set to 0.
The server is listening to the port immediately, but will only start processing after
`start_background()` (threaded) or `.serve_forever()` (blocking) is called.
Args:
host: The host address to bind to.
port: The port to bind to. Use 0 to let the OS choose a free port.
credentials: Optional mapping of access keys to secret keys.
root_dir: Optional path to persist the S3 store on disk.
region: AWS region to emulate.
"""
self._state = S3State(Path(root_dir) if root_dir else None)
self._auth = S3Auth(credentials or {"test": "test"}, region=region)
class _Server(ThreadingHTTPServer):
state = self._state
auth = self._auth
self._httpd: ThreadingHTTPServer = _Server((host, port), S3RequestHandler)
self._thread: threading.Thread | None = None
print(f"S3 emulator on http://{host}:{self.port}", flush=True)
@property
def port(self) -> int:
"""Returns the port number the server is bound to."""
return self._httpd.server_port
@property
def state(self) -> S3State:
"""Returns the internal S3 state object."""
return self._state
def serve_forever(self):
"""Start the server and block until shutdown is called.
This method will block the calling thread. For non-blocking usage,
see start_background().
"""
try:
self._httpd.serve_forever()
finally:
self._state.flush()
def shutdown(self):
"""Shutdown the server and flush any pending state changes."""
self._httpd.shutdown()
self._state.flush()
def start_background(self):
"""Start the server in a background thread."""
if self._thread and self._thread.is_alive():
raise RuntimeError("Server already running")
def _run():
self.serve_forever()
self._thread = threading.Thread(target=_run, daemon=True)
self._thread.start()
def join(self, timeout: float | None = None):
"""Join the background thread.
Args:
timeout: Optional timeout in seconds to wait for thread completion.
"""
if self._thread is None:
return
self._thread.join(timeout)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import json
from pathlib import Path
from threading import RLock
from typing import Dict, Optional
from uuid import uuid4
__all__ = ["S3State"]
class S3State:
"""A minimal, thread-safe, in-memory representation of an S3 object store.
Optionally, a root_dir can be supplied to persist the store on the local
file system. The directory structure mirrors the S3 layout:
<root_dir>/<bucket>/<key>
Buckets are directories, objects are stored as regular files. Metadata is
not currently persisted beyond the object byte payload.
"""
def __init__(self, root_dir: Optional[Path] = None) -> None:
"""
Args:
root_dir: Path to persist the store on disk.
"""
self._fs: Dict[str, Dict[str, bytes]] = {}
self._uploads: Dict[str, _MultipartUpload] = {}
self._lock = RLock()
self._root_dir = root_dir
if self._root_dir is not None:
self._root_dir.mkdir(parents=True, exist_ok=True)
self._load_from_disk()
def list_buckets(self) -> list[str]:
"""List all buckets in the store.
Returns:
Sorted list of bucket names.
"""
with self._lock:
return sorted(self._fs.keys())
def create_bucket(self, bucket: str) -> None:
"""Create a new bucket.
Args:
bucket: Name of the bucket to create.
"""
with self._lock:
if bucket in self._fs:
print(f"Bucket '{bucket}' already exists")
return
self._fs[bucket] = {}
if self._root_dir is not None:
(self._root_dir / bucket).mkdir(parents=True, exist_ok=True)
def delete_bucket(self, bucket: str) -> None:
"""Delete a bucket.
Args:
bucket: Name of the bucket to delete.
"""
with self._lock:
if bucket not in self._fs:
raise KeyError(f"Bucket '{bucket}' does not exist")
if self._fs[bucket]:
raise RuntimeError("Bucket not empty")
del self._fs[bucket]
if self._root_dir is not None:
bucket_path = self._root_dir / bucket
if bucket_path.exists():
for p in bucket_path.rglob("*"):
p.unlink()
bucket_path.rmdir()
def put_object(self, bucket: str, key: str, data: bytes) -> None:
"""Store an object in a bucket.
Args:
bucket: Name of the bucket.
key: Object key.
data: Object data.
"""
if not bucket:
raise ValueError("Bucket name must be given")
with self._lock:
if bucket not in self._fs:
self._fs[bucket] = {}
self._fs[bucket][key] = data
if self._root_dir is not None:
obj_path = (self._root_dir / bucket / key).resolve()
obj_path.parent.mkdir(parents=True, exist_ok=True)
obj_path.write_bytes(data)
def get_object(self, bucket: str, key: str) -> bytes:
"""Retrieve an object from a bucket.
Args:
bucket: Name of the bucket.
key: Object key.
Returns:
The object data.
"""
with self._lock:
try:
return self._fs[bucket][key]
except KeyError as exc:
raise FileNotFoundError(f"{bucket}/{key}") from exc
def delete_object(self, bucket: str, key: str) -> None:
"""Delete an object from a bucket.
Args:
bucket: Name of the bucket.
key: Object key.
"""
with self._lock:
try:
del self._fs[bucket][key]
except KeyError as exc:
raise FileNotFoundError(f"{bucket}/{key}") from exc
if self._root_dir is not None:
obj_path = self._root_dir / bucket / key
if obj_path.exists():
obj_path.unlink(missing_ok=True)
def list_objects(self, bucket: str) -> list[str]:
"""List all objects in a bucket.
Args:
bucket: Name of the bucket.
Returns:
Sorted list of object keys.
"""
with self._lock:
if bucket not in self._fs:
raise KeyError(f"Bucket '{bucket}' does not exist")
return sorted(self._fs[bucket].keys())
STATE_FILE = "__state.json"
def _load_from_disk(self) -> None:
"""Load persisted state from root_dir.
The object payload itself is not loaded in memory to keep startup
affordable. Only the structure (bucket -> keys) is persisted in a
state file.
"""
if self._root_dir is None:
return
state_file = self._root_dir / self.STATE_FILE
if not state_file.exists():
return
try:
mapping = json.loads(state_file.read_text())
except Exception as err: # noqa: BLE001
print(f"Failed to read persisted state: {err}")
return
with self._lock:
self._fs = {bucket: {key: b"" for key in keys} for bucket, keys in mapping.items()}
def flush(self) -> None:
"""Persist only the structure of the store to disk."""
if self._root_dir is None:
return
mapping = {bucket: list(objects.keys()) for bucket, objects in self._fs.items()}
(self._root_dir / self.STATE_FILE).write_text(json.dumps(mapping))
def initiate_multipart(self, bucket: str, key: str) -> str:
"""Create a new multipart upload.
Args:
bucket: Name of the bucket.
key: Object key.
Returns:
The upload ID.
"""
with self._lock:
upload_id = uuid4().hex
self._uploads[upload_id] = _MultipartUpload(bucket, key)
if bucket not in self._fs:
self._fs[bucket] = {}
return upload_id
def upload_part(self, upload_id: str, part_number: int, data: bytes) -> None:
"""Upload a part of a multipart upload.
Args:
upload_id: The upload ID.
part_number: The part number.
data: The part data.
"""
with self._lock:
mp = self._uploads.get(upload_id)
if mp is None:
raise KeyError("Invalid upload_id")
mp.parts[part_number] = data
def complete_multipart(self, upload_id: str) -> None:
"""Complete a multipart upload.
Args:
upload_id: The upload ID.
"""
with self._lock:
mp = self._uploads.pop(upload_id, None)
if mp is None:
raise KeyError("Invalid upload_id")
data = mp.assemble()
if mp.bucket not in self._fs:
self._fs[mp.bucket] = {}
self._fs[mp.bucket][mp.key] = data
if self._root_dir is not None:
obj_path = (self._root_dir / mp.bucket / mp.key).resolve()
obj_path.parent.mkdir(parents=True, exist_ok=True)
obj_path.write_bytes(data)
def abort_multipart(self, upload_id: str) -> None:
"""Abort a multipart upload.
Args:
upload_id: The upload ID.
"""
with self._lock:
self._uploads.pop(upload_id, None)
def add_file(self, src: Path, dst: str):
"""Add a file or directory to the store.
Args:
src: Source file or directory path.
dst: Destination path in S3 format (bucket/key).
"""
if src.is_dir():
dst = dst.removesuffix("/")
for file in src.iterdir():
self.add_file(file, dst=f"{dst}/{file.name}")
elif src.is_file():
bucket, key = dst.removeprefix("/").split("/", 1)
self.put_object(bucket, key, src.read_bytes())
else:
raise ValueError(f"Invalid file: {src}")
class _MultipartUpload:
"""Internal helper class for managing multipart uploads."""
__slots__ = ("bucket", "key", "parts")
def __init__(self, bucket: str, key: str):
self.bucket = bucket
self.key = key
self.parts: Dict[int, bytes] = {}
def assemble(self) -> bytes:
"""Assemble the uploaded parts into a complete object.
Returns:
The complete object data.
"""
if not self.parts:
return b""
return b"".join(self.parts[n] for n in sorted(self.parts))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment