Commit c1cacde6 authored by weishb's avatar weishb
Browse files

vllm-omni_0.15.0.rc1+fix1 first commit

parent 35607782
import subprocess
from pathlib import Path
import pytest
from tests.conftest import OmniServer
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
stage_configs = [str(Path(__file__).parent.parent / "e2e" / "stage_configs" / "qwen3_omni_ci.yaml")]
# Create parameter combinations for model and stage config
test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
@pytest.fixture(scope="module")
def omni_server(request):
"""Start vLLM-Omni server as a subprocess with actual model weights.
Uses session scope so the server starts only once for the entire test session.
Multi-stage initialization can take 10-20+ minutes.
"""
model, stage_config_path = request.param
print(f"Starting OmniServer with model: {model}")
print("This may take 10-20+ minutes for initialization...")
with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "120"]) as server:
print("OmniServer started successfully")
yield server
print("OmniServer stopped")
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_bench_serve_chat(omni_server):
command = [
"vllm",
"bench",
"serve",
"--omni",
"--model",
omni_server.model,
"--port",
str(omni_server.port),
"--dataset-name",
"random",
"--random-input-len",
"32",
"--random-output-len",
"4",
"--num-prompts",
"5",
"--endpoint",
"/v1/chat/completions",
"--backend",
"openai-chat-omni",
]
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
import base64
import datetime
import io
import math
import os
import random
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Set CPU device for CI environments without GPU
if "VLLM_TARGET_DEVICE" not in os.environ:
os.environ["VLLM_TARGET_DEVICE"] = "cpu"
import gc
import socket
import subprocess
import sys
import time
from pathlib import Path
from typing import Any
import numpy as np
import psutil
import pytest
import torch
import yaml
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
from vllm.logger import init_logger
from vllm.utils.network_utils import get_open_port
logger = init_logger(__name__)
@pytest.fixture(autouse=True)
def default_vllm_config():
"""Set a default VllmConfig for all tests.
This fixture is auto-used for all tests to ensure that any test
that directly instantiates vLLM CustomOps (e.g., RMSNorm, LayerNorm)
or model components has the required VllmConfig context.
This fixture is required for vLLM 0.14.0+ where CustomOp initialization
requires a VllmConfig context set via set_current_vllm_config().
"""
from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config
# Use CPU device if no GPU is available (e.g., in CI environments)
has_gpu = torch.cuda.is_available() and torch.cuda.device_count() > 0
device = "cuda" if has_gpu else "cpu"
device_config = DeviceConfig(device=device)
with set_current_vllm_config(VllmConfig(device_config=device_config)):
yield
@pytest.fixture(autouse=True)
def clean_gpu_memory_between_tests():
print("\n=== PRE-TEST GPU CLEANUP ===")
_run_pre_test_cleanup()
yield
_run_post_test_cleanup()
def _run_pre_test_cleanup(enable_force=False):
if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force:
print("GPU cleanup disabled")
return
print("Pre-test GPU status:")
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
try:
from tests.utils import wait_for_gpu_memory_to_clear
wait_for_gpu_memory_to_clear(
devices=list(range(num_gpus)),
threshold_ratio=0.05,
)
except Exception as e:
print(f"Pre-test cleanup note: {e}")
def _run_post_test_cleanup(enable_force=False):
if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force:
print("GPU cleanup disabled")
return
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
print("Post-test GPU status:")
_print_gpu_processes()
def _print_gpu_processes():
"""Print GPU information including nvidia-smi and system processes"""
print("\n" + "=" * 80)
print("NVIDIA GPU Information (nvidia-smi)")
print("=" * 80)
try:
nvidia_result = subprocess.run(
["nvidia-smi"],
capture_output=True,
text=True,
timeout=5,
)
if nvidia_result.returncode == 0:
lines = nvidia_result.stdout.strip().split("\n")
for line in lines[:20]:
print(line)
if len(lines) > 20:
print(f"... (showing first 20 of {len(lines)} lines)")
else:
print("nvidia-smi command failed")
except (subprocess.TimeoutExpired, FileNotFoundError):
print("nvidia-smi not available or timed out")
except Exception as e:
print(f"Error running nvidia-smi: {e}")
print("\n" + "=" * 80)
print("Detailed GPU Processes (nvidia-smi pmon)")
print("=" * 80)
try:
pmon_result = subprocess.run(
["nvidia-smi", "pmon", "-c", "1"],
capture_output=True,
text=True,
timeout=3,
)
if pmon_result.returncode == 0 and pmon_result.stdout.strip():
print(pmon_result.stdout)
else:
print("No active GPU processes found via nvidia-smi pmon")
except Exception:
print("nvidia-smi pmon not available")
print("\n" + "=" * 80)
print("System Processes with GPU keywords")
print("=" * 80)
def dummy_messages_from_mix_data(
system_prompt: dict[str, Any] = None,
video_data_url: Any = None,
audio_data_url: Any = None,
image_data_url: Any = None,
content_text: str = None,
):
"""Create messages with video、image、audio data URL for OpenAI API."""
if content_text is not None:
content = [{"type": "text", "text": content_text}]
else:
content = []
media_items = []
if isinstance(video_data_url, list):
for video_url in video_data_url:
media_items.append((video_url, "video"))
else:
media_items.append((video_data_url, "video"))
if isinstance(image_data_url, list):
for url in image_data_url:
media_items.append((url, "image"))
else:
media_items.append((image_data_url, "image"))
if isinstance(audio_data_url, list):
for url in audio_data_url:
media_items.append((url, "audio"))
else:
media_items.append((audio_data_url, "audio"))
content.extend(
{"type": f"{media_type}_url", f"{media_type}_url": {"url": url}}
for url, media_type in media_items
if url is not None
)
messages = [{"role": "user", "content": content}]
if system_prompt is not None:
messages = [system_prompt] + messages
return messages
def generate_synthetic_audio(
duration: int, # seconds
num_channels: int, # 1:Mono,2:Stereo 5:5.1 surround sound
sample_rate: int = 48000, # Default use 48000Hz.
save_to_file: bool = False,
) -> dict[str, Any]:
""" "Generate synthetic audio with rain."""
import soundfile as sf
# Initialize audio data array
num_samples = int(sample_rate * duration)
audio_data = np.zeros((num_samples, num_channels), dtype=np.float32)
# Configure parameters based on rain intensity
drop_density = 10 # Number of raindrops per second
drop_volume = 0.15 # Volume of individual raindrops
background_volume = 0.02 # Volume of background rain noise
# Pink noise sounds more natural than white noise for rain
white_noise = np.random.randn(num_samples)
pink_noise = np.convolve(white_noise, np.ones(8) / 8, mode="same")
pink_noise = pink_noise / np.max(np.abs(pink_noise)) if np.max(np.abs(pink_noise)) > 0 else pink_noise
bg_noise = pink_noise * background_volume
# Add background noise to all channels
for ch in range(num_channels):
audio_data[:, ch] += bg_noise
# Total number of raindrops = density × duration × channels for stereo effect
total_drops = int(drop_density * duration * num_channels)
for _ in range(total_drops):
# Random timing for raindrop
drop_time = random.uniform(0, duration)
# Random duration of raindrop sound (0.01-0.05 seconds)
drop_duration = random.uniform(0.01, 0.05)
# Random frequency gives variation in raindrop pitch
drop_freq = random.uniform(500, 5000) # Hz
# Random channel selection for stereo positioning
channel = random.randint(0, num_channels - 1)
# Calculate sample positions for this raindrop
start_sample = int(drop_time * sample_rate)
drop_samples = int(drop_duration * sample_rate)
end_sample = min(start_sample + drop_samples, num_samples)
if start_sample < end_sample:
# Generate the raindrop sound
num_drop_samples = end_sample - start_sample
t = np.arange(num_drop_samples) / sample_rate
# Basic sine wave for raindrop sound
drop_sound = drop_volume * np.sin(2 * math.pi * drop_freq * t)
# Apply envelope for natural attack and decay
envelope = np.ones(num_drop_samples)
attack_samples = int(num_drop_samples * 0.1) # 10% of samples for attack
decay_samples = num_drop_samples - attack_samples
if attack_samples > 0:
# Linear attack: volume increases from 0 to 1
envelope[:attack_samples] = np.linspace(0, 1, attack_samples)
if decay_samples > 0:
# Exponential decay for natural sound fade
decay = np.exp(-8 * t[attack_samples:] / drop_duration)
envelope[attack_samples:] = decay
# Apply envelope to raindrop sound
drop_sound *= envelope
# Add raindrop sound to selected channel
audio_data[start_sample:end_sample, channel] += drop_sound
# Step 3: Add simple reverb effect for realism
# Reverb simulates sound reflections in environment
if duration > 2:
# Single delay reverb (100ms delay)
delay_samples = int(0.1 * sample_rate)
if delay_samples < num_samples:
for ch in range(num_channels):
delayed = np.zeros(num_samples)
delayed[delay_samples:] = audio_data[:-delay_samples, ch] * 0.3
audio_data[:, ch] += delayed
# Step 4: Normalize audio to prevent clipping
# Find maximum amplitude and scale to 80% of maximum volume
max_amp = np.max(np.abs(audio_data))
if max_amp > 0:
audio_data = audio_data / max_amp * 0.8
# Handle file saving
audio_bytes = None
if save_to_file:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = f"audio_{num_channels}ch_{timestamp}.wav"
try:
sf.write(output_path, audio_data, sample_rate, format="WAV", subtype="PCM_16")
print(f"Audio saved: {output_path}")
with open(output_path, "rb") as f:
audio_bytes = f.read()
except Exception as e:
print(f"Save failed: {e}")
save_to_file = False
# If not saving or save failed, create in memory
if not save_to_file or audio_bytes is None:
buffer = io.BytesIO()
sf.write(buffer, audio_data, sample_rate, format="WAV", subtype="PCM_16")
buffer.seek(0)
audio_bytes = buffer.read()
# Return result
base64_audio = base64.b64encode(audio_bytes).decode("utf-8")
result = {
"base64": base64_audio,
}
if save_to_file and output_path:
result["file_path"] = output_path
return result
def generate_synthetic_video(width: int, height: int, num_frames: int, save_to_file: bool = False) -> str:
"""Generate synthetic video with bouncing balls and return base64 string."""
import cv2
import imageio
# Create random balls
num_balls = random.randint(3, 8)
balls = []
for _ in range(num_balls):
radius = min(width, height) // 8
if radius < 1:
raise ValueError(f"Video dimensions ({width}x{height}) are too small for synthetic video generation")
x = random.randint(radius, width - radius)
y = random.randint(radius, height - radius)
speed = random.uniform(3.0, 8.0)
angle = random.uniform(0, 2 * math.pi)
vx = speed * math.cos(angle)
vy = speed * math.sin(angle)
# OpenCV uses BGR format, but imageio expects RGB
# We'll create in BGR first, then convert to RGB later
color_bgr = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
balls.append({"x": x, "y": y, "vx": vx, "vy": vy, "radius": radius, "color_bgr": color_bgr})
# Generate video frames
video_frames = []
for frame_idx in range(num_frames):
# Create black background (BGR format)
frame_bgr = np.zeros((height, width, 3), dtype=np.uint8)
for ball in balls:
# Update position
ball["x"] += ball["vx"]
ball["y"] += ball["vy"]
# Boundary collision detection
if ball["x"] - ball["radius"] <= 0 or ball["x"] + ball["radius"] >= width:
ball["vx"] = -ball["vx"]
ball["x"] = max(ball["radius"], min(width - ball["radius"], ball["x"]))
if ball["y"] - ball["radius"] <= 0 or ball["y"] + ball["radius"] >= height:
ball["vy"] = -ball["vy"]
ball["y"] = max(ball["radius"], min(height - ball["radius"], ball["y"]))
# Use cv2 to draw circle
x, y = int(ball["x"]), int(ball["y"])
radius = ball["radius"]
# Draw solid circle (main circle)
cv2.circle(frame_bgr, (x, y), radius, ball["color_bgr"], -1)
# Add simple 3D effect: draw a brighter center
if radius > 3: # Only add highlight when radius is large enough
highlight_radius = max(1, radius // 2)
highlight_x = max(highlight_radius, min(x - radius // 4, width - highlight_radius))
highlight_y = max(highlight_radius, min(y - radius // 4, height - highlight_radius))
# Create highlight color (brighter)
highlight_color = tuple(min(c + 40, 255) for c in ball["color_bgr"])
cv2.circle(frame_bgr, (highlight_x, highlight_y), highlight_radius, highlight_color, -1)
# Convert BGR to RGB for imageio
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
video_frames.append(frame_rgb)
video_bytes = None
saved_file_path = None
buffer = io.BytesIO()
writer_kwargs = {
"format": "mp4",
"fps": 30,
"codec": "libx264",
"quality": 7,
"pixelformat": "yuv420p",
"macro_block_size": 16,
"ffmpeg_params": [
"-preset",
"medium",
"-crf",
"23",
"-movflags",
"+faststart",
"-pix_fmt",
"yuv420p",
"-vf",
f"scale={width}:{height}",
],
}
if save_to_file:
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = f"video_{width}x{height}_{timestamp}.mp4"
try:
with imageio.get_writer(output_path, **writer_kwargs) as writer:
for frame in video_frames:
writer.append_data(frame)
saved_file_path = output_path
print(f"Video saved to: {saved_file_path}")
with open(output_path, "rb") as f:
video_bytes = f.read()
except Exception as e:
print(f"Warning: Failed to save video to file {output_path}: {e}")
save_to_file = False
if not save_to_file or video_bytes is None:
with imageio.get_writer(buffer, **writer_kwargs) as writer:
for frame in video_frames:
writer.append_data(frame)
buffer.seek(0)
video_bytes = buffer.read()
base64_video = base64.b64encode(video_bytes).decode("utf-8")
result = {
"base64": base64_video,
}
if save_to_file and saved_file_path:
result["file_path"] = saved_file_path
return result
def generate_synthetic_image(width: int, height: int, save_to_file: bool = False) -> Any:
"""Generate synthetic image with randomly colored squares and return base64 string."""
from PIL import Image, ImageDraw
# Create white background
image = Image.new("RGB", (width, height), (255, 255, 255))
draw = ImageDraw.Draw(image)
# Generate random number of squares
num_squares = random.randint(3, 8)
for _ in range(num_squares):
# Random square size
square_size = random.randint(min(width, height) // 8, min(width, height) // 4)
# Random position
x = random.randint(0, width - square_size - 1)
y = random.randint(0, height - square_size - 1)
# Random color
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
# Random border width
border_width = random.randint(1, 5)
# Draw square
draw.rectangle([x, y, x + square_size, y + square_size], fill=color, outline=(0, 0, 0), width=border_width)
# Handle file saving
image_bytes = None
saved_file_path = None
if save_to_file:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = f"image_{width}x{height}_{timestamp}.jpg"
try:
# Save image to file
image.save(output_path, format="JPEG", quality=85, optimize=True)
saved_file_path = output_path
print(f"Image saved to: {saved_file_path}")
# Read file for base64 encoding
with open(output_path, "rb") as f:
image_bytes = f.read()
except Exception as e:
print(f"Warning: Failed to save image to file {output_path}: {e}")
save_to_file = False
# If not saving or save failed, create in memory
if not save_to_file or image_bytes is None:
buffer = io.BytesIO()
image.save(buffer, format="JPEG", quality=85, optimize=True)
buffer.seek(0)
image_bytes = buffer.read()
# Generate base64
base64_image = base64.b64encode(image_bytes).decode("utf-8")
# Return result
result = {
"base64": base64_image,
}
if save_to_file and saved_file_path:
result["file_path"] = saved_file_path
return result
def preprocess_text(text):
import re
word_to_num = {
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
for word, num in word_to_num.items():
pattern = r"\b" + re.escape(word) + r"\b"
text = re.sub(pattern, num, text, flags=re.IGNORECASE)
text = re.sub(r"[^\w\s]", "", text)
text = re.sub(r"\s+", " ", text)
return text.lower().strip()
def cosine_similarity_text(text1, text2, n: int = 3):
from collections import Counter
if not text1 or not text2:
return 0.0
text1 = preprocess_text(text1)
text2 = preprocess_text(text2)
ngrams1 = [text1[i : i + n] for i in range(len(text1) - n + 1)]
ngrams2 = [text2[i : i + n] for i in range(len(text2) - n + 1)]
counter1 = Counter(ngrams1)
counter2 = Counter(ngrams2)
all_ngrams = set(counter1.keys()) | set(counter2.keys())
vec1 = [counter1.get(ng, 0) for ng in all_ngrams]
vec2 = [counter2.get(ng, 0) for ng in all_ngrams]
dot_product = sum(a * b for a, b in zip(vec1, vec2))
norm1 = sum(a * a for a in vec1) ** 0.5
norm2 = sum(b * b for b in vec2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
def convert_audio_to_text(audio_data):
"""
Convert base64 encoded audio data to text using speech recognition.
"""
import whisper
audio_data = base64.b64decode(audio_data)
output_path = f"./test_{int(time.time())}"
with open(output_path, "wb") as audio_file:
audio_file.write(audio_data)
print(f"audio data is saved: {output_path}")
model = whisper.load_model("base")
text = model.transcribe(
output_path,
temperature=0.0,
word_timestamps=True,
condition_on_previous_text=False,
)["text"]
if text:
return text
else:
return ""
def merge_base64_and_convert_to_text(base64_list):
"""
Merge a list of base64 encoded audio data and convert to text.
"""
import whisper
from pydub import AudioSegment
merged_audio = None
for base64_data in base64_list:
audio_data = base64.b64decode(base64_data.split(",", 1)[-1])
seg = AudioSegment.from_file(io.BytesIO(audio_data))
if merged_audio is None:
merged_audio = seg
else:
merged_audio += seg
output_path = f"./test_{int(time.time())}"
merged_audio.export(output_path, format="wav")
model = whisper.load_model("base")
text = model.transcribe(
output_path,
temperature=0.0,
word_timestamps=True,
condition_on_previous_text=False,
)["text"]
if text:
return text
else:
return ""
def modify_stage_config(
yaml_path: str,
updates: dict[str, Any],
deletes: dict[str, Any] = None,
) -> str:
"""
Modify configurations in a YAML file, supporting both top-level and stage-specific modifications,
including addition, modification, and deletion of configurations.
Args:
yaml_path: Path to the YAML configuration file.
updates: Dictionary containing both top-level and stage-specific modifications to add or update.
Format: {
'async_chunk': True,
'stage_args': {
0: {'engine_args.max_model_len': 5800},
1: {'runtime.max_batch_size': 2}
}
}
deletes: Dictionary containing configurations to delete.
Format: {
'old_config': None, # Delete entire key
'stage_args': {
0: ['engine_args.old_param'],
1: ['runtime.unused_setting']
}
}
Returns:
str: Path to the newly created modified YAML file with timestamp suffix.
"""
path = Path(yaml_path)
if not path.exists():
raise FileNotFoundError(f"yaml does not exist: {path}")
try:
with open(yaml_path, encoding="utf-8") as f:
config = yaml.safe_load(f) or {}
except Exception as e:
raise ValueError(f"Cannot parse YAML file: {e}")
# Helper function to apply update
def apply_update(config_dict: dict, key_path: str, value: Any) -> None:
"""Apply update to dictionary using dot-separated path."""
# Handle direct list assignment (e.g., engine_input_source: [1, 2])
if "." not in key_path:
# Simple key, set directly
config_dict[key_path] = value
return
current = config_dict
keys = key_path.split(".")
for i in range(len(keys) - 1):
key = keys[i]
# Handle list indices
if key.isdigit() and isinstance(current, list):
index = int(key)
if index < 0:
raise ValueError(f"Negative list index not allowed: {index}")
if index >= len(current):
# Expand list if needed
while len(current) <= index:
# If we need to go deeper (more keys after this), create a dict
# Otherwise, create None placeholder
current.append({} if i < len(keys) - 2 else None)
current = current[index]
elif isinstance(current, dict):
# Handle dictionary keys
if key not in current:
# If there are more keys after this, create appropriate structure
if i < len(keys) - 1:
# Check if next key is a digit (list index) or string (dict key)
if keys[i + 1].isdigit():
current[key] = []
else:
current[key] = {}
else:
# This is the last key, create based on value type
current[key] = [] if isinstance(value, list) else {}
elif not isinstance(current[key], (dict, list)) and i < len(keys) - 1:
# If current value is not dict/list but we need to go deeper, replace it
if keys[i + 1].isdigit():
current[key] = []
else:
current[key] = {}
current = current[key]
else:
# Current is not a dict or list, cannot traverse further
raise TypeError(
f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}"
)
# Set the final value
last_key = keys[-1]
if isinstance(current, list) and last_key.isdigit():
# Setting a value in a list by index
index = int(last_key)
if index < 0:
raise ValueError(f"Negative list index not allowed: {index}")
if index >= len(current):
# Expand list if needed
while len(current) <= index:
current.append(None)
current[index] = value
elif isinstance(current, dict):
# Special case: if the value is a list and we're setting a top-level key
# Example: updating engine_input_source with [1, 2]
current[last_key] = value
else:
# Current is not a dict, cannot set key
raise TypeError(f"Cannot set value at {key_path}. Current type is {type(current).__name__}, expected dict.")
# Helper function to delete by path
def delete_by_path(config_dict: dict, path: str) -> None:
"""Delete configuration by dot-separated path."""
if not path:
return
current = config_dict
keys = path.split(".")
# Traverse to the parent
for i in range(len(keys) - 1):
key = keys[i]
# Handle list indices
if key.isdigit() and isinstance(current, list):
index = int(key)
if index < 0 or index >= len(current):
raise KeyError(f"List index {index} out of bounds")
current = current[index]
elif isinstance(current, dict):
if key not in current:
raise KeyError(f"Path {'.'.join(keys[: i + 1])} does not exist")
current = current[key]
else:
raise TypeError(
f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}"
)
# Delete the item
last_key = keys[-1]
if isinstance(current, list) and last_key.isdigit():
index = int(last_key)
if index < 0 or index >= len(current):
raise KeyError(f"List index {index} out of bounds")
del current[index]
elif isinstance(current, dict) and last_key in current:
del current[last_key]
else:
raise KeyError(f"Path {path} does not exist")
# Apply deletions first
if deletes:
for key, value in deletes.items():
if key == "stage_args":
if value and isinstance(value, dict):
stage_args = config.get("stage_args", [])
if not stage_args:
raise ValueError("stage_args does not exist in config")
for stage_id, delete_paths in value.items():
if not delete_paths:
continue
# Find stage by ID
target_stage = None
for stage in stage_args:
if stage.get("stage_id") == stage_id:
target_stage = stage
break
if target_stage is None:
available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s]
raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}")
# Delete specified paths in this stage
for path in delete_paths:
if path: # Skip empty paths
delete_by_path(target_stage, path)
elif "." in key:
# Delete using dot-separated path
delete_by_path(config, key)
elif value is None and key in config:
# Delete entire key
del config[key]
# Apply updates
for key, value in updates.items():
if key == "stage_args":
if value and isinstance(value, dict):
stage_args = config.get("stage_args", [])
if not stage_args:
raise ValueError("stage_args does not exist in config")
for stage_id, stage_updates in value.items():
# Find stage by ID
target_stage = None
for stage in stage_args:
if stage.get("stage_id") == stage_id:
target_stage = stage
break
if target_stage is None:
available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s]
raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}")
# Apply updates to this stage
for path, val in stage_updates.items():
# Check if this is a simple key (not dot-separated)
# Example: 'engine_input_source' vs 'engine_args.max_model_len'
if "." not in path:
# Direct key assignment (e.g., updating a list value)
target_stage[path] = val
else:
# Dot-separated path (e.g., nested dict access)
apply_update(target_stage, path, val)
elif "." in key:
# Apply using dot-separated path
apply_update(config, key, value)
else:
# Direct top-level key
config[key] = value
# Save to new file with timestamp
timestamp = int(time.time())
base_name = yaml_path.rsplit(".", 1)[0] if "." in yaml_path else yaml_path
output_path = f"{base_name}_{timestamp}.yaml"
with open(output_path, "w", encoding="utf-8") as f:
yaml.dump(config, f, default_flow_style=None, sort_keys=False, allow_unicode=True, indent=2)
return output_path
class OmniServer:
"""Omniserver for vLLM-Omni tests."""
def __init__(
self,
model: str,
serve_args: list[str],
*,
env_dict: dict[str, str] | None = None,
) -> None:
_run_pre_test_cleanup(enable_force=True)
_run_post_test_cleanup(enable_force=True)
cleanup_dist_env_and_memory()
self.model = model
self.serve_args = serve_args
self.env_dict = env_dict
self.proc: subprocess.Popen | None = None
self.host = "127.0.0.1"
self.port = get_open_port()
def _start_server(self) -> None:
"""Start the vLLM-Omni server subprocess."""
env = os.environ.copy()
env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if self.env_dict is not None:
env.update(self.env_dict)
cmd = [
sys.executable,
"-m",
"vllm_omni.entrypoints.cli.main",
"serve",
self.model,
"--omni",
"--host",
self.host,
"--port",
str(self.port),
] + self.serve_args
print(f"Launching OmniServer with: {' '.join(cmd)}")
self.proc = subprocess.Popen(
cmd,
env=env,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root
)
# Wait for server to be ready
max_wait = 1200 # 20 minutes
start_time = time.time()
while time.time() - start_time < max_wait:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
result = sock.connect_ex((self.host, self.port))
if result == 0:
print(f"Server ready on {self.host}:{self.port}")
return
except Exception:
pass
time.sleep(2)
raise RuntimeError(f"Server failed to start within {max_wait} seconds")
def _kill_process_tree(self, pid):
"""kill process and its children with verification"""
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
# Get all PIDs first
all_pids = [pid] + [child.pid for child in children]
# Terminate children
for child in children:
try:
child.terminate()
except psutil.NoSuchProcess:
pass
# Wait for children
gone, still_alive = psutil.wait_procs(children, timeout=10)
# Kill remaining children
for child in still_alive:
try:
child.kill()
except psutil.NoSuchProcess:
pass
# Terminate parent
try:
parent.terminate()
parent.wait(timeout=10)
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
try:
parent.kill()
except psutil.NoSuchProcess:
pass
# VERIFICATION: Check if all processes are gone
time.sleep(1) # Give system time
alive_processes = []
for check_pid in all_pids:
if psutil.pid_exists(check_pid):
alive_processes.append(check_pid)
if alive_processes:
print(f"Warning: Processes still alive: {alive_processes}")
# Optional: Try system kill
import subprocess
for alive_pid in alive_processes:
try:
subprocess.run(["kill", "-9", str(alive_pid)], timeout=2)
except Exception as e:
print(f"Cleanup failed: {e}")
except psutil.NoSuchProcess:
pass
def __enter__(self):
self._start_server()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.proc:
self._kill_process_tree(self.proc.pid)
_run_pre_test_cleanup(enable_force=True)
_run_post_test_cleanup(enable_force=True)
cleanup_dist_env_and_memory()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import pickle
import tempfile
import pytest
import torch
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.data import (
DiffusionParallelConfig,
OmniDiffusionConfig,
)
from vllm_omni.diffusion.distributed.parallel_state import (
destroy_distributed_env,
init_distributed_environment,
initialize_model_parallel,
)
from vllm_omni.diffusion.forward_context import set_forward_context
from vllm_omni.platforms import current_omni_platform
def update_environment_variables(envs_dict: dict[str, str]):
"""Update multiple environment variables with logging."""
for k, v in envs_dict.items():
os.environ[k] = v
class TestAttentionModel(torch.nn.Module):
"""Test model using Attention layer."""
def __init__(
self,
num_heads: int,
head_size: int,
hidden_size: int,
causal: bool = False,
num_kv_heads: int | None = None,
scatter_idx: int = 2,
gather_idx: int = 1,
use_sync: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.hidden_size = hidden_size
self.attention = Attention(
num_heads=num_heads,
head_size=head_size,
causal=causal,
softmax_scale=1.0 / (head_size**0.5),
num_kv_heads=num_kv_heads,
scatter_idx=scatter_idx,
gather_idx=gather_idx,
use_sync=use_sync,
)
# Linear projection layers for Q, K, V
self.q_proj = torch.nn.Linear(hidden_size, num_heads * head_size)
self.k_proj = torch.nn.Linear(hidden_size, (num_kv_heads or num_heads) * head_size)
self.v_proj = torch.nn.Linear(hidden_size, (num_kv_heads or num_heads) * head_size)
self.o_proj = torch.nn.Linear(num_heads * head_size, hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward pass through attention layer."""
batch_size, seq_len, _ = hidden_states.shape
# Project to Q, K, V
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Reshape to (batch_size, seq_len, num_heads, head_size)
q = q.view(batch_size, seq_len, self.num_heads, self.head_size)
k = k.view(batch_size, seq_len, k.shape[-1] // self.head_size, self.head_size)
v = v.view(batch_size, seq_len, v.shape[-1] // self.head_size, self.head_size)
# Apply attention
attn_output = self.attention(q, k, v)
# Reshape back and project
attn_output = attn_output.view(batch_size, seq_len, -1)
output = self.o_proj(attn_output)
return output
class TestMultiLayerAttentionModel(torch.nn.Module):
"""Test model with multiple attention layers."""
def __init__(
self,
num_layers: int,
num_heads: int,
head_size: int,
hidden_size: int,
causal: bool = True,
num_kv_heads: int | None = None,
scatter_idx: int = 2,
gather_idx: int = 1,
use_sync: bool = False,
):
super().__init__()
self.num_layers = num_layers
self.layers = torch.nn.ModuleList(
[
TestAttentionModel(
num_heads=num_heads,
head_size=head_size,
hidden_size=hidden_size,
causal=causal,
num_kv_heads=num_kv_heads,
scatter_idx=scatter_idx,
gather_idx=gather_idx,
use_sync=use_sync,
)
for _ in range(num_layers)
]
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward pass through multiple attention layers."""
for layer in self.layers:
hidden_states = hidden_states + layer(hidden_states)
return hidden_states
@pytest.mark.parametrize(
"test_model_cls",
[
TestMultiLayerAttentionModel,
],
)
@pytest.mark.parametrize("ulysses_degree", [2])
@pytest.mark.parametrize("ring_degree", [2])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("num_heads", [8])
@pytest.mark.parametrize("head_size", [8])
@pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # [torch.float16, torch.bfloat16]
@pytest.mark.parametrize("use_sync", [False])
@pytest.mark.parametrize("dynamic", [False])
@pytest.mark.parametrize("use_compile", [False])
@pytest.mark.parametrize("attn_backend", ["sdpa", "flash_attn"])
def test_sequence_parallel(
ulysses_degree: int,
ring_degree: int,
test_model_cls: type[torch.nn.Module],
dtype: torch.dtype,
causal: bool,
use_sync: bool,
dynamic: bool,
use_compile: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
attn_backend: str,
):
"""Test Ulysses attention by comparing with and without SP enabled."""
sequence_parallel_size = ulysses_degree * ring_degree
# Skip if not enough GPUs available
available_gpus = current_omni_platform.get_device_count()
if available_gpus < sequence_parallel_size:
pytest.skip(f"Test requires {sequence_parallel_size} GPUs but only {available_gpus} available")
# Create temporary files to share results between processes
with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f:
baseline_output_file = f.name
with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f:
sp_output_file = f.name
with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f:
model_state_file = f.name
with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as f:
input_data_file = f.name
try:
# Step 1: Run without SP (baseline with ulysses_degree=1, ring_degree=1)
print("\n[Baseline] Running without SP (ulysses_degree=1, ring_degree=1)...")
torch.multiprocessing.spawn(
ulysses_attention_on_test_model,
args=(
1, # num_processes = 1 for baseline
test_model_cls,
batch_size,
seq_len,
num_heads,
head_size,
dtype,
causal,
use_sync,
dynamic,
use_compile,
1, # ulysses_degree = 1
1, # ring_degree = 1
1, # sequence_parallel_size = 1
baseline_output_file,
model_state_file,
input_data_file,
True, # is_baseline
attn_backend,
),
nprocs=1,
)
# Step 2: Run with SP enabled
print(f"\n[SP Test] Running with SP (ulysses_degree={ulysses_degree}, ring_degree={ring_degree})...")
torch.multiprocessing.spawn(
ulysses_attention_on_test_model,
args=(
sequence_parallel_size, # num_processes
test_model_cls,
batch_size,
seq_len,
num_heads,
head_size,
dtype,
causal,
use_sync,
dynamic,
use_compile,
ulysses_degree,
ring_degree,
sequence_parallel_size,
sp_output_file,
model_state_file,
input_data_file,
False, # is_baseline
attn_backend,
),
nprocs=sequence_parallel_size,
)
# Step 3: Verify input consistency and compare outputs
print(f"\n{'=' * 80}")
print("Verifying input data consistency...")
with open(input_data_file, "rb") as f:
input_data = pickle.load(f)
input_checksum = hash(input_data.tobytes())
print(f" Input data shape: {input_data.shape}")
print(f" Input data checksum: {input_checksum}")
print(" ✓ Both baseline and SP used the same input data")
print(f"\n{'=' * 80}")
print("Comparing outputs between baseline and SP...")
with open(baseline_output_file, "rb") as f:
baseline_output = pickle.load(f)
with open(sp_output_file, "rb") as f:
sp_output = pickle.load(f)
# Convert to tensors for comparison
baseline_tensor = torch.tensor(baseline_output)
sp_tensor = torch.tensor(sp_output)
print(f" Baseline output shape: {baseline_tensor.shape}")
print(f" SP output shape: {sp_tensor.shape}")
assert baseline_tensor.shape == sp_tensor.shape, "Output shapes must match!"
# Calculate differences
abs_diff = torch.abs(baseline_tensor - sp_tensor)
max_abs_diff = abs_diff.max().item()
mean_abs_diff = abs_diff.mean().item()
# Calculate relative difference (avoid division by zero)
baseline_abs = torch.abs(baseline_tensor)
relative_diff = abs_diff / (baseline_abs + 1e-8)
max_relative_diff = relative_diff.max().item()
mean_relative_diff = relative_diff.mean().item()
print(f"\n{'=' * 80}")
print("Output Difference Analysis:")
print(f" - Max absolute difference: {max_abs_diff:.6e}")
print(f" - Mean absolute difference: {mean_abs_diff:.6e}")
print(f" - Max relative difference: {max_relative_diff:.6e}")
print(f" - Mean relative difference: {mean_relative_diff:.6e}")
print(f" - Baseline output range: [{baseline_tensor.min().item():.6e}, {baseline_tensor.max().item():.6e}]")
print(f" - SP output range: [{sp_tensor.min().item():.6e}, {sp_tensor.max().item():.6e}]")
print(f"{'=' * 80}\n")
# Assert that differences are within acceptable tolerance
# For FP16/BF16, we expect some numerical differences due to different computation order under parallelism.
# If we use the same backend (e.g. Flash Attention) for both baseline and SP, differences should be smaller.
if dtype == torch.float16:
atol, rtol = 5e-2, 5e-2 # Increased tolerance for Ring Attention
elif dtype == torch.bfloat16:
atol, rtol = 5e-2, 5e-2 # Increased tolerance for Ring Attention
else:
atol, rtol = 1e-5, 1e-4
assert max_abs_diff < atol or max_relative_diff < rtol, (
f"Output difference too large: max_abs_diff={max_abs_diff:.6e}, "
f"max_relative_diff={max_relative_diff:.6e}, "
f"tolerance: atol={atol}, rtol={rtol}"
)
print("✓ Test passed: SP output matches baseline within tolerance")
finally:
# Clean up temporary files
for f in [baseline_output_file, sp_output_file, model_state_file, input_data_file]:
if os.path.exists(f):
os.remove(f)
def ulysses_attention_on_test_model(
local_rank: int,
world_size: int,
test_model_cls: type[torch.nn.Module],
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
causal: bool,
use_sync: bool,
dynamic: bool,
use_compile: bool,
ulysses_degree: int,
ring_degree: int,
sequence_parallel_size: int,
output_file: str,
model_state_file: str,
input_data_file: str,
is_baseline: bool,
attn_backend: str,
):
"""Run Ulysses attention test on a test model and save results for comparison."""
# Use fixed seed for reproducibility across baseline and SP runs
RANDOM_SEED = 42
current_omni_platform.seed_everything(RANDOM_SEED)
mode_str = "Baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})"
print(f"\n[{mode_str}] Rank {local_rank}/{world_size} - Random seed set to {RANDOM_SEED}")
device = torch.device(f"{current_omni_platform.device_type}:{local_rank}")
current_omni_platform.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
# Initialize distributed environment
init_distributed_environment()
# Set up OmniDiffusionConfig with parallel config
parallel_config = DiffusionParallelConfig(
pipeline_parallel_size=1,
data_parallel_size=1,
tensor_parallel_size=1,
sequence_parallel_size=sequence_parallel_size,
ulysses_degree=ulysses_degree,
ring_degree=ring_degree,
cfg_parallel_size=1,
)
od_config = OmniDiffusionConfig(
model="test_model",
dtype=dtype,
parallel_config=parallel_config,
attention_backend=attn_backend, # Set the attention backend here
)
# Initialize model parallel
initialize_model_parallel(
data_parallel_size=1,
cfg_parallel_size=1,
sequence_parallel_size=sequence_parallel_size,
ulysses_degree=ulysses_degree,
ring_degree=ring_degree,
tensor_parallel_size=1,
pipeline_parallel_size=1,
)
# Set the config so Attention can access it
with set_forward_context(omni_diffusion_config=od_config):
# Create model
hidden_size = num_heads * head_size
# Create model with appropriate parameters
model_kwargs = {
"num_heads": num_heads,
"head_size": head_size,
"hidden_size": hidden_size,
"causal": causal,
"num_kv_heads": None,
"scatter_idx": 2,
"gather_idx": 1,
"use_sync": use_sync,
}
if test_model_cls == TestMultiLayerAttentionModel:
model_kwargs["num_layers"] = 2
model = test_model_cls(**model_kwargs)
model = model.to(device).to(dtype)
# For baseline: Generate and save model state and input data
# This ensures both baseline and SP use exactly the same initialization
if is_baseline and local_rank == 0:
# Save model state for reuse (before any computation)
model_state = {k: v.cpu() for k, v in model.state_dict().items()}
with open(model_state_file, "wb") as f:
pickle.dump(model_state, f)
# Generate and save full input data with fixed seed
# Reinitialize RNG to ensure reproducibility
torch.manual_seed(42)
current_omni_platform.seed_everything(42)
full_hidden_states = torch.randn(
(batch_size, seq_len, hidden_size),
dtype=dtype,
device="cpu",
)
with open(input_data_file, "wb") as f:
pickle.dump(full_hidden_states.detach().cpu().float().numpy(), f)
print("[Baseline] Saved model state and input data")
# Synchronize to ensure baseline has saved data before SP loads it
if world_size > 1:
torch.distributed.barrier()
# IMPORTANT: Both baseline and SP load the same model state and input data
# This ensures exact same initialization and input for fair comparison
with open(model_state_file, "rb") as f:
model_state = pickle.load(f)
model.load_state_dict({k: v.to(device).to(dtype) for k, v in model_state.items()})
with open(input_data_file, "rb") as f:
full_hidden_states_np = pickle.load(f)
full_hidden_states = torch.from_numpy(full_hidden_states_np).to(device).to(dtype)
print(f"[Rank {local_rank}] Loaded model state and full input data with shape {full_hidden_states.shape}")
# Split input sequence according to sequence parallel BEFORE model forward
# Each rank gets a contiguous chunk of the sequence dimension
local_seq_len = seq_len // sequence_parallel_size
start_idx = local_rank * local_seq_len
end_idx = start_idx + local_seq_len
hidden_states = full_hidden_states[:, start_idx:end_idx, :].contiguous()
print(
f"[Rank {local_rank}] Split input: local_seq_len={local_seq_len}, "
f"indices=[{start_idx}:{end_idx}], local_shape={hidden_states.shape}"
)
if dynamic:
torch._dynamo.mark_dynamic(hidden_states, 0)
torch._dynamo.mark_dynamic(hidden_states, 1)
# Compile model if requested
if use_compile:
model = torch.compile(model)
# Run forward pass with local sequence chunk
print(f"[Rank {local_rank}] Running forward pass...")
output = model(hidden_states)
print(f"[Rank {local_rank}] Forward pass completed, output shape: {output.shape}")
# Verify output shape
assert output.shape == (batch_size, local_seq_len, hidden_size), (
f"Output shape mismatch: expected {(batch_size, local_seq_len, hidden_size)}, got {output.shape}"
)
# Gather outputs from all ranks AFTER computation
if world_size > 1:
print(f"[Rank {local_rank}] Gathering outputs from all {world_size} ranks...")
# Gather all outputs to rank 0
gathered_outputs = [torch.zeros_like(output) for _ in range(world_size)]
torch.distributed.all_gather(gathered_outputs, output)
if local_rank == 0:
# Concatenate along sequence dimension to reconstruct full sequence
full_output = torch.cat(gathered_outputs, dim=1)
print(f"[Rank 0] Gathered and concatenated outputs: {full_output.shape}")
# Verify the full output shape matches expected
assert full_output.shape == (batch_size, seq_len, hidden_size), (
f"Gathered output shape mismatch: expected {(batch_size, seq_len, hidden_size)}, "
f"got {full_output.shape}"
)
else:
full_output = None
else:
# For baseline (world_size=1), output is already complete
full_output = output
print(f"[Rank 0] No gather needed (world_size=1), output shape: {full_output.shape}")
# Save output from rank 0 for comparison
if local_rank == 0:
output_np = full_output.detach().cpu().float().numpy()
with open(output_file, "wb") as f:
pickle.dump(output_np, f)
mode_str = "baseline (no SP)" if is_baseline else f"SP (ulysses={ulysses_degree}, ring={ring_degree})"
print(
f"\n[{mode_str}] ✓ Saved output with shape {full_output.shape}:\n"
f" - batch_size={batch_size}, seq_len={seq_len}\n"
f" - num_heads={num_heads}, head_size={head_size}\n"
f" - dtype={dtype}, causal={causal}, use_sync={use_sync}\n"
)
destroy_distributed_env()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test script for FlashAttention backend with padding handling.
This script tests two main scenarios:
1. Case 1: Comparing padded vs unpadded inputs for batch_size=1
2. Case 2: Comparing FlashAttention and SDPA backends for batch_size=2 with padding
"""
import pytest
import torch
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl
from vllm_omni.diffusion.attention.backends.sdpa import SDPAImpl
def create_attention_mask(batch_size: int, seq_len: int, valid_len: int, device: torch.device) -> torch.Tensor:
"""
Create attention mask where first valid_len tokens are valid (1) and rest are padding (0).
Args:
batch_size: Batch size
seq_len: Total sequence length (including padding)
valid_len: Number of valid (non-padded) tokens
Returns:
Attention mask of shape (batch_size, seq_len)
"""
mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device)
mask[:, :valid_len] = True
return mask
def pad_tensor(tensor: torch.Tensor, target_seq_len: int, pad_value: float = 0.0) -> torch.Tensor:
"""
Pad tensor along sequence dimension (dim=1).
Args:
tensor: Input tensor of shape (batch_size, seq_len, num_heads, head_dim)
target_seq_len: Target sequence length after padding
pad_value: Value to use for padding
Returns:
Padded tensor of shape (batch_size, target_seq_len, num_heads, head_dim)
"""
batch_size, seq_len, num_heads, head_dim = tensor.shape
if target_seq_len <= seq_len:
return tensor
padding = torch.full(
(batch_size, target_seq_len - seq_len, num_heads, head_dim), pad_value, dtype=tensor.dtype, device=tensor.device
)
return torch.cat([tensor, padding], dim=1)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention requires CUDA")
def test_padding_equivalence():
"""
Case 1: Test that padded and unpadded inputs produce similar outputs.
- Input A: batch_size=1, hidden_states (1, 48), encoder_hidden_states (1, 16)
Concatenated length: 64, NO attention_mask
- Input B: Same data but padded: hidden_states (1, 58), encoder_hidden_states (1, 26)
Concatenated length: 84, WITH attention_mask
Expected: Output A and Output B should be very close.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
# Configuration
batch_size = 1
hidden_seq_len = 48
encoder_seq_len = 16
pad_length = 10
num_heads = 8
head_dim = 64
# Initialize FlashAttention
fa_impl = FlashAttentionImpl(
num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False
)
# Create base tensors with random values (same for both A and B)
torch.manual_seed(42)
hidden_states_base = torch.randn(batch_size, hidden_seq_len, num_heads, head_dim, device=device, dtype=dtype)
encoder_hidden_states_base = torch.randn(
batch_size, encoder_seq_len, num_heads, head_dim, device=device, dtype=dtype
)
# ========== Input A: Unpadded, no attention mask ==========
query_a = torch.cat([hidden_states_base, encoder_hidden_states_base], dim=1)
key_a = query_a.clone()
value_a = query_a.clone()
attn_metadata_a = AttentionMetadata(attn_mask=None)
output_a = fa_impl.forward(query=query_a, key=key_a, value=value_a, attn_metadata=attn_metadata_a)
# ========== Input B: Padded with attention mask ==========
hidden_states_padded = pad_tensor(hidden_states_base, hidden_seq_len + pad_length)
encoder_hidden_states_padded = pad_tensor(encoder_hidden_states_base, encoder_seq_len + pad_length)
query_b = torch.cat([hidden_states_padded, encoder_hidden_states_padded], dim=1)
key_b = query_b.clone()
value_b = query_b.clone()
# Create attention mask
attn_mask_b = torch.cat(
[
create_attention_mask(batch_size, hidden_seq_len + pad_length, hidden_seq_len, device),
create_attention_mask(batch_size, encoder_seq_len + pad_length, encoder_seq_len, device),
],
dim=1,
)
attn_metadata_b = AttentionMetadata(attn_mask=attn_mask_b)
output_b = fa_impl.forward(query=query_b, key=key_b, value=value_b, attn_metadata=attn_metadata_b)
# Extract non-padded portion from output_b
output_b_unpadded = torch.cat(
[
output_b[:, :hidden_seq_len, :, :],
output_b[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :],
],
dim=1,
)
# Compare outputs
max_diff = torch.max(torch.abs(output_a - output_b_unpadded)).item()
mean_diff = torch.mean(torch.abs(output_a - output_b_unpadded)).item()
print("\n=== Case 1: Padding Equivalence Test ===")
print(f"Output A shape: {output_a.shape}")
print(f"Output B shape: {output_b.shape}")
print(f"Output B unpadded shape: {output_b_unpadded.shape}")
print(f"Max absolute difference: {max_diff:.6f}")
print(f"Mean absolute difference: {mean_diff:.6f}")
# Assert that outputs are close
# Using higher tolerance for bfloat16
assert max_diff < 0.1, f"Max difference {max_diff} exceeds threshold 0.1"
assert mean_diff < 0.01, f"Mean difference {mean_diff} exceeds threshold 0.01"
print("✓ Case 1 PASSED: Padded and unpadded outputs are very close!")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention requires CUDA")
def test_fa_vs_sdpa():
"""
Case 2: Compare FlashAttention and SDPA backends with padding.
- batch_size=2
- hidden_states: (2, 48) padded to (2, 58)
- encoder_hidden_states: (2, 16) padded to (2, 26)
- Concatenated length: 84
- Compare FA and SDPA outputs
Expected: FA and SDPA outputs should be very close.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
# Configuration
batch_size = 2
hidden_seq_len = 48
encoder_seq_len = 16
pad_length = 10
num_heads = 8
head_dim = 64
# Initialize both backends
fa_impl = FlashAttentionImpl(
num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False
)
sdpa_impl = SDPAImpl(num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False)
# Create base tensors
torch.manual_seed(123)
hidden_states_base = torch.randn(batch_size, hidden_seq_len, num_heads, head_dim, device=device, dtype=dtype)
encoder_hidden_states_base = torch.randn(
batch_size, encoder_seq_len, num_heads, head_dim, device=device, dtype=dtype
)
# Pad tensors
hidden_states_padded = pad_tensor(hidden_states_base, hidden_seq_len + pad_length)
encoder_hidden_states_padded = pad_tensor(encoder_hidden_states_base, encoder_seq_len + pad_length)
# Concatenate
query = torch.cat([hidden_states_padded, encoder_hidden_states_padded], dim=1)
key = query.clone()
value = query.clone()
# Create attention mask
attn_mask = torch.cat(
[
create_attention_mask(batch_size, hidden_seq_len + pad_length, hidden_seq_len, device),
create_attention_mask(batch_size, encoder_seq_len + pad_length, encoder_seq_len, device),
],
dim=1,
)
attn_metadata = AttentionMetadata(attn_mask=attn_mask)
# Run FlashAttention
output_fa = fa_impl.forward(query=query.clone(), key=key.clone(), value=value.clone(), attn_metadata=attn_metadata)
# Run SDPA
# SDPA expects 4D attention mask: (batch_size, 1, seq_len, seq_len) or (batch_size, seq_len)
# For causal=False, we need to convert 2D mask to 4D
if attn_mask is not None:
# Expand mask for SDPA: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
attn_mask_4d = attn_mask.unsqueeze(1).unsqueeze(2)
# Convert bool to float: True -> 0.0, False -> -inf
attn_mask_float = torch.zeros_like(attn_mask_4d, dtype=dtype)
attn_mask_float.masked_fill_(~attn_mask_4d, float("-inf"))
attn_metadata_sdpa = AttentionMetadata(attn_mask=attn_mask_float)
else:
attn_metadata_sdpa = AttentionMetadata(attn_mask=None)
output_sdpa = sdpa_impl.forward(
query=query.clone(), key=key.clone(), value=value.clone(), attn_metadata=attn_metadata_sdpa
)
# Compare outputs (only compare valid regions)
output_fa_valid = torch.cat(
[
output_fa[:, :hidden_seq_len, :, :],
output_fa[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :],
],
dim=1,
)
output_sdpa_valid = torch.cat(
[
output_sdpa[:, :hidden_seq_len, :, :],
output_sdpa[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :],
],
dim=1,
)
max_diff = torch.max(torch.abs(output_fa_valid - output_sdpa_valid)).item()
mean_diff = torch.mean(torch.abs(output_fa_valid - output_sdpa_valid)).item()
print("\n=== Case 2: FA vs SDPA Comparison ===")
print(f"Batch size: {batch_size}")
print(f"FA output shape: {output_fa.shape}")
print(f"SDPA output shape: {output_sdpa.shape}")
print(f"Max absolute difference (valid region): {max_diff:.6f}")
print(f"Mean absolute difference (valid region): {mean_diff:.6f}")
# Assert that outputs are close
# Using higher tolerance for bfloat16 and different implementations
assert max_diff < 0.01, f"Max difference {max_diff} exceeds threshold 0.01"
assert mean_diff < 0.001, f"Mean difference {mean_diff} exceeds threshold 0.001"
print("✓ Case 2 PASSED: FA and SDPA outputs are very close!")
if __name__ == "__main__":
print("Running FlashAttention Padding Tests...")
print("=" * 60)
# Try to run CUDA tests
if torch.cuda.is_available():
try:
print("\n[Running Case 1: Padding Equivalence for FA]")
test_padding_equivalence()
except Exception as e:
print(f"✗ Case 1 failed: {e}")
import traceback
traceback.print_exc()
try:
print("\n[Running Case 2: FA vs SDPA]")
test_fa_vs_sdpa()
except Exception as e:
print(f"✗ Case 2 failed: {e}")
import traceback
traceback.print_exc()
else:
raise RuntimeError("CUDA is not available")
print("\n" + "=" * 60)
print("Test suite completed!")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for cache backends (cache-dit and teacache).
"""
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for cache backends (cache-dit and teacache).
This module tests the cache backend implementations:
- CacheDiTBackend: cache-dit acceleration backend
- TeaCacheBackend: TeaCache hook-based backend
- Cache selector function: get_cache_backend
- DiffusionCacheConfig: configuration dataclass
"""
from unittest.mock import Mock, patch
import pytest
from vllm_omni.diffusion.cache.cache_dit_backend import (
CacheDiTBackend,
)
from vllm_omni.diffusion.cache.selector import get_cache_backend
from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend
from vllm_omni.diffusion.data import DiffusionCacheConfig
class TestCacheDiTBackend:
"""Test CacheDiTBackend implementation."""
def test_init_with_dict(self):
"""Test initialization with dictionary config."""
config_dict = {"Fn_compute_blocks": 4, "max_warmup_steps": 8}
backend = CacheDiTBackend(config_dict)
assert backend.config.Fn_compute_blocks == 4
assert backend.config.max_warmup_steps == 8
assert backend.enabled is False
def test_init_with_config_object(self):
"""Test initialization with DiffusionCacheConfig object."""
config = DiffusionCacheConfig(Fn_compute_blocks=4)
backend = CacheDiTBackend(config)
assert backend.config.Fn_compute_blocks == 4
assert backend.enabled is False
@patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit")
def test_enable_single_transformer(self, mock_cache_dit):
"""Test enabling cache-dit on single-transformer pipeline."""
# Mock pipeline
mock_pipeline = Mock()
mock_pipeline.__class__.__name__ = "DiTPipeline"
mock_transformer = Mock()
mock_pipeline.transformer = mock_transformer
# Mock cache_dit functions
mock_cache_dit.enable_cache = Mock()
mock_cache_dit.refresh_context = Mock()
backend = CacheDiTBackend({"Fn_compute_blocks": 2})
backend.enable(mock_pipeline)
# Verify cache-dit was enabled
assert backend.enabled is True
assert backend._refresh_func is not None
mock_cache_dit.enable_cache.assert_called_once()
@patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit")
def test_refresh(self, mock_cache_dit):
"""Test refreshing cache context with SCM mask policy updates when num_inference_steps changes."""
# Mock pipeline
mock_pipeline = Mock()
mock_pipeline.__class__.__name__ = "DiTPipeline"
mock_transformer = Mock()
mock_pipeline.transformer = mock_transformer
# Mock cache_dit functions
mock_cache_dit.enable_cache = Mock()
mock_cache_dit.refresh_context = Mock()
mock_steps_mask_50 = [1, 0, 1, 0, 1] * 10 # Mock mask for 50 steps
mock_steps_mask_100 = [1, 0, 1, 0, 1] * 20 # Mock mask for 100 steps
mock_cache_dit.steps_mask = Mock(side_effect=[mock_steps_mask_50, mock_steps_mask_100])
# Enable cache-dit with SCM enabled (using mask policy)
config = DiffusionCacheConfig(
scm_steps_mask_policy="fast",
scm_steps_policy="dynamic",
)
backend = CacheDiTBackend(config)
backend.enable(mock_pipeline)
# First refresh with 50 steps
backend.refresh(mock_pipeline, num_inference_steps=50)
assert backend._last_num_inference_steps == 50
# Verify steps_mask was called with mask policy (not direct steps mask)
mock_cache_dit.steps_mask.assert_called_with(mask_policy="fast", total_steps=50)
assert mock_cache_dit.steps_mask.call_count == 1
# Verify refresh_context was called with cache_config (SCM path)
mock_cache_dit.refresh_context.assert_called_once()
call_args = mock_cache_dit.refresh_context.call_args
assert call_args[0][0] == mock_transformer
# Check that cache_config was passed (not num_inference_steps directly when SCM is enabled)
assert "cache_config" in call_args[1]
cache_config_arg = call_args[1]["cache_config"]
assert cache_config_arg is not None
# Change num_inference_steps and refresh again
mock_cache_dit.refresh_context.reset_mock()
backend.refresh(mock_pipeline, num_inference_steps=100)
# Verify steps_mask was called again with new num_inference_steps (using mask policy)
assert mock_cache_dit.steps_mask.call_count == 2
# Check the last call was with 100 steps and mask policy
assert mock_cache_dit.steps_mask.call_args_list[-1].kwargs["total_steps"] == 100
assert mock_cache_dit.steps_mask.call_args_list[-1].kwargs["mask_policy"] == "fast"
# Verify refresh_context was called again with updated mask
mock_cache_dit.refresh_context.assert_called_once()
call_args = mock_cache_dit.refresh_context.call_args
assert call_args[0][0] == mock_transformer
assert "cache_config" in call_args[1]
assert backend._last_num_inference_steps == 100
class TestTeaCacheBackend:
"""Test TeaCacheBackend implementation."""
def test_init(self):
"""Test initialization."""
config = DiffusionCacheConfig(rel_l1_thresh=0.3)
backend = TeaCacheBackend(config)
assert backend.config.rel_l1_thresh == 0.3
assert backend.enabled is False
@patch("vllm_omni.diffusion.cache.teacache.backend.apply_teacache_hook")
def test_enable(self, mock_apply_hook):
"""Test enabling TeaCache on pipeline."""
# Mock pipeline
mock_pipeline = Mock()
mock_pipeline.__class__.__name__ = "QwenImagePipeline"
mock_transformer = Mock()
mock_transformer.__class__.__name__ = "QwenImageTransformer2DModel"
mock_pipeline.transformer = mock_transformer
config = DiffusionCacheConfig(rel_l1_thresh=0.3)
backend = TeaCacheBackend(config)
backend.enable(mock_pipeline)
# Verify hook was applied
assert backend.enabled is True
mock_apply_hook.assert_called_once()
@patch("vllm_omni.diffusion.cache.teacache.backend.apply_teacache_hook")
def test_enable_with_coefficients(self, mock_apply_hook):
"""Test enabling TeaCache with custom coefficients."""
mock_pipeline = Mock()
mock_pipeline.__class__.__name__ = "QwenImagePipeline"
mock_transformer = Mock()
mock_transformer.__class__.__name__ = "QwenImageTransformer2DModel"
mock_pipeline.transformer = mock_transformer
config = DiffusionCacheConfig(rel_l1_thresh=0.3, coefficients=[1.0, 0.5, 0.2, 0.1, 0.05])
backend = TeaCacheBackend(config)
backend.enable(mock_pipeline)
assert backend.enabled is True
mock_apply_hook.assert_called_once()
@patch("vllm_omni.diffusion.cache.teacache.backend.apply_teacache_hook")
def test_refresh(self, mock_apply_hook):
"""Test refreshing TeaCache state."""
mock_pipeline = Mock()
mock_pipeline.__class__.__name__ = "QwenImagePipeline"
mock_transformer = Mock()
mock_transformer.__class__.__name__ = "QwenImageTransformer2DModel"
mock_pipeline.transformer = mock_transformer
# Mock hook registry
mock_hook = Mock()
mock_registry = Mock()
mock_registry.get_hook = Mock(return_value=mock_hook)
mock_registry.reset_hook = Mock()
mock_transformer._hook_registry = mock_registry
config = DiffusionCacheConfig()
backend = TeaCacheBackend(config)
backend.enable(mock_pipeline)
# Test refresh
backend.refresh(mock_pipeline, num_inference_steps=50)
mock_registry.reset_hook.assert_called_once()
class TestCacheSelector:
"""Test cache backend selector function."""
def test_get_cache_backend_none(self):
"""Test getting None backend."""
backend = get_cache_backend(None, None)
assert backend is None
backend = get_cache_backend("none", None)
assert backend is None
def test_get_cache_backend_cache_dit(self):
"""Test getting cache-dit backend."""
config_dict = {"Fn_compute_blocks": 4}
backend = get_cache_backend("cache_dit", config_dict)
assert isinstance(backend, CacheDiTBackend)
assert backend.config.Fn_compute_blocks == 4
def test_get_cache_backend_tea_cache(self):
"""Test getting teacache backend."""
config_dict = {"rel_l1_thresh": 0.3}
backend = get_cache_backend("tea_cache", config_dict)
assert isinstance(backend, TeaCacheBackend)
assert backend.config.rel_l1_thresh == 0.3
def test_get_cache_backend_invalid(self):
"""Test getting invalid backend raises error."""
with pytest.raises(ValueError, match="Unsupported cache backend"):
get_cache_backend("invalid_backend", {})
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for CFG (Classifier-Free Guidance) parallel functionality.
This test verifies that predict_noise_maybe_with_cfg produces numerically
equivalent results with and without CFG parallel using fixed random inputs.
"""
import os
import pytest
import torch
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.parallel_state import (
destroy_distributed_env,
get_classifier_free_guidance_rank,
get_classifier_free_guidance_world_size,
init_distributed_environment,
initialize_model_parallel,
)
from vllm_omni.platforms import current_omni_platform
def update_environment_variables(envs_dict: dict[str, str]):
"""Update multiple environment variables."""
for k, v in envs_dict.items():
os.environ[k] = v
class SimpleTransformer(torch.nn.Module):
"""Simple transformer model for testing with random initialization.
Contains:
- Input projection (conv to hidden_dim)
- QKV projection layers
- Self-attention layer
- Output projection
"""
def __init__(self, in_channels: int = 4, hidden_dim: int = 128, num_heads: int = 8):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
# Input projection: (B, C, H, W) -> (B, hidden_dim, H, W)
self.input_proj = torch.nn.Conv2d(in_channels, hidden_dim, 1)
# QKV projection layers
self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim)
self.k_proj = torch.nn.Linear(hidden_dim, hidden_dim)
self.v_proj = torch.nn.Linear(hidden_dim, hidden_dim)
# Output projection after attention
self.out_proj = torch.nn.Linear(hidden_dim, hidden_dim)
# Final output projection: (B, hidden_dim, H, W) -> (B, C, H, W)
self.final_proj = torch.nn.Conv2d(hidden_dim, in_channels, 1)
# Layer norm
self.norm1 = torch.nn.LayerNorm(hidden_dim)
self.norm2 = torch.nn.LayerNorm(hidden_dim)
def forward(self, x: torch.Tensor, **kwargs) -> tuple[torch.Tensor]:
"""Forward pass with self-attention.
Args:
x: Input tensor of shape (B, C, H, W)
Returns:
Output tensor of shape (B, C, H, W)
"""
B, C, H, W = x.shape
# Input projection
x = self.input_proj(x) # (B, hidden_dim, H, W)
# Reshape to sequence: (B, hidden_dim, H, W) -> (B, H*W, hidden_dim)
x = x.flatten(2).transpose(1, 2) # (B, H*W, hidden_dim)
# Self-attention with residual connection
residual = x
x = self.norm1(x)
# QKV projection
q = self.q_proj(x) # (B, H*W, hidden_dim)
k = self.k_proj(x) # (B, H*W, hidden_dim)
v = self.v_proj(x) # (B, H*W, hidden_dim)
# Reshape for multi-head attention: (B, H*W, hidden_dim) -> (B, num_heads, H*W, head_dim)
seq_len = H * W
q = q.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
scale = self.head_dim**-0.5
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (B, num_heads, H*W, H*W)
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().view(B, seq_len, self.hidden_dim)
attn_output = self.out_proj(attn_output)
x = residual + attn_output
residual = x
x = self.norm2(x)
x = residual + x
x = x.transpose(1, 2).view(B, self.hidden_dim, H, W)
out = self.final_proj(x)
return (out,)
class TestCFGPipeline(CFGParallelMixin):
"""Test pipeline using CFGParallelMixin."""
def __init__(self, in_channels: int = 4, hidden_dim: int = 128, seed: int = 42):
# Set seed BEFORE creating transformer to ensure consistent layer initialization
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
self.transformer = SimpleTransformer(in_channels, hidden_dim)
# Re-initialize all parameters with fixed seed for full reproducibility
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
for param in self.transformer.parameters():
torch.nn.init.normal_(param, mean=0.0, std=0.02)
def _test_cfg_parallel_worker(
local_rank: int,
world_size: int,
cfg_parallel_size: int,
dtype: torch.dtype,
test_config: dict,
result_queue: torch.multiprocessing.Queue,
):
"""Worker function for CFG parallel test."""
device = torch.device(f"{current_omni_platform.device_type}:{local_rank}")
current_omni_platform.set_device(device)
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "29502",
}
)
init_distributed_environment()
initialize_model_parallel(cfg_parallel_size=cfg_parallel_size)
cfg_rank = get_classifier_free_guidance_rank()
cfg_world_size = get_classifier_free_guidance_world_size()
assert cfg_world_size == cfg_parallel_size
# Create pipeline with same seed to ensure identical model weights across all ranks
# Note: model_seed is set inside TestCFGPipeline.__init__
pipeline = TestCFGPipeline(
in_channels=test_config["channels"],
hidden_dim=test_config["hidden_dim"],
seed=test_config["model_seed"],
)
pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype)
pipeline.transformer.eval() # Set to eval mode for deterministic behavior
# Create fixed inputs with explicit seed setting for reproducibility
# Set both CPU and CUDA seeds to ensure identical inputs across all ranks
torch.manual_seed(test_config["input_seed"])
if torch.cuda.is_available():
torch.cuda.manual_seed_all(test_config["input_seed"])
batch_size = test_config["batch_size"]
channels = test_config["channels"]
height = test_config["height"]
width = test_config["width"]
# Positive input
positive_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device)
# Negative input with different seed
torch.manual_seed(test_config["input_seed"] + 1)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(test_config["input_seed"] + 1)
negative_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device)
# Prepare kwargs for predict_noise_maybe_with_cfg
positive_kwargs = {"x": positive_input}
negative_kwargs = {"x": negative_input}
with torch.no_grad():
# Call predict_noise_maybe_with_cfg
noise_pred = pipeline.predict_noise_maybe_with_cfg(
do_true_cfg=True,
true_cfg_scale=test_config["cfg_scale"],
positive_kwargs=positive_kwargs,
negative_kwargs=negative_kwargs,
cfg_normalize=test_config["cfg_normalize"],
)
# Only rank 0 has valid output in CFG parallel mode
if cfg_rank == 0:
assert noise_pred is not None
result_queue.put(noise_pred.cpu())
else:
assert noise_pred is None
destroy_distributed_env()
def _test_cfg_sequential_worker(
local_rank: int,
world_size: int,
dtype: torch.dtype,
test_config: dict,
result_queue: torch.multiprocessing.Queue,
):
"""Worker function for sequential CFG test (baseline)."""
device = torch.device(f"{current_omni_platform.device_type}:{local_rank}")
current_omni_platform.set_device(device)
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "29503",
}
)
init_distributed_environment()
initialize_model_parallel(cfg_parallel_size=1) # No CFG parallel
cfg_world_size = get_classifier_free_guidance_world_size()
assert cfg_world_size == 1
# Create pipeline with same seed to ensure identical model weights as CFG parallel
# Note: model_seed is set inside TestCFGPipeline.__init__
pipeline = TestCFGPipeline(
in_channels=test_config["channels"],
hidden_dim=test_config["hidden_dim"],
seed=test_config["model_seed"],
)
pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype)
pipeline.transformer.eval()
# Create fixed inputs (same seed as CFG parallel to ensure identical inputs)
# Set both CPU and CUDA seeds for full reproducibility
torch.manual_seed(test_config["input_seed"])
if torch.cuda.is_available():
torch.cuda.manual_seed_all(test_config["input_seed"])
batch_size = test_config["batch_size"]
channels = test_config["channels"]
height = test_config["height"]
width = test_config["width"]
# Positive input
positive_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device)
# Negative input with different seed
torch.manual_seed(test_config["input_seed"] + 1)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(test_config["input_seed"] + 1)
negative_input = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device)
positive_kwargs = {"x": positive_input}
negative_kwargs = {"x": negative_input}
with torch.no_grad():
noise_pred = pipeline.predict_noise_maybe_with_cfg(
do_true_cfg=True,
true_cfg_scale=test_config["cfg_scale"],
positive_kwargs=positive_kwargs,
negative_kwargs=negative_kwargs,
cfg_normalize=test_config["cfg_normalize"],
)
# Sequential CFG always returns output
assert noise_pred is not None
result_queue.put(noise_pred.cpu())
destroy_distributed_env()
@pytest.mark.parametrize("cfg_parallel_size", [2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("cfg_normalize", [False, True])
def test_predict_noise_maybe_with_cfg(cfg_parallel_size: int, dtype: torch.dtype, batch_size: int, cfg_normalize: bool):
"""
Test that predict_noise_maybe_with_cfg produces identical results
with and without CFG parallel.
Args:
cfg_parallel_size: Number of GPUs for CFG parallel
dtype: Data type for computation
batch_size: Batch size for testing
cfg_normalize: Whether to normalize CFG output
"""
available_gpus = current_omni_platform.get_device_count()
if available_gpus < cfg_parallel_size:
pytest.skip(f"Test requires {cfg_parallel_size} GPUs but only {available_gpus} available")
test_config = {
"batch_size": batch_size,
"channels": 4,
"height": 16,
"width": 16,
"hidden_dim": 128,
"cfg_scale": 7.5,
"cfg_normalize": cfg_normalize,
"model_seed": 42, # Fixed seed for model initialization
"input_seed": 123, # Fixed seed for input generation
}
mp_context = torch.multiprocessing.get_context("spawn")
manager = mp_context.Manager()
baseline_queue = manager.Queue()
cfg_parallel_queue = manager.Queue()
# Run baseline (sequential CFG) on single GPU
torch.multiprocessing.spawn(
_test_cfg_sequential_worker,
args=(1, dtype, test_config, baseline_queue),
nprocs=1,
)
# Run CFG parallel on multiple GPUs
torch.multiprocessing.spawn(
_test_cfg_parallel_worker,
args=(cfg_parallel_size, cfg_parallel_size, dtype, test_config, cfg_parallel_queue),
nprocs=cfg_parallel_size,
)
# Get results from queues
baseline_output = baseline_queue.get()
cfg_parallel_output = cfg_parallel_queue.get()
# Verify shapes match
assert baseline_output.shape == cfg_parallel_output.shape, (
f"Shape mismatch: baseline {baseline_output.shape} vs CFG parallel {cfg_parallel_output.shape}"
)
# Verify numerical equivalence with appropriate tolerances
if dtype == torch.float32:
rtol, atol = 1e-5, 1e-5
elif dtype == torch.bfloat16:
rtol, atol = 1e-2, 1e-2
else:
rtol, atol = 1e-3, 1e-3
torch.testing.assert_close(
cfg_parallel_output,
baseline_output,
rtol=rtol,
atol=atol,
msg=(
f"CFG parallel output differs from sequential CFG\n"
f" dtype={dtype}, batch_size={batch_size}, cfg_normalize={cfg_normalize}\n"
f" Max diff: {(cfg_parallel_output - baseline_output).abs().max().item():.6e}"
),
)
print(
f"✓ Test passed: cfg_size={cfg_parallel_size}, dtype={dtype}, "
f"batch_size={batch_size}, cfg_normalize={cfg_normalize}"
)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_predict_noise_without_cfg(dtype: torch.dtype):
"""
Test predict_noise_maybe_with_cfg when do_true_cfg=False.
When CFG is disabled, only the positive branch should be computed.
This test runs on a single GPU without distributed environment.
"""
available_gpus = current_omni_platform.get_device_count()
if available_gpus < 1:
pytest.skip("Test requires at least 1 GPU")
device = torch.device(f"{current_omni_platform.device_type}:0")
current_omni_platform.set_device(device)
# Create pipeline without distributed environment
pipeline = TestCFGPipeline(in_channels=4, hidden_dim=128, seed=42)
pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype)
pipeline.transformer.eval()
# Set seed for input generation
torch.manual_seed(123)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(123)
positive_input = torch.randn(1, 4, 16, 16, dtype=dtype, device=device)
with torch.no_grad():
noise_pred = pipeline.predict_noise_maybe_with_cfg(
do_true_cfg=False, # No CFG
true_cfg_scale=7.5,
positive_kwargs={"x": positive_input},
negative_kwargs=None,
cfg_normalize=False,
)
# Should always return output when do_true_cfg=False
assert noise_pred is not None
assert noise_pred.shape == (1, 4, 16, 16)
print(f"✓ Test passed: predict_noise without CFG (dtype={dtype})")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for SeqAllToAll4D and SeqAllToAll5D communication primitives."""
import os
import pytest
import torch
from vllm_omni.diffusion.distributed.comm import RingComm, SeqAllToAll4D, SeqAllToAll5D
from vllm_omni.diffusion.distributed.parallel_state import (
destroy_distributed_env,
get_sp_group,
init_distributed_environment,
initialize_model_parallel,
)
from vllm_omni.platforms import current_omni_platform
def update_environment_variables(envs_dict: dict[str, str]):
"""Update multiple environment variables with logging."""
for k, v in envs_dict.items():
os.environ[k] = v
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len_per_rank", [8])
@pytest.mark.parametrize("num_heads", [8])
@pytest.mark.parametrize("head_size", [32])
@pytest.mark.parametrize("use_sync", [False, True])
def test_4d_identity(
world_size: int,
dtype: torch.dtype,
batch_size: int,
seq_len_per_rank: int,
num_heads: int,
head_size: int,
use_sync: bool,
):
"""Test that two consecutive all-to-all operations return the original input."""
# Skip if not enough GPUs available
available_gpus = current_omni_platform.get_device_count()
if available_gpus < world_size:
pytest.skip(f"Test requires {world_size} GPUs but only {available_gpus} available")
# Ensure num_heads is divisible by world_size
if num_heads % world_size != 0:
pytest.skip(f"num_heads ({num_heads}) not divisible by world_size ({world_size})")
# Run test with multiprocessing spawn
torch.multiprocessing.spawn(
_test_4d_identity_worker,
args=(
world_size,
dtype,
batch_size,
seq_len_per_rank,
num_heads,
head_size,
use_sync,
),
nprocs=world_size,
)
def _test_4d_identity_worker(
local_rank: int,
world_size: int,
dtype: torch.dtype,
batch_size: int,
seq_len_per_rank: int,
num_heads: int,
head_size: int,
use_sync: bool,
):
"""Worker function for test_4d_identity."""
# Set device
device = torch.device(f"{current_omni_platform.device_type}:{local_rank}")
current_omni_platform.set_device(device)
# Set environment variables for distributed training
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "29500",
}
)
# Initialize distributed environment
init_distributed_environment()
initialize_model_parallel(ulysses_degree=world_size) # test ulysses sp by default
sp_group = get_sp_group().ulysses_group # get ulysses sp group not ring sp group
# Create input tensor: (bs, seqlen/P, hc, hs)
torch.manual_seed(42 + local_rank)
input_tensor = torch.randn(
batch_size,
seq_len_per_rank,
num_heads,
head_size,
dtype=dtype,
device=device,
)
# Save original input for comparison
original_input = input_tensor.clone()
# First all-to-all: (bs, seqlen/P, hc, hs) -> (bs, seqlen, hc/P, hs)
intermediate = SeqAllToAll4D.apply(
sp_group,
input_tensor,
2, # scatter head dimension
1, # gather sequence dimension
use_sync,
)
# Verify intermediate shape
expected_shape = (
batch_size,
seq_len_per_rank * world_size,
num_heads // world_size,
head_size,
)
assert intermediate.shape == expected_shape, (
f"Intermediate shape mismatch: expected {expected_shape}, got {intermediate.shape}"
)
# Second all-to-all: (bs, seqlen, hc/P, hs) -> (bs, seqlen/P, hc, hs)
output = SeqAllToAll4D.apply(
sp_group,
intermediate,
1, # scatter sequence dimension
2, # gather head dimension
use_sync,
)
# Verify output shape matches input
assert output.shape == original_input.shape, (
f"Output shape mismatch: expected {original_input.shape}, got {output.shape}"
)
# Verify output matches original input
torch.testing.assert_close(
output,
original_input,
rtol=1e-5,
atol=1e-5,
msg="Output does not match original input after two all-to-all operations",
)
# Cleanup distributed environment
destroy_distributed_env()
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seq_len_per_rank", [8])
@pytest.mark.parametrize("num_heads", [8])
@pytest.mark.parametrize("head_size", [32])
@pytest.mark.parametrize("use_sync", [False, True])
def test_5d_identity(
world_size: int,
dtype: torch.dtype,
batch_size: int,
seq_len_per_rank: int,
num_heads: int,
head_size: int,
use_sync: bool,
):
"""Test that two consecutive all-to-all operations return the original input."""
# Skip if not enough GPUs available
available_gpus = current_omni_platform.get_device_count()
if available_gpus < world_size:
pytest.skip(f"Test requires {world_size} GPUs but only {available_gpus} available")
# Ensure num_heads is divisible by world_size
if num_heads % world_size != 0:
pytest.skip(f"num_heads ({num_heads}) not divisible by world_size ({world_size})")
# Run test with multiprocessing spawn
torch.multiprocessing.spawn(
_test_5d_identity_worker,
args=(
world_size,
dtype,
batch_size,
seq_len_per_rank,
num_heads,
head_size,
use_sync,
),
nprocs=world_size,
)
def _test_5d_identity_worker(
local_rank: int,
world_size: int,
dtype: torch.dtype,
batch_size: int,
seq_len_per_rank: int,
num_heads: int,
head_size: int,
use_sync: bool,
):
"""Worker function for test_5d_identity."""
# Set device
device = torch.device(f"{current_omni_platform.device_type}:{local_rank}")
current_omni_platform.set_device(device)
# Set environment variables for distributed training
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "29500",
}
)
# Initialize distributed environment
init_distributed_environment()
initialize_model_parallel(ulysses_degree=world_size) # test ulysses sp by default
sp_group = get_sp_group().ulysses_group # get ulysses sp group not ring sp group
# Create input tensor: (bs, seqlen/P, 3, hc, hs)
# The '3' dimension is for Q, K, V
torch.manual_seed(42 + local_rank)
input_tensor = torch.randn(
batch_size,
seq_len_per_rank,
3, # Q, K, V
num_heads,
head_size,
dtype=dtype,
device=device,
)
# Save original input for comparison
original_input = input_tensor.clone()
# First all-to-all: (bs, seqlen/P, 3, hc, hs) -> (bs, seqlen, 3, hc/P, hs)
intermediate = SeqAllToAll5D.apply(
sp_group,
input_tensor,
3, # scatter head dimension
1, # gather sequence dimension
use_sync,
)
# Verify intermediate shape
expected_shape = (
batch_size,
seq_len_per_rank * world_size,
3,
num_heads // world_size,
head_size,
)
assert intermediate.shape == expected_shape, (
f"Intermediate shape mismatch: expected {expected_shape}, got {intermediate.shape}"
)
# Second all-to-all: (bs, seqlen, 3, hc/P, hs) -> (bs, seqlen/P, 3, hc, hs)
output = SeqAllToAll5D.apply(
sp_group,
intermediate,
1, # scatter sequence dimension
3, # gather head dimension
use_sync,
)
# Verify output shape matches input
assert output.shape == original_input.shape, (
f"Output shape mismatch: expected {original_input.shape}, got {output.shape}"
)
# Verify output matches original input
torch.testing.assert_close(
output,
original_input,
rtol=1e-5,
atol=1e-5,
msg="Output does not match original input after two all-to-all operations",
)
# Cleanup distributed environment
destroy_distributed_env()
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("num_heads", [8])
@pytest.mark.parametrize("head_size", [128])
def test_ring_p2p(
world_size: int,
dtype: torch.dtype,
batch_size: int,
num_heads: int,
head_size: int,
):
"""Test Ring P2P communication (send_recv)."""
# Skip if not enough GPUs available
available_gpus = current_omni_platform.get_device_count()
if available_gpus < world_size:
pytest.skip(f"Test requires {world_size} GPUs but only {available_gpus} available")
torch.multiprocessing.spawn(
_test_ring_p2p_worker,
args=(world_size, dtype, batch_size, num_heads, head_size),
nprocs=world_size,
)
def _test_ring_p2p_worker(
local_rank: int,
world_size: int,
dtype: torch.dtype,
batch_size: int,
num_heads: int,
head_size: int,
):
"""Worker for Ring P2P test."""
import sys
# Set device
device = torch.device(f"{current_omni_platform.device_type}:{local_rank}")
current_omni_platform.set_device(device)
# Set env vars
# Use a different port to avoid conflict with other tests if run in parallel
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "29501",
}
)
# Init distributed
try:
init_distributed_environment()
# Ring degree = world_size to test ring group
initialize_model_parallel(ring_degree=world_size)
sp_group = get_sp_group()
print(f"[Rank {local_rank}] Initialized. Ring group size: {sp_group.ring_group.size()}")
sys.stdout.flush()
# Create RingComm
comm = RingComm(sp_group.ring_group)
# Create tensor: rank-specific data
# (batch, num_heads, head_size)
# Fill with rank value + 1 to avoid 0 and make verification easy
input_tensor = torch.full(
(batch_size, num_heads, head_size), fill_value=float(local_rank + 1), dtype=dtype, device=device
)
print(f"[Rank {local_rank}] Input sum: {input_tensor.sum().item()}")
sys.stdout.flush()
# Send input, receive from prev
# RingComm.send_recv sends to next, receives from prev
t0 = __import__("time").time()
recv_tensor = comm.send_recv(input_tensor)
comm.commit()
comm.wait()
t1 = __import__("time").time()
print(f"[Rank {local_rank}] Communication done in {t1 - t0:.4f}s")
# Verify
# Expected value: from (rank - 1) % world_size
prev_rank = (local_rank - 1 + world_size) % world_size
expected_value = float(prev_rank + 1)
recv_sum = recv_tensor.sum().item()
print(f"[Rank {local_rank}] Received sum: {recv_sum}, Expected value: {expected_value}")
sys.stdout.flush()
expected_tensor = torch.full_like(recv_tensor, fill_value=expected_value)
# Use a slightly loose tolerance for bfloat16
torch.testing.assert_close(
recv_tensor, expected_tensor, rtol=1e-3, atol=1e-3, msg=f"[Rank {local_rank}] Data mismatch!"
)
print(f"[Rank {local_rank}] Verification PASSED")
except Exception as e:
print(f"[Rank {local_rank}] FAILED with error: {e}")
import traceback
traceback.print_exc()
raise e
finally:
destroy_distributed_env()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the Sequence Parallelism (SP) framework.
These tests verify the SP plan mechanism and hooks work correctly without
requiring a distributed environment. They test:
1. _sp_plan validation (sp_plan.py)
2. Hook utilities and submodule resolution (sequence_parallel.py)
3. Model _sp_plan definitions
4. Tensor sharding simulation
Note: Our "Sequence Parallelism" (SP) corresponds to "Context Parallelism" (CP)
in diffusers. We use "Sequence Parallelism" to align with vLLM-Omni terminology.
"""
import pytest
import torch
import torch.nn as nn
from vllm_omni.diffusion.distributed.sp_plan import (
SequenceParallelInput,
SequenceParallelOutput,
SequenceParallelPartialInput,
get_sp_plan_from_model,
validate_sp_plan,
)
def is_distributed_initialized() -> bool:
"""Check if distributed environment is initialized."""
try:
from vllm_omni.diffusion.distributed.parallel_state import get_sp_group
get_sp_group()
return True
except (AssertionError, ImportError):
return False
# Decorator to skip tests that require distributed environment
requires_distributed = pytest.mark.skipif(
not is_distributed_initialized(),
reason="Requires initialized distributed environment (SP group)",
)
# Module-level markers: these tests are diffusion + parallel related
pytestmark = [
pytest.mark.diffusion,
pytest.mark.parallel,
]
# =============================================================================
# Tests for sp_plan.py
# =============================================================================
@pytest.mark.cpu
class TestSequenceParallelPlanValidation:
"""Test _sp_plan validation logic."""
def test_valid_simple_plan(self):
"""Test a simple valid _sp_plan."""
plan = {
"rope": {
0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),
1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),
},
"blocks.0": {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
},
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
# Should not raise
validate_sp_plan(plan)
def test_valid_partial_input_plan(self):
"""Test a valid _sp_plan with SequenceParallelPartialInput."""
plan = {
"pos_embed": {
0: SequenceParallelPartialInput(
split_dim=0,
text_len_source="txt_ids",
expected_dims=2,
split_output=True,
),
},
"blocks.0": {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
},
}
# Should not raise
validate_sp_plan(plan)
def test_invalid_plan_type(self):
"""Test that non-dict plan raises error."""
with pytest.raises(ValueError, match="must be a dict"):
validate_sp_plan("not a dict")
def test_invalid_module_key_type(self):
"""Test that non-string module keys raise error."""
plan = {123: {"hidden_states": SequenceParallelInput(split_dim=1)}}
with pytest.raises(ValueError, match="keys must be strings"):
validate_sp_plan(plan)
def test_invalid_output_index_without_split_output(self):
"""Test that integer keys require split_output=True."""
plan = {
"rope": {
0: SequenceParallelInput(split_dim=1, split_output=False), # Invalid
}
}
with pytest.raises(ValueError, match="split_output=True"):
validate_sp_plan(plan)
@pytest.mark.cpu
class TestGetSpPlanFromModel:
"""Test get_sp_plan_from_model utility."""
def test_model_with_sp_plan(self):
"""Test getting _sp_plan from a model that has one."""
class ModelWithPlan(nn.Module):
_sp_plan = {
"layer": {
"x": SequenceParallelInput(split_dim=1),
}
}
model = ModelWithPlan()
plan = get_sp_plan_from_model(model)
assert plan is not None
assert "layer" in plan
def test_model_without_sp_plan(self):
"""Test getting _sp_plan from a model without one."""
class ModelWithoutPlan(nn.Module):
pass
model = ModelWithoutPlan()
plan = get_sp_plan_from_model(model)
assert plan is None
@pytest.mark.cpu
class TestSequenceParallelInputTypes:
"""Test SequenceParallelInput and related types."""
def test_sequence_parallel_input_repr(self):
"""Test SequenceParallelInput repr."""
spi = SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True)
assert "split_dim=1" in repr(spi)
assert "expected_dims=3" in repr(spi)
assert "split_output=True" in repr(spi)
def test_sequence_parallel_output_repr(self):
"""Test SequenceParallelOutput repr."""
spo = SequenceParallelOutput(gather_dim=1, expected_dims=3)
assert "gather_dim=1" in repr(spo)
assert "expected_dims=3" in repr(spo)
def test_sequence_parallel_partial_input_repr(self):
"""Test SequenceParallelPartialInput repr."""
sppi = SequenceParallelPartialInput(
split_dim=0,
text_len_source="txt_ids",
expected_dims=2,
split_output=True,
)
assert "split_dim=0" in repr(sppi)
assert "txt_ids" in repr(sppi)
assert "expected_dims=2" in repr(sppi)
assert "split_output=True" in repr(sppi)
def test_sequence_parallel_partial_input_with_int_source(self):
"""Test SequenceParallelPartialInput with integer text_len_source."""
sppi = SequenceParallelPartialInput(
split_dim=0,
text_len_source=512, # Fixed length
expected_dims=2,
)
assert sppi.text_len_source == 512
# =============================================================================
# Tests for sequence_parallel.py
# =============================================================================
@pytest.mark.cpu
class TestModuleForwardMetadata:
"""Test ModuleForwardMetadata parameter resolution."""
def test_get_parameter_from_kwargs(self):
"""Test getting parameter from kwargs."""
from vllm_omni.diffusion.hooks.sequence_parallel import ModuleForwardMetadata
class DummyModule(nn.Module):
def forward(self, hidden_states, encoder_hidden_states):
pass
metadata = ModuleForwardMetadata()
metadata._cls = DummyModule
kwargs = {"hidden_states": torch.randn(2, 4, 8)}
val, is_kwarg, index = metadata._get_parameter_from_args_kwargs("hidden_states", (), kwargs)
assert is_kwarg is True
assert index is None
assert val.shape == (2, 4, 8)
def test_get_parameter_from_args(self):
"""Test getting parameter from positional args."""
from vllm_omni.diffusion.hooks.sequence_parallel import ModuleForwardMetadata
class DummyModule(nn.Module):
def forward(self, hidden_states, encoder_hidden_states):
pass
metadata = ModuleForwardMetadata()
metadata._cls = DummyModule
tensor = torch.randn(2, 4, 8)
args = (tensor,)
val, is_kwarg, index = metadata._get_parameter_from_args_kwargs("hidden_states", args, {})
assert is_kwarg is False
assert index == 0
assert torch.equal(val, tensor)
def test_parameter_caching(self):
"""Test that parameter indices are cached."""
from vllm_omni.diffusion.hooks.sequence_parallel import ModuleForwardMetadata
class DummyModule(nn.Module):
def forward(self, a, b, c):
pass
metadata = ModuleForwardMetadata()
metadata._cls = DummyModule
# First call - should populate cache
args = (torch.randn(1), torch.randn(1), torch.randn(1))
metadata._get_parameter_from_args_kwargs("b", args, {})
# Check cache was populated
assert metadata.cached_parameter_indices is not None
assert metadata.cached_parameter_indices["a"] == 0
assert metadata.cached_parameter_indices["b"] == 1
assert metadata.cached_parameter_indices["c"] == 2
@pytest.mark.cpu
class TestGetSubmoduleByName:
"""Test _get_submodule_by_name function."""
def test_root_module(self):
"""Test getting root module with empty string."""
from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name
model = nn.Linear(10, 10)
submodule = _get_submodule_by_name(model, "")
assert submodule is model
def test_simple_submodule(self):
"""Test getting a simple submodule."""
from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 10)
model = Model()
submodule = _get_submodule_by_name(model, "layer")
assert submodule is model.layer
def test_nested_submodule(self):
"""Test getting a nested submodule."""
from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name
class Model(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(10, 10), nn.ReLU())
model = Model()
submodule = _get_submodule_by_name(model, "encoder.0")
assert isinstance(submodule, nn.Linear)
def test_module_list_by_index(self):
"""Test getting element from ModuleList by index."""
from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name
class Model(nn.Module):
def __init__(self):
super().__init__()
self.blocks = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
model = Model()
submodule = _get_submodule_by_name(model, "blocks.0")
assert submodule is model.blocks[0]
submodule = _get_submodule_by_name(model, "blocks.2")
assert submodule is model.blocks[2]
def test_wildcard_modulelist(self):
"""Test wildcard matching for ModuleList."""
from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name
class Model(nn.Module):
def __init__(self):
super().__init__()
self.blocks = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
model = Model()
submodules = _get_submodule_by_name(model, "blocks.*")
assert isinstance(submodules, list)
assert len(submodules) == 3
for i, sm in enumerate(submodules):
assert sm is model.blocks[i]
def test_module_dict(self):
"""Test getting submodule from ModuleDict."""
from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name
class Model(nn.Module):
def __init__(self):
super().__init__()
self.outputs = nn.ModuleDict({"main": nn.Linear(10, 10), "aux": nn.Linear(10, 5)})
model = Model()
submodule = _get_submodule_by_name(model, "outputs.main")
assert submodule is model.outputs["main"]
submodule = _get_submodule_by_name(model, "outputs.aux")
assert submodule is model.outputs["aux"]
def test_invalid_submodule_raises(self):
"""Test that invalid submodule path raises error."""
from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 10)
model = Model()
with pytest.raises(ValueError, match="not a submodule"):
_get_submodule_by_name(model, "nonexistent")
def test_multiple_wildcards_raises(self):
"""Test that multiple wildcards raise error."""
from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name
model = nn.Linear(10, 10)
with pytest.raises(ValueError, match="only be used once"):
_get_submodule_by_name(model, "a.*.b.*")
@pytest.mark.cpu
class TestHookRegistration:
"""Test hook registration logic (without distributed backend)."""
def test_plan_validation_before_apply(self):
"""Test that invalid plans are rejected before hook registration."""
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.proj_in = nn.Linear(10, 10)
self.proj_out = nn.Linear(10, 10)
def forward(self, x):
return self.proj_out(self.proj_in(x))
# Invalid plan (non-string key)
invalid_plan = {
123: {"x": SequenceParallelInput(split_dim=1)},
}
with pytest.raises(ValueError):
validate_sp_plan(invalid_plan)
def test_valid_plan_structure_for_model(self):
"""Test that a valid plan can be defined for a model."""
class SimpleModel(nn.Module):
_sp_plan = {
"proj_in": {"x": SequenceParallelInput(split_dim=1, expected_dims=3)},
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
def __init__(self):
super().__init__()
self.proj_in = nn.Linear(10, 10)
self.proj_out = nn.Linear(10, 10)
def forward(self, x):
return self.proj_out(self.proj_in(x))
model = SimpleModel()
plan = get_sp_plan_from_model(model)
assert plan is not None
assert "proj_in" in plan
assert "proj_out" in plan
# Verify submodules exist
from vllm_omni.diffusion.hooks.sequence_parallel import _get_submodule_by_name
assert _get_submodule_by_name(model, "proj_in") is model.proj_in
assert _get_submodule_by_name(model, "proj_out") is model.proj_out
# =============================================================================
# Tests for model _sp_plan definitions
# =============================================================================
@pytest.mark.L4
class TestModelSpPlans:
"""Test that model _sp_plan definitions are valid.
These tests import actual model classes to verify _sp_plan structure.
May require GPU for model imports.
"""
def test_zimage_transformer_sp_plan(self):
"""Test ZImageTransformer2DModel _sp_plan structure.
The plan specifies:
- unified_prepare: Shard all 4 outputs (unified, cos, sin, attn_mask)
- all_final_layer.2-1: Gather outputs after final layer
Note: _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism)
"""
try:
from vllm_omni.diffusion.models.z_image.z_image_transformer import ZImageTransformer2DModel
plan = getattr(ZImageTransformer2DModel, "_sp_plan", None)
assert plan is not None, "ZImageTransformer2DModel should define _sp_plan"
assert isinstance(plan, dict)
assert "unified_prepare" in plan
unified_prepare_plan = plan["unified_prepare"]
# Check all 4 outputs are sharded with split_output=True
assert 0 in unified_prepare_plan # unified
assert 1 in unified_prepare_plan # unified_cos
assert 2 in unified_prepare_plan # unified_sin
assert 3 in unified_prepare_plan # unified_attn_mask
# Check output gathering
assert "all_final_layer.2-1" in plan
validate_sp_plan(plan)
except ImportError:
pytest.skip("ZImageTransformer2DModel not available")
def test_qwen_image_transformer_sp_plan(self):
"""Test QwenImageTransformer2DModel _sp_plan structure.
Qwen-Image follows the diffusers pattern similar to Z-Image:
- image_rope_prepare: Shards hidden_states and vid_freqs together
- proj_out: Gathers output
Key insight: hidden_states and vid_freqs MUST be sharded together
to maintain dimension alignment for RoPE computation.
Note: _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism)
"""
try:
from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import (
QwenImageTransformer2DModel,
)
plan = getattr(QwenImageTransformer2DModel, "_sp_plan", None)
assert plan is not None, "QwenImageTransformer2DModel should define _sp_plan"
assert isinstance(plan, dict)
# Check image_rope_prepare sharding
assert "image_rope_prepare" in plan
rope_plan = plan["image_rope_prepare"]
# hidden_states (index 0)
assert 0 in rope_plan
assert rope_plan[0].split_dim == 1
assert rope_plan[0].split_output is True
# vid_freqs (index 1)
assert 1 in rope_plan
assert rope_plan[1].split_dim == 0
assert rope_plan[1].split_output is True
# txt_freqs (index 2) should NOT be in plan (kept replicated)
assert 2 not in rope_plan
# Check output gathering at proj_out
assert "proj_out" in plan
proj_out_plan = plan["proj_out"]
assert proj_out_plan.gather_dim == 1
validate_sp_plan(plan)
except ImportError:
pytest.skip("QwenImageTransformer2DModel not available")
# =============================================================================
# Tests for tensor sharding simulation (no distributed required)
# =============================================================================
@pytest.mark.cpu
class TestMockSharding:
"""Test tensor sharding logic (mocked, no distributed)."""
def test_shard_tensor_simulation(self):
"""Simulate tensor sharding without distributed backend."""
# Create a test tensor
batch_size, seq_len, hidden_dim = 2, 16, 64
tensor = torch.randn(batch_size, seq_len, hidden_dim)
# Simulate sharding for world_size=4
world_size = 4
rank = 1
# Manual chunking (what sp_shard does internally)
chunks = tensor.chunk(world_size, dim=1)
sharded = chunks[rank]
assert sharded.shape == (batch_size, seq_len // world_size, hidden_dim)
assert sharded.shape == (2, 4, 64)
def test_partial_shard_simulation(self):
"""Simulate partial sharding (text kept, image sharded)."""
# Create a test tensor with [text, image] concatenated
batch_size = 2
text_len = 8
image_len = 16
hidden_dim = 64
text_part = torch.randn(batch_size, text_len, hidden_dim)
image_part = torch.randn(batch_size, image_len, hidden_dim)
tensor = torch.cat([text_part, image_part], dim=1)
assert tensor.shape == (batch_size, text_len + image_len, hidden_dim)
# Simulate partial sharding for world_size=4, rank=1
world_size = 4
rank = 1
dim = 1
# Extract parts
text_kept = tensor.narrow(dim, 0, text_len)
image_full = tensor.narrow(dim, text_len, image_len)
# Shard only image part
image_chunks = image_full.chunk(world_size, dim=dim)
image_sharded = image_chunks[rank]
# Concatenate back
result = torch.cat([text_kept, image_sharded], dim=dim)
expected_len = text_len + image_len // world_size
assert result.shape == (batch_size, expected_len, hidden_dim)
assert result.shape == (2, 8 + 4, 64) # text_len + image_len/4
def test_gather_tensor_simulation(self):
"""Simulate tensor gathering without distributed backend."""
# Create sharded tensors (as if from different ranks)
batch_size, shard_seq_len, hidden_dim = 2, 4, 64
world_size = 4
shards = [torch.randn(batch_size, shard_seq_len, hidden_dim) for _ in range(world_size)]
# Simulate gathering (concatenate along dim 1)
gathered = torch.cat(shards, dim=1)
assert gathered.shape == (batch_size, shard_seq_len * world_size, hidden_dim)
assert gathered.shape == (2, 16, 64)
def test_padding_simulation(self):
"""Simulate padding for non-divisible sequence lengths."""
# Create tensor with non-divisible sequence length
batch_size, seq_len, hidden_dim = 2, 17, 64 # 17 not divisible by 4
tensor = torch.randn(batch_size, seq_len, hidden_dim)
world_size = 4
dim = 1
# Calculate padding needed
remainder = seq_len % world_size
if remainder != 0:
pad_size = world_size - remainder
else:
pad_size = 0
assert pad_size == 3 # 17 + 3 = 20, divisible by 4
# Pad tensor
if pad_size > 0:
pad_shape = list(tensor.shape)
pad_shape[dim] = pad_size
padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
padded = torch.cat([tensor, padding], dim=dim)
else:
padded = tensor
assert padded.shape == (batch_size, seq_len + pad_size, hidden_dim)
assert padded.shape == (2, 20, 64)
# Now can shard evenly
chunks = padded.chunk(world_size, dim=dim)
assert all(c.shape == (2, 5, 64) for c in chunks)
# =============================================================================
# Additional tests for sequence_parallel.py coverage
# =============================================================================
@pytest.mark.cpu
class TestUnwrapModule:
"""Test _unwrap_module function."""
def test_unwrap_simple_module(self):
"""Test that a simple module returns itself."""
from vllm_omni.diffusion.hooks.sequence_parallel import _unwrap_module
module = nn.Linear(10, 10)
result = _unwrap_module(module)
assert result is module
def test_unwrap_sequential_single(self):
"""Test unwrapping a Sequential with single child."""
from vllm_omni.diffusion.hooks.sequence_parallel import _unwrap_module
inner = nn.Linear(10, 10)
wrapper = nn.Sequential(inner)
result = _unwrap_module(wrapper)
# Should unwrap to the inner module
assert result is inner
def test_unwrap_nested_wrapper(self):
"""Test unwrapping nested single-child wrappers."""
from vllm_omni.diffusion.hooks.sequence_parallel import _unwrap_module
inner = nn.Linear(10, 10)
wrapper1 = nn.Sequential(inner)
wrapper2 = nn.Sequential(wrapper1)
result = _unwrap_module(wrapper2)
# Should fully unwrap to the innermost module
assert result is inner
@pytest.mark.cpu
class TestSequenceParallelSplitHookInit:
"""Test SequenceParallelSplitHook initialization and setup."""
def test_hook_init(self):
"""Test SequenceParallelSplitHook initialization."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook
metadata = {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
}
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
hook = SequenceParallelSplitHook(metadata, config)
assert hook.metadata == metadata
assert hook.config == config
assert hook.module_forward_metadata is None # Not initialized until initialize_hook
def test_hook_initialize(self):
"""Test SequenceParallelSplitHook.initialize_hook."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook
class DummyModule(nn.Module):
def forward(self, hidden_states, encoder_hidden_states):
return hidden_states
metadata = {
"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
}
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
hook = SequenceParallelSplitHook(metadata, config)
module = DummyModule()
# Initialize hook
result = hook.initialize_hook(module)
assert result is module
assert hook.module_forward_metadata is not None
assert hook.module_forward_metadata._cls is DummyModule
@pytest.mark.cpu
class TestSequenceParallelGatherHookInit:
"""Test SequenceParallelGatherHook initialization."""
def test_hook_init_single_output(self):
"""Test SequenceParallelGatherHook with single output."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelGatherHook
metadata = SequenceParallelOutput(gather_dim=1, expected_dims=3)
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
hook = SequenceParallelGatherHook(metadata, config)
# Single output should be wrapped in a list
assert isinstance(hook.metadata, list)
assert len(hook.metadata) == 1
assert hook.metadata[0].gather_dim == 1
def test_hook_init_multiple_outputs(self):
"""Test SequenceParallelGatherHook with multiple outputs."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelGatherHook
metadata = [
SequenceParallelOutput(gather_dim=1, expected_dims=3),
SequenceParallelOutput(gather_dim=2, expected_dims=4),
]
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
hook = SequenceParallelGatherHook(metadata, config)
assert len(hook.metadata) == 2
assert hook.metadata[0].gather_dim == 1
assert hook.metadata[1].gather_dim == 2
@pytest.mark.cpu
class TestResolveTextLen:
"""Test _resolve_text_len in SequenceParallelSplitHook."""
def test_resolve_int_source(self):
"""Test resolving text length from integer source."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook
class DummyModule(nn.Module):
def forward(self, x, txt_ids):
return x
partial_input = SequenceParallelPartialInput(
split_dim=1,
text_len_source=256, # Fixed integer
expected_dims=3,
)
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
hook = SequenceParallelSplitHook({"x": partial_input}, config)
hook.initialize_hook(DummyModule())
# Resolve with integer source
text_len = hook._resolve_text_len(partial_input, (), {})
assert text_len == 256
def test_resolve_string_source_from_tensor(self):
"""Test resolving text length from tensor parameter."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook
class DummyModule(nn.Module):
def forward(self, x, txt_ids):
return x
partial_input = SequenceParallelPartialInput(
split_dim=1,
text_len_source="txt_ids", # Get from parameter
expected_dims=3,
)
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
hook = SequenceParallelSplitHook({"x": partial_input}, config)
hook.initialize_hook(DummyModule())
# Provide txt_ids tensor
txt_ids = torch.randn(128, 64) # shape[0] = 128
kwargs = {"txt_ids": txt_ids}
text_len = hook._resolve_text_len(partial_input, (), kwargs)
assert text_len == 128
def test_resolve_text_len_caching(self):
"""Test that text length is cached."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook
class DummyModule(nn.Module):
def forward(self, x, txt_ids):
return x
partial_input = SequenceParallelPartialInput(
split_dim=1,
text_len_source="txt_ids",
expected_dims=3,
)
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
hook = SequenceParallelSplitHook({"x": partial_input}, config)
hook.initialize_hook(DummyModule())
txt_ids = torch.randn(64, 32)
kwargs = {"txt_ids": txt_ids}
# First call - should populate cache
hook._resolve_text_len(partial_input, (), kwargs)
assert "txt_ids" in hook._text_len_cache
assert hook._text_len_cache["txt_ids"] == 64
# Second call - should use cache
text_len = hook._resolve_text_len(partial_input, (), kwargs)
assert text_len == 64
@pytest.mark.cpu
class TestHookNameTemplates:
"""Test hook name template generation."""
def test_input_hook_name(self):
"""Test input hook name format."""
from vllm_omni.diffusion.hooks.sequence_parallel import _SP_INPUT_HOOK_TEMPLATE
name = _SP_INPUT_HOOK_TEMPLATE.format("blocks.0")
assert name == "sp_input---blocks.0"
def test_output_hook_name(self):
"""Test output hook name format."""
from vllm_omni.diffusion.hooks.sequence_parallel import _SP_OUTPUT_HOOK_TEMPLATE
name = _SP_OUTPUT_HOOK_TEMPLATE.format("proj_out")
assert name == "sp_output---proj_out"
@pytest.mark.cpu
class TestApplyRemoveSequenceParallel:
"""Test apply_sequence_parallel and remove_sequence_parallel functions."""
def test_apply_sp_registers_hooks(self):
"""Test that apply_sequence_parallel registers hooks on modules."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import (
_SP_INPUT_HOOK_TEMPLATE,
_SP_OUTPUT_HOOK_TEMPLATE,
apply_sequence_parallel,
)
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.proj_in = nn.Linear(10, 10)
self.proj_out = nn.Linear(10, 10)
def forward(self, hidden_states):
x = self.proj_in(hidden_states)
return self.proj_out(x)
model = SimpleModel()
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
plan = {
"proj_in": {"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3)},
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
# Apply SP
apply_sequence_parallel(model, config, plan)
# Check hooks are registered
assert hasattr(model.proj_in, "_hook_registry")
assert hasattr(model.proj_out, "_hook_registry")
proj_in_registry = model.proj_in._hook_registry
proj_out_registry = model.proj_out._hook_registry
assert _SP_INPUT_HOOK_TEMPLATE.format("proj_in") in proj_in_registry._hooks
assert _SP_OUTPUT_HOOK_TEMPLATE.format("proj_out") in proj_out_registry._hooks
def test_remove_sp_removes_hooks(self):
"""Test that remove_sequence_parallel removes hooks from modules."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import (
_SP_INPUT_HOOK_TEMPLATE,
_SP_OUTPUT_HOOK_TEMPLATE,
apply_sequence_parallel,
remove_sequence_parallel,
)
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.proj_in = nn.Linear(10, 10)
self.proj_out = nn.Linear(10, 10)
def forward(self, hidden_states):
x = self.proj_in(hidden_states)
return self.proj_out(x)
model = SimpleModel()
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
plan = {
"proj_in": {"hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3)},
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
# Apply then remove SP
apply_sequence_parallel(model, config, plan)
remove_sequence_parallel(model, plan)
# Check hooks are removed
proj_in_registry = model.proj_in._hook_registry
proj_out_registry = model.proj_out._hook_registry
assert _SP_INPUT_HOOK_TEMPLATE.format("proj_in") not in proj_in_registry._hooks
assert _SP_OUTPUT_HOOK_TEMPLATE.format("proj_out") not in proj_out_registry._hooks
def test_apply_sp_with_wildcard(self):
"""Test apply_sequence_parallel with wildcard module names."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import (
_SP_INPUT_HOOK_TEMPLATE,
apply_sequence_parallel,
)
class Block(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.blocks = nn.ModuleList([Block() for _ in range(3)])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
model = Model()
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
plan = {
"blocks.*": {"x": SequenceParallelInput(split_dim=1, expected_dims=3)},
}
# Apply SP
apply_sequence_parallel(model, config, plan)
# Check all blocks have hooks registered
for i, block in enumerate(model.blocks):
assert hasattr(block, "_hook_registry")
registry = block._hook_registry
assert _SP_INPUT_HOOK_TEMPLATE.format("blocks.*") in registry._hooks
@pytest.mark.cpu
class TestDimensionValidation:
"""Test expected_dims validation in hooks."""
def test_skip_shard_on_wrong_dims(self):
"""Test that sharding is skipped when tensor dims don't match expected."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
from vllm_omni.diffusion.hooks.sequence_parallel import SequenceParallelSplitHook
class DummyModule(nn.Module):
def forward(self, x):
return x
# Expect 3D tensor
metadata = {
"x": SequenceParallelInput(split_dim=1, expected_dims=3),
}
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=1)
hook = SequenceParallelSplitHook(metadata, config)
hook.initialize_hook(DummyModule())
# Provide 4D tensor (wrong dims)
tensor_4d = torch.randn(2, 4, 8, 16)
# _prepare_sp_input should return tensor unchanged when dims don't match
result = hook._prepare_sp_input(tensor_4d, metadata["x"], (), {})
# Since expected_dims=3 but tensor has 4 dims, should return original
assert result.shape == tensor_4d.shape
@pytest.mark.cpu
class TestSequenceParallelConfig:
"""Test SequenceParallelConfig dataclass."""
def test_config_defaults_invalid(self):
"""Test that SequenceParallelConfig with default values raises error.
At least one of ulysses_degree or ring_degree must be > 1 to enable SP.
"""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
with pytest.raises(ValueError, match="must be > 1"):
SequenceParallelConfig() # Both defaults are 1, which is invalid
def test_config_ulysses_only(self):
"""Test SequenceParallelConfig with Ulysses only."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
config = SequenceParallelConfig(ulysses_degree=4, ring_degree=1)
assert config.sequence_parallel_size == 4
def test_config_ring_only(self):
"""Test SequenceParallelConfig with Ring only."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
config = SequenceParallelConfig(ulysses_degree=1, ring_degree=4)
assert config.sequence_parallel_size == 4
def test_config_hybrid(self):
"""Test SequenceParallelConfig with hybrid (Ulysses + Ring)."""
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelConfig
config = SequenceParallelConfig(ulysses_degree=2, ring_degree=4)
assert config.sequence_parallel_size == 8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
import torch
from vllm_omni.diffusion.lora.layers.base_linear import DiffusionBaseLinearLayerWithLoRA
@dataclass
class _DummyLoRAConfig:
fully_sharded_loras: bool = False
class _DummyQuantMethod:
def __init__(self, weight: torch.Tensor):
self._weight = weight
def apply(self, _base_layer, x: torch.Tensor, bias: torch.Tensor | None):
y = x @ self._weight.t()
if bias is not None:
y = y + bias
return y
def test_diffusion_base_linear_apply_multi_slice():
# Build a fake diffusion LoRA layer with 2 slices and rank=2.
layer = DiffusionBaseLinearLayerWithLoRA.__new__(DiffusionBaseLinearLayerWithLoRA)
layer.tp_size = 1
layer.lora_config = _DummyLoRAConfig()
in_dim = 3
out_slices = (2, 1)
rank = 2
# Base weight: identity-ish mapping to make base output easy to reason about.
base_weight = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
)
layer.base_layer = type("Base", (), {})()
layer.base_layer.quant_method = _DummyQuantMethod(base_weight)
# Allocate stacked weights: (max_loras=1, 1, rank, in_dim) and (1, 1, out_slice, rank)
a0 = torch.zeros((1, 1, rank, in_dim))
b0 = torch.zeros((1, 1, out_slices[0], rank))
a1 = torch.zeros((1, 1, rank, in_dim))
b1 = torch.zeros((1, 1, out_slices[1], rank))
# Slice 0: delta0 = (x @ A0.T) @ B0.T
A0 = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) # (2, 3)
B0 = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) # (2, 2)
a0[0, 0, :, :] = A0
b0[0, 0, :, :] = B0
# Slice 1: delta1 = (x @ A1.T) @ B1.T
A1 = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) # (2, 3)
B1 = torch.tensor([[2.0, 0.0]]) # (1, 2)
a1[0, 0, :, :] = A1
b1[0, 0, :, :] = B1
layer.lora_a_stacked = (a0, a1)
layer.lora_b_stacked = (b0, b1)
layer.output_slices = out_slices
x = torch.tensor([[1.0, 2.0, 3.0]])
out = layer.apply(x)
# Base output is identity: [1,2,3]
base_out = x @ base_weight.t()
# delta0:
# (x @ A0.T) = [1,2]
# [1,2] @ B0.T = [1,2]
delta0 = torch.tensor([[1.0, 2.0]])
# delta1:
# (x @ A1.T) = [3,1]
# [3,1] @ B1.T = [6]
delta1 = torch.tensor([[6.0]])
expected = torch.cat([base_out[:, :2] + delta0, base_out[:, 2:3] + delta1], dim=-1)
assert torch.allclose(out, expected)
def test_diffusion_base_linear_reset_lora_disables_fast_path(monkeypatch):
# Verify that after reset_lora(), apply() skips LoRA matmuls even if the
# LoRA tensors are still allocated and non-empty.
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
layer = DiffusionBaseLinearLayerWithLoRA.__new__(DiffusionBaseLinearLayerWithLoRA)
layer.tp_size = 1
layer.lora_config = _DummyLoRAConfig()
in_dim = 2
out_dim = 2
rank = 1
base_weight = torch.eye(in_dim)
layer.base_layer = type("Base", (), {})()
layer.base_layer.quant_method = _DummyQuantMethod(base_weight)
a = torch.ones((1, 1, rank, in_dim))
b = torch.tensor([[[[1.0], [2.0]]]]) # (1,1,out_dim,rank)
layer.lora_a_stacked = (a,)
layer.lora_b_stacked = (b,)
layer.output_slices = (out_dim,)
layer._diffusion_lora_active_slices = (True,)
x = torch.tensor([[1.0, 2.0]])
out_active = layer.apply(x)
assert torch.allclose(out_active, torch.tensor([[4.0, 8.0]]))
monkeypatch.setattr(BaseLinearLayerWithLoRA, "reset_lora", lambda self, index: None)
layer.reset_lora(0)
assert layer._diffusion_lora_active_slices == (False,)
out_inactive = layer.apply(x)
assert torch.allclose(out_inactive, x)
def test_diffusion_base_linear_apply_respects_inactive_slices():
# Build a fake diffusion LoRA layer with 2 slices and rank=2.
layer = DiffusionBaseLinearLayerWithLoRA.__new__(DiffusionBaseLinearLayerWithLoRA)
layer.tp_size = 1
layer.lora_config = _DummyLoRAConfig()
in_dim = 3
out_slices = (2, 1)
rank = 2
base_weight = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
)
layer.base_layer = type("Base", (), {})()
layer.base_layer.quant_method = _DummyQuantMethod(base_weight)
a0 = torch.zeros((1, 1, rank, in_dim))
b0 = torch.zeros((1, 1, out_slices[0], rank))
a1 = torch.zeros((1, 1, rank, in_dim))
b1 = torch.zeros((1, 1, out_slices[1], rank))
A0 = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) # (2, 3)
B0 = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) # (2, 2)
a0[0, 0, :, :] = A0
b0[0, 0, :, :] = B0
A1 = torch.tensor([[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) # (2, 3)
B1 = torch.tensor([[2.0, 0.0]]) # (1, 2)
a1[0, 0, :, :] = A1
b1[0, 0, :, :] = B1
layer.lora_a_stacked = (a0, a1)
layer.lora_b_stacked = (b0, b1)
layer.output_slices = out_slices
layer._diffusion_lora_active_slices = (True, False)
x = torch.tensor([[1.0, 2.0, 3.0]])
out = layer.apply(x)
# Only the first slice should be adapted.
expected = torch.tensor([[2.0, 4.0, 3.0]])
assert torch.allclose(out, expected)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import torch
from vllm.lora.lora_weights import LoRALayerWeights
from vllm.lora.utils import get_supported_lora_modules
from vllm.model_executor.layers.linear import LinearBase
from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager
from vllm_omni.lora.request import LoRARequest
class _DummyLoRALayer:
def __init__(self, n_slices: int, output_slices: tuple[int, ...]):
self.n_slices = n_slices
self.output_slices = output_slices
self.set_calls: list[
tuple[list[torch.Tensor | None] | torch.Tensor, list[torch.Tensor | None] | torch.Tensor]
] = []
self.reset_calls: int = 0
def set_lora(self, index: int, lora_a, lora_b):
assert index == 0
self.set_calls.append((lora_a, lora_b))
def reset_lora(self, index: int):
assert index == 0
self.reset_calls += 1
class _FakeLinearBase(LinearBase):
def __init__(self):
torch.nn.Module.__init__(self)
def test_lora_manager_supported_modules_are_stable_with_wrapped_layers(monkeypatch):
# Simulate a pipeline that already contains LoRA wrappers where the original
# LinearBase is nested under ".base_layer".
import vllm_omni.diffusion.lora.manager as manager_mod
class _DummyBaseLayerWithLoRA(torch.nn.Module):
def __init__(self, base_layer: torch.nn.Module):
super().__init__()
self.base_layer = base_layer
monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", _DummyBaseLayerWithLoRA)
pipeline = torch.nn.Module()
pipeline.transformer = torch.nn.Module()
pipeline.transformer.foo = _DummyBaseLayerWithLoRA(_FakeLinearBase())
# vLLM helper would see only the nested LinearBase and yield "base_layer".
assert get_supported_lora_modules(pipeline) == ["base_layer"]
manager = DiffusionLoRAManager(
pipeline=pipeline,
device=torch.device("cpu"),
dtype=torch.bfloat16,
max_cached_adapters=1,
)
assert "foo" in manager._supported_lora_modules
assert "base_layer" not in manager._supported_lora_modules
def test_lora_manager_replace_layers_does_not_rewrap_base_layer(monkeypatch):
import vllm_omni.diffusion.lora.manager as manager_mod
class _DummyBaseLayerWithLoRA(torch.nn.Module):
def __init__(self, base_layer: torch.nn.Module):
super().__init__()
self.base_layer = base_layer
monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", _DummyBaseLayerWithLoRA)
def _fake_from_layer_diffusion(*, layer: torch.nn.Module, **_kwargs):
if isinstance(layer, _FakeLinearBase):
return _DummyBaseLayerWithLoRA(layer)
return layer
replace_calls: list[str] = []
def _fake_replace_submodule(root: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
replace_calls.append(module_name)
setattr(root, module_name, submodule)
monkeypatch.setattr(manager_mod, "from_layer_diffusion", _fake_from_layer_diffusion)
monkeypatch.setattr(manager_mod, "replace_submodule", _fake_replace_submodule)
pipeline = torch.nn.Module()
pipeline.transformer = torch.nn.Module()
pipeline.transformer.foo = _FakeLinearBase()
manager = DiffusionLoRAManager(
pipeline=pipeline,
device=torch.device("cpu"),
dtype=torch.bfloat16,
max_cached_adapters=1,
)
peft_helper = type("_PH", (), {"r": 1})()
manager._replace_layers_with_lora(peft_helper)
manager._replace_layers_with_lora(peft_helper)
# Only the top-level layer should have been replaced; nested ".base_layer"
# must be skipped to avoid nesting LoRA wrappers.
assert replace_calls == ["foo"]
def test_lora_manager_replaces_packed_layer_when_targeting_sublayers(monkeypatch):
import vllm_omni.diffusion.lora.manager as manager_mod
class _DummyBaseLayerWithLoRA(torch.nn.Module):
def __init__(self, base_layer: torch.nn.Module):
super().__init__()
self.base_layer = base_layer
monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", _DummyBaseLayerWithLoRA)
def _fake_from_layer_diffusion(*, layer: torch.nn.Module, **_kwargs):
return _DummyBaseLayerWithLoRA(layer)
replace_calls: list[str] = []
def _fake_replace_submodule(root: torch.nn.Module, module_name: str, submodule: torch.nn.Module):
replace_calls.append(module_name)
setattr(root, module_name, submodule)
monkeypatch.setattr(manager_mod, "from_layer_diffusion", _fake_from_layer_diffusion)
monkeypatch.setattr(manager_mod, "replace_submodule", _fake_replace_submodule)
pipeline = torch.nn.Module()
pipeline.packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]}
pipeline.transformer = torch.nn.Module()
pipeline.transformer.to_qkv = _FakeLinearBase()
manager = DiffusionLoRAManager(
pipeline=pipeline,
device=torch.device("cpu"),
dtype=torch.bfloat16,
max_cached_adapters=1,
)
# Treat the dummy layer as a packed 3-slice projection so the manager uses
# `packed_modules_mapping` to decide replacement based on target_modules.
monkeypatch.setattr(manager, "_get_packed_modules_list", lambda _module: ["q", "k", "v"])
peft_helper = type("_PH", (), {"r": 1, "target_modules": ["to_q"]})()
manager._replace_layers_with_lora(peft_helper)
assert replace_calls == ["to_qkv"]
def test_lora_manager_activates_fused_lora_on_packed_layer():
manager = DiffusionLoRAManager(
pipeline=torch.nn.Module(),
device=torch.device("cpu"),
dtype=torch.bfloat16,
max_cached_adapters=1,
)
packed_layer = _DummyLoRALayer(n_slices=3, output_slices=(2, 1, 1))
manager._lora_modules = {"transformer.blocks.0.attn.to_qkv": packed_layer}
rank = 2
A = torch.ones((rank, 4))
B = torch.arange(0, sum(packed_layer.output_slices) * rank, dtype=torch.bfloat16).view(-1, rank)
lora = LoRALayerWeights(
module_name="transformer.blocks.0.attn.to_qkv",
rank=rank,
lora_alpha=rank,
lora_a=A,
lora_b=B,
)
manager._registered_adapters = {
7: type(
"LM",
(),
{
"id": 7,
"loras": {"transformer.blocks.0.attn.to_qkv": lora},
"get_lora": lambda self, k: self.loras.get(k),
},
)()
}
manager._adapter_scales = {7: 0.5}
manager._activate_adapter(7)
assert packed_layer.reset_calls == 0
assert len(packed_layer.set_calls) == 1
lora_a_list, lora_b_list = packed_layer.set_calls[0]
assert isinstance(lora_a_list, list)
assert isinstance(lora_b_list, list)
assert len(lora_a_list) == 3
assert len(lora_b_list) == 3
assert all(torch.allclose(a, A) for a in lora_a_list)
# B should be split into 3 slices and scaled.
b0, b1, b2 = lora_b_list
assert b0.shape[0] == 2 and b1.shape[0] == 1 and b2.shape[0] == 1
assert torch.allclose(torch.cat([b0, b1, b2], dim=0), B * 0.5)
def test_lora_manager_activates_packed_lora_from_sublayers():
pipeline = torch.nn.Module()
pipeline.packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]}
manager = DiffusionLoRAManager(
pipeline=pipeline,
device=torch.device("cpu"),
dtype=torch.bfloat16,
max_cached_adapters=1,
)
packed_layer = _DummyLoRALayer(n_slices=3, output_slices=(2, 1, 1))
manager._lora_modules = {"transformer.blocks.0.attn.to_qkv": packed_layer}
rank = 2
loras: dict[str, LoRALayerWeights] = {}
for name, out_dim in zip(["to_q", "to_k", "to_v"], [2, 1, 1]):
loras[f"transformer.blocks.0.attn.{name}"] = LoRALayerWeights(
module_name=f"transformer.blocks.0.attn.{name}",
rank=rank,
lora_alpha=rank,
lora_a=torch.ones((rank, 4)) * (1 if name == "to_q" else 2),
lora_b=torch.ones((out_dim, rank)) * (3 if name == "to_q" else 4),
)
manager._registered_adapters = {
1: type("LM", (), {"id": 1, "loras": loras, "get_lora": lambda self, k: self.loras.get(k)})()
}
manager._adapter_scales = {1: 2.0}
manager._activate_adapter(1)
assert packed_layer.reset_calls == 0
assert len(packed_layer.set_calls) == 1
lora_a_list, lora_b_list = packed_layer.set_calls[0]
assert isinstance(lora_a_list, list)
assert isinstance(lora_b_list, list)
assert len(lora_a_list) == 3
assert len(lora_b_list) == 3
# Scale should apply to B only.
assert torch.allclose(lora_b_list[0], torch.ones((2, rank)) * 3 * 2.0)
assert torch.allclose(lora_b_list[1], torch.ones((1, rank)) * 4 * 2.0)
assert torch.allclose(lora_b_list[2], torch.ones((1, rank)) * 4 * 2.0)
def _dummy_lora_request(adapter_id: int) -> LoRARequest:
return LoRARequest(
lora_name=f"adapter_{adapter_id}",
lora_int_id=adapter_id,
lora_path=f"/tmp/adapter_{adapter_id}",
)
def test_lora_manager_evicts_lru_adapter_when_cache_full(monkeypatch):
manager = DiffusionLoRAManager(
pipeline=torch.nn.Module(),
device=torch.device("cpu"),
dtype=torch.bfloat16,
max_cached_adapters=2,
)
def _fake_load(_req: LoRARequest):
lora_model = type("LM", (), {"id": _req.lora_int_id})()
peft_helper = type("PH", (), {})()
return lora_model, peft_helper
monkeypatch.setattr(manager, "_load_adapter", _fake_load)
monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None)
monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id: None)
req1 = _dummy_lora_request(1)
req2 = _dummy_lora_request(2)
req3 = _dummy_lora_request(3)
manager.set_active_adapter(req1, lora_scale=1.0)
manager.set_active_adapter(req2, lora_scale=1.0)
# Touch adapter 1 so adapter 2 becomes LRU.
manager.set_active_adapter(req1, lora_scale=1.0)
manager.set_active_adapter(req3, lora_scale=1.0)
assert set(manager.list_adapters()) == {1, 3}
def test_lora_manager_does_not_evict_pinned_adapter(monkeypatch):
manager = DiffusionLoRAManager(
pipeline=torch.nn.Module(),
device=torch.device("cpu"),
dtype=torch.bfloat16,
max_cached_adapters=2,
)
def _fake_load(_req: LoRARequest):
lora_model = type("LM", (), {"id": _req.lora_int_id})()
peft_helper = type("PH", (), {})()
return lora_model, peft_helper
monkeypatch.setattr(manager, "_load_adapter", _fake_load)
monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None)
monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id: None)
manager.set_active_adapter(_dummy_lora_request(1), lora_scale=1.0)
assert manager.pin_adapter(1)
manager.set_active_adapter(_dummy_lora_request(2), lora_scale=1.0)
manager.set_active_adapter(_dummy_lora_request(3), lora_scale=1.0)
assert set(manager.list_adapters()) == {1, 3}
def test_lora_manager_warns_when_all_adapters_pinned(monkeypatch):
manager = DiffusionLoRAManager(
pipeline=torch.nn.Module(),
device=torch.device("cpu"),
dtype=torch.bfloat16,
max_cached_adapters=2,
)
def _fake_load(_req: LoRARequest):
lora_model = type("LM", (), {"id": _req.lora_int_id})()
peft_helper = type("PH", (), {})()
return lora_model, peft_helper
monkeypatch.setattr(manager, "_load_adapter", _fake_load)
monkeypatch.setattr(manager, "_replace_layers_with_lora", lambda _peft: None)
monkeypatch.setattr(manager, "_activate_adapter", lambda _adapter_id: None)
manager.set_active_adapter(_dummy_lora_request(1), lora_scale=1.0)
manager.set_active_adapter(_dummy_lora_request(2), lora_scale=1.0)
assert manager.pin_adapter(1)
assert manager.pin_adapter(2)
manager.max_cached_adapters = 1
manager._evict_if_needed()
assert set(manager.list_adapters()) == {1, 2}
import pytest
from vllm_omni.diffusion.models.z_image.z_image_transformer import validate_zimage_tp_constraints
def test_validate_zimage_tp_constraints_tp2_ok():
ffn_hidden_dim, final_out_dims, supported_tp = validate_zimage_tp_constraints(
dim=3840,
n_heads=30,
n_kv_heads=30,
in_channels=16,
all_patch_size=(2,),
all_f_patch_size=(1,),
tensor_parallel_size=2,
)
assert ffn_hidden_dim == 10240
assert final_out_dims == [64]
assert supported_tp == [1, 2]
def test_validate_zimage_tp_constraints_tp4_fails_on_heads():
with pytest.raises(ValueError, match=r"n_heads % tensor_parallel_size"):
validate_zimage_tp_constraints(
dim=3840,
n_heads=30,
n_kv_heads=30,
in_channels=16,
all_patch_size=(2,),
all_f_patch_size=(1,),
tensor_parallel_size=4,
)
def test_validate_zimage_tp_constraints_tp3_fails_on_ffn_hidden_dim():
with pytest.raises(ValueError, match=r"ffn_hidden_dim % tensor_parallel_size"):
validate_zimage_tp_constraints(
dim=3840,
n_heads=30,
n_kv_heads=30,
in_channels=16,
all_patch_size=(2,),
all_f_patch_size=(1,),
tensor_parallel_size=3,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for DiffusionWorker class.
This module tests the DiffusionWorker implementation:
- load_weights: Loading model weights
- sleep: Putting worker into sleep mode (levels 1 and 2)
- wake_up: Waking worker from sleep mode
"""
from unittest.mock import Mock, patch
import pytest
import torch
from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker
@pytest.fixture
def mock_od_config():
"""Create a mock OmniDiffusionConfig."""
config = Mock()
config.num_gpus = 1
config.master_port = 12345
config.enable_sleep_mode = False
config.cache_backend = None
config.cache_config = None
config.model = "test-model"
return config
@pytest.fixture
def mock_gpu_worker(mock_od_config):
"""Create a DiffusionWorker with mocked initialization."""
with patch.object(DiffusionWorker, "init_device"):
worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config)
# Mock the model_runner with pipeline
worker.model_runner = Mock()
worker.model_runner.pipeline = Mock()
worker.device = torch.device("cuda", 0)
worker._sleep_saved_buffers = {}
return worker
class TestDiffusionWorkerLoadWeights:
"""Test DiffusionWorker.load_weights method."""
def test_load_weights_calls_pipeline(self, mock_gpu_worker):
"""Test that load_weights delegates to model_runner.load_weights."""
# Setup mock weights
mock_weights = [
("layer1.weight", torch.randn(10, 10)),
("layer2.weight", torch.randn(20, 20)),
]
expected_loaded = {"layer1.weight", "layer2.weight"}
# Configure model_runner mock
mock_gpu_worker.model_runner.load_weights = Mock(return_value=expected_loaded)
# Call load_weights
result = mock_gpu_worker.load_weights(mock_weights)
# Verify model_runner.load_weights was called with the weights
mock_gpu_worker.model_runner.load_weights.assert_called_once_with(mock_weights)
assert result == expected_loaded
def test_load_weights_empty_iterable(self, mock_gpu_worker):
"""Test load_weights with empty weights iterable."""
mock_gpu_worker.model_runner.load_weights = Mock(return_value=set())
result = mock_gpu_worker.load_weights([])
mock_gpu_worker.model_runner.load_weights.assert_called_once_with([])
assert result == set()
class TestDiffusionWorkerSleep:
"""Test DiffusionWorker.sleep method."""
@patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_sleep_level_1(self, mock_allocator_class, mock_platform, mock_gpu_worker):
"""Test sleep mode level 1 (offload weights only)."""
# Setup memory info mocks
# Before sleep: 1GB free
# After sleep: 3GB free (freed 2GB)
mock_platform.get_free_memory.side_effect = [
1 * 1024**3, # Before sleep
3 * 1024**3, # After sleep
]
mock_platform.get_device_total_memory.return_value = 8 * 1024**3
# Setup allocator mock
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.sleep = Mock()
# Call sleep with level 1
result = mock_gpu_worker.sleep(level=1)
# Verify sleep was called with correct tags
mock_allocator.sleep.assert_called_once_with(offload_tags=("weights",))
assert result is True
# Verify buffers were NOT saved (level 1 doesn't save buffers)
assert len(mock_gpu_worker._sleep_saved_buffers) == 0
@patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_sleep_level_2(self, mock_allocator_class, mock_platform, mock_gpu_worker):
"""Test sleep mode level 2 (offload all, save buffers)."""
# Setup memory info mocks
mock_platform.get_free_memory.side_effect = [
1 * 1024**3, # Before sleep
5 * 1024**3, # After sleep (freed 4GB)
]
mock_platform.get_device_total_memory.return_value = 8 * 1024**3
# Setup allocator mock
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.sleep = Mock()
# Mock pipeline buffers
mock_buffer1 = torch.randn(10, 10)
mock_buffer2 = torch.randn(20, 20)
mock_gpu_worker.model_runner.pipeline.named_buffers = Mock(
return_value=[
("buffer1", mock_buffer1),
("buffer2", mock_buffer2),
]
)
# Call sleep with level 2
result = mock_gpu_worker.sleep(level=2)
# Verify sleep was called with empty tags (offload all)
mock_allocator.sleep.assert_called_once_with(offload_tags=tuple())
assert result is True
# Verify buffers were saved
assert len(mock_gpu_worker._sleep_saved_buffers) == 2
assert "buffer1" in mock_gpu_worker._sleep_saved_buffers
assert "buffer2" in mock_gpu_worker._sleep_saved_buffers
@patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_sleep_memory_freed_validation(self, mock_allocator_class, mock_platform, mock_gpu_worker):
"""Test that sleep validates memory was actually freed."""
# Simulate memory increase (should trigger assertion error)
mock_platform.get_free_memory.side_effect = [
3 * 1024**3, # Before sleep: 3GB free
1 * 1024**3, # After sleep: 1GB free (negative freed!)
]
mock_platform.get_device_total_memory.return_value = 8 * 1024**3
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.sleep = Mock()
# This should raise an assertion error
with pytest.raises(AssertionError, match="Memory usage increased after sleeping"):
mock_gpu_worker.sleep(level=1)
class TestDiffusionWorkerWakeUp:
"""Test DiffusionWorker.wake_up method."""
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_wake_up_without_buffers(self, mock_allocator_class, mock_gpu_worker):
"""Test wake_up without saved buffers (level 1 sleep)."""
# Setup allocator mock
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.wake_up = Mock()
# Ensure no saved buffers
mock_gpu_worker._sleep_saved_buffers = {}
# Call wake_up
result = mock_gpu_worker.wake_up(tags=["weights"])
# Verify allocator.wake_up was called
mock_allocator.wake_up.assert_called_once_with(["weights"])
assert result is True
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_wake_up_with_buffers(self, mock_allocator_class, mock_gpu_worker):
"""Test wake_up with saved buffers (level 2 sleep)."""
# Setup allocator mock
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.wake_up = Mock()
# Create saved buffers
saved_buffer1 = torch.randn(10, 10)
saved_buffer2 = torch.randn(20, 20)
mock_gpu_worker._sleep_saved_buffers = {
"buffer1": saved_buffer1,
"buffer2": saved_buffer2,
}
# Mock pipeline buffers (these will be restored)
mock_buffer1 = Mock()
mock_buffer1.data = Mock()
mock_buffer2 = Mock()
mock_buffer2.data = Mock()
mock_gpu_worker.model_runner.pipeline.named_buffers = Mock(
return_value=[
("buffer1", mock_buffer1),
("buffer2", mock_buffer2),
]
)
# Call wake_up
result = mock_gpu_worker.wake_up(tags=None)
# Verify allocator.wake_up was called
mock_allocator.wake_up.assert_called_once_with(None)
# Verify buffers were restored
mock_buffer1.data.copy_.assert_called_once()
mock_buffer2.data.copy_.assert_called_once()
# Verify saved buffers were cleared
assert len(mock_gpu_worker._sleep_saved_buffers) == 0
assert result is True
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_wake_up_partial_buffer_restore(self, mock_allocator_class, mock_gpu_worker):
"""Test wake_up only restores buffers that were saved."""
# Setup allocator mock
mock_allocator = Mock()
mock_allocator_class.get_instance = Mock(return_value=mock_allocator)
mock_allocator.wake_up = Mock()
# Only save buffer1, not buffer2
saved_buffer1 = torch.randn(10, 10)
mock_gpu_worker._sleep_saved_buffers = {
"buffer1": saved_buffer1,
}
# Mock pipeline has both buffers
mock_buffer1 = Mock()
mock_buffer1.data = Mock()
mock_buffer2 = Mock()
mock_buffer2.data = Mock()
mock_gpu_worker.model_runner.pipeline.named_buffers = Mock(
return_value=[
("buffer1", mock_buffer1),
("buffer2", mock_buffer2),
]
)
# Call wake_up
result = mock_gpu_worker.wake_up()
# Verify only buffer1 was restored
mock_buffer1.data.copy_.assert_called_once()
# buffer2 should NOT be restored since it wasn't saved
mock_buffer2.data.copy_.assert_not_called()
assert result is True
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector, try_send_via_connector
from vllm_omni.distributed.omni_connectors.connectors.shm_connector import SharedMemoryConnector
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec, OmniTransferConfig
from vllm_omni.distributed.omni_connectors.utils.initialization import get_connectors_config_for_stage
@pytest.fixture
def mock_objects():
return {"connector": MagicMock(), "metrics": MagicMock(), "queue_fn": MagicMock()}
def test_send_success(mock_objects):
"""Test try_send_via_connector success path."""
# Setup
mock_connector = mock_objects["connector"]
mock_metrics = mock_objects["metrics"]
mock_queue_fn = mock_objects["queue_fn"]
stage_id = 0
next_stage_id = 1
req_id = "req_123"
inputs = {"input_ids": [1, 2, 3]}
sampling_params = {"temperature": 0.7}
prompt = "test prompt"
# Mock connector.put return
# Returns: (success, size, metadata)
mock_metadata = {"handle": "xyz"}
mock_connector.put.return_value = (True, 100, mock_metadata)
# Execute
result = try_send_via_connector(
connector=mock_connector,
stage_id=stage_id,
next_stage_id=next_stage_id,
req_id=req_id,
next_inputs=inputs,
sampling_params=sampling_params,
original_prompt=prompt,
next_stage_queue_submit_fn=mock_queue_fn,
metrics=mock_metrics,
)
# Verify
assert result is True
# 1. Verify connector.put called correctly
mock_connector.put.assert_called_once()
args, _ = mock_connector.put.call_args
assert args[0] == "0" # from_stage
assert args[1] == "1" # to_stage
assert args[2] == req_id
# Verify payload structure in put
payload = args[3]
assert payload["engine_inputs"] == inputs
assert payload["sampling_params"] == sampling_params
# 2. Verify queue notification submitted
mock_queue_fn.assert_called_once()
notify_payload = mock_queue_fn.call_args[0][0]
assert notify_payload["request_id"] == req_id
assert notify_payload["from_connector"] is True
assert notify_payload["connector_metadata"] == mock_metadata
# 3. Verify metrics recorded
mock_metrics.on_forward.assert_called_once()
def test_send_fail(mock_objects):
"""Test try_send_via_connector when connector fails."""
mock_connector = mock_objects["connector"]
mock_metrics = mock_objects["metrics"]
mock_queue_fn = mock_objects["queue_fn"]
mock_connector.put.return_value = (False, 0, None)
result = try_send_via_connector(
connector=mock_connector,
stage_id=0,
next_stage_id=1,
req_id="req_fail",
next_inputs={},
sampling_params={},
original_prompt="",
next_stage_queue_submit_fn=mock_queue_fn,
metrics=mock_metrics,
)
assert result is False
mock_queue_fn.assert_not_called()
def test_recv_success(mock_objects):
"""Test try_recv_via_connector success path."""
mock_connector = mock_objects["connector"]
# Setup task received from queue
task = {
"request_id": "req_recv",
"from_connector": True,
"from_stage": "0",
"connector_metadata": {"handle": "xyz"},
}
# Setup connectors dict
connectors = {("0", "1"): mock_connector}
# Mock connector.get return
expected_data = {"engine_inputs": {"ids": [1]}}
# get returns: (data_obj, size)
mock_connector.get.return_value = (expected_data, 50)
# serialize_obj needed for metrics calculation if size not returned directly
mock_connector.serialize_obj.return_value = b"bytes"
# Execute
# We are stage 1 receiving from stage 0
inputs, rx_metrics = try_recv_via_connector(task, connectors, stage_id=1)
# Verify
assert inputs == expected_data["engine_inputs"]
assert rx_metrics is not None
mock_connector.get.assert_called_once_with("0", "1", "req_recv", metadata={"handle": "xyz"})
def test_recv_no_connector():
"""Test recv fails when no connector exists for edge."""
task = {"request_id": "req_missing", "from_connector": True, "from_stage": "0"}
connectors = {} # Empty connectors
inputs, _ = try_recv_via_connector(task, connectors, stage_id=1)
assert inputs is None
def test_shm_connector_flow():
"""
Verify the full flow: Send -> Adapter -> Connector -> Adapter -> Recv.
Using real SharedMemoryConnector (inline mode for simplicity).
"""
# 1. Setup Connector
config = {"shm_threshold_bytes": 1024} # Large threshold to use inline
connector = SharedMemoryConnector(config)
connectors_map = {("0", "1"): connector}
# 2. Setup Data
stage_id = 0
next_stage_id = 1
req_id = "flow_req"
inputs = {"tokens": [10, 20, 30]}
sampling_params = {"n": 1}
# Queue capture mechanism
queue_capture = []
def mock_submit(payload):
queue_capture.append(payload)
mock_metrics = MagicMock()
# 3. Send
success = try_send_via_connector(
connector=connector,
stage_id=stage_id,
next_stage_id=next_stage_id,
req_id=req_id,
next_inputs=inputs,
sampling_params=sampling_params,
original_prompt="prompt",
next_stage_queue_submit_fn=mock_submit,
metrics=mock_metrics,
)
assert success is True
assert len(queue_capture) == 1
# 4. Recv
# The 'task' is what would be popped from the queue
received_task = queue_capture[0]
# Verify queue payload contains what we expect
assert received_task["from_connector"] is True
assert received_task["from_stage"] == "0"
# Decode
decoded_inputs, _ = try_recv_via_connector(received_task, connectors_map, stage_id=1)
# 5. Verify Data Integrity
assert decoded_inputs == inputs
def test_get_connectors_for_stage():
"""Test filtering logic for stage config."""
# Config has edges: 0->1, 1->2
config = OmniTransferConfig(connectors={("0", "1"): ConnectorSpec(name="C1"), ("1", "2"): ConnectorSpec(name="C2")})
# Get config for Stage 1
# Stage 1 receives from 0 (input) and sends to 2 (output)
# get_connectors_config_for_stage ONLY returns INPUT connectors for the worker to initialize
stage_config = get_connectors_config_for_stage(config, stage_id=1)
# Should contain "from_stage_0"
assert "from_stage_0" in stage_config
assert stage_config["from_stage_0"]["spec"]["name"] == "C1"
# Should NOT contain "from_stage_1" or related to output
assert "from_stage_1" not in stage_config
# Verify Stage 2
stage_2_config = get_connectors_config_for_stage(config, stage_id=2)
assert "from_stage_1" in stage_2_config
assert stage_2_config["from_stage_1"]["spec"]["name"] == "C2"
def test_recv_with_missing_metadata():
"""Test recv when queue payload is malformed (missing metadata)."""
# Connector expects metadata but task doesn't have it
task = {
"request_id": "req_bad",
"from_connector": True,
"from_stage": "0",
# Missing "connector_metadata"
}
mock_conn = MagicMock()
# If get is called with None metadata, connector usually handles it or adapter handles exception
mock_conn.get.side_effect = Exception("Get failed")
connectors = {("0", "1"): mock_conn}
inputs, _ = try_recv_via_connector(task, connectors, stage_id=1)
assert inputs is None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from vllm_omni.distributed.omni_connectors.connectors.shm_connector import SharedMemoryConnector
from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer
def test_basic_serialization():
"""Test basic msgpack serialization."""
data = {"key": "value", "list": [1, 2, 3]}
serialized = OmniSerializer.serialize(data)
assert isinstance(serialized, bytes)
deserialized = OmniSerializer.deserialize(serialized)
assert data == deserialized
def test_tensor_serialization():
"""Test torch.Tensor serialization."""
import torch
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
serialized = OmniSerializer.serialize(tensor)
deserialized = OmniSerializer.deserialize(serialized)
assert torch.equal(tensor, deserialized)
def test_ndarray_serialization():
"""Test numpy.ndarray serialization."""
import numpy as np
arr = np.array([[1, 2, 3], [4, 5, 6]])
serialized = OmniSerializer.serialize(arr)
deserialized = OmniSerializer.deserialize(serialized)
assert np.array_equal(arr, deserialized)
def test_create_shm_connector():
"""Test creating SharedMemoryConnector via Factory."""
spec = ConnectorSpec(name="SharedMemoryConnector", extra={"shm_threshold_bytes": 1024})
connector = OmniConnectorFactory.create_connector(spec)
assert isinstance(connector, SharedMemoryConnector)
assert connector.threshold == 1024
def test_create_unknown_connector():
"""Test error when creating unknown connector."""
spec = ConnectorSpec(name="UnknownConnector")
with pytest.raises(ValueError):
OmniConnectorFactory.create_connector(spec)
@pytest.fixture
def shm_connector():
config = {"shm_threshold_bytes": 100} # Small threshold for testing
return SharedMemoryConnector(config)
def test_put_get_inline(shm_connector):
"""Test inline transfer for small data."""
data = {"small": "data"}
# Ensure data is smaller than threshold (100 bytes)
success, size, metadata = shm_connector.put("stage_0", "stage_1", "req_1", data)
assert success is True
assert "inline_bytes" in metadata
assert "shm" not in metadata
# Retrieve
retrieved_data, ret_size = shm_connector.get("stage_0", "stage_1", "req_1", metadata)
assert data == retrieved_data
assert size == ret_size
def test_put_get_shm(shm_connector, monkeypatch):
"""Test SHM transfer logic for large data (Mocked)."""
# Create data larger than 100 bytes
data = {"large": "x" * 200}
# Mock SHM return values
mock_handle = {"name": "test_shm", "size": 200}
mock_write = MagicMock(return_value=mock_handle)
monkeypatch.setattr("vllm_omni.distributed.omni_connectors.connectors.shm_connector.shm_write_bytes", mock_write)
# When reading, return the serialized bytes of the data
serialized_data = shm_connector.serialize_obj(data)
mock_read = MagicMock(return_value=serialized_data)
monkeypatch.setattr("vllm_omni.distributed.omni_connectors.connectors.shm_connector.shm_read_bytes", mock_read)
# Put
success, size, metadata = shm_connector.put("stage_0", "stage_1", "req_2", data)
assert success is True
# Should use SHM because data > threshold
assert "shm" in metadata
assert metadata["shm"] == mock_handle
assert "inline_bytes" not in metadata
mock_write.assert_called_once()
# Get
retrieved_data, ret_size = shm_connector.get("stage_0", "stage_1", "req_2", metadata)
assert data == retrieved_data
mock_read.assert_called_once_with(mock_handle)
def test_get_invalid_metadata(shm_connector):
"""Test get with invalid metadata."""
result = shm_connector.get("stage_0", "stage_1", "req_3", {})
assert result is None
result = shm_connector.get("stage_0", "stage_1", "req_3", {"unknown": "format"})
assert result is None
import pytest
import torch
from tests.utils import hardware_test
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import (
OmniKVCacheConfig,
OmniKVTransferManager,
)
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
class MockConnector:
def __init__(self):
self.store = {}
def put(self, from_stage, to_stage, put_key, data):
# The manager now passes full key as put_key
key = f"{from_stage}->{to_stage}:{put_key}"
self.store[key] = data
return True, len(str(data)), None # (success, size, metadata)
def get(self, from_stage, to_stage, get_key, metadata=None):
# The manager now passes full key as get_key
key = f"{from_stage}->{to_stage}:{get_key}"
if key in self.store:
return self.store[key], len(str(self.store[key]))
return None
@pytest.fixture
def mock_connector():
return MockConnector()
@pytest.fixture
def kv_config():
return OmniKVCacheConfig(
connector_config={"type": "mock"},
from_stage="stage1",
to_stage="stage2",
stage_id="stage2", # Acting as receiver for some tests
need_recv_cache=True,
need_send_cache=True,
recv_timeout=1.0, # Short timeout for tests
)
@pytest.fixture
def common_constants():
return {
"num_layers": 2,
"num_heads": 4,
"head_dim": 16,
"block_size": 8,
"seq_len": 20,
"req_id": "req_test_1",
}
@pytest.mark.cache
@hardware_test(
res={"cuda": "L4"},
num_cards=2,
)
def test_manager_extraction(kv_config, mock_connector, common_constants):
"""Test extraction and sending logic in OmniKVTransferManager."""
num_layers = common_constants["num_layers"]
block_size = common_constants["block_size"]
num_heads = common_constants["num_heads"]
head_dim = common_constants["head_dim"]
seq_len = common_constants["seq_len"]
req_id = common_constants["req_id"]
num_blocks = 10
kv_caches = []
for _ in range(num_layers):
k_cache = torch.randn(num_blocks, block_size, num_heads, head_dim)
v_cache = torch.randn(num_blocks, block_size, num_heads, head_dim)
# Stack K and V to create [2, num_blocks, block_size, n_heads, head_dim]
layer_cache = torch.stack([k_cache, v_cache], dim=0)
kv_caches.append(layer_cache)
block_ids = [1, 3, 5]
finished_reqs = {req_id: {"block_ids": block_ids, "seq_len": seq_len}}
manager = OmniKVTransferManager(kv_config)
# Mock the connector factory or injection
manager._connector = mock_connector
processed = manager.handle_finished_requests_kv_transfer(finished_reqs, kv_caches, block_size, "float32")
assert req_id in processed
# Check if data was put into connector
# Manager builds full key: omni_{from}_to_{to}_kv_cache_{req_id}
full_request_id = f"omni_stage1_to_stage2_kv_cache_{req_id}"
expected_key = f"stage1->stage2:{full_request_id}"
assert expected_key in mock_connector.store
data = mock_connector.store[expected_key]
assert data["request_id"] == req_id
assert "layer_blocks" in data
assert len(data["layer_blocks"]["key_cache"]) == num_layers
# Verify shape of extracted tensor: [seq_len, heads, dim]
# Note: Manager detaches and moves to CPU
expected_shape = (seq_len, num_heads, head_dim)
assert data["layer_blocks"]["key_cache"][0].shape == expected_shape
@pytest.mark.cache
@hardware_test(
res={"cuda": "L4"},
num_cards=2,
)
def test_manager_reception(kv_config, mock_connector, common_constants):
"""Test reception and injection logic in OmniKVTransferManager."""
num_layers = common_constants["num_layers"]
block_size = common_constants["block_size"]
num_heads = common_constants["num_heads"]
head_dim = common_constants["head_dim"]
seq_len = common_constants["seq_len"]
req_id = common_constants["req_id"]
expected_shape = (seq_len, num_heads, head_dim)
key_cache = [torch.randn(expected_shape) for _ in range(num_layers)]
value_cache = [torch.randn(expected_shape) for _ in range(num_layers)]
layer_blocks = {"key_cache": key_cache, "value_cache": value_cache}
metadata = {
"block_size": block_size,
"num_layers": num_layers,
"dtype": "float32",
"seq_len": seq_len,
}
data_to_receive = {
"request_id": req_id,
"layer_blocks": layer_blocks,
"metadata": metadata,
"block_ids": [],
}
# In setUp, from_stage="stage1", stage_id="stage2". recv_stages=("stage1", "stage2")
manager = OmniKVTransferManager(kv_config)
manager._connector = mock_connector
# Pre-populate connector with data
# Manager builds full key: omni_{from}_to_{to}_kv_cache_{req_id}
full_request_id = f"omni_stage1_to_stage2_kv_cache_{req_id}"
store_key = f"stage1->stage2:{full_request_id}"
mock_connector.store[store_key] = data_to_receive
req = OmniDiffusionRequest(
prompts=["test_recv"],
sampling_params=OmniDiffusionSamplingParams(),
request_ids=[req_id],
)
# req.need_kv_receive = True # Implicitly handled by receive_kv_cache check? No, manager doesn't check it, runner does.
# But receive_kv_cache in manager checks request_id. Which we need to fix in manager next.
success = manager.receive_kv_cache(req, target_device=torch.device("cpu"))
assert success
assert hasattr(req, "past_key_values")
assert hasattr(req, "kv_metadata")
assert len(req.past_key_values.key_cache) == num_layers
assert torch.allclose(req.past_key_values.key_cache[0], key_cache[0])
assert req.kv_metadata["seq_len"] == seq_len
@pytest.mark.cache
@hardware_test(
res={"cuda": "L4"},
num_cards=2,
)
def test_integration_flow(common_constants):
"""Simulate extraction -> connector -> reception."""
num_layers = common_constants["num_layers"]
block_size = common_constants["block_size"]
num_heads = common_constants["num_heads"]
head_dim = common_constants["head_dim"]
req_id = common_constants["req_id"]
sender_config = OmniKVCacheConfig(
connector_config={"type": "mock"}, from_stage="sender", to_stage="receiver", need_send_cache=True
)
sender_manager = OmniKVTransferManager(sender_config)
connector = MockConnector()
sender_manager._connector = connector # Shared connector
# Create Data
num_blocks = 5
kv_caches = []
for _ in range(num_layers):
layer = torch.randn(2, num_blocks, block_size, num_heads, head_dim)
kv_caches.append(layer)
finished_reqs = {req_id: {"block_ids": [0, 1], "seq_len": 10}}
# Send
sender_manager.handle_finished_requests_kv_transfer(finished_reqs, kv_caches, block_size, "float32")
receiver_config = OmniKVCacheConfig(
connector_config={"type": "mock"},
from_stage="sender",
stage_id="receiver",
need_recv_cache=True,
recv_timeout=1.0,
)
receiver_manager = OmniKVTransferManager(receiver_config)
receiver_manager._connector = connector # Share the same mock connector instance
req = OmniDiffusionRequest(
prompts=["test_integ"],
sampling_params=OmniDiffusionSamplingParams(),
request_ids=[req_id],
)
# Receive
success = receiver_manager.receive_kv_cache(req)
# Verify
assert success
assert req.past_key_values is not None
assert req.kv_metadata["seq_len"] == 10
@pytest.mark.cache
@hardware_test(
res={"cuda": "L4"},
num_cards=2,
)
def test_manager_extraction_no_connector(kv_config, common_constants):
"""Test extraction when connector is unavailable (should still return IDs)."""
block_size = common_constants["block_size"]
req_id = common_constants["req_id"]
manager = OmniKVTransferManager(kv_config)
# Force connector to be None
manager._connector = None
manager.config.connector_config = None
finished_reqs = {req_id: {"block_ids": [1, 2], "seq_len": 10}}
processed = manager.handle_finished_requests_kv_transfer(
finished_reqs, kv_caches=[], block_size=block_size, cache_dtype="float32"
)
assert req_id in processed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
import pytest
# Use the new import path for initialization utilities
from vllm_omni.distributed.omni_connectors.utils.initialization import load_omni_transfer_config
def get_config_files():
"""Helper to find config files."""
# Go up two levels from 'tests/distributed/omni_connectors' (approx) to 'vllm-omni' root
# Adjust based on file location: vllm-omni/tests/distributed/omni_connectors/test_omni_connector_configs.py
# This file is 4 levels deep from root if we count from tests?
# vllm-omni/tests/distributed/omni_connectors -> parent -> distributed -> parent -> tests -> parent -> vllm-omni
# Let's use resolve to be safe.
# Path(__file__) = .../vllm-omni/tests/distributed/omni_connectors/test_omni_connector_configs.py
# .parent = omni_connectors
# .parent = distributed
# .parent = tests
# .parent = vllm-omni
base_dir = Path(__file__).resolve().parent.parent.parent.parent
config_dir = base_dir / "vllm_omni" / "model_executor" / "stage_configs"
if not config_dir.exists():
return []
return list(config_dir.glob("qwen*.yaml"))
# Collect files at module level for parametrization
config_files = get_config_files()
@pytest.mark.skipif(len(config_files) == 0, reason="No config files found or directory missing")
@pytest.mark.parametrize("yaml_file", config_files, ids=lambda p: p.name)
def test_load_qwen_yaml_configs(yaml_file):
"""
Scan and test loading of all qwen*.yaml config files.
This ensures that existing stage configs are compatible with the OmniConnector system.
"""
print(f"Testing config load: {yaml_file.name}")
try:
# Attempt to load the config
# default_shm_threshold doesn't matter much for loading correctness, using default
config = load_omni_transfer_config(yaml_file)
assert config is not None, "Config should not be None"
# Basic validation
# Note: Some configs might not have 'runtime' or 'connectors' section if they rely on auto-shm
# but the load function should succeed regardless.
# If the config defines stages, we expect connectors to be populated (either explicit or auto SHM)
# We can't strictly assert len(config.connectors) > 0 because a single stage pipeline might have 0 edges.
print(f" -> Successfully loaded. Connectors: {len(config.connectors)}")
except Exception as e:
pytest.fail(f"Failed to load config {yaml_file.name}: {e}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment