Commit f356f546 authored by maming's avatar maming
Browse files

Initial commit

parents
Pipeline #3339 canceled with stages
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from enum import Enum
from typing import TypeVar
import PIL.Image
import torch
TTransform = TypeVar("TTransform", bound=torch.nn.Module)
class Sampling(Enum):
BILINEAR = PIL.Image.BILINEAR
NEAREST = PIL.Image.NEAREST
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import torch
from megatron.energon.transforms.mappers import TransformMapper
class CustomTransform(TransformMapper[torch.nn.Module]):
"""Abstract class for additional custom transforms.
Inherit from this and override at least `apply_transform`.
"""
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import math
import numbers
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Generic, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from torchvision import transforms as T
from megatron.energon.transforms.common import TTransform
class TransformMapper(ABC, Generic[TTransform]):
source_type: ClassVar[Type[torch.nn.Module]]
transform: TTransform
def __init__(self, transform: TTransform):
self.transform = transform
def __call__(self, sample):
return self.transform(sample)
@staticmethod
def translate(x: float, y: float) -> np.ndarray:
m = np.eye(3, dtype=np.float64)
m[0, 2] = x
m[1, 2] = y
return m
@staticmethod
def rotate(angle: float) -> np.ndarray:
"""Counter-clockwise rotation. Note that the Y-axis is point down."""
m = np.eye(3, dtype=np.float64)
m[:2, :2] = np.array([[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]])
return m
@staticmethod
def scale(x: float, y: float) -> np.ndarray:
m = np.eye(3, dtype=np.float64)
m[0, 0] = x
m[1, 1] = y
return m
@staticmethod
def shear(x: float, y: float) -> np.ndarray:
m = np.eye(3, dtype=np.float64)
m[0, 1] = x
m[1, 0] = y
return m
@staticmethod
def hflip() -> np.ndarray:
m = np.eye(3, dtype=np.float64)
m[0, 0] = -1
return m
@staticmethod
def vflip() -> np.ndarray:
m = np.eye(3, dtype=np.float64)
m[1, 1] = -1
return m
@abstractmethod
def apply_transform(
self, matrix: np.ndarray, dst_size: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Any]: ...
def fill(
self,
) -> Optional[Union[int, float, Tuple[Union[int, float], ...], List[Union[int, float]]]]:
return None
def interpolation(self) -> Optional[T.InterpolationMode]:
return None
class ResizeMapper(TransformMapper[T.Resize]):
source_type = T.Resize
def __init__(self, transform: T.Resize):
super().__init__(transform)
def _compute_resized_output_size(
self, image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]:
if len(size) == 1: # specified size only for the smallest edge
h, w = image_size
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size[0]
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
else: # specified both h and w
new_w, new_h = size[1], size[0]
return [new_h, new_w]
def apply_transform(
self, matrix: np.ndarray, dst_size: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Tuple[Any, ...]]:
size = self.transform.size
if isinstance(size, int):
size = [size]
h, w = self._compute_resized_output_size(dst_size, size, self.transform.max_size)
matrix = self.scale(w / dst_size[1], h / dst_size[0]) @ matrix
# matrix = self.scale((w - 1) / (dst_size[1] - 1), (h - 1) / (dst_size[0] - 1)) @ matrix
# matrix = self.translate(0.25, 0.25) @ matrix
# matrix = self.translate(0.1, 0) @ matrix
dst_size = np.array((h, w), dtype=dst_size.dtype)
# print(f"Resize s={size}")
return matrix, dst_size, (self.source_type.__name__, size)
def interpolation(self) -> Optional[T.InterpolationMode]:
return self.transform.interpolation
class RandomResizedCropMapper(TransformMapper[T.RandomResizedCrop]):
source_type = T.RandomResizedCrop
def __init__(self, transform: T.RandomResizedCrop):
super().__init__(transform)
def get_params(self, size: Tuple[int, int]) -> Tuple[int, int, int, int]:
"""
Gets the parameters for a random resized crop.
This function is derived from T.RandomResizedCrop.get_params, but without requiring the
input image (to determine the input size).
Returns:
Tuple of (top, left, height, width).
"""
height, width = size
area = height * width
log_ratio = torch.log(torch.tensor(self.transform.ratio))
for _ in range(10):
target_area = (
area
* torch.empty(1).uniform_(self.transform.scale[0], self.transform.scale[1]).item()
)
aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = torch.randint(0, height - h + 1, size=(1,)).item()
j = torch.randint(0, width - w + 1, size=(1,)).item()
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(self.transform.ratio):
w = width
h = int(round(w / min(self.transform.ratio)))
elif in_ratio > max(self.transform.ratio):
h = height
w = int(round(h * max(self.transform.ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def apply_transform(
self, matrix: np.ndarray, dst_size: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Tuple[Any, ...]]:
top, left, height, width = self.get_params(dst_size)
# print(
# "RandomResizedCrop", top, left, dst_size[0] - height - top, dst_size[1] - width - left
# )
# Crop to left, top, height, width
matrix = self.translate(-left, -top) @ matrix
dst_size = np.array([height, width], dtype=dst_size.dtype)
# Resize to target size
matrix = (
self.scale(self.transform.size[1] / dst_size[1], self.transform.size[0] / dst_size[0])
@ matrix
)
dst_size = np.array(self.transform.size, dtype=dst_size.dtype)
return matrix, dst_size, (self.source_type.__name__, (top, left, height, width))
def interpolation(self) -> Optional[T.InterpolationMode]:
return self.transform.interpolation
class RandomHorizontalFlipMapper(TransformMapper[T.RandomHorizontalFlip]):
source_type = T.RandomHorizontalFlip
def __init__(self, transform: T.RandomHorizontalFlip):
super().__init__(transform)
def apply_transform(
self, matrix: np.ndarray, dst_size: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Any]:
do_flip = torch.rand(1) < self.transform.p
if do_flip:
matrix = self.hflip() @ matrix
matrix = self.translate(dst_size[1], 0) @ matrix
# print(f"RandomHorizontalFlip")
return matrix, dst_size, (self.source_type.__name__, do_flip)
class RandomVerticalFlipMapper(TransformMapper[T.RandomVerticalFlip]):
source_type = T.RandomVerticalFlip
def __init__(self, transform: T.RandomVerticalFlip):
super().__init__(transform)
def apply_transform(
self, matrix: np.ndarray, dst_size: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Any]:
do_flip = torch.rand(1) < self.transform.p
if do_flip:
matrix = self.vflip() @ matrix
matrix = self.translate(0, dst_size[0]) @ matrix
# print(f"RandomVerticalFlip")
return matrix, dst_size, (self.source_type.__name__, do_flip)
class RandomRotationMapper(TransformMapper[T.RandomRotation]):
source_type = T.RandomRotation
def __init__(self, transform: T.RandomRotation):
super().__init__(transform)
def apply_transform(
self, matrix: np.ndarray, dst_size: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Any]:
assert self.transform.center is None, "Only centered rotation is supported"
degrees = self.transform.get_params(self.transform.degrees)
rads = degrees * np.pi / 180
# print(f"Rotate deg={degrees}")
orig_size = dst_size
if self.transform.expand:
# Compute size of rotated rectangle
w = np.abs(np.sin(rads)) * dst_size[0] + np.abs(np.cos(rads)) * dst_size[1]
h = np.abs(np.sin(rads)) * dst_size[1] + np.abs(np.cos(rads)) * dst_size[0]
# Round in the same way as PIL does
rounded_w = np.ceil(orig_size[1] / 2 + w / 2) - np.floor(orig_size[1] / 2 - w / 2)
rounded_h = np.ceil(orig_size[0] / 2 + h / 2) - np.floor(orig_size[0] / 2 - h / 2)
# New size is h, w
dst_size = np.array([int(rounded_h), int(rounded_w)], dtype=dst_size.dtype)
matrix = (
self.translate(dst_size[1] / 2, dst_size[0] / 2)
@ self.rotate(rads)
@ self.translate(-orig_size[1] / 2, -orig_size[0] / 2)
@ matrix
)
return matrix, dst_size, (self.source_type.__name__, degrees)
def fill(
self,
) -> Optional[Union[int, float, Tuple[Union[int, float], ...], List[Union[int, float]]]]:
return self.transform.fill
def interpolation(self) -> Optional[T.InterpolationMode]:
return self.transform.interpolation
class RandomCropMapper(TransformMapper[T.RandomCrop]):
source_type = T.RandomCrop
def __init__(self, transform: T.RandomCrop):
super().__init__(transform)
def apply_transform(
self, matrix: np.ndarray, dst_size: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Any]:
th, tw = self.transform.size # Target height and width
# pad the width if needed
if self.transform.pad_if_needed and dst_size[1] < tw:
padding = tw - dst_size[1] # Pad this much on both left and right
matrix = self.translate(padding, 0) @ matrix
dst_size[1] += 2 * padding
# pad the height if needed
if self.transform.pad_if_needed and dst_size[0] < th:
padding = th - dst_size[0] # Pad this much on both top and bottom
matrix = self.translate(0, padding) @ matrix
dst_size[0] += 2 * padding
h, w = dst_size
if h < th or w < tw:
raise ValueError(
f"Required crop size {(th, tw)} is larger than input image size {(h, w)}"
)
if w == tw and h == th:
# No need to crop if we're at the target size already
i = 0
j = 0
else:
i = torch.randint(0, h - th + 1, size=(1,)).item() # Offset y
j = torch.randint(0, w - tw + 1, size=(1,)).item() # Offset x
matrix = self.translate(-j, -i) @ matrix
if self.transform.pad_if_needed:
dst_size = np.array((th, tw), dtype=dst_size.dtype)
else:
dst_size = np.array((min(th, dst_size[0]), min(tw, dst_size[1])), dtype=dst_size.dtype)
# print(f"RandomCrop t=[{dx}, {dy}], s={dst_size}")
return matrix, dst_size, (self.source_type.__name__, (j, i, th, tw))
def fill(
self,
) -> Optional[Union[int, float, Tuple[Union[int, float], ...], List[Union[int, float]]]]:
return self.transform.fill
class RandomPerspectiveMapper(TransformMapper[T.RandomPerspective]):
source_type = T.RandomPerspective
def __init__(self, transform: T.RandomPerspective):
super().__init__(transform)
@staticmethod
def compute_homography(
startpoints: List[Tuple[float, float]], endpoints: List[Tuple[float, float]]
) -> np.ndarray:
assert len(startpoints) == len(endpoints) == 4
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)
for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
a_matrix[2 * i, :] = torch.tensor(
[p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]
)
a_matrix[2 * i + 1, :] = torch.tensor(
[0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]
)
b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution
m = np.eye(3, dtype=np.float32)
m[0, :] = res[:3]
m[1, :] = res[3:6]
m[2, :2] = res[6:]
return m
def apply_transform(
self, matrix: np.ndarray, dst_size: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Any]:
assert self.transform.fill == 0, "Only zero fill is supported"
startpoints = None
endpoints = None
if torch.rand(1) <= self.transform.p:
startpoints, endpoints = self.transform.get_params(
dst_size[1], dst_size[0], self.transform.distortion_scale
)
# print(
# f"Perspective ds={self.transform.distortion_scale}: sp={startpoints} -> ep={endpoints}"
# )
matrix = self.compute_homography(endpoints, startpoints) @ matrix
return matrix, dst_size, (self.source_type.__name__, startpoints, endpoints)
def fill(
self,
) -> Optional[Union[int, float, Tuple[Union[int, float], ...], List[Union[int, float]]]]:
return self.transform.fill
def interpolation(self) -> Optional[T.InterpolationMode]:
return self.transform.interpolation
class CenterCropMapper(TransformMapper[T.CenterCrop]):
source_type = T.CenterCrop
def __init__(self, transform: T.CenterCrop):
super().__init__(transform)
def apply_transform(
self, matrix: np.ndarray, dst_size: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Any]:
if isinstance(self.transform.size, numbers.Number):
th, tw = int(self.transform.size), int(self.transform.size)
elif isinstance(self.transform.size, (tuple, list)) and len(self.transform.size) == 1:
th, tw = self.transform.size[0], self.transform.size[0]
else:
th, tw = self.transform.size
shift_y = round(float(th - dst_size[0]) / 2)
shift_x = round(float(tw - dst_size[1]) / 2)
matrix = self.translate(shift_x, shift_y) @ matrix
dst_size = np.array((th, tw), dtype=dst_size.dtype)
# print(f"CenterCrop t=[{dx}, {dy}], s={dst_size}")
return matrix, dst_size, (self.source_type.__name__, (shift_y, shift_x, th, tw))
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Type
import numpy as np
import PIL.Image
import torch
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode, pil_modes_mapping
from megatron.energon.transforms.custom import CustomTransform
from megatron.energon.transforms.mappers import TransformMapper
class MergeTransform(torch.nn.Module):
transforms: List[TransformMapper]
def __init__(self, transforms: List[torch.nn.Module]):
super().__init__()
self.transforms = [self._transform_to_mapper(t) for t in transforms]
# Get common interpolation and fill values
interpolation = None
interpolation_t = None
fill = None
fill_t = None
for t in self.transforms:
t_fill = t.fill()
if t_fill is not None:
if fill is None:
fill = t_fill
fill_t = t
if fill is not None and t_fill != fill:
raise ValueError(
f"Fill values are not equal: {fill} ({fill_t}) != {t_fill} ({t})"
)
t_interpolation = t.interpolation()
if t_interpolation is not None:
if interpolation is None:
interpolation = t_interpolation
interpolation_t = t
if interpolation is not None and t_interpolation != interpolation:
raise ValueError(
f"Interpolation values are not equal: {interpolation} ({interpolation_t}) != {t_interpolation} ({t})"
)
self.interpolation = InterpolationMode.BILINEAR if interpolation is None else interpolation
self.fill_value = fill
def _transform_to_mapper(self, transform: torch.nn.Module) -> Type[TransformMapper]:
"""Given a transform object, instantiate the corresponding mapper.
This also handles objects of derived transform classes."""
if isinstance(transform, CustomTransform):
# Custom transforms can be used as-is, they provide the apply_transform method
return transform
for m in TransformMapper.__subclasses__():
if isinstance(transform, m.source_type):
return m(transform) # Instantiate
raise ValueError(f"Unsupported transform type {type(transform)}")
def forward(self, x):
matrix = np.eye(3, dtype=np.float64)
if isinstance(x, PIL.Image.Image):
dst_size = np.array((x.height, x.width), dtype=np.int64)
else:
dst_size = np.array(x.shape[-2:], dtype=np.int64)
all_params = []
for transform in self.transforms:
matrix, dst_size, params = transform.apply_transform(matrix, dst_size)
all_params.append(params)
if isinstance(x, PIL.Image.Image):
try:
interpolation = pil_modes_mapping[self.interpolation]
except KeyError:
raise NotImplementedError(f"interpolation: {self.interpolation}")
# Invert matrix for backward mapping
matrix = np.linalg.inv(matrix)
# Scale matrix
matrix /= matrix[2, 2]
if self.fill_value is None:
fill_color = None
elif isinstance(self.fill_value, (int, float)):
fill_color = (self.fill_value,) * len(x.getbands())
else:
fill_color = self.fill_value
if np.allclose(matrix[2, :2], [0, 0]):
# print("PIL Affine")
return x.transform(
tuple(dst_size[::-1]),
PIL.Image.AFFINE,
matrix.flatten()[:6],
interpolation,
fillcolor=fill_color,
)
else:
# print("PIL Perspective")
return x.transform(
tuple(dst_size[::-1]),
PIL.Image.PERSPECTIVE,
matrix.flatten()[:8],
interpolation,
fillcolor=fill_color,
)
elif isinstance(x, torch.Tensor):
print("torch affine")
if self.interpolation == T.InterpolationMode.NEAREST:
interpolation = "nearest"
elif self.interpolation == T.InterpolationMode.BILINEAR:
interpolation = "bilinear"
elif self.interpolation == T.InterpolationMode.BICUBIC:
interpolation = "bicubic"
else:
raise NotImplementedError(f"interpolation: {self.interpolation}")
if self.fill_value is not None and self.fill_value != 0:
raise NotImplementedError(
f"Fill value {self.fill_value} is not supported for torch"
)
# Normalize to [-1, 1] range
matrix = (
TransformMapper.translate(-1, -1)
@ TransformMapper.scale(2 / dst_size[1], 2 / dst_size[0])
@ matrix
@ TransformMapper.scale(x.shape[-1] / 2, x.shape[-2] / 2)
@ TransformMapper.translate(1, 1)
)
matrix = np.linalg.inv(matrix)
if np.allclose(matrix[2, :2], [0, 0]):
grid = torch.nn.functional.affine_grid(
torch.as_tensor(matrix[None, :2, :], dtype=torch.float32),
torch.Size((1, 3, *dst_size)),
)
else:
xs = torch.linspace(-1, 1, dst_size[1], dtype=torch.float32)
ys = torch.linspace(-1, 1, dst_size[0], dtype=torch.float32)
zs = torch.ones((1,), dtype=torch.float32)
# shape: (2<x,y,1>, W, H)
grid = torch.stack(torch.meshgrid([xs, ys, zs], indexing="ij"))[..., 0]
# shape: (H, W, 2<x,y,1>)
grid = grid.permute(2, 1, 0)
# shape: (H, W, 3<x,y,w>, 1)
grid = (
torch.as_tensor(matrix, dtype=torch.float32)[None, None, ...] @ grid[..., None]
)
# shape: (H, W, 2<x,y>)
grid = grid[:, :, :2, 0] / grid[:, :, 2:3, 0]
# shape: (1, H, W, 2<x,y>)
grid = grid[None, ...]
return torch.nn.functional.grid_sample(
x[None, ...], grid, interpolation, padding_mode="zeros", align_corners=False
)[0, ...]
else:
raise NotImplementedError()
# TODO: Needs implementation and testing
import cv2
return cv2.warpAffine(x, matrix[:2], tuple(dst_size), flags=cv2.INTER_LINEAR)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""Provides methods for converting typed objects to json objects and vice versa."""
import dataclasses
import importlib
import inspect
import json
import re
import typing
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, TypeVar, Union
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
class JsonValueError(ValueError):
"""Internal error class, for better error messages."""
def __init__(
self, msg: str, expected_type: Type, value: Any, path: str, stage: Tuple[int, ...]
) -> None:
"""New error"""
super().__init__(msg)
self.expected_type = expected_type
self.value = value
self.path = path
self.stage = stage
class _MissingValue:
"""Internal class representing a missing value (i.e. a field which has not been specified,
opposed to explicit `None`) for better error messages."""
def __repr__(self) -> str:
return "missing value"
_missing_value = _MissingValue()
TType = TypeVar("TType")
def _check_instance_type(cls, inst_type: Type) -> bool:
"""Check if a class is an instance of a type."""
if inst_type is None:
return True
elif typing.get_origin(inst_type) is not None:
org = typing.get_origin(inst_type)
if org is Union:
for check_type in typing.get_args(inst_type):
if _check_instance_type(cls, check_type):
return True
elif isinstance(org, type) and issubclass(cls, org):
return True
return False
elif inst_type is Any:
return True
else:
return not issubclass(inst_type, type) or issubclass(cls, inst_type)
@edataclass
class JsonParser:
strict: bool = True
def _resolve_object(
self,
module_name: str,
object_name: str,
cls: Type[TType],
is_type: bool,
is_callable: bool,
is_instantiating_class: bool,
is_calling_function: bool,
) -> TType:
"""Resolve an object reference to the object."""
try:
module = importlib.import_module(module_name)
except ModuleNotFoundError:
try_energon_module = importlib.import_module("megatron.energon", package=None)
if hasattr(try_energon_module, object_name):
module = try_energon_module
else:
raise
try:
return getattr(module, object_name)
except AttributeError:
raise ModuleNotFoundError(f"Object {object_name} not found in {module_name}")
def raw_to_instance(
self,
kwargs: dict,
inst_type: Type[TType],
_path: str = "root",
_stage: Tuple[int, ...] = (),
) -> TType:
"""
Try to import and instantiate a class from a dict with "__module__" and "__class__"/"__function__" keys.
Args:
kwargs: The dict to parse
inst_type: Expected return type, used if type is not specified in the kwargs
strict: If true, don't allow additional attributes
_path: (internal for recursive call) The path to the object being converted from the root
_stage: (internal for recursive call) Numbers representing the position of the current
object being converted from the root
Returns:
Instantiated class
"""
kwargs = kwargs.copy()
module_name = kwargs.pop("__module__", None)
# Check if this is a type of Type[...] or just a class. Type[...] will return the class instead
# of instantiating it.
is_type = typing.get_origin(inst_type) is type
is_callable = typing.get_origin(inst_type) is typing.get_origin(Callable)
is_calling_function = False
is_instantiating_class = False
if is_type:
inst_type = typing.get_args(inst_type)[0]
object_name = kwargs.pop("__class__", None)
if module_name is None or object_name is None:
raise JsonValueError(
f"Expected __module__ and __class__ for Type[{inst_type}], got {kwargs}",
inst_type,
(module_name, object_name),
_path,
_stage,
)
elif is_callable:
object_name = kwargs.pop("__function__", None)
if module_name is None or object_name is None:
raise JsonValueError(
f"Expected __module__ and __function__ for {inst_type}, got {kwargs}",
inst_type,
(module_name, object_name),
_path,
_stage,
)
else:
if "__class__" in kwargs:
object_name = kwargs.pop("__class__", None)
is_instantiating_class = True
is_calling_function = False
elif "__function__" in kwargs:
object_name = kwargs.pop("__function__", None)
is_instantiating_class = False
is_calling_function = True
# Else case: It's a plain type, and nothing was passed, use the default cls
if module_name is None or object_name is None:
cls = inst_type
else:
cls = self._resolve_object(
module_name,
object_name,
inst_type,
is_type,
is_callable,
is_instantiating_class,
is_calling_function,
)
if is_type:
if isinstance(inst_type, type) and (
not isinstance(cls, type) or not issubclass(cls, inst_type)
):
raise JsonValueError(
f"Expected Type[{inst_type}], got {cls}", inst_type, cls, _path, _stage
)
elif is_callable:
if not callable(cls):
raise JsonValueError(
f"Expected a callable, got {cls}", inst_type, cls, _path, _stage
)
elif is_instantiating_class:
if not isinstance(cls, type) or not _check_instance_type(cls, inst_type):
raise JsonValueError(
f"Expected {inst_type}, got {cls}", inst_type, cls, _path, _stage
)
else:
assert is_calling_function
if not callable(cls):
raise JsonValueError(
f"Expected {inst_type}, got {cls}", inst_type, cls, _path, _stage
)
if is_type or is_callable:
inst = cls
else:
# Do not assert the other cases, we fallback to the passed cls
inst = self.safe_call_function(kwargs, cls, allow_imports=True)
assert not isinstance(cls, type) or _check_instance_type(type(inst), inst_type), (
f"Expected {inst_type}, got {cls}"
)
return inst
def raw_to_typed( # noqa: C901
self,
raw_data: Union[dict, list, str, int, bool, float, None],
inst_type: Type[TType],
allow_imports: bool = False,
_path: str = "root",
_stage: Tuple[int, ...] = (),
) -> TType:
"""
Converts raw data (i.e. dicts, lists and primitives) to typed objects (like
`NamedTuple` or `dataclasses.dataclass`). Validates that python typing matches.
Usage::
class MyNamedTuple(NamedTuple):
x: int
y: str
assert raw_to_typed({'x': int, 'y': "foo"}, MyNamedTuple) == MyNamedTuple(x=5, y="foo")
Args:
raw_data: The raw (e.g. json) data to be made as `inst_type`
inst_type: The type to return
allow_imports: If true, parse '__module__' and '__class__/__function__' attributes to allow explicit
instantiation of types
_path: (internal for recursive call) The path to the object being converted from the root
_stage: (internal for recursive call) Numbers representing the position of the current
object being converted from the root
Returns:
The input data as `inst_type`.
"""
type_name = getattr(inst_type, "__name__", repr(inst_type))
if raw_data is _missing_value:
raise JsonValueError(
f"Missing value at {_path}",
inst_type,
raw_data,
_path,
_stage,
)
elif inst_type in (str, int, float, bool, None, type(None)):
# Literal types or missing data
if not isinstance(raw_data, inst_type) and not (
isinstance(raw_data, int) and inst_type is float
):
raise JsonValueError(
f"Type does not match, expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
return raw_data
elif inst_type is Any:
if (
allow_imports
and isinstance(raw_data, dict)
and "__module__" in raw_data
and ("__class__" in raw_data or "__function__" in raw_data)
):
return self.raw_to_instance(raw_data, inst_type, _path=_path, _stage=_stage)
# Any
return raw_data
elif typing.get_origin(inst_type) is Literal:
# Literal[value[, ...]]
values = typing.get_args(inst_type)
if raw_data not in values:
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
return raw_data
elif typing.get_origin(inst_type) is Union:
# Union[union_types[0], union_types[1], ...]
union_types = typing.get_args(inst_type)
if None in union_types:
# Fast Optional path
if raw_data is None:
return None
best_inner_error: Optional[JsonValueError] = None
inner_exceptions = []
for subtype in union_types:
try:
return self.raw_to_typed(
raw_data,
subtype,
allow_imports,
f"{_path} -> {getattr(subtype, '__name__', repr(subtype))}",
_stage + (1,),
)
except JsonValueError as err:
if best_inner_error is None or len(err.stage) > len(best_inner_error.stage):
best_inner_error = err
inner_exceptions.clear()
inner_exceptions.append(err)
elif len(err.stage) == len(best_inner_error.stage):
inner_exceptions.append(err)
continue
if len(inner_exceptions) > 0:
cur_exc = inner_exceptions[0]
for next_exc in inner_exceptions[1:]:
try:
raise next_exc from cur_exc
except JsonValueError as e:
cur_exc = e
raise cur_exc
else:
raise JsonValueError(
f"Expected {inst_type} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
elif (
isinstance(inst_type, type)
and issubclass(inst_type, tuple)
and hasattr(inst_type, "__annotations__")
):
# class MyClass(NamedTuple): ...
if not isinstance(raw_data, dict):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
if getattr(inst_type, "__dash_keys__", "False"):
raw_data = {key.replace("-", "_"): val for key, val in raw_data.items()}
defaults = getattr(inst_type, "_field_defaults", {})
kwargs = {
field_name: self.raw_to_typed(
raw_data.get(field_name, defaults.get(field_name, _missing_value)),
field_type,
allow_imports,
f"{_path} -> {type_name}:{field_name}",
_stage + (idx,),
)
for idx, (field_name, field_type) in enumerate(inst_type.__annotations__.items())
}
if self.strict and not set(raw_data).issubset(inst_type.__annotations__):
raise JsonValueError(
f"Additional attributes for {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
try:
return inst_type(**kwargs)
except BaseException:
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
elif dataclasses.is_dataclass(inst_type):
# dataclass
if not isinstance(raw_data, dict):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
kwargs = {
field.name: self.raw_to_typed(
raw_data.get(
field.name,
(
(
_missing_value
if field.default_factory is dataclasses.MISSING
else field.default_factory()
)
if field.default is dataclasses.MISSING
else field.default
),
),
field.type,
allow_imports,
f"{_path} -> {type_name}:{field.name}",
_stage + (idx,),
)
for idx, field in enumerate(dataclasses.fields(inst_type))
if field.init
}
if self.strict and not set(raw_data).issubset(
field.name for field in dataclasses.fields(inst_type) if field.init
):
raise JsonValueError(
f"Additional attributes for {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
try:
return inst_type(**kwargs)
except BaseException:
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
elif typing.get_origin(inst_type) is list:
# List[inner_type]
(inner_type,) = typing.get_args(inst_type)
if not isinstance(raw_data, list):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
return [
self.raw_to_typed(
val, inner_type, allow_imports, f"{_path} -> {idx}", _stage + (idx,)
)
for idx, val in enumerate(raw_data)
]
elif typing.get_origin(inst_type) is set:
# Set[inner_type]
(inner_type,) = typing.get_args(inst_type)
if not isinstance(raw_data, list):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
res = set(
self.raw_to_typed(
val, inner_type, allow_imports, f"{_path} -> {idx}", _stage + (idx,)
)
for idx, val in enumerate(raw_data)
)
if len(res) != len(raw_data):
raise JsonValueError(
f"Duplicate element at {_path}",
inst_type,
raw_data,
_path,
_stage,
)
return res
elif typing.get_origin(inst_type) is tuple:
# Tuple[inner_types[0], inner_types[1], ...] or Tuple[inner_types[0], Ellipsis/...]
inner_types = typing.get_args(inst_type)
if not isinstance(raw_data, list):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
if len(inner_types) == 2 and inner_types[1] is Ellipsis:
# Tuple of arbitrary length, all elements same type
# Tuple[inner_types[0], Ellipsis/...]
return tuple(
self.raw_to_typed(
val, inner_types[0], allow_imports, f"{_path} -> {idx}", _stage + (idx,)
)
for idx, val in enumerate(raw_data)
)
else:
# Fixed size/typed tuple
# Tuple[inner_types[0], inner_types[1], ...]
if len(raw_data) != len(inner_types):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
return [
self.raw_to_typed(
val, inner_type, allow_imports, f"{_path} -> {idx}", _stage + (idx,)
)
for idx, (val, inner_type) in enumerate(zip(raw_data, inner_types))
]
elif typing.get_origin(inst_type) is dict:
# Dict[str, value_type]
key_type, value_type = typing.get_args(inst_type)
assert key_type is str
if not isinstance(raw_data, dict):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
return {
key: self.raw_to_typed(
val, value_type, allow_imports, f"{_path} -> {key!r}", _stage + (idx,)
)
for idx, (key, val) in enumerate(raw_data.items())
}
elif inst_type in (dict, list):
# dict, list (no subtyping)
if not isinstance(raw_data, inst_type):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
return raw_data
elif inst_type is EPath:
if isinstance(raw_data, str):
return EPath(raw_data)
elif not isinstance(raw_data, EPath):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {raw_data!r}",
inst_type,
raw_data,
_path,
_stage,
)
return raw_data
elif (
allow_imports
and isinstance(raw_data, dict)
and "__module__" in raw_data
and ("__class__" in raw_data or "__function__" in raw_data)
):
return self.raw_to_instance(raw_data, inst_type, _path=_path, _stage=_stage)
else:
return raw_data
def safe_call_function(
self,
raw_data: Union[dict, list, str, int, bool, float, None],
fn: Callable[..., TType],
allow_imports: bool = False,
) -> TType:
"""
Converts raw data (i.e. dicts, lists and primitives) to typed call arguments.
Validates that python typing matches.
Usage::
def fn(arg1: float, arg2: MyType, arg3) -> Any:
assert isinstance(arg1, float)
assert isinstance(arg2, MyType)
fn(3.141, MyType(), None)
Args:
raw_data: The raw (e.g. json) data to be made as `inst_type`
fn: The function to call with the converted data
strict: If true, don't allow additional attributes
allow_imports: If true, allow instantiating objects by specifying __module__ and __class__/__function__.
Returns:
The return value of `fn`
"""
parameters = list(inspect.signature(fn).parameters.items())
if inspect.isclass(fn):
init_sig = getattr(fn, "__init__", None)
if init_sig is not None:
parameters = list(inspect.signature(init_sig).parameters.items())[1:]
args = []
kwargs = {}
if isinstance(raw_data, dict):
unused_args = raw_data.copy()
for idx, (key, param) in enumerate(parameters):
t = Any if param.annotation is inspect.Parameter.empty else param.annotation
if param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
):
if param.default is inspect.Parameter.empty and key not in unused_args:
raise ValueError(f"Missing required argument {key!r} for {fn}")
kwargs[key] = self.raw_to_typed(
unused_args.pop(key, param.default),
t,
allow_imports,
_path=key,
_stage=(idx,),
)
elif param.kind == inspect.Parameter.VAR_KEYWORD:
for arg_key, arg_val in unused_args.items():
kwargs[arg_key] = self.raw_to_typed(
arg_val, t, allow_imports, _path=key, _stage=(idx,)
)
unused_args.clear()
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
# No way to pass positional arguments
pass
elif param.kind == inspect.Parameter.POSITIONAL_ONLY:
# No way to pass positional arguments
raise RuntimeError(f"Unsupported positional only argument {key!r}")
else:
raise RuntimeError(f"Unknown parameter kind {param.kind!r}")
if self.strict and len(unused_args) > 0:
raise ValueError(f"Unexpected arguments: {unused_args!r}")
elif isinstance(raw_data, list):
unused_args = raw_data.copy()
for idx, (key, param) in enumerate(parameters):
t = Any if param.annotation is inspect.Parameter.empty else param.annotation
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
if param.default is inspect.Parameter.empty and len(unused_args) == 0:
raise ValueError(
f"Missing required positional-only argument {key!r} at index {idx}"
)
args.append(
self.raw_to_typed(
unused_args.pop(), t, allow_imports, _path=key, _stage=(idx,)
)
)
elif param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
if param.default is inspect.Parameter.empty and len(unused_args) == 0:
raise ValueError(
f"Missing required positional argument {key!r} at index {idx}"
)
if len(unused_args) == 0:
arg_val = param.default
else:
arg_val = unused_args.pop()
args.append(
self.raw_to_typed(arg_val, t, allow_imports, _path=key, _stage=(idx,))
)
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
for arg_val in unused_args:
args.append(
self.raw_to_typed(arg_val, t, allow_imports, _path=key, _stage=(idx,))
)
unused_args.clear()
elif param.kind == inspect.Parameter.VAR_KEYWORD:
# No way to pass keyword arguments
pass
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
raise RuntimeError(f"Unsupported keyword-only argument {key!r}")
else:
raise RuntimeError(f"Unknown parameter kind {param.kind!r}")
if self.strict and len(unused_args) > 0:
raise ValueError(f"Unexpected arguments: {unused_args!r}")
else:
raise ValueError(
f"Cannot call function with raw data of type {type(raw_data)!r}, require list or dict"
)
return fn(*args, **kwargs)
def override( # noqa: C901
self,
value: TType,
overrides: Any,
inst_type: Optional[Type[TType]] = None,
allow_imports: bool = False,
_path: str = "root",
_stage: Tuple[int, ...] = (),
) -> TType:
"""
Allows overriding values of a typed object using environment config.
Allows overriding single config variables, or whole objects.
Examples::
class MyNamedTuple(NamedTuple):
x: int
y: str
class MyNested(NamedTuple):
nested: MyNamedTuple
assert override(
MyNested(nested=MyNamedTuple(x=42, y="foo")),
{'nested.x': 5},
) == MyNested(nested=MyNamedTuple(x=5, y="foo"))
assert override(
MyNested(nested=MyNamedTuple(x=42, y="foo")),
{'nested': '{"x": 5, "y": "bar"}'},
) == MyNested(nested=MyNamedTuple(x=5, y="bar"))
Args:
value: The base value to override.
overrides: The overrides to apply
strict: If true, no additional keys are allowed
inst_type: If given, validate against this base type instead of the type of `value`.
allow_imports: If true, allow instantiating types with dicts of __module__ and __class__/__function__.
_path: Internal: The path to the current value.
_stage: Internal: The current stage of the override.
Returns:
Same type as the input object (or `inst_type` if set), copied and updated from the
overrides.
"""
if inst_type is None:
inst_type = type(value)
type_name = getattr(inst_type, "__name__", repr(inst_type))
if inst_type in (str, int, float, bool, None, type(None)):
# Literal types
if inst_type in (None, type(None)) and overrides == "None":
overrides = None
elif inst_type is bool and overrides in ("True", "true", "1", "False", "false", "0"):
overrides = overrides in ("True", "true", "1")
elif inst_type in (int, float) and isinstance(overrides, str):
overrides = inst_type(overrides)
if not isinstance(overrides, inst_type) and not (
isinstance(overrides, int) and inst_type is float
):
raise JsonValueError(
f"Type does not match, expected {type_name} at {_path}, got {overrides!r}",
inst_type,
overrides,
_path,
_stage,
)
return overrides
elif inst_type is Any:
# Any
if isinstance(overrides, str):
if overrides.isnumeric():
return int(overrides)
elif overrides == "True":
return True
elif overrides == "False":
return True
return overrides
if isinstance(value, (dict, list, tuple)):
# Merge with dict, list, str
return self.override(value, overrides, type(value), allow_imports, _path, _stage)
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}",
inst_type,
overrides,
_path,
_stage,
)
elif typing.get_origin(inst_type) is Literal:
# Literal[value]
(value,) = typing.get_args(inst_type)
if value != overrides:
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}",
inst_type,
overrides,
_path,
_stage,
)
return value
elif typing.get_origin(inst_type) is Union:
# Union[union_types[0], union_types[1], ...]
union_types = typing.get_args(inst_type)
if isinstance(overrides, str):
for subtype in union_types:
if subtype is None and overrides == "None":
return None
elif subtype is bool:
if overrides == "True":
return True
elif overrides == "False":
return False
elif subtype is int and overrides.strip().isnumeric():
return int(overrides)
elif subtype is str:
return overrides
elif subtype is float and float_pattern.fullmatch(overrides):
return float(overrides)
if overrides.lstrip().startswith("{") or overrides.lstrip().startswith("["):
overrides = json.loads(overrides)
return self.raw_to_typed(
overrides,
inst_type,
allow_imports,
_path,
_stage,
)
for subtype in union_types:
if _isinstance_deep(value, subtype):
return self.override(
value,
overrides,
subtype,
allow_imports,
f"{_path} -> {getattr(subtype, '__name__', repr(subtype))}",
_stage + (1,),
)
raise JsonValueError(
f"Expected {type_name} at {_path}, existing is {value!r} which is invalid",
inst_type,
value,
_path,
_stage,
)
elif (
isinstance(inst_type, type)
and issubclass(inst_type, tuple)
and hasattr(inst_type, "__annotations__")
):
# class MyClass(NamedTuple): ...
if not isinstance(overrides, (dict, str)):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}",
inst_type,
overrides,
_path,
_stage,
)
if isinstance(overrides, str):
return self.raw_to_typed(
json.loads(overrides),
inst_type,
allow_imports,
_path,
_stage,
)
local_overrides = _split_dict_keys(overrides)
if getattr(inst_type, "__dash_keys__", "False"):
local_overrides = {
key.replace("-", "_"): val for key, val in local_overrides.items()
}
kwargs = {
field_name: (
self.override(
getattr(value, field_name),
local_overrides.pop(field_name),
field_type,
allow_imports,
f"{_path} -> {type_name}:{field_name}",
_stage + (idx,),
)
if field_name in local_overrides
else getattr(value, field_name)
)
for idx, (field_name, field_type) in enumerate(inst_type.__annotations__.items())
}
if self.strict and len(local_overrides) != 0:
raise JsonValueError(
f"Invalid config keys {', '.join(local_overrides.keys())} for {type_name} at "
f"{_path}",
inst_type,
overrides,
_path,
_stage,
)
try:
return inst_type(**kwargs)
except BaseException:
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}",
inst_type,
overrides,
_path,
_stage,
)
elif dataclasses.is_dataclass(inst_type):
# dataclass
if not isinstance(overrides, (dict, str)):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}",
inst_type,
overrides,
_path,
_stage,
)
if isinstance(overrides, str):
return self.raw_to_typed(
json.loads(overrides),
inst_type,
allow_imports,
_path,
_stage,
)
local_overrides = _split_dict_keys(overrides)
if getattr(inst_type, "__dash_keys__", "False"):
local_overrides = {
key.replace("-", "_"): val for key, val in local_overrides.items()
}
kwargs = {
field.name: (
self.override(
getattr(value, field.name),
local_overrides.pop(field.name),
field.type,
allow_imports,
f"{_path} -> {type_name}:{field.name}",
_stage + (idx,),
)
if field.name in local_overrides
else getattr(value, field.name)
)
for idx, field in enumerate(dataclasses.fields(inst_type))
if field.init
}
if self.strict and len(local_overrides) != 0:
raise JsonValueError(
f"Invalid config keys {', '.join(local_overrides.keys())} for {type_name} at "
f"{_path}",
inst_type,
overrides,
_path,
_stage,
)
try:
return inst_type(**kwargs)
except BaseException:
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}",
inst_type,
overrides,
_path,
_stage,
)
elif (
typing.get_origin(inst_type) is list
or typing.get_origin(inst_type) is tuple
or inst_type in (list, tuple)
):
# List[inner_type] or Tuple[inner_type, Ellipsis] or
# Tuple[inner_type[0], inner_type[1], ...]
if inst_type is list:
inner_type = Any
inner_types = []
cls = list
elif inst_type is tuple:
inner_type = Any
inner_types = []
cls = tuple
elif typing.get_origin(inst_type) is list:
(inner_type,) = typing.get_args(inst_type)
inner_types = []
cls = list
else:
inner_types = typing.get_args(inst_type)
if len(inner_types) == 2 and inner_types[1] is Ellipsis:
inner_type = inner_types[0]
else:
inner_type = None
cls = tuple
if not isinstance(overrides, (dict, str)):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}",
inst_type,
overrides,
_path,
_stage,
)
if isinstance(overrides, str):
return self.raw_to_typed(
json.loads(overrides),
inst_type,
allow_imports,
_path,
_stage,
)
local_overrides = _split_dict_keys(overrides)
if not all(key.isnumeric() for key in local_overrides.keys()):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}, expected integer keys",
inst_type,
overrides,
_path,
_stage,
)
local_overrides_int = {int(key): value for key, value in local_overrides.items()}
new_max_idx = max(local_overrides_int.keys())
original_max_idx = len(value)
if inner_type is None and new_max_idx >= len(inner_types):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}, index {new_max_idx} out of "
f"bounds",
inst_type,
overrides,
_path,
_stage,
)
for i in range(original_max_idx, new_max_idx):
if i not in local_overrides_int:
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}, missing value for index "
f"{i}",
inst_type,
overrides,
_path,
_stage,
)
return cls(
(
self.override(
value[idx],
local_overrides_int[idx],
inner_type,
allow_imports,
f"{_path} -> {idx}",
_stage + (idx,),
)
if idx in local_overrides_int
else value[idx]
)
for idx in range(max(new_max_idx + 1, original_max_idx))
)
elif typing.get_origin(inst_type) is dict or inst_type is dict:
# Dict[str, value_type]
if inst_type is dict:
value_type = Any
else:
key_type, value_type = typing.get_args(inst_type)
assert key_type is str
if not isinstance(overrides, (dict, str)):
raise JsonValueError(
f"Expected {type_name} at {_path}, got {overrides!r}",
inst_type,
overrides,
_path,
_stage,
)
if isinstance(overrides, str):
return self.raw_to_typed(
json.loads(overrides),
inst_type,
allow_imports,
_path,
_stage,
)
local_overrides = _split_dict_keys(overrides)
if getattr(inst_type, "__dash_keys__", "False"):
local_overrides = {
key.replace("-", "_"): val for key, val in local_overrides.items()
}
res = {
key: (
self.override(
subvalue,
local_overrides.pop(key),
value_type,
allow_imports,
f"{_path} -> {type_name}:{key!r}",
_stage + (idx,),
)
if key in local_overrides
else subvalue
)
for idx, (key, subvalue) in value.items()
}
for key, val in local_overrides.items():
if not isinstance(val, str):
raise JsonValueError(
f"Expected new {type_name} at {_path} -> {type_name}:{key!r}, got {val!r}",
inst_type,
overrides,
_path,
_stage,
)
res[key] = self.raw_to_typed(
json.loads(val),
value_type,
allow_imports,
f"{_path} -> {type_name}:{key!r}",
_stage + (len(res),),
)
return res
else:
raise RuntimeError(f"Unknown type {inst_type}")
def to_json_object(obj: Any) -> Any:
"""
Converts the given object to a json object.
Args:
obj: The object to convert
Returns:
The json-like object.
"""
if isinstance(obj, (str, int, float, bool, type(None))):
# Literal types
return obj
elif isinstance(obj, tuple) and hasattr(obj, "__annotations__"):
# class MyClass(NamedTuple): ...
return {
field_name: to_json_object(getattr(obj, field_name))
for field_name in obj.__annotations__.keys()
}
elif dataclasses.is_dataclass(obj):
# dataclass
return {
field.name: to_json_object(getattr(obj, field.name))
for field in dataclasses.fields(obj)
if field.init
}
elif isinstance(obj, (list, tuple)):
return [to_json_object(val) for val in obj]
elif isinstance(obj, dict):
return {key: to_json_object(val) for key, val in obj.items()}
else:
raise RuntimeError(f"Unknown type {type(obj)}")
float_pattern = re.compile(r"[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?")
def _split_dict_keys(dct: Dict[str, Any]) -> Dict[str, Any]:
"""Splits the given dict keys by first '.' to subdicts."""
res = {}
for key, value in dct.items():
if "." in key:
outer_key, _, inner_key = key.partition(".")
if outer_key in res:
if not isinstance(res[outer_key], dict):
raise ValueError(f"Cannot combine {outer_key!r} with {res!r}")
res[outer_key][inner_key] = value
else:
res[outer_key] = {inner_key: value}
else:
if key in res:
raise ValueError(f"Cannot combine {key!r} with {res!r}")
res[key] = value
return res
def _isinstance_deep(val: Any, tp_chk: Type) -> bool:
"""Verifies if the given value is an instance of the tp_chk, allowing for typing extensions."""
if tp_chk is Any:
return True
elif typing.get_origin(tp_chk) is Literal:
(value,) = typing.get_args(val)
return val == value
elif typing.get_origin(tp_chk) is list:
(inner_type,) = typing.get_args(val)
return isinstance(val, list) and all(_isinstance_deep(v, inner_type) for v in val)
elif typing.get_origin(tp_chk) is tuple:
inner_types = typing.get_args(val)
if len(inner_types) == 2 and inner_types[1] == Ellipsis:
return isinstance(val, tuple) and all(_isinstance_deep(v, inner_types[0]) for v in val)
else:
return (
isinstance(val, tuple)
and len(val) == len(inner_types)
and all(_isinstance_deep(v, inner_type) for v, inner_type in zip(val, inner_types))
)
elif typing.get_origin(tp_chk) is dict:
key_type, value_type = typing.get_args(val)
return isinstance(val, dict) and all(
_isinstance_deep(k, key_type) and _isinstance_deep(v, value_type)
for k, v in val.items()
)
else:
return isinstance(val, tp_chk)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import inspect
import linecache
import os
import sys
import threading
import time
import traceback
from time import perf_counter
from typing import Any, Callable, Iterable, Iterator, Optional, TypeVar
import torch
from torch.distributed._shard.sharded_tensor import ShardedTensorBase
# For the watch_iter type
T = TypeVar("T")
# Maximum length of a single object string to print.
PRINT_LOCAL_MAX_LENGTH = 250
class Watchdog:
"""
A watchdog timer that:
- can be 'enabled' or 'disabled' by presence/absence of a deadline,
- resets automatically when 'enable()' is called,
- can be used as a context manager,
- can wrap an iterator to watch only the time for 'next()' calls,
- attempts a two-phase shutdown on callback error:
1) sys.exit(1) for graceful,
2) if still alive after 10s, os._exit(1).
"""
def __init__(
self,
timeout: float,
initial_timeout: Optional[float] = None,
callback: Optional[Callable[[], None]] = None,
dump_stacks: bool = True,
enabled: bool = True,
) -> None:
"""
Args:
timeout: Number of seconds before the watchdog fires if not reset/disabled.
initial_timeout: Number of seconds before the watchdog fires in the first iteration.
callback: Optional function to call upon timeout.
dump_stacks: If True, print full stack traces for all threads on timeout (except watchdog's own thread).
enabled: If False, watchdog starts disabled until enable() is called.
"""
self._timeout = timeout
self._initial_timeout = initial_timeout
self._callback = callback
self._dump_stacks = dump_stacks
self._is_first_iteration = True
# If _deadline is None, the watchdog is disabled.
# Otherwise, _deadline = time.time() + _timeout if enabled.
if enabled:
self._deadline: Optional[float] = perf_counter() + self._get_next_timeout()
else:
self._deadline = None
self._stop = False # signals permanent shutdown (finish)
# Condition variable to manage state changes
self._cv = threading.Condition()
# Background thread (daemon) that monitors timeouts
self._worker_thread = threading.Thread(target=self._worker, daemon=True)
self._worker_thread.start()
def _get_next_timeout(self) -> float:
if self._is_first_iteration:
self._is_first_iteration = False
return self._initial_timeout if self._initial_timeout is not None else self._timeout
else:
return self._timeout
def _worker(self) -> None:
"""
Background thread that periodically checks if the watchdog has expired.
Once it times out or is told to stop, it exits.
"""
while True:
with self._cv:
if self._stop:
# finish() was called; end the worker.
return
if self._deadline is None:
# Disabled; no deadline. Just wait a bit, then re-check.
self._cv.wait(timeout=1.0)
continue
remaining = self._deadline - perf_counter()
if remaining <= 0:
# We have timed out
self._on_timeout()
return
else:
# Wait until either the deadline or a state change
self._cv.wait(timeout=remaining)
def _on_timeout(self) -> None:
"""
Called exactly once if the watchdog times out.
1) Optionally dumps stacks,
2) Calls user callback,
3) If callback raises an error,
- print traceback,
- sys.exit(1),
- fallback to os._exit(1) after 10s if process not terminated.
"""
watchdog_thread_id = threading.get_ident()
# 1) Dump stacks if requested
if self._dump_stacks:
print("Watchdog triggered: Dumping thread stacks")
self._print_all_thread_stacks(skip_thread_id=watchdog_thread_id)
# 2) Call user callback
if self._callback:
try:
self._callback()
except Exception:
# Print the traceback
traceback.print_exc()
# Start a background kill-switch after 10 seconds
def force_exit_after_delay() -> None:
time.sleep(10)
os._exit(1)
killer = threading.Thread(target=force_exit_after_delay, daemon=True)
killer.start()
# Attempt graceful shutdown
sys.exit(1)
def _print_all_thread_stacks(self, skip_thread_id: Optional[int] = None) -> None:
"""
Dump stacks of all threads in a style reminiscent of py-spy, from
innermost (current) to outermost. Skip the watchdog's own thread if given.
Args:
skip_thread_id: If given, skip this thread's stack.
"""
frames = sys._current_frames() # thread_id -> frame
# We gather known threads to print their names
all_threads = {t.ident: t for t in threading.enumerate()}
for thread_id, frame in frames.items():
if skip_thread_id is not None and thread_id == skip_thread_id:
continue
thread = all_threads.get(thread_id)
thread_name = thread.name if thread else f"Unknown-{thread_id}"
print(f'Thread {thread_id}: "{thread_name}"')
# Build the stack from current (innermost) to outermost
stack_frames = []
f = frame
while f is not None:
stack_frames.append(f)
f = f.f_back
for fr in stack_frames:
code = fr.f_code
func_name = code.co_name
filename = code.co_filename
lineno = fr.f_lineno
print(f" {func_name} ({filename}:{lineno})")
# Attempt to read the actual line of source
line = linecache.getline(filename, lineno).rstrip()
if line:
print(f" > {line}")
# Show arguments and locals
arg_info = inspect.getargvalues(fr)
arg_names = arg_info.args
varargs = arg_info.varargs
varkw = arg_info.keywords
local_vars = arg_info.locals
# Separate out the arguments
arg_dict = {}
for arg in arg_names:
if arg in local_vars:
arg_dict[arg] = local_vars[arg]
if varargs and varargs in local_vars:
arg_dict["*" + varargs] = local_vars[varargs]
if varkw and varkw in local_vars:
arg_dict["**" + varkw] = local_vars[varkw]
if arg_dict:
print(" Arguments:")
for k, v in arg_dict.items():
print(f" {k}: {repr_short(v)}")
other_locals = {k: v for k, v in local_vars.items() if k not in arg_dict}
if other_locals:
print(" Locals:")
for k, v in other_locals.items():
print(f" {k}: {repr_short(v)}")
print(flush=True)
def reset(self) -> None:
"""
Reset the watchdog timer (push out deadline by `timeout` seconds),
but only if currently enabled (i.e., _deadline is not None).
"""
with self._cv:
if self._deadline is not None:
self._deadline = perf_counter() + self._timeout
self._cv.notify()
def enable(self) -> None:
"""
Enable (or re-enable) the watchdog. Always resets the deadline to
`time.time() + timeout`.
"""
with self._cv:
self._deadline = perf_counter() + self._get_next_timeout()
self._cv.notify()
def disable(self) -> None:
"""
Disable the watchdog (no timeout will fire until re-enabled).
"""
with self._cv:
self._deadline = None
self._cv.notify()
def finish(self) -> None:
"""
Permanently stop the watchdog thread and disarm the timer.
After calling finish(), you cannot re-enable this watchdog.
"""
with self._cv:
self._stop = True
self._cv.notify()
self._worker_thread.join()
def __enter__(self) -> "Watchdog":
# If currently disabled, calling enable() will also reset the timer
if self._deadline is None:
self.enable()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
# End the watchdog on context exit
self.finish()
def watch_iter(self, iterable: Iterable[T]) -> Iterator[T]:
"""
Wrap an iterable so that each 'next()' call is watched by the watchdog,
but the time in between iterations is not watched. Usage:
wd = Watchdog(timeout=3, enabled=False)
for item in wd.watch_iter(generator()):
# processing item not timed by the watchdog
pass
This pattern:
- enable() -> sets/extends deadline
- next(...) -> measured portion
- disable() -> stops timer
Args:
iterable: The iterable to wrap and watch.
Returns:
An iterator that wraps the input iterable and watches for timeouts.
"""
try:
self.enable()
for item in iterable:
self.disable()
yield item
self.enable()
finally:
self.disable()
def repr_short(obj: Any) -> str:
"""
Return a short repr of an object.
"""
if isinstance(obj, torch.Tensor):
if isinstance(obj, ShardedTensorBase) or obj.is_cuda:
return "<CUDA tensor>"
s = repr(obj)
if len(s) > PRINT_LOCAL_MAX_LENGTH:
s = s[: PRINT_LOCAL_MAX_LENGTH // 2] + "..." + s[-PRINT_LOCAL_MAX_LENGTH // 2 :]
return s
if __name__ == "__main__":
# Example usage
def my_callback() -> None:
print("Watchdog timed out in callback.")
# Demonstrate an error
raise ValueError("Example error from callback.")
print("Simple usage example:")
wd = Watchdog(timeout=2, callback=my_callback, enabled=True)
print("Sleeping 3s so the watchdog times out.")
time.sleep(30)
# Because we never reset or finish, the watchdog should fire and
# forcibly exit, after printing the traceback and stack dumps.
print("You won't see this line if the watchdog fired first.")
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import hashlib
import json
import multiprocessing
from dataclasses import dataclass
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, TextIO, TypeVar
import torch.distributed
import torch.utils.data
from megatron.energon.cache import CachePool
__all__ = ("WorkerConfig",)
T = TypeVar("T")
@dataclass(slots=True, kw_only=True, eq=False)
class WorkerConfig:
"""
Provides information about the current worker and the global configuration. This gives each
data parallel rank its proper config. Every `rank` (up to `world_size-1`) must be used.
If set wrong, the datasets might yield the same data or data might be missing, as data
is split over the data parallel ranks with this config!
You may set the same rank, if you need multiple ranks to retrieve the same data.
"""
#: The data parallel rank/id of the current process.
rank: int
#: The total number of data parallel processes.
world_size: int
#: The number of workers per rank. May be 0 to disable worker processes.
num_workers: int
#: If not using all ranks for data parallel, set this to the corresponding group.
data_parallel_group: Optional[torch.distributed.ProcessGroup] = None
#: The id offset of the current worker. e.g. the worker may live as `worker_info.id=0`, but
# actually yield samples for id=1 (i.e. worker_id_offset=1). Required to support restoring the
# worker state if last emitted sample was not for worker_id=0. Required by SavableDataLoader to
# restore the worker state. Is only set to nonzero within a worker process.
worker_id_offset: ClassVar[int] = 0
#: The following seed_offset is used used at two points in the code.
# 1. The seed_offset in the worker_config that is passed to the dataset initialization, is used
# to set the seed for the dataset shuffling and shuffled blending (All code that uses WorkerRng).
# 2. The worker_config passed to the data loader initialization, is used to set the seed for the
# torch, numpy and random libraries. This does not affect the dataset shuffling, but only the
# user code (e.g. code in TaskEncoder).
seed_offset: int = 0
#: The path to the debug file for the current worker. Should contain "{worker_id}" and "{pid}"
# to separate the workers.
worker_debug_path: Optional[str] = None
#: Log level for worker logging.
worker_log_level: int = 0
#: The opened file for the current worker. Should not be set from outside.
_worker_debug_file: Optional[TextIO] = None
#: worker_id of the opened worker debug file
_worker_debug_file_worker_id: Optional[int] = None
#: The current sample index within the current iterating worker
_sample_index_stack: ClassVar[Optional[List[int]]] = None
#: The current worker config within the current iterating worker
active_worker_config: ClassVar[Optional["WorkerConfig"]] = None
#: The global rank override for the worker. Required for restoring samples.
_worker_override_global_rank: ClassVar[Optional[List[int]]] = None
#: The current cache pool for the worker.
_cache_pool: "ClassVar[Optional[CachePool]]" = None
def worker_activate(
self,
sample_index: int,
override_global_rank: Optional[int] = None,
cache_pool: "Optional[CachePool]" = None,
):
"""Activates the worker config for the current worker and sets it as actively iterating.
Must be called before next() call on the datasets."""
assert WorkerConfig.active_worker_config is None
WorkerConfig._sample_index_stack = [sample_index]
WorkerConfig.active_worker_config = self
WorkerConfig._worker_override_global_rank = override_global_rank
WorkerConfig._cache_pool = cache_pool
def worker_push_sample_index(self, sample_index: int):
"""Pushes a new sample index to the sample index stack. Should be set by wrapping datasets
before calling inners."""
assert WorkerConfig.active_worker_config is not None
WorkerConfig._sample_index_stack.append(sample_index)
def worker_pop_sample_index(self):
"""Pushes a new sample index to the sample index stack. Should be set by wrapping datasets
before calling inners."""
assert WorkerConfig.active_worker_config is not None
return WorkerConfig._sample_index_stack.pop()
def worker_deactivate(self):
"""Deactivates the worker config for the current worker and deactivates it for iterating.
Must be called after next() call on the datasets."""
if WorkerConfig.active_worker_config is not None:
assert len(WorkerConfig._sample_index_stack) == 1, (
f"Sample index stack not empty: {WorkerConfig._sample_index_stack}"
)
WorkerConfig._sample_index_stack = None
WorkerConfig.active_worker_config = None
WorkerConfig._worker_override_global_rank = None
@property
def active_worker_sample_index(self) -> int:
"""Returns the current sample index for the actively iterating worker."""
# Internal sample index is for the local worker. If using multiple workers per rank, this
# must be multiplied by the number of workers and offset by the local worker index.
return (
WorkerConfig._sample_index_stack[-1] * max(self.num_workers, 1) + self.rank_worker_id()
)
@property
def active_worker_batch_index(self) -> int:
"""Returns the current batch index for the actively iterating worker."""
# Internal batch index is for the local worker. If using multiple workers per rank, this
# must be multiplied by the number of workers and offset by the local worker index.
return (
WorkerConfig._sample_index_stack[0] * max(self.num_workers, 1) + self.rank_worker_id()
)
def global_rank(self) -> int:
"""Returns the global rank of this worker config but as a global rank, not
as a rank within the data parallel group."""
if self.data_parallel_group is None:
return self.rank
return torch.distributed.get_global_rank(self.data_parallel_group, self.rank)
def __eq__(self, other):
"""Do not compare everything to check for equal config"""
if not isinstance(other, WorkerConfig):
return NotImplementedError()
return all(
[
self.rank == other.rank,
self.world_size == other.world_size,
self.num_workers == other.num_workers,
]
)
@staticmethod
def default_worker_config(
num_workers: int = 4, data_parallel_group: Optional[torch.distributed.ProcessGroup] = None
) -> "WorkerConfig":
"""Returns the default worker config using torch distributed if available.
If torch distributed is not available, a single local rank is assumed."""
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank(data_parallel_group)
world_size = torch.distributed.get_world_size(data_parallel_group)
else:
rank = 0
world_size = 1
return WorkerConfig(
rank=rank,
world_size=world_size,
num_workers=num_workers,
data_parallel_group=data_parallel_group,
)
def rank_worker_id(self) -> int:
"""Returns the self worker id within the current rank."""
if self._worker_override_global_rank:
assert self.worker_id_offset == 0
return self._worker_override_global_rank % self.num_workers
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
return self.worker_id_offset
assert worker_info.num_workers == self.num_workers
# Apply the worker_id_offset as a left rotation of the logical worker ids.
# This ensures that after restoring a checkpoint the first physical
# worker (id=0) corresponds to the logical worker that should emit the
# next sample. For example, if `worker_id_offset` is 1, logical worker
# 1 becomes the first to emit a sample, shifting the ordering forward.
return (worker_info.id + self.worker_id_offset) % worker_info.num_workers
def assert_worker(self):
"""Checks if the current process is a worker (if configured so), and that the workers are
properly configured."""
if self.num_workers <= 1:
assert self.rank_worker_id() == 0
else:
worker_info = torch.utils.data.get_worker_info()
assert worker_info is not None, "Cannot iterate out of worker context"
assert worker_info.num_workers == self.num_workers, (
f"Actual number of workers for this rank ({worker_info.num_workers}) does not "
f"match the configured number of workers ({self.num_workers})"
)
def global_worker_id(self, override_local_worker_id: Optional[int] = None) -> int:
"""Returns the global worker index by multiplying the rank with the number of workers.
Alternatively, you can override the local worker id.
Args:
override_local_worker_id (int, optional): The local worker id to override. None means
the current worker, which is the default.
"""
if self._worker_override_global_rank is not None:
assert override_local_worker_id is None
return self._worker_override_global_rank
if override_local_worker_id is not None:
return self.rank * self.num_workers + override_local_worker_id
else:
self.assert_worker()
return self.rank * self.num_workers + self.rank_worker_id()
def worker_seed(self, override_local_worker_id: Optional[int] = None) -> int:
"""Returns the seed for the current worker (or a specified worker).
Base on the current worker id and the seed offset, compute a seed.
Alternatively, you can override the local worker id with a fixed one to
pregenerate seeds for multiple workers.
Args:
override_local_worker_id (int, optional): The local worker id to override. None means
the current worker, which is the default.
"""
if self.num_workers == 0:
# If we are not using workers, different ranks should still get a different seed
global_worker_id = self.rank
else:
global_worker_id = self.global_worker_id(override_local_worker_id)
seed_offset = self.seed_offset
seed_hash = hashlib.sha1(f"{global_worker_id},{seed_offset}".encode("utf-8")).digest()
return int.from_bytes(seed_hash, byteorder="big", signed=False) & 0xFFFFFFFF
def config(self) -> Dict[str, Any]:
return {
"rank": self.rank,
"world_size": self.world_size,
"num_workers": self.num_workers,
"data_parallel_group": (
self.data_parallel_group.size() if self.data_parallel_group else None
),
}
def should_log(self, level: int) -> bool:
return level <= self.worker_log_level
def worker_log(self, data: dict) -> None:
"""Logs the given data to the worker debug file."""
if self.worker_debug_path is None:
print(json.dumps(data) + "\n", end="", flush=True)
else:
in_worker = torch.utils.data.get_worker_info() is not None
# Additional "worker" with rank_worker_id=0 is the main process. All workers have +1
# as their worker_id.
worker_id = (
self.rank * (self.num_workers + 1) + self.rank_worker_id() + (1 if in_worker else 0)
)
if self._worker_debug_file is None or self._worker_debug_file_worker_id != worker_id:
if self._worker_debug_file is not None:
self._worker_debug_file.close()
path = Path(
self.worker_debug_path.format(
worker_id=worker_id, pid=multiprocessing.current_process().ident
)
)
path.parent.mkdir(exist_ok=True, parents=True)
self._worker_debug_file = path.open("w")
self._worker_debug_file_worker_id = worker_id
self._worker_debug_file.write(json.dumps(data) + "\n")
self._worker_debug_file.flush()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.wrappers.batch_dataset import BatchDataset
from megatron.energon.wrappers.blend_dataset import BlendDataset
from megatron.energon.wrappers.concat_dataset import ConcatDataset
from megatron.energon.wrappers.epochize_dataset import EpochizeDataset
from megatron.energon.wrappers.filter_dataset import FilterDataset
from megatron.energon.wrappers.gc_dataset import GcDataset
from megatron.energon.wrappers.group_batch_dataset import GroupBatchDataset
from megatron.energon.wrappers.iter_map_dataset import IterMapDataset
from megatron.energon.wrappers.limit_dataset import LimitDataset
from megatron.energon.wrappers.log_sample_dataset import LogSampleDataset
from megatron.energon.wrappers.map_dataset import MapDataset
from megatron.energon.wrappers.mix_batch_dataset import (
MixBatchDataset,
concat_pad,
generic_concat,
homogeneous_concat_mix,
)
from megatron.energon.wrappers.packing_dataset import PackingDataset
from megatron.energon.wrappers.repeat_dataset import RepeatDataset
from megatron.energon.wrappers.shuffle_buffer_dataset import ShuffleBufferDataset
from megatron.energon.wrappers.skip import SkipSample
__all__ = [
"BatchDataset",
"BlendDataset",
"ConcatDataset",
"EpochizeDataset",
"FilterDataset",
"GcDataset",
"GroupBatchDataset",
"IterMapDataset",
"LimitDataset",
"LogSampleDataset",
"MapDataset",
"MixBatchDataset",
"RepeatDataset",
"ShuffleBufferDataset",
"SkipSample",
"PackingDataset",
"concat_pad",
"generic_concat",
"homogeneous_concat_mix",
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import traceback
from typing import Optional, Sequence
from megatron.energon.source_info import SourceInfo
def log_exception(_e: Exception, sample, sources: Optional[Sequence[SourceInfo]] = None):
traceback.print_exc()
print("-" * 10)
if sources:
print("Sources:")
for source in sources:
print(
f" - {source.dataset_path}[{source.index}] {source.shard_name}{source.file_names!r}"
)
print("-" * 10)
sample_str = str(sample)
if len(sample_str) > 400:
sample_str = sample_str[:200] + "..." + sample_str[-200:]
print(sample_str)
print("-" * 10)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Generator, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union
from megatron.energon.flavors.base_dataset import (
FlexState,
Sample,
SavableDataset,
add_sample_restore_key,
)
from megatron.energon.savable import Savable
from megatron.energon.worker import WorkerConfig
T = TypeVar("T")
T_sample = TypeVar("T_sample", covariant=True)
T_sample_out = TypeVar("T_sample_out", covariant=True)
T_sample_in = TypeVar("T_sample_in", covariant=True)
class BaseWrapperDataset(SavableDataset[T_sample_out], Generic[T_sample_in, T_sample_out], ABC):
"""Base class for dataset wrappers. All dataset wrappers should derive from this. A dataset
wrapper takes one dataset and modifies its samples to make a new dataset. This can be for
shuffling samples or applying custom functions to the data. Some wrappers only modify the
length of the dataset or how it's repeated."""
datasets: Tuple[SavableDataset[T_sample_in], ...]
def __init__(
self,
datasets: Union[SavableDataset[T_sample_in], Iterable[SavableDataset[T_sample_in]]],
*,
worker_config: WorkerConfig,
):
super().__init__(worker_config=worker_config)
if isinstance(datasets, SavableDataset):
self.datasets = (datasets,)
else:
self.datasets = tuple(datasets)
for d in self.datasets:
# Check that the dataset worker configs are the same as the wrapper worker config
assert d.worker_config == self.worker_config, (
"Dataset and wrapper worker configs must match."
)
@property
def dataset(self) -> SavableDataset:
"""Convenience property, if only one dataset is wrapped."""
assert len(self.datasets) == 1
return self.datasets[0]
def can_restore_sample(self) -> bool:
return all(ds.can_restore_sample() for ds in self.datasets)
def assert_can_restore(self) -> None:
for ds in self.datasets:
ds.assert_can_restore()
def worker_has_samples(self) -> bool:
return any(ds.worker_has_samples() for ds in self.datasets)
def _find_wrapped_dataset(self, cls: Type[SavableDataset]) -> Optional[SavableDataset]:
"""Find the outermost dataset wrapped in this dataset that is of type cls."""
for ds in self.datasets:
if isinstance(ds, cls):
return ds
elif isinstance(ds, BaseWrapperDataset):
res = ds._find_wrapped_dataset(cls)
if res is not None:
return res
return None
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample_out:
if len(self.datasets) == 1:
return self.datasets[0].restore_sample(restore_key)
else:
id, ds_idx = restore_key[:2]
assert id == type(self).__name__
restore_key = restore_key[2:]
assert isinstance(ds_idx, int)
return add_sample_restore_key(
self.datasets[ds_idx].restore_sample(restore_key),
ds_idx,
src=self,
)
def save_state(self) -> FlexState:
own_state = super().save_state()
return FlexState(datasets=[ds.save_state() for ds in self.datasets], **own_state)
def restore_state(self, state: FlexState) -> None:
assert len(self.datasets) == len(state["datasets"])
for dataset, dstate in zip(self.datasets, state["datasets"]):
dataset.restore_state(dstate)
super().restore_state(state)
def reset_state_deep(self) -> None:
"""Resets the state of the inner datasets and then the own state."""
for ds in self.datasets:
if isinstance(ds, BaseWrapperDataset):
ds.reset_state_deep()
else:
ds.reset_state_own()
self.reset_state_own()
@abstractmethod
def reset_state_own(self) -> None:
"""Resets the state of the dataset, excl. the inner datasets."""
...
class SampleIndex(Savable):
"""A simple class to hold the sample index for one worker."""
worker_config: WorkerConfig
current_idx: int
actives = 0
def __init__(self, worker_config: WorkerConfig, *, src: Any) -> None:
self.worker_config = worker_config
self.current_idx = 0
self.src = src
def get_next(self) -> int:
res = self.current_idx
self.current_idx += 1
return res
@contextmanager
def ctx(self, sample_idx: Optional[int] = None):
if sample_idx is None:
sample_idx = self.get_next()
assert WorkerConfig.active_worker_config is not None
WorkerConfig.active_worker_config.worker_push_sample_index(sample_idx)
# print(" " * SampleIndex.actives + f"Activated from {type(self.src).__name__}({id(self.src)}) {sample_idx} -> {WorkerConfig.active_worker_config._sample_index_stack}")
SampleIndex.actives += 1
try:
yield sample_idx
finally:
assert WorkerConfig.active_worker_config is not None
popped = WorkerConfig.active_worker_config.worker_pop_sample_index()
SampleIndex.actives -= 1
# print(" " * SampleIndex.actives + f"Deactivate from {type(self.src).__name__}({id(self.src)}) {sample_idx} -> {WorkerConfig.active_worker_config._sample_index_stack}")
assert popped == sample_idx, f"Expected {sample_idx}, got {popped}"
def iter_ctx(
self,
it: Iterable[T_sample],
sample_idx: Optional[int] = None,
) -> Generator[Tuple[int, T_sample], None, None]:
it = iter(it)
try:
while True:
try:
with self.ctx(sample_idx) as res_sample_idx:
x = next(it)
yield res_sample_idx, x
except StopIteration:
break
finally:
if hasattr(it, "close"):
it.close()
def save_state(self) -> int:
return self.current_idx
def restore_state(self, state: Optional[int]) -> None:
if state is None:
self.current_idx = 0
else:
self.current_idx = state
def get_sample_restore_key(sample: Any) -> Optional[Union[str, int]]:
"""Gets the restore key from an arbitrary sample."""
if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"):
return sample.__restore_key__
elif isinstance(sample, dict) and "__restore_key__" in sample:
return sample["__restore_key__"]
else:
return None
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import inspect
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError
from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key
from megatron.energon.source_info import SourceInfo
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers._log_exception import log_exception
from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key
from megatron.energon.wrappers.skip import SkipSample
T_batch = TypeVar("T_batch", covariant=True)
T_batch_sample = TypeVar("T_batch_sample", covariant=True)
class BatchDataset(BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch]):
"""This dataset wrapper transforms a dataset of samples into a dataset of batches."""
batch_size: int
batcher: Callable[[List[T_batch_sample]], T_batch]
drop_last: bool
error_handler: Callable[[Exception, list[T_batch_sample], Sequence[SourceInfo]], None]
_sample_index: SampleIndex
_generator_sample_keys: Optional[Any]
_generator_offset: Optional[int]
_last_batch_failures: int = 0
_savable_fields = ("_sample_index", "_generator_sample_keys", "_generator_offset")
def __init__(
self,
dataset: SavableDataset[T_batch_sample],
batch_size: int,
batcher: Callable[[List[T_batch_sample]], T_batch],
*,
batcher_stateless: bool = False,
batcher_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None,
drop_last: bool = False,
error_handler: Callable[
[Exception, list[T_batch_sample], Sequence[SourceInfo]], None
] = log_exception,
failure_tolerance: int = 100,
worker_config: WorkerConfig,
):
"""Construct a BatchDataset.
Args:
dataset: The input dataset to wrap
batch_size: The desired batch size. The last batch may be smaller.
batcher: Function which combines separate samples into a single object. May raise
:exc:`megatron.energon.SkipSample` to skip a sample.
batcher_stateless: If True, the batcher is stateless, thus samples can be stored/
restored.
batcher_config: Configuration for the batcher function. If callable, it should return the
configuration. Defaults to None.
drop_last: If True, the last batch is dropped if it is smaller than the batch size.
error_handler: Function which handles exceptions raised by the batcher. The default
implementation logs the exception.
failure_tolerance: The number of consecutive failures after which the dataset is considered broken. Set to 0 to disable.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
self.batch_size = batch_size
self.batcher = batcher
self.batcher_stateless = batcher_stateless
self.batcher_config = batcher_config
self.drop_last = drop_last
self.error_handler = error_handler
self.failure_tolerance = failure_tolerance
self.reset_state_own()
def reset_state_own(self) -> None:
self._sample_index = SampleIndex(self.worker_config, src=self)
self._generator_sample_keys = None
self._generator_offset = None
def len_worker(self, worker_idx: int | None = None) -> int:
n_samples = self.dataset.len_worker(worker_idx)
n_batches = n_samples // self.batch_size
if n_samples % self.batch_size != 0 and not self.drop_last:
n_batches += 1
return n_batches
def __iter__(self) -> Iterator[T_batch]:
batch: List[T_batch_sample] = []
sample_restore_keys = []
if self._generator_sample_keys is not None:
sample_restore_keys = self._generator_sample_keys
assert self._generator_offset is not None
batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys]
with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx:
batch_sample = self.batcher(batch)
assert isinstance(batch_sample, Generator)
assert inspect.isgeneratorfunction(self.batcher), (
f"Generator in {self.batcher} but not marked as such."
)
target_offset = self._generator_offset
self._generator_offset = 0
for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate(
self._sample_index.iter_ctx(batch_sample, sample_idx)
):
# Skip other samples
if batch_sub_idx >= target_offset:
self._generator_offset = batch_sub_idx + 1
yield set_sample_restore_key(
inner_batch_sample,
sample_idx,
batch_sub_idx,
*sample_restore_keys,
src=self,
)
self._generator_sample_keys = None
self._generator_offset = None
batch.clear()
sample_restore_keys = []
def flush() -> Generator[T_batch, None, None]:
try:
with self._sample_index.ctx() as sample_idx:
batch_sample = self.batcher(batch)
if isinstance(batch_sample, Generator):
assert inspect.isgeneratorfunction(self.batcher), (
f"Generator in {self.batcher} but not marked as such."
)
self._generator_sample_keys = sample_restore_keys
self._generator_offset = 0
for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate(
self._sample_index.iter_ctx(batch_sample, sample_idx)
):
self._last_batch_failures = 0
self._generator_offset = batch_sub_idx + 1
yield set_sample_restore_key(
inner_batch_sample,
sample_idx,
batch_sub_idx,
*sample_restore_keys,
src=self,
)
self._generator_sample_keys = None
self._generator_offset = None
else:
self._last_batch_failures = 0
set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self)
yield batch_sample
except GeneratorExit:
raise
except SkipSample:
pass
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(batch)
except Exception as e:
self.error_handler(e, batch)
self._last_batch_failures += 1
if (
self.failure_tolerance > 0
and self._last_batch_failures >= self.failure_tolerance
):
raise FatalSampleError.from_sample(
batch,
f"BatchDataset {self.batcher} failed {self._last_batch_failures} times in a row. Likely your code or dataset are broken.",
)
finally:
sample_restore_keys.clear()
for sample in self.dataset:
batch.append(sample)
sample_restore_keys.append(get_sample_restore_key(sample))
if len(batch) == self.batch_size:
yield from flush()
batch = []
if len(batch) > 0 and not self.drop_last:
yield from flush()
def can_restore_sample(self) -> bool:
# Cannot really verify if the returned elements contain a __restore_key__.
# If the user wants to use this, well...
return super().can_restore_sample() and self.batcher_stateless
def assert_can_restore(self) -> None:
assert self.batcher_stateless, (
f"Batcher {self.batcher} must be stateless to restore samples"
)
super().assert_can_restore()
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_batch:
# We need to store multiple indices to restore a batch.
self.assert_can_restore()
if inspect.isgeneratorfunction(self.batcher):
id, sample_idx, batch_sub_idx, *samples_restore_keys = restore_key
assert id == type(self).__name__
else:
id, sample_idx, *samples_restore_keys = restore_key
assert id == type(self).__name__
batch = [self.dataset.restore_sample(inner_idx) for inner_idx in samples_restore_keys]
try:
with self._sample_index.ctx(sample_idx):
batch_sample = self.batcher(batch)
if isinstance(batch_sample, Generator):
assert inspect.isgeneratorfunction(self.batcher), (
f"Generator in {self.batcher} but not marked as such."
)
for cur_batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate(
self._sample_index.iter_ctx(batch_sample, sample_idx)
):
self._last_batch_failures = 0
if cur_batch_sub_idx == batch_sub_idx:
return set_sample_restore_key(
inner_batch_sample,
sample_idx,
batch_sub_idx,
*samples_restore_keys,
src=self,
)
assert False, f"Batch sub-index {batch_sub_idx} not found in batch"
else:
self._last_batch_failures = 0
return set_sample_restore_key(
batch_sample,
sample_idx,
*samples_restore_keys,
src=self,
)
except GeneratorExit:
raise FatalSampleError.from_sample(
batch,
f"BatchDataset {self.batcher} generator exitedwhile trying to restore a batch.",
)
except SkipSample:
raise FatalSampleError.from_sample(
batch, f"BatchDataset {self.batcher} skipped while trying to restore a batch."
)
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(batch)
except Exception as e:
self.error_handler(e, batch)
self._last_batch_failures += 1
if self.failure_tolerance > 0 and self._last_batch_failures >= self.failure_tolerance:
raise FatalSampleError.from_sample(
batch,
f"BatchDataset {self.batcher} failed {self._last_batch_failures} times in a row. Likely your code or dataset are broken.",
)
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"batch_size": self.batch_size,
"batcher": self._function_config(self.batcher),
**(
{
"batcher_config": (
self.batcher_config()
if callable(self.batcher_config)
else self.batcher_config
)
}
if self.batcher_config
else {}
),
"batcher_stateless": self.batcher_stateless,
"drop_last": self.drop_last,
"error_handler": self._function_config(self.error_handler),
"worker_config": self.worker_config.config(),
"dataset": self.dataset.config(),
}
def __str__(self):
return f"BatchDataset(batch_size={self.batch_size}, drop_last={self.drop_last}, batcher={self.batcher}, dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Dict, Iterator, List, Sequence, Tuple, TypeVar
import torch
from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key
from megatron.energon.rng import WorkerRng
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
T_sample = TypeVar("T_sample")
class BlendDataset(BaseWrapperDataset[T_sample, T_sample]):
"""
This dataset wrapper blends multiple iterable datasets together give a weighting.
The datasets may be infinite. This dataset is always infinite.
"""
datasets: List[SavableDataset[T_sample]]
weights: Tuple[float, ...]
dataset_weights: Sequence[Tuple[SavableDataset[T_sample], float]]
exhausted: List[bool]
_worker_rng: WorkerRng
_savable_fields = ("exhausted", "_worker_rng")
def __init__(
self,
*dataset_weights: Tuple[SavableDataset[T_sample], float],
worker_config: WorkerConfig,
):
"""Construct a BlendDataset.
Args:
dataset_weights: Each argument should be a tuple of (dataset, weight) with a weight
between 0 and 1. The output samples are sampled from the input datasets with the
given probabilities.
worker_config: Configuration for the workers.
"""
# datasets = [dataset for dataset, _weight in dataset_weights]
self.datasets, self.weights = zip(*dataset_weights)
super().__init__(self.datasets, worker_config=worker_config)
self.dataset_weights = dataset_weights
self.reset_state_own()
def reset_state_own(self) -> None:
self._worker_rng = WorkerRng(self.worker_config)
self.exhausted = [False] * len(self.weights)
def len_worker(self, worker_idx: int | None = None) -> int:
# Give the number of samples in inner datasets, disregarding the weight
return sum(dataset.len_worker(worker_idx) for dataset in self.datasets)
def __iter__(self) -> Iterator[T_sample]:
assert self.worker_has_samples(), "Cannot blend all empty datasets"
# Create a list of datasets and their weights, but
# set the weight to 0 if the dataset has no samples on this worker.
dataset_iters = []
weights = []
for idx, (dataset, weight) in enumerate(self.dataset_weights):
assert weight > 0, "All blending weights must be > 0"
if dataset.worker_has_samples():
dataset_iters.append(iter(dataset))
weights.append(weight)
else:
dataset_iters.append(None)
weights.append(0)
self.exhausted[idx] = True
weights = torch.tensor(weights, dtype=torch.float32)
if weights.sum() == 0:
raise RuntimeError(
"There is a worker with no samples in any of the blended datasets. "
"This can happen if you have a lot of workers and your dataset is too small. "
"Currently this case is not supported."
)
# Some may already be exhausted on this worker when restoring a state.
for idx, exhausted in enumerate(self.exhausted):
if exhausted:
weights[idx] = 0
dataset_iters[idx] = None
while True:
ds_idx = self._worker_rng.choice_idx(probs=weights)
if dataset_iters[ds_idx] is None:
if all(dataset_iter is None for dataset_iter in dataset_iters):
break
continue
try:
sample = next(dataset_iters[ds_idx])
except StopIteration:
dataset_iters[ds_idx] = None
weights[ds_idx] = 0
self.exhausted[ds_idx] = True
if all(dataset_iter is None for dataset_iter in dataset_iters):
break
else:
yield add_sample_restore_key(sample, ds_idx, src=self)
self.exhausted = [False] * len(self.dataset_weights)
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset_weights": [
(dataset.config(), weight) for dataset, weight in self.dataset_weights
],
"worker_config": self.worker_config.config(),
}
def __str__(self):
return f"BlendDataset(dataset_weights={self.dataset_weights})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import (
Any,
Dict,
Generator,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from megatron.energon.flavors.base_dataset import FlexState, SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key
T_sample = TypeVar("T_sample")
class SavableSampleBuffer(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""A buffer of samples, savable."""
_buffer: List[T_sample]
_restore_keys: List[Tuple[Union[str, int, tuple], ...]]
_savable_fields = ("_restore_keys",)
_restore_pending: bool = False
def __init__(self, dataset: SavableDataset[T_sample], *, worker_config: WorkerConfig):
super().__init__(dataset, worker_config=worker_config)
self.reset_state_own()
def reset_state_own(self) -> None:
self._buffer = []
self._restore_keys = []
def worker_start(self) -> None:
if self._restore_pending:
assert len(self._buffer) == 0
self._restore_pending = False
for restore_key in self._restore_keys:
self._buffer.append(self.restore_sample(restore_key))
assert len(self._buffer) == len(self._restore_keys)
def append(self, sample: T_sample) -> T_sample:
self._buffer.append(sample)
self._restore_keys.append(get_sample_restore_key(sample))
return sample
def extend(self, samples: List[T_sample], restore_keys: Optional[Sequence[Any]] = None) -> None:
self._buffer.extend(samples)
if restore_keys is None:
self._restore_keys.extend(get_sample_restore_key(sample) for sample in samples)
else:
self._restore_keys.extend(restore_keys)
def append_iter(self) -> Generator[T_sample, None, None]:
for sample in self.dataset:
yield self.append(sample)
def pop(self, index: int) -> T_sample:
self._restore_keys.pop(index)
return self._buffer.pop(index)
def flush(self) -> Tuple[List[T_sample], Tuple[Any, ...]]:
buffer = list(self._buffer)
restore_key = tuple(self._restore_keys)
self._buffer.clear()
self._restore_keys.clear()
return buffer, restore_key
@property
def buffer(self) -> List[T_sample]:
return self._buffer
def __iter__(self) -> Iterator[T_sample]:
return iter(self._buffer)
def __getitem__(self, index: Union[int, slice]) -> Union[T_sample, List[T_sample]]:
return self._buffer[index]
def __setitem__(self, index: Union[int, slice], value: T_sample) -> None:
self._buffer[index] = value
if isinstance(index, slice):
self._restore_keys[index] = (get_sample_restore_key(v) for v in value)
else:
self._restore_keys[index] = get_sample_restore_key(value)
def __delitem__(self, index: Union[int, slice]) -> None:
del self._buffer[index]
del self._restore_keys[index]
def len_worker(self, worker_idx: int | None = None) -> int:
self.worker_config.assert_worker()
assert worker_idx is None or worker_idx == self.worker_config.rank_worker_id(), (
"SavableSampleBuffer.len_worker only available for the current worker"
)
return len(self._restore_keys)
def len_rank(self) -> int:
raise NotImplementedError("len_rank is not available for SavableSampleBuffer")
def save_state(self) -> FlexState:
# Don't call super().save_state() because we don't want to save the wrapped datasets
# Just save the own state
return SavableDataset.save_state(self)
def restore_state(self, state: FlexState) -> None:
# Don't call super().restore_state() because we don't want to restore the wrapped datasets
# Just restore the own state
SavableDataset.restore_state(self, state)
self._restore_pending = True
def restore_key(self) -> Tuple[Union[str, int], ...]:
return tuple(self._restore_keys)
def restore_samples(
self, index: Tuple[Union[str, int, tuple], ...]
) -> Tuple[Tuple[Union[str, int, tuple], ...], List[T_sample]]:
buffer = []
restore_keys = []
for sub_index in index:
sample = self.restore_sample(sub_index)
restore_keys.append(get_sample_restore_key(sample))
buffer.append(sample)
return tuple(restore_keys), buffer
def clear(self) -> None:
self._buffer.clear()
self._restore_keys.clear()
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset": self.dataset.config(),
"worker_config": self.worker_config.config(),
}
def debug_print(self, indent: str = ""):
print(
f"{indent}SavableSampleBuffer(size={len(self._restore_keys)}, res_pend={self._restore_pending}):\n",
end="",
)
for i, (sample, restore_key) in enumerate(zip(self._buffer, self._restore_keys)):
print(f"{indent}Sample {i} [{restore_key!r}]: {sample.__key__}\n", end="")
def __str__(self):
return f"SavableSampleBuffer(size={len(self._buffer)})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Dict, Generic, Iterator, TypeVar
from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
T_sample = TypeVar("T_sample")
class ConcatDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""
This dataset wrapper concatenates multiple iterable datasets together. The datasets must be
finite, otherwise not all datasets can be sampled. This is only useful for validation / test
datasets.
"""
def __init__(
self,
*datasets: SavableDataset[T_sample],
worker_config: WorkerConfig,
):
"""Construct a concatenated dataset."""
super().__init__(datasets, worker_config=worker_config)
assert len(self) >= 0, "Datasets must be finite."
def reset_state_own(self) -> None:
return
def len_worker(self, worker_idx: int | None = None) -> int:
return sum(dataset.len_worker(worker_idx) for dataset in self.datasets)
def __iter__(self) -> Iterator[T_sample]:
for ds_idx, dataset in enumerate(self.datasets):
for sample in dataset:
yield add_sample_restore_key(
sample,
ds_idx,
src=self,
)
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"datasets": [dataset.config() for dataset in self.datasets],
}
def __str__(self):
return f"ConcatDataset(datasets={self.datasets})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Dict, Generic, Iterator, Optional, TypeVar
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
T_sample = TypeVar("T_sample")
class EpochizeDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""
Uses the base dataset, and creates one epoch, which has length samples. Keeps the underlying
dataset iterator alive over epochs (i.e. if it is an infinite dataset, it will keep the state).
Repeats the underlying dataset if the iterator is exhausted.
"""
length: int
_active_iter: Optional[Iterator[T_sample]]
_offset: int
_savable_fields = ("_offset",)
def __init__(
self,
dataset: SavableDataset[T_sample],
length: int,
worker_config: WorkerConfig,
):
"""
Create the epochized dataset.
Args:
dataset: The source dataset (possibly infinite)
length: Number of samples to iterate before iteration stops (i.e. one epoch). When
iteration continues, the original dataset iterator is resumed and does only restart
if exhausted.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
self.length = length
self._active_iter = None
self.reset_state_own()
def reset_state_own(self) -> None:
self._offset = 0
def __iter__(self) -> Iterator[T_sample]:
# Compute the local length for this worker, i.e. all worker's lengths sum up to the total
if self.worker_config.num_workers <= 1:
local_length = self.length
else:
local_length = self.length // self.worker_config.num_workers
if self.worker_config.rank_worker_id() < self.length % self.worker_config.num_workers:
local_length += 1
if self.worker_config.should_log(level=2):
self.worker_config.worker_log(
{
"t": "EpochizeDataset.epoch_start",
"r": self.worker_config.rank,
"w": self.worker_config.rank_worker_id(),
"offset": self._offset,
"local_length": local_length,
"length": self.length,
}
)
offset_range = list(range(self._offset, local_length))
# Only iterate if there are samples to iterate
if len(offset_range) > 0:
if self._active_iter is None:
self._active_iter = iter(self.dataset)
for idx in offset_range:
self._offset = (idx + 1) % local_length
try:
sample = next(self._active_iter)
except StopIteration:
break
yield sample
if self.worker_config.should_log(level=2):
self.worker_config.worker_log(
{
"t": "EpochizeDataset.epoch_end",
"r": self.worker_config.rank,
"w": self.worker_config.rank_worker_id(),
"offset": self._offset,
"local_length": local_length,
"length": self.length,
}
)
def len_worker(self, worker_idx: int | None = None) -> int:
if worker_idx is None:
self.worker_config.assert_worker()
worker_idx = self.worker_config.rank_worker_id()
if self.worker_config.num_workers <= 1:
assert worker_idx == 0
return self.length
else:
local_length = self.length // self.worker_config.num_workers
if worker_idx < self.length % self.worker_config.num_workers:
local_length += 1
return local_length
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset": self.dataset.config(),
"length": self.length,
"worker_config": self.worker_config.config(),
}
def __str__(self):
return f"EpochizeDataset(length={self.length}, dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Callable, Dict, Generic, Iterator, Optional, TypeVar, Union
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex
T_sample = TypeVar("T_sample")
class FilterDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""This dataset wrapper applies a custom filter function to each sample and does not yield
filtered samples."""
filter_fn: Callable[[T_sample], bool]
filter_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]]
_sample_index: SampleIndex
_savable_fields = ("_sample_index",)
def __init__(
self,
dataset: SavableDataset[T_sample],
*,
filter_fn: Callable[[T_sample], bool],
filter_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None,
worker_config: WorkerConfig,
):
"""Construct a MapDataset.
Args:
dataset: The input dataset to wrap
filter_fn: The function to apply to each sample. If it returns `True`, the sample is
accepted.
filter_fn_config: Configuration for the filter function. If callable, it should return the
configuration. Defaults to None.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
self.filter_fn = filter_fn
self.filter_fn_config = filter_fn_config
self.reset_state_own()
def reset_state_own(self) -> None:
self._sample_index = SampleIndex(self.worker_config, src=self)
def len_worker(self, worker_idx: int | None = None) -> int:
return self.dataset.len_worker(worker_idx)
def __iter__(self) -> Iterator[T_sample]:
for sample in self.dataset:
with self._sample_index.ctx():
filter_res = self.filter_fn(sample)
if filter_res:
yield sample
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset": self.dataset.config(),
"filter_fn": self._function_config(self.filter_fn),
**(
{
"filter_fn_config": (
self.filter_fn_config()
if callable(self.filter_fn_config)
else self.filter_fn_config
)
}
if self.filter_fn_config
else {}
),
}
def __str__(self):
return f"FilterDataset(filter_fn={self.filter_fn}, dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import gc
from typing import Any, Dict, Generic, Iterator, TypeVar
import torch
import torch.utils.data
import torch.utils.data.dataloader
from torch.distributed._shard.sharded_tensor import ShardedTensorBase
from torch.distributed.distributed_c10d import reduce_op
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
T_sample = TypeVar("T_sample")
_frozen_cuda_tensors = set()
_frozen_cuda_tensors_initialized = False
GC_DEFAULT_EVERY_N_ITER = 10
class GcFreezeError(RuntimeError):
pass
def gc_init_worker(worker_id: int):
"""This function should be called by any forked worker process that uses CUDA.
It should be called as early as possible in the worker process, ideally in
the worker_init_fn of the DataLoader.
By keeping a reference to all CUDA tensors in the worker process, we can
prevent the forked tensors from being garbage collected."""
global _frozen_cuda_tensors_initialized, _frozen_cuda_tensors
num_tensors = 0
for o in gc.get_objects():
try:
if o is not reduce_op:
if isinstance(o, torch.Tensor):
if isinstance(o, ShardedTensorBase) or o.is_cuda:
# Calling .is_cuda or any hasattr on ShardedTensor will raise an error
# Hence, o.is_cuda is only called if o is not a ShardedTensor (in the if above)
_frozen_cuda_tensors.add(o)
num_tensors += 1
elif isinstance(o, torch.utils.data.dataloader._MultiProcessingDataLoaderIter):
o._shutdown = True
except ReferenceError:
# Can happen if the object is a weakref proxy, don't care
pass
_frozen_cuda_tensors_initialized = True
class GcDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""Applies a garbage collection step. This is needed, because python garbage collection
does not work well with very large objects, such as tensors. This case happens, if there are
a few hundred objects created and released every epoch (some of them being (large) tensors),
where a lot of them are alive at the same time, but released later. In that case, those objects
may end up in gc generation 2, where they may live until a lot of objects have been created,
until automatic garbage collection of gen2 is actually triggered. To avoid this memory leak,
`gc.collect()` is best to be called regularly. In addition, if `gc.freeze()` is used before the
loop, it will remove the objects currently alive from garbage collection checks, thus making the
gc faster.
"""
every_n_iter: int
freeze: bool
def __init__(
self,
dataset: SavableDataset[T_sample],
*,
worker_config: WorkerConfig,
every_n_iter: int = GC_DEFAULT_EVERY_N_ITER,
freeze: bool = True,
):
"""Construct a GcDataset, which applies garbage collection every `every_n_iter` iterations.
Args:
dataset: The input dataset to wrap
every_n_iter: How often to perform garbage collection
freeze: If true, run `gc.freeze()` before the loop, and `gc.unfreeze()` after the loop.
This will speed up garbage collection, but will keep all initially alive objects
alive until the end of the loop (i.e. if the dataset state was restored, that state
will be saved as well).
"""
super().__init__(dataset, worker_config=worker_config)
self.every_n_iter = every_n_iter
self.freeze = freeze
def reset_state_own(self) -> None:
return
def len_worker(self, worker_idx: int | None = None) -> int:
return self.dataset.len_worker(worker_idx)
def __iter__(self) -> Iterator[T_sample]:
in_worker = torch.utils.data.get_worker_info() is not None
if in_worker and not _frozen_cuda_tensors_initialized:
raise GcFreezeError(
"You are using GcDataset with multiple workers, but forgot to call gc_init_worker() in at least one forked worker process."
)
if self.freeze:
gc.collect()
gc.freeze()
try:
iter = 0
for sample in self.dataset:
yield sample
iter += 1
if iter >= self.every_n_iter:
gc.collect()
iter = 0
finally:
if self.freeze:
gc.unfreeze()
def config(self) -> Dict[str, Any]:
# This is transparent, no config to be saved (it does not affect the dataset)
return self.dataset.config()
def __str__(self):
return f"GcDataset(every_n_iter={self.every_n_iter}, dataset={self.dataset}, freeze={self.freeze})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import inspect
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Hashable,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from megatron.energon.edataclass import edataclass
from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError
from megatron.energon.flavors.base_dataset import (
FlexState,
SavableDataset,
set_sample_restore_key,
)
from megatron.energon.savable import Savable
from megatron.energon.source_info import SourceInfo
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers._log_exception import log_exception
from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex
from megatron.energon.wrappers.buffer import SavableSampleBuffer
from megatron.energon.wrappers.skip import SkipSample
T_batch = TypeVar("T_batch", covariant=True)
T_batch_sample = TypeVar("T_batch_sample", covariant=True)
@edataclass
class Bucket(Savable, Generic[T_batch_sample]):
batch_size: int
samples: SavableSampleBuffer[T_batch_sample]
def save_state(self) -> FlexState:
return FlexState(
batch_size=self.batch_size,
samples=self.samples.save_state(),
)
def restore_state(self, state: FlexState):
self.batch_size = state["batch_size"]
self.samples.restore_state(state["samples"])
class GroupBatchDataset(
BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch]
):
"""This dataset wrapper transforms a dataset of samples into a dataset of batches, grouped by some criterion.
The length is not correct, as this function can not predict the number of batches as there is no fixed batch size,
instead it returns the inner dataset size.
An example use case is: Image-Text samples, which are to be grouped by the image size into three
size categories (e.g. 128x128, 256x256, 512x512) for efficient augmentation and batching.
"""
dataset: SavableDataset[T_batch_sample]
sample_group_key: Callable[[T_batch_sample], Tuple[Hashable, Optional[int]]]
batcher: Callable[[List[T_batch_sample]], T_batch]
drop_last: bool
error_handler: Callable[[Exception, List[T_batch_sample], list[SourceInfo]], None]
_group_key_sample_index: SampleIndex
_batch_sample_index: SampleIndex
_buckets: Dict[Hashable, Bucket[T_batch_sample]]
_last_batch_failures: int = 0
def __init__(
self,
dataset: SavableDataset[T_batch_sample],
fixed_batch_size: Optional[int],
sample_group_key: Callable[[T_batch_sample], Tuple[Hashable, Optional[int]]],
batcher: Callable[[List[T_batch_sample]], T_batch],
*,
batcher_stateless: bool = False,
batcher_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None,
drop_last: bool = False,
error_handler: Callable[
[Exception, List[T_batch_sample], Sequence[SourceInfo]], None
] = log_exception,
failure_tolerance: int = 100,
worker_config: WorkerConfig,
):
"""Construct a GroupBatchDataset.
Args:
dataset: The input dataset to wrap
fixed_batch_size: Fixed batch size to use for all buckets. If None, the batch size is determined by the sample_group_key function.
sample_group_key: Function which determines the bucket of a sample.
batcher: Function which combines separate samples into a single object. May raise
:exc:`megatron.energon.SkipSample` to skip a sample.
drop_last: If True, the last batch is dropped if it is smaller than the batch size.
error_handler: Handler for errors. Defaults to logging and ignoring the exception.
failure_tolerance: The number of consecutive failures after which the dataset is considered broken. Set to 0 to disable.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
self.fixed_batch_size = fixed_batch_size
self.sample_group_key = sample_group_key
self.batcher = batcher
self.batcher_stateless = batcher_stateless
self.batcher_config = batcher_config
self.drop_last = drop_last
self.error_handler = error_handler
self.failure_tolerance = failure_tolerance
self.reset_state_own()
assert not inspect.isgeneratorfunction(batcher), (
f"Batcher {batcher} must not be a generator function for grouped batching."
)
def reset_state_own(self) -> None:
self._group_key_sample_index = SampleIndex(self.worker_config, src=self)
self._batch_sample_index = SampleIndex(self.worker_config, src=self)
self._buckets = {}
def len_worker(self, worker_idx: int | None = None) -> int:
# Return an upper bound. This is for sure not correct.
return self.dataset.len_worker(worker_idx)
def __iter__(self) -> Iterator[T_batch]:
buckets = self._buckets
if buckets is None:
buckets = self._buckets = dict()
# Load saved state if available
for bucket in buckets.values():
bucket.samples.worker_start()
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial GroupBatchDataset state:\n", end="")
# for bucket_key, bucket in buckets.items():
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{bucket_key}] (bs={bucket.batch_size}, len(samples)={len(bucket.samples)}):\n", end="")
# bucket.samples.debug_print(" ")
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial done\n", end="")
def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]:
# Debug print the state
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flush GroupBatchDataset state:\n", end="")
# for dbg_bucket_key, dbg_bucket in buckets.items():
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{dbg_bucket_key}{'*' if dbg_bucket_key == bucket_key else ''}] (bs={dbg_bucket.batch_size}, len(samples)={len(dbg_bucket.samples)}):\n", end="")
# dbg_bucket.samples.debug_print(" ")
batch_items, sample_restore_keys = bucket.samples.flush()
# print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flushed: len(batch)={len(batch_items)} len(samples)={len(bucket.samples)}\n", end="")
try:
with self._batch_sample_index.ctx() as sample_idx:
batch_sample = self.batcher(batch_items)
assert not isinstance(batch_sample, Generator), (
f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet."
)
self._last_batch_failures = 0
set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self)
yield batch_sample
except SkipSample:
pass
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(batch_items)
except Exception as e:
self.error_handler(e, batch_items)
self._last_batch_failures += 1
if (
self.failure_tolerance > 0
and self._last_batch_failures >= self.failure_tolerance
):
raise FatalSampleError.from_sample(
batch_items,
f"GroupBatchDataset {self.batcher} failed {self._last_batch_failures} times in a row. Likely your code or dataset are broken.",
)
# Add samples to the buckets
for sample in self.dataset:
try:
with self._group_key_sample_index.ctx():
bucket_key, batch_size = self.sample_group_key(sample)
assert (batch_size is None) != (self.fixed_batch_size is None), (
f"A sample in group for key {bucket_key} returned batch size {batch_size}, but fixed "
f"batch size is set to {self.fixed_batch_size}. One of the two should be None."
)
if self.fixed_batch_size is not None:
batch_size = self.fixed_batch_size
except SkipSample:
continue
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(sample)
except Exception as e:
self.error_handler(e, [sample])
continue
bucket = buckets.get(bucket_key)
if bucket is None:
assert batch_size is not None
buckets[bucket_key] = bucket = Bucket(
batch_size=batch_size,
samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config),
)
else:
assert bucket.batch_size == batch_size, (
f"Got different batch size for group {bucket_key}: {bucket.batch_size} != {batch_size}."
)
bucket.samples.append(sample)
if bucket.samples.len_worker() >= bucket.batch_size:
yield from flush(bucket)
# Flush out last samples
if not self.drop_last:
for bucket in buckets.values():
if bucket.samples.len_worker() > 0:
yield from flush(bucket)
# Clear the buckets
self._buckets.clear()
def save_state(self) -> FlexState:
return FlexState(
bucket_sample_index=self._group_key_sample_index.save_state(),
batch_sample_index=self._batch_sample_index.save_state(),
buckets={key: bucket.save_state() for key, bucket in self._buckets.items()},
**super().save_state(),
)
def restore_state(self, state: FlexState) -> None:
super().restore_state(state)
self._group_key_sample_index.restore_state(state["bucket_sample_index"])
self._batch_sample_index.restore_state(state["batch_sample_index"])
for key, bucket_state in state["buckets"].items():
self._buckets[key] = Bucket(
batch_size=-1,
samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config),
)
self._buckets[key].restore_state(bucket_state)
def can_restore_sample(self) -> bool:
return super().can_restore_sample() and self.batcher_stateless
def assert_can_restore(self) -> None:
assert self.batcher_stateless, (
f"Batcher {self.batcher} must be stateless to restore samples"
)
super().assert_can_restore()
def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch:
self.assert_can_restore()
id, sample_idx, *sample_restore_keys = index
assert id == type(self).__name__
batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys]
try:
with self._batch_sample_index.ctx(sample_idx):
batch_sample = self.batcher(batch)
set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self)
self._last_batch_failures = 0
except SkipSample:
pass
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(batch)
except Exception as e:
self.error_handler(e, batch)
self._last_batch_failures += 1
if self.failure_tolerance > 0 and self._last_batch_failures >= self.failure_tolerance:
raise FatalSampleError.from_sample(
batch,
f"GroupBatchDataset {self.batcher} failed {self._last_batch_failures} times in a row. Likely your code or dataset are broken.",
)
return batch_sample
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"bucket": self._function_config(self.sample_group_key),
"batcher": self._function_config(self.batcher),
**(
{
"batcher_config": (
self.batcher_config()
if callable(self.batcher_config)
else self.batcher_config
)
}
if self.batcher_config
else {}
),
"batcher_stateless": self.batcher_stateless,
"drop_last": self.drop_last,
"error_handler": self._function_config(self.error_handler),
"worker_config": self.worker_config.config(),
"dataset": self.dataset.config(),
}
def __str__(self):
return f"GroupBatchDataset(bucket={self.sample_group_key}, batcher={self.batcher}, drop_last={self.drop_last}, dataset={self.dataset})"
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Iterator,
Optional,
Tuple,
TypeVar,
Union,
)
from torch.utils.data import IterableDataset
from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError
from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key
from megatron.energon.source_info import SourceInfo
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers._log_exception import log_exception
from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key
T_sample = TypeVar("T_sample")
T_sample_out = TypeVar("T_sample_out")
class IterMapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]):
"""This dataset wrapper applies a custom function to transform the stream of samples and yield
a new stream of samples.
If used in a savable dataset context, it is critical, that `iter_map_fn` is either stateless,
or that the state of the `iter_map_fn` is saved and restored externally.
"""
iter_map_fn: Callable[[Iterator[T_sample]], Iterator[T_sample_out]]
len_map_fn: Callable[[int], int]
error_handler: Callable[[Exception, Optional[T_sample], list[SourceInfo]], None]
stateless_iter_fn: bool
iter_map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]]
_sample_index: SampleIndex
_savable_fields = ("_sample_index",)
def __init__(
self,
dataset: SavableDataset[T_sample],
iter_map_fn: Callable[[Iterator[T_sample]], Iterator[T_sample_out]],
*,
len_map_fn: Callable[[int], int] = lambda x: x,
error_handler: Callable[
[Exception, Optional[T_sample], list[SourceInfo]], None
] = log_exception,
stateless_iter_fn: bool = False,
iter_map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None,
worker_config: WorkerConfig,
):
"""Construct a IterMapDataset.
For saving and restoring samples, the iter_map_fn must only yield 0 or 1 sample per
iterated sample.
Args:
dataset: The input dataset to wrap
iter_map_fn: The function to apply to the stream of samples. Returns a new stream of
samples. If savability should be preserved, this function should be stateless.
len_map_fn: The function to apply to the length of the dataset. Returns the new
(approximate) length of the resulting stream of samples based on the original
length.
error_handler: Handler for errors. Defaults to logging and ignoring the exception.
stateless_iter_fn: If true, assume the iter_map_fn is deterministic and stateless
(it does not aggregate samples (thus key for random access can propagate to inner
dataset), yielding zero or multiple samples per fetched sample is fine).
Defaults to False.
iter_map_fn_config: Configuration for the iter_map_fn function. If callable, it should return the
configuration. Defaults to None.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
self.iter_map_fn = iter_map_fn
self.len_map_fn = len_map_fn
self.error_handler = error_handler
self.stateless_iter_fn = stateless_iter_fn
self.iter_map_fn_config = iter_map_fn_config
self.reset_state_own()
def reset_state_own(self) -> None:
self._sample_index = SampleIndex(self.worker_config, src=self)
def len_worker(self, worker_idx: int | None = None) -> int:
return self.len_map_fn(self.dataset.len_worker(worker_idx))
def __iter__(self) -> Iterator[T_sample_out]:
last_sample_wrapper = _LastSampleWrapper(self.dataset)
# The iter_map_fn is stateless. Thus we need to know which inner sample created the
# outer sample, and the relative outer sample index, so we can restore it.
# This is the sample index within the currently yielded sample
iter_idx = 0
sample_idx = 0
sample_restore_keys = []
def reset_idx_iter() -> Generator[T_sample, None, None]:
# Resets the inner sample index
nonlocal iter_idx, sample_restore_keys
for entry in last_sample_wrapper:
iter_idx = 0
sample_restore_keys.append(get_sample_restore_key(entry))
yield entry
ds_iter = iter(reset_idx_iter())
# While True will break when the inner dataset is exhausted, but may continue on exception
while True:
iter_idx = 0
try:
for sample_idx, sample in self._sample_index.iter_ctx(self.iter_map_fn(ds_iter)):
yield set_sample_restore_key(
sample,
sample_idx,
iter_idx,
*sample_restore_keys,
src=self,
)
sample_restore_keys.clear()
iter_idx += 1
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(last_sample_wrapper.last_sample)
except Exception as e:
self.error_handler(e, last_sample_wrapper.last_sample)
else:
break
def can_restore_sample(self) -> bool:
return super().can_restore_sample() and self.stateless_iter_fn
def assert_can_restore(self) -> None:
assert self.stateless_iter_fn, (
"IterMapDataset can only restore samples if iter_map_fn is stateless."
)
super().assert_can_restore()
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample:
self.assert_can_restore()
id, sample_idx, iter_idx, *sample_restore_keys = restore_key
assert id == type(self).__name__
assert isinstance(iter_idx, int)
to_be_mapped = (
self.dataset.restore_sample(inner_index) for inner_index in sample_restore_keys
)
try:
inner_iter = iter(self.iter_map_fn(to_be_mapped))
# Skip inner yielded samples to get the correct sample
for skip_idx in range(iter_idx):
with self._sample_index.ctx(sample_idx - iter_idx + skip_idx):
next(inner_iter)
# This is the sample to restore
with self._sample_index.ctx(sample_idx):
sample = next(inner_iter)
return set_sample_restore_key(
sample,
sample_idx,
iter_idx,
*sample_restore_keys,
src=self,
)
except StopIteration:
raise RuntimeError(
"Generator did not yield enough samples, but is marked stateless/deterministic."
)
except GeneratorExit:
raise FatalSampleError.from_sample(
to_be_mapped,
f"IterMapDataset {self.iter_map_fn} generator exited while trying to restore a sample.",
)
except SYSTEM_EXCEPTIONS:
raise FatalSampleError.from_sample(to_be_mapped)
except Exception as e:
self.error_handler(e, to_be_mapped)
finally:
# Properly close if it's a generator
if hasattr(inner_iter, "close"):
inner_iter.close()
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset": self.dataset.config(),
"iter_map_fn": self._function_config(self.iter_map_fn),
**(
{
"iter_map_fn_config": (
self.iter_map_fn_config()
if callable(self.iter_map_fn_config)
else self.iter_map_fn_config
)
}
if self.iter_map_fn_config
else {}
),
"len_map_fn": self._function_config(self.len_map_fn),
"error_handler": self._function_config(self.error_handler),
}
def __str__(self):
return f"IterMapDataset(iter_map_fn={self.iter_map_fn}, dataset={self.dataset})"
class _LastSampleWrapper:
"""
Wraps the inner dataset and stores the last iterated sample.
"""
last_sample: Optional[T_sample] = None
dataset: IterableDataset[T_sample]
def __init__(self, dataset: IterableDataset[T_sample]):
self.dataset = dataset
def __iter__(self) -> Iterator[T_sample]:
for sample in self.dataset:
self.last_sample = sample
yield sample
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Dict, Generic, Iterator, TypeVar
from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset
T_sample = TypeVar("T_sample")
class LimitDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
"""Limits the length of the dataset."""
length: int
current_offset: int
_savable_fields = ("current_offset",)
def __init__(
self,
dataset: SavableDataset[T_sample],
length: int,
*,
reset_after_epoch: bool = False,
worker_config: WorkerConfig,
):
"""
Limits the length of the dataset.
Args:
dataset: The dataset to limit
length: The length to limit to
reset_after_epoch: If true, reset the underlying dataset after one epoch.
worker_config: Configuration for the workers.
"""
super().__init__(dataset, worker_config=worker_config)
self.length = length
self.reset_after_epoch = reset_after_epoch
self.reset_state_own()
def reset_state_own(self) -> None:
self.current_offset = 0
def len_worker(self, worker_idx: int | None = None) -> int:
if worker_idx is None:
self.worker_config.assert_worker()
worker_idx = self.worker_config.rank_worker_id()
if self.worker_config.num_workers <= 1:
return self.length
else:
local_limit = self.length // self.worker_config.num_workers
if worker_idx < self.length % self.worker_config.num_workers:
local_limit += 1
return local_limit
def len_rank(self) -> int:
return min(self.length, self.dataset.len_rank())
def __iter__(self) -> Iterator[T_sample]:
worker_id = self.worker_config.rank_worker_id()
# Compute the local limit for this worker, i.e. all worker's limits sum up to the total
if self.worker_config.num_workers <= 1:
local_limit = self.length
else:
local_limit = self.length // self.worker_config.num_workers
if worker_id < self.length % self.worker_config.num_workers:
local_limit += 1
if self.worker_config.should_log(level=2):
self.worker_config.worker_log(
{
"t": "LimitDataset.start",
"r": self.worker_config.rank,
"w": worker_id,
"offset": self.current_offset,
"local_limit": local_limit,
"limit": self.length,
}
)
offset_range = list(range(self.current_offset, local_limit))
# Only iterate self.dataset if there are samples to iterate
if len(offset_range) > 0:
for sample, offset in zip(
self.dataset,
offset_range,
):
self.current_offset = offset + 1
yield sample
if self.worker_config.should_log(level=2):
self.worker_config.worker_log(
{
"t": "LimitDataset.done",
"r": self.worker_config.rank,
"w": worker_id,
"offset": self.current_offset,
"local_limit": local_limit,
"limit": self.length,
}
)
# Reset the inner dataset
self.dataset.reset_state_deep()
self.current_offset = 0
if self.reset_after_epoch:
self.dataset.reset_state_deep()
def worker_has_samples(self) -> bool:
return super().worker_has_samples() and self.length > 0
def config(self) -> Dict[str, Any]:
return {
"type": type(self).__qualname__,
"dataset": self.dataset.config(),
"length": self.length,
"reset_after_epoch": self.reset_after_epoch,
"worker_config": self.worker_config.config(),
}
def __str__(self):
return f"LimitDataset(length={self.length}, dataset={self.dataset})"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment