Commit f356f546 authored by maming's avatar maming
Browse files

Initial commit

parents
Pipeline #3339 canceled with stages
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
from contextlib import contextmanager
from typing import Generator
from .server import S3EmulatorServer
@contextmanager
def s3_emulator(
host: str = "127.0.0.1",
port: int = 0,
*,
credentials: dict[str, str] | None = None,
root_dir: str | None = None,
region: str = "us-east-1",
) -> Generator[S3EmulatorServer, None, None]:
"""Context manager for running an S3 emulator server in the background.
Args:
host: Host to bind the server to
port: Port to bind the server to. Use 0 to let the OS choose a free port.
credentials: Optional credentials mapping
root_dir: Optional directory to persist S3 data
region: Region for authentication
Yields:
The running S3 emulator server instance
"""
server = S3EmulatorServer(
host=host,
port=port,
credentials=credentials,
root_dir=root_dir,
region=region,
)
try:
server.start_background()
yield server
finally:
server.shutdown()
server.join()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""This module tests video decoding."""
import io
import logging
import os
import pickle
import sys
import time
import unittest
from pathlib import Path
import av
import numpy as np
import torch
import torchvision.transforms as transforms
from megatron.energon.av import AVDecoder, AVMetadata, get_clips_uniform, get_single_frames_uniform
# Set multiprocessing start method to 'spawn' on macOS to avoid DataLoader cleanup issues
if sys.platform == "darwin":
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)
def load_video_to_tensor(video_path: str) -> torch.Tensor:
"""Load a video file into a tensor using PyAV directly.
Args:
video_path: Path to the video file
Returns:
Tensor of shape [num_frames, channels, height, width]
"""
container = av.open(video_path)
for stream in container.streams:
cc = stream.codec_context
cc.thread_type = "NONE"
cc.thread_count = 0
frames = []
for frame in container.decode(video=0):
frame = frame.reformat(format="rgb24")
frames.append(torch.from_numpy(frame.to_ndarray()))
video_tensor = torch.stack(frames)
video_tensor = video_tensor.permute(0, 3, 1, 2)
return video_tensor
def tensors_close(tensor1: torch.Tensor, tensor2: torch.Tensor, tolerance: float = 0.01) -> bool:
"""Compare two tensors with a tolerance.
Args:
tensor1: First tensor of frames
tensor2: Second tensor of frames
tolerance: Maximum allowed mean absolute error
Returns:
True if tensors are close enough, False otherwise
"""
if tensor1.shape != tensor2.shape:
raise ValueError("Input tensors must have the same shape.")
tensor1 = tensor1.float() / 255.0
tensor2 = tensor2.float() / 255.0
# Compute Mean Absolute Error
mae = torch.mean(torch.abs(tensor1 - tensor2)).item()
return mae <= tolerance
class TestVideoDecode(unittest.TestCase):
"""Test video decoding functionality."""
def setUp(self):
"""Set up test fixtures."""
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
self.decode_baseline_video_pyav()
self.loaders = [] # Keep track of loaders for cleanup
def tearDown(self):
"""Clean up test fixtures."""
# Clean up any loaders
for loader in self.loaders:
if hasattr(loader, "_iterator"):
loader._iterator = None
if hasattr(loader, "_shutdown_workers"):
try:
loader._shutdown_workers()
except Exception:
pass
def decode_baseline_video_pyav(self):
"""Load the baseline video using PyAV directly."""
self.complete_video_tensor = load_video_to_tensor("tests/data/sync_test.mp4")
def test_decode_all_frames(self):
"""Test decoding all frames from a video file."""
av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes()))
av_data = av_decoder.get_frames()
video_tensor = av_data.video_clips[0]
print(video_tensor.shape)
assert (video_tensor == self.complete_video_tensor).all(), (
"Energon decoded video does not match baseline"
)
def test_decode_metadata(self):
"""Test decoding metadata."""
expected_metadata = [
AVMetadata(
video_duration=63.054,
video_num_frames=1891,
video_fps=30.0,
video_width=192,
video_height=108,
audio_duration=63.103,
audio_channels=2,
audio_sample_rate=48000,
),
AVMetadata(
video_duration=63.03333333333333,
video_num_frames=1891,
video_fps=30.0,
video_width=192,
video_height=108,
audio_duration=63.068,
audio_channels=2,
audio_sample_rate=48000,
),
]
for video_file, expected_metadata in zip(
["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"], expected_metadata
):
av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes()))
assert av_decoder.get_metadata() == expected_metadata, (
f"Metadata does not match expected metadata for {video_file}"
)
assert av_decoder.get_video_duration(get_frame_count=False) in (
(expected_metadata.video_duration, None),
(expected_metadata.video_duration, expected_metadata.video_num_frames),
)
assert av_decoder.get_video_duration(get_frame_count=True) == (
expected_metadata.video_duration,
expected_metadata.video_num_frames,
)
assert av_decoder.get_audio_duration() == expected_metadata.audio_duration
assert av_decoder.get_video_fps() == expected_metadata.video_fps
assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate
def test_decode_strided_resized(self):
"""Test decoding a subset of frames with resizing."""
for video_file in ["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"]:
print(f"================= Testing {video_file} ==================")
av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes()))
video_tensor = get_single_frames_uniform(
av_decoder=av_decoder,
num_frames=64,
video_out_frame_size=(224, 224),
)
# Get strided frames from baseline complete video tensor
strided_baseline_tensor = self.complete_video_tensor[
np.linspace(0, self.complete_video_tensor.shape[0] - 1, 64, dtype=int).tolist()
]
# Now resize the baseline frames
resize = transforms.Resize((224, 224))
strided_resized_baseline_tensor = resize(strided_baseline_tensor)
# We allow small numerical differences due to different resize implementations
assert tensors_close(video_tensor, strided_resized_baseline_tensor, tolerance=0.01), (
"Energon decoded video does not match baseline"
)
def test_video_audio_sync(self):
"""Test decoding video frames and audio clips together."""
av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes()))
# Extract a single frame every 2 seconds and an audio clip (0.05 seconds long) at the same time.
# We extract the frames from the sync video that shows the full white circle on the left,
# when the click sound occurs.
# Note that the click sound is actually off by 0.022 secs in the original video,
# I verified this in Davinci Resolve.
av_data = av_decoder.get_clips(
video_clip_ranges=[(a * 2 + 1 / 30, a * 2 + 1 / 30) for a in range(65)],
audio_clip_ranges=[(a * 2 + 1 / 30, a * 2 + 1 / 30 + 0.05) for a in range(65)],
video_unit="seconds",
audio_unit="seconds",
video_out_frame_size=None,
)
# We drop the first two extracted frames because the click sequence hasn't started yet
video_clips = av_data.video_clips[2:]
audio_clips = av_data.audio_clips[2:]
# Then we check that the first extracted frame is all white in the area (18, 18, 55, 55)
# Image.fromarray(video_clips[0][0, :, 18:55, 18:55].numpy().transpose(1,2,0)).save('circ.png')
assert (video_clips[0][0, :, 18:55, 18:55] > 250).all(), (
"First extracted frame is not all white in the area (18, 18, 55, 55)"
)
# Check that all the video frames are the same (close value)
for video_clip in video_clips:
assert tensors_close(video_clip, video_clips[0], tolerance=0.01), (
"All video frames are not the same"
)
# Check that the first audio clip has the click sound
assert (audio_clips[0] > 0.5).any(), "Audio click not found"
# Check that all the audio clips are the same (close value)
for audio_clip in audio_clips:
assert tensors_close(audio_clip, audio_clips[0], tolerance=0.01), (
"All audio clips are not the same"
)
def test_pickle_decoder(self):
"""Test AVDecoder on a video file can be pickled and unpickled."""
av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes()))
# Get metadata from original decoder
original_metadata = av_decoder.get_metadata()
# Pickle the decoder
pickled_data = pickle.dumps(av_decoder)
# Unpickle the decoder
unpickled_decoder = pickle.loads(pickled_data)
# Verify metadata matches
unpickled_metadata = unpickled_decoder.get_metadata()
assert unpickled_metadata == original_metadata, (
f"Unpickled metadata {unpickled_metadata} does not match original {original_metadata}"
)
# Verify we can still decode frames from the unpickled decoder
video_tensor = get_single_frames_uniform(
av_decoder=unpickled_decoder,
num_frames=16,
video_out_frame_size=(64, 64),
)
# Check that we got the expected shape
assert video_tensor.shape == (16, 3, 64, 64), (
f"Expected shape (16, 3, 64, 64), got {video_tensor.shape}"
)
def load_audio_to_tensor(audio_path: str) -> torch.Tensor:
"""Load an audio file into a tensor using PyAV directly.
Args:
audio_path: Path to the audio file
Returns:
Tensor of shape [channels, samples]
"""
container = av.open(audio_path)
frames = []
for frame in container.decode(audio=0):
frames.append(torch.from_numpy(frame.to_ndarray()))
audio_tensor = torch.cat(frames, dim=-1)
return audio_tensor
class TestAudioDecode(unittest.TestCase):
"""Test audio decoding functionality."""
def setUp(self):
"""Set up test fixtures."""
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
self.decode_baseline_audio_pyav()
self.loaders = [] # Keep track of loaders for cleanup
def tearDown(self):
"""Clean up test fixtures."""
# Clean up any loaders
for loader in self.loaders:
if hasattr(loader, "_iterator"):
loader._iterator = None
if hasattr(loader, "_shutdown_workers"):
try:
loader._shutdown_workers()
except Exception:
pass
def decode_baseline_audio_pyav(self):
"""Load the baseline audio using PyAV directly."""
self.complete_audio_tensor = load_audio_to_tensor("tests/data/test_audio.flac")
def test_decode_all_samples(self):
"""Test decoding all samples from an audio file."""
with open("tests/data/test_audio.flac", "rb") as f:
raw_bytes = f.read()
stream = io.BytesIO(raw_bytes)
av_decoder = AVDecoder(stream)
av_data = av_decoder.get_audio()
audio_tensor = av_data.audio_clips[0]
assert (audio_tensor == self.complete_audio_tensor).all(), (
"Energon decoded audio does not match baseline"
)
def test_decode_clips(self):
"""Test decoding multiple clips from an audio file."""
with open("tests/data/test_audio.flac", "rb") as f:
raw_bytes = f.read()
stream = io.BytesIO(raw_bytes)
av_decoder = AVDecoder(stream)
av_data = get_clips_uniform(
av_decoder=av_decoder, num_clips=5, clip_duration_seconds=3, request_audio=True
)
audio_tensor = av_data.audio_clips[0]
audio_sps = av_decoder.get_audio_samples_per_second()
# Check audio tensor shape (5 clips, channels, 3 seconds at original sample rate)
assert len(av_data.audio_clips) == 5
assert len(av_data.audio_timestamps) == 5
assert audio_tensor.shape[1] >= int(3 * audio_sps)
assert audio_tensor.shape[1] <= int(4 * audio_sps)
def test_decode_wav(self):
"""Test decoding a WAV file."""
# Skip WAV test if file doesn't exist
if not os.path.exists("tests/data/test_audio.wav"):
self.skipTest("WAV test file not found")
return
with open("tests/data/test_audio.wav", "rb") as f:
raw_bytes = f.read()
stream = io.BytesIO(raw_bytes)
av_decoder = AVDecoder(stream)
av_data = get_clips_uniform(
av_decoder=av_decoder, num_clips=3, clip_duration_seconds=3, request_audio=True
)
audio_sps = av_decoder.get_audio_samples_per_second()
# Check audio tensor shape (3 clips, 2 channels, samples)
expected_samples = int(3 * audio_sps) # 3 seconds at original sample rate
assert all(
audio_tensor.shape == torch.Size([2, expected_samples])
for audio_tensor in av_data.audio_clips
), "Energon decoded WAV file has wrong shape."
def test_decode_wav_same_shape(self):
"""Test decoding a WAV file."""
# Skip WAV test if file doesn't exist
if not os.path.exists("tests/data/test_audio.wav"):
self.skipTest("WAV test file not found")
return
with open("tests/data/test_audio.wav", "rb") as f:
raw_bytes = f.read()
stream = io.BytesIO(raw_bytes)
av_decoder = AVDecoder(stream)
av_data = get_clips_uniform(
av_decoder=av_decoder,
num_clips=10,
clip_duration_seconds=0.9954783485892385,
request_audio=True,
)
audio_sps = av_decoder.get_audio_samples_per_second()
print(f"SPS: {audio_sps}")
for audio_tensor in av_data.audio_clips:
print(audio_tensor.shape)
assert all(
audio_tensor.shape == av_data.audio_clips[0].shape
for audio_tensor in av_data.audio_clips
), "Audio clips have different shapes"
def test_wav_decode_against_soundfile(self):
"""Test decoding a WAV file against the soundfile library."""
try:
import soundfile
except ImportError:
self.skipTest("soundfile library not found")
with open("tests/data/test_audio.wav", "rb") as f:
raw_bytes = f.read()
stream = io.BytesIO(raw_bytes)
av_decoder = AVDecoder(stream)
av_data = av_decoder.get_clips(audio_clip_ranges=[(0, float("inf"))], audio_unit="samples")
audio_tensor = av_data.audio_clips[0]
# Load the same audio file using soundfile
audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16")
audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1)
# Check that the two tensors are close
assert tensors_close(audio_tensor, audio_tensor_soundfile, tolerance=0.01), (
"Energon decoded audio does not match baseline"
)
# Now check partial extraction in the middle of the audio
av_data = av_decoder.get_clips(audio_clip_ranges=[(0.5, 1.0)], audio_unit="seconds")
audio_tensor = av_data.audio_clips[0]
audio_sps = av_decoder.get_audio_samples_per_second()
audio_tensor_soundfile = torch.from_numpy(
audio_data[int(0.5 * audio_sps) : int(1.0 * audio_sps)]
).transpose(0, 1)
# Check that the two tensors are close
assert tensors_close(audio_tensor, audio_tensor_soundfile, tolerance=0.01), (
"Energon decoded audio does not match baseline"
)
# Now compare the speed of the two implementations by repeatedly decoding the same audio
num_trials = 100
start_time = time.perf_counter()
for _ in range(num_trials):
av_data = av_decoder.get_clips(
audio_clip_ranges=[(0, float("inf"))], audio_unit="samples"
)
audio_tensor = av_data.audio_clips[0]
end_time = time.perf_counter()
print(f"AVDecoder time: {end_time - start_time} seconds")
# Now do the same with soundfile
start_time = time.perf_counter()
for _ in range(num_trials):
audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16")
audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1)
end_time = time.perf_counter()
print(f"Soundfile time: {end_time - start_time} seconds")
start_time = time.perf_counter()
for _ in range(num_trials):
av_data = av_decoder.get_clips(
audio_clip_ranges=[(0, float("inf"))], audio_unit="samples"
)
audio_tensor = av_data.audio_clips[0]
end_time = time.perf_counter()
print(f"AVDecoder time: {end_time - start_time} seconds")
# Now do the same with soundfile
start_time = time.perf_counter()
for _ in range(num_trials):
audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16")
audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1)
end_time = time.perf_counter()
print(f"Soundfile time: {end_time - start_time} seconds")
def test_decode_metadata(self):
"""Test decoding metadata."""
expected_metadata = [
AVMetadata(
audio_duration=10.0,
audio_channels=1,
audio_sample_rate=32000,
),
AVMetadata(
audio_duration=12.782585034013605,
audio_channels=2,
audio_sample_rate=44100,
),
]
for audio_file, expected_metadata in zip(
["tests/data/test_audio.flac", "tests/data/test_audio.wav"], expected_metadata
):
av_decoder = AVDecoder(io.BytesIO(Path(audio_file).read_bytes()))
assert av_decoder.get_metadata() == expected_metadata, (
f"Metadata does not match expected metadata for {audio_file}: {av_decoder.get_metadata()}"
)
assert av_decoder.get_audio_duration() == expected_metadata.audio_duration
assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""This module defines tests for crude datasets."""
import gc
import logging
import pickle
import re
import sys
import tempfile
import unittest
import warnings
from pathlib import Path
from typing import List
import torch
import webdataset as wds
from megatron.energon import (
Batch,
Cooker,
DefaultTaskEncoder,
TextSample,
WorkerConfig,
basic_sample_keys,
get_savable_loader,
get_train_dataset,
stateless,
)
from megatron.energon.cache import FileStore, FileStoreCachePool, Lazy
from megatron.energon.cache.base import CachePool
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.flavors.base_dataset import Sample
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder
from megatron.energon.flavors.webdataset.structs import reraise_exception
from megatron.energon.source_info import SourceInfo
from megatron.energon.task_encoder.cooking import cooker
@edataclass
class LazyTextSample(Sample):
txt: str
next_txt: Lazy[str]
# Speed up tests significantly by reducing the torch status check interval for broken worker shutdown
try:
torch.utils.data._utils.worker.MP_STATUS_CHECK_INTERVAL = 0.1
torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.1
except AttributeError:
pass
@edataclass
class TextBatch(Batch):
txts: List[str]
@stateless
def cook_text(sample: dict) -> TextSample:
return TextSample(
**basic_sample_keys(sample),
text=f"<{sample['txt']}>",
)
@stateless
def cook_other(sample: dict) -> TextSample:
d = pickle.loads(sample["pkl"])
return TextSample(
**basic_sample_keys(sample),
text=f"<{sample['txt']}|{d['idx']}>",
)
@stateless
def cook_aux(sample: dict, pkl_source: FileStore, fs_source: FileStore) -> TextSample:
# ds2 is offset by 100
d = pkl_source.get(f"{int(sample['txt']) + 100:06d}.txt", sample)
return TextSample(
**basic_sample_keys(sample),
text=f"<{sample['txt']}|aux|{d}>",
)
class CookingTaskEncoder(DefaultTaskEncoder[TextSample, TextSample, TextBatch, TextBatch]):
"""A simple task encoder for captioning."""
cookers = [
Cooker(cook_text, has_subflavors={"crude_type": "txtpkl"}),
Cooker(cook_other, has_subflavors={"crude_type": "otherpkl"}),
Cooker(cook_aux, has_subflavors={"crude_type": "aux_random_access"}),
]
def batch(self, samples: List[TextSample]) -> TextBatch:
return TextBatch.from_samples(
samples,
txts=[sample.text for sample in samples],
)
def select_samples_to_pack(self, samples):
return [[sample] for sample in samples]
@stateless
def pack_selected_samples(self, samples):
return samples[0]
@stateless
def cook_aux_filesystem_reference(
sample: dict, pkl_source: FileStore, fs_source: FileStore
) -> TextSample:
d = fs_source.get("aux_metadataset.yaml", sample)[:25].decode()
return TextSample(
**basic_sample_keys(sample),
text=f"<{sample['txt']}|aux|{d}>",
)
class CookingTaskEncoderWithAuxFilesystemReference(CookingTaskEncoder):
cookers = [
Cooker(cook_aux_filesystem_reference, has_subflavors={"crude_type": "aux_random_access"}),
]
@stateless
@cooker(need_cache=True, need_primary=True)
def cook_aux_primary_cache(
sample: dict, primary: FileStore, pkl_source: FileStore, fs_source: FileStore, cache: CachePool
) -> LazyTextSample:
# ds2 is offset by 100
d = pkl_source.get(f"{int(sample['txt']) + 100:06d}.txt", sample)
my_lazy_next_txt = cache.get_lazy(primary, f"{(int(sample['txt']) + 1) % 55:06d}.txt")
return LazyTextSample(
**basic_sample_keys(sample),
txt=f"<{sample['txt']}|aux|{d}>",
next_txt=my_lazy_next_txt,
)
class LazyCookingTaskEncoder(
DefaultTaskEncoder[LazyTextSample, LazyTextSample, TextBatch, TextBatch]
):
# Classvar is fine here.
decoder = SampleDecoder(image_decode="pilrgb")
cookers = [
Cooker(cook_aux_primary_cache, has_subflavors={"crude_type": "aux_random_access"}),
]
def select_samples_to_pack(self, samples: List[LazyTextSample]) -> List[List[LazyTextSample]]:
return [[sample] for sample in samples]
@stateless
def pack_selected_samples(self, samples: List[LazyTextSample]) -> TextSample:
assert len(samples) == 1, f"Expected 1 sample, got {len(samples)}"
next_txt = samples[0].next_txt.get(samples[0])
return TextSample.derive_from(
samples[0],
text=samples[0].txt + "|" + next_txt,
)
def batch(self, samples: List[TextSample]) -> TextBatch:
return TextBatch.from_samples(
samples,
txts=[sample.text for sample in samples],
)
class LazyCookingTaskEncoderWithPostencode(
DefaultTaskEncoder[LazyTextSample, LazyTextSample, TextBatch, TextBatch]
):
# Classvar is fine here.
decoder = SampleDecoder(image_decode="pilrgb")
cookers = [
Cooker(cook_aux_primary_cache, has_subflavors={"crude_type": "aux_random_access"}),
]
@stateless
def postencode_sample(self, sample: LazyTextSample) -> TextSample:
assert isinstance(sample, LazyTextSample)
return TextSample.derive_from(
sample,
text=sample.txt + "|" + sample.next_txt.get(sample),
)
def select_samples_to_pack(self, samples: List[LazyTextSample]) -> List[List[LazyTextSample]]:
return [[sample] for sample in samples]
@stateless
def pack_selected_samples(self, samples: List[TextSample]) -> TextSample:
assert len(samples) == 1
return samples[0]
def batch(self, samples: List[TextSample]) -> TextBatch:
return TextBatch.from_samples(
samples,
txts=[sample.text for sample in samples],
)
class GenericCookingTaskEncoder(DefaultTaskEncoder[TextSample, TextSample, TextBatch, TextBatch]):
"""A simple task encoder for captioning."""
cookers = [Cooker(cook_text)]
def batch(self, samples: List[TextSample]) -> TextBatch:
return TextBatch.from_samples(
samples,
txts=[sample.text for sample in samples],
)
class TestDataset(unittest.TestCase):
# Set up the test fixture
def setUp(self):
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
warnings.simplefilter("ignore", ResourceWarning)
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory()
self.dataset_path = Path(self.temp_dir.name)
# self.dataset_path = Path("./test_dataset")
self.dataset_path.mkdir(exist_ok=True, parents=True)
(self.dataset_path / "ds1").mkdir(exist_ok=True, parents=True)
(self.dataset_path / "ds2").mkdir(exist_ok=True, parents=True)
# Create a small dummy captioning dataset
self.create_crude_text_test_dataset(self.dataset_path / "ds1", 0)
self.create_crude_text_test_dataset(self.dataset_path / "ds2", 100)
self.mds_path = self.dataset_path / "metadataset.yaml"
with open(self.mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: Metadataset",
"splits:",
" train:",
" datasets:",
" - weight: 1",
" path: ds1",
" subflavors:",
" source: metadataset.yaml",
" number: 43",
" mds: mds",
" crude_type: txtpkl",
" shuffle_over_epochs_multiplier: 3",
" - weight: 1",
" path: ds2",
" subflavors:",
" source: metadataset.yaml",
" number: 44",
" mds: mds",
" crude_type: otherpkl",
" val:",
" datasets:",
" - weight: 1",
" path: ds1",
" split_part: train",
" - weight: 1",
" path: ds2",
" split_part: train",
]
)
)
self.aux_mds_path = self.dataset_path / "aux_metadataset.yaml"
with open(self.aux_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" path: ds1",
" aux:",
" pkl_source: ds2",
" fs_source: filesystem://.",
" subflavors:",
" crude_type: aux_random_access",
]
)
)
print(self.dataset_path)
def tearDown(self):
# Remove all temporary files
gc.collect()
self.temp_dir.cleanup()
@staticmethod
def create_crude_text_test_dataset(path: Path, offset: int):
"""Creates a small dummy test dataset for testing purposes."""
# Create num_samples unique captions
(path / "parts").mkdir(exist_ok=True, parents=True)
# Initialize the ShardWriter
with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer:
for idx in range(55):
# Write individual files to shards
shard_writer.write(
{
"__key__": f"{idx + offset:06d}",
"txt": f"{idx + offset}".encode(),
"pkl": pickle.dumps({"idx": idx + offset}),
},
)
total_shards = shard_writer.shard
from megatron.energon.flavors import BaseWebdatasetFactory
BaseWebdatasetFactory.prepare_dataset(
path,
[f"parts/data-{{0..{total_shards - 1}}}.tar"],
split_parts_ratio=[("train", 1.0)],
shuffle_seed=None,
workers=1,
)
with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: CrudeWebdataset",
"subflavors:",
" dataset.yaml: true",
" number: 42",
]
)
)
def test_metadataset(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
)
# Train mode dataset
torch.manual_seed(42)
train_dataset = get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=3,
task_encoder=CookingTaskEncoder(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
handler=reraise_exception,
)
loader = get_savable_loader(
train_dataset,
)
print(len(train_dataset))
# assert len(train_dataset) == 11
for idx, data in enumerate(loader):
if idx >= len(train_dataset):
break
assert isinstance(data, TextBatch)
print("Batch", idx)
for txt, key in zip(data.txts, data.__key__):
key_int = int(key.split("/")[-1])
if key_int < 100:
assert txt == f"<{key_int}>"
else:
assert txt == f"<{key_int}|{key_int}>"
print(key, txt)
def test_loader(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=2,
)
loader = get_savable_loader(
get_train_dataset(
self.mds_path,
batch_size=2,
worker_config=worker_config,
task_encoder=CookingTaskEncoder(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
packing_buffer_size=2,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
samples = [s.__key__ for idx, s in zip(range(100), loader)]
print(samples)
state = loader.save_state_rank()
samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)]
print(samples_after)
loader = get_savable_loader(
get_train_dataset(
self.mds_path,
batch_size=2,
worker_config=worker_config,
task_encoder=CookingTaskEncoder(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
packing_buffer_size=2,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
loader.restore_state_rank(state)
samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)]
print(samples_restored)
assert all([a == b for a, b in zip(samples_after, samples_restored)])
def test_aux_random_access(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=2,
)
print("Initializing dataset")
loader = get_savable_loader(
get_train_dataset(
self.aux_mds_path,
batch_size=2,
worker_config=worker_config,
task_encoder=CookingTaskEncoder(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
packing_buffer_size=2,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
print("Iterating from dataset")
samples = [s.txts for idx, s in zip(range(100), loader)]
for idx, txts in enumerate(samples):
for txt in txts:
m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>", txt)
assert m, f"Invalid aux text: {txt}"
assert int(m.group(2)) == int(m.group(1)) + 100
print(samples)
state = loader.save_state_rank()
samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)]
print(samples_after)
loader = get_savable_loader(
get_train_dataset(
self.aux_mds_path,
batch_size=2,
worker_config=worker_config,
task_encoder=CookingTaskEncoder(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
packing_buffer_size=2,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
loader.restore_state_rank(state)
samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)]
print(samples_restored)
assert all([a == b for a, b in zip(samples_after, samples_restored)])
def test_aux_random_access_with_cache(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=2,
)
print("Initializing dataset")
loader = get_savable_loader(
get_train_dataset(
self.aux_mds_path,
batch_size=2,
worker_config=worker_config,
task_encoder=LazyCookingTaskEncoder(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
packing_buffer_size=2,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
cache_pool=FileStoreCachePool(
parent_cache_dir=self.dataset_path / "cache",
num_workers=1,
),
)
print("Iterating from dataset")
samples = [s.txts for idx, s in zip(range(100), loader)]
for idx, txts in enumerate(samples):
for txt in txts:
m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>\|([0-9]*)", txt)
assert m, f"Invalid aux text: {txt}"
assert int(m.group(2)) == int(m.group(1)) + 100
assert int(m.group(3)) == (int(m.group(1)) + 1) % 55
print(samples)
state = loader.save_state_rank()
samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)]
print(samples_after)
loader = get_savable_loader(
get_train_dataset(
self.aux_mds_path,
batch_size=2,
worker_config=worker_config,
task_encoder=CookingTaskEncoder(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
packing_buffer_size=2,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
cache_pool=FileStoreCachePool(
parent_cache_dir=self.dataset_path / "cache",
num_workers=1,
),
)
loader.restore_state_rank(state)
samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)]
print(samples_restored)
assert all([a == b for a, b in zip(samples_after, samples_restored)])
def test_aux_random_access_with_cache_and_postencode(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=2,
)
print("Initializing dataset")
loader = get_savable_loader(
get_train_dataset(
self.aux_mds_path,
batch_size=2,
worker_config=worker_config,
task_encoder=LazyCookingTaskEncoderWithPostencode(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
packing_buffer_size=2,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
cache_pool=FileStoreCachePool(
parent_cache_dir=self.dataset_path / "cache",
num_workers=1,
),
)
print("Iterating from dataset")
samples = [s.txts for idx, s in zip(range(100), loader)]
for idx, txts in enumerate(samples):
for txt in txts:
m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>\|([0-9]*)", txt)
assert m, f"Invalid aux text: {txt}"
assert int(m.group(2)) == int(m.group(1)) + 100
assert int(m.group(3)) == (int(m.group(1)) + 1) % 55
print(samples)
state = loader.save_state_rank()
samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)]
print(samples_after)
loader = get_savable_loader(
get_train_dataset(
self.aux_mds_path,
batch_size=2,
worker_config=worker_config,
task_encoder=LazyCookingTaskEncoderWithPostencode(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
packing_buffer_size=2,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
cache_pool=FileStoreCachePool(
parent_cache_dir=self.dataset_path / "cache",
num_workers=1,
),
)
loader.restore_state_rank(state)
samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)]
print(samples_restored)
assert all([a == b for a, b in zip(samples_after, samples_restored)])
# Verify that the sources are correct
sample_src_check = [s.__sources__ for idx, s in zip(range(1), loader)][0]
print(sample_src_check)
# NOTE: Auxiliary sources have string as index, not int
assert sample_src_check == (
# Primary source for the sample, reading all source files
SourceInfo(
dataset_path=EPath(self.dataset_path / "ds1"),
index=2,
shard_name="parts/data-0.tar",
file_names=("000002.pkl", "000002.txt"),
),
# Auxiliary source for the sample, reading from ds2
SourceInfo(
dataset_path=EPath(self.dataset_path / "ds2"),
index="000102.txt",
shard_name="parts/data-0.tar",
file_names=("000102.txt",),
),
# Auxiliary source for the sample, reading from ds1, but next sample
SourceInfo(
dataset_path=EPath(self.dataset_path / "ds1"),
index="000003.txt",
shard_name="parts/data-0.tar",
file_names=("000003.txt",),
),
SourceInfo(
dataset_path=EPath(self.dataset_path / "ds1"),
index=21,
shard_name="parts/data-2.tar",
file_names=("000021.pkl", "000021.txt"),
),
SourceInfo(
dataset_path=EPath(self.dataset_path / "ds2"),
index="000121.txt",
shard_name="parts/data-2.tar",
file_names=("000121.txt",),
),
SourceInfo(
dataset_path=EPath(self.dataset_path / "ds1"),
index="000022.txt",
shard_name="parts/data-2.tar",
file_names=("000022.txt",),
),
)
def test_aux_filesystem_reference(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
)
loader = get_savable_loader(
get_train_dataset(
self.aux_mds_path,
batch_size=1,
worker_config=worker_config,
task_encoder=CookingTaskEncoderWithAuxFilesystemReference(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
)
sample = next(iter(loader))
assert sample.txts[0].endswith("|aux|__module__: megatron.ener>")
def test_nomds(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=2,
)
loader = get_savable_loader(
get_train_dataset(
self.dataset_path / "ds1",
batch_size=2,
worker_config=worker_config,
task_encoder=GenericCookingTaskEncoder(),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
samples = [s.__key__ for idx, s in zip(range(100), loader)]
print(samples)
assert len(samples) == 100
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""This module defines tests for the dataset."""
import dataclasses
import gc
import io
import json
import logging
import math
import random
import sys
import tempfile
import unittest
import warnings
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Hashable, List, Tuple, Type, Union
import numpy as np
import torch
import webdataset as wds
from click.testing import CliRunner
from PIL import Image
from megatron.energon import (
AugmentTaskEncoder,
Batch,
BatchDataset,
BlendDataset,
CaptioningSample,
DefaultTaskEncoder,
MapDataset,
MixBatchDataset,
Sample,
SavableDataLoader,
TaskEncoder,
WorkerConfig,
generic_batch,
get_loader,
get_savable_loader,
get_train_dataset,
get_val_dataset,
homogeneous_concat_mix,
)
from megatron.energon.dataset_config import get_dataset_from_config
from megatron.energon.edataclass import edataclass
from megatron.energon.flavors import BaseWebdatasetFactory
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.task_encoder.base import stateless
from megatron.energon.tools.analyze_debug import command as analyze_debug_command
from megatron.energon.tools.info import command as info_command
from megatron.energon.tools.lint import command as lint_command
from megatron.energon.tools.prepare import command as prepare_command
from megatron.energon.tools.preview import command as preview_command
# Speed up tests significantly by reducing the torch status check interval for broken worker shutdown
try:
torch.utils.data._utils.worker.MP_STATUS_CHECK_INTERVAL = 0.1
torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.1
except AttributeError:
pass
DATASET_SIZE = 50
no_worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0)
@edataclass
class ExtendedCaptioningSample(CaptioningSample):
batch_index: int
sample_index: int
rand_num: int
@edataclass
class EncodedCaptioningSample(Sample):
image: torch.Tensor
caption: torch.Tensor
@edataclass
class CaptioningEncodedBatch(CaptioningSample):
pass
@edataclass
class CaptioningBatch(Batch):
image: torch.Tensor
caption: torch.Tensor
class TestDataset(unittest.TestCase):
# Set up the test fixture
def setUp(self):
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
warnings.simplefilter("ignore", ResourceWarning)
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory()
self.dataset_path = Path(self.temp_dir.name)
# self.dataset_path = Path("./test_dataset")
self.dataset_path.mkdir(exist_ok=True, parents=True)
# Create a small dummy captioning dataset
self.samples = self.create_captioning_test_dataset(self.dataset_path, DATASET_SIZE)
print(self.dataset_path)
def tearDown(self):
# Remove all temporary files
gc.collect()
self.temp_dir.cleanup()
@staticmethod
def create_captioning_test_dataset(path: Union[str, Path], num_samples: int = 50):
"""Creates a small dummy captioning dataset for testing purposes."""
path = Path(path)
animals = (
"ant bee beetle bug bumblebee butterfly caterpillar cicada cricket dragonfly earwig "
"firefly grasshopper honeybee hornet inchworm ladybug locust mantis mayfly mosquito "
"moth sawfly silkworm termite wasp woodlouse"
).split()
adjectives = (
"adorable affable amazing amiable attractive beautiful calm charming cherubic classic "
"classy convivial cordial cuddly curly cute debonair elegant famous fresh friendly "
"funny gorgeous graceful gregarious grinning handsome hilarious hot interesting kind "
"laughing lovely meek mellow merciful neat nifty notorious poetic pretty refined "
"refreshing sexy smiling sociable spiffy stylish sweet tactful whimsical"
).split()
# Set random seeds for numpy and torch
np.random.seed(42)
torch.manual_seed(42)
entries = []
assert num_samples < len(animals) * len(adjectives), (
"Cannot generate more samples than unique captions."
)
# Create num_samples unique captions
captions = set()
while len(captions) < num_samples:
# Create random description by sampling from adjectives and animals
adjective = np.random.choice(adjectives)
prefix = "An" if adjective[0] in "aeiou" else "A"
description = f"{prefix} {adjective} {np.random.choice(animals)}."
captions.add(description)
(path / "parts").mkdir(exist_ok=True, parents=True)
# Initialize the ShardWriter
with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=30) as shard_writer:
for idx in range(num_samples):
# Create a dummy image with random noise and save to disk
img_buf = io.BytesIO()
randimg = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
image = Image.fromarray(randimg)
image.save(img_buf, format="PNG")
img_bytes = img_buf.getvalue()
description = captions.pop()
entries.append({"image": randimg, "caption": description})
# Write individual files to shards
shard_writer.write(
{
"__key__": f"{idx:06d}",
"png": img_bytes,
"txt": description.encode("utf-8"),
"json": json.dumps({"caption": description}),
},
)
total_shards = shard_writer.shard
BaseWebdatasetFactory.prepare_dataset(
path,
[f"parts/data-{{0..{total_shards - 1}}}.tar"],
split_parts_ratio=[("train", 1.0)],
)
with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f:
f.write(
"\n".join(
[
"sample_type:",
" __module__: megatron.energon",
" __class__: CaptioningSample",
"field_map:",
" image: png",
" caption: txt",
]
)
)
with open(path / MAIN_FOLDER_NAME / "dataset_field.yaml", "w") as f:
f.write(
"\n".join(
[
"sample_type:",
" __module__: megatron.energon",
" __class__: CaptioningSample",
"field_map:",
" image: png",
" caption: json[caption]",
]
)
)
with open(path / MAIN_FOLDER_NAME / "dataset_sample_loader.yaml", "w") as f:
f.write(
"\n".join(
[
"sample_type:",
" __module__: megatron.energon",
" __class__: CaptioningSample",
"sample_loader: sample_loader.py:sample_loader",
"part_filter: sample_loader.py:part_filter",
]
)
)
with open(path / MAIN_FOLDER_NAME / "dataset_sample_loader_key.yaml", "w") as f:
f.write(
"\n".join(
[
"sample_type:",
" __module__: megatron.energon",
" __class__: CaptioningSample",
"sample_loader: sample_loader.py:sample_loader_key",
"part_filter: sample_loader.py:part_filter",
]
)
)
with open(path / MAIN_FOLDER_NAME / "sample_loader.py", "w") as f:
f.write(
"\n".join(
[
"def sample_loader(raw: dict) -> dict:",
" assert 'txt' not in raw",
" return dict(",
' image=raw["png"],',
' caption="<SL>" + raw["json"]["caption"],',
" )",
"",
"def sample_loader_key(raw: dict) -> dict:",
" assert 'txt' not in raw",
" return dict(",
' __key__="<SL>" + raw["__key__"],',
' image=raw["png"],',
' caption="<SL>" + raw["json"]["caption"],',
" )",
"",
"def part_filter(part: str) -> bool:",
' return part in ["json", "png"]',
"",
]
)
)
with open(path / MAIN_FOLDER_NAME / "dataset_exclude.yaml", "w") as f:
f.write(
"\n".join(
[
"sample_type:",
" __module__: megatron.energon",
" __class__: CaptioningSample",
"field_map:",
" image: png",
" caption: txt",
"split_config: split2.yaml",
]
)
)
with open(path / MAIN_FOLDER_NAME / "split2.yaml", "w") as f:
with open(path / MAIN_FOLDER_NAME / "split.yaml", "r") as rf:
origsplit = rf.read()
f.write(
origsplit
+ "\n"
+ "\n".join(
[
"exclude:",
" - parts/data-0.tar",
" - parts/data-1.tar/00003{5..9}",
]
)
)
return entries
def test_captioning_dataset(self):
ds = get_dataset_from_config(
self.dataset_path,
split_part="train",
worker_config=no_worker_config,
training=False,
sample_type=CaptioningSample,
)
ds = MapDataset(
ds.build(),
lambda x: CaptioningSample(
__key__=x.__key__,
__restore_key__=x.__restore_key__,
__subflavors__=x.__subflavors__,
image=x.image,
caption=torch.tensor(np.frombuffer(x.caption.encode(), dtype=np.uint8)),
),
worker_config=no_worker_config,
)
def get_ld(ds):
return get_loader(ds)
# Check len operator
assert len(ds) == 50
# Check if iterating returns the same
iter1 = list(get_ld(ds))
iter2 = list(get_ld(ds))
assert len(iter1) == 50
assert len(iter2) == 50
assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2))
# Check case when batch size is larger than dataset size
batch_sizes = []
for wrapped_sample in get_ld(
BatchDataset(
ds,
batch_size=DATASET_SIZE * 2,
batcher=generic_batch,
worker_config=no_worker_config,
)
):
batch_sizes.append(wrapped_sample.image.shape[0])
assert batch_sizes == [DATASET_SIZE]
# Check returned dimensions and batch sizes if batch size is smaller than dataset size
batch_size = 4
assert batch_size < DATASET_SIZE
batched_ds = BatchDataset(
ds, batch_size=batch_size, batcher=generic_batch, worker_config=no_worker_config
)
cnt = 0
expected_num_batches = math.ceil(DATASET_SIZE / batch_size)
for idx, wrapped_sample in enumerate(get_ld(batched_ds)):
# Check batch sizes
if idx < expected_num_batches - 1:
assert wrapped_sample.image.shape[0] == batch_size
assert wrapped_sample.caption.shape[0] == batch_size
else:
assert wrapped_sample.image.shape[0] == DATASET_SIZE % batch_size
assert wrapped_sample.caption.shape[0] == DATASET_SIZE % batch_size
# Check image size
assert tuple(wrapped_sample.image.shape[1:]) == (3, 100, 100)
cnt += 1
logging.info(f" Batch {idx}:")
logging.info(f" {wrapped_sample.image.shape=}")
logging.info(f" {wrapped_sample.caption.shape=}")
assert cnt == expected_num_batches
# Check if actual image and caption data are correct
loader = get_ld(
BatchDataset(ds, batch_size=9, batcher=generic_batch, worker_config=no_worker_config),
)
batch_sizes = []
dataset_samples = {sample["caption"]: sample["image"] for sample in self.samples}
for idx, sample in enumerate(loader):
batch_sizes.append(sample.image.shape[0])
for bidx in range(sample.image.shape[0]):
refimg = dataset_samples.pop(
sample.caption[bidx].numpy().tobytes().rstrip(b"\0").decode()
)
assert torch.allclose(
sample.image[bidx],
torch.permute(torch.tensor(refimg, dtype=torch.float32) / 255, (2, 0, 1)),
)
assert len(dataset_samples) == 0
assert batch_sizes == [9, 9, 9, 9, 9, 5]
def test_field_access(self):
ds = get_dataset_from_config(
self.dataset_path,
dataset_config="dataset_field.yaml",
split_part="train",
worker_config=no_worker_config,
training=False,
sample_type=CaptioningSample,
)
captions = set(sample["caption"] for sample in self.samples)
for sample in get_loader(ds.build()):
captions.remove(sample.caption)
assert len(captions) == 0
def test_sample_loader(self):
ds = get_dataset_from_config(
self.dataset_path,
dataset_config="dataset_sample_loader.yaml",
split_part="train",
worker_config=no_worker_config,
training=False,
sample_type=CaptioningSample,
)
captions = set(sample["caption"] for sample in self.samples)
for sample in get_loader(ds.build()):
assert sample.caption[:4] == "<SL>"
captions.remove(sample.caption[4:])
assert len(captions) == 0
def test_sample_loader_key(self):
ds = get_dataset_from_config(
self.dataset_path,
dataset_config="dataset_sample_loader_key.yaml",
split_part="train",
worker_config=no_worker_config,
training=False,
sample_type=CaptioningSample,
)
captions = set(sample["caption"] for sample in self.samples)
keys = set(
f"<SL>parts/data-{idx // 30:d}.tar/{idx:06d}" for idx in range(len(self.samples))
)
for sample in get_loader(ds.build()):
assert sample.caption[:4] == "<SL>"
captions.remove(sample.caption[4:])
keys.remove(sample.__key__)
assert len(captions) == 0
assert len(keys) == 0
def test_exclusion(self):
ds = get_dataset_from_config(
self.dataset_path,
dataset_config="dataset_exclude.yaml",
split_part="train",
worker_config=no_worker_config,
training=False,
sample_type=CaptioningSample,
)
keys = [entry.__key__ for entry in get_loader(ds.build())]
assert keys == [
f"parts/data-1.tar/{i:06d}" for i in list(range(30, 35)) + list(range(40, 50))
], keys
def test_loader(self):
torch.manual_seed(42)
class TestTaskEncoder(DefaultTaskEncoder):
def __init__(self):
super().__init__(raw_batch_type=CaptioningBatch)
def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample:
return EncodedCaptioningSample.derive_from(
sample,
image=sample.image,
caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8),
)
loader = get_loader(
get_train_dataset(
self.dataset_path,
batch_size=10,
worker_config=no_worker_config,
parallel_shard_iters=2,
virtual_epoch_length=2,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=TestTaskEncoder(),
)
)
assert len(loader) == 2
def hist(data):
"""Histogram function"""
r = defaultdict(lambda: 0)
for k in data:
r[k] += 1
return r
print([[batch.__key__ for batch in loader] for _ in range(100)])
keys = [key for _ in range(100) for batch in loader for key in batch.__key__]
# 100 iterations, 2 virtual epoch size, batch size 10
print(len(keys), keys)
keyhist = hist(keys)
print(sorted(keyhist.items()))
print(sorted(keyhist.items(), key=lambda x: (x[1], x[0])))
assert len(keys) == 100 * 2 * 10
# Data should be approximately sampled uniformly (40+-1 samples per key)
assert len(keyhist) == 50
assert all(v in (39, 40, 41) for v in keyhist.values())
loader2 = get_loader(
get_val_dataset(
self.dataset_path,
split_part="train",
batch_size=10,
worker_config=no_worker_config,
task_encoder=TestTaskEncoder(),
)
)
assert len(loader2) == 5
# The order in the split is shuffled this way
assert list(key for batch in loader2 for key in batch.__key__) == [
f"parts/data-1.tar/{i:06d}" for i in range(30, 50)
] + [f"parts/data-0.tar/{i:06d}" for i in range(30)]
def test_default_dataset(self):
torch.manual_seed(42)
train_loader = get_loader(
get_train_dataset(
self.dataset_path,
batch_size=10,
worker_config=no_worker_config,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
)
val_loader = get_loader(
get_val_dataset(
self.dataset_path,
split_part="train",
batch_size=10,
worker_config=no_worker_config,
)
)
n_samples = 0
for i, sample in zip(range(100), train_loader):
assert sample.image.shape == (10, 3, 100, 100)
n_samples += sample.image.shape[0]
assert n_samples == 1000
n_samples = 0
for sample in val_loader:
assert sample.image.shape == (10, 3, 100, 100)
n_samples += sample.image.shape[0]
assert n_samples == 50
def test_no_batching(self):
train_loader = get_loader(
get_train_dataset(
self.dataset_path,
batch_size=None,
worker_config=no_worker_config,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
)
one_sample = next(iter(train_loader))
# Single sample without batching
assert isinstance(one_sample.image, torch.Tensor)
assert isinstance(one_sample.caption, str)
def test_dataset_len(self):
torch.manual_seed(42)
worker_config = WorkerConfig(rank=0, world_size=1, num_workers=4)
train_dataset = get_train_dataset(
self.dataset_path,
batch_size=11,
worker_config=worker_config,
virtual_epoch_length=12,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
train_loader = get_loader(train_dataset)
assert len(train_dataset) == 12
assert len(train_loader) == 12
assert len(list(train_loader)) == 12
val_dataset = get_val_dataset(
self.dataset_path, split_part="train", batch_size=1, worker_config=no_worker_config
)
val_loader = get_loader(val_dataset)
assert len(val_loader) == 50
assert len(list(val_loader)) == 50
val_dataset = get_val_dataset(
self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config
)
val_loader = get_loader(val_dataset)
# n samples: ceil(50 / 11) // 4 * 4
assert len(val_dataset) == 8
assert len(val_loader) == 8
assert len(list(val_loader)) == 8
assert [len(entry.__key__) for entry in val_loader] == [11, 11, 11, 11, 2, 1, 2, 1]
assert sum(len(entry.__key__) for entry in val_loader) == 50
def test_multirank_dataset(self):
torch.manual_seed(42)
worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2)
worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2)
train_dataset = get_train_dataset(
self.dataset_path,
batch_size=11,
worker_config=worker_config_r0,
virtual_epoch_length=12,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
train_loader = get_loader(train_dataset)
assert len(train_dataset) == 12
assert len(train_loader) == 12
assert len(list(train_loader)) == 12
val_dataset0 = get_val_dataset(
self.dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r0
)
val_loader0 = get_loader(val_dataset0)
print(len(val_loader0))
assert len(val_loader0) == 25
keys0 = set(key for entry in val_loader0 for key in entry.__key__)
assert len(keys0) == 25
val_dataset0b11 = get_val_dataset(
self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r0
)
val_loader0b11 = get_loader(val_dataset0b11)
assert len(val_dataset0b11) == 4
assert len(val_loader0b11) == 4
assert len(list(val_loader0b11)) == 4
keys0b11 = set(key for entry in val_loader0b11 for key in entry.__key__)
print([len(entry.__key__) for entry in val_loader0b11])
assert [len(entry.__key__) for entry in val_loader0b11] == [11, 11, 2, 1]
assert len(keys0b11) == 25
assert keys0b11 == keys0
val_dataset1 = get_val_dataset(
self.dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r1
)
val_loader1 = get_loader(val_dataset1)
print(len(val_loader1))
assert len(val_loader1) == 25
keys1 = set(key for entry in val_loader1 for key in entry.__key__)
assert len(keys1) == 25
print(sorted(keys1))
print(sorted(keys0))
assert keys1.isdisjoint(keys0)
val_dataset1b11 = get_val_dataset(
self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r1
)
val_loader1b11 = get_loader(val_dataset1b11)
assert len(val_dataset1b11) == 4
assert len(val_loader1b11) == 4
assert len(list(val_loader1b11)) == 4
keys1b11 = set(key for entry in val_loader1b11 for key in entry.__key__)
print([len(entry.__key__) for entry in val_loader1b11])
assert [len(entry.__key__) for entry in val_loader1b11] == [11, 11, 2, 1]
assert len(keys1b11) == 25
assert keys1b11.isdisjoint(keys0b11)
assert keys1b11 == keys1
def test_weight_aug(self):
class WeightAugmentTaskEncoder(AugmentTaskEncoder):
def __init__(self, task_encoder: TaskEncoder, weight: float, target_data_class: type):
super().__init__(task_encoder)
self.weight = weight
self.target_data_class = target_data_class
def encode_sample(self, sample):
sample = super().encode_sample(sample)
return self.target_data_class(**dataclasses.asdict(sample), weight=self.weight)
torch.manual_seed(42)
@edataclass
class WeightedCaptioningBatch(Batch):
image: torch.Tensor
caption: List[str]
weight: float
loader = get_loader(
get_val_dataset(
self.dataset_path,
split_part="train",
batch_size=10,
worker_config=no_worker_config,
task_encoder=WeightAugmentTaskEncoder(
DefaultTaskEncoder(),
weight=0.8,
target_data_class=WeightedCaptioningBatch,
),
)
)
for data in loader:
assert data.weight == [0.8] * 10
def test_blending(self):
torch.manual_seed(42)
loader = get_loader(
BlendDataset(
(
get_train_dataset(
self.dataset_path,
batch_size=10,
worker_config=no_worker_config,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
2,
),
(
get_train_dataset(
self.dataset_path,
batch_size=20,
worker_config=no_worker_config,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
8,
),
worker_config=no_worker_config,
)
)
bs_hist = {10: 0, 20: 0}
for i, sample in zip(range(1000), loader):
bs_hist[sample.image.shape[0]] += 1
print(bs_hist)
assert 150 <= bs_hist[10] <= 250
assert 750 <= bs_hist[20] <= 850
def test_mixing_homogeneous(self):
@dataclass
class TestBatch(Batch):
image: torch.Tensor
caption: List[str]
source: int
class TestTaskEncoder(TaskEncoder):
def __init__(self, source: int):
self.source = source
def encode_batch(self, batch):
return TestBatch(**dataclasses.asdict(batch), source=self.source)
loader = get_loader(
MixBatchDataset(
(
get_train_dataset(
self.dataset_path,
batch_size=1,
worker_config=no_worker_config,
task_encoder=TestTaskEncoder(source=0),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
2,
),
(
get_train_dataset(
self.dataset_path,
batch_size=1,
worker_config=no_worker_config,
task_encoder=TestTaskEncoder(source=1),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
8,
),
batch_size=10,
batch_mix_fn=homogeneous_concat_mix,
worker_config=no_worker_config,
)
)
source_hist = {0: 0, 1: 0}
for i, sample in zip(range(1000), loader):
assert sample.image.shape == (10, 3, 100, 100)
for source in sample.source:
source_hist[source] += 1
assert 1500 <= source_hist[0] <= 2500
assert 7500 <= source_hist[1] <= 8500
def test_mixing_heterogeneous(self):
@dataclass
class TestBatch1(Batch):
image: torch.Tensor
caption: List[str]
source: int
@dataclass
class TestBatch2(TestBatch1):
pass
class TestTaskEncoder(TaskEncoder):
def __init__(self, source: int, batch_cls: Type[TestBatch1]):
self.source = source
self.batch_cls = batch_cls
def encode_batch(self, batch):
return self.batch_cls(**dataclasses.asdict(batch), source=self.source)
loader = get_loader(
MixBatchDataset(
(
get_train_dataset(
self.dataset_path,
batch_size=1,
worker_config=no_worker_config,
task_encoder=TestTaskEncoder(source=0, batch_cls=TestBatch1),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
2,
),
(
get_train_dataset(
self.dataset_path,
batch_size=1,
worker_config=no_worker_config,
task_encoder=TestTaskEncoder(source=1, batch_cls=TestBatch2),
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
8,
),
batch_size=10,
worker_config=no_worker_config,
)
)
source_hist = {0: 0, 1: 0}
for i, samples in zip(range(1000), loader):
assert len(samples) == 10
for sample in samples:
assert sample.image.shape == (1, 3, 100, 100)
source_hist[sample.source] += 1
assert 1500 <= source_hist[0] <= 2500
assert 7500 <= source_hist[1] <= 8500
def test_val_limit(self):
torch.manual_seed(42)
loader = get_loader(
get_val_dataset(
self.dataset_path,
split_part="train",
batch_size=2,
worker_config=no_worker_config,
limit=3,
)
)
assert len(loader) == 3
samples = [[batch.__key__ for batch in loader] for _ in range(10)]
print(samples)
assert all(samples[0] == one_ep_samples for one_ep_samples in samples)
worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2)
loader = get_loader(
get_val_dataset(
self.dataset_path,
split_part="train",
batch_size=2,
worker_config=worker_config,
limit=3,
)
)
assert len(loader) == 3
samples_wrk2 = [[batch.__key__ for batch in loader] for _ in range(10)]
print(samples)
assert all(samples_wrk2[0] == one_ep_samples for one_ep_samples in samples_wrk2)
def test_current_batch_index(self):
# Tests if the get_current_batch_index works properly
torch.manual_seed(42)
class TestTaskEncoder(TaskEncoder):
@stateless(restore_seeds=True)
def encode_sample(self, sample):
# print("si stack:", WorkerConfig._sample_index_stack)
return ExtendedCaptioningSample.extend(
sample,
batch_index=self.current_batch_index,
sample_index=self.current_sample_index,
rand_num=random.randint(0, 1000),
)
# First, test simple single main-thread loader with accessing get_current_batch_index
loader = get_loader(
get_train_dataset(
self.dataset_path,
batch_size=2,
task_encoder=TestTaskEncoder(),
worker_config=no_worker_config,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
)
)
batches = list(zip(range(20), loader))
print("bi", [batch.batch_index for batch_idx, batch in batches])
assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches)
print("si", [batch.sample_index for batch_idx, batch in batches])
assert all(
all(
si == sample_offset + batch_idx * 2
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches
)
print("pk", [batch.__key__ for batch_idx, batch in batches])
print("rk", [batch.__restore_key__ for batch_idx, batch in batches])
assert loader.can_restore_sample()
# These need to be hard coded to detect breaking changes
# If a change is expected, update the values with the ones printed below
ref_batch_rand_nums = [
[661, 762],
[206, 470],
[130, 283],
[508, 61],
[625, 661],
[296, 376],
[632, 514],
[715, 406],
[555, 27],
[760, 36],
[607, 610],
[825, 219],
[564, 832],
[876, 512],
[632, 605],
[357, 738],
[40, 378],
[609, 444],
[610, 367],
[367, 69],
]
batch_rand_nums = []
for batch_idx, batch in batches:
restore_batch = loader.restore_sample(batch.__restore_key__)
assert restore_batch.__key__ == batch.__key__
assert restore_batch.batch_index == batch.batch_index
assert restore_batch.sample_index == batch.sample_index
assert restore_batch.rand_num == batch.rand_num
batch_rand_nums.append(restore_batch.rand_num)
assert np.allclose(restore_batch.image, batch.image)
# For constructing the test data above:
print("batch_rand_nums: ", batch_rand_nums)
assert batch_rand_nums == ref_batch_rand_nums
# Now, test multi-worker loader with accessing get_current_batch_index
worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2)
worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2)
loader = get_loader(
get_train_dataset(
self.dataset_path,
batch_size=2,
task_encoder=TestTaskEncoder(),
worker_config=worker_config_r0,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
)
)
loader_r1 = get_loader(
get_train_dataset(
self.dataset_path,
batch_size=2,
task_encoder=TestTaskEncoder(),
worker_config=worker_config_r1,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
)
)
batches = list(zip(range(20), loader))
print("bir0", [batch.batch_index for batch_idx, batch in batches])
assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches)
print("sir0", [batch.sample_index for batch_idx, batch in batches])
assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches)
assert all(
all(
si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2)
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches
)
batches_r1 = list(zip(range(20), loader_r1))
print("bir0", [batch.batch_index for batch_idx, batch in batches_r1])
print("sir1", [batch.sample_index for batch_idx, batch in batches_r1])
assert all(
all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1
)
assert all(
all(
si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2)
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches_r1
)
# Now, test multi-worker loader with accessing get_current_batch_index and save/restore state
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=2,
task_encoder=TestTaskEncoder(),
worker_config=worker_config_r0,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
)
)
loader_r1 = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=2,
task_encoder=TestTaskEncoder(),
worker_config=worker_config_r1,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
)
)
batches = list(zip(range(20), loader))
print([batch.batch_index for batch_idx, batch in batches])
assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches)
assert all(
all(
si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2)
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches
)
batches_r1 = list(zip(range(20), loader_r1))
print([batch.batch_index for batch_idx, batch in batches_r1])
assert all(
all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1
)
assert all(
all(
si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2)
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches_r1
)
# Save and restore state
state = loader.save_state_rank()
# Restore state and check if the batch index is restored correctly
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=2,
task_encoder=TestTaskEncoder(),
worker_config=worker_config_r0,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
)
)
loader.restore_state_rank(state)
batches = list(zip(range(20, 40), loader))
print([batch.batch_index for batch_idx, batch in batches])
print([batch.sample_index for batch_idx, batch in batches])
assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches)
assert all(
all(
si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2)
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches
)
def test_current_batch_index_generator(self):
# Tests if the get_current_batch_index works properly
torch.manual_seed(42)
class TestTaskEncoder(TaskEncoder):
@stateless(restore_seeds=True)
def encode_sample(self, sample):
# print("si stack:", WorkerConfig._sample_index_stack)
yield ExtendedCaptioningSample.extend(
sample,
batch_index=self.current_batch_index,
sample_index=self.current_sample_index,
rand_num=random.randint(0, 1000) + 0,
)
yield ExtendedCaptioningSample.extend(
sample,
batch_index=self.current_batch_index,
sample_index=self.current_sample_index,
rand_num=random.randint(0, 1000) + 1000,
)
# First, test simple single main-thread loader with accessing get_current_batch_index
loader = get_loader(
get_train_dataset(
self.dataset_path,
batch_size=3,
task_encoder=TestTaskEncoder(),
worker_config=no_worker_config,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
)
)
batches = list(zip(range(20), loader))
print("bi", [batch.batch_index for batch_idx, batch in batches])
assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches)
print("si", [batch.sample_index for batch_idx, batch in batches])
assert all(
all(
si == (sample_offset + batch_idx * 3) // 2
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches
)
print("rk", [batch.__restore_key__ for batch_idx, batch in batches])
assert loader.can_restore_sample()
# These need to be hard coded to detect breaking changes
# If a change is expected, update the values with the ones printed below
ref_batch_rand_nums = [
[661, 1747, 762],
[1171, 206, 1921],
[470, 1705, 130],
[1722, 283, 1990],
[508, 1041, 61],
[1102, 625, 1559],
[661, 1512, 296],
[1866, 376, 1345],
[632, 1176, 514],
[1652, 715, 1702],
[406, 1552, 555],
[1303, 27, 1520],
[760, 1380, 36],
[1869, 607, 1292],
[610, 1084, 825],
[1113, 219, 1102],
[564, 1695, 832],
[1612, 876, 2000],
[512, 1308, 632],
[1425, 605, 1931],
]
batch_rand_nums = []
for batch_idx, batch in batches:
restore_batch = loader.restore_sample(batch.__restore_key__)
assert restore_batch.batch_index == batch.batch_index
assert restore_batch.sample_index == batch.sample_index
assert restore_batch.rand_num == batch.rand_num
batch_rand_nums.append(restore_batch.rand_num)
assert np.allclose(restore_batch.image, batch.image)
# For constructing the test data above:
print("batch_rand_nums: ", batch_rand_nums)
assert batch_rand_nums == ref_batch_rand_nums
# Now, test multi-worker loader with accessing get_current_batch_index
worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2)
worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2)
loader = get_loader(
get_train_dataset(
self.dataset_path,
batch_size=3,
task_encoder=TestTaskEncoder(),
worker_config=worker_config_r0,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
)
)
loader_r1 = get_loader(
get_train_dataset(
self.dataset_path,
batch_size=3,
task_encoder=TestTaskEncoder(),
worker_config=worker_config_r1,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
)
)
batches = list(zip(range(20), loader))
print("bir0", [batch.batch_index for batch_idx, batch in batches])
assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches)
print("sir0", [batch.sample_index for batch_idx, batch in batches])
# [[0, 0, 2], [1, 1, 3], [2, 4, 4], [3, 5, 5], [6, 6, 8], [7, 7, 9], [8, 10, 10], [9, 11, 11], [12, 12, 14], [13, 13, 15], [14, 16, 16], [15, 17, 17], [18, 18, 20], [19, 19, 21], [20, 22, 22], [21, 23, 23], [24, 24, 26], [25, 25, 27], [26, 28, 28], [27, 29, 29]]
assert all(
all(
si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches
)
batches_r1 = list(zip(range(20), loader_r1))
print("bir0", [batch.batch_index for batch_idx, batch in batches_r1])
print("sir1", [batch.sample_index for batch_idx, batch in batches_r1])
assert all(
all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1
)
assert all(
all(
si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches_r1
)
# Now, test multi-worker loader with accessing get_current_batch_index and save/restore state
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=3,
task_encoder=TestTaskEncoder(),
worker_config=worker_config_r0,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
),
worker_config=worker_config_r0,
)
loader_r1 = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=3,
task_encoder=TestTaskEncoder(),
worker_config=worker_config_r1,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
),
worker_config=worker_config_r1,
)
batches = list(zip(range(20), loader))
print("bi:", [batch.batch_index for batch_idx, batch in batches])
print("si:", [batch.sample_index for batch_idx, batch in batches])
assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches)
assert all(
all(
si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches
)
batches_r1 = list(zip(range(20), loader_r1))
print([batch.batch_index for batch_idx, batch in batches_r1])
assert all(
all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1
)
assert all(
all(
si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches_r1
)
# Save and restore state
state = loader.save_state_rank()
# Iter next 20 from the loader
cmp_batches = list(zip(range(20, 40), loader))
print("bi:", [batch.batch_index for batch_idx, batch in cmp_batches])
print("si:", [batch.sample_index for batch_idx, batch in cmp_batches])
print("rnd:", [batch.rand_num for batch_idx, batch in cmp_batches])
assert all(
all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in cmp_batches
)
assert all(
all(
si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in cmp_batches
)
# Restore state and check if the batch index is restored correctly
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=3,
task_encoder=TestTaskEncoder(),
worker_config=worker_config_r0,
shuffle_buffer_size=20,
max_samples_per_sequence=10,
),
worker_config=worker_config_r0,
)
loader.restore_state_rank(state)
batches = list(zip(range(20, 40), loader))
print("bi:", [batch.batch_index for batch_idx, batch in batches])
print("si:", [batch.sample_index for batch_idx, batch in batches])
print("rnd:", [batch.rand_num for batch_idx, batch in batches])
assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches)
assert all(
all(
si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2
for sample_offset, si in enumerate(batch.sample_index)
)
for batch_idx, batch in batches
)
assert all(
all(b1s == b2s for b1s, b2s in zip(b1.rand_num, b2.rand_num))
for (_b1idx, b1), (_b2idx, b2) in zip(batches, cmp_batches)
)
def test_packing(self):
torch.manual_seed(42)
class TestTaskEncoder(DefaultTaskEncoder):
def __init__(self):
super().__init__(raw_batch_type=CaptioningBatch)
@stateless
def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample:
return EncodedCaptioningSample.derive_from(
sample,
image=sample.image,
caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8),
)
def select_samples_to_pack(
self, samples: List[EncodedCaptioningSample]
) -> List[List[EncodedCaptioningSample]]:
assert len(samples) == 21
return [samples[:1], samples[1 : 1 + 4], samples[1 + 4 : 1 + 4 + 16]]
@stateless
def pack_selected_samples(
self, samples: List[EncodedCaptioningSample]
) -> EncodedCaptioningSample:
return EncodedCaptioningSample(
__key__=",".join([sample.__key__ for sample in samples]),
__restore_key__=(),
image=torch.stack([sample.image for sample in samples]),
caption=torch.cat([sample.caption for sample in samples]),
)
loader = get_loader(
get_train_dataset(
self.dataset_path,
batch_size=2,
packing_buffer_size=21,
worker_config=no_worker_config,
virtual_epoch_length=6,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=TestTaskEncoder(),
)
)
assert len(loader) == 6
samples = list(loader)
print([batch.__key__ for batch in samples])
print([batch.__restore_key__ for batch in samples])
print([len(batch.__key__) for batch in samples])
print([[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples])
# Each batch should have 2 samples
assert [len(batch.__key__) for batch in samples] == [
2,
2,
2,
2,
2,
2,
]
# The packs of lengths 1, 4, 16 should be unrolled repeatedly across the batches of size 2
assert [
[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples
] == [[1, 4], [16, 1], [4, 16], [1, 4], [16, 1], [4, 16]]
restored_sample_1 = loader.restore_sample(samples[1].__restore_key__)
assert restored_sample_1.__key__ == samples[1].__key__
assert restored_sample_1.__restore_key__ == samples[1].__restore_key__
worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2)
loader_r0 = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=2,
packing_buffer_size=21,
worker_config=worker_config_r0,
virtual_epoch_length=8,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=TestTaskEncoder(),
),
checkpoint_every_min_n_samples=1,
checkpoint_every_sec=0,
)
samples_r0 = list(loader_r0)
assert [
[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples_r0
] == [[1, 4], [1, 4], [16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4]]
restored_sample_1 = loader_r0.restore_sample(samples_r0[1].__restore_key__)
assert restored_sample_1.__key__ == samples_r0[1].__key__
assert restored_sample_1.__restore_key__ == samples_r0[1].__restore_key__
rank_state_r0 = loader_r0.save_state_rank()
samples_r0_cmp = list(loader_r0)
assert [
[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples_r0_cmp
] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]]
loader_r0 = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=2,
packing_buffer_size=21,
worker_config=worker_config_r0,
virtual_epoch_length=8,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=TestTaskEncoder(),
),
checkpoint_every_min_n_samples=1,
checkpoint_every_sec=0,
)
loader_r0.restore_state_rank(rank_state_r0)
samples_r0_restored = list(loader_r0)
print("cmp", [batch.__key__ for batch in samples_r0_cmp])
print("rst", [batch.__key__ for batch in samples_r0_restored])
assert [
[len(batch_key.split(",")) for batch_key in batch.__key__]
for batch in samples_r0_restored
] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]]
assert all(s0.__key__ == s1.__key__ for s0, s1 in zip(samples_r0_cmp, samples_r0_restored))
def test_packing_val(self):
torch.manual_seed(42)
class TestTaskEncoder(DefaultTaskEncoder):
def __init__(self):
super().__init__(raw_batch_type=CaptioningBatch)
@stateless
def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample:
return EncodedCaptioningSample.derive_from(
sample,
image=sample.image,
caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8),
)
def select_samples_to_pack(
self, samples: List[EncodedCaptioningSample]
) -> List[List[EncodedCaptioningSample]]:
assert len(samples) in (1 + 3 + 5 + 2, 50 % 11)
if len(samples) < 11:
return []
return [
samples[1 + 3 + 5 : 1 + 3 + 5 + 2],
samples[1 + 3 : 1 + 3 + 5],
samples[1 : 1 + 3],
samples[:1],
]
@stateless
def pack_selected_samples(
self, samples: List[EncodedCaptioningSample]
) -> EncodedCaptioningSample:
return EncodedCaptioningSample(
__key__=",".join([sample.__key__ for sample in samples]),
__restore_key__=(),
image=torch.stack([sample.image for sample in samples]),
caption=torch.cat([sample.caption for sample in samples]),
)
loader = get_loader(
get_val_dataset(
self.dataset_path,
batch_size=2,
packing_buffer_size=11,
worker_config=no_worker_config,
task_encoder=TestTaskEncoder(),
split_part="train",
)
)
assert len(loader) == 25, f"len(loader) == {len(loader)}"
samples = list(loader)
print([batch.__key__ for batch in samples])
print([batch.__restore_key__ for batch in samples])
print([len(batch.__key__) for batch in samples])
print([[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples])
# Each batch should have 2 samples
assert [len(batch.__key__) for batch in samples] == [
2,
2,
2,
2,
2,
2,
2,
2,
]
# The packs of lengths 1, 4, 16 should be unrolled repeatedly across the batches of size 2
assert [
[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples
] == [[2, 5], [3, 1], [2, 5], [3, 1], [2, 5], [3, 1], [2, 5], [3, 1]]
restored_sample_1 = loader.restore_sample(samples[1].__restore_key__)
assert restored_sample_1.__key__ == samples[1].__key__
assert restored_sample_1.__restore_key__ == samples[1].__restore_key__
def test_group_batch(self):
class GroupingTaskEncoder(
TaskEncoder[CaptioningSample, CaptioningSample, CaptioningSample, CaptioningSample]
):
@stateless
def encode_sample(self, sample: CaptioningSample) -> CaptioningSample:
sample.caption = sample.__key__.split("/")[-2]
return sample
def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, int]:
if sample.caption == "data-0.tar":
return "shard1", 4
elif sample.caption == "data-1.tar":
return "shard2", 8
else:
assert False
@stateless
def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch:
return CaptioningEncodedBatch(**dataclasses.asdict(batch))
worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0)
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=None,
worker_config=worker_config,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=GroupingTaskEncoder(),
),
checkpoint_every_min_n_samples=1,
checkpoint_every_sec=0,
)
batches = list(zip(range(40), loader))
print([batch.__key__ for idx, batch in batches])
assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches)
assert all(all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches)
worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2)
loader_r0 = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=None,
worker_config=worker_config_r0,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=GroupingTaskEncoder(),
),
checkpoint_every_min_n_samples=1,
checkpoint_every_sec=0,
)
batches = list(zip(range(40), loader_r0))
print([batch.__key__ for idx, batch in batches])
assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches)
assert all(all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches)
state = loader_r0.save_state_rank()
cmp_samples = list(zip(range(40, 80), loader_r0))
print([batch.__key__ for idx, batch in cmp_samples])
loader_r0 = get_savable_loader(
get_train_dataset(
self.dataset_path,
batch_size=None,
worker_config=worker_config_r0,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=GroupingTaskEncoder(),
),
checkpoint_every_min_n_samples=1,
checkpoint_every_sec=0,
)
loader_r0.restore_state_rank(state)
cmp_samples_rest = list(zip(range(40, 80), loader_r0))
print([batch.__key__ for idx, batch in cmp_samples_rest])
assert len(cmp_samples) == len(cmp_samples_rest)
assert all(
len(cmp_sample.caption) == len(cmp_sample_rest.caption)
for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest)
)
assert all(
all(
cmp_cap == cmp_cap_rest
for cmp_cap, cmp_cap_rest in zip(cmp_sample.caption, cmp_sample_rest.caption)
)
for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest)
)
def test_debug_dataset(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=2,
worker_log_level=3,
worker_debug_path=str(self.dataset_path) + "/worker_debug/{worker_id}.jsonl",
)
# Reset this to 0 to make sure the test is deterministic
SavableDataLoader._next_id = 0
loader = get_savable_loader(
get_val_dataset(
self.dataset_path,
split_part="train",
batch_size=5,
worker_config=worker_config,
),
)
assert len(loader) == 10
samples = [[batch.__key__ for batch in loader] for _ in range(2)]
print(samples)
debug_log_path = self.dataset_path / "worker_debug"
assert (debug_log_path / "0.jsonl").is_file()
assert (debug_log_path / "1.jsonl").is_file()
assert (debug_log_path / "2.jsonl").is_file()
collected_keys_order = [[None] * 10 for _ in range(2)]
with (debug_log_path / "0.jsonl").open() as rf:
for line in rf:
line_data = json.loads(line)
if line_data["t"] == "SavableDataLoader.yield":
print(line_data)
for i in range(len(collected_keys_order)):
if collected_keys_order[i][line_data["idx"]] is None:
collected_keys_order[i][line_data["idx"]] = line_data["keys"]
break
else:
assert False, "Too many entries for key"
print(collected_keys_order)
assert collected_keys_order == samples
runner = CliRunner()
result = runner.invoke(
analyze_debug_command,
[
str(debug_log_path),
"--include-modality",
"train,val",
"--heatmap-path",
str(self.dataset_path / "heatmap.png"),
],
catch_exceptions=False,
)
print(result.stdout)
assert result.exit_code == 0, "Debug analysis failed, see output"
assert "Analyzing 3 logs" in result.stdout
assert "Found 50 unique sample keys, 20 steps" in result.stdout
def test_validate_captioning_dataset(self):
runner = CliRunner()
result = runner.invoke(
lint_command,
[str(self.dataset_path), "--split-parts=train"],
catch_exceptions=False,
)
assert result.exit_code == 0, "Validation failed, see output"
def test_prepare_dataset(self):
runner = CliRunner()
result = runner.invoke(
prepare_command,
[str(self.dataset_path)],
catch_exceptions=False,
input="y\n1,0,0\ny\n0\nY\npng\ntxt\n",
)
assert result.exit_code == 0, "Prepare failed, see output"
assert "Done" in result.stdout, "Prepare failed, see output"
def test_preview_captioning_dataset(self):
runner = CliRunner()
result = runner.invoke(
preview_command,
[str(self.dataset_path), "--split-parts=train"],
input="n\n",
catch_exceptions=False,
)
# First sample!
assert "__key__ (<class 'str'>): 'parts/data-1.tar/000030'" in result.stdout
assert result.exit_code == 0, "Preview failed, see output"
def test_info_captioning_dataset(self):
runner = CliRunner()
result = runner.invoke(
info_command,
[str(self.dataset_path)],
catch_exceptions=False,
)
print(result.stdout)
assert "50 samples" in result.stdout
assert "2 shards" in result.stdout
assert str(self.dataset_path) in result.stdout
assert "train" in result.stdout
assert result.exit_code == 0, "Preview failed, see output"
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""This module defines tests for the dataset."""
import gc
import logging
import random
import sys
import tempfile
import unittest
import warnings
from collections import Counter
from pathlib import Path
import torch
import webdataset as wds
import yaml
from click.testing import CliRunner
from megatron.energon import (
DefaultTaskEncoder,
TextSample,
WorkerConfig,
get_loader,
get_train_dataset,
)
from megatron.energon.dataset_config import get_dataset_from_config
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.loader import get_savable_loader
from megatron.energon.task_encoder.base import stateless
from megatron.energon.tools.checkpoint import command_redist
# Speed up tests significantly by reducing the torch status check interval for broken worker shutdown
try:
torch.utils.data._utils.worker.MP_STATUS_CHECK_INTERVAL = 0.1
torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.1
except AttributeError:
pass
def _norng_state(state):
if isinstance(state, bytes):
if len(state) > 100:
return state[:5] + f"...<len={len(state)}>".encode()
return state
elif isinstance(state, str):
if len(state) > 100:
return state[:5] + f"...<len={len(state)}>"
return state
elif isinstance(state, dict):
return {k: _norng_state(v) for k, v in state.items()}
elif isinstance(state, (list, tuple)):
if len(state) > 100:
state = state[:5]
return type(state)(_norng_state(v) for v in state)
else:
return state
class TestDataset(unittest.TestCase):
# Set up the test fixture
def setUp(self):
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
warnings.simplefilter("ignore", ResourceWarning)
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory()
self.dataset_path = Path(self.temp_dir.name)
# self.dataset_path = Path("./test_dataset")
self.dataset_path.mkdir(exist_ok=True, parents=True)
# Create a small dummy captioning dataset
self.create_text_test_dataset(self.dataset_path)
# Create temporary directories for checkpoint files
self.checkpoint_dir = Path(self.temp_dir.name) / "checkpoints"
self.checkpoint_dir.mkdir(exist_ok=True, parents=True)
self.redist_dir = Path(self.temp_dir.name) / "redist_checkpoints"
self.redist_dir.mkdir(exist_ok=True, parents=True)
print(self.dataset_path)
def tearDown(self):
# Remove all temporary files
gc.collect()
self.temp_dir.cleanup()
@staticmethod
def create_text_test_dataset(path: Path):
"""Creates a small dummy test dataset for testing purposes."""
# Create num_samples unique captions
(path / "parts").mkdir(exist_ok=True, parents=True)
# Initialize the ShardWriter
with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=100) as shard_writer:
for idx in range(55):
# Write individual files to shards
shard_writer.write(
{
"__key__": f"{idx:06d}",
"txt": f"{idx}".encode(),
},
)
# Also create smaller shards, to verify distributions
if idx in (1, 3, 6, 10, 20, 30, 40, 50):
shard_writer.next_stream()
total_shards = shard_writer.shard
from megatron.energon.flavors import BaseWebdatasetFactory
BaseWebdatasetFactory.prepare_dataset(
path,
[f"parts/data-{{0..{total_shards - 1}}}.tar"],
split_parts_ratio=[("train", 1.0)],
shuffle_seed=None,
)
with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f:
f.write(
"\n".join(
[
"sample_type:",
" __module__: megatron.energon",
" __class__: TextSample",
"field_map:",
" text: txt",
]
)
)
# Split with alternating train/val shards
with open(path / MAIN_FOLDER_NAME / "split2.yaml", "w") as f:
yaml.dump(
{
"split_parts": {
"train": [
"parts/data-4.tar",
"parts/data-0.tar",
"parts/data-2.tar",
],
"val": [
"parts/data-1.tar",
"parts/data-3.tar",
"parts/data-5.tar",
],
}
},
f,
)
def test_split_parts(self):
with open(self.dataset_path / MAIN_FOLDER_NAME / "split.yaml", "r") as f:
print(f.read())
with open(self.dataset_path / MAIN_FOLDER_NAME / "split2.yaml", "r") as f:
print(f.read())
ds = get_dataset_from_config(
self.dataset_path,
split_config="split2.yaml",
split_part="train",
worker_config=WorkerConfig(rank=0, world_size=1, num_workers=0),
training=False,
sample_type=TextSample,
)
dl = get_loader(ds.build())
all_keys = [sample.__key__ for sample in dl]
assert all_keys == [
"parts/data-4.tar/000011", # Shard 4 first
"parts/data-4.tar/000012",
"parts/data-4.tar/000013",
"parts/data-4.tar/000014",
"parts/data-4.tar/000015",
"parts/data-4.tar/000016",
"parts/data-4.tar/000017",
"parts/data-4.tar/000018",
"parts/data-4.tar/000019",
"parts/data-4.tar/000020",
"parts/data-0.tar/000000", # Shard 0
"parts/data-0.tar/000001",
"parts/data-2.tar/000004", # Shard 2
"parts/data-2.tar/000005",
"parts/data-2.tar/000006",
]
def test_text_dataset(self):
worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0)
ds = get_dataset_from_config(
self.dataset_path,
split_part="train",
training=False,
sample_type=TextSample,
worker_config=worker_config,
).build()
# Check len operator
assert len(ds) == 55
# Check if iterating returns the same
iter1 = list(get_loader(ds))
iter2 = list(get_loader(ds))
assert len(iter1) == 55
assert len(iter2) == 55
assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2))
assert all(f"{idx}" == x.text for idx, x in enumerate(get_loader(ds)))
del ds
gc.collect()
def test_epoch(self):
torch.manual_seed(42)
worker_config = WorkerConfig(rank=0, world_size=1, num_workers=5)
# Without shuffle buffer, should yield everything exactly once
ds3 = get_dataset_from_config(
self.dataset_path,
split_part="train",
training=True,
sample_type=TextSample,
worker_config=worker_config,
)
loader5 = get_loader(ds3.build())
order9 = [data.text for idx, data in zip(range(55), loader5)]
print(order9)
print(Counter(order9))
assert all(v == 1 for v in Counter(order9).values())
def test_determinism(self):
worker_config2 = WorkerConfig(rank=0, world_size=1, num_workers=2)
worker_config2b = WorkerConfig(rank=0, world_size=1, num_workers=2, seed_offset=43)
worker_config4 = WorkerConfig(rank=0, world_size=1, num_workers=4)
# This seed is used by the dataset to shuffle the data
torch.manual_seed(42)
ds1 = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config2,
batch_size=1,
shuffle_buffer_size=42,
max_samples_per_sequence=2,
)
ds1b = get_train_dataset( # Same but different seed
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config2b,
batch_size=1,
shuffle_buffer_size=42,
max_samples_per_sequence=2,
)
ds2 = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config2,
batch_size=1,
shuffle_buffer_size=42,
max_samples_per_sequence=2,
)
ds3 = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config4,
batch_size=1,
shuffle_buffer_size=42,
max_samples_per_sequence=2,
)
# Fork the dataset twice
loader1 = get_loader(ds1)
loader2 = get_loader(ds1)
order4 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)]
order5 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)]
order6 = [data.text[0] for idx, data in zip(range(55 * 20), loader2)]
print(order4)
print(Counter(order4))
# +-1 is possible due to the random shuffling (actually +-2 is possible)
assert all(17 <= v <= 22 for v in Counter(order4).values())
assert order4 != order5
assert order4 == order6
loader3 = get_loader(ds1b)
order7 = [data.text[0] for idx, data in zip(range(55 * 20), loader3)]
assert order6 != order7
loader4 = get_loader(ds3)
order8 = [data.text[0] for idx, data in zip(range(55 * 100), loader4)]
assert order6 != order8[: len(order6)]
print(Counter(order8))
assert all(90 <= v <= 110 for v in Counter(order8).values())
# Delete all locals, otherwise loaders might be kept alive
locals().clear()
gc.collect()
def test_determinism_taskencoder(self):
class TestTaskEncoder(DefaultTaskEncoder):
@stateless(restore_seeds=True)
def encode_sample(self, sample: TextSample) -> TextSample:
rand_str = f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}"
return TextSample(
__key__=sample.__key__,
__restore_key__=sample.__restore_key__,
__subflavors__=sample.__subflavors__,
text=sample.text + rand_str,
)
for num_workers in [0, 1]:
worker_config1 = WorkerConfig(rank=0, world_size=1, num_workers=num_workers)
# This seed is used by the dataset to shuffle the data
torch.manual_seed(42)
ds1a = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config1,
batch_size=1,
shuffle_buffer_size=42,
max_samples_per_sequence=2,
task_encoder=TestTaskEncoder(),
)
torch.manual_seed(44)
ds1b = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config1,
batch_size=1,
shuffle_buffer_size=42,
max_samples_per_sequence=2,
task_encoder=TestTaskEncoder(),
)
# Fork the dataset twice
loader1a = get_loader(ds1a)
loader1b = get_loader(ds1b)
order1a = [data.text[0] for idx, data in zip(range(55 * 20), loader1a)]
order1b = [data.text[0] for idx, data in zip(range(55 * 20), loader1b)]
assert order1a == order1b
# Delete all locals, otherwise loaders might be kept alive
locals().clear()
gc.collect()
def test_determinism_taskencoder_save_restore(self):
class TestTaskEncoder(DefaultTaskEncoder):
@stateless(restore_seeds=True)
def encode_sample(self, sample: TextSample) -> TextSample:
rand_str = (
f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}"
+ f"_{self.current_batch_index}_{self.current_sample_index}"
)
return TextSample(
__key__=sample.__key__,
__restore_key__=sample.__restore_key__,
__subflavors__=sample.__subflavors__,
text=sample.text + rand_str,
)
for num_workers in [1, 0]:
worker_config1 = WorkerConfig(rank=0, world_size=1, num_workers=num_workers)
# This seed is used by the dataset to shuffle the data
torch.manual_seed(42)
ds1a = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config1,
batch_size=1,
shuffle_buffer_size=42,
max_samples_per_sequence=2,
task_encoder=TestTaskEncoder(),
)
torch.manual_seed(44)
ds1b = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config1,
batch_size=1,
shuffle_buffer_size=42,
max_samples_per_sequence=2,
task_encoder=TestTaskEncoder(),
)
# Fork the dataset twice
loader1a = get_savable_loader(ds1a)
loader1b = get_savable_loader(ds1b)
# Load 7 samples
data_pre = [data.text[0] for idx, data in zip(range(7), loader1a)]
# Then save state
state = loader1a.save_state_rank()
# Load another 20 samples
data_post = [data.text[0] for idx, data in zip(range(20), loader1a)]
# Restore state
loader1b.restore_state_rank(state)
# Load 20 samples again
data_restored = [data.text[0] for idx, data in zip(range(20), loader1b)]
print("Data post:", data_post)
print("Data restored:", data_restored)
assert data_post == data_restored
# Delete all locals, otherwise loaders might be kept alive
locals().clear()
gc.collect()
def test_restore_state(self):
worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0)
count1 = 55 * 20
count2 = 55 * 20
sbs = 42
# count1 = 4
# count2 = 2
# sbs = None
psi = None
# This seed is used by the dataset to shuffle the data
torch.manual_seed(42)
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=sbs,
max_samples_per_sequence=2,
parallel_shard_iters=psi,
)
)
# print("save state")
state_0 = loader.save_state_global(global_dst_rank=0)
# print("save state done")
order_1 = [data.text[0] for idx, data in zip(range(count1), loader)]
assert len(order_1) == count1
# print("save state")
state_1 = loader.save_state_global(global_dst_rank=0)
# print("save state done")
order_2 = [data.text[0] for idx, data in zip(range(count2), loader)]
assert len(order_2) == count2
print("state0", state_0)
print("state1", state_1)
torch.manual_seed(213)
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=sbs,
max_samples_per_sequence=2,
parallel_shard_iters=psi,
)
)
loader.restore_state_global(state_0, src_rank=None)
order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)]
order_4 = order_45[:count1]
order_5 = order_45[count1:]
# print("order1", order_1)
# print("order2", order_2)
# print("order4", order_4)
assert order_1 == order_4
# print("order5", order_5)
assert order_2 == order_5
torch.manual_seed(145)
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=sbs,
max_samples_per_sequence=2,
parallel_shard_iters=psi,
)
)
# print("restore state")
loader.restore_state_global(state_1, src_rank=None)
# print("restore state done")
order_3 = [data.text[0] for idx, data in zip(range(count2), loader)]
# print("order1", order_1)
# print("order2", order_2[:100])
# print("order3", order_3[:100])
assert order_2 == order_3
def test_restore_state_dist(self):
from multiprocessing import Manager, Process
import torch.distributed as dist
world_size = 3
count1 = 55 * 20
count2 = 55 * 20
sbs = 42
psi = None
def phase1(rank: int, world_size: int, shared_dict: dict):
worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0)
# This seed is used by the dataset to shuffle the data
torch.manual_seed(42)
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=sbs,
max_samples_per_sequence=2,
parallel_shard_iters=psi,
)
)
state_0 = loader.save_state_global(global_dst_rank=0)
order_1 = [data.text[0] for idx, data in zip(range(count1), loader)]
assert len(order_1) == count1
# print(f"Rank {rank}: order_1", order_1)
state_1 = loader.save_state_global(global_dst_rank=0)
order_2 = [data.text[0] for idx, data in zip(range(count2), loader)]
assert len(order_2) == count2
shared_dict[(rank, "order_1")] = order_1
shared_dict[(rank, "order_2")] = order_2
if rank == 0:
shared_dict["state_0"] = state_0
shared_dict["state_1"] = state_1
def phase2(rank: int, world_size: int, shared_dict: dict):
order_1 = shared_dict[(rank, "order_1")]
order_2 = shared_dict[(rank, "order_2")]
if rank == 0:
state_0 = shared_dict["state_0"]
state_1 = shared_dict["state_1"]
else:
state_0 = None
state_1 = None
worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0)
torch.manual_seed(213)
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=sbs,
max_samples_per_sequence=2,
parallel_shard_iters=psi,
)
)
loader.restore_state_global(state_0, src_rank=0)
order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)]
order_4 = order_45[:count1]
order_5 = order_45[count1:]
# print(f"Rank {rank}: order_4", order_4)
assert order_1 == order_4
assert order_2 == order_5
torch.manual_seed(213)
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=sbs,
max_samples_per_sequence=2,
parallel_shard_iters=psi,
)
)
loader.restore_state_global(state_1, src_rank=0)
order_3 = [data.text[0] for idx, data in zip(range(count2), loader)]
assert order_2 == order_3
def init_process(rank, world_size, shared_dict, fn, backend="gloo"):
"""Initializes the distributed environment."""
dist.init_process_group(
backend=backend,
init_method="tcp://127.0.0.1:12355",
world_size=world_size,
rank=rank,
)
fn(rank, world_size, shared_dict)
dist.destroy_process_group()
with Manager() as manager:
shared_dict = manager.dict()
# Phase 1 (save state)
processes = []
for rank in range(world_size):
p = Process(target=init_process, args=(rank, world_size, shared_dict, phase1))
p.start()
processes.append(p)
for p in processes:
p.join()
# Phase 2 (restore state)
processes = []
for rank in range(world_size):
p = Process(target=init_process, args=(rank, world_size, shared_dict, phase2))
p.start()
processes.append(p)
for p in processes:
p.join()
def test_restore_state_workers(self):
worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2)
psi = 2
sbs = 42
n1 = 18
n2 = 109
n3 = 28
ces = 0
# This seed is used by the dataset to shuffle the data
torch.manual_seed(42)
ds = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=sbs,
max_samples_per_sequence=2,
parallel_shard_iters=psi,
)
loader = get_savable_loader(ds, checkpoint_every_sec=ces)
# print("save state")
state_0 = loader.save_state_rank()
it1 = iter(loader)
# print("save state done")
order_1 = [data.text[0] for idx, data in zip(range(n1), it1)]
# print("save state")
# time.sleep(0.5)
state_1 = loader.save_state_rank()
# print("save state done")
order_2 = [data.text[0] for idx, data in zip(range(n2), it1)]
state_2 = loader.save_state_rank()
order_3 = [data.text[0] for idx, data in zip(range(n3), it1)]
print("order_1", order_1)
print("order_2", order_2)
print("order_3", order_3)
# print("state0", state_0)
print("state1", state_1)
print("state2", state_2)
# Restoring the state of a new dataset should also yield the same
torch.manual_seed(42)
ds = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=sbs,
max_samples_per_sequence=2,
parallel_shard_iters=psi,
)
loader = get_savable_loader(ds)
loader.restore_state_rank(state_0)
order_6 = [data.text[0] for idx, data in zip(range(n1), loader)]
print("order1", order_1)
print("order6", order_6)
assert order_6 == order_1
# Restoring the state of a new dataset should also yield the same
torch.manual_seed(42)
ds = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=sbs,
max_samples_per_sequence=2,
parallel_shard_iters=psi,
)
loader = get_savable_loader(ds)
loader.restore_state_rank(state_1)
order_7 = [data.text[0] for idx, data in zip(range(n2), loader)]
print("order2", order_2[:100])
print("order7", order_7[:100])
assert order_7 == order_2
# Restoring the state of a new dataset should also yield the same
torch.manual_seed(42)
ds = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=worker_config,
batch_size=1,
max_samples_per_sequence=2,
shuffle_buffer_size=sbs,
parallel_shard_iters=psi,
)
loader = get_savable_loader(ds)
loader.restore_state_rank(state_2)
order_8 = [data.text[0] for idx, data in zip(range(n3), loader)]
print("order3", order_3)
print("order8", order_8)
assert order_8 == order_3
def test_invariance_global_samples(self):
# We'd like to ensure that the user can keep the same global batches
# (deterministic pseudo random order) when changing the number of ranks (world size).
# This can be achieved by obeying a few constraints:
# - Global batch size must stay the same across runs
# - Global batch size must be a multiple of (micro-batch size * world_size * num_workers)
# - Global batch size = micro-batch size * world_size * num_workers * gradient_accum_steps
# - world_size * num_workers must stay the same across runs
# Set the same torch.manual_seed(...) on each rank before constructing the dataset and the data loader
scenarios = [
dict(
configs=(WorkerConfig(rank=0, world_size=1, num_workers=4),),
micro_batch_size=2,
global_batch_size=8,
),
dict(
configs=(
WorkerConfig(rank=0, world_size=2, num_workers=2),
WorkerConfig(rank=1, world_size=2, num_workers=2),
),
micro_batch_size=2,
global_batch_size=8,
),
dict(
configs=(
WorkerConfig(rank=0, world_size=4, num_workers=1),
WorkerConfig(rank=1, world_size=4, num_workers=1),
WorkerConfig(rank=2, world_size=4, num_workers=1),
WorkerConfig(rank=3, world_size=4, num_workers=1),
),
micro_batch_size=2,
global_batch_size=8,
),
dict(
configs=(
WorkerConfig(rank=0, world_size=2, num_workers=2),
WorkerConfig(rank=1, world_size=2, num_workers=2),
),
micro_batch_size=1, # Micro-batch 1, more accum
global_batch_size=8,
),
]
# Constraints to user:
global_batches_per_scenario = []
for scenario in scenarios:
assert scenario["global_batch_size"] % scenario["micro_batch_size"] == 0, (
"Global batch size must be a multiple of the micro-batch size."
)
world_size = len(scenario["configs"])
gradient_accum_steps = scenario["global_batch_size"] // (
scenario["micro_batch_size"] * world_size
)
batches_per_rank = []
for rank_config in scenario["configs"]:
torch.manual_seed(42)
ds = get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=rank_config,
batch_size=scenario["micro_batch_size"],
shuffle_buffer_size=42,
max_samples_per_sequence=2,
)
loader = get_loader(ds)
micro_batches = [
data.text
for idx, data in zip(
range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader
)
]
batches_per_rank.append(micro_batches)
# Compose global batches
global_batches_cur_rank = []
batch_index = 0
while batch_index < len(batches_per_rank[0]):
global_batch = []
for _ in range(gradient_accum_steps):
for rank_batches in batches_per_rank:
global_batch.extend(rank_batches[batch_index])
batch_index += 1
if batch_index >= len(batches_per_rank[0]):
# last global batch may be smaller
break
global_batches_cur_rank.append(sorted(global_batch))
global_batches_per_scenario.append(global_batches_cur_rank)
# Check that the global batches are the same
# Assert that all scenarios produced the same number of global batches
assert all(
len(global_batches) == len(global_batches_per_scenario[0])
for global_batches in global_batches_per_scenario
), "Number of global batches per scenario does not match."
for global_batches in global_batches_per_scenario:
print("= Global batches per scenario")
for global_batch in global_batches:
print(" Global batch: ", global_batch)
# Assert that all global batches are the same
for i in range(len(global_batches_per_scenario[0])):
for scenerio_idx, global_batches in enumerate(global_batches_per_scenario):
assert global_batches[i] == global_batches_per_scenario[0][i], (
f"Global batch {i} of scenario {scenerio_idx} does not match."
)
# Delete all locals, otherwise loaders might be kept alive
locals().clear()
gc.collect()
def test_redist(self):
scenarios = [
dict(
configs=(
WorkerConfig(rank=0, world_size=2, num_workers=2),
WorkerConfig(rank=1, world_size=2, num_workers=2),
),
micro_batch_size=2,
global_batch_size=8,
),
dict(
configs=(WorkerConfig(rank=0, world_size=1, num_workers=4),),
micro_batch_size=2,
global_batch_size=8,
),
dict(
configs=(
WorkerConfig(rank=0, world_size=4, num_workers=1),
WorkerConfig(rank=1, world_size=4, num_workers=1),
WorkerConfig(rank=2, world_size=4, num_workers=1),
WorkerConfig(rank=3, world_size=4, num_workers=1),
),
micro_batch_size=2,
global_batch_size=8,
),
dict(
configs=(
WorkerConfig(rank=0, world_size=2, num_workers=2),
WorkerConfig(rank=1, world_size=2, num_workers=2),
),
micro_batch_size=1, # Micro-batch 1, more accum
global_batch_size=8,
),
dict( # Same as original
configs=(
WorkerConfig(rank=0, world_size=2, num_workers=2),
WorkerConfig(rank=1, world_size=2, num_workers=2),
),
micro_batch_size=2,
global_batch_size=8,
),
]
# === Stage 1 first generate a saved state using scenario 0
checkpoint_files = []
global_batches_per_scenario = []
scenario = scenarios[0]
world_size = len(scenario["configs"])
gradient_accum_steps = scenario["global_batch_size"] // (
scenario["micro_batch_size"] * world_size
)
batches_per_rank = []
for rank_config in scenario["configs"]:
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=rank_config,
batch_size=scenario["micro_batch_size"],
shuffle_buffer_size=42,
max_samples_per_sequence=2,
)
)
# Throw away some samples to advance the loader state
num_pre_samples = 20
for _ in zip(range(num_pre_samples), loader):
pass
# Save the state to a file
checkpoint_file = self.checkpoint_dir / f"state_rank{rank_config.rank}.pt"
state = loader.save_state_rank()
torch.save(state, str(checkpoint_file))
checkpoint_files.append(checkpoint_file)
# Now capture the next micro-batches
micro_batches = [
data.text
for idx, data in zip(
range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader
)
]
batches_per_rank.append(micro_batches)
# Compose global batches
global_batches_cur_rank = []
batch_index = 0
while batch_index < len(batches_per_rank[0]):
global_batch = []
for _ in range(gradient_accum_steps):
for rank_batches in batches_per_rank:
global_batch.extend(rank_batches[batch_index])
batch_index += 1
if batch_index >= len(batches_per_rank[0]):
# last global batch may be smaller
break
global_batches_cur_rank.append(sorted(global_batch))
global_batches_per_scenario.append(global_batches_cur_rank)
# === Stage 2: Now check that the global batches are the same after redistribution
for scenario in scenarios[1:]:
# Redistribute the saved state
runner = CliRunner()
result = runner.invoke(
command_redist,
[
"--new-world-size",
str(len(scenario["configs"])),
*[str(cpt) for cpt in checkpoint_files],
str(self.redist_dir),
],
)
print(result.output)
assert result.exception is None, result.exception
assert result.exit_code == 0, "Redistribution failed"
# Load state and check that the global batches are the same
assert scenario["global_batch_size"] % scenario["micro_batch_size"] == 0, (
"Global batch size must be a multiple of the micro-batch size."
)
world_size = len(scenario["configs"])
gradient_accum_steps = scenario["global_batch_size"] // (
scenario["micro_batch_size"] * world_size
)
batches_per_rank = []
for rank_config in scenario["configs"]:
loader = get_savable_loader(
get_train_dataset(
self.dataset_path,
split_part="train",
sample_type=TextSample,
worker_config=rank_config,
batch_size=scenario["micro_batch_size"],
shuffle_buffer_size=42,
max_samples_per_sequence=2,
)
)
state = torch.load(
str(self.redist_dir / f"state_rank{rank_config.rank}.pt"), weights_only=False
)
loader.restore_state_rank(state)
micro_batches = [
data.text
for idx, data in zip(
range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader
)
]
batches_per_rank.append(micro_batches)
# Compose global batches
global_batches_cur_rank = []
batch_index = 0
while batch_index < len(batches_per_rank[0]):
global_batch = []
for _ in range(gradient_accum_steps):
for rank_batches in batches_per_rank:
global_batch.extend(rank_batches[batch_index])
batch_index += 1
if batch_index >= len(batches_per_rank[0]):
# last global batch may be smaller
break
global_batches_cur_rank.append(sorted(global_batch))
global_batches_per_scenario.append(global_batches_cur_rank)
# Check that the global batches are the same
print()
# Assert that all scenarios produced the same global batches
assert all(
len(global_batches) == len(global_batches_per_scenario[0])
for global_batches in global_batches_per_scenario
), "Number of global batches per scenario does not match."
for global_batches in global_batches_per_scenario:
print("= Global batches per scenario")
for global_batch in global_batches:
print(" Global batch: ", global_batch)
# Assert that all global batches are the same
for i in range(len(global_batches_per_scenario[0])):
for scenerio_idx, global_batches in enumerate(global_batches_per_scenario):
assert global_batches[i] == global_batches_per_scenario[0][i], (
f"Global batch {i} of scenario {scenerio_idx} does not match."
)
# Delete all locals, otherwise loaders might be kept alive
locals().clear()
gc.collect()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""This module tests the EPath class, our custom version of pathlib.Path"""
import logging
import multiprocessing
import os
import pickle
import struct
import sys
import unittest
from megatron.energon.epathlib import EPath
from tests.epath_s3_emulator import setup_s3_emulator
class TestEPath(unittest.TestCase):
def setUp(self):
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
def tearDown(self):
pass
def test_basic(self):
"""Some basic functionality tests"""
p_rel = EPath("./subdir")
p_abs = EPath("/tmp")
p_comb = p_abs / p_rel
# Those should not raise:
assert not p_comb.is_file()
assert not p_abs.is_file()
assert p_abs.is_dir()
def test_contextman(self):
"""Test the context manager"""
tmp_file_path = "/tmp/testfile.bin"
# First create a file
with open(tmp_file_path, "wb") as f:
f.write(struct.pack("H10s", 1337, b"1234567890"))
# Test context manager reading
p = EPath(tmp_file_path).open("rb")
print(p)
with p:
b = p.read()
assert isinstance(b, bytes)
num, data = struct.unpack("H10s", b)
logging.info(f"num: {num}")
assert num == 1337
assert data == b"1234567890"
# Test context manager writing
tmp_file_path2 = "/tmp/testfile2.bin"
with EPath(tmp_file_path2).open("wb") as p:
p.write(struct.pack("H10s", 1337, b"1234567890"))
def test_localfs(self):
"""Test the local filesystem"""
p = EPath("/tmp/testfile.bin")
with p.open("wb") as f:
f.write(b"dummycontent")
assert p.is_file()
assert p.size() == 12
with p.open("rb") as f:
assert f.read() == b"dummycontent"
# Test relative paths
revert_dir = os.getcwd()
try:
os.chdir("/tmp")
p = EPath("testfile.bin")
assert str(p) == "/tmp/testfile.bin"
assert p.is_file()
assert p.size() == 12
with p.open("rb") as f:
assert f.read() == b"dummycontent"
p = EPath("nonexisting/../testfile.bin")
assert str(p) == "/tmp/testfile.bin"
p = EPath("../tmp/testfile.bin")
assert str(p) == "/tmp/testfile.bin"
finally:
os.chdir(revert_dir)
p.unlink()
assert p.is_file() is False
def test_glob(self):
"""Test the glob functionality"""
# First create some files
for i in range(10):
with open(f"/tmp/epathtestfile_{i}.bin", "wb") as f:
f.write(b"dummycontent")
# Test globbing
p = EPath("/tmp").glob("epathtestfile_*.bin")
logging.info(f"p: {p}, type of p: {type(p)}")
elems = list(p)
assert len(elems) == 10
for i, e in enumerate(elems):
logging.info(f"glob_result[{i}]: {e}")
assert isinstance(e, EPath)
assert e.is_file()
# Test globbing with a pattern
p = EPath("/tmp").glob("epathtestfile_[0-3].bin")
assert len(list(p)) == 4
def test_s3_path_resolution(self):
"""Test s3 path resolution"""
rclone_config_path = EPath("/tmp/XDG_CONFIG_HOME/.config/rclone/rclone.conf")
with rclone_config_path.open("w") as f:
f.write(
"\n".join(
[
"[s3]",
"type = s3",
"env_auth = false",
"access_key_id = dummy",
"secret_access_key = dummy",
"region = dummy",
"endpoint = https://localhost",
]
)
)
orig_xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
os.environ["XDG_CONFIG_HOME"] = "/tmp/XDG_CONFIG_HOME/.config"
os.environ["HOME"] = "/tmp/XDG_CONFIG_HOME"
# Hack to clear the cache of the rclone config for msc to get the "s3" profile
from multistorageclient.rclone import read_rclone_config
read_rclone_config.cache_clear()
try:
# Test globbing
p = EPath("msc://s3/tmp/path/subpath.txt")
assert str(p) == "msc://s3/tmp/path/subpath.txt", str(p)
p2 = p / ".." / "subpath2.txt"
assert str(p2) == "msc://s3/tmp/path/subpath2.txt", str(p2)
p3 = EPath("msc://s3/tmp/path/.././subpath.txt")
assert str(p3) == "msc://s3/tmp/subpath.txt", str(p3)
p4 = p3.parent / "../bla/bla/bla/../../../no/../subpath2.txt"
assert str(p4) == "msc://s3/subpath2.txt", str(p4)
# Test warning for deprecated rclone protocol
with self.assertWarns((DeprecationWarning, FutureWarning)) as warning:
# Test rclone backwards compatibility
pr = EPath("rclone://s3/tmp/path/.././subpath.txt")
assert str(pr) == "msc://s3/tmp/subpath.txt", str(pr)
assert "deprecated" in str(warning.warnings[0].message)
# Test pickle / unpickle
p4serialized = pickle.dumps(p4)
# No secret must be serialized
assert b"dummy" not in p4serialized
finally:
if orig_xdg_config_home is not None:
os.environ["XDG_CONFIG_HOME"] = orig_xdg_config_home
else:
del os.environ["XDG_CONFIG_HOME"]
rclone_config_path.unlink()
def test_multi_storage_client(self):
"""Test the Multi-Storage Client integration"""
# Test path handling
p = EPath("msc://default/etc/resolv.conf")
assert str(p) == "/etc/resolv.conf", str(p)
assert p.is_file()
p2 = p / ".." / "hosts"
assert str(p2) == "/etc/hosts", str(p2)
# Test glob
p3 = EPath("msc://default/etc/")
assert p3.is_dir()
for i in p3.glob("*.conf"):
assert str(i).endswith(".conf")
# Test open file
assert p.size() > 0
with p.open("r") as fp:
assert len(fp.read()) > 0
# Test move and delete
p4 = EPath("msc://default/tmp/random_file_0001")
p4.unlink()
with p4.open("w") as fp:
fp.write("*****")
assert p4.is_file()
p5 = EPath("msc://default/tmp/random_file_0002")
p5.unlink()
assert p5.is_file() is False
p4.move(p5)
assert p5.is_file()
assert p4.is_file() is False
p5.unlink()
assert p5.is_file() is False
# Test pickle / unpickle
p5serialized = pickle.dumps(p5)
p5unserialized = pickle.loads(p5serialized)
assert p5unserialized == p5
assert str(p5unserialized) == str(p5)
def test_multiprocessing(self):
"""Test EPath in multiprocessing context"""
p = EPath("/tmp/path/subpath.txt")
orig_start_method = multiprocessing.get_start_method()
try:
multiprocessing.set_start_method("spawn", force=True)
proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True))
proc.start()
proc.join()
assert proc.exitcode == 0
multiprocessing.set_start_method("fork", force=True)
proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True))
proc.start()
proc.join()
assert proc.exitcode == 0
finally:
multiprocessing.set_start_method(orig_start_method, force=True)
def test_multiprocessing_msc(self):
"""Test EPath in multiprocessing context"""
p = EPath("msc://default/tmp/random_file_0001")
with p.open("w") as fp:
fp.write("*****")
orig_start_method = multiprocessing.get_start_method()
try:
multiprocessing.set_start_method("spawn", force=True)
proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True))
proc.start()
proc.join()
assert proc.exitcode == 0
multiprocessing.set_start_method("fork", force=True)
proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True))
proc.start()
proc.join()
assert proc.exitcode == 0
finally:
multiprocessing.set_start_method(orig_start_method, force=True)
p.unlink()
def test_msc_s3(self):
# Test S3 with MSC
with setup_s3_emulator(profile_name="s3test_msc"):
p = EPath("msc://s3test_msc/test/dir/file.txt")
assert not p.is_file()
p.write_text("dummy")
assert p.is_file()
assert p.size() > 0
assert p.read_text() == "dummy"
# TODO: Fix when fixed in MSC.
# assert EPath("msc://s3test_msc/test").is_dir()
assert EPath("msc://s3test_msc/test/dir").is_dir()
p.unlink()
assert not p.is_file()
# assert not EPath("msc://s3test_msc/test").is_dir()
assert not EPath("msc://s3test_msc/test/dir").is_dir()
def _multiproc_test_func(p: EPath, test_function: bool):
"""Helper function for multiprocessing test"""
print(f"str: {str(p)}")
if test_function:
print(f"is_file: {p.is_file()}")
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import gc
import tempfile
import threading
import time
import unittest
from pathlib import Path
from typing import Any, Dict, Optional
from megatron.energon.cache import DecodeFileStore, FileCacheLazy, FileStore, FileStoreCachePool
from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder
from megatron.energon.source_info import SourceInfo
class MockFileStore(FileStore):
"""Mock implementation of FileStore for testing"""
def __init__(self, data: Optional[Dict[str, Any]] = None, path: str = "mock_store"):
self._data = data if data is not None else {}
self._path = path
def __getitem__(self, key: str) -> tuple[Any, SourceInfo]:
return self._data[key], SourceInfo(
dataset_path=self._path,
index=None,
shard_name=None,
file_names=(key,),
)
def get_path(self) -> str:
return self._path
class MockDecoder(SampleDecoder):
"""Mock decoder for DecodeFileStore"""
def decode(self, fname: str, raw: bytes) -> Any:
return f"{fname}: {raw.decode()}"
class TestFileStoreCachePool(unittest.TestCase):
"""Test cases for FileStoreCachePool"""
def setUp(self):
"""Setup test environment before each test"""
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_path = Path(self.temp_dir.name)
def tearDown(self):
"""Clean up after each test"""
self.temp_dir.cleanup()
def test_get_method(self):
"""Test the synchronous get method"""
# Create mock file stores
mock_raw_file_store = MockFileStore(
{
"file1": b"test data 1",
"file2": b"test data 2",
"file3": b"test data 3",
}
)
mock_decode_file_store = DecodeFileStore(
decoder=MockDecoder(),
inner_reader=mock_raw_file_store,
)
pool = FileStoreCachePool(parent_cache_dir=self.temp_path)
try:
# get should directly read from the dataset without caching
sample_for_source_info = {"__sources__": []}
result = pool.get(mock_raw_file_store, "file1", sample_for_source_info)
assert result == b"test data 1"
assert len(sample_for_source_info["__sources__"]) == 1
assert (
sample_for_source_info["__sources__"][0].dataset_path
== mock_raw_file_store.get_path()
)
assert sample_for_source_info["__sources__"][0].index is None
assert sample_for_source_info["__sources__"][0].shard_name is None
assert sample_for_source_info["__sources__"][0].file_names == ("file1",)
# get should directly read from the dataset without caching
sample_for_source_info = {"__sources__": []}
result = pool.get(mock_decode_file_store, "file1", sample_for_source_info)
assert result == "file1: test data 1"
assert len(sample_for_source_info["__sources__"]) == 1
assert (
sample_for_source_info["__sources__"][0].dataset_path
== mock_decode_file_store.get_path()
)
assert sample_for_source_info["__sources__"][0].index is None
assert sample_for_source_info["__sources__"][0].shard_name is None
assert sample_for_source_info["__sources__"][0].file_names == ("file1",)
finally:
pool.close()
def test_get_lazy_method(self):
"""Test the lazy get method for background prefetching"""
pool = FileStoreCachePool(parent_cache_dir=self.temp_path)
# Create mock file stores
mock_raw_file_store = MockFileStore(
{
"file1": b"test data 1",
}
)
try:
# Request lazy loading
lazy_ref = pool.get_lazy(mock_raw_file_store, "file1")
# Verify the return type
assert isinstance(lazy_ref, FileCacheLazy)
# Wait for the background task
lazy_ref.entry.send_to_cache_future.result()
# Check that the file exists in the cache directory
cache_files = list(pool.cache_dir.glob("*"))
assert len(cache_files) == 1
# Get the data
result = lazy_ref.get()
assert result == b"test data 1"
finally:
pool.close()
def test_shared_references(self):
"""Test that multiple references share the same background task"""
pool = FileStoreCachePool(parent_cache_dir=self.temp_path)
# Create mock file stores
mock_raw_file_store = MockFileStore(
{
"file1": b"test data 1",
}
)
try:
# Check that the file exists in the cache directory
cache_files = list(pool.cache_dir.rglob("*"))
assert len(cache_files) == 0
# Request lazy loading for the same file twice
lazy_ref1 = pool.get_lazy(mock_raw_file_store, "file1")
lazy_ref2 = pool.get_lazy(mock_raw_file_store, "file1")
# Check that they share the same entry
assert lazy_ref1.entry is lazy_ref2.entry
# Check that refcount is 2
assert lazy_ref1.entry.refcount == 2
# Wait for the background task
lazy_ref1.entry.send_to_cache_future.result()
# Check that the file exists in the cache directory
cache_files = list(pool.cache_dir.rglob("*"))
assert len(cache_files) == 1, cache_files
# Get data from both references
sample_with_source_info = {"__sources__": []}
result1 = lazy_ref1.get(sample_with_source_info)
assert lazy_ref1.entry.refcount == 1
sample_with_source_info2 = {"__sources__": []}
result2 = lazy_ref2.get(sample_with_source_info2)
assert lazy_ref1.entry.refcount == 0
# Check that the file exists in the cache directory
cache_files = list(pool.cache_dir.rglob("*"))
assert len(cache_files) == 0
assert result1 == b"test data 1"
assert result2 == b"test data 1"
assert (
sample_with_source_info["__sources__"][0].dataset_path
== sample_with_source_info2["__sources__"][0].dataset_path
)
assert sample_with_source_info["__sources__"][0].index is None
assert sample_with_source_info["__sources__"][0].shard_name is None
assert (
sample_with_source_info["__sources__"][0].file_names
== sample_with_source_info2["__sources__"][0].file_names
)
finally:
pool.close()
def test_cache_size_management(self):
"""Test that the cache respects size limits and evicts files"""
# Create a cache pool with strict limits
pool = FileStoreCachePool(
parent_cache_dir=self.temp_path,
max_cache_size_gbytes=0.0001, # ~100KB
max_cache_count=2,
num_workers=1,
)
# Set to a safe byte size
pool.max_cache_size = 75_000
mock_raw_file_store = MockFileStore(
{
"large_file1": b"a" * 50_000,
"large_file2": b"b" * 50_000,
"large_file3": b"c" * 50_000,
"large_file4": b"d" * 25_000,
"large_file5": b"e" * 25_000,
"large_file6": b"f" * 25_000,
}
)
try:
# Enqueue all fetches
lazy1 = pool.get_lazy(mock_raw_file_store, "large_file1")
lazy2 = pool.get_lazy(mock_raw_file_store, "large_file2")
lazy3 = pool.get_lazy(mock_raw_file_store, "large_file3")
lazy4 = pool.get_lazy(mock_raw_file_store, "large_file4")
lazy2_2 = pool.get_lazy(mock_raw_file_store, "large_file2")
lazy2_3 = pool.get_lazy(mock_raw_file_store, "large_file2")
lazy3_2 = pool.get_lazy(mock_raw_file_store, "large_file3")
lazy5 = pool.get_lazy(mock_raw_file_store, "large_file5")
lazy6 = pool.get_lazy(mock_raw_file_store, "large_file6")
lazy6_2 = pool.get_lazy(mock_raw_file_store, "large_file6")
def status():
return [
(
name,
lazy.entry.refcount,
"consumed"
if lazy._data
else ("cached" if lazy.entry.send_to_cache_future.done() else "pending"),
)
for lazy, name in (
[
(lazy1, "1"),
(lazy2, "2"),
(lazy2_2, "2_2"),
(lazy2_3, "2_3"),
(lazy3, "3"),
(lazy3_2, "3_2"),
(lazy4, "4"),
(lazy5, "5"),
(lazy6, "6"),
]
+ ([(lazy6_2, "6_2")] if lazy6_2 is not None else [])
)
]
def txt_status():
out = []
for lazy in [
lazy1,
lazy2,
lazy2_2,
lazy2_3,
lazy3,
lazy3_2,
lazy4,
lazy5,
lazy6,
] + ([lazy6_2] if lazy6_2 is not None else []):
if lazy._data is not None:
out.append(
f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] consumed"
)
elif lazy.entry.send_to_cache_future.done():
out.append(
f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] cached"
)
else:
out.append(
f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] pending"
)
return (
f"Cached Count: {pool.current_cache_count}, Cache size: {pool.current_cache_size}\n"
+ "\n".join(out)
)
# lazy2_2 and lazy2_3 should share the same entry as lazy2
assert lazy2_2.entry is lazy2.entry
assert lazy2_3.entry is lazy2.entry
lazy1.entry.send_to_cache_future.result(timeout=1)
# Wait for the background tasks to finish
time.sleep(0.5)
print("Checking cache status")
# They should not be able to finish, because the cache is full
# Queue state: [2<50>, 3<50>, 4<25>, 5<25>, 6<25>], cached out: [1<50>], removed: []
assert status() == [
("1", 1, "cached"),
("2", 3, "pending"),
("2_2", 3, "pending"),
("2_3", 3, "pending"),
("3", 2, "pending"),
("3_2", 2, "pending"),
("4", 1, "pending"),
("5", 1, "pending"),
("6", 2, "pending"),
("6_2", 2, "pending"),
], txt_status()
# Check cache count and size before second file
assert pool.current_cache_count == 1, pool.current_cache_count
assert pool.current_cache_size == 50_000, pool.current_cache_size
print("Fetching lazy2_3")
# Now, fetching the second file should still work directly and ignore the caching
# But it will requeue fetching the second file to the background thread for the remaining lazies.
result2_3 = lazy2_3.get()
assert result2_3 == b"b" * 50_000
# They should not be able to finish, because the cache is full
# Queue state: [3<50>, 4<25>, 5<25>, 6<25>, 2<50>], cached out: [1<50>], removed: []
assert status() == [
("1", 1, "cached"),
("2", 2, "pending"),
("2_2", 2, "pending"),
("2_3", 2, "consumed"),
("3", 2, "pending"),
("3_2", 2, "pending"),
("4", 1, "pending"),
("5", 1, "pending"),
("6", 2, "pending"),
("6_2", 2, "pending"),
], txt_status()
# Fetch
result1 = lazy1.get()
assert result1 == b"a" * 50_000
lazy3.entry.send_to_cache_future.result(timeout=1)
time.sleep(0.5)
# Second file is now queued at the end.
# File 3 and 4 should now be cached.
# Queue state: [5<25>, 6<25>, 2<50>], cached out: [3<50>, 4<25>], removed: [1<50>]
assert status() == [
("1", 0, "consumed"),
("2", 2, "pending"),
("2_2", 2, "pending"),
("2_3", 2, "consumed"),
("3", 2, "cached"),
("3_2", 2, "cached"),
("4", 1, "cached"),
("5", 1, "pending"),
("6", 2, "pending"),
("6_2", 2, "pending"),
], txt_status()
assert pool.current_cache_count == 2
assert pool.current_cache_size == 75_000
result3 = lazy3.get()
assert result3 == b"c" * 50_000
time.sleep(0.5)
# Space by large_file3 is still occupied in cache
# Queue state: [5<25>, 6<25>, 2<50>], cached out: [3<50>, 4<25>], removed: [1<50>]
assert status() == [
("1", 0, "consumed"),
("2", 2, "pending"),
("2_2", 2, "pending"),
("2_3", 2, "consumed"),
("3", 1, "consumed"),
("3_2", 1, "cached"),
("4", 1, "cached"),
("5", 1, "pending"),
("6", 2, "pending"),
("6_2", 2, "pending"),
], txt_status()
assert pool.current_cache_count == 2
assert pool.current_cache_size == 75_000
result3_2 = lazy3_2.get()
assert result3_2 == b"c" * 50_000
time.sleep(0.5)
# Space by large_file3 was freed now, 4, 5, and 6 should fit now, large_file2 not yet
# Queue state: [6<25>, 2<50>], cached out: [5<25>, 4<25>], removed: [1<50>, 3<50>]
assert status() == [
("1", 0, "consumed"),
("2", 2, "pending"),
("2_2", 2, "pending"),
("2_3", 2, "consumed"),
("3", 0, "consumed"),
("3_2", 0, "consumed"),
("4", 1, "cached"),
("5", 1, "cached"),
("6", 2, "pending"),
("6_2", 2, "pending"),
], txt_status()
assert pool.current_cache_count == 2
assert pool.current_cache_size == 50_000
result4 = lazy4.get()
assert result4 == b"d" * 25_000
time.sleep(0.5)
# Nothing changed, no space for large_file2 still
# Queue state: [6<25>, 2<50>], cached out: [5<25>, 4<25>], removed: [1<50>, 3<50>, 4<25>]
assert status() == [
("1", 0, "consumed"),
("2", 2, "pending"),
("2_2", 2, "pending"),
("2_3", 2, "consumed"),
("3", 0, "consumed"),
("3_2", 0, "consumed"),
("4", 0, "consumed"),
("5", 1, "cached"),
("6", 2, "cached"),
("6_2", 2, "cached"),
], txt_status()
assert pool.current_cache_count == 2
assert pool.current_cache_size == 50_000
result5 = lazy5.get()
assert result5 == b"e" * 25_000
time.sleep(0.5)
# Now large_file2 can be cached
# Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>]
assert status() == [
("1", 0, "consumed"),
("2", 2, "cached"),
("2_2", 2, "cached"),
("2_3", 2, "consumed"),
("3", 0, "consumed"),
("3_2", 0, "consumed"),
("4", 0, "consumed"),
("5", 0, "consumed"),
("6", 2, "cached"),
("6_2", 2, "cached"),
], txt_status()
assert pool.current_cache_count == 2
assert pool.current_cache_size == 75_000
result6 = lazy6.get()
assert result6 == b"f" * 25_000
# Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>]
assert status() == [
("1", 0, "consumed"),
("2", 2, "cached"),
("2_2", 2, "cached"),
("2_3", 2, "consumed"),
("3", 0, "consumed"),
("3_2", 0, "consumed"),
("4", 0, "consumed"),
("5", 0, "consumed"),
("6", 1, "consumed"),
("6_2", 1, "cached"),
], txt_status()
assert pool.current_cache_count == 2
assert pool.current_cache_size == 75_000
result2 = lazy2.get()
assert result2 == b"b" * 50_000
# Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>]
assert status() == [
("1", 0, "consumed"),
("2", 1, "consumed"),
("2_2", 1, "cached"),
("2_3", 1, "consumed"),
("3", 0, "consumed"),
("3_2", 0, "consumed"),
("4", 0, "consumed"),
("5", 0, "consumed"),
("6", 1, "consumed"),
("6_2", 1, "cached"),
], txt_status()
assert pool.current_cache_count == 2
assert pool.current_cache_size == 75_000
result2_2 = lazy2_2.get()
assert result2_2 == b"b" * 50_000
# Cache should only contain large_file6 now
# Queue state: [], cached out: [6<25>], removed: [1<50>, 3<50>, 4<25>, 5<25>, 2<50>]
assert status() == [
("1", 0, "consumed"),
("2", 0, "consumed"),
("2_2", 0, "consumed"),
("2_3", 0, "consumed"),
("3", 0, "consumed"),
("3_2", 0, "consumed"),
("4", 0, "consumed"),
("5", 0, "consumed"),
("6", 1, "consumed"),
("6_2", 1, "cached"),
], txt_status()
assert pool.current_cache_count == 1, txt_status()
assert pool.current_cache_size == 25_000
# Delete the last reference to large_file6, it should be removed from the cache
lazy6_2 = None
gc.collect()
# Cache should be empty now
# Queue state: [], cached out: [], removed: [1<50>, 3<50>, 4<25>, 5<25>, 6<25>, 2<50>]
assert status() == [
("1", 0, "consumed"),
("2", 0, "consumed"),
("2_2", 0, "consumed"),
("2_3", 0, "consumed"),
("3", 0, "consumed"),
("3_2", 0, "consumed"),
("4", 0, "consumed"),
("5", 0, "consumed"),
("6", 0, "consumed"),
], txt_status()
assert pool.current_cache_count == 0, txt_status()
assert pool.current_cache_size == 0
# Check that the cache directory is empty
assert not list(pool.cache_dir.glob("*"))
finally:
pool.close()
def test_raw_method(self):
"""Test the 'raw' caching method with DecodeFileStore"""
pool = FileStoreCachePool(parent_cache_dir=self.temp_path, method="raw")
mock_raw_file_store = MockFileStore(
{
"file1": b"test data 1",
}
)
mock_decode_file_store = DecodeFileStore(
decoder=MockDecoder(),
inner_reader=mock_raw_file_store,
)
try:
# Request lazy loading
lazy_ref = pool.get_lazy(mock_decode_file_store, "file1")
# Wait for background task
time.sleep(0.5)
# Get the data - should be decoded
sample_with_source_info = {"__sources__": []}
result = lazy_ref.get(sample_with_source_info)
assert result == "file1: test data 1"
assert (
sample_with_source_info["__sources__"][0].dataset_path
== mock_decode_file_store.get_path()
)
assert sample_with_source_info["__sources__"][0].index is None
assert sample_with_source_info["__sources__"][0].shard_name is None
assert sample_with_source_info["__sources__"][0].file_names == ("file1",)
finally:
pool.close()
def test_pickle_method(self):
"""Test the 'pickle' caching method"""
pool = FileStoreCachePool(parent_cache_dir=self.temp_path, method="pickle")
mock_raw_file_store = MockFileStore(
{
"file1": b"test data 1",
}
)
mock_decode_file_store = DecodeFileStore(
decoder=MockDecoder(),
inner_reader=mock_raw_file_store,
)
try:
# Request lazy loading
lazy_ref = pool.get_lazy(mock_decode_file_store, "file1")
# Wait for background task
lazy_ref.entry.send_to_cache_future.result()
# Get the data - should be unpickled correctly
sample_with_source_info = {"__sources__": []}
result = lazy_ref.get(sample_with_source_info)
assert result == "file1: test data 1"
assert (
sample_with_source_info["__sources__"][0].dataset_path
== mock_decode_file_store.get_path()
)
assert sample_with_source_info["__sources__"][0].index is None
assert sample_with_source_info["__sources__"][0].shard_name is None
assert sample_with_source_info["__sources__"][0].file_names == ("file1",)
# Request lazy loading
lazy_ref = pool.get_lazy(mock_raw_file_store, "file1")
# Wait for background task
lazy_ref.entry.send_to_cache_future.result()
# Get the data - should be unpickled correctly
sample_with_source_info = {"__sources__": []}
result = lazy_ref.get(sample_with_source_info)
assert result == b"test data 1"
assert (
sample_with_source_info["__sources__"][0].dataset_path
== mock_raw_file_store.get_path()
)
assert sample_with_source_info["__sources__"][0].index is None
assert sample_with_source_info["__sources__"][0].shard_name is None
assert sample_with_source_info["__sources__"][0].file_names == ("file1",)
finally:
pool.close()
def test_concurrent_access(self):
"""Test concurrent access to the cache pool"""
pool = FileStoreCachePool(parent_cache_dir=self.temp_path)
mock_raw_file_store = MockFileStore(
{
"file1": b"test data 1",
}
)
results = []
def worker(filename):
lazy_ref = pool.get_lazy(mock_raw_file_store, filename)
result, source_info = lazy_ref.get()
results.append(result)
assert source_info.dataset_path == mock_raw_file_store.get_path()
assert source_info.index is None
assert source_info.shard_name is None
assert source_info.file_names == (filename,)
try:
# Start multiple threads accessing the same file
threads = []
for i in range(5):
t = threading.Thread(target=worker, args=("file1",))
threads.append(t)
t.start()
# Wait for all threads to complete
for t in threads:
t.join()
# All threads should get the correct result
for r in results:
assert r == b"test data 1"
finally:
pool.close()
def test_to_cache(self):
"""Test that the cache out method works"""
pool = FileStoreCachePool(parent_cache_dir=self.temp_path)
try:
# Get the data - should be pickled / unpickled correctly
result = pool.to_cache((1, "some_data", 2), "file1")
cache_path = result.cache_path
# Check that the cache file exists
assert cache_path is not None
assert cache_path.is_file()
assert pool.cache_dir == cache_path.parent
# Verify that the data is read correctly, also two times.
assert result.get() == (1, "some_data", 2)
assert result.get() == (1, "some_data", 2)
# Verify that the cache file is deleted now that we've read the data.
assert result.cache_path is None
assert not cache_path.is_file()
# Verify that the cache file is deleted when the object is deleted before reading the file.
result2 = pool.to_cache((1, "some_data", 2), "file2")
assert result2.cache_path is not None
assert result2.cache_path.is_file()
assert result2.cache_path != cache_path
cache_path = result2.cache_path
del result2
gc.collect()
assert not cache_path.is_file()
finally:
pool.close()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""This module defines tests for meta datasets."""
import gc
import json
import logging
import random
import sys
import tempfile
import unittest
import warnings
from collections import Counter
from pathlib import Path
from typing import Iterable
import torch
from click.testing import CliRunner
from megatron.energon import (
Cooker,
CrudeSample,
DefaultTaskEncoder,
Sample,
WorkerConfig,
basic_sample_keys,
edataclass,
get_loader,
get_train_dataset,
stateless,
)
from megatron.energon.tools.prepare import command as prepare_command
from tests.epath_s3_emulator import setup_s3_emulator
# Speed up tests significantly by reducing the torch status check interval for broken worker shutdown
try:
torch.utils.data._utils.worker.MP_STATUS_CHECK_INTERVAL = 0.1
torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.1
except AttributeError:
pass
@edataclass
class TextSample(Sample):
idx: int
text: str
@stateless()
def cook_text(sample: CrudeSample) -> TextSample:
return TextSample(
**basic_sample_keys(sample), idx=sample["json"]["idx"], text=sample["json"]["txt"]
)
class SimpleCookingTaskEncoder(DefaultTaskEncoder):
cookers = [Cooker(cook=cook_text)]
class TestJsonlDataset(unittest.TestCase):
# Set up the test fixture
def setUp(self):
random.seed(42)
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
warnings.simplefilter("ignore", ResourceWarning)
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory()
self.dataset_path = Path(self.temp_dir.name)
# self.dataset_path = Path("./test_dataset")
self.dataset_path.mkdir(exist_ok=True, parents=True)
# Create a small dummy datasets
self.create_text_test_dataset(self.dataset_path / "ds1.jsonl", range(55), range(55))
self.create_text_test_dataset(
self.dataset_path / "ds2.jsonl", range(100, 155), range(100, 155)
)
self.create_text_test_dataset(self.dataset_path / "ds3.jsonl", range(200, 255), range(55))
self.mds_all_path = self.dataset_path / "metadataset_all.yaml"
with open(self.mds_all_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend:",
" - path: ds1.jsonl",
" subflavors:",
" ds: ds1",
" - path: ds2.jsonl",
" subflavors:",
" ds: ds2",
" - path: ds3.jsonl",
" subflavors:",
" ds: ds3",
]
)
)
def tearDown(self):
# Remove all temporary files
gc.collect()
self.temp_dir.cleanup()
@staticmethod
def create_text_test_dataset(
path: Path, txt_range: Iterable[int], key_range: Iterable[int], prefix: str = ""
):
"""Creates a small dummy test dataset for testing purposes."""
# Write jsonl file
with open(path, "w") as wf:
for key, txt in zip(key_range, txt_range):
# Write JSON entries to the file, one per line.
wf.write(json.dumps({"idx": key, "txt": f"{prefix}{txt}"}) + "\n")
from megatron.energon.flavors import CrudeJsonlDatasetFactory
CrudeJsonlDatasetFactory.prepare_dataset(path)
def test_dataset(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Train mode dataset
train_dataset = get_train_dataset(
self.dataset_path / "ds1.jsonl",
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=SimpleCookingTaskEncoder(),
)
print(len(train_dataset))
assert len(train_dataset) == 55, f"Expected 55 samples, got {len(train_dataset)}"
train_loader1 = get_loader(train_dataset)
train_order1 = [
text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text
]
print(train_order1[:10])
print(Counter(train_order1))
assert len(Counter(train_order1)) == 55
assert all(v == 10 for v in Counter(train_order1).values())
def test_metadataset_all(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=2,
seed_offset=42,
)
# Train mode dataset
train_dataset = get_train_dataset(
self.mds_all_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=SimpleCookingTaskEncoder(),
)
print(len(train_dataset))
assert len(train_dataset) == 55 * 3, f"Expected 55 * 3 samples, got {len(train_dataset)}"
train_loader1 = get_loader(train_dataset)
train_order1 = [
text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text
]
print(train_order1[:10])
print(Counter(train_order1))
assert len(Counter(train_order1)) == 55 * 3
assert all(2 <= v <= 5 for v in Counter(train_order1).values())
def test_metadataset_multirank(self):
torch.manual_seed(42)
sample_counts = Counter()
expected_lens = [19, 19, 17]
for cur_rank in range(3):
worker_config = WorkerConfig(
rank=cur_rank,
world_size=3,
num_workers=5,
seed_offset=42,
)
# Train mode dataset
train_dataset = get_train_dataset(
self.dataset_path / "ds1.jsonl",
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=SimpleCookingTaskEncoder(),
repeat=False,
)
print(len(train_dataset))
assert len(train_dataset) == expected_lens[cur_rank], (
f"Expected {expected_lens[cur_rank]} samples, got {len(train_dataset)}"
)
train_loader1 = get_loader(train_dataset)
for data in train_loader1:
sample_counts[int(data.text[0])] += 1
for i in range(55):
assert sample_counts[i] == 1, (
f"Sample {i} should have been seen exactly once, but was seen {sample_counts[i]} times."
)
def test_s3(self):
# Create a joined dataset configuration
mixed_mds_path = self.dataset_path / "metadataset_mixed.yaml"
with open(mixed_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" path: msc://s3test_jsonl_dataset/test/dataset/metadataset_all.yaml",
]
)
)
with setup_s3_emulator(profile_name="s3test_jsonl_dataset") as emu:
# Upload the dataset to the S3 emulator
# EPath(self.dataset_path).copy(EPath("msc://s3/test/dataset"))
emu.add_file(self.dataset_path, "test/dataset")
train_dataset = get_loader(
get_train_dataset(
mixed_mds_path,
worker_config=WorkerConfig(
rank=0,
world_size=1,
num_workers=2,
),
batch_size=1,
shuffle_buffer_size=10,
max_samples_per_sequence=None,
virtual_epoch_length=55 * 10,
task_encoder=SimpleCookingTaskEncoder(),
)
)
data = list(enumerate(train_dataset))
assert len(data) == 55 * 10, len(data)
cnt = Counter(t for _, entry in data for t in entry.text)
assert len(cnt) == 55 * 3
assert all(2 <= v <= 5 for v in cnt.values())
def test_prepare(self):
print("Creating new dataset")
with open(self.dataset_path / "ds_prep.jsonl", "w") as f:
for i in range(10):
f.write(json.dumps({"idx": i, "txt": f"{i}"}) + "\n\n")
runner = CliRunner()
result = runner.invoke(
prepare_command,
[str(self.dataset_path / "ds_prep.jsonl")],
catch_exceptions=False,
)
print(result.stdout)
assert result.exit_code == 0, "Prepare failed, see output"
assert "Done" in result.stdout, "Prepare failed, see output"
assert "Found 10 samples" in result.stdout, "Prepare failed, see output"
assert (self.dataset_path / "ds_prep.jsonl.idx").exists()
torch.manual_seed(42)
# Train mode dataset
train_loader = get_loader(
get_train_dataset(
self.dataset_path / "ds_prep.jsonl",
worker_config=WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
),
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=SimpleCookingTaskEncoder(),
)
)
assert len(train_loader) == 10, f"Expected 10 samples, got {len(train_loader)}"
train_order1 = [text for _, data in zip(range(50), train_loader) for text in data.text]
print(train_order1[:10])
print(Counter(train_order1))
assert len(Counter(train_order1)) == 10
assert all(v == 5 for v in Counter(train_order1).values())
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""This module defines tests for meta datasets."""
import gc
import logging
import sys
import tempfile
import time
import unittest
import warnings
from collections import Counter
from pathlib import Path
from typing import Any, Iterable
import torch
import webdataset as wds
from megatron.energon import (
SavableDataset,
WorkerConfig,
get_loader,
get_savable_loader,
get_train_dataset,
get_val_dataset,
load_dataset,
)
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.metadataset.loader_interface import DatasetBlendMode
from megatron.energon.wrappers.blend_dataset import BlendDataset
# Speed up tests significantly by reducing the torch status check interval for broken worker shutdown
try:
torch.utils.data._utils.worker.MP_STATUS_CHECK_INTERVAL = 0.1
torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.1
except AttributeError:
pass
def _norng_state(state):
if isinstance(state, bytes):
if len(state) > 100:
return state[:5] + f"...<len={len(state)}>".encode()
return state
elif isinstance(state, str):
if len(state) > 100:
return state[:5] + f"...<len={len(state)}>"
return state
elif isinstance(state, dict):
return {k: _norng_state(v) for k, v in state.items()}
elif isinstance(state, (list, tuple)):
if len(state) > 100:
state = state[:5]
return type(state)(_norng_state(v) for v in state)
else:
return state
def get_blend_dataset(ds: SavableDataset):
if isinstance(ds, BlendDataset):
return ds
else:
if hasattr(ds, "dataset"):
return get_blend_dataset(ds.dataset)
else:
raise ValueError("No blend dataset found")
def assert_nested_equal(a: Any, b: Any, path: str = "") -> None:
"""
Recursively checks that two nested data structures (consisting of dicts, lists, tuples,
and other basic types) are equal. If they are not equal, prints the path of the first mismatch
and raises an AssertionError.
Args:
a: First nested structure to compare.
b: Second nested structure to compare.
path: Internal parameter used to pass the current traversal path (do not set this manually).
Raises:
AssertionError: If a mismatch is found.
"""
# Check if types differ
if type(a) is not type(b):
mismatch_details = f"Type mismatch at {path or '<root>'}: {type(a)} != {type(b)}"
print(mismatch_details)
raise AssertionError(mismatch_details)
# If they are both dictionaries, compare each key and value
if isinstance(a, dict):
# Check if they have the same keys
a_keys = set(a.keys())
b_keys = set(b.keys())
if a_keys != b_keys:
missing_in_a = b_keys - a_keys
missing_in_b = a_keys - b_keys
mismatch_details = (
f"Key mismatch at {path or '<root>'}:\n"
+ "Missing in first object: "
+ ", ".join(f"[{k}]={b[k]!r}" for k in missing_in_a)
+ "\n"
+ "Missing in second object: "
+ ", ".join(f"[{k}]={a[k]!r}" for k in missing_in_b)
+ "\n"
)
print(mismatch_details)
raise AssertionError(mismatch_details)
for key in a:
sub_path = f"{path}['{key}']" if path else f"['{key}']"
assert_nested_equal(a[key], b[key], sub_path)
# If they are lists (or tuples), compare elements in order
elif isinstance(a, (list, tuple)):
if len(a) != len(b):
mismatch_details = f"Length mismatch at {path or '<root>'}: {len(a)} != {len(b)}"
print(mismatch_details)
raise AssertionError(mismatch_details)
for index, (item_a, item_b) in enumerate(zip(a, b)):
sub_path = f"{path}[{index}]" if path else f"[{index}]"
assert_nested_equal(item_a, item_b, sub_path)
# Otherwise, compare values directly
else:
if a != b:
mismatch_details = f"Value mismatch at {path or '<root>'}: {repr(a)} != {repr(b)}"
print(mismatch_details)
raise AssertionError(mismatch_details)
class TestDataset(unittest.TestCase):
# Set up the test fixture
def setUp(self):
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
warnings.simplefilter("ignore", ResourceWarning)
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory()
self.dataset_path = Path(self.temp_dir.name)
# self.dataset_path = Path("./test_dataset")
self.dataset_path.mkdir(exist_ok=True, parents=True)
(self.dataset_path / "ds1").mkdir(exist_ok=True, parents=True)
(self.dataset_path / "ds2").mkdir(exist_ok=True, parents=True)
# Create a small dummy captioning dataset
self.create_text_test_dataset(self.dataset_path / "ds1", range(55), range(55))
self.create_text_test_dataset(self.dataset_path / "ds2", range(100, 155), range(100, 155))
self.create_text_test_dataset(self.dataset_path / "ds3", range(200, 255), range(0, 55))
self.mds_path = self.dataset_path / "metadataset.yaml"
with open(self.mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: Metadataset",
"splits:",
" train:",
" datasets:",
" - weight: 1",
" path: ds1",
" subflavor: ds1",
" subflavors:",
" source: metadataset.yaml",
" number: 43",
" mds: mds",
" shuffle_over_epochs_multiplier: 3",
" - weight: 1",
" path: ds2",
" subflavor: ds2",
" subflavors:",
" source: metadataset.yaml",
" number: 44",
" mds: mds",
" val:",
" datasets:",
" - weight: 1",
" path: ds1",
" split_part: train",
" - weight: 1",
" path: ds2",
" split_part: train",
]
)
)
self.nested_mds_path = self.dataset_path / "nested_metadataset.yaml"
with open(self.nested_mds_path, "w") as f:
f.write(
"\n".join(
[
"splits:",
" train:",
" datasets:",
" - weight: 4",
" path: ./metadataset.yaml",
" split_part: train",
" subflavor: train",
" subflavors:",
" source: nested_metadataset.yaml",
" mds: nested_train",
" - path: ./metadataset.yaml",
" split_part: val",
" subflavors:",
" source: nested_metadataset.yaml",
" mds: nested_val",
]
)
)
print(self.dataset_path)
def tearDown(self):
# Remove all temporary files
gc.collect()
self.temp_dir.cleanup()
@staticmethod
def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]):
"""Creates a small dummy test dataset for testing purposes."""
# Create num_samples unique captions
(path / "parts").mkdir(exist_ok=True, parents=True)
# Initialize the ShardWriter
with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer:
for key, txt in zip(key_range, txt_range):
# Write individual files to shards
shard_writer.write(
{
"__key__": f"{key:06d}",
"txt": f"{txt}".encode(),
},
)
total_shards = shard_writer.shard
from megatron.energon.flavors import BaseWebdatasetFactory
BaseWebdatasetFactory.prepare_dataset(
path,
[f"parts/data-{{0..{total_shards - 1}}}.tar"],
split_parts_ratio=[("train", 1.0)],
shuffle_seed=None,
)
with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f:
f.write(
"\n".join(
[
"sample_type:",
" __module__: megatron.energon",
" __class__: TextSample",
"field_map:",
" text: txt",
"subflavors:",
" source: dataset.yaml",
" dataset.yaml: true",
" number: 42",
]
)
)
def test_metadataset(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Train mode dataset
train_dataset = get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=10,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
print(len(train_dataset))
assert len(train_dataset) == 11
train_loader1 = get_loader(train_dataset)
train_order1 = [
text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text
]
print(train_order1[:10])
print(Counter(train_order1))
assert len(Counter(train_order1)) == 110
assert all(48 <= v <= 52 for v in Counter(train_order1).values())
train_subflavors = [
subflavor["__subflavor__"]
for idx, data in zip(range(55), train_loader1)
for subflavor in data.__subflavors__
]
print(train_subflavors[:10])
print(Counter(train_subflavors))
assert len(Counter(train_subflavors)) == 2
assert all(250 <= v <= 300 for v in Counter(train_subflavors).values())
# Train mode dataset
train_dataset = get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=10,
shuffle_buffer_size=25,
max_samples_per_sequence=25,
)
print(len(train_dataset))
assert len(train_dataset) == 11
train_loader1 = get_loader(train_dataset)
train_order1 = [
text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text
]
print(train_order1[:10])
print(Counter(train_order1))
assert len(Counter(train_order1)) == 110
assert all(48 <= v <= 52 for v in Counter(train_order1).values())
# Val mode dataset
val_dataset = get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10)
print(len(val_dataset))
assert len(val_dataset) == 11
val_loader1 = get_loader(val_dataset)
val_order1 = [text for data in val_loader1 for text in data.text]
assert len(val_order1) == 110
print(Counter(val_order1))
assert all(v == 1 for v in Counter(val_order1).values())
def test_nested_metadataset(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
)
dataset = load_dataset(self.nested_mds_path)
raw_datasets = dataset.get_datasets(
training=False, split_part="train", worker_config=worker_config
)
assert raw_datasets.blend_mode == DatasetBlendMode.DATASET_WEIGHT
assert [raw_dataset.weight for raw_dataset in raw_datasets.datasets] == [0.4, 0.4, 0.1, 0.1]
assert [raw_dataset.dataset.paths[0].name for raw_dataset in raw_datasets.datasets] == [
"ds1",
"ds2",
"ds1",
"ds2",
]
print([raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets])
assert [raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets] == [
{
"source": "nested_metadataset.yaml",
"dataset.yaml": True,
"number": 43,
"mds": "nested_train",
"__subflavor__": "train",
},
{
"source": "nested_metadataset.yaml",
"dataset.yaml": True,
"number": 44,
"mds": "nested_train",
"__subflavor__": "train",
},
{
"source": "nested_metadataset.yaml",
"dataset.yaml": True,
"number": 42,
"mds": "nested_val",
},
{
"source": "nested_metadataset.yaml",
"dataset.yaml": True,
"number": 42,
"mds": "nested_val",
},
]
# Train mode dataset
train_dataset = get_train_dataset(
self.nested_mds_path,
worker_config=worker_config,
batch_size=10,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
print(len(train_dataset))
assert len(train_dataset) == 22
train_loader1 = get_loader(train_dataset)
train_order1 = [
text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text
]
print(train_order1[:10])
print(Counter(train_order1))
assert len(Counter(train_order1)) == 110
assert all(48 <= v <= 53 for v in Counter(train_order1).values())
train_subflavors = [
subflavor.get("__subflavor__")
for idx, data in zip(range(55), train_loader1)
for subflavor in data.__subflavors__
]
cnt = Counter(train_subflavors)
print(train_subflavors[:10])
print(cnt)
avg = 55 * 10 / 5
assert len(Counter(train_subflavors)) == 2
assert avg * 4 - 40 < cnt["train"] < avg * 4 + 40
assert avg - 10 < cnt[None] < avg + 10
train_subflavorss = [
tuple(subflavor.items())
for idx, data in zip(range(55), train_loader1)
for subflavor in data.__subflavors__
]
cnt = Counter(train_subflavorss)
print(train_subflavorss[:10])
print(cnt)
assert len(Counter(train_subflavorss)) == 3
assert (
avg * 2 - 20
< cnt[
(
("source", "nested_metadataset.yaml"),
("dataset.yaml", True),
("number", 43),
("__subflavor__", "train"),
("mds", "nested_train"),
)
]
< avg * 2 + 20
)
assert (
avg * 2 - 20
< cnt[
(
("source", "nested_metadataset.yaml"),
("dataset.yaml", True),
("number", 44),
("__subflavor__", "train"),
("mds", "nested_train"),
)
]
< avg * 2 + 20
)
assert (
avg * 1 - 20
< cnt[
(
("source", "nested_metadataset.yaml"),
("dataset.yaml", True),
("number", 42),
("mds", "nested_val"),
)
]
< avg * 1 + 20
)
# Train mode dataset
train_dataset = get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=10,
shuffle_buffer_size=25,
max_samples_per_sequence=25,
)
print(len(train_dataset))
assert len(train_dataset) == 11
train_loader1 = get_loader(train_dataset)
train_order1 = [
text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text
]
print(train_order1[:10])
print(Counter(train_order1))
assert len(Counter(train_order1)) == 110
assert all(48 <= v <= 52 for v in Counter(train_order1).values())
# Val mode dataset
val_dataset = get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10)
print(len(val_dataset))
assert len(val_dataset) == 11
val_loader1 = get_loader(val_dataset)
val_order1 = [text for data in val_loader1 for text in data.text]
assert len(val_order1) == 110
print(Counter(val_order1))
assert all(v == 1 for v in Counter(val_order1).values())
def test_worker_sample_balance(self):
torch.manual_seed(42)
for num_workers in [6, 30]:
samples_per_global_worker = Counter()
for rank in range(2):
wc = WorkerConfig(
rank=rank,
world_size=2,
num_workers=num_workers,
)
train_dataset = get_train_dataset(
self.nested_mds_path,
worker_config=wc,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
blend_dataset = get_blend_dataset(train_dataset)
assert isinstance(blend_dataset, BlendDataset)
ds_weights = blend_dataset.dataset_weights
assert len(ds_weights) == 4 # 4 datasets
# We are now going to count the number of samples that was assigned to each
# globally unique worker. This corresponds to the shard_ranges that energon
# prints out when the dataset is built.
for ds, w in ds_weights:
worker_slice_offsets = ds.dataset.dataset.workers_slice_offsets
assert len(worker_slice_offsets) == num_workers
for worker_idx, slice_offsets in enumerate(worker_slice_offsets):
samples_per_global_worker[(rank, worker_idx)] += (
slice_offsets[-1] - slice_offsets[0]
)
print(samples_per_global_worker)
# Check the sample assignnent is balanced across all global workers
if num_workers == 6:
assert list(samples_per_global_worker.values()) == [
19, # rank 0
18,
18,
19,
18,
18,
19, # rank 1
18,
18,
19,
18,
18,
]
elif num_workers == 30:
# This should match the pattern of the first 40 items of a generalized bit
# reversal sequence of length 60.
# Given 4 * 55 = 220 samples modulo 60 workers, is 40 remaining samples
assert list(samples_per_global_worker.values()) == [
4,
4,
4,
4,
3,
4,
3,
4,
4,
4,
3,
4,
3,
4,
3,
4,
4,
4,
4,
3,
4,
3,
4,
4,
4,
3,
4,
3,
4,
3,
4,
4,
4,
4,
3,
4,
3,
4,
4,
4,
3,
4,
3,
4,
3,
4,
4,
4,
4,
3,
4,
3,
4,
4,
4,
3,
4,
3,
4,
3,
]
def test_save_restore_state_train(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
def new_loader():
return get_savable_loader(
get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=10,
parallel_shard_iters=2,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
shuffle_over_epochs_multiplier=2,
),
)
# Train mode dataset
loader = new_loader()
state_0 = loader.save_state_rank()
order_0 = [data.text for idx, data in zip(range(10), loader)]
state_1 = loader.save_state_rank()
# print("save state done")
order_1 = [data.text for idx, data in zip(range(20), loader)]
state_2 = loader.save_state_rank()
# print("save state done")
# Iterated 30 samples, afterwards 50 samples. Checkpoint should be around that
order_2 = [data.text for idx, data in zip(range(20), loader)]
state_3 = loader.save_state_rank()
# print("save state done")
# Iterated 50 samples, afterwards 53 samples. Checkpoint should be around that
order_3 = [data.text for idx, data in zip(range(3), loader)]
state_4 = loader.save_state_rank()
# print("save state done")
# Dataset size is 55, want to save one sample before end of epoch
# Iterated 53 samples, afterwards 54 samples. Checkpoint should be around that
order_4 = [data.text for idx, data in zip(range(1), loader)]
state_5 = loader.save_state_rank()
# print("save state done")
# Dataset size is 55, want to save one sample before end of epoch
# Iterated 54 samples, afterwards 55 samples. Checkpoint should be around that
order_5 = [data.text for idx, data in zip(range(1), loader)]
state_6 = loader.save_state_rank()
# print("save state done")
# Dataset size is 55, want to save one sample before end of epoch
# Iterated 55 samples, afterwards 75 samples. Checkpoint should be around that
order_6 = [data.text for idx, data in zip(range(70), loader)]
loader = new_loader()
print("state_1:", _norng_state(state_1))
loader.restore_state_rank(state_1)
order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)]
assert order_1 == order_1_rest
loader = new_loader()
loader.restore_state_rank(state_0)
order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)]
assert order_0 == order_0_rest
loader = new_loader()
print("state_2:", _norng_state(state_2))
loader.restore_state_rank(state_2)
order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)]
print("order_2:", order_2)
print("order_2_rest:", order_2_rest)
assert order_2 == order_2_rest
loader = new_loader()
print("state_3:", _norng_state(state_3))
loader.restore_state_rank(state_3)
order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)]
print("order_3:", order_3)
print("order_3_rest:", order_3_rest)
assert order_3 == order_3_rest
loader = new_loader()
print("state_4:", _norng_state(state_4))
loader.restore_state_rank(state_4)
order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)]
print("order_4:", order_4)
print("order_4_rest:", order_4_rest)
assert order_4 == order_4_rest
loader = new_loader()
print("state_5:", _norng_state(state_5))
loader.restore_state_rank(state_5)
order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)]
print("order_5:", order_5)
print("order_5_rest:", order_5_rest)
assert order_5 == order_5_rest
loader = new_loader()
print("state_6:", _norng_state(state_6))
loader.restore_state_rank(state_6)
order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)]
print("order_6:", order_6)
print("order_6_rest:", order_6_rest)
assert order_6 == order_6_rest
wrk_cfg = worker_config.config()
assert wrk_cfg == {
"rank": 0,
"world_size": 1,
"num_workers": 0,
"data_parallel_group": None,
}
print("loader.config():")
print(loader.config())
print()
reference_config = {
"type": "SavableDataLoader",
"num_workers": 0,
"persistent_workers": False,
"pin_memory": True,
"prefetch_factor": None,
"dataset": {
"type": "MapDataset",
"dataset": {
"type": "BatchDataset",
"batch_size": 10,
"batcher": "megatron.energon.task_encoder.base.DefaultTaskEncoder.batch",
"batcher_stateless": True,
"drop_last": False,
"error_handler": "megatron.energon.wrappers._log_exception.log_exception",
"worker_config": wrk_cfg,
"dataset": {
"type": "MapDataset",
"dataset": {
"type": "BlendDataset",
"dataset_weights": [
(
{
"type": "RepeatDataset",
"dataset": {
"type": "MapDataset",
"dataset": {
"type": "WebdatasetSampleLoaderDataset",
"joins": 1,
"len": 55,
"slice_offsets": [[0, 10, 20, 30, 40, 50, 55]],
"worker_config": wrk_cfg,
"shuffle_over_epochs": 6,
"parallel_slice_iters": 2,
},
"map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw",
"map_fn_config": {
"type": "StandardWebdatasetFactory",
"training": True,
"_path": str(self.dataset_path / "ds1"),
"shards": [
{
"name": "parts/data-0.tar",
"count": 10,
"_path": str(
self.dataset_path
/ "ds1/parts/data-0.tar"
),
},
{
"name": "parts/data-1.tar",
"count": 10,
"_path": str(
self.dataset_path
/ "ds1/parts/data-1.tar"
),
},
{
"name": "parts/data-2.tar",
"count": 10,
"_path": str(
self.dataset_path
/ "ds1/parts/data-2.tar"
),
},
{
"name": "parts/data-3.tar",
"count": 10,
"_path": str(
self.dataset_path
/ "ds1/parts/data-3.tar"
),
},
{
"name": "parts/data-4.tar",
"count": 10,
"_path": str(
self.dataset_path
/ "ds1/parts/data-4.tar"
),
},
{
"name": "parts/data-5.tar",
"count": 5,
"_path": str(
self.dataset_path
/ "ds1/parts/data-5.tar"
),
},
],
"sample_excludes": [],
"shuffle_over_epochs": 6,
"parallel_shard_iters": 2,
"max_samples_per_sequence": None,
"subset": None,
"subflavors": {
"source": "metadataset.yaml",
"dataset.yaml": True,
"number": 43,
"mds": "mds",
"__subflavor__": "ds1",
},
"sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__.<locals>.<lambda>",
"image_decode": "torchrgb",
"av_decode": "AVDecoder",
"video_decode_audio": False,
"guess_content": False,
},
"map_fn_stateless": True,
},
"repeats": None,
"worker_config": wrk_cfg,
},
0.5,
),
(
{
"type": "RepeatDataset",
"dataset": {
"type": "MapDataset",
"dataset": {
"type": "WebdatasetSampleLoaderDataset",
"joins": 1,
"len": 55,
"slice_offsets": [[0, 10, 20, 30, 40, 50, 55]],
"worker_config": wrk_cfg,
"shuffle_over_epochs": 2,
"parallel_slice_iters": 2,
},
"map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw",
"map_fn_config": {
"type": "StandardWebdatasetFactory",
"training": True,
"_path": str(self.dataset_path / "ds2"),
"shards": [
{
"name": "parts/data-0.tar",
"count": 10,
"_path": str(
self.dataset_path
/ "ds2/parts/data-0.tar"
),
},
{
"name": "parts/data-1.tar",
"count": 10,
"_path": str(
self.dataset_path
/ "ds2/parts/data-1.tar"
),
},
{
"name": "parts/data-2.tar",
"count": 10,
"_path": str(
self.dataset_path
/ "ds2/parts/data-2.tar"
),
},
{
"name": "parts/data-3.tar",
"count": 10,
"_path": str(
self.dataset_path
/ "ds2/parts/data-3.tar"
),
},
{
"name": "parts/data-4.tar",
"count": 10,
"_path": str(
self.dataset_path
/ "ds2/parts/data-4.tar"
),
},
{
"name": "parts/data-5.tar",
"count": 5,
"_path": str(
self.dataset_path
/ "ds2/parts/data-5.tar"
),
},
],
"sample_excludes": [],
"shuffle_over_epochs": 2,
"parallel_shard_iters": 2,
"max_samples_per_sequence": None,
"subset": None,
"subflavors": {
"source": "metadataset.yaml",
"dataset.yaml": True,
"number": 44,
"mds": "mds",
"__subflavor__": "ds2",
},
"sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__.<locals>.<lambda>",
"image_decode": "torchrgb",
"av_decode": "AVDecoder",
"video_decode_audio": False,
"guess_content": False,
},
"map_fn_stateless": True,
},
"repeats": None,
"worker_config": wrk_cfg,
},
0.5,
),
],
"worker_config": wrk_cfg,
},
"map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_sample",
"map_fn_stateless": True,
},
},
"map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_batch",
"map_fn_stateless": True,
},
}
print("Comparing dataset configs in test_save_restore_state_train.")
assert_nested_equal(loader.config(), reference_config)
def test_save_restore_state_train_workers(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=1,
seed_offset=42,
)
def new_loader():
return get_savable_loader(
get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=10,
parallel_shard_iters=2,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
checkpoint_every_sec=0.5,
checkpoint_every_min_n_samples=1,
)
# Train mode dataset
loader = new_loader()
state_0 = loader.save_state_rank()
order_0 = [data.text for idx, data in zip(range(10), loader)]
time.sleep(0.5)
state_1 = loader.save_state_rank()
# print("save state done")
order_1 = [data.text for idx, data in zip(range(20), loader)]
# Ensure a checkpoint is created on next()
time.sleep(1.5)
state_2 = loader.save_state_rank()
# print("save state done")
# Iterated 30 samples, afterwards 50 samples. Checkpoint should be around that
order_2 = [data.text for idx, data in zip(range(20), loader)]
state_3 = loader.save_state_rank()
# print("save state done")
# Iterated 50 samples, afterwards 53 samples. Checkpoint should be around that
order_3 = [data.text for idx, data in zip(range(3), loader)]
# Ensure a checkpoint is created on next()
time.sleep(1.5)
state_4 = loader.save_state_rank()
# print("save state done")
# Dataset size is 55, want to save one sample before end of epoch
# Iterated 1 samples, afterwards 54 samples. Checkpoint should be around that
order_4 = [data.text for idx, data in zip(range(1), loader)]
# Ensure a checkpoint is created on next()
time.sleep(1.5)
state_5 = loader.save_state_rank()
# print("save state done")
# Dataset size is 55, want to save one sample before end of epoch
# Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that
order_5 = [data.text for idx, data in zip(range(1), loader)]
# Ensure a checkpoint is created on next()
time.sleep(1.5)
state_6 = loader.save_state_rank()
# print("save state done")
# Dataset size is 55, want to save one sample before end of epoch
# Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that
order_6 = [data.text for idx, data in zip(range(10), loader)]
loader = new_loader()
print("state_1:", _norng_state(state_1))
loader.restore_state_rank(state_1)
order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)]
print("order_1:", order_1)
print("order_1_rest:", order_1_rest)
assert order_1 == order_1_rest
loader = new_loader()
loader.restore_state_rank(state_0)
order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)]
assert order_0 == order_0_rest
loader = new_loader()
print("state_2:", _norng_state(state_2))
loader.restore_state_rank(state_2)
order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)]
print("order_2:", order_2)
print("order_2_rest:", order_2_rest)
assert order_2 == order_2_rest
loader = new_loader()
print("state_3:", _norng_state(state_3))
loader.restore_state_rank(state_3)
order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)]
print("order_3:", order_3)
print("order_3_rest:", order_3_rest)
assert order_3 == order_3_rest
loader = new_loader()
print("state_4:", _norng_state(state_4))
loader.restore_state_rank(state_4)
order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)]
print("order_4:", order_4)
print("order_4_rest:", order_4_rest)
assert order_4 == order_4_rest
loader = new_loader()
print("state_5:", _norng_state(state_5))
loader.restore_state_rank(state_5)
order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)]
print("order_5:", order_5)
print("order_5_rest:", order_5_rest)
assert order_5 == order_5_rest
loader = new_loader()
print("state_6:", _norng_state(state_6))
loader.restore_state_rank(state_6)
order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)]
print("order_6:", order_6)
print("order_6_rest:", order_6_rest)
assert order_6 == order_6_rest
def test_save_restore_state_train_epochize_workers(self):
torch.manual_seed(42)
psi = 2
vel = 19
sbs = 10
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=2,
seed_offset=42,
)
# Train mode dataset
torch.manual_seed(42)
loader = get_savable_loader(
get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=1,
parallel_shard_iters=psi,
virtual_epoch_length=vel,
shuffle_buffer_size=sbs,
max_samples_per_sequence=sbs,
),
)
state_0 = loader.save_state_rank()
order_1 = [data.text[0] for data in loader]
state_1 = loader.save_state_rank()
order_2 = [data.text[0] for data in loader]
state_2 = loader.save_state_rank()
order_3 = [data.text[0] for idx, data in zip(range(17), loader)]
torch.manual_seed(42)
loader = get_savable_loader(
get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=1,
parallel_shard_iters=psi,
virtual_epoch_length=vel,
shuffle_buffer_size=sbs,
max_samples_per_sequence=sbs,
),
)
print("state_0:", _norng_state(state_0))
loader.restore_state_rank(state_0)
order_5 = [data.text[0] for data in loader]
print("order_1:", order_1)
print("order_5:", order_5)
assert order_1 == order_5
torch.manual_seed(42)
loader = get_savable_loader(
get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=1,
parallel_shard_iters=psi,
virtual_epoch_length=vel,
shuffle_buffer_size=sbs,
max_samples_per_sequence=sbs,
),
)
print("state_1:", _norng_state(state_1))
loader.restore_state_rank(state_1)
order_6 = [data.text[0] for data in loader]
print("order_2:", order_2)
print("order_6:", order_6)
assert order_2 == order_6
torch.manual_seed(42)
loader = get_savable_loader(
get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=1,
parallel_shard_iters=psi,
virtual_epoch_length=vel,
shuffle_buffer_size=sbs,
max_samples_per_sequence=sbs,
),
)
print("state_2:", _norng_state(state_2))
loader.restore_state_rank(state_2)
order_7 = [data.text[0] for idx, data in zip(range(17), loader)]
print("order_3:", order_3)
print("order_7:", order_7)
assert order_3 == order_7
def test_save_restore_state_val(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Train mode dataset
loader = get_savable_loader(
get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10),
)
state_0 = loader.save_state_rank()
order_1 = [data.text for idx, data in zip(range(55 * 20), loader)]
state_1 = loader.save_state_rank()
# print("save state done")
order_2 = [data.text for idx, data in zip(range(55 * 20), loader)]
loader = get_savable_loader(
get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10),
)
loader.restore_state_rank(state_1)
order_3 = [data.text for idx, data in zip(range(55 * 20), loader)]
assert order_2 == order_3
loader = get_savable_loader(
get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10),
)
loader.restore_state_rank(state_0)
order_4 = [data.text for idx, data in zip(range(55 * 20), loader)]
assert order_1 == order_4
def test_blending_randomness(self):
import random
import numpy
for num_workers in [0, 1, 2]: # Especially also check the num_workers=0 case
world_size = 4
micro_batch_size = 1
seed = 42
configs = (
WorkerConfig(rank=0, world_size=world_size, num_workers=num_workers),
WorkerConfig(rank=1, world_size=world_size, num_workers=num_workers),
WorkerConfig(rank=2, world_size=world_size, num_workers=num_workers),
)
all_ranks_subflavors = []
for rank_config in configs:
torch.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)
ds = get_train_dataset(
self.mds_path,
split_part="train",
worker_config=rank_config,
batch_size=micro_batch_size,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
loader = get_loader(ds)
subflavors = [
data.__subflavors__[0].get("__subflavor__")
for idx, data in zip(range(25), loader)
]
all_ranks_subflavors.append(subflavors)
print(f"Subflavors for rank {rank_config.rank}:", subflavors)
# Assert that all ranks got different data
for i in range(len(all_ranks_subflavors)):
for j in range(i + 1, len(all_ranks_subflavors)):
assert all_ranks_subflavors[i] != all_ranks_subflavors[j], (
f"Rank {i} and rank {j} got the same subflavors."
)
# Delete all locals, otherwise loaders might be kept alive
locals().clear()
gc.collect()
def test_slice_iter_shuffle_over_epochs(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
def new_loader():
return get_savable_loader(
get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=10,
parallel_shard_iters=2,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
shuffle_over_epochs_multiplier=-1,
),
)
# Train mode dataset
loader = new_loader()
_ = [data.text for idx, data in zip(range(1000), loader)]
def test_save_restore_next(self):
torch.manual_seed(42)
wc = WorkerConfig(
rank=0,
world_size=1,
num_workers=6,
)
initial_loader = get_savable_loader(
get_train_dataset(
self.nested_mds_path,
worker_config=wc,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=0,
)
skip_initial = 9
previous_cp = initial_loader.save_state_rank()
print("initial_samples:")
for i, sample in zip(range(skip_initial), initial_loader):
print(f"sample[@{i}]: {sample.text}")
print("previous_cp:", previous_cp)
rst_loader = get_savable_loader(
get_train_dataset(
self.nested_mds_path,
worker_config=wc,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=0,
)
rst_loader.restore_state_rank(previous_cp)
for i, rst_sample in zip(range(1), rst_loader):
print(f"rst_sample[@{i}]: {rst_sample.text}")
assert sample.text == rst_sample.text, f"{sample} != {rst_sample}"
assert sample.__key__ == rst_sample.__key__, f"{sample} != {rst_sample}"
assert sample.__restore_key__ == rst_sample.__restore_key__, f"{sample} != {rst_sample}"
previous_cp = initial_loader.save_state_rank()
# Iterate 10 samples, the save state and store the next 10 samples for reference.
state_initial = initial_loader.save_state_rank()
print("state_initial:", str(state_initial))
initial_samples = [sample for _, sample in zip(range(20), initial_loader)]
print(
"initial_samples:"
+ "".join(
f"\n [@{idx}] {sample.text}"
for idx, sample in enumerate(initial_samples, start=skip_initial)
)
)
del initial_loader
gc.collect()
second_loader = get_savable_loader(
get_train_dataset(
self.nested_mds_path,
worker_config=wc,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=0,
)
second_loader.restore_state_rank(state_initial)
# Save the state again, to check that it is the same as the just restored state
same_state = second_loader.save_state_rank()
print("same_state:", same_state)
assert same_state == state_initial
for offset in range(10):
try:
# Save state and restore in next loader
state_offset = second_loader.save_state_rank()
# Get 1 sample from the current loader
samples = [sample for _, sample in zip(range(1), second_loader)]
assert len(samples) == 1
sample = samples[0]
# Check that the sample is the same as the initial loader's reference sample
print(f"sample[@{offset + skip_initial}]: {sample.text}")
try:
assert sample.text == initial_samples[offset].text, (
f"{sample} != {initial_samples[offset]}"
)
assert sample.__key__ == initial_samples[offset].__key__, (
f"{sample} != {initial_samples[offset]}"
)
assert sample.__restore_key__ == initial_samples[offset].__restore_key__, (
f"{sample} != {initial_samples[offset]}"
)
except Exception as e:
print(
"samples:"
+ f"\n [@{offset + skip_initial}] {sample.text}"
+ "".join(
f"\n [@{idx}] {sample.text}"
for idx, sample in zip(
range(skip_initial + offset + 1, skip_initial + offset + 6),
second_loader,
)
)
)
raise ValueError(f"Failed to iterate @{offset + skip_initial} samples") from e
# Restore state in a new loader
ref_loader = get_savable_loader(
get_train_dataset(
self.nested_mds_path,
worker_config=wc,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=0,
)
ref_loader.restore_state_rank(state_offset)
# Get 1 sample from the restored loader
next_loader_samples = [sample for _, sample in zip(range(6), ref_loader)]
assert len(next_loader_samples) == 6
next_loader_sample = next_loader_samples[0]
print(
"next_loader_samples:"
+ f"\n [@{offset + skip_initial}] {sample.text}"
+ "".join(
f"\n [@{idx}] {sample}"
for idx, sample in zip(
range(skip_initial + offset, skip_initial + offset + 6),
next_loader_samples,
)
)
)
assert next_loader_sample.text == sample.text, f"{next_loader_sample} != {sample}"
assert next_loader_sample.__key__ == sample.__key__, (
f"{next_loader_sample} != {sample}"
)
assert next_loader_sample.__restore_key__ == sample.__restore_key__, (
f"{next_loader_sample} != {sample}"
)
except Exception as e:
raise ValueError(f"Failed to iterate @{skip_initial}+{offset} samples") from e
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""This module defines tests for meta datasets."""
import gc
import logging
import sys
import tempfile
import unittest
import warnings
from pathlib import Path
from typing import Iterable
import torch
import webdataset as wds
from megatron.energon import (
BlendDataset,
SavableDataset,
WorkerConfig,
get_savable_loader,
get_train_dataset,
)
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
# Speed up tests significantly by reducing the torch status check interval for broken worker shutdown
try:
torch.utils.data._utils.worker.MP_STATUS_CHECK_INTERVAL = 0.1
torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.1
except AttributeError:
pass
def _norng_state(state):
if isinstance(state, bytes):
if len(state) > 100:
return state[:5] + f"...<len={len(state)}>".encode()
return state
elif isinstance(state, str):
if len(state) > 100:
return state[:5] + f"...<len={len(state)}>"
return state
elif isinstance(state, dict):
return {k: _norng_state(v) for k, v in state.items()}
elif isinstance(state, (list, tuple)):
if len(state) > 100:
state = state[:5]
return type(state)(_norng_state(v) for v in state)
else:
return state
def get_blend_dataset(ds: SavableDataset):
if isinstance(ds, BlendDataset):
return ds
else:
if hasattr(ds, "dataset"):
return get_blend_dataset(ds.dataset)
else:
raise ValueError("No blend dataset found")
class TestDataset(unittest.TestCase):
# Set up the test fixture
def setUp(self):
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
warnings.simplefilter("ignore", ResourceWarning)
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory()
self.dataset_path = Path(self.temp_dir.name)
# self.dataset_path = Path("./test_dataset")
self.dataset_path.mkdir(exist_ok=True, parents=True)
(self.dataset_path / "ds1").mkdir(exist_ok=True, parents=True)
(self.dataset_path / "ds2").mkdir(exist_ok=True, parents=True)
(self.dataset_path / "ds3").mkdir(exist_ok=True, parents=True)
# Create a small dummy captioning dataset
self.create_text_test_dataset(self.dataset_path / "ds1", range(55), range(55))
self.create_text_test_dataset(self.dataset_path / "ds2", range(100, 107), range(100, 107))
self.create_text_test_dataset(self.dataset_path / "ds3", range(200, 255), range(0, 55))
self.mds_path = self.dataset_path / "metadataset_v2.yaml"
with open(self.mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend:",
" - weight: 1",
" path: ds1",
" - weight: 1",
" path: ds2",
" - weight: 1",
" path: ds3",
]
)
)
print(self.dataset_path)
def tearDown(self):
# Remove all temporary files
gc.collect()
self.temp_dir.cleanup()
@staticmethod
def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]):
"""Creates a small dummy test dataset for testing purposes."""
# Create num_samples unique captions
(path / "parts").mkdir(exist_ok=True, parents=True)
# Initialize the ShardWriter
with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer:
for key, txt in zip(key_range, txt_range):
# Write individual files to shards
shard_writer.write(
{
"__key__": f"{key:06d}",
"txt": f"{txt}".encode(),
},
)
total_shards = shard_writer.shard
from megatron.energon.flavors import BaseWebdatasetFactory
BaseWebdatasetFactory.prepare_dataset(
path,
[f"parts/data-{{0..{total_shards - 1}}}.tar"],
split_parts_ratio=[("train", 1.0)],
shuffle_seed=None,
)
with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: TextWebdataset",
"field_map:",
" text: txt",
"subflavors:",
" source: dataset.yaml",
" dataset.yaml: true",
" number: 42",
]
)
)
def test_metadataset_few_samples_save_restore(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=32,
num_workers=1,
seed_offset=42,
)
# Train mode dataset
train_dataset = get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=100,
max_samples_per_sequence=None,
)
print(len(train_dataset))
assert len(train_dataset) == 4
# The middle dataset should have 0 samples assigned to this rank
blend_ds = get_blend_dataset(train_dataset)
assert len(blend_ds.dataset_weights[1][0].dataset.dataset.workers_slice_offsets[0]) == 1
assert len(blend_ds.dataset_weights[1][0].dataset.dataset) == 0
train_loader = get_savable_loader(
train_dataset,
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
# Load 3 samples
list(zip(train_loader, range(3)))
# Save state mid epoch
state1 = train_loader.save_state_rank()
# Load 5 samples
data1b = list(zip(train_loader, range(5)))
# Restore state
train_loader = get_savable_loader(
get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=100,
max_samples_per_sequence=None,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
train_loader.restore_state_rank(state1)
# Load 5 samples
data2_restore = list(zip(train_loader, range(5)))
# Check that the restored state is the same
order1b = [(s[0].__key__[0], int(s[0].text[0])) for s in data1b]
order2 = [(s[0].__key__[0], int(s[0].text[0])) for s in data2_restore]
print("order1b")
print(order1b)
print("order2")
print(order2)
assert order1b == order2, "The restored state does not match the original state."
def test_too_few_samples(self):
# Will only give a single sample, as there are 117 samples in total, and 100 ranks
ws = 100
lens = []
for i_rank in range(ws):
worker_config = WorkerConfig(rank=i_rank, world_size=ws, num_workers=0)
loader = get_savable_loader(
get_train_dataset(
self.mds_path,
batch_size=1,
worker_config=worker_config,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
checkpoint_every_min_n_samples=1,
checkpoint_every_sec=0,
)
lens.append(len(loader))
txts = []
for i, sample in zip(range(10), loader):
txts.extend(sample.text)
assert len(set(txts)) == len(loader), (
f"Rank {i_rank} should have exactly {len(loader)} sample, but got {txts}"
)
assert lens == [
2 if i in [0, 3, 6, 12, 18, 25, 31, 37, 43, 50, 56, 62, 68, 75, 81, 87, 93] else 1
for i in range(100)
]
if __name__ == "__main__":
# unittest.main()
ds = TestDataset()
ds.setUp()
ds.test_metadataset_few_samples_save_restore()
ds.tearDown()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""This module defines tests for meta datasets."""
import gc
import logging
import random
import sys
import tempfile
import unittest
import warnings
from collections import Counter
from pathlib import Path
from typing import Iterable
from unittest.mock import patch
import torch
import webdataset as wds
from megatron.energon import (
Sample,
TextSample,
WorkerConfig,
get_loader,
get_savable_loader,
get_train_dataset,
load_dataset,
)
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME
from megatron.energon.metadataset.loader import prepare_metadataset
from megatron.energon.metadataset.loader_interface import DatasetBlendMode
from megatron.energon.task_encoder.base import DefaultTaskEncoder
from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset
from tests.epath_s3_emulator import setup_s3_emulator
# Speed up tests significantly by reducing the torch status check interval for broken worker shutdown
try:
torch.utils.data._utils.worker.MP_STATUS_CHECK_INTERVAL = 0.1
torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 0.1
except AttributeError:
pass
def _norng_state(state):
if isinstance(state, bytes):
if len(state) > 100:
return state[:5] + f"...<len={len(state)}>".encode()
return state
elif isinstance(state, str):
if len(state) > 100:
return state[:5] + f"...<len={len(state)}>"
return state
elif isinstance(state, dict):
return {k: _norng_state(v) for k, v in state.items()}
elif isinstance(state, (list, tuple)):
if len(state) > 100:
state = state[:5]
return type(state)(_norng_state(v) for v in state)
else:
return state
@edataclass
class TestJoinedSample(Sample):
text1: torch.Tensor
text2: torch.Tensor
@staticmethod
def from_joined(ds1: TextSample, ds2: TextSample) -> "TestJoinedSample":
return TestJoinedSample.derive_from(
ds1,
text1=ds1.text,
text2=ds2.text,
)
def test_joiner(text1: TextSample, text2: TextSample) -> TestJoinedSample:
return TestJoinedSample.derive_from(text1, text1=f"j{text1.text}", text2=f"j{text2.text}")
class TestDataset(unittest.TestCase):
# Set up the test fixture
def setUp(self):
random.seed(42)
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
warnings.simplefilter("ignore", ResourceWarning)
# Create a temporary directory
self.temp_dir = tempfile.TemporaryDirectory()
self.dataset_path = Path(self.temp_dir.name)
# self.dataset_path = Path("./test_dataset")
self.dataset_path.mkdir(exist_ok=True, parents=True)
# Create a small dummy datasets
self.create_text_test_dataset(self.dataset_path / "ds1", range(55), range(55))
self.create_text_test_dataset(self.dataset_path / "ds2", range(100, 155), range(100, 155))
self.create_text_test_dataset(self.dataset_path / "ds3", range(200, 255), range(55))
# Create a shuffled dataset for joining with the ds1. It has overlap but includes more samples
shuffled_range_100 = list(range(100))
random.shuffle(shuffled_range_100)
self.create_text_test_dataset(
self.dataset_path / "ds1b", shuffled_range_100, shuffled_range_100, prefix="B"
)
shuffled_range_100 = list(range(100))
random.shuffle(shuffled_range_100)
self.create_text_test_dataset(
self.dataset_path / "ds1c", shuffled_range_100, shuffled_range_100, prefix="C"
)
self.mds_path = self.dataset_path / "metadataset_v2.yaml"
with open(self.mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend:",
" - weight: 1",
" path: ds1",
" subflavors:",
" source: metadataset_v2.yaml",
" number: 43",
" mds: mds",
" shuffle_over_epochs_multiplier: 3",
" - weight: 1",
" path: ds2",
" subflavors:",
" source: metadataset_v2.yaml",
" number: 44",
" mds: mds",
" val:",
" blend:",
" - weight: 1",
" path: ds1",
" split_part: train",
" - weight: 1",
" path: ds2",
" split_part: train",
]
)
)
self.nested_mds_path = self.dataset_path / "nested_metadataset_v2.yaml"
with open(self.nested_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend:",
" - weight: 4",
" path: ./metadataset_v2.yaml",
" split_part: train",
" subflavors:",
" source: nested_metadataset.yaml",
" mds: nested_train",
" - path: ./metadataset_v2.yaml",
" split_part: val",
" subflavors:",
" source: nested_metadataset.yaml",
" mds: nested_val",
]
)
)
print(self.dataset_path)
def tearDown(self):
# Remove all temporary files
gc.collect()
self.temp_dir.cleanup()
@staticmethod
def create_text_test_dataset(
path: Path, txt_range: Iterable[int], key_range: Iterable[int], prefix: str = ""
):
"""Creates a small dummy test dataset for testing purposes."""
# Create num_samples unique captions
(path / "parts").mkdir(exist_ok=True, parents=True)
# Initialize the ShardWriter
with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer:
for key, txt in zip(key_range, txt_range):
# Write individual files to shards
shard_writer.write(
{
"__key__": f"{key:06d}",
"txt": f"{prefix}{txt}".encode(),
},
)
total_shards = shard_writer.shard
from megatron.energon.flavors import BaseWebdatasetFactory
BaseWebdatasetFactory.prepare_dataset(
path,
[f"parts/data-{{0..{total_shards - 1}}}.tar"],
split_parts_ratio=[("train", 1.0)],
shuffle_seed=None,
)
with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f:
f.write(
"\n".join(
[
"sample_type:",
" __module__: megatron.energon",
" __class__: TextSample",
"field_map:",
" text: txt",
"subflavors:",
" source: dataset.yaml",
" dataset.yaml: true",
" number: 42",
]
)
)
def test_metadataset(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Train mode dataset
train_dataset = get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=10,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
print(len(train_dataset))
assert len(train_dataset) == 11
train_loader1 = get_loader(train_dataset)
train_order1 = [
text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text
]
print(train_order1[:10])
print(Counter(train_order1))
assert len(Counter(train_order1)) == 110
assert all(48 <= v <= 52 for v in Counter(train_order1).values())
def test_nested_metadataset(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
)
dataset = load_dataset(self.nested_mds_path)
raw_datasets = dataset.get_datasets(
training=False, split_part="train", worker_config=worker_config
)
assert raw_datasets.blend_mode == DatasetBlendMode.DATASET_WEIGHT
assert [raw_dataset.weight for raw_dataset in raw_datasets.datasets] == [
0.4,
0.4,
0.1,
0.1,
], [raw_dataset.weight for raw_dataset in raw_datasets.datasets]
assert [raw_dataset.dataset.paths[0].name for raw_dataset in raw_datasets.datasets] == [
"ds1",
"ds2",
"ds1",
"ds2",
]
print([raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets])
assert [raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets] == [
{
"source": "nested_metadataset.yaml",
"dataset.yaml": True,
"number": 43,
"mds": "nested_train",
},
{
"source": "nested_metadataset.yaml",
"dataset.yaml": True,
"number": 44,
"mds": "nested_train",
},
{
"source": "nested_metadataset.yaml",
"dataset.yaml": True,
"number": 42,
"mds": "nested_val",
},
{
"source": "nested_metadataset.yaml",
"dataset.yaml": True,
"number": 42,
"mds": "nested_val",
},
]
def test_joined_metadataset(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Create a joined dataset configuration
joined_mds_path = self.dataset_path / "joined_metadataset_v2.yaml"
with open(joined_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" join:",
" ds1:",
" path: ds1",
" subflavors:",
" source1: ds1",
" number: 43",
" ds2:",
" path: ds3",
" subflavors:",
" source2: ds3",
" number: 44",
" joiner:",
f" __module__: {TestJoinedSample.__module__}",
f" __class__: {TestJoinedSample.__name__}",
]
)
)
prepare_metadataset(EPath(joined_mds_path))
# Train mode dataset
train_dataset = get_train_dataset(
joined_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
print(len(train_dataset))
assert len(train_dataset) == 55
train_loader = get_savable_loader(
train_dataset,
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
data = list(zip(range(2 * 55), train_loader))
txt1_order = [data.text1[0] for idx, data in data]
txt2_order = [data.text2[0] for idx, data in data]
key_order = [data.__key__[0] for idx, data in data]
# ds1 has 55 samples, key range 0:55, txt range 0:55
# ds3 has 28 samples, key range 0:55, txt range 200:255
# Joining results in: 0:55
print("txt1:", txt1_order)
# Joining results in: 200:255
print("txt2:", txt2_order)
# Joining results in: 0:55
print("key:", key_order)
# Check matching
assert all(int(txt1) + 200 == int(txt2) for txt1, txt2 in zip(txt1_order, txt2_order))
# Check frequency
assert set(txt1_order) == set(str(i) for i in range(0, 55))
assert set(txt2_order) == set(str(i) for i in range(200, 255))
# Every item must occurr 2 times (2*55).
assert Counter(txt1_order).most_common(1)[0][1] == 2
state = train_loader.save_state_rank()
# Iterate 60 more items
data = list(zip(range(60), train_loader))
txt1_order = [data.text1 for idx, data in data]
txt2_order = [data.text2 for idx, data in data]
key_order = [data.__key__ for idx, data in data]
# Restore state
train_loader = get_savable_loader(
get_train_dataset(
joined_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
train_loader.restore_state_rank(state)
# Iterate 360 more items
data = list(zip(range(60), train_loader))
txt1_order_rest = [data.text1 for idx, data in data]
txt2_order_rest = [data.text2 for idx, data in data]
key_order_rest = [data.__key__ for idx, data in data]
# Verify matching
assert txt1_order == txt1_order_rest
assert txt2_order == txt2_order_rest
assert key_order == key_order_rest
def test_joined_metadataset_joiner(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Create a joined dataset configuration
joined_mds_path = self.dataset_path / "joined_metadataset_joiner.yaml"
with open(joined_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend:",
" - weight: 1",
" join:",
" text1:",
" path: ds1",
" subflavors:",
" source1: ds1",
" number: 43",
" text2:",
" path: ds3",
" subflavors:",
" source2: ds3",
" number: 44",
" joiner:",
f" __module__: {test_joiner.__module__}",
f" __function__: {test_joiner.__name__}",
]
)
)
prepare_metadataset(EPath(joined_mds_path))
# Train mode dataset
train_dataset = get_train_dataset(
joined_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
print(len(train_dataset))
assert len(train_dataset) == 55
train_loader = get_savable_loader(
train_dataset,
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
data = list(zip(range(2 * 55), train_loader))
txt1_order = [data.text1[0] for idx, data in data]
txt2_order = [data.text2[0] for idx, data in data]
key_order = [data.__key__[0] for idx, data in data]
# ds1 has 55 samples, key range 0:55, txt range 0:55
# ds3 has 28 samples, key range 0:55, txt range 200:255
# Joining results in: 0:55, with prefix "j"
print("txt1:", txt1_order)
# Joining results in: 200:255, with prefix "j"
print("txt2:", txt2_order)
# Joining results in: 0:55
print("key:", key_order)
# Check matching
assert all(
int(txt1[1:]) + 200 == int(txt2[1:]) for txt1, txt2 in zip(txt1_order, txt2_order)
)
# Check frequency
assert set(txt1_order) == set(f"j{i}" for i in range(0, 55))
assert set(txt2_order) == set(f"j{i}" for i in range(200, 255))
# Every item must occurr 2 times (2*55).
assert Counter(txt1_order).most_common(1)[0][1] == 2
def test_left_join(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Create a joined dataset configuration
joined_mds_path = self.dataset_path / "left_join.yaml"
with open(joined_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend:",
" - weight: 1",
" join:",
" text1:",
" path: ds1",
" subflavors:",
" source1: ds1",
" number: 43",
" text2:",
" path: ds1b",
" nonmatch: skip",
" subflavors:",
" source2: ds1b",
" number: 44",
" joiner:",
f" __module__: {test_joiner.__module__}",
f" __function__: {test_joiner.__name__}",
]
)
)
prepare_metadataset(EPath(joined_mds_path))
# Train mode dataset
train_dataset = get_train_dataset(
joined_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
print(len(train_dataset))
assert len(train_dataset) == 55, len(train_dataset)
train_loader = get_savable_loader(
train_dataset,
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
data = list(zip(range(2 * 55), train_loader))
txt1_order = [data.text1[0] for idx, data in data]
txt2_order = [data.text2[0] for idx, data in data]
key_order = [data.__key__[0] for idx, data in data]
# ds1 has 55 samples, key range 0:55, txt range 0:55
# ds3 has 28 samples, key range 0:55, txt range 200:255
# Joining results in: 0:55, with prefix "j"
print("txt1:", txt1_order)
# Joining results in: 200:255, with prefix "j"
print("txt2:", txt2_order)
# Joining results in: 0:55
print("key:", key_order)
# Check matching
assert all(int(txt1[1:]) == int(txt2[2:]) for txt1, txt2 in zip(txt1_order, txt2_order))
# Check frequency
assert set(txt1_order) == set(f"j{i}" for i in range(55))
assert set(txt2_order) == set(f"jB{i}" for i in range(55))
# Every item must occurr 2 times (2*55).
assert Counter(txt1_order).most_common(1)[0][1] == 2
# Test that changing the file works as expected
with open(joined_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend:",
" - weight: 1",
" join:",
" text1:",
" path: ds1c",
" subflavors:",
" source1: ds1c",
" number: 43",
" text2:",
" path: ds1b",
" nonmatch: skip",
" subflavors:",
" source2: ds1b",
" number: 44",
" joiner:",
f" __module__: {test_joiner.__module__}",
f" __function__: {test_joiner.__name__}",
" - weight: 1",
" join:",
" text1:",
" path: ds1b",
" text2:",
" path: ds1",
" nonmatch: skip",
" joiner:",
f" __module__: {test_joiner.__module__}",
f" __function__: {test_joiner.__name__}",
]
)
)
# Expect this to fail. Preparation does not match!
with self.assertRaises(Exception):
# Train mode dataset
train_dataset = get_train_dataset(
joined_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
# Shall succeed after preparation
prepare_metadataset(EPath(joined_mds_path))
train_dataset = get_train_dataset(
joined_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
# Check that there are no remainder files
cache_folder = joined_mds_path.with_name(joined_mds_path.name + ".cache")
assert sum(1 for f in cache_folder.iterdir() if f.is_file()) == 2, list(
cache_folder.iterdir()
)
def test_left_join_exclude(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Create a joined dataset configuration
orig_split_path = self.dataset_path / "ds1" / ".nv-meta" / "split.yaml"
exclude_split_path = self.dataset_path / "ds1" / ".nv-meta" / "exclude_split.yaml"
with open(exclude_split_path, "w") as f:
f.write(
"\n".join(
[
orig_split_path.read_text(),
"exclude:",
' - "parts/data-0.tar/000000"',
' - "parts/data-0.tar/000001"',
' - "parts/data-0.tar/000002"',
' - "parts/data-0.tar/000003"',
' - "parts/data-0.tar/000004"',
' - "parts/data-1.tar"',
' - "parts/data-2.tar/000029"',
]
)
)
# Create a joined dataset configuration
joined_mds_path = self.dataset_path / "left_join.yaml"
with open(joined_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend:",
" - weight: 1",
" join:",
" text1:",
" path: ds1",
" split_config: exclude_split.yaml",
" text2:",
" path: ds1b",
" nonmatch: skip",
" joiner:",
f" __module__: {test_joiner.__module__}",
f" __function__: {test_joiner.__name__}",
]
)
)
prepare_metadataset(EPath(joined_mds_path))
# Train mode dataset
train_dataset = get_train_dataset(
joined_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
)
print(len(train_dataset))
assert len(train_dataset) == 55 - 16, len(train_dataset)
train_loader = get_savable_loader(
train_dataset,
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
data = list(zip(range(2 * 55), train_loader))
txt1_order = [data.text1[0] for idx, data in data]
txt2_order = [data.text2[0] for idx, data in data]
key_order = [data.__key__[0] for idx, data in data]
# ds1 has 55 samples, key range 0:55, txt range 0:55
# ds3 has 28 samples, key range 0:55, txt range 200:255
# Joining results in: 0:55, with prefix "j"
print("txt1:", txt1_order)
# Joining results in: 200:255, with prefix "j"
print("txt2:", txt2_order)
# Joining results in: 0:55
print("key:", key_order)
# Check matching
assert all(int(txt1[1:]) == int(txt2[2:]) for txt1, txt2 in zip(txt1_order, txt2_order))
# Check frequency
set_filtered_nums = set(range(5, 10)) | set(range(20, 29)) | set(range(30, 55))
assert set(txt1_order) == set(f"j{i}" for i in set_filtered_nums)
assert set(txt2_order) == set(f"jB{i}" for i in set_filtered_nums)
def test_joined_metadataset_prepare_mock(self):
torch.manual_seed(42)
# Create a joined dataset configuration
joined_mds_path = self.dataset_path / "joined_metadataset_prepare_mock.yaml"
with open(joined_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" join:",
" - path: ds1",
" - path: ds3",
" joiner:",
" __module__: __main__",
" __class__: NonExistantSample",
]
)
)
prepare_metadataset(EPath(joined_mds_path))
# Create a joined dataset configuration
joined_mds_path = self.dataset_path / "joined_metadataset_prepare_mock2.yaml"
with open(joined_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" join:",
" - path: ds1",
" - path: ds3",
" joiner:",
" __module__: non_existant_module",
" __class__: MyCaptioningSample",
]
)
)
prepare_metadataset(EPath(joined_mds_path))
def test_metadataset_fixed_epochs(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Create a joined dataset configuration
fixed_epochs_mds_path = self.dataset_path / "metadataset_fixed_epochs.yaml"
with open(fixed_epochs_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend_epochized:",
" - repetitions: 2",
" path: ds1",
" subflavors:",
" source: ds1",
" number: 43",
" - repetitions: 3",
" path: ds2",
" subflavors:",
" source: ds2",
" number: 42",
]
)
)
# Train mode dataset
train_dataset = get_train_dataset(
fixed_epochs_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
repeat=False,
)
print(len(train_dataset))
assert len(train_dataset) == 5 * 55, len(train_dataset)
train_loader = get_savable_loader(
train_dataset,
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
data = list(enumerate(train_loader))
txt_order = [data.text[0] for idx, data in data]
key_order = [data.__subflavors__[0]["source"] + "/" + data.__key__[0] for idx, data in data]
print("txt1:", txt_order)
print("key:", key_order)
assert len(txt_order) == 5 * 55, Counter(txt_order)
ds1_keys = [key for key in key_order if key.startswith("ds1/")]
ds2_keys = [key for key in key_order if key.startswith("ds2/")]
txt_cnt = Counter(txt_order)
ds1_key_cnt = Counter(ds1_keys)
ds2_key_cnt = Counter(ds2_keys)
assert len(ds1_keys) == 2 * 55, (len(ds1_keys), ds1_key_cnt)
assert len(ds2_keys) == 3 * 55, (len(ds2_keys), ds2_key_cnt)
assert all(ds1_key_cnt[key] == 2 for key in ds1_keys)
assert all(ds2_key_cnt[key] == 3 for key in ds2_keys)
assert all(txt_cnt[key] in (2, 3) for key in txt_order)
# Next epoch
data = list(enumerate(train_loader))
print([data.text[0] for idx, data in data])
assert len(data) == 5 * 55, len(data)
# Next epoch
data1 = list(zip(range(3 * 55), train_loader))
assert len(data1) == 3 * 55, len(data1)
# Save state mid epoch
state1 = train_loader.save_state_rank()
print(state1)
data2 = list(enumerate(train_loader))
assert len(data2) == 2 * 55
txt_order = [data.text[0] for idx, data in data1 + data2]
key_order = [
data.__subflavors__[0]["source"] + "/" + data.__key__[0] for idx, data in data1 + data2
]
assert len(txt_order) == 5 * 55, Counter(txt_order)
ds1_keys = [key for key in key_order if key.startswith("ds1/")]
ds2_keys = [key for key in key_order if key.startswith("ds2/")]
txt_cnt = Counter(txt_order)
ds1_key_cnt = Counter(ds1_keys)
ds2_key_cnt = Counter(ds2_keys)
assert len(ds1_keys) == 2 * 55, (len(ds1_keys), ds1_key_cnt)
assert len(ds2_keys) == 3 * 55, (len(ds2_keys), ds2_key_cnt)
assert all(ds1_key_cnt[key] == 2 for key in ds1_keys)
assert all(ds2_key_cnt[key] == 3 for key in ds2_keys)
assert all(txt_cnt[key] in (2, 3) for key in txt_order)
# Restore state
train_loader = get_savable_loader(
get_train_dataset(
fixed_epochs_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
repeat=False,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
train_loader.restore_state_rank(state1)
data2_restore = list(enumerate(train_loader))
assert len(data2_restore) == 2 * 55
txt_order_rst = [data.text[0] for idx, data in data1 + data2_restore]
key_order_rst = [
data.__subflavors__[0]["source"] + "/" + data.__key__[0]
for idx, data in data1 + data2_restore
]
assert len(txt_order_rst) == 5 * 55, Counter(txt_order_rst)
assert txt_order_rst == txt_order
assert key_order_rst == key_order
ds1_keys_rst = [key for key in key_order_rst if key.startswith("ds1/")]
ds2_keys_rst = [key for key in key_order_rst if key.startswith("ds2/")]
txt_cnt_rst = Counter(txt_order_rst)
ds1_key_cnt_rst = Counter(ds1_keys_rst)
ds2_key_cnt_rst = Counter(ds2_keys_rst)
assert len(ds1_keys_rst) == 2 * 55, (len(ds1_keys_rst), ds1_key_cnt_rst)
assert len(ds2_keys_rst) == 3 * 55, (len(ds2_keys_rst), ds2_key_cnt_rst)
assert all(ds1_key_cnt_rst[key] == 2 for key in ds1_keys_rst)
assert all(ds2_key_cnt_rst[key] == 3 for key in ds2_keys_rst)
assert all(txt_cnt_rst[key] in (2, 3) for key in txt_order_rst)
def test_metadataset_fixed_fractional_epochs(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Create a joined dataset configuration
fixed_epochs_mds_path = self.dataset_path / "metadataset_fixed_epochs.yaml"
with open(fixed_epochs_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend_epochized:",
" - repetitions: 0.7",
" path: ds1",
" subflavors:",
" source: ds1",
" number: 43",
" - repetitions: 1.5",
" path: ds2",
" subflavors:",
" source: ds2",
" number: 42",
]
)
)
# ===== Part 1: Verify fractions =====
# Train mode dataset
train_dataset = get_train_dataset(
fixed_epochs_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
)
train_loader = get_savable_loader(
train_dataset,
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
assert len(train_loader) == 38 + 55 + 27, len(train_loader)
data = list(enumerate(train_loader))
# Check the overall number of samples
# Should be 0.7*len(ds1) + 1.5*len(ds2) = 0.7*55 + 1.5*55 = 38 + 55 + 27 (floor rounding)
assert len(data) == 38 + 55 + 27, len(data)
sample_counts = Counter([int(s[1].text[0]) for s in data])
# The first 70% of samples from ds1 (0 to incl. 37) should be repeated only once
assert all(sample_counts[sample] == 1 for sample in range(38))
# Since ds2 is repeated 1.5 times, the first 50% of samples from ds2 (100 to incl. 126) should be repeated twice
assert all(sample_counts[sample] == 2 for sample in range(100, 127))
# The remaining samples from ds2 (127 to incl. 154) should be repeated only once
assert all(sample_counts[sample] == 1 for sample in range(127, 155))
# ===== Part 2: Save and restore state =====
# Now let's check if the state is stored and restored correctly
train_loader = get_savable_loader(
get_train_dataset(
fixed_epochs_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
data1 = list(zip(range(95), train_loader))
state1 = train_loader.save_state_rank()
train_loader = get_savable_loader(
get_train_dataset(
fixed_epochs_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
train_loader.restore_state_rank(state1)
data2_restore = list(enumerate(train_loader))
total_samples_save_restore = len(data1) + len(data2_restore)
assert total_samples_save_restore == len(data), (
"Total number of samples do not match when using save/restore"
)
sample_counts_save_restore = Counter(
[int(s[1].text[0]) for d in [data1, data2_restore] for s in d]
)
assert sample_counts_save_restore == sample_counts, (
"Sample counts do not match when using save/restore"
)
# ===== Part 3: Check if the state is restored correctly when saving right at the end of a dataset =====
torch.manual_seed(42)
train_loader = get_savable_loader(
get_train_dataset(
fixed_epochs_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
ds1_counter = 0
data1 = []
for idx, sample in enumerate(train_loader):
data1.append((idx, sample))
if sample.__subflavors__[0]["source"] == "ds1":
ds1_counter += 1
if ds1_counter == 38:
# Stop right after the last sample from ds1
break
state1 = train_loader.save_state_rank()
train_loader = get_savable_loader(
get_train_dataset(
fixed_epochs_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
train_loader.restore_state_rank(state1)
data2_restore = list(enumerate(train_loader))
total_samples_save_restore = len(data1) + len(data2_restore)
assert total_samples_save_restore == len(data), (
"Total number of samples do not match when using save/restore"
)
sample_counts_save_restore = Counter(
[int(s[1].text[0]) for d in [data1, data2_restore] for s in d]
)
assert sample_counts_save_restore == sample_counts, (
"Sample counts do not match when using save/restore"
)
# Try in repeat mode
# Train mode dataset
train_loader = get_savable_loader(
get_train_dataset(
fixed_epochs_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
data = list(zip(range(200), train_loader))
assert len(train_loader) == 38 + 55 + 27, len(train_loader)
# Check the overall number of samples
# Should be 0.7*len(ds1) + 1.5*len(ds2) = 38 + 55 + 27 (floor rounding)
assert len(data) == 200, len(data)
# ===== Part 4: Test count for multiple workers =====
worker_config = WorkerConfig(
rank=0,
world_size=2,
num_workers=2,
seed_offset=42,
)
# Train mode dataset
train_loader = get_savable_loader(
get_train_dataset(
fixed_epochs_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
),
checkpoint_every_sec=0,
checkpoint_every_min_n_samples=1,
)
# TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py
assert len(train_loader) == 58, len(train_loader)
data = list(enumerate(train_loader))
# Check the overall number of samples
# Should be 0.7*len(ds1)55 + 1.5*len(ds2)55 = 38 + 55 + 27 (floor rounding)
# TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py
assert len(data) == 58, len(data)
@patch.object(WatchdogDataset, "_watchdog_trigger")
def test_watchdog_dataset(self, mock_watchdog_trigger):
class TestTaskEncoder(DefaultTaskEncoder):
def __init__(self):
super().__init__()
self.did_sleep = False
def encode_sample(self, sample: TextSample) -> TextSample:
if sample.text == "13":
import time
if not self.did_sleep:
print("Sleeping for 5 seconds on encode_sample to simulate stuck worker")
time.sleep(5)
self.did_sleep = True
return sample
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
# Train mode dataset
train_dataset = get_train_dataset(
self.mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
max_samples_per_sequence=None,
task_encoder=TestTaskEncoder(),
)
train_loader = get_loader(
train_dataset,
watchdog_timeout_seconds=3,
fail_on_timeout=False,
)
for idx, data in enumerate(train_loader):
print(idx, data.text[0])
if idx > 255:
break
mock_watchdog_trigger.assert_called()
def test_dataset_absolute_nested_subset_fail(self):
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
ratio_mds_path = self.dataset_path / "metadataset_ratio.yaml"
with open(ratio_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
# Absolute range on outer level should fail
" subset: {range: [50, 55]}",
" blend_epochized:",
" - path: ds1",
" subflavors:",
" source: ds1",
" number: 43",
" - repetitions: 2",
" path: ds2",
" subflavors:",
" source: ds2",
" number: 42",
]
)
)
try:
get_loader(
get_train_dataset(
ratio_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
)
)
assert False, "Should have failed"
except Exception as e:
assert "only allowed for a leaf dataset" in str(
e
) or "only use absolute subset ranges for a leaf dataset" in str(e), str(e)
return
def test_dataset_with_subset_end_keyword(self):
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
ratio_mds_path = self.dataset_path / "metadataset_ratio.yaml"
with open(ratio_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
# Absolute range: [50, end]
# I.e. corresponds to sample range: [50, 55] (end is not included, so up to 54)
" subset: {range: [50, end]}",
" path: ds1",
" subflavors:",
" source: ds1",
" number: 43",
]
)
)
loader = get_loader(
get_train_dataset(
ratio_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
)
)
all_numbers = [int(s.text[0]) for s in loader]
assert all_numbers == [50, 51, 52, 53, 54], "Subset range [50, end] should be [50, 55]"
def test_dataset_with_subset_ratio(self):
worker_config = WorkerConfig(
rank=0,
world_size=1,
num_workers=0,
seed_offset=42,
)
ratio_mds_path = self.dataset_path / "metadataset_ratio.yaml"
with open(ratio_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
# 20% of the dataset will be from ds1, 80% from ds2
# I.e. sample range: [0.2*55, 0.8*55] = [11, 44]
" subset: {range: [20%, 80%]}",
" blend_epochized:",
" - path: ds1",
" subflavors:",
" source: ds1",
" number: 43",
" - repetitions: 2",
" path: ds2",
" subflavors:",
" source: ds2",
" number: 42",
]
)
)
loader = get_loader(
get_train_dataset(
ratio_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
)
)
data = list(enumerate(loader))
assert len(data) == 33 + 33 * 2, len(data)
sample_counts = Counter([int(s[1].text[0]) for s in data])
assert all(sample_counts[sample] == 0 for sample in range(11)), sample_counts
assert all(sample_counts[sample] == 1 for sample in range(11, 44)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(44, 55)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(100, 111)), sample_counts
assert all(sample_counts[sample] == 2 for sample in range(111, 144)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(144, 155)), sample_counts
assert sample_counts.total() == 33 + 33 * 2, sample_counts.total()
# Combine with subset_samples
ratio2_mds_path = self.dataset_path / "metadataset_ratio2.yaml"
with open(ratio2_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
# take [10, 30] from ds1, [20, 40] from ds2 and then only [20%, 80%]
# I.e. sample range: [14, 26], 2 * [124, 136]
" subset: {range: [20%, 80%]}",
" blend_epochized:",
" - path: ds1",
" subset: {range: [10, 30]}",
" subflavors:",
" source: ds1",
" number: 43",
" - repetitions: 2",
" subset: {range: [20, 40]}",
" path: ds2",
" subflavors:",
" source: ds2",
" number: 42",
]
)
)
loader = get_loader(
get_train_dataset(
ratio2_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
)
)
data = list(enumerate(loader))
assert len(data) == 12 + 12 * 2, len(data)
sample_counts = Counter([int(s[1].text[0]) for s in data])
assert all(sample_counts[sample] == 0 for sample in range(14)), sample_counts
assert all(sample_counts[sample] == 1 for sample in range(14, 26)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(26, 55)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(100, 124)), sample_counts
assert all(sample_counts[sample] == 2 for sample in range(124, 136)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(136, 155)), sample_counts
assert sample_counts.total() == 12 + 12 * 2, sample_counts.total()
# Combine with subset_ratio and subset_samples and nested metadataset
nested_mds_path = self.dataset_path / "metadataset_nested_subset.yaml"
with open(nested_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" subset: {range: [0%, 50%]}",
" blend_epochized:",
" - path: ds3",
# take [30, 50] from ds3, then first 50%, resulting in samples [230, 240]
" subset: {range: [30, 50]}",
" subflavors:",
" source: ds3",
" number: 45",
" - repetitions: 2",
# Inner sample range: [14, 26], 2 * [124, 136], total=12*3=36
# Applying subset ratio 25%-75%: [17, 23], 2*[127, 133], total=3*6=18
# Applying outer 50%: [17, 20], 2*[127, 130], total=3*3=9
# Applying repetition: 2*[17, 20], 4*[127, 130], total=2*9=18
" subset: {range: [25%, 75%]}",
" path: metadataset_ratio2.yaml",
]
)
)
loader = get_loader(
get_train_dataset(
nested_mds_path,
worker_config=worker_config,
batch_size=1,
shuffle_buffer_size=None,
shuffle_over_epochs_multiplier=None,
parallel_shard_iters=1,
max_samples_per_sequence=None,
repeat=False,
)
)
data = list(enumerate(loader))
assert len(data) == 10 + 9 * 2, len(data)
sample_counts = Counter([int(s[1].text[0]) for s in data])
assert all(sample_counts[sample] == 0 for sample in range(17)), sample_counts
assert all(sample_counts[sample] == 2 for sample in range(17, 20)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(20, 55)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(100, 127)), sample_counts
assert all(sample_counts[sample] == 4 for sample in range(127, 130)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(130, 155)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(200, 230)), sample_counts
assert all(sample_counts[sample] == 1 for sample in range(230, 240)), sample_counts
assert all(sample_counts[sample] == 0 for sample in range(240, 255)), sample_counts
assert sample_counts.total() == 10 + 9 * 2, sample_counts.total()
def test_s3(self):
# Create a joined dataset configuration
mixed_mds_path = self.dataset_path / "metadataset_mixed.yaml"
with open(mixed_mds_path, "w") as f:
f.write(
"\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" path: msc://s3test_metadataset/test/dataset/nested_metadataset_v2.yaml",
]
)
)
with setup_s3_emulator(profile_name="s3test_metadataset") as emu:
# Upload the dataset to the S3 emulator
# EPath(self.dataset_path).copy(EPath("msc://s3/test/dataset"))
emu.add_file(self.dataset_path, "test/dataset")
train_dataset = get_loader(
get_train_dataset(
mixed_mds_path,
worker_config=WorkerConfig(
rank=0,
world_size=1,
num_workers=2,
),
batch_size=1,
shuffle_buffer_size=10,
max_samples_per_sequence=None,
virtual_epoch_length=10,
)
)
data = list(enumerate(train_dataset))
assert len(data) == 10, len(data)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
"""This module tests image transforms."""
import logging
import random
import sys
import unittest
from typing import Any, Tuple
import numpy as np
import torch
from PIL import Image
from torchvision import transforms as T
from torchvision.transforms import InterpolationMode
from megatron.energon.transforms import CustomTransform, MergeTransform
class TestTransforms(unittest.TestCase):
# Set up the test fixture
def setUp(self):
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
def tearDown(self):
pass
@staticmethod
def reset_rng(seed: int = 42):
"""Creates a PIL image with random noise."""
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
@staticmethod
def get_test_image(width: int, height: int):
"""Creates a PIL image with random noise."""
arr = np.zeros((width, height, 3), dtype=np.uint8)
# Some colorful borders
arr[0, :, :] = [255, 0, 0]
arr[:, 0, :] = [255, 255, 0]
arr[-1, :, :] = [255, 255, 255]
arr[:, -1, :] = [0, 255, 0]
# A single white pixel
if width > 3 and height > 3:
arr[3, 3, :] = [255, 255, 255]
# And in the middle some noise
if width > 10 and height > 10:
arr[5:-5, 5:-5, :] = np.random.randint(0, 255, (width - 10, height - 10, 3))
return Image.fromarray(arr)
@staticmethod
def get_test_image_soft(width: int, height: int):
"""Creates a PIL image smooth content"""
arr = np.zeros((width, height, 3), dtype=np.uint8)
# Fill red channel the image with a smooth gradient from left to right.
arr[:, :, 0] = np.arange(width)[:, None] / width * 255
# The same for green from top to bottom:
arr[:, :, 1] = np.arange(height)[None, :] / height * 255
return Image.fromarray(arr)
def _apply_and_compare(
self, testable_transform, img, atol=2, seed=42, msg=None, only_nonblack=False
):
# Then transform using our method
merge_transform = MergeTransform([testable_transform])
self.reset_rng(seed=seed)
test_result = merge_transform(img)
# And also transform using torchvision directly
self.reset_rng(seed=seed)
ref_result = testable_transform(img)
# Then compare the sizes and the images contents
self.assertEqual(test_result.size, ref_result.size)
# Check that image contents are close
np_test = np.array(test_result)
np_ref = np.array(ref_result)
if only_nonblack:
nonblack_mask = (np_test > 0) & (np_ref > 0)
np_test = np_test[nonblack_mask]
np_ref = np_ref[nonblack_mask]
# The maximum allowed difference between pixel values is 2 (uint8)
self.assertTrue(np.allclose(np_test, np_ref, atol=atol), msg=msg)
def test_resize(self):
"""Tests ResizeMapper"""
MAX_SIZE = 150
# These are the different setups we test. Each entry is a tuple of
# (source size, resize_kwargs)
size_list = [ # source size (w, h), resize_kwargs
[(100, 100), {"size": (100, 100)}],
[(200, 50), {"size": (100, 100)}],
[(50, 50), {"size": (100, 100)}],
[(500, 500), {"size": (10, 10)}],
[(1, 2), {"size": (1, 3)}], # Scale width by 1.5x
[(50, 100), {"size": 100, "max_size": MAX_SIZE}], # Test max_size
]
for source_size, resize_kwargs in size_list:
logging.info(
f"Testing Resize with source size {source_size} and resize_kwargs {resize_kwargs}"
)
# Create a test image of the given source size
img = TestTransforms.get_test_image(*source_size)
transform = T.Resize(**resize_kwargs, interpolation=InterpolationMode.NEAREST)
self._apply_and_compare(
transform,
img,
msg=f"Resize: source_size={source_size}, resize_kwargs={resize_kwargs}",
)
def test_random_resized_crop(self):
"""Tests RandomResizedCropMapper"""
randcrop = T.RandomResizedCrop(
90, scale=(0.3, 0.7), ratio=(0.75, 1.3), interpolation=InterpolationMode.BILINEAR
)
source_size = (50, 60)
logging.info(f"Testing RandomResizedCrop with source size {source_size}")
# Create a test image of the given source size
img = TestTransforms.get_test_image_soft(*source_size)
self._apply_and_compare(randcrop, img, msg="RandomResizedCrop")
def test_random_flip(self):
source_size = (55, 33)
img = TestTransforms.get_test_image(*source_size)
logging.info("Testing RandomHorizontalFlip 5 times")
for idx in range(5):
randhflip = T.RandomHorizontalFlip(p=0.8)
self._apply_and_compare(randhflip, img, seed=idx, msg="RandomHorizontalFlip")
logging.info("Testing RandomVerticalFlip 5 times")
for idx in range(5):
randvflip = T.RandomVerticalFlip(p=0.8)
self._apply_and_compare(randvflip, img, seed=idx, msg="RandomVerticalFlip")
def test_random_rotation(self):
source_size = (55, 33)
img = TestTransforms.get_test_image_soft(*source_size)
logging.info("Testing RandomRotation without expand")
for idx in range(5):
randrot = T.RandomRotation((-90, 269), interpolation=InterpolationMode.BILINEAR)
self._apply_and_compare(
randrot,
img,
seed=idx,
msg="RandomRotation without expand",
)
logging.info("Testing RandomRotation with expand")
for idx in range(5):
randrot = T.RandomRotation(
(-180, 269), interpolation=InterpolationMode.BILINEAR, expand=True
)
self._apply_and_compare(
randrot,
img,
seed=idx,
msg="RandomRotation with expand",
)
def test_random_crop(self):
source_size = (155, 120)
img = TestTransforms.get_test_image(*source_size)
size_list = [ # crop size (w, h)
(155, 120), # Same size
(100, 50),
3, # Single int as size
120,
(155, 8), # One dimension same size
]
logging.info("Testing RandomCrop")
for idx, size in enumerate(size_list):
randcrop = T.RandomCrop(size)
self._apply_and_compare(
randcrop,
img,
seed=idx,
msg=f"RandomCrop: crop size={size}",
)
# Test `pad_if_needed` (Crop size larger than image size)
randcrop = T.RandomCrop((500, 500), pad_if_needed=True)
self._apply_and_compare(randcrop, img)
def test_random_perspective(self):
source_size = (128, 133)
img = TestTransforms.get_test_image_soft(*source_size)
logging.info("Testing RandomPerspective")
for idx in range(5):
randpersp = T.RandomPerspective(interpolation=InterpolationMode.BILINEAR)
self._apply_and_compare(
randpersp,
img,
seed=idx,
msg=f"RandomPerspective: source_size={source_size}",
only_nonblack=True, # Sometimes one pixel is off
)
def test_center_crop(self):
source_size_list = [ # source size (w, h)
(155, 120),
(154, 119),
]
crop_size_list = [ # crop size (w, h)
(155, 120), # Same size
(100, 50),
3, # Single int as size
120,
(200, 50), # Large than image in x direction
(50, 200), # Large than image in y direction
(200, 200), # Large than image in both directions
]
logging.info("Testing CenterCrop")
for source_size in source_size_list:
img = TestTransforms.get_test_image(*source_size)
for idx, crop_size in enumerate(crop_size_list):
centcrop = T.CenterCrop(crop_size)
self._apply_and_compare(
centcrop,
img,
seed=idx,
msg=f"CenterCrop: source_size={source_size}, crop_size={crop_size}",
)
def test_custom(self):
"""Tests if a custom transform works"""
source_size = (128, 133)
class FixedTranslate(CustomTransform):
"""Translates the image by 5 pixels in both x and y direction"""
def __init__(self):
pass
def apply_transform(
self, matrix: np.ndarray, dst_size: np.ndarray
) -> Tuple[Any, Any, Any]:
matrix = self.translate(5, 5) @ matrix
return matrix, dst_size, (self.__class__.__name__, (5, 5))
img = TestTransforms.get_test_image(*source_size)
merge_transform = MergeTransform([FixedTranslate()])
test_result = merge_transform(img)
reference_img = Image.new(img.mode, img.size, (0, 0, 0))
reference_img.paste(img, (5, 5))
self.assertTrue(
np.allclose(np.array(test_result), np.array(reference_img), atol=1),
msg="FixedTranslate",
)
def test_merge(self):
"""Tests if two merged transforms yield the same result.
Merging RandomCrop and RandomPerspective."""
source_size = (128, 133)
img = TestTransforms.get_test_image_soft(*source_size)
randcrop = T.RandomCrop((70, 70))
randrot = T.RandomRotation((45, 269), interpolation=InterpolationMode.BILINEAR)
merge_transform = MergeTransform([randrot, randcrop])
self.reset_rng(1)
test_result = merge_transform(img)
self.reset_rng(1)
ref_result = randcrop(randrot(img))
self.assertTrue(
np.allclose(np.array(test_result), np.array(ref_result), atol=1),
msg="MergeTransform of RandomRotation and RandomCrop",
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import multiprocessing
import unittest
from dataclasses import dataclass
from megatron.energon.fork_hook import (
DataclassForkMixin,
ForkMixin,
WeakCallbacks,
after_in_child_fork_hook,
after_in_parent_fork_hook,
before_fork_hook,
)
class TestWeakref(unittest.TestCase):
def test_weakcallbacks(self):
# Just test the internal WeakCallbacks class.
a_called = 0
fn_called = 0
class A:
def method(self):
nonlocal a_called
a_called += 1
def fn():
nonlocal fn_called
fn_called += 1
a = A()
registry = WeakCallbacks()
registry.add_hook(a.method)
registry.add_hook(fn)
registry.add_hook(a.method)
registry.run()
assert a_called == 1, a_called
assert fn_called == 1, fn_called
assert len(registry._hooks) == 2, len(registry._hooks)
del a
assert len(registry._hooks) == 1, len(registry._hooks)
registry.run()
assert a_called == 1, a_called
assert fn_called == 2, fn_called
del fn
assert len(registry._hooks) == 0, len(registry._hooks)
registry.run()
assert a_called == 1, a_called
assert fn_called == 2, fn_called
assert len(registry._hooks) == 0, len(registry._hooks)
def test_fork_weakref(self):
# Verify that the fork hooks are called correctly, and that gc works correctly.
_a_before_fork_called = 0
_a_after_in_child_fork_called = 0
_a_after_in_parent_fork_called = 0
class A(ForkMixin):
def __before_fork__(self):
nonlocal _a_before_fork_called
_a_before_fork_called += 1
def __after_in_child_fork__(self):
nonlocal _a_after_in_child_fork_called
_a_after_in_child_fork_called += 1
def __after_in_parent_fork__(self):
nonlocal _a_after_in_parent_fork_called
_a_after_in_parent_fork_called += 1
_b_before_fork_called = 0
_b_after_in_child_fork_called = 0
_b_after_in_parent_fork_called = 0
@dataclass
class B(DataclassForkMixin):
def __before_fork__(self):
nonlocal _b_before_fork_called
_b_before_fork_called += 1
def __after_in_child_fork__(self):
nonlocal _b_after_in_child_fork_called
_b_after_in_child_fork_called += 1
def __after_in_parent_fork__(self):
nonlocal _b_after_in_parent_fork_called
_b_after_in_parent_fork_called += 1
a = A()
b = B()
_before_fork_called = 0
_after_in_child_fork_called = 0
_after_in_parent_fork_called = 0
def before_fork():
nonlocal _before_fork_called
_before_fork_called += 1
def after_in_child_fork():
nonlocal _after_in_child_fork_called
_after_in_child_fork_called += 1
def after_in_parent_fork():
nonlocal _after_in_parent_fork_called
_after_in_parent_fork_called += 1
before_fork_hook(before_fork)
after_in_child_fork_hook(after_in_child_fork)
after_in_parent_fork_hook(after_in_parent_fork)
multiprocessing.set_start_method("fork", force=True)
def process_verify_fork_hooks_1():
# Verify in the process that the fork hooks were called
assert _before_fork_called == 1, _before_fork_called
assert _after_in_child_fork_called == 1, _after_in_child_fork_called
# This was not called in the child process
assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called
assert _a_before_fork_called == 1, _a_before_fork_called
assert _a_after_in_child_fork_called == 1, _a_after_in_child_fork_called
assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called
assert _b_before_fork_called == 1, _b_before_fork_called
assert _b_after_in_child_fork_called == 1, _b_after_in_child_fork_called
assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called
p1 = multiprocessing.Process(target=process_verify_fork_hooks_1)
p1.start()
p1.join()
assert p1.exitcode == 0, p1.exitcode
assert _before_fork_called == 1, _before_fork_called
assert _after_in_child_fork_called == 0, _after_in_child_fork_called
assert _after_in_parent_fork_called == 1, _after_in_parent_fork_called
assert _a_before_fork_called == 1, _a_before_fork_called
assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called
assert _a_after_in_parent_fork_called == 1, _a_after_in_parent_fork_called
assert _b_before_fork_called == 1, _b_before_fork_called
assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called
assert _b_after_in_parent_fork_called == 1, _b_after_in_parent_fork_called
_a_before_fork_called = 0
_a_after_in_child_fork_called = 0
_a_after_in_parent_fork_called = 0
_b_before_fork_called = 0
_b_after_in_child_fork_called = 0
_b_after_in_parent_fork_called = 0
_before_fork_called = 0
_after_in_child_fork_called = 0
_after_in_parent_fork_called = 0
del a
del b
del before_fork
del after_in_child_fork
del after_in_parent_fork
def process_verify_fork_hooks_2():
assert _before_fork_called == 0, _before_fork_called
assert _after_in_child_fork_called == 0, _after_in_child_fork_called
assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called
assert _a_before_fork_called == 0, _a_before_fork_called
assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called
assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called
assert _b_before_fork_called == 0, _b_before_fork_called
assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called
assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called
p2 = multiprocessing.Process(target=process_verify_fork_hooks_2)
p2.start()
p2.join()
assert p2.exitcode == 0, p2.exitcode
assert _before_fork_called == 0, _before_fork_called
assert _after_in_child_fork_called == 0, _after_in_child_fork_called
assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called
assert _a_before_fork_called == 0, _a_before_fork_called
assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called
assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called
assert _b_before_fork_called == 0, _b_before_fork_called
assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called
assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called
This source diff could not be displayed because it is too large. You can view the blob instead.
[flake8]
max-line-length = 100
extend-ignore = E203,E501,F401,E402,E714
per-file-ignores = __init__.py:F401
\ No newline at end of file
---
name: BUG
about: Report a bug that needs attention
title: "[BUG]"
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior. The easier it is to reproduce the faster it will get maintainer attention.
**Expected behavior**
A clear and concise description of what you expected to happen.
**Stack trace/logs**
If applicable, add the stack trace or logs from the time of the error.
**Environment (please complete the following information):**
- Megatron-LM commit ID
- PyTorch version
- CUDA version
- NCCL version
**Proposed fix**
If you have a proposal for how to fix the issue state it here or link to a PR.
**Additional context**
Add any other context about the problem here.
---
name: ENHANCEMENT
about: Suggest an idea to improve this project
title: "[ENHANCEMENT]"
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Proposed implementation**
If you have a proposed implementation for the feature state it here or link to a PR.
**Additional context**
Add any other context or screenshots about the feature request here.
---
name: QUESTION
about: Ask a question about Megatron-LM that is not a bug, regression or enhancement
request
title: "[QUESTION]"
labels: ''
assignees: ''
---
**Your question**
Ask a clear and concise question about Megatron-LM.
---
name: REGRESSION
about: Report a regression in speed or accuracy due to a Megatron-LM update
title: "[REGRESSION]"
labels: ''
assignees: ''
---
**Describe the regression**
A clear and concise description of what the regression is.
**To Reproduce**
Steps to reproduce the behavior. The easier it is to reproduce the faster it will get maintainer attention.
**Previous performance**
What speed or accuracy did you previously see.
**New performance**
What speed or accuracy do you see after the update.
**Stack trace/logs**
If applicable, add the stack trace or logs related to the regression.
**Environment (please complete the following information):**
- Previous Megatron-LM commit ID
- New Megatron-LM commit ID
- Previous PyTorch version
- New PyTorch version
- Previous CUDA version
- New CUDA version
- Previous NCCL version
- New NCCL version
**Proposed fix**
If you have a proposal for how to fix the issue state it here or link to a PR.
**Additional context**
Add any other context about the problem here.
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
#
# You can adjust the behavior by modifying this file.
# For more information, see:
# https://github.com/actions/stale
name: Mark stale issues and pull requests
on:
schedule:
- cron: '15 18 * * *'
jobs:
stale:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v5
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
days-before-stale: 60
stale-issue-message: 'Marking as stale. No activity in 60 days.'
stale-pr-message: 'Marking as stale. No activity in 60 days.'
stale-issue-label: 'stale'
stale-pr-label: 'stale'
remove-stale-when-updated: true
operations-per-run: 1000
days-before-close: -1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment