Commit f356f546 authored by maming's avatar maming
Browse files

Initial commit

parents
Pipeline #3339 canceled with stages
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Optional, TypeVar
from megatron.energon.cache.base import CachePool, FileStore, Lazy, MockLazy
from megatron.energon.edataclass import edataclass
from megatron.energon.source_info import SourceInfo, add_source_info
T = TypeVar("T")
@edataclass
class DirectLazy(Lazy[T]):
"""
This is not really lazy, it will just defer the dataset access to the first get().
"""
_data: Optional[tuple[T, SourceInfo]] = None
def get(self, sample: Any = None) -> T:
"""Get the lazy data now and adds no source info to the sample."""
if self._data is None:
self._data = self.ds[self.fname]
assert self._data is not None
add_source_info(sample, self._data[1])
return self._data[0]
def __hash__(self) -> int:
"""Allows usage in sets and dicts as key."""
return hash((id(self.ds), self.fname))
def __eq__(self, other: Any) -> bool:
"""Allows usage in sets and dicts as key. Compares the data source and the filename."""
if not isinstance(other, Lazy):
return False
return self.ds is other.ds and self.fname == other.fname
class NoCachePool(CachePool):
"""A pass-through cache pool that does not cache anything."""
def get(self, ds: FileStore, fname: str, sample: Any = None) -> Any:
"""Get the data for a given file and adds the source info to the sample."""
return ds.get(fname, sample)
def get_lazy(self, ds: FileStore, fname: str) -> DirectLazy:
return DirectLazy(ds=ds, fname=fname, pool=self)
def to_cache(self, data: T, name: str) -> DirectLazy:
return MockLazy(fname=name, get_fn=lambda _: data)
def close(self) -> None:
pass
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""Main entry point for energon CLI tools."""
import click
from megatron.energon.tools.analyze_debug import command as analyze_debug_command
from megatron.energon.tools.checkpoint import command as checkpoint_command
from megatron.energon.tools.info import command as info_command
from megatron.energon.tools.lint import command as lint_command
from megatron.energon.tools.mount import command as mount_command
from megatron.energon.tools.prepare import command as prepare_command
from megatron.energon.tools.preview import command as preview_command
def print_version(ctx, param, value):
"""Print version and exit."""
if not value or ctx.resilient_parsing:
return
try:
from megatron.energon import __version__
click.echo(f"megatron-energon {__version__}")
except ImportError:
click.echo("megatron-energon version unknown")
ctx.exit()
@click.group(context_settings=dict(help_option_names=["-h", "--help"]), invoke_without_command=True)
@click.option(
"--version",
"-V",
is_flag=True,
expose_value=False,
is_eager=True,
callback=print_version,
help="Show version and exit.",
)
@click.pass_context
def main(ctx):
"""A set of tools that energon provides.
Among other things, you can use it to lint or preprocess your dataset.
See help of commands to learn more."""
# This is needed to show help if no subcommand is provided
if ctx.invoked_subcommand is None:
click.echo(main.get_help(ctx))
main.add_command(analyze_debug_command)
main.add_command(checkpoint_command)
main.add_command(lint_command)
main.add_command(info_command)
main.add_command(mount_command)
main.add_command(prepare_command)
main.add_command(preview_command)
if __name__ == "__main__":
main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from pathlib import Path
from typing import Any, Dict, Optional, Type, TypeVar, Union
from megatron.energon.epathlib import EPath
from megatron.energon.eyaml import load_yaml
from megatron.energon.flavors import (
BaseCoreDatasetFactory,
CrudeSample,
DefaultCrudeJsonlDatasetFactory,
StandardWebdatasetFactory,
)
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type
from megatron.energon.typed_converter import JsonParser
from megatron.energon.worker import WorkerConfig
T = TypeVar("T")
def load_config(
path: Union[EPath, Dict[str, Any]],
*,
default_type: Type[T],
default_kwargs: Optional[Dict[str, Any]] = None,
parser: JsonParser = JsonParser(strict=True),
) -> T:
"""
Loads a config from a file or directly from a dictionary.
Args:
path: Path to the config to load or a dictionary containing the config.
default_type: If set, this is the type to use if no type is specified in the config.
default_kwargs: Default kwargs to use, will be overridden by the config.
Returns:
The instantiated type.
"""
if isinstance(path, dict):
data = path
else:
# Read the config from a file
with path.open("rb") as f:
data: dict = load_yaml(f)
if default_kwargs is not None:
new_data = default_kwargs.copy()
new_data.update(data)
data = new_data
return parser.raw_to_instance(data, default_type)
T_sample = TypeVar("T_sample", covariant=True)
def get_dataset_from_config(
path: Union[EPath, Path, str],
*,
dataset_config: str | None = None,
split_config: str | None = None,
split_part: str | None = None,
training: bool = True,
subflavors: Optional[Dict[str, Any]] = None,
worker_config: WorkerConfig,
sample_type: Optional[Type[T_sample]] = None,
**kwargs,
) -> BaseCoreDatasetFactory[T_sample]:
"""
Gets a dataset from a config path or path to a jsonl file.
Args:
path: Path to the folder where the `.nv-meta` folder is contained, or path to a jsonl file.
dataset_config: Filename of the dataset config file (`path / '.nv-meta' / config`), or None for jsonl datasets.
split_config: Filename of the split config file (`path / '.nv-meta' / split_config`), or None for jsonl datasets.
split_part: Name of the split to load, or None for jsonl datasets.
training: If true, apply training randomization and loop the dataset.
subflavors: Merge-Override the __subflavors__ property of each sample.
worker_config: If set, use this worker config instead of the default one.
sample_type: Type of the samples to load, only used to ensure typing.
**kwargs: Additional arguments to be passed to the dataset constructor.
Returns:
The instantiated dataset
"""
path = EPath(path)
dataset: BaseCoreDatasetFactory[T_sample]
ds_type = get_dataset_type(path)
if ds_type == EnergonDatasetType.JSONL:
assert sample_type is CrudeSample or sample_type is None, (
f"Sample type must be CrudeSample for jsonl datasets, but got {sample_type}"
)
assert dataset_config is None, (
f"Dataset config must be None for jsonl datasets, but got {dataset_config}"
)
assert split_config is None, (
f"Split config must be None for jsonl datasets, but got {split_config}"
)
# Note: We ignore split_part for jsonl datasets and always return the full dataset.
dataset = DefaultCrudeJsonlDatasetFactory(
path,
training=training,
subflavors=subflavors,
worker_config=worker_config,
**kwargs,
)
elif ds_type == EnergonDatasetType.WEBDATASET:
if dataset_config is None:
dataset_config = "dataset.yaml"
if split_config is None:
split_config = "split.yaml"
if split_part is None:
split_part = "train"
dataset = load_config(
path / MAIN_FOLDER_NAME / dataset_config,
default_kwargs=dict(
path=path,
split_config=split_config,
split_part=split_part,
training=training,
worker_config=worker_config,
**kwargs,
),
default_type=StandardWebdatasetFactory,
)
else:
raise ValueError(
f"Path {path} does not contain a {MAIN_FOLDER_NAME}/.info.yaml or .info.json file nor is it a jsonl file. "
f"Did you forget to prepare the dataset? Please check the documentation for an introduction to dataset "
f"preparation."
)
if subflavors is not None:
dataset.subflavors.update(subflavors)
if sample_type is not None:
assert issubclass(dataset.__sample_type__, sample_type), (
f"Sample of type {dataset.__sample_type__} is not a subclass of {sample_type}."
)
return dataset
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from dataclasses import dataclass
from typing_extensions import dataclass_transform
# We define an alias for `@dataclass(slots=True, kw_only=True)`,
# because we want to use this combination almost everywhere.
@dataclass_transform(kw_only_default=True, slots_default=True)
def edataclass(cls):
"""
A dataclass transform that sets the kw_only and slots defaults to True.
This is equivalent to `@dataclass(slots=True, kw_only=True)`.
If you need more options, use `dataclass` directly.
E.g.: `@dataclass(slots=True, kw_only=True, eq=False)`.
"""
return dataclass(kw_only=True, slots=True)(cls)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.epathlib.epath import EPath
__all__ = ["EPath"]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import re
import shutil
from pathlib import Path as PathlibPath
from pathlib import PurePosixPath
from typing import BinaryIO, Generator, Literal, Optional, TextIO, Tuple, Union, overload
import multistorageclient as msc
from megatron.energon.errors import warn_deprecated
class EPath:
"""EPath work much like pathlib.Path but does support local as well as remote file systems.
To initialize a remote path, prepend the path with "msc://" to use the Multi-Storage Client (MSC).
For example:
EPath("msc://profilename/my_datasets/webdataset-000.tar")
You will need to have your MSC configuration (~/.msc_config.yaml) set up to access the object stores
or use your rclone configuration. See https://nvidia.github.io/multi-storage-client/config/index.html
for more information.
"""
# The path without the protocol. Can also be in S3 for example
internal_path: PurePosixPath
# The profile used to access the file system
profile: str
# The file system
fs: msc.StorageClient
def __init__(
self,
initial_path: Union[str, "EPath", PathlibPath],
) -> None:
if isinstance(initial_path, EPath):
self.internal_path = initial_path.internal_path
self.profile = initial_path.profile
self.fs = initial_path.fs
else:
if isinstance(initial_path, PathlibPath):
path = str(initial_path.absolute())
profile = "default"
else:
protocol, profile, path = self._split_protocol(initial_path)
if protocol is None or protocol == "file":
profile = "default"
path = str(PathlibPath(path).absolute())
elif protocol == "rclone":
warn_deprecated("rclone:// protocol is deprecated. Use msc:// instead.")
else:
assert protocol == "msc", f"Unknown protocol: {protocol}"
if not path.startswith("/"):
path = "/" + path
self.internal_path = self._resolve(path)
assert profile is not None
self.profile = profile
# Resolve the client. Only depends on the protocol and the first part of the path
self.fs, _ = msc.resolve_storage_client(f"msc://{self.profile}")
def __getstate__(self) -> dict:
return {
"internal_path": self.internal_path,
"profile": self.profile,
# Do not save the fs when serializing, to avoid leaking credentials
}
def __setstate__(self, state: dict) -> None:
self.internal_path = state["internal_path"]
self.profile = state["profile"]
self.fs, _ = msc.resolve_storage_client(f"msc://{self.profile}")
@staticmethod
def _resolve(path: Union[str, PurePosixPath]) -> PurePosixPath:
"""Resolve a path, removing .. and . components."""
if isinstance(path, str):
path = PurePosixPath(path)
parts = path.parts
if parts[0] != "/":
raise ValueError("Only absolute paths are supported")
if ".." in parts or "." in parts:
new_parts = []
for part in parts[1:]:
if part == "..":
if len(new_parts) == 0:
raise ValueError(f"Path above root: {path}")
new_parts.pop()
elif part == ".":
pass
else:
new_parts.append(part)
path = PurePosixPath("/", *new_parts)
return path
@staticmethod
def _split_protocol(path: str) -> Tuple[Optional[str], Optional[str], str]:
regex = re.compile(r"^(?P<protocol>[a-z]+)://(?P<profile>[^/]+?)/(?P<path>.+)$")
m = regex.match(path)
if m is None:
return None, None, path
return m.group("protocol"), m.group("profile"), m.group("path")
@property
def _internal_str_path(self) -> str:
"""Return the path as used inside the file system, without the protocol and fs part."""
return str(self.internal_path)
@overload
def open(self, mode: Literal["r", "w"] = "r", block_size: Optional[int] = None) -> TextIO: ...
@overload
def open(self, mode: Literal["rb", "wb"], block_size: Optional[int] = None) -> BinaryIO: ...
def open(
self, mode: Literal["r", "rb", "w", "wb"] = "r", block_size: Optional[int] = None
) -> Union[TextIO, BinaryIO]:
return self.fs.open(self._internal_str_path, mode)
def read_text(self) -> str:
with self.open() as f:
return f.read()
def read_bytes(self) -> bytes:
with self.open("rb") as f:
return f.read()
def write_text(self, text: str) -> None:
with self.open("w") as f:
f.write(text)
def write_bytes(self, data: bytes) -> None:
with self.open("wb") as f:
f.write(data)
def copy(self, target: "EPath") -> None:
"""Copy a file to a new path, possibly between different file systems.
Args:
target: The path to the local file to download to.
"""
if self.is_file():
if self.fs == target.fs:
self.fs.copy(self._internal_str_path, target._internal_str_path)
elif target.is_local():
self.fs.download_file(self._internal_str_path, target._internal_str_path)
elif self.is_local():
target.fs.upload_file(target._internal_str_path, self._internal_str_path)
else:
with self.open("rb") as src_f, target.open("wb") as dst_f:
shutil.copyfileobj(src_f, dst_f)
else:
inner_path = EPath(self)
for fpath in self.fs.list(self._internal_str_path):
inner_path.internal_path = PurePosixPath("/" + fpath.key)
inner_path.copy(target / inner_path.relative_to(self))
@property
def name(self) -> str:
return self.internal_path.name
@property
def parent(self) -> "EPath":
new_path = EPath(self)
new_path.internal_path = self.internal_path.parent
return new_path
@property
def url(self) -> str:
if self.is_local():
return self._internal_str_path
int_path_str = str(self.internal_path)
return f"msc://{self.profile}{int_path_str}"
def is_local(self) -> bool:
return self.profile == "default"
def is_dir(self) -> bool:
try:
return self.fs.info(self._internal_str_path).type == "directory"
except FileNotFoundError:
return False
def is_file(self) -> bool:
return self.fs.is_file(self._internal_str_path)
def mkdir(self, exist_ok: bool = True, parents: bool = False):
pass
def glob(self, pattern) -> Generator["EPath", None, None]:
search_path_pattern = (self / pattern)._internal_str_path
for path in self.fs.glob(search_path_pattern):
assert isinstance(path, str)
new_path = EPath(self)
new_path.internal_path = self._resolve(self.internal_path / PurePosixPath(path))
yield new_path
def size(self) -> int:
return self.fs.info(self._internal_str_path).content_length
def with_suffix(self, suffix: str) -> "EPath":
new_path = EPath(self)
new_path.internal_path = self.internal_path.with_suffix(suffix)
return new_path
def move(self, target: "EPath") -> None:
self.copy(target)
self.unlink()
def unlink(self) -> None:
return self.fs.delete(self._internal_str_path)
def relative_to(self, other: "EPath") -> str:
assert self.profile == other.profile, "Can only use relative_to within same profile"
return str(self.internal_path.relative_to(other.internal_path))
def __truediv__(self, other: Union[str, "EPath"]) -> "EPath":
if isinstance(other, EPath):
# Always absolute
return other
if other.startswith("/") or "://" in other:
return EPath(other)
new_path = EPath(self)
new_path.internal_path = self._resolve(self.internal_path / other)
return new_path
def __lt__(self, other: "EPath") -> bool:
assert self.profile == other.profile, "Cannot compare paths from different profiles"
return self.internal_path < other.internal_path
def __str__(self) -> str:
return self.url
def __repr__(self) -> str:
return f"EPath({str(self)!r})"
def __hash__(self) -> int:
return hash((self.internal_path, self.profile))
def __eq__(self, other: object) -> bool:
return (
isinstance(other, EPath)
and self.internal_path == other.internal_path
and self.profile == other.profile
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import configparser
import os
import shutil
from pathlib import Path
from typing import Dict, Optional
from megatron.energon.edataclass import edataclass
@edataclass
class ConfigEntry:
name: str
type: str
provider: Optional[str]
access_key_id: Optional[str]
secret_access_key: Optional[str]
region: Optional[str]
endpoint: Optional[str]
def find_executable_path(executable_name):
"""Find the path of an executable in the PATH environment variable. Returns None if not found."""
executable_path = shutil.which(executable_name)
if executable_path:
return Path(executable_path)
return None
def get_rclone_config_path() -> Optional[Path]:
# First check if rclone executable is in PATH, if yes, check if rclone.conf is in the same directory
rclone_exe_path = find_executable_path("rclone")
if rclone_exe_path is not None and rclone_exe_path.is_file():
rclone_config_path = rclone_exe_path.with_name("rclone.conf")
if rclone_config_path.is_file():
return rclone_config_path
# As a second option check the XDG_CONFIG_HOME environment variable, if it is set, check for rclone/rclone.conf in that directory
xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
if xdg_config_home and Path(xdg_config_home).is_dir():
rclone_config_path = Path(xdg_config_home) / "rclone" / "rclone.conf"
if rclone_config_path.is_file():
return rclone_config_path
# As a third option check the default location ~/.config/rclone/rclone.conf
rclone_config_path = Path.home() / ".config" / "rclone" / "rclone.conf"
if rclone_config_path.is_file():
return rclone_config_path
# Last option is to check the legacy location ~/.rclone.conf
legacy_config_path = Path.home() / ".rclone.conf"
if legacy_config_path.is_file():
return legacy_config_path
return None
def read_rclone_config_at_path(config_path: Path) -> Dict[str, ConfigEntry]:
"""Reads the config file and returns a dictionary with the config entries."""
config = configparser.ConfigParser()
config.read(config_path)
config_entries = {}
for section in config.sections():
entry = ConfigEntry(
name=section,
type=config[section].get("type"),
provider=config[section].get("provider"),
access_key_id=config[section].get("access_key_id"),
secret_access_key=config[section].get("secret_access_key"),
region=config[section].get("region"),
endpoint=config[section].get("endpoint"),
)
config_entries[section] = entry
return config_entries
def read_rclone_config() -> Dict[str, ConfigEntry]:
config_path = get_rclone_config_path()
if config_path is None:
raise FileNotFoundError("Could not find rclone configuration file.")
return read_rclone_config_at_path(config_path)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
import itertools
import warnings
from functools import wraps
from typing import Any, Type, TypeVar, Union
import numpy as np
import torch
def compact_str(
value: Union[dict, list, str, int, bool, None],
depth: int = 3,
max_items: int = 10,
max_str_len: int = 50,
) -> str:
"""
Compact representation of a value as a string.
Args:
value: The value to compact
depth: The maximum depth to compact
max_items: The maximum number of items to show in a list or dict
max_str_len: The maximum string length to show
Returns: The printable string
"""
if isinstance(value, dict):
if depth <= 0:
return "{...}"
return (
"{"
+ ", ".join(
(
f"{k}: {v!r}"
if isinstance(k, str) and k.startswith("__")
else f"{k}: {compact_str(v, depth - 1, max_items, max_str_len)}"
)
for k, v in itertools.islice(value.items(), max_items)
)
+ "}"
)
elif isinstance(value, list):
if depth <= 0:
return "[...]"
return (
"["
+ ", ".join(
compact_str(v, depth - 1, max_items, max_str_len) for v in value[:max_items]
)
+ "]"
)
elif isinstance(value, tuple):
if depth <= 0:
return "(...)"
return (
"("
+ ", ".join(
compact_str(v, depth - 1, max_items, max_str_len) for v in value[:max_items]
)
+ ")"
)
elif isinstance(value, str):
if len(value) > max_str_len:
return repr(value[:max_str_len] + "...")
return repr(value)
elif isinstance(value, torch.Tensor):
return f"Tensor(shape={value.shape}, dtype={value.dtype}, device={value.device})"
elif isinstance(value, np.ndarray):
return f"np.ndarray(shape={value.shape}, dtype={value.dtype})"
elif dataclasses.is_dataclass(value):
return f"{value.__class__.__name__}({', '.join(f'{field.name}={compact_str(getattr(value, field.name))}' for field in dataclasses.fields(value))})"
else:
return compact_str(repr(value), depth, max_items, max_str_len)
T = TypeVar("T")
class SampleException(ValueError):
@classmethod
def from_sample_key(cls: Type[T], sample_key: str) -> T:
return cls(f"Sample {sample_key} failed")
@classmethod
def from_sample(cls: Type[T], sample: Any, message: str = "") -> T:
if message:
message = f": {message}"
return cls(f"Sample {compact_str(sample)} failed{message}")
class FatalSampleError(SampleException):
# This will not be handled by the error handler
pass
def warn_deprecated(reason, stacklevel=2):
warnings.warn(reason, FutureWarning, stacklevel=stacklevel)
def deprecated(reason):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
warn_deprecated(f"{func.__name__} is deprecated: {reason}", stacklevel=3)
return func(*args, **kwargs)
return wrapper
return decorator
SYSTEM_EXCEPTIONS = (
SystemError,
SyntaxError,
ImportError,
StopIteration,
StopAsyncIteration,
MemoryError,
RecursionError,
ReferenceError,
NameError,
UnboundLocalError,
FatalSampleError,
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import json
from typing import Any, BinaryIO, Union
import ryml
from megatron.energon.epathlib.epath import EPath
def load_yaml_json(path: EPath) -> Any:
"""Load a YAML or JSON file from a path. Loading json is much faster than yaml.
Args:
path: The path to the file to load.
Returns:
The parsed YAML or JSON data as a native Python object.
"""
if path.name.endswith(".yaml"):
with path.open("rb") as f:
return load_yaml(f)
elif path.name.endswith(".json"):
with path.open("rb") as f:
return json.load(f)
else:
raise ValueError(f"Invalid file extension: {path.name}")
def load_yaml(stream: Union[BinaryIO, bytes]) -> Any:
"""Load a YAML file from a stream or bytes object using rapidyaml/ryml.
This is much faster than the standard yaml library, but we don't
support all YAML features.
Args:
stream: A stream or bytes object containing the YAML data.
Returns:
The parsed YAML data as a native Python object.
"""
if isinstance(stream, bytes):
bytes_data = stream
else:
bytes_data = stream.read()
tree = ryml.parse_in_arena(bytes_data)
native_obj = _ryml_tree_to_native(tree)
return native_obj
def _cast_scalar(tree: ryml.Tree, nid: int) -> object:
"""Convert YAML scalar to bool | int | float | None | str."""
memview = tree.val(nid)
s = memview.tobytes().decode()
if tree.is_val_quoted(nid):
return s
s_lower = s.lower()
if s_lower in {"null", "~"}:
return None
if s_lower in {"true", "yes"}:
return True
if s_lower in {"false", "no"}:
return False
try:
if "." in s or "e" in s_lower:
return float(s)
return int(s)
except ValueError:
return s
def _to_native(t: ryml.Tree, nid: int):
if t.is_map(nid):
# iterate children of a mapping node
return {t.key(ch).tobytes().decode(): _to_native(t, ch) for ch in ryml.children(t, nid)}
if t.is_seq(nid):
return [_to_native(t, ch) for ch in ryml.children(t, nid)]
# scalar leaf
return _cast_scalar(t, nid)
def _ryml_tree_to_native(tree: ryml.Tree):
return _to_native(tree, tree.root_id())
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.flavors.base_dataset import (
BaseCoreDatasetFactory,
PinMemoryMixin,
Sample,
SavableDataset,
)
from megatron.energon.flavors.captioning import CaptioningSample, CaptioningWebdataset
from megatron.energon.flavors.crude import CrudeSample, CrudeWebdataset
from megatron.energon.flavors.image import ImageSample, ImageWebdataset
from megatron.energon.flavors.image_classification import (
ImageClassificationSample,
ImageClassificationWebdataset,
)
from megatron.energon.flavors.interleaved import InterleavedSample, InterleavedWebdataset
from megatron.energon.flavors.jsonl import (
CrudeJsonlDatasetFactory,
DefaultCrudeJsonlDatasetFactory,
)
from megatron.energon.flavors.multichoice_vqa import MultiChoiceVQASample, MultiChoiceVQAWebdataset
from megatron.energon.flavors.ocr import OCRSample, OCRWebdataset
from megatron.energon.flavors.similarity_interleaved import (
SimilarityInterleavedSample,
SimilarityInterleavedWebdataset,
)
from megatron.energon.flavors.text import TextSample, TextWebdataset
from megatron.energon.flavors.vid_qa import VidQASample, VidQAWebdataset
from megatron.energon.flavors.vqa import VQASample, VQAWebdataset
from megatron.energon.flavors.vqa_and_ocr import VQAOCRWebdataset
from megatron.energon.flavors.webdataset import (
AVData,
AVDecoder,
AVDecoderType,
BaseWebdatasetFactory,
DefaultDecoderWebdatasetFactory,
DefaultGenericWebdatasetFactory,
EmptyDatasetError,
ImageDecoderType,
JoinedWebdatasetFactory,
SampleDecoder,
StandardWebdatasetFactory,
WebdatasetMeta,
)
__all__ = [
"AVData",
"AVDecoder",
"AVDecoderType",
"BaseCoreDatasetFactory",
"BaseWebdatasetFactory",
"CaptioningSample",
"CaptioningWebdataset",
"CrudeJsonlDatasetFactory",
"CrudeSample",
"CrudeWebdataset",
"DefaultCrudeJsonlDatasetFactory",
"DefaultDecoderWebdatasetFactory",
"DefaultGenericWebdatasetFactory",
"EmptyDatasetError",
"ImageClassificationSample",
"ImageClassificationWebdataset",
"ImageDecoderType",
"ImageSample",
"ImageWebdataset",
"InterleavedSample",
"InterleavedWebdataset",
"JoinedWebdatasetFactory",
"MultiChoiceVQASample",
"MultiChoiceVQAWebdataset",
"OCRSample",
"OCRWebdataset",
"PinMemoryMixin",
"Sample",
"SampleDecoder",
"SavableDataset",
"SimilarityInterleavedSample",
"SimilarityInterleavedWebdataset",
"StandardWebdatasetFactory",
"TextSample",
"TextWebdataset",
"VidQASample",
"VidQAWebdataset",
"VQAOCRWebdataset",
"VQASample",
"VQAWebdataset",
"WebdatasetMeta",
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
import inspect
import typing
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import is_dataclass
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import torch
from torch.utils.data import IterableDataset
from typing_extensions import Self
from megatron.energon.cache import FileStore
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
from megatron.energon.savable import Savable
from megatron.energon.source_info import SourceInfo
from megatron.energon.state import FlexState
from megatron.energon.worker import WorkerConfig
T_sample = TypeVar("T_sample", covariant=True)
T = TypeVar("T", covariant=True)
class PinMemoryMixin:
"""A mixin class providing a generic `pin_memory` function."""
def _pin_memory(self, batch: T, device: Union[torch.device, str, None] = None) -> T:
"""Pin memory of a batch. Uses recursion to handle nested structures. Supports nested
structures of dicts, dataclasses, namedtuples, lists and tuples."""
if isinstance(batch, torch.Tensor):
return batch.pin_memory(device)
elif isinstance(batch, dict):
return {key: self._pin_memory(value, device) for key, value in batch.items()}
elif dataclasses.is_dataclass(batch):
return type(batch)(
**{
field.name: self._pin_memory(getattr(batch, field.name), device)
for field in dataclasses.fields(batch)
}
)
elif isinstance(batch, (tuple, list)):
if hasattr(batch, "_fields"):
# NamedTuple
return type(batch)(*[self._pin_memory(val, device) for val in batch])
else:
# list / tuple
return type(batch)(self._pin_memory(val, device) for val in batch)
else:
return batch
def pin_memory(self: Self) -> Self:
return self._pin_memory(self)
class ExtendableDataclassMixin:
"""A mixin class providing a generic `extend` function for copying dataclasses."""
@classmethod
def extend(cls: Type[T], src, **kwargs) -> T:
"""
Used for overridden dataclass instances. Example
.. code-block:: python
@dataclass
class MyBaseClass:
a: List[int]
@dataclass
class MyExtendedClass(MyBaseClass):
# Add a new field `b` to the state
b: List[int]
base = MyBaseClass(a=[1, 2, 3])
extended = MyExtendedClass.extend(base, b=[4, 5, 6])
Args:
src: The source dataclass instance to extend.
**kwargs: The new fields to add to the instance to construct the new instance.
Returns:
The extended dataclass instance.
"""
assert is_dataclass(cls), "Must be a dataclass"
assert issubclass(cls, type(src)), "Cannot extend class of different type"
for f in dataclasses.fields(src):
if not f.init or f.type is ClassVar or typing.get_origin(f.type) is ClassVar:
continue
if f.name not in kwargs:
kwargs[f.name] = getattr(src, f.name)
return cls(**kwargs)
@edataclass
class Sample(ABC, PinMemoryMixin, ExtendableDataclassMixin):
"""An abstract base class for one element of a batch.
Each task should derive a specific subclass as a `@dataclass`, like
:class:`megatron.energon.CaptioningBatchSample`, and add the input and output fields as needed for
training.
"""
#: Uniquely identifies each sample in the dataset.
__key__: str
#: Key for restoring the sample. This is used to restore the sample from a checkpoint. It
# should be a (nested) tuple of strings and integers, which can be used to index the dataset.
__restore_key__: Tuple[Union[str, int, tuple], ...]
#: A dataset may define a subflavors to distinguish between samples of the same sample type.
__subflavors__: Optional[Dict[str, Any]] = None
#: Information about the source of the sample, i.e. where the data was loaded from.
__sources__: Optional[tuple[SourceInfo, ...]] = None
@classmethod
def derive_from(cls: Type[T_sample], base_sample: "Sample", **kwargs) -> T_sample:
"""
Uses the base fields of `Sample` from base_sample (i.e. __key__, __restore_key__, __subflavors__, __sources__)
and creates a new sample with the kwargs as fields. This is useful for creating new samples, while keeping the
metadata of the base sample.
Args:
base_sample: The base sample to copy the base fields / metadata from.
kwargs: The fields of the new sample.
Returns:
The new sample.
"""
base_kwargs = {
field.name: getattr(base_sample, field.name)
for field in dataclasses.fields(Sample)
if field.name not in kwargs
}
return cls(
**base_kwargs,
**kwargs,
)
@classmethod
def from_joined(
cls: Type[T_sample], *args: "Optional[Sample]", **kwargs: "Optional[Sample]"
) -> T_sample:
"""
Creates a sample from joined samples. The samples are either passed as positional arguments or as keyword
arguments. The first sample is the primary sample, which is used to initialize the key and subflavors.
In the default implementation, the joined samples' fields will be joined together, such that latter joined
samples will update the fields last (i.e. take precedence), except for the key and subflavors. The restore key
is later set externally.
Args:
args: The samples to join (either this or kwargs is specified).
kwargs: The samples to join (either this or args is specified). Not supported for the default
implementation. Overwriting implementations may use this.
Returns:
The joined constructed sample.
"""
assert len(kwargs) == 0, (
"Please specify joined datasets as list for the default joiner. Keyword arguments are confusing, because keys are ignored."
)
excluded_fields = set(field.name for field in dataclasses.fields(Sample))
init_args = {}
if len(args) > 0:
primary = args[0]
assert primary is not None, "Primary sample must not be None."
fields = dataclasses.fields(primary)
for field in fields:
init_args[field.name] = getattr(primary, field.name)
# Merge sources from all joined samples
init_args["__sources__"] = (
*(primary.__sources__ or ()),
*(
src
for arg in args
if arg is not None and arg.__sources__ is not None
for src in arg.__sources__
),
)
for arg in args:
if arg is None:
continue
fields = dataclasses.fields(arg)
for field in fields:
if field.name not in excluded_fields:
init_args[field.name] = getattr(arg, field.name)
return cls(**init_args)
@edataclass
class State(ABC, ExtendableDataclassMixin):
"""An abstract base class for the state of a dataset. See :class:`megatron.energon.SavableDataset`.
The state of a dataset is used to save and restore the dataset state (i.e. random generators,
buffer states, file pointers, etc.).
Each dataset should derive a specific subclass as a `@dataclass` and add the fields as needed
for training.
To extend subclasses, use the .extend method. Example:
.. code-block:: python
@dataclass
class MyState(State):
a: int
@dataclass
class MyExtendedState(MyState):
# Add a new field `b` to the state
b: int
class MyStateSaver:
def save_state(self) -> MyState:
return MyState(a=42)
class MyExtendedStateSaver(MyStateSaver):
def save_state(self) -> MyExtendedState:
# Fetch state from super class, which is already a complete instance (cannot add
# new fields to it, type is fixed).
state: MyState = super().save_state()
# Now extend the state of the super class (of type `MyState`) with the new field
# required to define `MyExtendedState`.
return MyExtendedState.extend(state, b=21)
"""
class SavableDataset(IterableDataset[T_sample], Savable, Generic[T_sample], ABC):
"""A dataset that can be saved and restored (i.e. the random state, internal buffers, etc.).
I.e. it can be resumed from a checkpoint.
How dataset state saving works:
1. The dataset state needs to be saved in all forked worker processes which contain a copy of
the main dataset instance (see :class:`megatron.energon.SavableDataLoader`). Each worker returns
only its own state.
2. The main process merges the states via the :meth:`megatron.energon.SavableDataset.merge_states`
method in the main process on the main dataset instance (which doesn't hold the worker states,
as they were forked).
3. The main process saves the merged state to the checkpoint.
"""
worker_config: WorkerConfig
#: List of names of the fields that are saved and restored in the state.
_savable_fields: ClassVar[Tuple[str, ...]] = ()
def __init__(self, worker_config: WorkerConfig):
self.worker_config = worker_config
@abstractmethod
def len_worker(self, worker_idx: int | None = None) -> int:
"""Returns the length of the dataset for the current or a specific worker.
The length is the number of different available samples.
The number of actually yielded samples may be different (considering skipping samples or generator functions).
Args:
worker_idx: The index of the worker to return the length for.
If None, the length of the current worker is returned (must be in worker context).
"""
...
def len_rank(self) -> int:
"""Returns the length of the dataset for the current rank.
The length is the number of different available samples.
The number of actually yielded samples may be different (considering skipping samples or generator functions).
"""
return sum(self.len_worker(i) for i in range(self.worker_config.num_workers or 1))
def __len__(self) -> int:
"""Returns the length of the dataset for the current rank. Corresponds to `len_rank`."""
return self.len_rank()
def save_state(self) -> FlexState:
"""
Saves the state of the dataset. This will save and return the state of all fields
in the _savable_fields tuple.
Can only be called in a worker process.
"""
state = FlexState()
state["__class__"] = type(self).__name__
for key in self._savable_fields:
attr = getattr(self, key)
if isinstance(attr, Savable):
state[key] = attr.save_state()
else:
# Check if this field is a simple python type or a user class
if attr is not None and getattr(attr, "__module__", "builtins") != "builtins":
import warnings
warnings.warn(
f"The savable attribute {key} of class {type(self)} does "
"not inherit from Savable, nor it is a simple builtin type. Please double-check.",
UserWarning,
)
state[key] = deepcopy(getattr(self, key))
return state
def restore_state(self, state: FlexState) -> None:
"""
Restores the state of the dataset. This will restore the state of all fields
in the _savable_fields tuple.
Can only be called in a worker process.
Args:
state: The state of the dataset as savable object. If None, restore initial state.
"""
assert state["__class__"] == type(self).__name__, (
f"Class name mismatch: {state['__class__']} != {type(self).__name__}"
)
for key in self._savable_fields:
assert key in state, f"Key {key} not in state {state}"
value = state.get(key)
assert hasattr(self, key), f"Savable field {key} not in dataset {self}"
if isinstance(getattr(self, key), Savable):
getattr(self, key).restore_state(value)
else:
setattr(self, key, value)
@abstractmethod
def reset_state_own(self) -> None:
"""Resets the state of the dataset to the initial state. Can only be called in a worker process."""
...
def reset_state_deep(self) -> None:
"""Resets the state of the dataset to the initial state. Can only be called in a worker process."""
self.reset_state_own()
@abstractmethod
def worker_has_samples(self) -> bool:
"""Returns True if the worker's split has samples. This is used to determine if this dataset
yields anything."""
...
@staticmethod
def _function_config(fn: Callable) -> str:
mod = inspect.getmodule(fn)
if mod is not None:
mod_name = mod.__name__
else:
mod_name = getattr(fn, "__module__", "<unknown>")
return f"{mod_name}.{getattr(fn, '__qualname__', getattr(fn, '__name__', '<unknown>'))}"
@abstractmethod
def config(self) -> Dict[str, Any]:
"""Return a config dict that can be used to check if datasets have the same settings.
Variables in dicts starting with "_" represent a possibly changable setting, like a full
path which may be changed."""
return {
"type": type(self).__qualname__,
}
def can_restore_sample(self) -> bool:
"""Returns True if the dataset can restore a sample from a key."""
return False
def assert_can_restore(self) -> None:
"""Asserts that the dataset can restore a sample from a key."""
assert self.can_restore_sample(), "This dataset cannot restore samples."
def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample:
"""
Generic key type, because it might be either an integer (for a core dataset), or something
more complex (e.g. for blended datasets).
Default raises an exception (assumed non-deterministic if not implemented, does not
guarantee determinism).
"""
raise NotImplementedError(
"This dataset does not support indexing, because it is not safely deterministic."
)
class BaseCoreDatasetFactory(Generic[T_sample], ABC):
"""Base type for an inner dataset sample loader. This factory can be used to construct a sample loader, or for
joining in a joined dataset."""
__sample_type__: Type[T_sample] = cast(Type[T_sample], None)
paths: List[EPath]
subflavors: Dict[str, Any]
@abstractmethod
def build(self, worker_rotation_offset: int = 0) -> SavableDataset[T_sample]:
"""Builds the dataset."""
...
@abstractmethod
def as_file_store(self) -> "FileStore":
"""Returns the dataset as a random access dataset."""
...
@abstractmethod
def __len__(self) -> int:
"""Returns the length of the dataset across all ranks."""
...
def add_sample_restore_key(
sample: T_sample, *key: Union[int, str], src: Any, fail_otherwise: bool = False
) -> T_sample:
"""Adds a key to a sample. The sample must be a valid `Sample` or dict containing
__restore_key__, which is a tuple of keys that can be used to restore the inner sample.
This restore key is prepended with the `key`."""
if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"):
try:
sample.__restore_key__ = (type(src).__name__, *key, *sample.__restore_key__)
except KeyError:
pass
elif isinstance(sample, dict) and "__restore_key__" in sample:
sample["__restore_key__"] = (type(src).__name__, *key, *sample["__restore_key__"])
elif fail_otherwise:
raise RuntimeError(
"Did not yield a sample with a restore key, but is marked stateless/deterministic."
)
return sample
def set_sample_restore_key(
sample: T_sample, *key: Union[int, str], src: Any, fail_otherwise: bool = False
) -> T_sample:
"""Sets the restore key for a sample. The sample must be a valid `Sample` or dict containing
__restore_key__, which is a tuple of keys that can be used to restore the inner sample.
This restore key is prepended with the `key`."""
if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"):
try:
sample.__restore_key__ = (type(src).__name__, *key)
except KeyError:
pass
elif isinstance(sample, dict) and "__restore_key__" in sample:
sample["__restore_key__"] = (type(src).__name__, *key)
elif fail_otherwise:
raise RuntimeError(
"Did not yield a sample with a restore key, but is marked stateless/deterministic."
)
return sample
def legacy_handler(
handler: Union[
Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None],
Callable[[Exception, Optional[str]], None],
],
) -> Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None]:
"""Safely returns the new style three argument handler. If the handler takes 2 arguments, it wraps it."""
import functools
import inspect
handler_sig = inspect.signature(handler)
if len(handler_sig.parameters) != 3:
original_handler = handler
@functools.wraps(original_handler)
def wrapped_handler(
exc: Exception, key: Optional[str], source_infos: Optional[list[SourceInfo]]
) -> None:
return original_handler(exc, key)
return wrapped_handler
else:
return handler
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import torch
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class CaptioningSample(Sample):
"""Sample type for image captioning."""
#: The input image tensor in the shape (C, H, W)
image: torch.Tensor
#: The caption string
caption: str
class CaptioningWebdataset(DefaultDecoderWebdatasetFactory[CaptioningSample]):
__sample_type__ = CaptioningSample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Callable, Dict, List, Optional, Union
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
class CrudeSample(dict):
"""Generic sample type to be processed later."""
class CrudeWebdataset(DefaultDecoderWebdatasetFactory[CrudeSample]):
"""The CrudeWebdataset is used to load crude / raw samples and
decode them in the user code using so-called cookers.
See the documentation under "Crude Data" for more information.
"""
__sample_type__ = CrudeSample
def __init__(
self,
path: EPath,
*,
subflavors: Optional[Dict[str, Any]] = None,
part_filter: Union[str, List[str], Callable[[str], bool]] = lambda _: True,
**kwargs,
):
"""
Constructs a crude webdataset.
Args:
path: Root path to the joined datasets.
subflavors: Subflavors dictionary to set for all loaded samples.
part_filter: Function for filtering tar files to load by dict keys.
**kwargs: Additional arguments to the BaseWebdataset constructor.
"""
# We skip the parent class __init__ and call the BaseWebdataset.__init__ directly
if "sample_loader" in kwargs:
raise ValueError("sample_loader is not allowed to be set when using CrudeWebdataset")
super().__init__(
path,
subflavors=subflavors,
sample_loader=lambda sample: sample,
part_filter=part_filter,
**kwargs,
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import torch
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class ImageSample(Sample):
"""Sample type for an image, e.g. for image reconstruction."""
#: The input image tensor in the shape (C, H, W)
image: torch.Tensor
class ImageWebdataset(DefaultDecoderWebdatasetFactory[ImageSample]):
__sample_type__ = ImageSample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Optional
import torch
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class ImageClassificationSample(Sample):
"""Sample type for classifying an image."""
#: The input image tensor in the shape (C, H, W)
image: torch.Tensor
#: The class label of the image
label: Optional[int] = None
#: The class label of the image
label_name: Optional[str] = None
class ImageClassificationWebdataset(DefaultDecoderWebdatasetFactory[ImageClassificationSample]):
__sample_type__ = ImageClassificationSample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from typing import List, Union
import torch
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import warn_deprecated
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory
@edataclass
class InterleavedSample(Sample):
"""Sample type for interleaved media such as text with images."""
#: The interleaved media (either torch.tensor for an image, or str for text)
sequence: List[Union[torch.Tensor, str]]
class InterleavedWebdataset(DefaultDecoderWebdatasetFactory[InterleavedSample]):
__sample_type__ = InterleavedSample
def __init__(self, path: EPath, **kwargs):
warn_deprecated(
f"{type(self)} is deprecated, use the default instead and set the sample_type:\n"
f"To convert, update your {path}/.nv-meta/dataset.yaml to:\n"
f"# remove top-level __module__ and __class__\n"
f"sample_type:\n"
f" __module__: megatron.energon\n"
f" __class__: {self.__sample_type__.__name__}\n"
f"# Keep the remaining content"
)
super().__init__(path, **kwargs)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from megatron.energon.flavors.jsonl.crude_jsonl_dataset import (
CrudeJsonlDatasetFactory,
DefaultCrudeJsonlDatasetFactory,
)
__all__ = [
"CrudeJsonlDatasetFactory",
"DefaultCrudeJsonlDatasetFactory",
]
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import json
import logging
from typing import Any, Callable, Dict, Optional
from megatron.energon.cache import FileStore
from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import (
BaseCoreDatasetFactory,
SavableDataset,
legacy_handler,
)
from megatron.energon.flavors.crude import CrudeSample
from megatron.energon.flavors.jsonl.ijsonl import IJsonlIndexReader
from megatron.energon.flavors.jsonl.jsonl_prepare import JsonlPreparator
from megatron.energon.flavors.webdataset.base_webdataset import _print_shard_slices
from megatron.energon.flavors.webdataset.error_handler import ErrorHandler
from megatron.energon.flavors.webdataset.sample_loader import (
RawSampleData,
WebdatasetSampleLoaderDataset,
)
from megatron.energon.flavors.webdataset.sharder import Sharder
from megatron.energon.flavors.webdataset.structs import (
DatasetSubset,
FilteredSample,
ShardInfo,
reraise_exception,
)
from megatron.energon.source_info import SourceInfo
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.map_dataset import MapDataset
logger = logging.getLogger(__name__)
class CrudeJsonlDatasetFactory(
BaseCoreDatasetFactory[CrudeSample],
JsonlPreparator,
Sharder,
ErrorHandler,
):
"""
Factory class for creating a crude dataset from JSONL (JSON Lines) files.
This factory creates datasets from JSONL files where each line contains a JSON object.
The samples are returned as CrudeSample objects (dictionary-like) containing the raw JSON data.
"""
__sample_type__ = CrudeSample
path: EPath
training: bool
worker_config: WorkerConfig
def __init__(
self,
path: EPath,
*,
training: bool,
worker_config: WorkerConfig,
shuffle_over_epochs: Optional[int] = 1,
parallel_shard_iters: Optional[int] = None,
max_samples_per_sequence: Optional[int] = None,
subset: Optional[DatasetSubset] = None,
part_filter: Optional[Callable[[str], bool]] = None,
handler: Callable[
[Exception, Optional[str], Optional[list[SourceInfo]]], None
] = reraise_exception,
):
"""
Factory for a jsonl file as a crude dataset.
Args:
path: Path to the jsonl file.
training: If true, apply shuffling and loop the dataset.
worker_config: Configuration for the workers.
shuffle_over_epochs: Only effective if training=True.
How many epochs to shuffle over if training.
If = 1, every sample is seen exactly once per epoch.
If > 1, samples (or rather shard slices) are shuffled within this number of epochs
(i.e. randomly selected without replacement).
If -1, the shards are effectively shuffle over infinite epochs (i.e. shard slices
are drawn with replacement).
parallel_shard_iters: Number of parallel opened shards per worker, shuffling between.
max_samples_per_sequence: Maximum number of samples per sequence (=how many samples
will be sequentially iterated).
subset: If specified, the dataset will be subsetted.
part_filter: (internal) Function for filtering tar files by dict keys
handler: Exception handler. Args: (exception, key).
"""
assert self.__sample_type__ is not None, f"Class {type(self)} must define __sample_type__"
self.path = path
self.paths = [path]
self.training = training
self.worker_config = worker_config
self.shuffle_over_epochs = shuffle_over_epochs
self.parallel_shard_iters = parallel_shard_iters
self.max_samples_per_sequence = max_samples_per_sequence
self.subset = subset
self.part_filter = part_filter
self.handler = legacy_handler(handler)
if part_filter is None or part_filter("json"):
self._len = IJsonlIndexReader.count_samples(path)
else:
self._len = 0
assert self.path.size() == IJsonlIndexReader.size(path), (
"The index of the jsonl file does not match the file. Regenerate the index."
)
def __len__(self) -> int:
return self._len
def build(self, worker_rotation_offset: int = 0) -> SavableDataset[CrudeSample]:
from megatron.energon.flavors.jsonl.ijsonl_reader import IJsonlReader
if self.parallel_shard_iters is None:
if self.training:
# 16 seems to be a good choice since we don't want too many file handles open
parallel_shard_iters = 16
else:
parallel_shard_iters = 1
else:
parallel_shard_iters = self.parallel_shard_iters
virtual_shards = [
ShardInfo(
name=self.path.name,
path=self.path,
count=self._len,
)
]
workers_sample_slice_offsets = self.shard_workers(
virtual_shards,
worker_config=self.worker_config,
max_samples_per_sequence=self.max_samples_per_sequence,
rotation_offset=worker_rotation_offset,
subset=self.subset,
)
_print_shard_slices(self.worker_config, virtual_shards, workers_sample_slice_offsets)
itar_reader = IJsonlReader(
self.path,
index_cache_size=parallel_shard_iters,
)
dataset = WebdatasetSampleLoaderDataset(
join_readers=[itar_reader],
workers_sample_slice_offsets=workers_sample_slice_offsets,
worker_config=self.worker_config,
shuffle_over_epochs=self.shuffle_over_epochs if self.training else None,
parallel_slice_iters=parallel_shard_iters,
)
return MapDataset(
dataset,
self._load_sample_raw,
error_handler=self.error_handler,
stateless_map_fn=True,
map_fn_config=self.config,
worker_config=self.worker_config,
)
def as_file_store(self) -> "FileStore":
from megatron.energon.cache.file_store import JsonlFileStore
return JsonlFileStore(self.path)
def _load_sample(self, sample: FilteredSample) -> CrudeSample:
return CrudeSample(sample)
def _load_sample_raw(self, raw_sample: RawSampleData) -> CrudeSample:
# Just a wrapper for the inner tuple. Tuple should be of length 1.
assert len(raw_sample.data) == 1 and raw_sample.data[0] is not None
return self._load_sample(raw_sample.data[0])
def config(self) -> Dict[str, Any]:
return dict(
type=type(self).__qualname__,
training=self.training,
_path=str(self.path),
jsonl_filename=self.path.name,
count=self._len,
shuffle_over_epochs=self.shuffle_over_epochs,
parallel_shard_iters=self.parallel_shard_iters,
max_samples_per_sequence=self.max_samples_per_sequence,
subset=self.subset.config() if self.subset is not None else None,
)
def __str__(self):
return f"{type(self).__name__}(path={self.path})"
class DefaultCrudeJsonlDatasetFactory(CrudeJsonlDatasetFactory):
"""
Adds subflavors to the sample and loads the json.
"""
def __init__(self, path: EPath, *, subflavors: Optional[Dict[str, Any]] = None, **kwargs):
if "decoder" in kwargs:
del kwargs["decoder"]
super().__init__(path, **kwargs)
self.subflavors = subflavors
def _load_sample(self, sample: FilteredSample) -> CrudeSample:
sample["__subflavors__"] = self.subflavors
# Instead of using a decoder, we just load the json here, as we know it's json.
sample["json"] = json.loads(sample["json"])
return super()._load_sample(sample)
def config(self) -> Dict[str, Any]:
return dict(
**super().config(),
subflavors=self.subflavors,
)
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import struct
from typing import BinaryIO, Dict, Generator, Optional, Tuple, Union
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib import EPath
IJSONL_SUFFIX = ".jsonl.idx"
@edataclass
class IJsonlSamplePointer:
"""
Points to a sample inside some jsonl file on disk.
"""
# The index of the sample in the jsonl file.
index: int
# The byte offset of the sample in the jsonl file.
byte_offset: int
# The size of the sample in the jsonl file.
byte_size: int
class IJsonlIndexReader:
def __init__(self, jsonl_path: Union[EPath, str]):
jsonl_path = EPath(jsonl_path)
index_path = jsonl_path.with_suffix(IJSONL_SUFFIX)
self._length = index_path.size() // 8
self.ijsonl = index_path.open("rb")
def __getitem__(self, index: int) -> int:
if index >= self._length or index < 0:
raise IndexError(f"Index {index} out of range")
if self.ijsonl.tell() != 8 * index:
self.ijsonl.seek(8 * index)
return struct.unpack("Q", self.ijsonl.read(8))[0]
def __iter__(self) -> Generator[int, None, None]:
self.ijsonl.seek(0)
while True:
raw = self.ijsonl.read(8)
if len(raw) == 0:
break
assert len(raw) == 8
yield struct.unpack("Q", raw)[0]
def __len__(self) -> int:
return self._length
def close(self):
self.ijsonl.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
@staticmethod
def count_samples(jsonl_path: EPath | str) -> int:
return EPath(jsonl_path).with_suffix(IJSONL_SUFFIX).size() // 8 - 1
@staticmethod
def size(jsonl_path: EPath) -> int:
with IJsonlIndexReader(jsonl_path) as reader:
return reader[len(reader) - 1]
class IJsonlIndexWriter:
def __init__(self, jsonl_path: EPath):
self.final_name = jsonl_path.with_suffix(IJSONL_SUFFIX)
self.tmp_name = jsonl_path.with_suffix(IJSONL_SUFFIX + ".tmp")
self.ijsonl = self.tmp_name.open("wb")
def append(self, offset: int):
self.ijsonl.write(struct.pack("Q", offset))
def close(self, finalize: bool = True):
self.ijsonl.close()
if finalize:
self.tmp_name.move(self.final_name)
else:
self.tmp_name.unlink()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close(finalize=exc_val is None)
@edataclass
class CacheEntry:
ijsonl_index_reader: IJsonlIndexReader
lookahead_offset: Optional[int] = None
lookahead_byteoffset: Optional[int] = None
class CachedIJsonlOffsetReader:
"""
This class is a high-level wrapper around IJsonlIndexReader that caches some
of the recent lookups for faster access. It is designed for the case when
you need to read multiple offsets from the same jsonl file.
Args:
cache_size: The number of entries to keep in the cache. By default, we keep 32.
"""
def __init__(self, jsonl_file: Union[str, EPath], cache_size: int = 32):
# Maps current_offset -> CacheEntry
self.ijsonl_index_reader_cache: Dict[int, CacheEntry] = {}
self.cache_size = cache_size
self.jsonl_file = EPath(jsonl_file)
def close(self):
for cache_entry in self.ijsonl_index_reader_cache.values():
cache_entry.ijsonl_index_reader.close()
self.ijsonl_index_reader_cache.clear()
def _find_or_create_entry(
self,
sample_offset: int,
) -> CacheEntry:
"""
1. If we already have a key == sample_offset, return it.
2. Otherwise, create a new entry or reuse the oldest entry.
"""
# Direct hit in the cache?
if sample_offset in self.ijsonl_index_reader_cache:
return self.ijsonl_index_reader_cache[sample_offset]
# We didn't find an existing entry. Create a new one.
# Evict if needed.
if len(self.ijsonl_index_reader_cache) >= self.cache_size:
# Reuse the oldest entry
oldest_key = next(iter(self.ijsonl_index_reader_cache))
cache_entry = self.ijsonl_index_reader_cache.pop(oldest_key)
else:
new_reader = IJsonlIndexReader(self.jsonl_file)
cache_entry = CacheEntry(ijsonl_index_reader=new_reader)
self.ijsonl_index_reader_cache[sample_offset] = cache_entry
return cache_entry
def _get_ijsonl_byte_offset_with_entry(
self,
cache_entry: CacheEntry,
sample_offset: int,
) -> Tuple[int, int]:
"""
Return (start_byte_offset, length_to_next),
possibly using per-entry lookahead for speed.
"""
ijsonl_index_reader = cache_entry.ijsonl_index_reader
# If offset=0, define the result as byte offset=0 for convenience
if sample_offset == 0:
result_byte_offset = 0
elif sample_offset == cache_entry.lookahead_offset:
# Reuse the previously cached byte offset from the lookahead
assert cache_entry.lookahead_byteoffset is not None, (
"Lookahead offset matched but no lookahead byte offset found."
)
result_byte_offset = cache_entry.lookahead_byteoffset
else:
# Normal random access
result_byte_offset = ijsonl_index_reader[sample_offset]
# Prepare the lookahead for (sample_offset+1)
next_offset = sample_offset + 1
try:
cache_entry.lookahead_byteoffset = ijsonl_index_reader[next_offset]
cache_entry.lookahead_offset = next_offset
except IndexError:
cache_entry.lookahead_offset = None
cache_entry.lookahead_byteoffset = None
# length = difference to the next offset, or 0 if none
if cache_entry.lookahead_byteoffset is not None:
length = cache_entry.lookahead_byteoffset - result_byte_offset
else:
length = 0
return result_byte_offset, length
def get_ijsonl_byte_offset(
self,
sample_offset: int = 0,
) -> Tuple[int, int]:
"""
High-level API to get the byte offset and length for the given file & sample_offset.
"""
# Find or create the suitable CacheEntry
entry = self._find_or_create_entry(sample_offset)
# Use (and update) the per-entry lookahead logic
result_byte_offset, length = self._get_ijsonl_byte_offset_with_entry(entry, sample_offset)
# Update cache entry with the new offset
self.ijsonl_index_reader_cache.pop(sample_offset)
if entry.lookahead_offset is not None:
new_key = entry.lookahead_offset
if new_key not in self.ijsonl_index_reader_cache:
self.ijsonl_index_reader_cache[new_key] = entry
else:
# Already have this entry in the cache, so we can close the reader and use the existing one
# TODO: We may actually may want to keep multiple readers open, because they may be multiple
# sequences to the same sequence.
entry.ijsonl_index_reader.close()
else:
# No lookahead, so we can close the reader
entry.ijsonl_index_reader.close()
return result_byte_offset, length
def __len__(self) -> int:
if len(self.ijsonl_index_reader_cache) == 0:
return IJsonlIndexReader.count_samples(self.jsonl_file)
return len(next(iter(self.ijsonl_index_reader_cache.values())).ijsonl_index_reader) - 1
def get_total_size(self) -> int:
if len(self.ijsonl_index_reader_cache) == 0:
self.ijsonl_index_reader_cache[0] = CacheEntry(
ijsonl_index_reader=IJsonlIndexReader(self.jsonl_file)
)
reader = next(iter(self.ijsonl_index_reader_cache.values())).ijsonl_index_reader
return reader[len(reader) - 1]
class IJsonlFile:
"""
This class is a high-level wrapper around a binary file that allows for reading a jsonl file,
with random access while keeping the file open.
Usage:
with open(filename, "rb") as fileobj:
with IJsonlFile(fileobj=fileobj) as f:
data = f.next(offset=101888, size=100)
json.loads(data)
# Or, if you want to read the whole file:
with open(filename, "rb") as fileobj:
with IJsonlFile(fileobj=fileobj) as f:
while True:
data = f.next()
if data is None:
break
json.loads(data)
# Or, if you want to read the whole file:
with open(filename, "rb") as fileobj:
with IJsonlFile(fileobj=fileobj) as f:
for data in f:
json.loads(data)
"""
def __init__(self, fileobj: BinaryIO):
self.fileobj = fileobj
def seek(self, offset: int):
self.fileobj.seek(offset)
def next(self, offset: int | None = None, size: int | None = None) -> bytes | None:
if offset is not None and offset != self.fileobj.tell():
self.fileobj.seek(offset)
if size is None:
entry = self.fileobj.readline()
if entry == b"":
return None
return entry
else:
assert size > 0, "Size must contain at least the line terminator and a json object"
data = self.fileobj.read(size)
if data == b"":
return None
return data
def __iter__(self) -> Generator[bytes, None, None]:
while True:
data = self.next()
if data is None:
break
yield data
def close(self):
self.fileobj.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.fileobj.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment