Commit f356f546 authored by maming's avatar maming
Browse files

Initial commit

parents
Pipeline #3339 canceled with stages
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Callable, ClassVar, Dict, Optional, Union
from megatron.energon.epathlib import EPath
class ModuleLoader:
"""Class providing access to python files to be loaded as modules."""
#: Loaded modules. Global singleton, no need to load a module multiple times.
_modules: ClassVar[Dict[EPath, Dict[str, Any]]] = {}
def _load_module(self, path: Union[str, EPath]):
"""Internally loads a module"""
path = EPath(path)
with path.open("r") as rf:
src_code = rf.read()
code = compile(src_code, path.url, "exec")
mod = {}
eval(code, mod)
self._modules[path] = mod
def __getitem__(self, path: Union[str, EPath]) -> Any:
"""
Gets a module by path.
Args:
path: The path to load
Returns:
Loaded module dict.
"""
path = EPath(path)
if path not in self._modules:
self._load_module(path)
return self._modules[path]
def get_function(
self,
module_fn: str,
default_fn_name: Optional[str] = None,
relative_path: Optional[EPath] = None,
) -> Callable:
"""
Loads a function in the form `"module/submodule.py:function_name"`.
Args:
module_fn: The function string to load. Must contain the path of the python module
and a function name separated by a colon (might be omitted if `default_fn_name` is
specified).
default_fn_name: Default name of the function if not given in `module_fn` string.
relative_path: The relative parent path to the module. If not specified, the current
working directory / absolute path is used.
Returns:
The function from the module
"""
if ":" in module_fn:
module, fn_name = module_fn.rsplit(":", 1)
else:
if default_fn_name is None:
raise ValueError("Function name must be specified")
fn_name = default_fn_name
module = module_fn
if relative_path is not None:
module = str(relative_path / module)
fn = self[module][fn_name]
assert callable(fn)
return fn
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from types import TracebackType
from typing import BinaryIO, Optional, Type
from megatron.energon.epathlib import EPath
class RetryReadStream(BinaryIO):
"""A stream that retries reading from a file. Only supports reading bytes."""
_path: EPath
_file: Optional[BinaryIO]
_pos: int
_size: int
def __init__(self, path: EPath):
"""Construct a RetryReadStream. It reads only bytes from a file."""
self._path = path
self._file = None
self._pos = 0
self._size = path.size()
def __enter__(self) -> "RetryReadStream":
return self
def __exit__(
self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType
) -> None:
self.close()
def close(self) -> None:
if self._file is not None:
self._file.close()
def read(self, n: int = -1) -> bytes:
buf = b""
for retry in range(10):
try:
if self._file is None:
self._file = self._path.open("rb")
self._file.seek(self._pos)
res = self._file.read(n)
self._pos += len(res)
buf += res
if (
(n == -1 and self._pos >= self._size)
or len(buf) == n
or self._pos >= self._size
):
return res
except IOError:
try:
self._file.close()
except IOError:
pass
self._file = None
if retry == 9:
raise
continue
def seek(self, offset: int, whence: int = 0) -> int:
if whence == 0:
pass
elif whence == 1:
offset += self._pos
elif whence == 2:
offset += self._size
else:
raise ValueError(f"Invalid whence value: {whence}")
offset = min(max(offset, 0), self._size)
self._pos = offset
try:
if self._file is not None:
self._file.seek(offset)
except IOError:
pass
return self._pos
def tell(self) -> int:
return self._pos
def isatty(self) -> bool:
return False
def readable(self) -> bool:
return True
def seekable(self) -> bool:
return True
def writable(self) -> bool:
return False
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import hashlib
import random
from typing import Any, List, Mapping, Optional, Sequence, TypeVar
import numpy
import torch
import torch.distributed
import torch.utils.data
from megatron.energon.edataclass import edataclass
from megatron.energon.savable import FlexState, Savable
from megatron.energon.worker import WorkerConfig
T = TypeVar("T")
class WorkerRng(Savable):
"""Helper class for getting a worker random generator, which is still in itself deterministic.
If not in a worker, uses the global random generator's seed to initialize a new rng."""
worker_config: WorkerConfig
_rng: Optional[torch.Generator] = None
_restore_state: Optional[bytes] = None
def __init__(self, worker_config: WorkerConfig):
self.worker_config = worker_config
@property
def rng(self) -> torch.Generator:
if self._rng is None or self._restore_state is not None:
self.worker_config.assert_worker()
self._rng = torch.Generator()
if self._restore_state is not None:
self._rng.set_state(
torch.frombuffer(
bytearray(self._restore_state),
dtype=torch.uint8,
).clone()
)
else:
# Restore to initial state (either due to zero sized states, or just initial state)
self._rng.manual_seed(self.worker_config.worker_seed())
self._restore_state = None
return self._rng
def randbelow(self, n: int) -> int:
return torch.randint(0, n, (), generator=self.rng).item()
def choice_idx(self, probs: torch.Tensor) -> int:
if len(probs) == 1:
return 0
else:
# Custom implementation of multinomial to ensure consistency
# Torch changed their implementation of torch.multinomial in 2.7.0 and to be
# consistent with any torch version, we use a custom implementation here instead.
# This is anyways just a very simple case of multinomial, thus this should be fine.
# Actually, benchmarks show that this is faster than torch.multinomial by a factor of
# 10 even on CPU.
cdf = torch.cumsum(probs, dim=0)
val = torch.rand(1, generator=self.rng) * cdf[-1]
return torch.searchsorted(cdf, val).item()
def choice(self, l: List[T], probs: Optional[torch.Tensor] = None) -> T:
if probs is None:
return l[self.randbelow(len(l))]
assert len(l) == len(probs)
return l[self.choice_idx(probs)]
def shuffle(self, l: List[T]) -> List[T]:
"""Returns a new list with shuffled entries"""
p = torch.randperm(len(l), generator=self.rng)
return [l[p[i]] for i in range(len(l))]
def rand_pop(self, l: List[T]) -> T:
return l.pop(self.randbelow(len(l)))
def save_state(self) -> FlexState:
return FlexState(rng=None if self.rng is None else bytes(self.rng.get_state().tolist()))
def restore_state(self, state: FlexState):
if state["rng"] is None:
self._restore_state = None
else:
self._restore_state = state["rng"]
@edataclass
class SystemRngState:
"""The state of the global random generators.
Note that the data types of the internal RNG states are implementation details of the
respective libraries and may change in the future.
Python does not even specify the type in their docs. Hence we will allow arbitrary types,
because all that matters is that we can save and restore them. We will not use the data
anywhere else.
"""
torch: Any # Currently `torch.Tensor`
numpy: Any # Currently `dict[str, Any] | tuple[str, NDArray[uint32], int, int, float]`
random: Any # Currently a nested tuple
def _hashable_value(self, value: Any) -> Any:
if isinstance(value, (int, float, bool, str)) or value is None:
return value
elif isinstance(value, torch.Tensor):
return self._hashable_value(value.tolist())
elif isinstance(value, numpy.ndarray):
return self._hashable_value(value.tolist())
elif isinstance(value, Mapping):
return tuple(
(self._hashable_value(k), self._hashable_value(v)) for k, v in value.items()
)
elif isinstance(value, Sequence):
return tuple(self._hashable_value(v) for v in value)
else:
raise ValueError(f"Cannot hash value of type {type(value)}: {value!r}")
def __repr__(self):
# If the hash is the same, the state is the same. Should suffice to identify the state.
return f"SystemRngState(hash={hash(self._hashable_value((self.torch, self.numpy, self.random)))})"
class SystemRng:
"""A class to seed, save or restore the global random generators.
This affects torch, numpy and the standard library random module."""
@staticmethod
def seed(seed: int) -> None:
"""Seeds the global random generators."""
torch.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)
@staticmethod
def save_state() -> SystemRngState:
"""Saves the global rng state for torch, numpy and random."""
return SystemRngState(
torch=torch.get_rng_state(),
numpy=numpy.random.get_state(),
random=random.getstate(),
)
@staticmethod
def restore_state(state: SystemRngState) -> None:
"""Restores the global rng state for torch, numpy and random."""
torch.set_rng_state(state.torch)
numpy.random.set_state(state.numpy)
random.setstate(state.random)
@staticmethod
def get_seed_from_args(*args: Any) -> int:
"""Deterministically generates a seed from the given arguments.
The str() representation of each arg is used."""
# Use a deterministic hash function to compute the seed
hash_digest = hashlib.sha1("|".join([str(obj) for obj in args]).encode("utf-8")).digest()
# We use the first 4 bytes of the hash as the seed and fix the endianness
seed_value = int.from_bytes(hash_digest[:4], byteorder="big")
return seed_value
@staticmethod
def seed_args(*args: Any) -> None:
"""Seeds the global random generators deterministically from the given arguments."""
SystemRng.seed(SystemRng.get_seed_from_args(*args))
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from abc import ABC, abstractmethod
from megatron.energon.state import FlexState
class Savable(ABC):
@abstractmethod
def save_state(self) -> FlexState:
"""Dumps the state of the object into some serializable format.
Should only dump the part that is persistable for energon usage.
"""
...
@abstractmethod
def restore_state(self, state: FlexState):
"""Restores the state of the object.
If state is None, the object is restored to its default state."""
...
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import gc
import multiprocessing
import queue
import threading
import time
import traceback
import weakref
from functools import partial
from typing import (
Any,
ClassVar,
Dict,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
import numpy as np
import torch
import torch.distributed
import torch.multiprocessing
from torch.utils.data import DataLoader, IterableDataset
from megatron.energon.cache import CachePool
from megatron.energon.edataclass import edataclass
from megatron.energon.errors import deprecated
from megatron.energon.flavors.base_dataset import (
FlexState,
SavableDataset,
State,
add_sample_restore_key,
)
from megatron.energon.rng import SystemRng, SystemRngState
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
from megatron.energon.wrappers.batch_dataset import BatchDataset
from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset, gc_init_worker
from megatron.energon.wrappers.log_sample_dataset import default_get_keys
from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset
T = TypeVar("T")
def _init_worker(seed_per_worker: List[int], worker_id: int):
"""Initializes the the worker process.
Sets the random seeds and prepare EPath for the forked worker process.
"""
gc_init_worker(worker_id)
worker_seed = seed_per_worker[worker_id]
SystemRng.seed(worker_seed)
class SimpleSavableDatasetWrapper(BaseWrapperDataset[T, Tuple[int, int, T]], Generic[T]):
"""Wrapper for non-multiprocessing savable datasets. Restarts the inner dataset. This class is
not intended to be used directly."""
#: The cache pool to use for the dataset.
cache_pool: CachePool
_state_restored: bool
_sample_index: int
_savable_fields = ("_sample_index",)
def __init__(
self, dataset: SavableDataset[T], worker_config: WorkerConfig, cache_pool: CachePool
):
"""
Args:
dataset: The dataset to wrap.
worker_config: The worker config to use for the dataset.
cache_pool: The cache pool to use for the dataset.
"""
super().__init__(dataset, worker_config=worker_config)
self.cache_pool = cache_pool
self.reset_state_own()
def reset_state_own(self) -> None:
self._sample_index = 0
self._state_restored = False
def len_worker(self, worker_idx: int | None = None) -> int:
return self.dataset.len_worker(worker_idx)
@property
def __len__(self):
# Note: This disables hasattr(self, "__len__"), because that attr will
raise AttributeError("Disabled direct length access to avoid DataLoader warnings.")
def __iter__(self):
self._state_restored = True
worker_id = self.worker_config.rank_worker_id()
global_worker_id = self.worker_config.global_worker_id()
while self._state_restored:
self._state_restored = False
self.worker_config.worker_activate(self._sample_index, cache_pool=self.cache_pool)
worker_active = True
try:
for src_data in self.dataset:
self.worker_config.worker_deactivate()
worker_active = False
sample_index = self._sample_index
src_data = add_sample_restore_key(
src_data, global_worker_id, sample_index, src=self
)
self._sample_index += 1
yield worker_id, sample_index, src_data
if self._state_restored:
# Restart iterator after restore
break
self.worker_config.worker_activate(
self._sample_index, cache_pool=self.cache_pool
)
worker_active = True
finally:
if worker_active:
self.worker_config.worker_deactivate()
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T:
id, global_worker_id, sample_idx = restore_key[:3]
assert id == type(self).__name__
restore_key = restore_key[3:]
self.worker_config.worker_activate(
sample_idx, override_global_rank=global_worker_id, cache_pool=self.cache_pool
)
try:
return add_sample_restore_key(
self.dataset.restore_sample(restore_key),
global_worker_id,
sample_idx,
src=self,
)
finally:
self.worker_config.worker_deactivate()
def config(self) -> Dict[str, Any]:
return self.dataset.config()
def __str__(self):
return f"SimpleSavableDatasetWrapper(dataset={self.dataset})"
@edataclass
class SavableDatasetState(State):
"""State of the dataset wrapper. It stores the global random states and the index of the next
sample to be returned from the dataset. This class is not intended to be used directly, but by
:class:`megatron.energon.SavableDatasetWrapper`."""
#: The state of all the system random number generators
rng: SystemRngState
#: The state of the savable dataset
dataset_state: FlexState
#: Index of the next sample to be returned from the dataset
sample_index: int
def __repr__(self):
return f"SavableDatasetState(rng={self.rng!r}, sample_index={self.sample_index})"
@edataclass
class SavableCheckpoint:
"""Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. An instance is created
regularly to be able to save the state of the dataset wrapper before the currently emitted
sample.
"""
#: The state of the wrapper
state: Optional[SavableDatasetState]
#: The time at which the checkpoint was created
checkpoint_time: float
#: Index of the next sample to be returned from the dataset after restoring the checkpoint
sample_index: int
@edataclass
class SavableDatasetCheckpoint(State):
"""Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. The checkpoint state
represents a state before that checkpoint, with an offset (i.e. samples to be skipped)."""
#: The state of the wrapper at the sample index when the checkpoint was created.
state: Optional[SavableDatasetState]
#: Offset of the checkpoint to the actual sample index to be restored.
offset: int
class SavableDatasetWrapper(IterableDataset[Tuple[int, int, T]], Generic[T]):
"""Internal class for wrapping a savable dataset for a worker process. Provides communication
with the :class:`megatron.energon.SavableDataLoader`. This class is not intended to be used directly.
See :class:`megatron.energon.SavableDataLoader` for more information."""
#: The wrapped dataset
dataset: SavableDataset[T]
#: The configuration of the worker process
worker_config: WorkerConfig
#: The time interval in seconds to wait at minimum between two checkpoints
checkpoint_every_sec: float
#: The minimum number of samples to be emitted between two checkpoints. Should be `number of
# workers * 2`.
checkpoint_every_min_n_samples: int
#: The number of checkpoints to keep in memory, before discarding. Should be 2.
n_checkpoints: int
#: The cache pool to use for the dataset.
cache_pool: CachePool
#: The queue of the worker process to receive commands from the `SavableDataLoader`.
_cmd_queues: List[torch.multiprocessing.Queue]
#: The queue of the worker process to send results to the `SavableDataLoader`.
_result_queues: List[torch.multiprocessing.Queue]
_sample_index: int = 0
_worker_offset: int = 0
_last_checkpoints: List[SavableCheckpoint]
_workers_restore_from: List[Optional[SavableDatasetState]] = list()
_workers_skip_samples: List[int]
_running: bool = False
_command_lock: Optional[threading.RLock] = None
_cmd_thread: Optional[threading.Thread] = None
def __init__(
self,
dataset: SavableDataset[T],
worker_config: WorkerConfig,
checkpoint_every_sec: float,
checkpoint_every_min_n_samples: int,
n_checkpoints: int = 2,
*,
cmd_queues: List[torch.multiprocessing.Queue],
result_queues: List[torch.multiprocessing.Queue],
cache_pool: CachePool,
):
"""
Create the savable dataset wrapper for multiprocessing data loading.
Args:
dataset: The dataset to wrap
worker_config: The worker config as used by all datasets
checkpoint_every_sec: The time interval in seconds to wait at minimum between two
checkpoints.
checkpoint_every_min_n_samples: The minimum number of samples to be emitted between
two checkpoints. Should be `number of workers * 2`.
n_checkpoints: Number of checkpoints to keep.
cmd_queues: The command queues for communicating with the worker processes.
result_queues: The result queues for communicating with the worker processes.
cache_pool: The cache pool to use for the dataset.
"""
num_workers = max(worker_config.num_workers, 1)
self.dataset = dataset
self.worker_config = worker_config
self.checkpoint_every_sec = checkpoint_every_sec
self.checkpoint_every_min_n_samples = checkpoint_every_min_n_samples
self.n_checkpoints = n_checkpoints
self._last_checkpoints = [
SavableCheckpoint(state=None, checkpoint_time=time.perf_counter(), sample_index=-1)
]
self._workers_restore_from = [None] * num_workers
self._workers_skip_samples = [0] * num_workers
self._cmd_queues = cmd_queues
self._result_queues = result_queues
self.cache_pool = cache_pool
@staticmethod
def _command_thread(self: "SavableDatasetWrapper"):
"""The internal thread, which processes the command and result queues. This thread is
static, because `self` is actually passed as weakref proxy, to avoid keeping the dataset
alive via the thread.
"""
# print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread starting")
assert self._command_lock is not None
try:
while self._running:
try:
cmd_args = self._cmd_queues[self._worker_id].get(timeout=1)
except queue.Empty:
continue
# print(f"recv cmd {cmd_args}")
with self._command_lock:
cmd = cmd_args[0]
if cmd is None:
break
try:
fn = getattr(self, cmd)
self._result_queues[self._worker_id].put(
{self._worker_id: fn(*cmd_args[1:])}
)
# print(f"result sent")
except Exception as e:
traceback.print_exc()
self._result_queues[self._worker_id].put({self._worker_id: e})
# print(f"exc sent")
except BaseException:
traceback.print_exc()
raise
finally:
pass
# print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread closing")
def len_worker(self, worker_idx: int | None = None) -> int:
return self.dataset.len_worker(worker_idx)
def len_rank(self):
return self.dataset.len_rank()
@property
def __len__(self):
# Note: This disables hasattr(self, "__len__"), because that attr will
raise AttributeError("Disabled direct length access to avoid DataLoader warnings.")
def __del__(self):
if self._cmd_thread is not None:
# print(f"{id(self)}:{multiprocessing.current_process().ident} Closing cmd thread")
self._running = False
self._cmd_thread.join()
self._command_lock = None
self._cmd_thread = None
# print(f"{id(self)}:{multiprocessing.current_process().ident} Cmd thread closed")
def __iter__(self):
# First: Set the worker offset globally for the current worker
WorkerConfig.worker_id_offset = self._worker_offset
self._worker_id = self.worker_config.rank_worker_id()
global_worker_id = self.worker_config.global_worker_id()
if self._cmd_thread is None:
self._running = True
self._command_lock = threading.RLock()
weakref_self = weakref.proxy(self)
self._cmd_thread = threading.Thread(
target=SavableDatasetWrapper._command_thread,
name="command_thread",
args=(weakref_self,),
daemon=True,
)
self._cmd_thread.start()
# atexit.register(lambda: weakref_self.__del__())
try:
assert self._command_lock is not None
with self._command_lock:
if self._workers_restore_from[self._worker_id] is not None:
my_state = self._workers_restore_from[self._worker_id]
my_ds_state = my_state.dataset_state
assert my_state is not None
if my_ds_state is None:
self.dataset.reset_state_deep()
else:
self.dataset.restore_state(my_ds_state)
self._restore_state(my_state)
self._workers_restore_from[self._worker_id] = None
else:
# Store the initial state of the worker if we stop before the first sample
self._store_checkpoint()
# If skipping, also restart the iterator to reach the start of the restored
# checkpoint
last_was_skip = True
while last_was_skip:
dataset_has_samples = False
self.worker_config.worker_activate(
self._sample_index, cache_pool=self.cache_pool
)
worker_active = True
try:
for src_data in self.dataset:
self.worker_config.worker_deactivate()
worker_active = False
dataset_has_samples = True
if self._workers_skip_samples[self._worker_id] > 0:
# Skip ahead to reach the start of the restored checkpoint
# print(f"Skip [{self._sample_index}:{self._worker_id}] {src_data}")
self._workers_skip_samples[self._worker_id] -= 1
self._sample_index += 1
last_was_skip = True
self.worker_config.worker_activate(
self._sample_index, cache_pool=self.cache_pool
)
worker_active = True
continue
last_was_skip = False
sample_index = self._sample_index
add_sample_restore_key(
src_data, global_worker_id, sample_index, src=self
)
self._sample_index += 1
self._store_checkpoint()
try:
self._command_lock.release()
# print(f"{id(self)}:{multiprocessing.current_process().ident} Lock released")
# Commands may be executed only when data was yielded, not during
# iteration fetching.
# print(f"Yield next data [{sample_index}:{self._worker_id}] {src_data}")
yield self._worker_id, sample_index, src_data
finally:
# print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquiring")
self._command_lock.acquire()
# print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquired")
self.worker_config.worker_activate(
self._sample_index, cache_pool=self.cache_pool
)
worker_active = True
finally:
if worker_active:
self.worker_config.worker_deactivate()
# If the dataset is empty, don't try again and again
if not dataset_has_samples:
break
finally:
# print(f"{id(self)}:{multiprocessing.current_process().ident} Worker iter closing")
# Always store a final checkpoint (it's likely to be saved)
self._store_checkpoint(force=True)
def _store_checkpoint(self, force: bool = False) -> None:
"""
Internally create a checkpoint for the current state. This is required to store states
from the past, which have already been yielded here, but not yet been retrieved from the
intermediate queues.
Args:
force: If true, ignore time or frequency condition.
"""
if (
force
or (
self._last_checkpoints[-1].checkpoint_time + self.checkpoint_every_sec
< time.perf_counter()
and self._last_checkpoints[-1].sample_index + self.checkpoint_every_min_n_samples
<= self._sample_index
)
or self._sample_index <= 1
):
# print(f"Storing checkpoint at {self._worker_id}:{self._sample_index}")
self._last_checkpoints.append(
SavableCheckpoint(
state=self._save_state(),
checkpoint_time=time.perf_counter(),
sample_index=self._sample_index,
)
)
if len(self._last_checkpoints) > self.n_checkpoints:
self._last_checkpoints.pop(0)
def _save_state(self) -> SavableDatasetState:
"""Saves the internal state"""
return SavableDatasetState(
rng=SystemRng.save_state(),
dataset_state=self.dataset.save_state(),
sample_index=self._sample_index,
)
def _restore_state(self, state: SavableDatasetState) -> None:
"""Restores the internal worker state"""
assert torch.utils.data.get_worker_info() is not None, "Can only restore in worker process"
if state.rng is None:
SystemRng.seed(torch.initial_seed() & 0xFFFFFFFF)
else:
SystemRng.restore_state(state.rng)
self._sample_index = state.sample_index
self._last_checkpoints = [
SavableCheckpoint(
state=self._save_state(),
checkpoint_time=time.perf_counter(),
sample_index=self._sample_index,
)
]
def get_checkpoint(self, last_sample_indexes: List[int]) -> SavableDatasetCheckpoint:
"""
Get a checkpoint given the last emitted sample indexes for all workers.
Args:
last_sample_indexes: The last emitted sample indexes for all workers.
Returns:
The found checkpoint including the offset to the next sample index
"""
sample_index = last_sample_indexes[self._worker_id] + 1
for checkpoint in reversed(self._last_checkpoints):
if checkpoint.sample_index <= sample_index:
# print(f"Found cp for {sample_index} at {checkpoint.sample_index}")
return SavableDatasetCheckpoint(
state=checkpoint.state,
offset=sample_index - checkpoint.sample_index,
)
# Immediate save after restore
if len(self._last_checkpoints) == 0:
return SavableDatasetCheckpoint(
state=self._workers_restore_from[self._worker_id],
offset=self._workers_skip_samples[self._worker_id],
)
raise ValueError("No checkpoint found")
def restore_checkpoint(
self,
worker_states: Optional[List[SavableDatasetCheckpoint]],
worker_offset: int,
) -> None:
"""
Restores the merged checkpoint from all worker processes.
Args:
worker_states: The state to restore for each worker
worker_offset: The offset of the last worker which has emitted a sample. This will be
set in all worker processes to ensure the right worker starts as first.
"""
assert torch.utils.data.get_worker_info() is None, "Cannot restore in worker process"
num_workers = max(self.worker_config.num_workers, 1)
if worker_states is None:
self._workers_restore_from = [None] * num_workers
assert worker_offset == 0
self._worker_offset = 0
self._workers_skip_samples = [0] * num_workers
else:
assert isinstance(worker_states, list)
assert len(worker_states) == num_workers
assert isinstance(worker_states[0], SavableDatasetCheckpoint)
self._worker_offset = worker_offset
# Tear the state_list apart (which has len=num_workers)
# and store the states in the internal arrays
self._workers_restore_from = [state.state for state in worker_states]
self._workers_skip_samples = [state.offset for state in worker_states]
def get_initial_checkpoint(self) -> Optional[List[SavableDatasetCheckpoint]]:
"""
Get the initial checkpoint for all worker processes if they have not started yet.
Returns:
The initial checkpoint for all worker processes and the worker offset.
"""
assert torch.utils.data.get_worker_info() is None, (
"Cannot get initial checkpoint in worker process"
)
if all(s is None for s in self._workers_restore_from):
assert all(s == 0 for s in self._workers_skip_samples)
# Initial state, no checkpoint
return None
return [
SavableDatasetCheckpoint(
state=state,
offset=offset,
)
for state, offset in zip(self._workers_restore_from, self._workers_skip_samples)
]
def can_restore_sample(self) -> bool:
return self.dataset.can_restore_sample()
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T:
id, global_worker_id, sample_idx = restore_key[:3]
assert id == type(self).__name__
restore_key = restore_key[3:]
self.worker_config.worker_activate(sample_idx, override_global_rank=global_worker_id)
try:
return add_sample_restore_key(
self.dataset.restore_sample(restore_key),
global_worker_id,
sample_idx,
src=self,
)
finally:
self.worker_config.worker_deactivate()
def config(self) -> Dict[str, Any]:
return self.dataset.config()
def __str__(self):
return f"SavableDatasetWrapper(dataset={self.dataset})"
@edataclass
class SavableDataLoaderState(State):
"""Saved state of the :class:`megatron.energon.SavableDataLoader`. Contains the state for all worker
processed of a single rank."""
#: The internal state of the dataset (for each worker process)
worker_states: List[Union[SavableDatasetCheckpoint, FlexState]]
#: Which worker will be the next to emit a sample. Used to restore the proper order
next_worker_id: int
#: The micro batch size that was used, if available.
#: On restore, this is used to compare the new and old micro batch size.
micro_batch_size: Optional[int]
class SavableDataLoader(DataLoader[T], Generic[T]):
"""DataLoader that supports saving and restoring the state of the dataset.
When restoring, the dataloader and dataset must be instantiated with the exactly same
parameters.
How this works (for no worker processes)
----------------------------------------
1. The state of the dataset is saved using :meth:`megatron.energon.SavableDataset.save_state`
2. (for compatibility) The state of the dataset is converted to using inner arrays using
:meth:`megatron.energon.SavableDataset.merge_states`.
3. The state can be restored using :meth:`megatron.energon.SavableDataset.restore_state` given the
previously saved (and merged) state.
How this works (for worker processes)
-------------------------------------
- First issue is, that worker processes work with internal queues between processes to pass
loaded samples to the main process (also to perform collating). This means that the whole
state of the dataset is not directly accessible from the main process.
- To solve this issue, the dataset regularly saves a checkpoint of its state to be able to
resume from that state (and skip the samples that have already been yielded).
- To have a consistent state, the sample index from the latest yielded samples is saved for all
worker instances. Thus, the main process knows exactly which sample indexes should come next
from which worker.
- Internally, pytorch iterates through the workers in order to retrieve the next worker's
samples. Unfortunately, that next worker index cannot be restored in pytorch's dataloader,
thus the workers are shifted internally by that offset
(see :attr:`megatron.energon.WorkerConfig.worker_id_offset`).
1. The dataset is wrapped in a :class:`megatron.energon.SavableDatasetWrapper`. This allows the main
process to communicate with the worker and send commands to the workers and retrieve the
results.
2. The state of the dataset is saved using
:meth:`megatron.energon.SavableDatasetWrapper.get_checkpoint`. This gives the last checkpoint
from the requested sample index and stores the offset (i.e. number of samples to skip) from
that checkpoint.
3. The state is merged using :meth:`megatron.energon.SavableDatasetWrapper.merge_checkpoints`. This
merges the states of all workers and returns a single state that can be used to restore the
state of the dataset.
4. The state can be restored using :meth:`megatron.energon.SavableDatasetWrapper.restore_state`
before a worker is started, such that all workers initially receive the same state array.
The worker firstly sets the worker index offset, then uses its (shifted) own index to get its
required state from the merged state array.
"""
#: The worker config
worker_config: WorkerConfig
#: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper`
dataset: Union[SavableDatasetWrapper[T], SimpleSavableDatasetWrapper[T]]
#: The global ID counter
_next_id: ClassVar[int] = 0
#: Class instance id
id: int = 0
#: The queues used to send commands to the workers
cmd_queues: List[torch.multiprocessing.Queue]
#: The queues used to receive results from the workers
result_queues: List[torch.multiprocessing.Queue]
#: Instance of the current data iterator. There shall be only one active iterator, such that the
# dataset is not iterated multiple times in parallel. The state will continue between epochs.
_epoch_iterator: Optional[Iterator[T]] = None
#: Whether the dataloader has running workers.
_has_workers: bool = False
#: The index of the current worker. -1 if not started yet.
_worker_sample_counters: List[int]
#: Id of the next worker to retrieve data from
_next_worker_id: int = 0
#: Global index of the last yielded sample
_global_sample_idx: int = 0
#: Current iterator index of the last yielded sample
_sample_idx: int = 0
def __init__(
self,
dataset: SavableDataset[T],
*,
checkpoint_every_sec: float = 60,
checkpoint_every_min_n_samples: Optional[int] = None,
n_checkpoints: Optional[int] = None,
gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER,
gc_freeze_at_start: bool = True,
prefetch_factor: int = 2,
cache_pool: Optional[CachePool] = None,
watchdog_timeout_seconds: Optional[float] = 60,
watchdog_initial_timeout_seconds: Optional[float] = None,
fail_on_timeout: bool = False,
):
"""
Create the dataloader supporting saving and restoring the state.
Args:
dataset: The dataset to load.
worker_config: The worker config to use
checkpoint_every_sec: This is the time in seconds after which a checkpoint is saved.
It may take the same duration to restore a checkpoint, but introduces additional
overhead during reading data from the dataset, so this should be chosen accordingly.
Only applies if using workers.
checkpoint_every_min_n_samples: Overwrites the minimum number of samples between
checkpoints. Defaults to `number of workers * 2`. Only applies if using workers.
n_checkpoints: The number of checkpoints to keep in memory. Only applies if using
workers. If None, computes a suitable value.
gc_collect_every_n_steps: The number of steps after which the garbage collector is
called. As we're usually handling large (but few) tensors here, and the python
garbage collection is already full of objects just by importing, this can improve
the memory footprint quite a lot, and may even be necessary to avoid memory
overflow.
gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker
processes. This improves the garbage collection performance by a lot.
In rare cases, this may cause issues and can be disabled. Keep enabled if you
experience no issues.
cache_pool: If set, the cache pool to use for the dataset.
watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled.
watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.
fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace.
"""
self.worker_config = dataset.worker_config
self.id = self.next_id()
dataset = WatchdogDataset(
dataset,
worker_config=self.worker_config,
timeout_seconds=watchdog_timeout_seconds,
initial_timeout_seconds=watchdog_initial_timeout_seconds,
fail_on_timeout=fail_on_timeout,
)
if gc_collect_every_n_steps > 0:
dataset = GcDataset(
dataset,
worker_config=self.worker_config,
every_n_iter=gc_collect_every_n_steps,
freeze=gc_freeze_at_start,
)
self.cmd_queues = [multiprocessing.Queue() for _ in range(self.worker_config.num_workers)]
self.result_queues = [
multiprocessing.Queue() for _ in range(self.worker_config.num_workers)
]
num_procs = max(self.worker_config.num_workers, 1)
if n_checkpoints is None:
n_checkpoints = prefetch_factor * num_procs + 1
if self.worker_config.num_workers > 0:
if checkpoint_every_min_n_samples is None:
checkpoint_every_min_n_samples = self.worker_config.num_workers * 2
dataset = SavableDatasetWrapper(
dataset,
self.worker_config,
checkpoint_every_sec=checkpoint_every_sec,
checkpoint_every_min_n_samples=checkpoint_every_min_n_samples,
n_checkpoints=n_checkpoints,
cmd_queues=self.cmd_queues,
result_queues=self.result_queues,
cache_pool=cache_pool,
)
else:
dataset = SimpleSavableDatasetWrapper(
dataset, self.worker_config, cache_pool=cache_pool
)
self._worker_sample_counters = [-1] * num_procs
kwargs = {}
if self.worker_config.num_workers > 0:
kwargs["persistent_workers"] = True
kwargs["prefetch_factor"] = prefetch_factor
# Assert that prefetch_factor works well with num_checkpoints.
# This ensures that the oldest checkpoint is old enough to cover
# all the buffered samples in the torch dataloader.
assert prefetch_factor * num_procs + 1 <= n_checkpoints, (
"When increasing prefetch_factor, also increase n_checkpoints, so that "
"the number of checkpoints is at least as large as num_workers * prefetch_factor + 1"
)
# Compute seeds for each worker, based on current rank
seed_per_worker = [
self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers)
]
super().__init__(
dataset,
batch_size=None,
shuffle=False,
num_workers=self.worker_config.num_workers,
pin_memory=True,
worker_init_fn=partial(_init_worker, seed_per_worker),
**kwargs,
)
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "SavableLoader.__init__",
"r": self.worker_config.rank,
"w": None,
"id": self.id,
"config": dataset.config(),
}
)
@staticmethod
def next_id() -> int:
next_id = SavableDataLoader._next_id
SavableDataLoader._next_id += 1
return next_id
def __len__(self):
# We override this, because otherwise we'll see warnings
return self.dataset.len_rank()
def _epoch_iter(self):
"""Iterator for one epoch, i.e. until the inner dataset raises StopIteration."""
iter_idx = 0
id = self.next_id()
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "SavableDataLoader.iter",
"r": self.worker_config.rank,
"w": None,
"id": self.id,
"iter_id": id,
}
)
try:
for worker_id, sample_idx, sample in super().__iter__():
self._worker_sample_counters[worker_id] = sample_idx
# If the next sample will be from the first worker, we can safely resume
self._next_worker_id = (worker_id + 1) % max(self.num_workers, 1)
# self._debugf.write(
# f"[w={worker_id}, s={sample_idx}] {self._sample_str(sample)}\n"
# )
# self._debugf.flush()
if self.worker_config.should_log(level=1):
keys = default_get_keys(sample)
self.worker_config.worker_log(
{
**{
"t": "SavableDataLoader.yield",
"r": self.worker_config.rank,
"w": None,
"id": self.id,
"iter_id": id,
"worker_id": worker_id,
"worker_idx": sample_idx,
"idx": self._sample_idx,
"iter_idx": iter_idx,
"global_idx": self._global_sample_idx,
},
**({} if keys is None else {"keys": keys}),
}
)
self._sample_idx += 1
self._global_sample_idx += 1
iter_idx += 1
yield sample
self._epoch_iterator = None
self._next_worker_id = 0
finally:
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "SavableDataLoader.StopIteration",
"r": self.worker_config.rank,
"w": None,
"id": self.id,
"iter_id": self.id,
}
)
def __iter__(self):
if self.num_workers > 0:
# Always keep same iterator alive, as long as it yields data
if self._epoch_iterator is None:
self._epoch_iterator = self._epoch_iter()
self._sample_idx = 0
self._has_workers = True
# print("New Iterator", self._persistent_iterator)
return self._epoch_iterator
else:
return self._epoch_iter()
def _worker_command(self, *cmd_args) -> List[Any]:
"""Executes a command in all workers and returns the results."""
# print(f"cmd: {cmd_args}")
for cmd_queue in self.cmd_queues:
cmd_queue.put(cmd_args)
# print(f"waiting for res")
assert len(self.result_queues) == self.worker_config.num_workers
res = {k: v for results_queue in self.result_queues for k, v in results_queue.get().items()}
res = [res[i] for i in range(len(res))]
# print(f"res: {res}")
for r in res:
if isinstance(r, Exception):
raise r
return res
def _get_batch_size(self) -> Optional[int]:
"""Try to infer micro batch size from the dataset"""
if isinstance(self.dataset, (SavableDatasetWrapper, SimpleSavableDatasetWrapper)):
dataset = self.dataset.dataset
else:
dataset = self.dataset
if (
isinstance(dataset, BaseWrapperDataset)
and (bds := dataset._find_wrapped_dataset(BatchDataset)) is not None
):
assert isinstance(bds, BatchDataset)
return bds.batch_size
else:
return None
def save_state_rank(self) -> Optional[SavableDataLoaderState]:
"""
Saves the state of the dataset for the current rank. Allows for restoring the state later
using `restore_state_rank`, given the result of this method.
Returns:
The state of the dataset.
"""
# Fetch current rank's worker's state
if self.num_workers == 0:
# No workers configured
assert isinstance(self.dataset, SimpleSavableDatasetWrapper)
worker_states = [self.dataset.save_state()]
assert self._next_worker_id == 0
elif self._has_workers:
# Fetch from worker processes
worker_states = self._worker_command("get_checkpoint", self._worker_sample_counters)
else:
# Workers configured, but not started yet.
# If a state has already been restored, it will be returned.
assert isinstance(self.dataset, SavableDatasetWrapper)
worker_states = self.dataset.get_initial_checkpoint()
if worker_states is None:
return None
# Merge the states
merged_state = SavableDataLoaderState(
worker_states=worker_states,
next_worker_id=self._next_worker_id,
micro_batch_size=self._get_batch_size(),
)
# Not distributed -> return the merged state
return merged_state
def restore_state_rank(self, state: Optional[SavableDataLoaderState]) -> None:
"""
Restores the saved state for the current rank.
Args:
state: The state to restore, as saved by `save_state_rank`.
"""
assert not self._has_workers, "Cannot restore state while workers are running"
if state is None:
# Assume initial state
return
assert isinstance(state, SavableDataLoaderState)
old_micro_batch_size = state.micro_batch_size
micro_batch_size = self._get_batch_size()
if self.num_workers == 0:
# No workers configured
assert isinstance(self.dataset, SimpleSavableDatasetWrapper)
assert micro_batch_size == old_micro_batch_size, (
"Changing micro batch size is not allowed without workers"
)
assert len(state.worker_states) == 1
assert isinstance(state.worker_states[0], FlexState)
self.dataset.restore_state(state.worker_states[0])
else:
# Workers configured
assert isinstance(self.dataset, SavableDatasetWrapper)
assert all(isinstance(s, SavableDatasetCheckpoint) for s in state.worker_states)
# Check batch sizes (before and after)
if micro_batch_size != old_micro_batch_size:
assert micro_batch_size is not None and old_micro_batch_size is not None, (
"Cannot resume with different batching mode "
"(batching to non-batching or vice versa)"
)
if micro_batch_size > old_micro_batch_size:
raise ValueError(
"Resuming with larger micro batch size is not allowed: "
f"{micro_batch_size} > {state.micro_batch_size}"
)
elif (
micro_batch_size < old_micro_batch_size
and old_micro_batch_size % micro_batch_size != 0
):
raise ValueError(
"Resuming with smaller micro batch size only allowed if the old "
f"micro batch size is a multiple of the new one: {micro_batch_size} < {state.micro_batch_size}"
)
batch_size_ratio = old_micro_batch_size // micro_batch_size
for worker_state in state.worker_states:
assert isinstance(worker_state, SavableDatasetCheckpoint)
# When resuming with a smaller micro batch size, the offset must be scaled
# up to the new micro batch size to skip the same number of samples as before.
worker_state.offset *= batch_size_ratio
self.dataset.restore_checkpoint(state.worker_states, worker_offset=state.next_worker_id)
# Initialize the worker-sample counters so that every worker owns a valid
# "last emitted sample" index. Workers that have not emitted anything yet keep
# the default value ``-1``.
assert isinstance(state.worker_states, list)
self._worker_sample_counters = [
(
ws.state.sample_index - 1
if (isinstance(ws, SavableDatasetCheckpoint) and ws.state is not None)
else -1
)
for ws in state.worker_states
]
self._next_worker_id = state.next_worker_id
@deprecated(
"`save_state` is deprecated and was renamed to `save_state_global` and will be removed "
"in a future update. If you actually do not want to gather the states to a rank, use "
"`save_state_rank` instead."
)
def save_state(self, dst_rank: int) -> Optional[Sequence[Optional[SavableDataLoaderState]]]:
"""Deprecated. Use `save_state_global` (or `save_state_rank`) instead."""
return self.save_state_global(dst_rank)
def save_state_global(
self, global_dst_rank: int
) -> Optional[Sequence[Optional[SavableDataLoaderState]]]:
"""
Saves the state of the dataset globally, collecting the state from all ranks using torch
distributed. Allows for restoring the state later using `restore_state_global`, given the
result of this method.
Typical scenario: Save the state to disk only on the `dst_rank`, the other ranks do not
save the state. Later, restore the state either only loaded on the `dst_rank` or
loading on all ranks separately using `restore_state_global`.
Note: If you want to save/restore the state per rank separately, use `save_state_rank` and
the corresponding `restore_state_rank`. Also, these do not rely on torch distributed.
Args:
global_dst_rank: The state will be gathered to this rank. The rank refers to the
global rank, not the rank within the data parallel group.
Returns:
The state of the dataset (or `None`, if not on `dst_rank`).
"""
# Fetch current rank's worker's state
merged_state = self.save_state_rank()
# Gather the merged states
if self.worker_config.world_size > 1:
output: Optional[Sequence[Optional[SavableDataLoaderState]]]
if self.worker_config.global_rank() == global_dst_rank:
output = [None] * self.worker_config.world_size
else:
# Check if the global_dst_rank is in the same group at all
if self.worker_config.data_parallel_group is not None:
try:
_ = torch.distributed.get_group_rank(
self.worker_config.data_parallel_group, global_dst_rank
)
except RuntimeError:
raise ValueError(
f"global_dst_rank {global_dst_rank} is not in the group of the current rank's worker config"
)
output = None
torch.distributed.gather_object(
merged_state,
output,
global_dst_rank,
group=self.worker_config.data_parallel_group,
)
return output
else:
# Not distributed -> return the merged state
return [merged_state]
@deprecated(
"`restore_state` was renamed to `restore_state_global` and will be removed in a future update."
)
def restore_state(
self,
state: Optional[Sequence[Optional[SavableDataLoaderState]]],
) -> None:
"""Deprecated. Use `restore_state_global` (or `restore_state_rank`) instead."""
return self.restore_state_global(state)
def restore_state_global(
self,
state: Optional[Sequence[Optional[SavableDataLoaderState]]],
*,
src_rank: Optional[int] = None,
) -> None:
"""
Restores the saved state from `save_state_global` (in torch distributed setup).
The global state needs be loaded on every rank that has a data loader instance.
Optionally, one can specify a src_rank and only provide the state once.
In case of multiple data parallel groups, you must provide the state once
in each data parallel group. In this case the `src_rank` is the rank within the
data parallel group.
Args:
state: The state to restore, as saved by `save_state_global`.
src_rank: The rank from which the state is broadcasted (within the data parallel group, if using DP groups).
"""
assert self._epoch_iterator is None, "Cannot restore state while workers are running"
# Only restore multi-rank if state is actually a list and we are in a distributed setup.
# Otherwise treat as single rank state.
if src_rank is None or self.worker_config.world_size == 1:
assert isinstance(state, list), "State must be a list in distributed setup"
assert len(state) == self.worker_config.world_size, (
"State must be a list of size world_size"
)
# All ranks have the state
# Select the state of the current rank
rank_state = state[self.worker_config.rank]
else:
if self.worker_config.data_parallel_group is not None:
# Only the src_rank has the state within this dp group
try:
global_src_rank = torch.distributed.get_global_rank(
self.worker_config.data_parallel_group, src_rank
)
except RuntimeError:
raise ValueError(
f"src_rank {src_rank} is not in the group of the current rank's worker config"
)
else:
# If no DP group is given, we assume the global rank is
# the same as the data parallel rank
global_src_rank = src_rank
if self.worker_config.rank != src_rank:
# Send the state to all other ranks
assert state is None
# Must still be a list of Nones
state = [None] * self.worker_config.world_size
else:
assert isinstance(state, list), "State must be a list in distributed setup"
assert len(state) == self.worker_config.world_size, (
"State must be a list of size world_size"
)
local_object = [None]
torch.distributed.scatter_object_list(
local_object,
state,
src=global_src_rank,
group=self.worker_config.data_parallel_group,
)
rank_state = local_object[0]
self.restore_state_rank(rank_state)
def can_restore_sample(self) -> bool:
return self.dataset.can_restore_sample()
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T:
"""Restores a sample from a key. This is useful to debug the dataset."""
return self.dataset.restore_sample(restore_key)
def config(self):
"""Get the configuration, which defines the dataset. Useful in conjunction with `save_state`
and `restore_state` to match the configuration as well."""
return {
"type": type(self).__qualname__,
"num_workers": self.num_workers,
"persistent_workers": self.persistent_workers,
"pin_memory": self.pin_memory,
"prefetch_factor": None if self.num_workers == 0 else self.prefetch_factor,
"dataset": self.dataset.config(),
}
class BasicDataLoader(DataLoader[T], Generic[T]):
"""DataLoader that supports debugging the dataset without saving capability (e.g. for val/eval)."""
#: The worker config
worker_config: WorkerConfig
#: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper`
dataset: Union[SavableDatasetWrapper[T], SavableDataset[T]]
id: int
_sample_idx: int = 0
def __init__(
self,
dataset: SavableDataset[T],
gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER,
gc_freeze_at_start: bool = True,
prefetch_factor: int = 2,
cache_pool: Optional[CachePool] = None,
watchdog_timeout_seconds: Optional[float] = 60,
watchdog_initial_timeout_seconds: Optional[float] = None,
fail_on_timeout: bool = False,
):
"""
Create the dataloader supporting saving and restoring the state.
Args:
dataset: The dataset to load.
gc_collect_every_n_steps: The number of steps after which the garbage collector is
called. As we're usually handling large (but few) tensors here, and the python
garbage collection is already full of objects just by importing, this can improve
the memory footprint quite a lot, and may even be necessary to avoid memory
overflow.
gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker
processes. This improves the garbage collection performance by a lot.
In rare cases, this may cause issues and can be disabled. Keep enabled if you
experience no issues.
cache_pool: If set, the cache pool to use for the dataset.
watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled.
watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds.
fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace.
"""
self.worker_config = dataset.worker_config
self.id = SavableDataLoader.next_id()
dataset = WatchdogDataset(
dataset,
worker_config=self.worker_config,
timeout_seconds=watchdog_timeout_seconds,
initial_timeout_seconds=watchdog_initial_timeout_seconds,
fail_on_timeout=fail_on_timeout,
)
if gc_collect_every_n_steps > 0:
dataset = GcDataset(
dataset,
worker_config=self.worker_config,
every_n_iter=gc_collect_every_n_steps,
freeze=gc_freeze_at_start,
)
dataset = SimpleSavableDatasetWrapper(
dataset, worker_config=self.worker_config, cache_pool=cache_pool
)
self._worker_sample_counters = [0] * max(self.worker_config.num_workers, 1)
kwargs = {}
if self.worker_config.num_workers > 0:
# These must not be specified for num_workers =0
kwargs["persistent_workers"] = True
kwargs["prefetch_factor"] = prefetch_factor
seed_per_worker = [
self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers)
]
gc.collect() # This ensures that we don't include any old worker refs in the newly forked worker processes
super().__init__(
dataset,
batch_size=None,
shuffle=False,
num_workers=self.worker_config.num_workers,
pin_memory=True,
worker_init_fn=partial(_init_worker, seed_per_worker),
**kwargs,
)
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "BasicDataLoader.__init__",
"r": self.worker_config.rank,
"w": None,
"id": self.id,
"config": self.config(),
}
)
def __len__(self):
# We override this, because otherwise we'll see warnings
return self.dataset.len_rank()
def __iter__(self):
def _inner_generator(iterator):
iter_idx = 0
id = SavableDataLoader.next_id()
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "BasicDataLoader.iter",
"r": self.worker_config.rank,
"w": None,
"id": self.id,
"iter_id": id,
}
)
try:
for worker_id, sample_idx, sample in iterator:
# If the next sample will be from the first worker, we can safely resume
if self.worker_config.should_log(level=1):
keys = default_get_keys(sample)
self.worker_config.worker_log(
{
**{
"t": "BasicDataLoader.yield",
"r": self.worker_config.rank,
"w": None,
"id": self.id,
"iter_id": self.id,
"worker_id": worker_id,
"worker_idx": sample_idx,
"idx": iter_idx,
"iter_idx": iter_idx,
"global_idx": self._sample_idx,
},
**({} if keys is None else {"keys": keys}),
}
)
self._sample_idx += 1
iter_idx += 1
yield sample
finally:
if self.worker_config.should_log(level=1):
self.worker_config.worker_log(
{
"t": "BasicDataLoader.StopIteration",
"r": self.worker_config.rank,
"w": None,
"id": self.id,
"iter_id": id,
}
)
return _inner_generator(super().__iter__())
def config(self):
"""Get the configuration, which defines the dataset. Useful in conjunction with `save_state`
and `restore_state` to match the configuration as well."""
return {
"type": type(self).__qualname__,
"num_workers": self.worker_config.num_workers,
"persistent_workers": self.persistent_workers,
"pin_memory": self.pin_memory,
"prefetch_factor": None if self.num_workers == 0 else self.prefetch_factor,
"dataset": self.dataset.config(),
}
def can_restore_sample(self) -> bool:
return self.dataset.can_restore_sample()
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T:
"""Restores a sample from a key. This is useful to debug the dataset."""
return self.dataset.restore_sample(restore_key)
def _sample_str(self, sample):
"""Returns a human readable debug string for a single sample, also uniquely identifying it."""
import dataclasses
import hashlib
if isinstance(sample, torch.Tensor):
return f"Tensor(shape={sample.shape}, dtype={sample.dtype}, sha256={hashlib.sha256(sample.detach().cpu().numpy().tobytes()).hexdigest()!r})"
elif isinstance(sample, np.ndarray):
return f"ndarray(shape={sample.shape}, dtype={sample.dtype}, sha256={hashlib.sha256(sample.tobytes()).hexdigest()!r})"
elif isinstance(sample, (str, bytes)):
if len(sample) > 100:
return f"{sample[:100]!r}..."
return repr(sample)
elif isinstance(sample, (int, float)):
return repr(sample)
elif isinstance(sample, dict):
innerstr = ", ".join(f"{k}={self._sample_str(v)}, " for k, v in sample.items())
return f"dict({innerstr})"
elif isinstance(sample, (list, tuple)):
innerstr = ", ".join(self._sample_str(v) for v in sample)
return f"{type(sample)}({innerstr})"
elif dataclasses.is_dataclass(sample):
innerstr = ", ".join(
f"{field.name}={self._sample_str(getattr(sample, field.name))}, "
for field in dataclasses.fields(sample)
)
return f"{type(sample).__name__}({innerstr})"
else:
return repr(sample)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from dataclasses import dataclass
from typing import Any, Union
from megatron.energon.epathlib import EPath
@dataclass(frozen=True, slots=True, kw_only=True)
class SourceInfo:
"""Information about the source of a sample, i.e. where the data was loaded from."""
#: The path to the dataset
dataset_path: EPath
#: The index of the sample in the dataset
index: Union[str, int]
#: The name of the shard tar file
shard_name: str
#: The names of the files in the shard used to create the sample
file_names: tuple[str, ...]
def add_source_info(sample: Any, *source_infos: SourceInfo) -> Any:
"""Adds a source info to a sample."""
from megatron.energon.flavors.base_dataset import Sample
if isinstance(sample, Sample) or hasattr(sample, "__sources__"):
if sample.__sources__ is None:
sample.__sources__ = source_infos
else:
sample.__sources__ = (*sample.__sources__, *source_infos)
elif isinstance(sample, dict):
if sample.get("__sources__") is None:
sample["__sources__"] = source_infos
else:
sample["__sources__"] = (*sample["__sources__"], *source_infos)
return sample
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
class FlexState(dict):
pass
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.task_encoder.base import (
AugmentTaskEncoder,
Batch,
DefaultTaskEncoder,
TaskEncoder,
batch_list,
batch_pad_stack,
batch_stack,
generic_batch,
stateless,
)
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys, cooker
from megatron.energon.task_encoder.loader import (
get_train_dataset,
get_val_dataset,
get_val_datasets,
)
__all__ = [
"AugmentTaskEncoder",
"basic_sample_keys",
"batch_list",
"batch_pad_stack",
"batch_stack",
"Batch",
"cooker",
"Cooker",
"DefaultTaskEncoder",
"generic_batch",
"get_train_dataset",
"get_val_dataset",
"get_val_datasets",
"stateless",
"TaskEncoder",
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
import functools
import inspect
from abc import ABC
from dataclasses import is_dataclass
from types import MethodType
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Hashable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)
import torch
from typing_extensions import ParamSpec
from megatron.energon.cache import CachePool, DecodeFileStore, FileStore
from megatron.energon.edataclass import edataclass
from megatron.energon.flavors import (
CrudeSample,
CrudeWebdataset,
PinMemoryMixin,
Sample,
SampleDecoder,
SavableDataset,
)
from megatron.energon.flavors.base_dataset import ExtendableDataclassMixin
from megatron.energon.metadataset.loader_interface import DatasetBlendMode, LoadedDataset
from megatron.energon.rng import SystemRng
from megatron.energon.source_info import SourceInfo
from megatron.energon.task_encoder.cooking import Cooker
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers import (
BatchDataset,
BlendDataset,
ConcatDataset,
EpochizeDataset,
GroupBatchDataset,
LimitDataset,
LogSampleDataset,
MapDataset,
PackingDataset,
ShuffleBufferDataset,
)
from megatron.energon.wrappers.repeat_dataset import RepeatDataset
T = TypeVar("T")
V = TypeVar("V")
T_sample = TypeVar("T_sample", bound=Sample)
T_encoded_sample = TypeVar("T_encoded_sample")
T_raw_batch = TypeVar("T_raw_batch")
T_batch = TypeVar("T_batch")
FeatureBatcher = Callable[[List[Any]], Any]
def generic_batch(batch: List[Any]) -> Any:
"""Based on the types/shapes of the batch: Will either pad and stack, or return as list.
Recurses structures (dict, dataclass, namedtuple) and applies the same logic to each field."""
if isinstance(batch[0], torch.Tensor):
return batch_pad_stack(batch)
elif isinstance(batch[0], dict):
return {key: generic_batch([sample[key] for sample in batch]) for key in batch[0].keys()}
elif is_dataclass(batch[0]):
if hasattr(batch[0], "from_samples"):
# The dataclass defines a method for batching samples
return batch[0].from_samples(batch)
return type(batch[0])(
**{
field.name: generic_batch([getattr(sample, field.name) for sample in batch])
for field in dataclasses.fields(batch[0])
}
)
elif isinstance(batch[0], tuple) and hasattr(batch[0], "_fields"):
# NamedTuple
return type(batch[0])(
**{
field: generic_batch([getattr(sample, field) for sample in batch])
for field in batch[0]._fields
}
)
else:
return batch_list(batch)
def batch_stack(batch: List[Any]) -> Any:
"""Stack a batch of tensors."""
return torch.stack(batch, dim=0)
def batch_pad_stack(batch: List[Any]) -> Any:
"""Stack a batch of arbitrary-sized tensors padded with 0s."""
max_size = [max(b.shape[dim] for b in batch) for dim in range(batch[0].ndim)]
batch_tensor = batch[0].new_zeros((len(batch), *max_size))
for i, b in enumerate(batch):
batch_tensor[(i, *(slice(0, s) for s in b.shape))] = b
# Pad all tensors to max_size
return batch_tensor
def batch_list(batch: List[Any]) -> Any:
"""Stack a batch of tensors padded with 0s."""
return batch
P = ParamSpec("P")
@overload
def stateless(
*, restore_seeds: bool = False, failure_tolerance: Optional[int] = None
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
@overload
def stateless(fn: Callable[P, T]) -> Callable[P, T]: ...
def stateless(
fn: Optional[Callable[..., T]] = None,
*,
restore_seeds: bool = False,
failure_tolerance: Optional[int] = None,
) -> Union[Callable[[Callable[..., T]], Callable[..., T]], Callable[..., T]]:
"""Decorator to mark a function of the task encoder as restorable.
Args:
fn: The function to decorate.
restore_seeds: Whether to restore the seeds for the function. I.e. the seeds are set
from the sample index and the worker seed, such that they can be restored when a sample
is restored from that function.
failure_tolerance: The number of consecutive exceptions that are handled, after which a `FatalSampleError` is
raised for this function. Set to 0 to disable.
Usage:
.. code-block:: python
@stateless
def encode_sample(self, sample: T_sample) -> T_encoded_sample:
...
# Or if randomness is used (e.g. for augmentations):
@stateless(restore_seeds=True)
def encode_sample(self, sample: T_sample) -> T_encoded_sample:
...
"""
if fn is None:
return lambda f: stateless(
f, restore_seeds=restore_seeds, failure_tolerance=failure_tolerance
)
if restore_seeds:
worker_seed = None
@functools.wraps(fn)
def seed_wrapper_generator(self, *args, **kwargs):
nonlocal worker_seed
if worker_seed is None:
worker_seed = WorkerConfig.active_worker_config.worker_seed()
# Save the RNG states and set the new seed
outer_rng_state = SystemRng.save_state()
# Before constructing the generator and before the first
# iteration, set inner RNG based on seed computed
# from worker_seed and current sample index
SystemRng.seed_args(worker_seed, self.current_sample_index)
it = iter(fn(self, *args, **kwargs))
inner_rand_state = None
while True:
if inner_rand_state is not None:
# Restore inner random state before calling the generator
# This will not be done on the first iteration
SystemRng.restore_state(inner_rand_state)
try:
# Now call the generator. This will yield the sample
# But note it may also throw an exception or a StopIteration
sample = next(it)
# Save inner random state after calling the generator
inner_rand_state = SystemRng.save_state()
except StopIteration:
# We're stopping here, but the outer random state
# will be restored before returning (in finally below)
break
finally:
# Restore outer rand state before yielding or when an exception was raised
SystemRng.restore_state(outer_rng_state)
# Now yield the sample.
# This will give control back to the caller who may
# change the random state.
yield sample
# Save outer random state after yielding
outer_rng_state = SystemRng.save_state()
@functools.wraps(fn)
def seed_wrapper(self, *args, **kwargs):
nonlocal worker_seed
if worker_seed is None:
worker_seed = WorkerConfig.active_worker_config.worker_seed()
# Save the RNG states and set the new seed
rng_state = SystemRng.save_state()
SystemRng.seed_args(worker_seed, self.current_sample_index)
try:
return fn(self, *args, **kwargs)
finally:
# Restore the RNGs
SystemRng.restore_state(rng_state)
if inspect.isgeneratorfunction(fn):
setattr(seed_wrapper_generator, "__stateless__", True)
return seed_wrapper_generator
else:
setattr(seed_wrapper, "__stateless__", True)
return seed_wrapper
setattr(fn, "__stateless__", True)
if failure_tolerance is not None:
setattr(fn, "__failure_tolerance__", failure_tolerance)
return fn
def get_stateless(fn: Callable) -> bool:
"""Get whether a function is stateless."""
return getattr(fn, "__stateless__", False)
def get_failure_tolerance(
fn: Callable, default_failure_tolerance: Optional[int] = None
) -> Optional[int]:
"""Get the failure tolerance of a function."""
return getattr(fn, "__failure_tolerance__", default_failure_tolerance)
@edataclass
class Batch(PinMemoryMixin, ExtendableDataclassMixin):
"""Base class for a batch dataclass. Provides a default implementation for pinning memory.
Additionally, it provides a future safe implementation for creating an instance from another
batch `Batch.derive_from`."""
#: Uniquely identifies each sample in the dataset.
__key__: list[str]
#: Key for restoring the sample. This is used to restore the sample from a checkpoint. It
# should be a (nested) tuple of strings and integers, which can be used to index the dataset.
__restore_key__: Tuple[Union[str, int, tuple], ...]
#: A dataset may define a subflavors to distinguish between samples of the same sample type.
__subflavors__: Optional[list[Optional[Dict[str, Any]]]] = None
#: Information about the source of the sample, i.e. where the data was loaded from.
__sources__: Optional[tuple[SourceInfo, ...]] = None
@classmethod
def derive_from(cls: Type[T_batch], base_batch: "Batch", **kwargs) -> T_batch:
"""
Uses the base fields of `Batch` from base_batch (i.e. __key__, __restore_key__, __subflavors__, __sources__)
and creates a new batch with the kwargs as fields. This is useful for creating new batches, while keeping the
metadata of the base batch.
Use like::
.. code-block:: python
def encode_batch(batch: RawBatch) -> Batch:
return Batch.derive_from(batch, field1=batch.field1 + 1)
Args:
base_batch: The base batch to copy the base fields / metadata from.
kwargs: The fields of the new batch.
Returns:
The new batch.
"""
base_kwargs = {
field.name: getattr(base_batch, field.name) for field in dataclasses.fields(Batch)
}
return cls(
**base_kwargs,
**kwargs,
)
@classmethod
def from_samples(cls: Type[T_batch], samples: Sequence[Sample], **kwargs) -> T_batch:
"""
Creates a batch from samples to be batched. Tensors will be padded and stacked, other types will be put into
lists. This is the default implementation for `Batch.from_samples`.
Args:
samples: The samples to batch.
kwargs: Additional (overriding) fields of the batch.
Returns:
The constructed batch.
"""
assert all(dataclasses.is_dataclass(scls) for scls in samples), (
"Samples must be dataclasses"
)
# assert dataclasses.is_dataclass(cls), "Batch must be dataclass"
init_args = {}
fields = dataclasses.fields(cls)
for field in fields:
if field.name in kwargs:
init_args[field.name] = kwargs[field.name]
elif field.name == "__sources__":
if any(sample.__sources__ is not None for sample in samples):
# Special handling, needs flattening
init_args[field.name] = tuple(
source
for sample in samples
if sample.__sources__
for source in sample.__sources__
)
elif field.name == "__subflavors__":
if any(sample.__subflavors__ is not None for sample in samples):
init_args[field.name] = [
sample.__subflavors__ for sample in samples if sample.__subflavors__
]
else:
value = [getattr(sample, field.name) for sample in samples]
if len(samples) > 0 and isinstance(samples[0], torch.Tensor):
value = batch_pad_stack(value)
init_args[field.name] = value
return cls(**init_args)
class TaskEncoder(ABC, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]):
"""
Base class for task encoders.
Task encoding follows these steps:
0. Data comes from the dataset
1. :meth:`megatron.energon.TaskEncoder.encode_sample` / :meth:`megatron.energon.TaskEncoder.preencode_sample` is called on each sample
2. :meth:`megatron.energon.TaskEncoder.select_samples_to_pack` is called on the buffer of samples
3. :meth:`megatron.energon.TaskEncoder.postencode_sample` is called on each sample of the current pack
4. :meth:`megatron.energon.TaskEncoder.pack_selected_samples` is called on the selected sample pack
5. :meth:`megatron.energon.TaskEncoder.batch` is called on the list of encoded samples
6. :meth:`megatron.energon.TaskEncoder.encode_batch` is called on the batch
7. yield to main process
8. :meth:`megatron.energon.Batch.to_device` is called on the encoded batch
9. resulting encoded batch is passed to the network
"""
__default_failure_tolerance__: Optional[int] = 100
cookers: Sequence[Cooker[T_sample]] = ()
#: Internal: List of registered cookers. Will be the same as `cookers` after registering cookers.
_registered_cookers: List[Cooker[T_sample]]
#: The decoder to use for decoding samples. Set manually as needed to override options.
decoder: Optional[SampleDecoder] = SampleDecoder()
@stateless
def cook_crude_sample(
self,
sample: Union[T_sample, CrudeSample],
get_primary_aux: Callable[[], FileStore],
**aux: FileStore,
) -> T_sample:
"""
Cooks a crude sample.
Args:
sample: The sample to cook.
get_primary_aux: A function that returns the (cached) primary auxiliary dataset.
**aux: The auxiliary side dishes to use for cooking.
Returns: The cooked sample.
"""
if isinstance(sample, CrudeSample):
for cooker in self.cookers:
if cooker.is_match(sample):
assert get_stateless(cooker.cook), "Cooker must be stateless"
if not cooker.need_primary and not cooker.need_cache:
kwargs = aux
else:
kwargs: dict = {}
if cooker.need_primary:
kwargs["primary"] = get_primary_aux()
kwargs.update(aux)
if cooker.need_cache:
kwargs["cache"] = self.cache
return cooker.cook(sample, **kwargs)
raise NotImplementedError(
"You are using crude samples but not providing a way to cook them: "
f"Sample key={sample['__key__']}, subflavors={sample['__subflavors__']}, "
f"self.cookers={self.cookers}"
)
else:
assert isinstance(sample, Sample), "Sample must be a complete Sample or a CrudeSample"
return sample
def _is_overridden(
self, bound_method: Callable[..., Any], bases: Optional[Sequence[Type[Any]]] = None
) -> bool:
"""Check if a method is overridden by a subclass of the base class(es).
By default, only TaskEncoder is used as a base class.
This is mainly used for optimization purposes. If the default method
is a no-op, we can skip it entirely unless the user has overridden it.
Args:
bound_method: The method to check.
bases: The base classes to check against.
Returns:
True if the method is overridden outside of TaskEncoder, False otherwise.
"""
if not isinstance(bound_method, MethodType):
# If the method is not bound, it is always overridden
return True
# Get the underlying function
func = bound_method.__func__
# Check if the subclass method matches any of the base class methods
if bases is None:
bases = (TaskEncoder,)
return not any(getattr(base, func.__name__) is func for base in bases)
@stateless
def encode_sample(
self, sample: T_sample
) -> Union[T_encoded_sample, Generator[T_encoded_sample, None, None]]:
"""Encode a single sample. May raise :exc:`megatron.energon.SkipSample` to skip a sample.
Alternatively, this can be a generator that yields (or ignores) new samples.
If this is defined, :func:`preencode_sample` and :func:`postencode_sample` must not be defined.
"""
return sample
@stateless
def preencode_sample(
self, sample: T_sample
) -> Union[T_sample, Generator[T_sample, None, None]]:
"""Pre-encode a single sample. May raise :exc:`megatron.energon.SkipSample` to skip a sample.
Alternatively, this can be a generator that yields (or ignores) new samples.
Use in conjunction with packing and caching.
If this is defined, :func:`encode_sample` must not be defined.
"""
return sample
@stateless
def postencode_sample(
self, sample: T_sample
) -> Union[T_encoded_sample, Generator[T_encoded_sample, None, None]]:
"""Post-encode a single sample. May raise :exc:`megatron.energon.SkipSample` to skip a sample.
Alternatively, this can be a generator that yields (or ignores) new samples.
Use in conjunction with packing and caching.
If this is defined, :func:`encode_sample` must not be defined.
"""
return sample
@stateless
def batch(self, samples: List[T_encoded_sample]) -> T_raw_batch:
"""Move a batch to a device. May raise :exc:`megatron.energon.SkipSample` to skip a batch."""
return self._batch(samples, type(samples[0]))
def batch_group_criterion(self, sample: T_encoded_sample) -> Tuple[Hashable, Optional[int]]:
"""
Return a group criterion for the sample. Default implementation does not group
(effectively, it returns a single value `(None, None)`, thus only one group is used).
Returns the key of the bucket to put this sample into, and the size of the bucket (=batch size).
The bucket size must always be the same for the same bucket key.
May raise :exc:`megatron.energon.SkipSample` to skip a batch.
"""
return None, None
@stateless
def encode_batch(self, batch: T_raw_batch) -> Union[T_batch, Generator[T_batch, None, None]]:
"""Encode a batch of samples. May raise :exc:`megatron.energon.SkipSample` to skip a batch.
Alternatively, this can be a generator that yields (or ignores) new batches."""
return batch
def _batch(
self,
samples: List[T_encoded_sample],
result_type: Type[T_raw_batch],
actions: Optional[Dict[str, FeatureBatcher]] = None,
default_action: FeatureBatcher = generic_batch,
) -> T_raw_batch:
"""
Batch a list of samples.
Args:
samples: The samples to batch
result_type: Type of the result (might be dict, dataclass, or namedtuple)
actions: For each field (=key), may specify a specific batcher
default_action: The batcher to apply to all fields not in `action`
Returns:
The batched result
"""
if dataclasses.is_dataclass(result_type) and hasattr(result_type, "from_samples"):
return result_type.from_samples(samples)
# Get dict of samples
if isinstance(samples[0], dict):
list_samples = {key: [sample[key] for sample in samples] for key in samples[0].keys()}
elif is_dataclass(samples[0]):
list_samples = {
field.name: [getattr(sample, field.name) for sample in samples]
for field in dataclasses.fields(samples[0])
}
elif isinstance(samples[0], tuple) and hasattr(samples[0], "_fields"):
# NamedTuple
list_samples = {
field: [getattr(sample, field) for sample in samples]
for field in samples[0]._fields
}
else:
raise ValueError("Unrecognized sample type.")
# Convert each field
if actions is not None:
list_samples = {
key: default_action(value) if key not in actions else actions[key](value)
for key, value in list_samples.items()
}
else:
list_samples = {key: default_action(value) for key, value in list_samples.items()}
# Construct result
if issubclass(result_type, dict):
return list_samples
elif dataclasses.is_dataclass(result_type) or issubclass(result_type, tuple):
# DataClass or NamedTuple
return result_type(**list_samples)
else:
raise ValueError("Unrecognized result type.")
def select_samples_to_pack(
self, samples: List[T_encoded_sample]
) -> List[List[T_encoded_sample]]:
"""
For packing, selects the samples to be packed together.
Packing is only active when packing_buffer_size is set.
Internally this stage is called "pre_packing".
Args:
samples: The samples to pre-pack. A full buffer will be passed into the function.
Returns: The pre-packed samples as a list of lists of samples.
"""
raise NotImplementedError("Packing only effective when overridden.")
def pack_selected_samples(self, samples: List[T_encoded_sample]) -> T_encoded_sample:
"""
Given one set of samples to pack, returns the final packed sample.
Packing is only active when packing_buffer_size is set.
Internally this stage is called "final_packing".
Args:
samples: The samples to pack into a single sample
Returns: The final packed sample.
"""
raise NotImplementedError("Packing only effective when overridden.")
def build_batch(
self,
dataset: SavableDataset[T_encoded_sample],
*,
batch_size: Optional[int],
batch_drop_last: bool = False,
packing_buffer_size: Optional[int] = None,
worker_config: WorkerConfig,
) -> SavableDataset[T_raw_batch]:
"""Applies the batcher to the dataset."""
dataset: SavableDataset[Any]
if packing_buffer_size is not None:
select_samples_to_pack_provided = self._is_overridden(self.select_samples_to_pack)
pack_selected_samples_provided = self._is_overridden(self.pack_selected_samples)
assert select_samples_to_pack_provided and pack_selected_samples_provided, (
"Both select_samples_to_pack and pack_selected_samples methods must be provided in the TaskEncoder when using packing_buffer_size"
)
if self._is_overridden(self.postencode_sample):
post_encode_fn = self.postencode_sample
else:
post_encode_fn = None
dataset = PackingDataset(
dataset,
buffer_size=packing_buffer_size,
pre_packer=self.select_samples_to_pack,
final_packer=self.pack_selected_samples,
final_packer_stateless=get_stateless(self.pack_selected_samples),
sample_encoder=post_encode_fn,
sample_encoder_stateless=True
if post_encode_fn is None
else get_stateless(post_encode_fn),
worker_config=worker_config,
pre_packer_failure_tolerance=get_failure_tolerance(
self.select_samples_to_pack, self.__default_failure_tolerance__
),
final_packer_failure_tolerance=get_failure_tolerance(
self.pack_selected_samples, self.__default_failure_tolerance__
),
sample_encoder_failure_tolerance=None
if post_encode_fn is None
else get_failure_tolerance(post_encode_fn, self.__default_failure_tolerance__),
)
elif self._is_overridden(self.postencode_sample):
dataset = MapDataset(
dataset,
self.postencode_sample,
worker_config=worker_config,
stateless_map_fn=get_stateless(self.postencode_sample),
failure_tolerance=get_failure_tolerance(
self.postencode_sample, self.__default_failure_tolerance__
),
)
if self._is_overridden(self.batch_group_criterion):
dataset = GroupBatchDataset(
dataset,
fixed_batch_size=batch_size,
sample_group_key=self.batch_group_criterion,
batcher=self.batch,
drop_last=batch_drop_last,
worker_config=worker_config,
failure_tolerance=get_failure_tolerance(
self.batch, self.__default_failure_tolerance__
),
)
if self._is_overridden(self.encode_batch):
dataset = MapDataset(
dataset,
self.encode_batch,
worker_config=worker_config,
stateless_map_fn=get_stateless(self.encode_batch),
failure_tolerance=get_failure_tolerance(
self.encode_batch, self.__default_failure_tolerance__
),
)
else:
# No grouping is active
if batch_size is not None:
dataset = BatchDataset(
dataset,
batch_size=batch_size,
batcher=self.batch,
batcher_stateless=get_stateless(self.batch),
drop_last=batch_drop_last,
worker_config=worker_config,
failure_tolerance=get_failure_tolerance(
self.batch, self.__default_failure_tolerance__
),
)
if self._is_overridden(self.encode_batch):
dataset = MapDataset(
dataset,
self.encode_batch,
worker_config=worker_config,
stateless_map_fn=get_stateless(self.encode_batch),
failure_tolerance=get_failure_tolerance(
self.encode_batch, self.__default_failure_tolerance__
),
)
return dataset
def build_cook_crude_sample(
self,
dataset: SavableDataset[Union[T_sample, dict]],
*,
worker_config: WorkerConfig,
subflavors: Dict[str, Any],
get_primary_aux: Callable[[], FileStore],
aux: Optional[Dict[str, FileStore]] = None,
) -> SavableDataset[T_sample]:
"""Applies the sample cooker to the dataset if we have cookers registered."""
assert self.cookers, "No cookers registered, but got crude dataset."
if aux is not None and self.decoder is not None:
aux = {k: DecodeFileStore(v, decoder=self.decoder) for k, v in aux.items()}
# Cache the primary auxiliary dataset for this dataset, i.e. construct it once when needed
primary_aux = None
def _get_primary_aux():
nonlocal primary_aux
if primary_aux is None:
try:
if aux is not None:
primary_aux = aux.get("primary")
if primary_aux is None:
primary_aux = get_primary_aux()
assert primary_aux is not None, "Primary auxiliary dataset must always exist"
if self.decoder is not None:
primary_aux = DecodeFileStore(primary_aux, decoder=self.decoder)
except Exception as e:
# Make the exception throw through for the sample being loaded
raise SystemError("Error getting primary auxiliary dataset") from e
return primary_aux
if aux is not None:
cook_fn = functools.partial(
self.cook_crude_sample, get_primary_aux=_get_primary_aux, **aux
)
else:
cook_fn = functools.partial(self.cook_crude_sample, get_primary_aux=_get_primary_aux)
return MapDataset(
dataset,
cook_fn,
worker_config=worker_config,
stateless_map_fn=True,
map_fn_config=dict(
cookers=[
dict(
cook=SavableDataset._function_config(cooker.cook),
has_subflavors=cooker.has_subflavors,
)
for cooker in self.cookers
],
subflavors=subflavors,
),
failure_tolerance=get_failure_tolerance(cook_fn, self.__default_failure_tolerance__),
)
def _load_dataset(
self, dataset: LoadedDataset, worker_rotation_offset: int, worker_config: WorkerConfig
) -> SavableDataset[T_sample]:
"""Loads a train dataset, optionally cooking the samples."""
if dataset.dataset.__sample_type__ == CrudeSample:
return self.build_cook_crude_sample(
dataset.dataset.build(worker_rotation_offset=worker_rotation_offset),
worker_config=worker_config,
subflavors=dataset.dataset.subflavors,
get_primary_aux=dataset.dataset.as_file_store,
aux=dataset.aux,
)
else:
assert dataset.aux is None, "Aux is not supported for non-crude datasets."
return dataset.dataset.build(worker_rotation_offset=worker_rotation_offset)
def build_encode_sample(
self,
dataset: SavableDataset[T_sample],
*,
worker_config: WorkerConfig,
) -> SavableDataset[T_encoded_sample]:
"""Applies the sample encoder to the dataset."""
if self._is_overridden(self.preencode_sample):
pre_encode_fn = self.preencode_sample
assert not self._is_overridden(
self.encode_sample, bases=(TaskEncoder, DefaultTaskEncoder)
), "Cannot have both pre- and post-encode functions defined."
elif self._is_overridden(self.encode_sample):
pre_encode_fn = self.encode_sample
else:
pre_encode_fn = None
if pre_encode_fn is not None:
dataset = MapDataset(
dataset,
pre_encode_fn,
worker_config=worker_config,
stateless_map_fn=get_stateless(pre_encode_fn),
failure_tolerance=get_failure_tolerance(
pre_encode_fn, self.__default_failure_tolerance__
),
)
return dataset
def build_train_datasets(
self,
*,
datasets: List[LoadedDataset],
worker_config: WorkerConfig,
batch_size: Optional[int],
batch_drop_last: bool = False,
packing_buffer_size: Optional[int] = None,
virtual_epoch_length: int = 0,
shuffle_buffer_size: Optional[int] = None,
blend_mode: DatasetBlendMode = DatasetBlendMode.NONE,
repeat: bool = True,
) -> SavableDataset[T_batch]:
"""Combines train datasets to a single dataset."""
# Check if there's a CrudeWebdataset but no cookers
for dataset in datasets:
if isinstance(dataset.dataset, CrudeWebdataset):
assert self.cookers, "CrudeWebdataset found, but no cookers registered."
global_workers = max(1, worker_config.num_workers) * worker_config.world_size
rotation_lengths = [len(dataset.dataset) for dataset in datasets]
for i in range(1, len(rotation_lengths)):
rotation_lengths[i] += rotation_lengths[i - 1]
worker_rotation_offsets = [
rotation_length % global_workers for rotation_length in [0] + rotation_lengths[:-1]
]
if blend_mode == DatasetBlendMode.DATASET_WEIGHT:
assert repeat, (
"If repeat is False, the datasets can only be repeated or have no mode. Cannot blend with dataset weights."
)
inner_datasets = [
(
RepeatDataset(
self._load_dataset(
dataset, worker_rotation_offset, worker_config=worker_config
),
worker_config=worker_config,
),
1.0 if dataset.weight is None else float(dataset.weight),
)
for dataset, worker_rotation_offset in zip(datasets, worker_rotation_offsets)
]
# Already repeating the inner datasets, so no need to repeat again
repeat = False
elif blend_mode == DatasetBlendMode.SAMPLE_REPETITIONS or (
not repeat and blend_mode == DatasetBlendMode.NONE
):
inner_datasets = [
(
(
self._load_dataset(
dataset, worker_rotation_offset, worker_config=worker_config
)
if dataset.repetitions is None or dataset.repetitions == 1
else RepeatDataset(
self._load_dataset(
dataset, worker_rotation_offset, worker_config=worker_config
),
repeats=dataset.repetitions,
worker_config=worker_config,
)
),
len(dataset.dataset)
* (1 if dataset.repetitions is None else dataset.repetitions),
)
for dataset, worker_rotation_offset in zip(datasets, worker_rotation_offsets)
]
else:
inner_datasets = [
(
RepeatDataset(
self._load_dataset(
dataset, worker_rotation_offset, worker_config=worker_config
),
worker_config=worker_config,
),
1.0,
)
for dataset, worker_rotation_offset in zip(datasets, worker_rotation_offsets)
]
# Already repeating the inner datasets, so no need to repeat again
repeat = False
if len(inner_datasets) > 1:
# The worker offset for each dataset is the cumsum of the dataset lengths, but modulo the
# global number of workers.
dataset = BlendDataset(
*[inner_dataset[:2] for inner_dataset in inner_datasets],
worker_config=worker_config,
)
elif len(datasets) == 1:
dataset = inner_datasets[0][0]
else:
raise ValueError("No datasets given.")
if repeat:
# Still need to repeat the dataset
dataset = RepeatDataset(dataset, worker_config=worker_config)
if shuffle_buffer_size is not None and shuffle_buffer_size > 1:
dataset = ShuffleBufferDataset(
dataset,
size=shuffle_buffer_size,
worker_config=worker_config,
)
dataset = self.build_encode_sample(dataset, worker_config=worker_config)
dataset = self.build_batch(
dataset,
batch_size=batch_size,
batch_drop_last=batch_drop_last,
packing_buffer_size=packing_buffer_size,
worker_config=worker_config,
)
if virtual_epoch_length > 0:
dataset = EpochizeDataset(
dataset,
length=virtual_epoch_length,
worker_config=worker_config,
)
if worker_config.should_log(level=1):
dataset = LogSampleDataset(dataset, mode="train", worker_config=worker_config)
return dataset
def build_val_datasets(
self,
*,
datasets: List[LoadedDataset],
worker_config: WorkerConfig,
batch_size: int,
batch_drop_last: bool = False,
packing_buffer_size: Optional[int] = None,
limit: Optional[int] = None,
) -> SavableDataset[T_batch]:
"""Combines val datasets to a single dataset."""
# Check if there's a CrudeWebdataset but no cookers
for dataset in datasets:
if isinstance(dataset, CrudeWebdataset):
assert self.cookers, "CrudeWebdataset found, but no cookers registered."
global_workers = max(1, worker_config.num_workers) * worker_config.world_size
rotation_lengths = [len(dataset.dataset) for dataset in datasets]
for i in range(1, len(rotation_lengths)):
rotation_lengths[i] += rotation_lengths[i - 1]
worker_rotation_offsets = [
rotation_length % global_workers for rotation_length in [0] + rotation_lengths[:-1]
]
if len(datasets) > 1:
dataset = ConcatDataset(
*[
self._load_dataset(dataset, worker_rotation_offset, worker_config)
for dataset, worker_rotation_offset in zip(datasets, worker_rotation_offsets)
],
worker_config=worker_config,
)
elif len(datasets) == 1:
dataset = self._load_dataset(datasets[0], worker_rotation_offsets[0], worker_config)
else:
raise ValueError("No datasets given.")
dataset = self.build_encode_sample(dataset, worker_config=worker_config)
dataset = self.build_batch(
dataset,
batch_size=batch_size,
batch_drop_last=batch_drop_last,
packing_buffer_size=packing_buffer_size,
worker_config=worker_config,
)
if limit is not None and limit > 0:
dataset = LimitDataset(
dataset,
length=limit,
worker_config=worker_config,
reset_after_epoch=True,
)
if worker_config.should_log(level=2):
dataset = LogSampleDataset(dataset, mode="val", worker_config=worker_config)
return dataset
@property
def current_batch_index(self) -> int:
"""Returns the current index for the next batch yielded from the current worker. Each batch
on the current rank will get a strictly increasing unique number. Counting happens on each
rank separately (i.e. each rank will get the same numbers for same batch index)."""
assert WorkerConfig.active_worker_config is not None, (
"The batch_index can only be fetched within the worker, and to be usable, you must use the get_(savable_)loader methods provided from the package."
)
return WorkerConfig.active_worker_config.active_worker_batch_index
@property
def current_sample_index(self) -> int:
"""Returns the current index for the next sample yielded from the current routine (e.g.
for `encode_sample`, `batch`, or `encode_batch`). Each routine will get a number
representing the number of calls to that function. Across workers, this number will be
unique, but it is not synced across workers, thus it may raise in different intervals (e.g.
if batching does not work the same for all batches). When restoring a sample, this number is
also restored and can be relied on for deterministic randomness reproduction of a sample."""
assert WorkerConfig.active_worker_config is not None, (
"The batch_index can only be fetched within the worker, and to be usable, you must use the get_(savable_)loader methods provided from the package."
)
return WorkerConfig.active_worker_config.active_worker_sample_index
@property
def cache(self) -> CachePool:
"""Returns the cache pool to use for caching out sample data to disk (for use with cookers / aux file stores).
This is set and configured externally by the loader."""
assert WorkerConfig.active_worker_config is not None, (
"The cache can only be fetched within the worker, and to be usable, you must use the get_(savable_)loader methods provided from the package."
)
assert WorkerConfig.active_worker_config._cache_pool is not None, (
"Cache pool must be set by the loader."
)
return WorkerConfig.active_worker_config._cache_pool
class DefaultTaskEncoder(
TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch],
ABC,
Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch],
):
"""
The default task encoder supports automagically mapping to target types.
You may override any methods to customize the behavior. By default, `encode_sample` is the
identity function, `batch` calls `\\_batch` with the type of the first sample, and `encode\\_batch`
is also the identity function. If you set any of `encoded_sample_type`, `raw_batch_type` or
`batch_type`, the corresponding method return that type, where it automatically maps the fields
(by name) to your new type.
"""
_encoded_sample_type: Optional[Type[T_encoded_sample]]
_raw_batch_type: Optional[Type[T_raw_batch]]
_batch_type: Optional[Type[T_batch]]
def __init__(
self,
*,
encoded_sample_type: Optional[Type[T_encoded_sample]] = None,
raw_batch_type: Optional[Type[T_raw_batch]] = None,
batch_type: Optional[Type[T_batch]] = None,
):
"""
Initialize the default task encoder.
Types may be:
- A `@dataclass` class: Return that typed dataclass. Field names must match the input
fields.
- A `NamedTuple` class: Return that typed namedtuple. Field names must match the input
fields.
- `dict`: Simply return the input as dict with field names as keys.
Args:
encoded_sample_type: Type of encoded samples (before batching)
raw_batch_type: Type of the batched samples (after batching)
batch_type: Type of the encoded batched samples
cache: Cache pool to use for caching. If not provided, a no-op cache pool will be used.
"""
self._encoded_sample_type = encoded_sample_type
self._raw_batch_type = raw_batch_type
self._batch_type = batch_type
@stateless
def encode_sample(
self, sample: T_sample
) -> Union[T_encoded_sample, Generator[T_encoded_sample, None, None]]:
"""Encode a single sample. The default implementation converts to the
_encoded_sample_type."""
if self._encoded_sample_type is None or isinstance(sample, self._encoded_sample_type):
return sample
if is_dataclass(sample):
fields = {
field.name: getattr(sample, field.name) for field in dataclasses.fields(sample)
}
elif isinstance(sample, tuple) and hasattr(sample, "_fields"):
fields = {field: getattr(sample, field) for field in sample._fields}
elif isinstance(sample, dict):
fields = sample
else:
raise ValueError("Unrecognized sample type.")
if issubclass(self._encoded_sample_type, dict):
return fields
elif dataclasses.is_dataclass(self._encoded_sample_type) or issubclass(
self._encoded_sample_type, tuple
):
# DataClass or NamedTuple
return self._encoded_sample_type(**fields)
else:
raise ValueError("Unrecognized encoded sample type.")
@stateless
def batch(self, samples: List[T_encoded_sample]) -> T_raw_batch:
"""Batch a list of samples. The default implementation uses default batching to convert
to _batch_type."""
actions = None
if isinstance(samples[0], Sample):
actions = {
"__subflavors__": lambda x: x,
}
return self._batch(
samples,
type(samples[0]) if self._raw_batch_type is None else self._raw_batch_type,
actions=actions,
)
@stateless
def encode_batch(self, batch: T_raw_batch) -> Union[T_batch, Generator[T_batch, None, None]]:
"""Encode a batch of samples. The default implementation converts to the
_encoded_batch_type."""
if self._batch_type is None or self._raw_batch_type == self._batch_type:
return batch
if is_dataclass(batch):
fields = {field.name: getattr(batch, field.name) for field in dataclasses.fields(batch)}
elif isinstance(batch, tuple) and hasattr(batch, "_fields"):
fields = {field: getattr(batch, field) for field in batch._fields}
elif isinstance(batch, dict):
fields = batch
else:
raise ValueError("Unrecognized sample type.")
if issubclass(self._batch_type, dict):
return fields
elif dataclasses.is_dataclass(self._batch_type) or issubclass(self._batch_type, tuple):
# DataClass or NamedTuple
return self._batch_type(**fields)
else:
raise ValueError("Unrecognized encoded sample type.")
class AugmentTaskEncoder(
TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch],
Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch],
):
"""Augment a task encoder with additional functionality. By default, delegates everything to the
original task encoder."""
def __init__(self, task_encoder: TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch]):
"""Initialize the augmenting task encoder.
Args:
task_encoder: The delegate task encoder. All calls will by default be forwarded to this.
"""
self._task_encoder = task_encoder
@property
def decoder(self) -> SampleDecoder:
return self._task_encoder.decoder
def encode_sample(self, sample: T_sample) -> T_encoded_sample:
return self._task_encoder.encode_sample(sample)
def batch(self, samples: List[T_encoded_sample]) -> T_raw_batch:
return self._task_encoder.batch(samples)
def encode_batch(self, batch_data: T_raw_batch) -> T_batch:
return self._task_encoder.encode_batch(batch_data)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
import functools
from dataclasses import dataclass
from typing import Callable, Generic, Optional, TypeVar, Union, overload
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.crude import CrudeSample
from megatron.energon.source_info import SourceInfo
T_sample = TypeVar("T_sample", bound=Sample, covariant=True)
F = TypeVar("F", bound=Callable[..., Sample])
@overload
def cooker(
fn: None = None,
) -> Callable[[F], F]: ...
@overload
def cooker(
*,
need_cache: bool = False,
need_primary: bool = False,
) -> Callable[[F], F]: ...
def cooker(
fn: Optional[F] = None,
*,
need_cache: bool = False,
need_primary: bool = False,
) -> Union[
F,
Callable[[F], F],
]:
"""Decorator to mark a function as a cooker, optionally enabling cache and primary dataset
arguments."""
if fn is None:
return functools.partial(
cooker,
need_cache=need_cache,
need_primary=need_primary,
)
@functools.wraps(fn)
def fn_wrapper(*args, **kwargs):
return fn(*args, **kwargs)
setattr(fn_wrapper, "__cooker_need_cache__", need_cache)
setattr(fn_wrapper, "__cooker_need_primary__", need_primary)
return fn_wrapper
def get_cooker_need_cache(fn: Callable[..., T_sample]) -> bool:
"""Get whether a function is a cooker."""
return getattr(fn, "__cooker_need_cache__", False)
def get_cooker_need_primary(fn: Callable[..., T_sample]) -> bool:
"""Get whether a function is a cooker."""
return getattr(fn, "__cooker_need_primary__", False)
@dataclass
class Cooker(Generic[T_sample]):
"""A cooker transforms a crude sample (simple dict) into a specific sample type inheriting
from `Sample`.
The `cook` method performs the transformation, the other fields are used to select the
samples which this cooker can transform. If no filters are provided, the cooker will transform
any `CrudeSample`.
"""
#: The callable that performs the cooking (i.e. loading / transforming the crude sample).
# Signature is:
# `(/, raw_sample: dict, *, primary?: RandomAccessDataset, **aux: RandomAccessDataset, cache?: Cache) -> Sample`.
# `primary` is passed only if want_primary_random_access is true.
# `cache` is passed only if want_cache is true.
cook: Callable[..., T_sample]
#: The subflavors to be present in the sample to be cooked by this cooker. All keys and values
# must match.
has_subflavors: Optional[dict] = None
@property
def need_primary(self) -> bool:
return get_cooker_need_primary(self.cook)
@property
def need_cache(self) -> bool:
return get_cooker_need_cache(self.cook)
def is_match(self, crude_sample: CrudeSample) -> bool:
if self.has_subflavors is not None:
# Checks if the dict entries provided as a filter all match
# the ones in the sample. The sample may have additional entries.
for k, v in self.has_subflavors.items():
if (
k not in crude_sample["__subflavors__"]
or crude_sample["__subflavors__"][k] != v
):
return False
return True
def basic_sample_keys(
crude_sample: dict, additional_source_info: tuple[SourceInfo, ...] = ()
) -> dict:
"""A convenience helper to extract the basic keys from a crude sample,
which you will always need to forward to the cooked sample."""
res = {
field.name: crude_sample[field.name]
for field in dataclasses.fields(Sample)
if field.name in crude_sample
}
if additional_source_info:
res["__sources__"] = (*crude_sample["__sources__"], *additional_source_info)
return res
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from pathlib import Path
from typing import Any, List, Literal, Optional, Tuple, TypeVar, Union
from megatron.energon.epathlib import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors import BaseCoreDatasetFactory, SavableDataset
from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder
from megatron.energon.metadataset import load_dataset
from megatron.energon.task_encoder.base import DefaultTaskEncoder, TaskEncoder, WorkerConfig
T = TypeVar("T", covariant=True)
def _split_kwargs(kwargs: dict) -> dict:
loader_kwargs = {}
try:
loader_kwargs["split_part"] = kwargs.pop("split_part")
except KeyError:
pass
try:
loader_kwargs["dataset_config"] = kwargs.pop("dataset_config")
except KeyError:
pass
try:
loader_kwargs["split_config"] = kwargs.pop("split_config")
except KeyError:
pass
return loader_kwargs
def _split_deprecated_decoder_kwargs(kwargs: dict, task_encoder: TaskEncoder) -> None:
"""
auto_decode: bool = True,
image_decode: ImageDecoder = "torchrgb",
ignore_decoder_errors: bool = False,
av_decode: AVDecoder = "AVDecoder",
video_decode_audio: bool = False,
"""
auto_decode = True
decoder_kwargs = {}
if "auto_decode" in kwargs:
auto_decode = kwargs.pop("auto_decode")
if "image_decode" in kwargs:
decoder_kwargs["image_decode"] = kwargs.pop("image_decode")
if "av_decode" in kwargs:
decoder_kwargs["av_decode"] = kwargs.pop("av_decode")
if "video_decode_audio" in kwargs:
decoder_kwargs["video_decode_audio"] = kwargs.pop("video_decode_audio")
if not auto_decode:
task_encoder.decoder = None
elif len(decoder_kwargs) > 0:
warn_deprecated(
"The following decoder kwargs are deprecated and will be removed in a future version: "
+ ", ".join(decoder_kwargs.keys())
+ ". Instead, set the decoder directly in your task encoder."
)
if (
hasattr(task_encoder, "decoder")
and task_encoder.decoder is not None
and task_encoder.decoder is not DefaultTaskEncoder.decoder
):
# The task encoder already has a decoder set.
# The user might be reusing the task encoder in multiple calls to get_train_dataset
# and get_val_dataset.
# We need to check if the decoder is the same as the one we are setting here.
# If it is, we can return.
if task_encoder.decoder.config() == SampleDecoder(**decoder_kwargs).config():
# It's the same decoder, nothing to do.
return
else:
raise ValueError(
"Task encoder already has a decoder, and you are setting a different decoder, which is not allowed."
)
assert (
not hasattr(task_encoder, "decoder")
or task_encoder.decoder is DefaultTaskEncoder.decoder
), "Task encoder already has a decoder, and setting using deprecated kwargs is not allowed."
task_encoder.decoder = SampleDecoder(**decoder_kwargs)
def get_train_dataset(
path: Union[str, EPath, Path],
*,
split_part: Union[Literal["train"], str] = "train",
worker_config: WorkerConfig,
batch_size: Optional[int],
batch_drop_last: bool = False,
packing_buffer_size: Optional[int] = None,
shuffle_buffer_size: Optional[int],
max_samples_per_sequence: Optional[int],
virtual_epoch_length: int = 0,
shuffle_over_epochs_multiplier: Optional[int] = 1,
task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(),
repeat: bool = True,
**kwargs,
) -> SavableDataset[T]:
"""
Get training data loader with sensible defaults. See `get_dataset` for more details.
The following recipe will be used:
- :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
- `task_encoder.encode_sample`
- (:class:`megatron.energon.MixDataset` if mixing)
- :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
- `task_encoder.encode_batch`
- :class:`megatron.energon.EpochizeDataset` (if `virtual_epoch_length` is set)
Args:
path: Path to the dataset.
split_part: Default split part to use.
worker_config: Worker configuration to use.
batch_size: Size of a batch. If None, do not batch
batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
shuffle_buffer_size: Size of the sample shuffle buffer (before task encoding).
max_samples_per_sequence: If set, limit the number of samples per sample-sequence to this.
virtual_epoch_length: If set, the dataset will be epochized to this length (=iterating
will be suspended and the for-loop returns, next for-loop continues iterating).
Otherwise, the dataset will loop indefinitely.
shuffle_over_epochs_multiplier: Shuffle the shards over this many epochs.
task_encoder: Task encoder to use.
repeat: By default, the inner datasets will loop. If set to False, stop iteration after
one epoch. Must only be set to False in conjunction with blend_epochized in the
metadataset if one is used.
cache_pool: If set, the cache pool to use for the dataset.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The dataloader.
"""
loader = load_dataset(path, **_split_kwargs(kwargs))
_split_deprecated_decoder_kwargs(kwargs, task_encoder)
datasets = loader.get_datasets(
training=True,
split_part=split_part,
worker_config=worker_config,
max_samples_per_sequence=max_samples_per_sequence,
shuffle_over_epochs_multiplier=shuffle_over_epochs_multiplier,
decoder=task_encoder.decoder,
**kwargs,
)
return task_encoder.build_train_datasets(
datasets=datasets.datasets,
worker_config=worker_config,
batch_size=batch_size,
batch_drop_last=batch_drop_last,
packing_buffer_size=packing_buffer_size,
virtual_epoch_length=virtual_epoch_length,
shuffle_buffer_size=shuffle_buffer_size,
blend_mode=datasets.blend_mode,
repeat=repeat,
)
def get_val_dataset(
path: Union[str, EPath, Path],
*,
split_part: Union[Literal["val", "test"], str] = "val",
worker_config: WorkerConfig,
batch_size: int,
batch_drop_last: bool = False,
packing_buffer_size: Optional[int] = None,
limit: Optional[int] = None,
task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(),
**kwargs,
) -> SavableDataset[T]:
"""
Get the validation/test dataset with sensible defaults. See `get_dataset` for more details.
The following recipe will be used:
- :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
- `task_encoder.encode_sample`
- (:class:`megatron.energon.MixDataset` if mixing)
- :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
- :class:`megatron.energon.LimitDataset` (if `limit` is set)
- `task_encoder.encode_batch`
Args:
path: Path to the dataset.
split_part: Default split part to use.
worker_config: Worker configuration to use.
batch_size: Size of a batch
batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
limit: If set, limit the number of batches loaded from the dataset to this.
task_encoder: Task encoder to use.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The loaded dataset.
"""
_split_deprecated_decoder_kwargs(kwargs, task_encoder)
loader = load_dataset(path, **_split_kwargs(kwargs))
datasets = loader.get_datasets(
training=False,
split_part=split_part,
worker_config=worker_config,
decoder=task_encoder.decoder,
**kwargs,
)
return task_encoder.build_val_datasets(
datasets=datasets.datasets,
worker_config=worker_config,
batch_size=batch_size,
batch_drop_last=batch_drop_last,
packing_buffer_size=packing_buffer_size,
limit=limit,
)
def get_val_datasets(
path: Union[str, EPath, Path],
*,
split_part: Union[Literal["val", "test"], str] = "val",
worker_config: WorkerConfig,
batch_size: int,
batch_drop_last: bool = False,
packing_buffer_size: Optional[int] = None,
limit: Optional[int] = None,
task_encoder: TaskEncoder[Any, Any, Any, T] = DefaultTaskEncoder(),
**kwargs,
) -> List[Tuple[SavableDataset[T], BaseCoreDatasetFactory]]:
"""
Get the validation/test dataset with sensible defaults. See `get_dataset` for more details.
The following recipe will be used:
- :func:`megatron.energon.dataset_config.get_dataset_from_config` (loads the raw dataset)
- `task_encoder.encode_sample`
- (:class:`megatron.energon.MixDataset` if mixing)
- :class:`megatron.energon.BatchDataset` with `task_encoder.batch` for collation
- :class:`megatron.energon.LimitDataset` (if `limit` is set)
- `task_encoder.encode_batch`
Args:
path: Path to the dataset.
split_part: Default split part to use.
worker_config: Worker configuration to use.
batch_size: Size of a batch
batch_drop_last: If true, drop the last batch if it is smaller than `batch_size`.
limit: If set, limit the number of batches loaded from the dataset to this.
task_encoder: Task encoder to use.
**kwargs: Additional arguments to the dataset constructor.
Returns:
The loaded val datasets, with the source datasets.
"""
_split_deprecated_decoder_kwargs(kwargs, task_encoder)
loader = load_dataset(path, **_split_kwargs(kwargs))
datasets = loader.get_datasets(
training=False,
split_part=split_part,
worker_config=worker_config,
decoder=task_encoder.decoder,
**kwargs,
)
return [
(
task_encoder.build_val_datasets(
datasets=[dataset],
worker_config=worker_config,
batch_size=batch_size,
batch_drop_last=batch_drop_last,
packing_buffer_size=packing_buffer_size,
limit=limit,
),
dataset.dataset,
)
for dataset in datasets.datasets
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import json
import time
import traceback
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import (
Container,
Dict,
Generator,
Iterable,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
)
import click
import numpy as np
from PIL import Image
from tqdm import tqdm
cpal = np.array(
[
[int(x) for x in line.split(" ")]
for line in """255 255 255
1 0 103
213 255 0
255 0 86
158 0 142
14 76 161
255 229 2
0 95 57
0 255 0
149 0 58
255 147 126
164 36 0
0 21 68
145 208 203
98 14 0
107 104 130
0 0 255
0 125 181
106 130 108
0 174 126
194 140 159
190 153 112
0 143 156
95 173 78
255 0 0
255 0 246
255 2 157
104 61 59
255 116 163
150 138 232
152 255 82
167 87 64
1 255 254
255 238 232
254 137 0
189 198 255
1 208 255
187 136 0
117 68 177
165 255 210
255 166 254
119 77 0
122 71 130
38 52 0
0 71 84
67 0 44
181 0 255
255 177 103
255 219 102
144 251 146
126 45 210
189 211 147
229 111 254
222 255 116
0 255 120
0 155 255
0 100 1
0 118 255
133 169 0
0 185 23
120 130 49
0 255 198
255 110 65
232 94 190""".split("\n")
],
dtype=np.int32,
)
class YieldBatchLogLine(TypedDict):
# Json example:
# {
# "t": "yield_batch",
# "r": 1,
# "w": 1,
# "m": "train",
# "idx": 1,
# "keys": ["parts/data-train-000051.tar/528866", ...],
# }
t: Literal["yield_batch"]
r: int
w: int
m: Literal["train", "val"]
idx: int
keys: List[str]
class SampleLoaderYieldLogLine(TypedDict):
# Json example:
# {
# "t": "WebdatasetSampleLoaderDataset._slices_iter.yield",
# "r": 1,
# "w": 1,
# "index": 528800,
# "key": "parts/data-train-000051.tar/528866",
# "shard": "parts/data-train-000051.tar",
# "count": 633,
# "epoch": 0,
# "epoch_count": 633
# }
t: Literal["WebdatasetSampleLoaderDataset._slices_iter.yield"]
r: int
w: int
#: The global index in the underlying dataset (concats of all shards)
index: int
#: The sample key from the shard, concatenated as f"{shard}/{key}"
key: str
#: Name of the shard
shard: str
#: Number of samples yielded from the sample loader over all epochs
count: int
#: Number of repetitions of the dataset (=epochs). First epoch is 0.
epoch: int
#: Number of samples yielded from the sample loader in the current epoch
epoch_count: int
class AutosizingHeatmapWriter:
"""Writes a heatmap, automatically resizing it if necessary."""
def __init__(self, heatmap_samples: int, heatmap_steps: int, colorize: bool = True):
self.heatmap = np.zeros((heatmap_samples, heatmap_steps, 3), dtype=np.int32)
self.heatmap_sample_factor = 1
self.heatmap_step_factor = 1
self.heatmap_sample_max = -1
self.heatmap_step_max = -1
self.colors_size = cpal.shape[0] if colorize else 1
def add(self, sample_id: int, step: int, src: int) -> None:
"""
Add a point to the heatmap (i.e. increase count at that position).
Args:
sample_id: The sample id (y-axis)
step: The step (x-axis)
"""
# Resize heatmap?
while self.heatmap.shape[0] * self.heatmap_sample_factor <= sample_id:
self.heatmap[: self.heatmap.shape[0] // 2] = self.heatmap[::2] + self.heatmap[1::2]
self.heatmap[self.heatmap.shape[0] // 2 :] = 0
self.heatmap_sample_factor *= 2
self.heatmap_sample_max = 0
while self.heatmap.shape[1] * self.heatmap_step_factor <= step:
self.heatmap[:, : self.heatmap.shape[1] // 2] = (
self.heatmap[:, ::2] + self.heatmap[:, 1::2]
)
self.heatmap[:, self.heatmap.shape[1] // 2 :] = 0
self.heatmap_step_factor *= 2
self.heatmap_step_max = 0
# Save point
step //= self.heatmap_step_factor
sample_id //= self.heatmap_sample_factor
self.heatmap[sample_id, step] += cpal[src % self.colors_size]
self.heatmap_step_max = max(self.heatmap_step_max, step)
self.heatmap_sample_max = max(self.heatmap_sample_max, sample_id)
def save(self, path: Union[Path, str], gain: float):
"""
Save the heatmap to the given path.
Args:
path: The path to save the heatmap to.
gain: The gain (=multiplication factor) for the heatmap.
Returns:
The maximum sample id and step id that were used in the heatmap.
"""
heatmap = self.heatmap[: self.heatmap_sample_max + 1, : self.heatmap_step_max + 1]
heatmap = heatmap.astype(np.float32)
heatmap = np.clip(heatmap * gain / heatmap.max((0, 1)) * 255, 0, 255).astype(np.uint8)
Image.fromarray(heatmap).save(path)
return (
self.heatmap_sample_max * self.heatmap_sample_factor,
self.heatmap_step_max * self.heatmap_step_factor,
)
@click.command(name="analyze-debug")
@click.argument(
"log_paths",
nargs=-1,
type=click.Path(exists=True, file_okay=True, dir_okay=True, path_type=Path),
)
@click.option(
"--heatmap-path",
type=click.Path(exists=False, writable=True, dir_okay=False, path_type=Path),
default=Path("heatmap.png"),
)
@click.option(
"--heatmap-steps",
type=int,
default=1000,
help="Size of the heatmap in step direction. All steps will be downscaled to this size.",
)
@click.option(
"--heatmap-samples",
type=int,
default=1000,
help="Size of the heatmap in sample direction. All samples will be downscaled to this size.",
)
@click.option(
"--heatmap-gain",
type=float,
default=10,
help="Gain (=multiplication factor) for the heatmap",
)
@click.option(
"--force-loading-order",
is_flag=True,
default=False,
help="If true, force using the dataloader loading order instead of batch data",
)
@click.option(
"--include-modality",
type=str,
default="train",
help="Choose which modality/modalities (train,val) to include. Comma separate for multiple.",
)
@click.option(
"--skip",
type=int,
default=0,
help="If >0, skip this many steps at the beginning of log file parsing.",
)
@click.option(
"--no-colors",
is_flag=True,
default=False,
help="If set, disable colorizing ranks.",
)
def command(
log_paths: List[Path],
heatmap_path: Path,
heatmap_steps: int,
heatmap_samples: int,
heatmap_gain: float,
force_loading_order: bool,
include_modality: str,
skip: int,
no_colors: bool,
):
"""Internal tool to analyze randomness.
The LOG_PATH should point to the folder with the debug log, or to a single log file."""
if len(log_paths) == 0:
raise click.ClickException("No log paths specified")
log_files = []
for log_path in log_paths:
if log_path.is_dir():
log_files.extend(sorted(log_path.glob("*.jsonl")))
elif log_path.is_file():
log_files.append(log_path)
else:
raise click.ClickException(f"Invalid log path: {log_path}")
if len(log_files) == 0:
raise click.ClickException("No log files found")
heatmap = AutosizingHeatmapWriter(heatmap_samples, heatmap_steps, colorize=not no_colors)
print(f"Analyzing {len(log_files)} logs...")
modalities = [m.strip() for m in include_modality.split(",")]
key_index = {}
count = 0
if not force_loading_order:
loaders = [LoaderLogIter(log_file, start_idx=skip) for log_file in log_files]
loaders_by_id: Dict[int, Tuple[LoaderInfo, List[LoaderLogIter]]] = {}
with ProcessPoolExecutor(max_workers=16) as executor:
for loader, loader_info in tqdm(
executor.map(_proc_map_loader, loaders), total=len(loaders)
):
for loader_id, loader_info in loader_info.items():
if loader_id in loaders_by_id:
existing_loader_info, existing_loaders = loaders_by_id[loader_id]
assert (
existing_loader_info.modality == loader_info.modality
and existing_loader_info.path == loader_info.path
), (
f"Found multiple loaders for {loader_id}: {existing_loader_info.modality, existing_loader_info.path} and {loader_info.modality, loader_info.path}"
)
existing_loader_info.global_count = max(
existing_loader_info.global_count, loader_info.global_count
)
existing_loaders.append(loader)
else:
loaders_by_id[loader_id] = (loader_info, [loader])
print("Available loaders:")
selected_loader_id = None
must_select = False
for loader_id, (loader_info, _iters) in loaders_by_id.items():
print(
f" {loader_id}: {loader_info.modality} {loader_info.path} {loader_info.global_count} steps"
)
if loader_info.modality in modalities:
if selected_loader_id is None:
selected_loader_id = loader_id
else:
# Have multiple loaders
must_select = True
if must_select:
while True:
loader_id_str = input("Choose loader id: ")
try:
selected_loader_id = int(loader_id_str)
except ValueError:
print(f"Invalid loader id {loader_id_str} 1")
continue
if selected_loader_id in loaders_by_id:
break
print(f"Invalid loader id {selected_loader_id}")
assert selected_loader_id is not None
selected_loader_info, selected_loader_readers = loaders_by_id[selected_loader_id]
print(
f"Reading for loader {selected_loader_id}: {selected_loader_info.modality} {selected_loader_info.path}"
)
log_iters = [
(idx, loader.log_entries(loader_ids={selected_loader_id}))
for idx, loader in enumerate(selected_loader_readers)
]
with tqdm(total=selected_loader_info.global_count) as pbar:
while len(log_iters) > 0:
cur_count = 0
# Iterate over all iterators for this count and put into heatmap
for src_idx, log_iter in tuple(log_iters):
# Iterate until None (=next count) is encountered
while True:
try:
log_keys = next(log_iter)
except StopIteration:
log_iters.remove((src_idx, log_iter))
break
except OSError:
traceback.print_exc()
log_iters.remove((src_idx, log_iter))
break
else:
if log_keys is None:
break
for log_key in log_keys:
key_id = key_index.setdefault(log_key, len(key_index))
heatmap.add(key_id, count, src_idx)
cur_count += 1
if cur_count == 0:
print(f"No data for step {count}")
count += 1
pbar.update(1)
if len(key_index) == 0:
if force_loading_order:
print("Forcing to use sample loader logs")
else:
print("No batch information in logs, trying sample loader logs...")
if modalities != {"train", "val"}:
print(" Data includes all modalities (train and val)")
print(
" Shuffle buffer and batching will not be considered, only the loading order from disk"
)
log_iters = [
_iter_sl_log_line_keys(_iter_sl_log_samples(log_file), start_idx=skip)
for log_file in log_files
]
key_index = {}
count = 0
start = time.time()
while len(log_iters) > 0:
cur_count = 0
# Iterate over all iterators for this count and put into heatmap
for log_iter in tuple(log_iters):
# Iterate until None (=next count) is encountered
while True:
try:
log_key = next(log_iter)
except StopIteration:
log_iters.remove(log_iter)
break
except OSError:
traceback.print_exc()
log_iters.remove(log_iter)
break
else:
if log_key is None:
break
key_id = key_index.setdefault(log_key, len(key_index))
heatmap.add(key_id, count)
cur_count += 1
if cur_count == 0:
print(f"No data for step {count}")
if time.time() - start > 10:
print(f" Step {count}")
start = time.time()
count += 1
if count == 0:
raise click.ClickException("No data found in logs")
print(f"Found {len(key_index)} unique sample keys, {count} steps")
# print(f"Heatmap factors: {heatmap_sample_factor} samples, {heatmap_step_factor} steps")
# print(f"Heatmap max: {heatmap_sample_max} samples, {heatmap_step_max} steps")
n_samples, n_steps = heatmap.save(heatmap_path, heatmap_gain)
print(f"Wrote heatmap to {heatmap_path}")
print("Heatmap axes:")
print(f" x-axis: {n_steps} worker steps")
print(f" y-axis: {n_samples} samples")
class LoaderInitLogLine(TypedDict):
t: Literal["SavableLoader.__init__", "BasicDataLoader.__init__"]
r: int
w: None
id: int
config: dict
class LoaderIterLogLine(TypedDict):
t: Literal["SavableDataLoader.iter", "BasicDataLoader.iter"]
r: int
w: None
id: int
iter_id: int
class LoaderYieldLogLine(TypedDict):
t: Literal["SavableDataLoader.yield", "BasicDataLoader.yield"]
r: int
w: None
id: int
iter_id: int
worker_id: int
worker_idx: int
idx: int
iter_idx: int
global_idx: int
keys: Optional[List[str]]
class LoaderStopLogLine(TypedDict):
t: Literal["SavableDataLoader.StopIteration", "BasicDataLoader.StopIteration"]
r: int
w: None
id: int
iter_id: int
LoaderLines = Union[
LoaderInitLogLine,
LoaderIterLogLine,
LoaderYieldLogLine,
LoaderStopLogLine,
]
LOADER_LOG_LINE_TYPES_T = (
"SavableLoader.__init__",
"BasicDataLoader.__init__",
"SavableDataLoader.iter",
"BasicDataLoader.iter",
"SavableDataLoader.yield",
"BasicDataLoader.yield",
"SavableDataLoader.StopIteration",
"BasicDataLoader.StopIteration",
)
@dataclass
class LoaderInfo:
id: int
modality: str
path: str
global_count: int
class LoaderLogIter:
def __init__(self, path: Path, start_idx: int = 0):
self._path = path
self._start_idx = start_idx
def _iter_log_lines(self, which: Iterable[str]) -> Generator[LoaderLines, None, None]:
try:
with self._path.open("r") as rf:
for line in rf:
if any(f'"t": "{t}"' in line for t in which):
try:
yield json.loads(line.strip())
except json.JSONDecodeError:
print("Cannot decode line", repr(line))
except IOError as e:
print(f"Ignoring IOError: {e} for {self._path}")
@staticmethod
def _find_config_modality(config: dict) -> Literal["train", "val"]:
assert isinstance(config, dict)
if "map_fn_config" in config and "training" in config["map_fn_config"]:
return "train" if config["map_fn_config"]["training"] else "val"
elif "dataset" in config:
return LoaderLogIter._find_config_modality(config["dataset"])
elif "dataset_weights" in config:
return LoaderLogIter._find_config_modality(config["dataset_weights"][0][0])
elif "datasets" in config:
return LoaderLogIter._find_config_modality(config["datasets"][0])
assert False, f"Unrecognized config {config}"
@staticmethod
def _find_config_path(config: dict) -> str:
assert isinstance(config, dict)
if "map_fn_config" in config and "_path" in config["map_fn_config"]:
return config["map_fn_config"]["_path"]
elif "dataset" in config:
return LoaderLogIter._find_config_path(config["dataset"])
elif "dataset_weights" in config:
return LoaderLogIter._find_config_path(config["dataset_weights"][0][0])
elif "datasets" in config:
return LoaderLogIter._find_config_path(config["datasets"][0])
assert False, f"Unrecognized config {config}"
def loaders(self) -> Dict[int, LoaderInfo]:
loaders = {}
for log_line in self._iter_log_lines(
(
"SavableLoader.__init__",
"BasicDataLoader.__init__",
"SavableDataLoader.yield",
"BasicDataLoader.yield",
)
):
if log_line["t"] in ("SavableLoader.__init__", "BasicDataLoader.__init__"):
loaders[log_line["id"]] = LoaderInfo(
id=log_line["id"],
modality=self._find_config_modality(log_line["config"]),
path=self._find_config_path(log_line["config"]),
global_count=0,
)
elif log_line["t"] in ("SavableDataLoader.yield", "BasicDataLoader.yield"):
loaders[log_line["id"]].global_count = log_line["global_idx"]
return loaders
def log_entries(self, loader_ids: Container[int]) -> Generator[Optional[List[str]], None, None]:
idx = self._start_idx
for log_line in self._iter_log_lines(("SavableDataLoader.yield", "BasicDataLoader.yield")):
if (
log_line["t"] in ("SavableDataLoader.yield", "BasicDataLoader.yield")
and log_line["id"] in loader_ids
):
assert log_line["global_idx"] >= idx, (
f"Found entry {log_line} with wrong idx <{idx}"
)
while log_line["global_idx"] != idx:
yield None
idx += 1
if "keys" in log_line:
yield log_line["keys"]
def __repr__(self) -> str:
return f"log({str(self._path)})"
def _proc_map_loader(loader: LoaderLogIter) -> Tuple[LoaderLogIter, Dict[int, LoaderInfo]]:
return (loader, loader.loaders())
def _iter_sl_log_line_keys(
log_lines: Iterable[SampleLoaderYieldLogLine],
start_idx: int = 0,
) -> Generator[Optional[str], None, None]:
count = start_idx
for log_line in log_lines:
if log_line["count"] < start_idx:
continue
assert log_line["count"] >= count
while log_line["count"] != count:
yield None
count += 1
yield log_line["key"]
def _iter_sl_log_samples(path: Path) -> Generator[SampleLoaderYieldLogLine, None, None]:
with path.open("r") as rf:
for line in rf:
if '"t": "WebdatasetSampleLoaderDataset._slices_iter.yield"' in line:
try:
yield json.loads(line.strip())
except json.JSONDecodeError:
print("Cannot decode line", repr(line))
if __name__ == "__main__":
command()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import re
from typing import List, Optional
import click
import torch
from megatron.energon.epathlib import EPath
from megatron.energon.savable_loader import SavableDataLoaderState
def natural_sort_key(s):
"""
Function to use for natural sorting of filenames.
This splits the input string by numbers and non-numbers and ensures
that numbers are compared as integers, not as strings.
"""
return [int(text) if text.isdigit() else text.lower() for text in re.split(r"(\d+)", s)]
def detect_and_replicate_pattern(file_list):
"""
Given a list of file paths, detect the single numeric pattern and return
a function that, when called with integer n (starting from 0), generates
the n-th filename following that pattern.
Raises an Exception if no pattern or multiple patterns are found.
"""
if not file_list:
raise ValueError("Cannot detect a pattern from an empty list.")
# -- 1) Sort the list using a natural key so that numbers compare numerically
sorted_files = sorted(file_list, key=natural_sort_key)
# -- 2) Tokenize each filename into [text, number, text, number, ...].
# We'll look for the pattern of tokens across all files.
def tokenize_filename(fname):
# Use the same split so that digit tokens are separated
# from non-digit tokens.
tokens = re.split(r"(\d+)", fname)
# tokens is like ["f", "001", ".txt"] for "f001.txt"
return tokens
tokenized = [tokenize_filename(f) for f in sorted_files]
# Check that all tokenized filenames have the same number of chunks:
token_len = len(tokenized[0])
for t in tokenized:
if len(t) != token_len:
raise Exception("Filenames do not share a consistent token structure.")
# -- 3) Identify exactly one numeric token position that changes across all files.
# All other positions must be identical across the entire list.
num_positions = [] # positions in the token list that differ
for pos in range(token_len):
# Check if this chunk is the same for all files or not:
# We compare "raw text" for non-digit chunks, and "integer value" for digit chunks.
# For the first file's token, check if it's digits or not
example_token = tokenized[0][pos]
example_is_digit = example_token.isdigit()
# Collect how all files differ at this position
all_tokens_at_pos = [t[pos] for t in tokenized]
# If it's supposed to be a numeric token,
# we compare the integer values to see if they differ or not.
# If it's a non-numeric token, they must all be identical.
if example_is_digit:
# Parse integer values
values = [int(x) if x.isdigit() else None for x in all_tokens_at_pos]
# If *any* of them is None or they vary, we track that as "differences".
# But let's see if indeed they differ across the files or not.
if len(set(values)) > 1:
# This token position changes among files
num_positions.append(pos)
else:
# The numeric token is the same for all files, so no variation here
pass
else:
# Non-digit token, must be identical across all files
if len(set(all_tokens_at_pos)) != 1:
raise Exception("Non-digit token differs among files. Invalid pattern.")
# We expect exactly 1 changing numeric token position
if len(num_positions) == 0:
raise Exception("No numeric portion found that differs among files.")
if len(num_positions) > 1:
raise Exception("Multiple numeric portions found that differ. Not a single pattern.")
varying_pos = num_positions[0]
# -- 4) Extract the numeric values of that varying position for all sorted files,
# check consecutive increments and find the zero-padding width.
numeric_values = [int(t[varying_pos]) for t in tokenized]
# Check if consecutive differences are all +1
for i in range(len(numeric_values) - 1):
if numeric_values[i + 1] - numeric_values[i] != 1:
raise Exception("Numeric values are not consecutive. Pattern is invalid.")
# The "base" number is numeric_values[0], i.e. the value for n=0
base_value = numeric_values[0]
# The zero-padding width is based on the first file's numeric token
zero_padding_width = len(tokenized[0][varying_pos])
# -- 5) Construct the function that, given n, returns the enumerated filename.
# We'll verify it against the original sorted list as well.
def generate_filename(n):
# Rebuild the token array from the first file's tokens,
# except we replace the one numeric token with (base_value + n) zero-padded.
new_tokens = tokenized[0][:]
new_int_value = base_value + n
# zero-pad with the discovered width
new_str_value = str(new_int_value).zfill(zero_padding_width)
# Replace the numeric position
new_tokens[varying_pos] = new_str_value
# Join all tokens back into a string
return "".join(new_tokens)
# -- 6) Verify that generate_filename(i) reproduces the sorted list exactly
# for i in [0..len(sorted_files)-1].
for i in range(len(sorted_files)):
candidate = generate_filename(i)
if candidate != sorted_files[i]:
raise Exception(
"Verification failed. The generated pattern does not match the input list."
)
# If we get here, everything is good. Return the generator function.
return generate_filename
class RankStateIterable:
"""Iterates the SavableDatasetCheckpoints of mulitple ranks in a round-robin fashion."""
def __init__(self, state_files: List[EPath]):
state_file_names = [state_file.name for state_file in state_files]
self.file_pattern_func = detect_and_replicate_pattern(state_file_names)
self.num_states = len(state_files)
# First open the first one to figure out if this is a global checkpoint or not
first_state = torch.load(str(state_files[0]), weights_only=False)
if isinstance(first_state, dict) and "dataloader_state_dict" in first_state:
self.megatron_style = True
first_state = first_state["dataloader_state_dict"]
else:
self.megatron_style = False
if isinstance(first_state, SavableDataLoaderState):
if self.megatron_style:
self.rank_states = [first_state] + [
torch.load(str(state_file), weights_only=False)["dataloader_state_dict"]
for state_file in state_files[1:]
]
else:
self.rank_states = [first_state] + [
torch.load(str(state_file), weights_only=False)
for state_file in state_files[1:]
]
self.is_global_checkpoint = False
elif isinstance(first_state, list):
assert len(state_files) == 1, "Global checkpoint must contain exactly one file"
assert all(isinstance(state, SavableDataLoaderState) for state in first_state)
self.rank_states = first_state
self.is_global_checkpoint = True
else:
raise ValueError(f"Unknown checkpoint type: {type(first_state)}")
self.rank_cur_worker = [0] * len(self.rank_states)
self.rank_worker_offset = [state.next_worker_id for state in self.rank_states]
self.rank_num_workers = [len(state.worker_states) for state in self.rank_states]
assert all(
self.rank_num_workers[0] == num_workers for num_workers in self.rank_num_workers
), "All ranks must have the same number of workers."
def write_new_states_to_folder(
self, output_folder: EPath, new_states: List[SavableDataLoaderState]
):
for rank_idx, rank_state in enumerate(new_states):
output_file = output_folder / self.file_pattern_func(rank_idx)
if self.megatron_style:
torch.save(
{"dataloader_state_dict": rank_state},
str(output_file),
)
else:
torch.save(rank_state, str(output_file))
def get_num_ranks(self):
return len(self.rank_states)
def get_num_workers(self):
return self.rank_num_workers[0]
def get_micro_batch_size(self):
return self.rank_states[0].micro_batch_size
def __iter__(self):
"""Iterates the SavableDatasetCheckpoints of mulitple ranks in a round-robin fashion."""
for rank, state in enumerate(self.rank_states):
for worker_state in state.worker_states:
yield worker_state
@click.command(name="redist")
@click.argument(
"input_files",
nargs=-1,
type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=EPath),
required=True,
)
@click.argument(
"output_path",
type=click.Path(file_okay=False, dir_okay=True, path_type=EPath),
)
@click.option(
"--new-world-size", type=int, help="Number of ranks to redistribute to", required=False
)
def command_redist(
input_files: List[EPath], output_path: EPath, new_world_size: Optional[int] = None
):
"""Redistribute a checkpoint.
Read checkpoint files from INPUT_FILES and redistribute them for a new
number of ranks. Write the output to OUTPUT_PATH."""
# Verify input files
if not input_files:
raise click.ClickException("No input files provided")
input_file_list = sorted(input_files, key=lambda x: natural_sort_key(x.name))
click.echo(f"Processing {len(input_file_list)} checkpoint files")
# Determine if we're processing a single global checkpoint or multiple rank files
rsi = RankStateIterable(input_file_list)
if not rsi.rank_states:
raise click.ClickException("No valid checkpoint states found")
if new_world_size is None:
click.echo(f"Current DP world size: {rsi.get_num_ranks()}")
click.echo(f"Current number of workers per DP rank: {rsi.get_num_workers()}")
new_world_size = click.prompt("Please enter the new DP world size", type=int)
assert isinstance(new_world_size, int)
if new_world_size <= 0:
raise click.ClickException("New world size must be greater than 0")
total_num_workers = rsi.get_num_workers() * rsi.get_num_ranks()
assert total_num_workers % new_world_size == 0, (
"New DP world size must be a multiple of the current DP world size"
)
new_workers_per_rank = total_num_workers // new_world_size
# Ensure output directory exists
output_path.mkdir(exist_ok=True, parents=True)
new_rank_states = [list() for _ in range(new_world_size)]
rsi_iter = iter(rsi)
for rank_idx in range(new_world_size):
for _ in range(new_workers_per_rank):
state = next(rsi_iter)
new_rank_states[rank_idx].append(state)
assert all(
len(new_rank_states[0]) == len(new_rank_states[rank]) for rank in range(1, new_world_size)
), "All ranks must have the same number of workers, also for the new distribution."
new_states = [
SavableDataLoaderState(
worker_states=new_rank_state,
next_worker_id=0, # Reset the next worker ID
micro_batch_size=rsi.get_micro_batch_size(),
)
for new_rank_state in new_rank_states
]
# Save the redistributed checkpoint
if rsi.is_global_checkpoint:
# Save as a single global checkpoint file
output_file = output_path / input_file_list[0].name
torch.save(new_states, str(output_file))
click.echo(f"Saved global checkpoint to {output_file}")
else:
rsi.write_new_states_to_folder(output_path, new_states)
click.echo(f"Saved {new_world_size} rank checkpoint files to {output_path}")
@click.command(name="info")
@click.argument(
"input_files",
nargs=-1,
type=click.Path(file_okay=True, dir_okay=False, exists=True, path_type=EPath),
required=True,
)
def command_info(input_files: List[EPath]):
"""Display information about a checkpoint.
Read a checkpoint from CHECKPOINT_PATH (either a single file or directory with *.pt files)
and display information about it.
"""
# Load the checkpoint(s)
rsi = RankStateIterable(input_files)
# Display basic information
if rsi.is_global_checkpoint:
click.echo("Checkpoint type: Global checkpoint")
else:
click.echo("Checkpoint type: Per-rank checkpoint files")
click.echo(f"Number of DP ranks: {rsi.get_num_ranks()}")
click.echo(f"Number of workers per DP rank: {rsi.get_num_workers()}")
# Additional detailed information
click.echo("\nDetailed information:")
for rank_idx, state in enumerate(rsi.rank_states):
if rsi.is_global_checkpoint:
click.echo(f" Rank {rank_idx}:")
else:
click.echo(f" Rank {rank_idx} ({input_files[rank_idx].name}):")
click.echo(f" Next worker ID: {state.next_worker_id}")
click.echo(f" Number of worker states: {len(state.worker_states)}")
@click.group(
name="checkpoint",
context_settings=dict(help_option_names=["-h", "--help"]),
invoke_without_command=True,
)
@click.pass_context
def command(ctx):
"""Tools for energon checkpoints."""
# This is needed to show help if no subcommand is provided
if ctx.invoked_subcommand is None:
click.echo(command.get_help(ctx))
command.add_command(command_redist)
command.add_command(command_info)
if __name__ == "__main__":
command()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from importlib import import_module
from typing import Type
import braceexpand
import click
import megatron.energon
from megatron.energon.epathlib import EPath
from megatron.energon.eyaml import load_yaml, load_yaml_json
from megatron.energon.flavors.base_dataset import BaseCoreDatasetFactory
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.metadata import get_dataset_info
fmt = """
* {dataset_name}: `{path}` ({samples_count:,} samples, {samples_size} in {shards_count} shards)
* Created with energon version: {energon_version}
* Sample Type: {{py:class}}`{sample_name} <{sample_fullname}>`
* Default Splits:
{splits_str}
"""
split_fmt = """ * `{split_name}`: {split_ratio:.0f}%, {split_samples_count:,} samples in {split_shards_count} shards
"""
def fmt_size(size: int) -> str:
keys = ["B", "KiB", "MiB", "GiB", "TiB"]
for key in keys:
if size < 1024:
return f"{size:.2f} {key}"
size /= 1024
return f"{size:.2f} PiB"
@click.command(name="info")
@click.argument(
"path",
type=click.Path(file_okay=False, dir_okay=True, path_type=EPath),
)
@click.option(
"--split-config", default="split.yaml", help="Split config file name", show_default=True
)
@click.option(
"--dataset-config", default="dataset.yaml", help="Dataset config file name", show_default=True
)
def command(
path: EPath,
split_config: str,
dataset_config: str,
):
"""
Get summarizing information about a dataset.
"""
ds_config = load_yaml((path / MAIN_FOLDER_NAME / dataset_config).read_bytes())
info_config = get_dataset_info(path)
split_config_obj = load_yaml_json(path / MAIN_FOLDER_NAME / split_config)
ds_energon_version = info_config.get("energon_version", "unknown")
samples_count = sum(info_config["shard_counts"].values())
dict_sample_type = ds_config["sample_type"]
sample_module = import_module(dict_sample_type["__module__"])
sample_cls: Type[BaseCoreDatasetFactory] = getattr(sample_module, dict_sample_type["__class__"])
sample_module = sample_cls.__module__
if (
sample_module.startswith("megatron.energon")
and getattr(megatron.energon, dict_sample_type["__class__"], None) == sample_cls
):
sample_module = "megatron.energon"
sample_name = sample_cls.__name__
sample_fullname = sample_module + "." + sample_name
def srt_key(pair):
try:
return ("train", "val", "test").index(pair[0])
except ValueError:
return 3
# Brace expand all the split part files
expanded_split_parts = {}
for split_name, split_parts in split_config_obj["split_parts"].items():
expanded_split_parts[split_name] = []
for split_part in split_parts:
for name in braceexpand.braceexpand(split_part):
expanded_split_parts[split_name].append(name)
splits_str = "".join(
split_fmt.format(
split_name=split_name,
split_ratio=round(
100
* sum(info_config["shard_counts"][shard] for shard in split_parts)
/ samples_count,
2,
),
split_samples_count=sum(info_config["shard_counts"][shard] for shard in split_parts),
split_shards_count=len(split_parts),
)
for split_name, split_parts in sorted(expanded_split_parts.items(), key=srt_key)
)
print(
fmt.format(
dataset_name=path.name,
path=str(path),
samples_count=samples_count,
samples_size=fmt_size(
sum((path / split_name).size() for split_name in info_config["shard_counts"].keys())
),
shards_count=len(info_config["shard_counts"]),
sample_name=sample_name,
sample_fullname=sample_fullname,
splits_str=splits_str,
energon_version=ds_energon_version,
)
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import logging
import traceback
from typing import Optional
import click
import tqdm
from megatron.energon import WorkerConfig, get_loader, get_val_dataset
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset.empty_dataset_error import EmptyDatasetError
from megatron.energon.flavors.webdataset.itar import TarIndexReader
from megatron.energon.flavors.webdataset.metadata import get_dataset_info
def check_index(dataset_path: EPath):
if not dataset_path.is_dir():
click.echo(f"{dataset_path} is not a directory, therefore the index will not be checked")
return
ok = True
# Get info file
info = get_dataset_info(dataset_path)
click.echo("Checking the index files...")
shards = info["shard_counts"]
for shard_file, length in shards.items():
with TarIndexReader(dataset_path / shard_file) as itar:
l = len(itar)
if l - 1 != length:
ok = False
print(
f"Error in shard {shard_file}: Shard length in Info file {length} != {l - 1} (length in index)"
)
return ok
@click.command(name="lint")
@click.argument(
"path",
type=click.Path(path_type=EPath),
)
@click.option(
"--split-parts", default="train,val,test", help="The splits to verify", show_default=True
)
@click.option(
"--dataset-config", default="dataset.yaml", help="Dataset config file name", show_default=True
)
@click.option(
"--split-config", default="split.yaml", help="Split config file name", show_default=True
)
@click.option(
"--parallel", default=1, help="Number of parallel workers", show_default=True, type=int
)
def command(path: EPath, split_parts: str, dataset_config: str, split_config: str, parallel: int):
"""Check energon dataset for errors.
The PATH should point to the folder with the dataset.
The dataset must comply with the energon dataset format. See README.md for more details."""
# Check the tar file index
if not check_index(path):
raise click.ClickException("Validation failed with errors, see logs for details.")
# Check the dataset
failed = False
ignore_list = []
def handler(exc: Exception, key: Optional[str] = None) -> None:
nonlocal failed
failed = True
logging.exception(str(exc))
if key is not None:
ignore_list.append(key)
kwargs = {}
if dataset_config != "dataset.yaml":
kwargs["dataset_config"] = dataset_config
if split_config != "split.yaml":
kwargs["split_config"] = split_config
worker_config = WorkerConfig(rank=0, world_size=1, num_workers=parallel)
for split_part in split_parts.split(","):
try:
dataset = get_val_dataset(
EPath(path),
split_part=split_part,
worker_config=worker_config,
batch_size=1,
handler=handler,
**kwargs,
)
except EmptyDatasetError:
click.echo(f"Skipping empty split part {split_part}")
continue
try:
for _ in tqdm.tqdm(get_loader(dataset)):
pass
except InterruptedError:
raise
except BaseException:
traceback.print_exc()
raise click.ClickException("Validation failed with errors, see logs for details.")
if failed:
click.echo(
"The following shards/samples failed (maybe set as dataset.yaml:ignore_list):", err=True
)
for item in ignore_list:
click.echo(f"- {item}", err=True)
raise click.ClickException("Validation failed with errors, see logs for details.")
if __name__ == "__main__":
command()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import os
import stat
from errno import EBADF, ENOENT
from pathlib import Path
from sqlite3 import OperationalError
from typing import Generator, Optional
import click
from megatron.energon.cache.file_store import WebdatasetFileStore
from megatron.energon.epathlib import EPath
MULTI_WARN = "WARNING_SAME_KEY_IN_MULTIPLE_TAR_FILES"
EnergonFS: Optional[type]
try:
from mfusepy import FUSE, FuseOSError, Operations
class _EnergonFS(Operations):
"""
Read-only filesystem that exposes an energon WebdatasetFileStore.
"""
def __init__(
self,
db_path: EPath,
*,
sample_folders: bool = False,
print_debug: int = 0,
allow_slow_mode: bool = False,
) -> None:
self._sample_folders = sample_folders
self._wds_filestore = WebdatasetFileStore(db_path)
self._all_sample_parts = {}
self._slow_mode = False
try:
for key, size, tar_file_id in self._wds_filestore.list_all_sample_parts():
if key not in self._all_sample_parts:
# Only take the first tar file id
self._all_sample_parts[key] = size
except OperationalError:
if not allow_slow_mode:
raise RuntimeError(
"The dataset was prepared with an older version of energon. Either update the dataset, or allow slow mode."
)
else:
assert sample_folders, (
"Only sample_folders mode is supported when using slow mode."
)
self._slow_mode = True
self._samples_with_multiple_tar_files = set()
self._all_samples = {}
for key, size, tar_file_id in self._wds_filestore.list_all_samples():
if key not in self._all_samples:
self._all_samples[key] = size
else:
self._samples_with_multiple_tar_files.add(key)
self._total_size = None
# When a file is opened, we keep the bytes in memory for now (until it is closed)
self._open_files = {}
# Get current uid and gid
self._uid = os.getuid()
self._gid = os.getgid()
# Get modification time of the db file
try:
self._mtime = os.path.getmtime(str(db_path))
except FileNotFoundError:
# Remote file systems have no modification time
self._mtime = 0
self._print = print_debug
def statfs(self, path: str) -> dict:
"""Return information about the file system.
This is called when the user runs `df` on the mount point.
"""
if self._total_size is None:
print("Computing total size...", end="", flush=True)
self._total_size = self._wds_filestore.get_total_size()
print(f"done: {self._total_size} bytes")
return dict(
f_bsize=512,
f_blocks=self._total_size // 512,
f_bavail=0,
f_bfree=0,
f_files=len(self._all_sample_parts) if not self._slow_mode else 0,
f_ffree=0,
f_namemax=1024,
)
def getattr(self, path: str, fh: int = 0) -> dict:
"""Return information about one file or folder.
This is called when using `ls -l` etc.
Returns a dict with the following keys:
- st_mode: File mode (S_IFDIR, S_IFREG, etc.)
- st_nlink: Number of links
- st_size: Size of the file
- st_ctime: Creation time
- st_mtime: Modification time
- st_atime: Access time
- st_uid: User ID of the file
- st_gid: Group ID of the file
"""
if path[0] != "/":
raise FuseOSError(ENOENT)
if path == "/":
return dict(
st_mode=0o555 | stat.S_IFDIR,
st_nlink=2,
st_size=0,
st_ctime=self._mtime,
st_mtime=self._mtime,
st_atime=self._mtime,
st_uid=self._uid,
st_gid=self._gid,
)
# Strip leading '/'
path = path[1:]
if path.endswith(MULTI_WARN):
return dict(
st_mode=0o000 | stat.S_IFBLK,
st_nlink=1,
st_size=0,
st_ctime=self._mtime,
st_mtime=self._mtime,
)
if self._sample_folders:
folder, part_name = self._path_parts(path)
if part_name != "":
# This is a sample part (file)
if folder not in self._all_samples:
raise FuseOSError(ENOENT)
full_name = f"{folder}.{part_name}"
if self._slow_mode and full_name not in self._all_sample_parts:
# Slow mode
for entry, size, tar_file_id in self._wds_filestore.list_sample_parts(
folder, slow_mode=True
):
cur_full_name = f"{folder}.{entry}"
self._all_sample_parts[cur_full_name] = size
if full_name not in self._all_sample_parts:
raise FuseOSError(ENOENT)
file_size = self._all_sample_parts[full_name]
mode = 0o444 | stat.S_IFREG
else:
# This is a sample (directory)
if path not in self._all_samples:
raise FuseOSError(ENOENT)
file_size = self._all_samples[path]
mode = 0o555 | stat.S_IFDIR
else:
if path not in self._all_sample_parts:
raise FuseOSError(ENOENT)
file_size = self._all_sample_parts[path]
mode = 0o444 | stat.S_IFREG
return dict(
st_mode=mode,
st_nlink=1,
st_size=file_size,
st_ctime=self._mtime,
st_mtime=self._mtime,
st_atime=self._mtime,
st_uid=self._uid,
st_gid=self._gid,
)
def _path_parts(self, path: str) -> tuple[str, str]:
"""Split a path into a folder and a part name and check for errors.
We only allow paths of the form "sample_key/part_name".
The leading "/" must be stripped before.
"""
path_parts = path.split("/")
# path_parts [0] == "sample_key"
# path_parts [1] == "part_name"
if len(path_parts) > 2:
raise FuseOSError(ENOENT)
if len(path_parts) == 1:
part_name = ""
else:
part_name = path_parts[1]
return path_parts[0], part_name
def readdir(self, path: str, fh: int = 0) -> Generator[str, None, None]:
"""List the contents of a directory.
This is called when using `ls` etc.
Returns a generator of the entries in the directory as strings.
"""
if path[0] != "/":
raise FuseOSError(ENOENT)
path = path[1:]
if self._sample_folders:
if path == "":
yield "."
yield ".."
for entry in self._all_samples.keys():
yield entry
else:
folder, part_name = self._path_parts(path)
if folder not in self._all_samples or part_name != "":
raise FuseOSError(ENOENT)
yield "."
yield ".."
single_tar_id = None
all_entries = list(
self._wds_filestore.list_sample_parts(folder, slow_mode=self._slow_mode)
)
for entry, size, tar_file_id in all_entries:
if single_tar_id is None:
single_tar_id = tar_file_id
elif single_tar_id != tar_file_id:
break
yield entry
if folder in self._samples_with_multiple_tar_files:
yield MULTI_WARN
else:
if path != "":
# Only "/" is allowed for listing all sample parts
raise FuseOSError(ENOENT)
yield "."
yield ".."
for entry in self._all_sample_parts.keys():
yield entry
for key in self._samples_with_multiple_tar_files:
yield f"{key}.{MULTI_WARN}"
def open(self, path: str, flags: int = 0) -> int:
"""Open a file for reading.
Actually, we already read the file into memory when it is opened.
The read operation just returns a slice of the memory buffer.
Returns a dummy file descriptor.
"""
if path[0] != "/":
raise FuseOSError(ENOENT)
path = path[1:]
# read-only: deny write flags
if flags & (os.O_WRONLY | os.O_RDWR | os.O_APPEND):
raise FuseOSError(ENOENT)
if self._sample_folders:
folder, part_name = self._path_parts(path)
if folder not in self._all_samples:
raise FuseOSError(ENOENT)
full_name = f"{folder}.{part_name}"
file_bytes, _ = self._wds_filestore[full_name]
else:
if path not in self._all_sample_parts:
raise FuseOSError(ENOENT)
file_bytes, _ = self._wds_filestore[path]
assert isinstance(file_bytes, bytes)
self._open_files[path] = file_bytes
# dummy file handle
return 0
def read(self, path: str, size: int, offset: int, fh: int = 0) -> bytes:
"""Read from an open file.
This is called when using `read` etc.
Returns the bytes object of a previously opened file.
"""
if path[0] != "/":
raise FuseOSError(EBADF)
path = path[1:]
if path not in self._open_files:
raise FuseOSError(ENOENT)
data = self._open_files[path]
return data[offset : offset + size]
def release(self, path: str, fh: int = 0) -> None:
"""Release an open file.
This is called when the file is closed. We can now discard the memory buffer.
"""
if path[0] != "/":
raise FuseOSError(ENOENT)
path = path[1:]
if path not in self._open_files:
raise FuseOSError(ENOENT)
del self._open_files[path]
def destroy(self, path: str) -> None:
print("Closing energon mount.")
if len(self._open_files) > 0:
print(f"Number of still open files: {len(self._open_files)}")
self._wds_filestore.close()
EnergonFS = _EnergonFS
except (ImportError, OSError):
# mfusepy or fuse not installed, so we can't mount the filesystem
EnergonFS = None
@click.command(name="mount")
@click.argument(
"path",
type=click.Path(path_type=EPath),
)
@click.argument(
"mountpoint",
type=click.Path(path_type=Path),
)
@click.option(
"--detach",
"-d",
is_flag=True,
default=False,
help="Run in background",
)
@click.option(
"--sample-folders",
"-s",
is_flag=True,
default=False,
help="Present a virtual folder per sample. Otherwise a single folder with all sample parts.",
)
@click.option(
"--allow-slow",
"-w",
is_flag=True,
default=False,
help="Allow slow mode. This is useful if the dataset was prepared with an older version of energon.",
)
def command(path: EPath, mountpoint: Path, detach: bool, sample_folders: bool, allow_slow: bool):
"""
Mount an energon WebdatasetFileStore at the given mountpoint.
The PATH should point to the folder with the dataset.
"""
path = EPath(path)
mountpoint = Path(mountpoint)
print(f"Mounting {path} at {mountpoint}...")
mountpoint.mkdir(parents=True, exist_ok=True)
if EnergonFS is not None:
energon_fs = EnergonFS(path, sample_folders=sample_folders, allow_slow_mode=allow_slow)
else:
print("mfusepy or fuse not installed, so we can't mount the filesystem.")
return
print("Mount is ready to go.")
try:
FUSE(
operations=energon_fs,
mountpoint=str(mountpoint),
foreground=not detach,
nothreads=True,
ro=True,
)
except RuntimeError as e:
if str(e) != "7":
raise
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
import inspect
import json
import re
import typing
from types import FunctionType
from typing import Any, List, Optional, Tuple, Type
import click
import yaml
from megatron.energon.epathlib import EPath
from megatron.energon.flavors import (
BaseWebdatasetFactory,
CrudeJsonlDatasetFactory,
CrudeWebdataset,
Sample,
)
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.metadata import (
EnergonDatasetType,
check_dataset_info_present,
get_dataset_info,
get_dataset_type,
)
from megatron.energon.metadataset.loader import prepare_metadataset
def type_str(tp: Type) -> str:
"""Returns a human-readable string for a type."""
if typing.get_origin(tp) is not None:
return repr(tp)
if isinstance(tp, type):
if tp.__module__ == "builtins":
return tp.__qualname__
return f"{tp.__module__}.{tp.__qualname__}"
if tp is ...:
return "..."
if isinstance(tp, FunctionType):
return tp.__name__
return repr(tp)
def sample_loader_template(fields: dict, parts: list):
"""Returns a template for a sample_loader.py file."""
fields_str = ""
for field in fields:
if field.name in ("__key__", "__restore_key__", "__subflavors__"):
continue
line = f""" {field.name}=raw["TODO"], # expected type: {type_str(field.type)}"""
if field.default is not dataclasses.MISSING:
line += ", default: " + repr(field.default)
fields_str += line + "\n"
return "\n".join(
[
"# This file was automatically generated by `energon prepare`.",
"# TODO: Edit it to return the proper fields",
"# import torch",
"",
"def sample_loader(raw: dict) -> dict:"
" # Note: Images are already decoded to tensors",
" # TODO: Set the correct values for all (required) fields",
" return dict(",
fields_str,
" )",
"",
"def part_filter(part: str) -> bool:",
" # TODO: Filter for parts required by the sample_loader",
" # E.g. if your dataset contains jpeg, txt and json, but you won't use json,",
" # remove it from the list, such that it is not decoded. If you need all, keep as is",
f" return part in {tuple(parts)!r}",
"",
]
)
def printify_json(data: Any) -> Any:
"""Shortens json data to a human-readable length."""
if isinstance(data, dict):
return {k: printify_json(v) for k, v in data.items()}
elif isinstance(data, list):
if len(data) > 3:
return [printify_json(v) for v in data[:3]] + ["..."]
return [printify_json(v) for v in data]
elif isinstance(data, str):
return data[:25] + ("..." if len(data) > 25 else "")
return data
@click.command(name="prepare")
@click.argument(
"path",
type=click.Path(path_type=EPath),
)
@click.option(
"--progress/--no-progress",
default=True,
)
@click.option(
"--split-parts",
help="Path pattern for parts in the form 'train:train/{000000-009999}.tar'. Will ignore ratio.",
multiple=True,
default=None,
)
@click.option(
"--exclude",
help="Exclude tar file paths (relative to root) matching that regex (at any position)",
)
@click.option(
"--num-workers",
type=int,
default=16,
help="Number of workers to use to index tar files",
)
@click.option(
"--tar-index-only",
help="Only (re)generate the tar-index",
is_flag=True,
)
@click.option(
"--shuffle-tars",
help="If set, the tar files will be shuffled before splitting.",
is_flag=True,
)
def command(
path: EPath,
progress: bool,
split_parts: Optional[List[str]],
exclude: str,
num_workers: int,
tar_index_only: bool,
shuffle_tars: bool,
):
"""Prepare WebDataset for use with energon.
The PATH should point to the folder with the dataset.
This tool will add the required metadata yaml files to the dataset. See README.md for more
details.
"""
ds_type = get_dataset_type(path)
if ds_type == EnergonDatasetType.METADATASET:
print("Preparing metadataset...")
prepare_metadataset(path)
return
elif ds_type == EnergonDatasetType.JSONL:
print("Preparing jsonl dataset...")
count = CrudeJsonlDatasetFactory.prepare_dataset(path)
print(f"Done. Found {count} samples.")
return
assert path.is_dir(), f"Path {path} is not a known dataset type"
if tar_index_only:
info = get_dataset_info(path)
all_tars = list(info["shard_counts"].keys())
else:
if check_dataset_info_present(path):
if not click.confirm(
"It seems the dataset had already been prepared. Do you want to continue?"
):
return
all_tars = list(path.glob("**/*.tar")) + list(path.glob("**/*.tgz"))
all_tars = [str(p.relative_to(path)) for p in sorted(all_tars)]
if exclude:
all_tars = [p for p in all_tars if not re.search(exclude, p)]
if len(all_tars) == 0:
click.echo("Did not find any tar files. Exiting.")
return
if not tar_index_only:
click.echo(f"Found {len(all_tars)} tar files in total. The first and last ones are:")
click.echo(f"- {all_tars[0]}")
click.echo(f"- {all_tars[-1]}")
click.echo(
"If you want to exclude some of them, cancel with ctrl+c and specify an exclude "
"filter in the command line."
)
split_parts_patterns: Optional[List[Tuple[str, str]]]
if split_parts:
split_parts_patterns = [tuple(x.split(":", 1)) for x in split_parts]
split_parts_ratio = None
elif not tar_index_only:
split_input = click.prompt(
'Please enter a desired train/val/test split like "0.5, 0.2, 0.3" or "8,1,1"', type=str
)
# Extract split floats
try:
split = [float(x.strip()) for x in split_input.split(",")]
assert len(split) == 3
except (ValueError, AssertionError):
click.echo("Invalid split. Stopping.")
return
split_parts_ratio = [("train", split[0]), ("val", split[1]), ("test", split[2])]
split_parts_patterns = None
else:
split_parts_ratio = None
split_parts_patterns = None
if progress:
def progress_fn(els, length=None):
with click.progressbar(
els,
label="Indexing shards",
show_pos=True,
length=length,
) as bar:
yield from bar
else:
def progress_fn(els, length=None):
return els
found_types, duplicates = BaseWebdatasetFactory.prepare_dataset(
path,
all_tars,
split_parts_ratio=split_parts_ratio,
split_parts_patterns=split_parts_patterns,
progress_fn=progress_fn,
tar_index_only=tar_index_only,
shuffle_seed=42 if shuffle_tars else None,
workers=num_workers,
)
if duplicates:
print(f"Examples of duplicates found: {duplicates}")
print()
print(
"The dataset has duplicate keys. Best practice is to use unique keys. "
"You won't be able to use this dataset for joining "
"later on."
)
found_types = list(found_types)
if tar_index_only:
return
if duplicates:
if not click.confirm("Do you want to continue?"):
return
# Print json of first two samples
for sample_idx, data in enumerate(
BaseWebdatasetFactory.iter_dataset_content(path / all_tars[0], ("json",))
):
print(f"Sample {sample_idx}, keys:")
for key in data.keys():
print(f" - {key}")
if "json" in data:
print(f"Json content of sample {sample_idx} of {all_tars[0]}:")
print(json.dumps(printify_json(json.loads(data["json"])), indent=2))
if sample_idx >= 1:
break
if len(found_types) > 10:
click.echo(
f"Found the following part types in the dataset: {', '.join(found_types[:10])} and more.."
)
allow_interactive_field_map = False
else:
click.echo(f"Found the following part types in the dataset: {', '.join(found_types)}")
allow_interactive_field_map = True
if click.confirm("Do you want to create a dataset.yaml interactively?", default=True):
# Get a list of all classes in megatron.energon that are subclasses of WebdatasetBase
import megatron.energon as data_import
display_name_and_class = [
(name, cls)
for name, cls in inspect.getmembers(data_import)
if isinstance(cls, type) and issubclass(cls, Sample)
]
display_name_and_class.append(("Crude sample (plain dict for cooking)", CrudeWebdataset))
# Print all classes and ask user to pick one
click.echo("The following sample types are available:")
for i, (name, cls) in enumerate(display_name_and_class):
click.echo(f"{i}. {name}")
while True:
choice = click.prompt("Please enter a number to choose a class", type=int)
try:
_, cls = display_name_and_class[choice]
break
except IndexError:
click.echo("Invalid choice. Please try again.")
continue
if cls == CrudeWebdataset:
click.echo(
"CrudeWebdataset does not need a field map. You will need to provide a `Cooker` for your dataset samples in your `TaskEncoder`."
)
click.echo(
"Furthermore, you might want to add `subflavors` in your meta dataset specification."
)
dataset_definition = {
"__module__": "megatron.energon",
"__class__": cls.__name__,
}
else:
click.echo("The sample type you selected:\n")
click.echo(inspect.getsource(cls))
dataset_definition = {
"sample_type": {
"__module__": "megatron.energon",
"__class__": cls.__name__,
},
}
if not allow_interactive_field_map:
click.echo(
"You cannot set a field_map for this dataset. You will need a sample_loader."
)
if allow_interactive_field_map and click.confirm(
"Do you want to set a simple field_map[Y] (or write your own sample_loader [n])?",
default=True,
):
click.echo(
"\nFor each field, please specify the corresponding name in the WebDataset."
)
click.echo(f"Available types in WebDataset: {', '.join(found_types)}")
click.echo("Leave empty for skipping optional field")
click.echo(
"You may also access json fields e.g. by setting the field to: json[field][field]"
)
click.echo("You may also specify alternative fields e.g. by setting to: jpg,png")
click.echo(f"Please enter the field_map for {cls.__name__}:")
dataset_definition["field_map"] = field_map = {}
for field in dataclasses.fields(cls):
if field.name in (
"__key__",
"__restore_key__",
"__subflavors__",
"__sources__",
):
continue
while True:
if (
field.default is dataclasses.MISSING
and field.default_factory is dataclasses.MISSING
):
default = ""
elif field.default is not dataclasses.MISSING:
default = f", default: {field.default}"
elif field.default_factory is not dataclasses.MISSING:
default = f", default: {field.default_factory!r}"
else:
raise RuntimeError("This should never happen")
field_map[field.name] = input(
f"Please enter a webdataset field name for '{field.name}' "
f"({field.type}{default}): ",
)
if not field_map[field.name] and default:
del field_map[field.name]
break
type_ok = True
for option in field_map[field.name].split(","):
field_name = option.split("[", 1)[0]
if field_name not in found_types:
click.echo(
"That type doesn't exist in the WebDataset. Please try again."
)
type_ok = False
if type_ok:
break
else:
if not allow_interactive_field_map:
template_part_types = set(["TODO"])
else:
template_part_types = found_types
if not (path / MAIN_FOLDER_NAME / "sample_loader.py").is_file() or click.confirm(
"Do you want to override the existing sample loader?"
):
with (path / MAIN_FOLDER_NAME / "sample_loader.py").open("w") as f:
f.write(
sample_loader_template(
dataclasses.fields(cls),
parts=template_part_types,
)
)
click.echo(
f"\nCreated {path / MAIN_FOLDER_NAME / 'sample_loader.py'}. Please edit it to "
f"return the proper values."
)
dataset_definition["sample_loader"] = "sample_loader.py:sample_loader"
dataset_definition["part_filter"] = "sample_loader.py:part_filter"
# Write the dataset.yaml file
with (path / MAIN_FOLDER_NAME / "dataset.yaml").open("w") as f:
yaml.dump(dataset_definition, f, sort_keys=False)
else:
click.echo("You will have to add a dataset.yaml manually.")
click.echo("Done")
if __name__ == "__main__":
command()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
import shutil
import traceback
from typing import Any
import click
import torch
from megatron.energon import Sample, StandardWebdatasetFactory
from megatron.energon.dataset_config import load_config
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.empty_dataset_error import EmptyDatasetError
from megatron.energon.loader import get_loader
from megatron.energon.worker import WorkerConfig
def to_str(value: Any, indent: str) -> str:
if isinstance(value, torch.Tensor):
orig_value = value
# Probably image?
if value.ndim == 3 and value.shape[0] in [1, 3, 4]:
# Convert to grayscale
if value.shape[0] == 1:
value = value[0]
elif value.shape[0] == 3:
value = value.to(dtype=torch.float32).mean(dim=0)
elif value.shape[0] == 4:
value = value[:3].to(dtype=torch.float32).mean(dim=0)
if value.ndim == 2:
# 2d image -> ascii print
# Resize to fit terminal
dst_w, dst_h = shutil.get_terminal_size((80, 24))
orig_h, orig_w = value.shape
dst_w -= len(indent)
procrustes = 0.3
# keep aspect ratio
if orig_w / orig_h < dst_w / dst_h:
dst_h = int(dst_w * procrustes * orig_h / orig_w)
else:
dst_w = int(dst_h / procrustes * orig_w / orig_h)
value = torch.nn.functional.interpolate(
value[None, None, :, :].to(dtype=torch.float32), size=(dst_h, dst_w), mode="area"
)[0, 0]
# normalize
value = (value - value.min()) / (value.max() - value.min())
# to ascii text
return (
f"Tensor(shape={orig_value.shape}, dtype={orig_value.dtype}):\n{indent}"
+ f"\n{indent}".join(
"".join(" .:-=+*#%@@"[int(v * 10)] for v in row) for row in value.tolist()
)
+ "\n"
)
elif value.ndim == 1:
# 1d array... print it?
return f"Tensor(shape={value.shape}, dtype={value.dtype}): {value[:128].tolist()}"
else:
return f"Tensor(shape={value.shape}, dtype={value.dtype})"
elif isinstance(value, (str, int, float, bool, type(None))):
return repr(value)
elif isinstance(value, (list, tuple)):
if hasattr(value, "_fields"):
return (
f"{type(value).__name__}(\n{indent}"
+ f",\n{indent} ".join(
f"{field.name}={to_str(value, indent + ' ')}"
for value, field in zip(value, value._fields)
)
+ f"\n{indent})"
)
if len(value) > 0 and isinstance(value, (str, int, float, bool)):
return repr(type(value)(to_str(v, indent) for v in value))
else:
return (
f"[\n{indent}"
+ f"\n{indent} ".join(to_str(v, indent + " ") for v in value)
+ f"\n{indent}]"
)
elif isinstance(value, bytes):
return f"bytes(length={len(value)}, value={value[:128]!r})"
return repr(value)
def pprint(idx: int, sample: Sample):
click.echo(f"Sample {idx}")
for field in dataclasses.fields(sample):
if field.name in ("__restore_key__", "__subflavors__", "__sources__"):
continue
click.echo(f" - {field.name} ({field.type}): {to_str(getattr(sample, field.name), '')}")
@click.command(name="preview")
@click.argument(
"path",
type=click.Path(file_okay=False, dir_okay=True, path_type=EPath),
)
@click.option(
"--split-parts", default="train,val,test", help="The splits to verify", show_default=True
)
@click.option(
"--dataset-config", default="dataset.yaml", help="Dataset config file name", show_default=True
)
def command(path: EPath, split_parts: str, dataset_config: str):
"""Preview samples of a dataset on the console."""
worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0)
for split_part in split_parts.split(","):
try:
dataset = load_config(
EPath(path) / MAIN_FOLDER_NAME / dataset_config,
default_kwargs=dict(
path=path,
split_part=split_part,
training=False,
worker_config=worker_config,
),
default_type=StandardWebdatasetFactory,
)
except EmptyDatasetError:
click.echo(f"Dataset {split_part} is empty. Skipping.")
continue
try:
for idx, sample in enumerate(get_loader(dataset.build())):
pprint(idx, sample)
click.confirm("Continue?", abort=True)
except click.Abort:
click.echo("Exiting Preview")
except BaseException:
traceback.print_exc()
raise click.ClickException("Validation failed with errors, see logs for details.")
if __name__ == "__main__":
command()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.transforms.custom import CustomTransform
from megatron.energon.transforms.mappers import (
CenterCropMapper,
RandomCropMapper,
RandomHorizontalFlipMapper,
RandomPerspectiveMapper,
RandomResizedCropMapper,
RandomRotationMapper,
RandomVerticalFlipMapper,
ResizeMapper,
TransformMapper,
)
from megatron.energon.transforms.merge import MergeTransform
__all__ = [
"TransformMapper",
"ResizeMapper",
"RandomResizedCropMapper",
"CenterCropMapper",
"RandomCropMapper",
"RandomHorizontalFlipMapper",
"RandomVerticalFlipMapper",
"RandomRotationMapper",
"RandomPerspectiveMapper",
"MergeTransform",
"CustomTransform",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment