Commit c1cacde6 authored by weishb's avatar weishb
Browse files

vllm-omni_0.15.0.rc1+fix1 first commit

parent 35607782
import os
import sys
from pathlib import Path
import pytest
import torch
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from vllm_omni import Omni
from vllm_omni.outputs import OmniRequestOutput
# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
models = ["Wan-AI/Wan2.2-T2V-A14B-Diffusers"]
@pytest.mark.parametrize("model_name", models)
def test_video_diffusion_model(model_name: str):
m = Omni(
model=model_name,
boundary_ratio=0.875,
flow_shift=5.0,
)
# Use minimal settings for testing
# num_frames must satisfy: num_frames % vae_scale_factor_temporal == 1
# For Wan2.2, vae_scale_factor_temporal=4, so valid values are 5, 9, 13, 17, ...
height = 480
width = 640
num_frames = 5
outputs = m.generate(
prompts="A cat sitting on a table",
sampling_params_list=OmniDiffusionSamplingParams(
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=2,
guidance_scale=1.0,
generator=torch.Generator("cuda").manual_seed(42),
),
)
first_output = outputs[0]
assert first_output.final_output_type == "image"
if not hasattr(first_output, "request_output") or not first_output.request_output:
raise ValueError("No request_output found in OmniRequestOutput")
req_out = first_output.request_output[0]
if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
raise ValueError("Invalid request_output structure or missing 'images' key")
frames = req_out.images[0]
assert frames is not None
assert hasattr(frames, "shape")
# frames shape: (batch, num_frames, height, width, channels)
assert frames.shape[1] == num_frames
assert frames.shape[2] == height
assert frames.shape[3] == width
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
System test for TeaCache backend.
This test verifies that TeaCache acceleration works correctly with diffusion models.
It uses minimal settings to keep test time short for CI.
"""
import os
import sys
from pathlib import Path
import pytest
import torch
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from vllm_omni import Omni
from vllm_omni.outputs import OmniRequestOutput
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
# Use random weights model for testing
models = ["riverclouds/qwen_image_random"]
@pytest.mark.parametrize("model_name", models)
def test_teacache(model_name: str):
"""Test TeaCache backend with diffusion model."""
# Configure TeaCache with default settings for fast testing
cache_config = {
"rel_l1_thresh": 0.2, # Default threshold
}
m = None
try:
m = Omni(
model=model_name,
cache_backend="tea_cache",
cache_config=cache_config,
)
# Use minimal settings for fast testing
height = 256
width = 256
num_inference_steps = 4 # Minimal steps for fast test
outputs = m.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=0.0,
generator=torch.Generator("cuda").manual_seed(42),
num_outputs_per_prompt=1, # Single output for speed
),
)
# Extract images from request_output[0]['images']
first_output = outputs[0]
assert first_output.final_output_type == "image"
if not hasattr(first_output, "request_output") or not first_output.request_output:
raise ValueError("No request_output found in OmniRequestOutput")
req_out = first_output.request_output[0]
if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
raise ValueError("Invalid request_output structure or missing 'images' key")
images = req_out.images
# Verify generation succeeded
assert images is not None
assert len(images) == 1
# Check image size
assert images[0].width == width
assert images[0].height == height
except Exception as e:
print(f"Test failed with error: {e}")
raise
finally:
if m is not None and hasattr(m, "close"):
m.close()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import sys
import time
from pathlib import Path
import numpy as np
import pytest
import torch
from PIL import Image
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from tests.utils import GPUMemoryMonitor
from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
PROMPT = "a photo of a cat sitting on a laptop keyboard"
def _get_zimage_model() -> str:
# Allow overriding the model for local/offline environments.
# Can be either a HuggingFace repo id or a local path.
return os.environ.get("VLLM_TEST_ZIMAGE_MODEL", "Tongyi-MAI/Z-Image-Turbo")
def _pil_to_float_rgb_tensor(img: Image.Image) -> torch.Tensor:
"""Convert PIL image to float32 RGB tensor in [0, 1] with shape [H, W, 3]."""
arr = np.asarray(img.convert("RGB"), dtype=np.float32) / 255.0
return torch.from_numpy(arr)
def _diff_metrics(a: Image.Image, b: Image.Image) -> tuple[float, float]:
"""Return (mean_abs_diff, max_abs_diff) over RGB pixels in [0, 1]."""
ta = _pil_to_float_rgb_tensor(a)
tb = _pil_to_float_rgb_tensor(b)
assert ta.shape == tb.shape, f"Image shapes differ: {ta.shape} vs {tb.shape}"
abs_diff = torch.abs(ta - tb)
return abs_diff.mean().item(), abs_diff.max().item()
def _extract_single_image(outputs) -> Image.Image:
first_output = outputs[0]
assert first_output.final_output_type == "image"
if not hasattr(first_output, "request_output") or not first_output.request_output:
raise ValueError("No request_output found in OmniRequestOutput")
req_out = first_output.request_output[0]
if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
raise ValueError("Invalid request_output structure or missing 'images' key")
images = req_out.images
if images is None or len(images) != 1:
raise ValueError(f"Expected 1 image, got {0 if images is None else len(images)}")
return images[0]
def _run_zimage_generate(
*, tp_size: int, height: int, width: int, num_inference_steps: int, seed: int
) -> tuple[Image.Image, float, float]:
torch.cuda.empty_cache()
device_index = torch.cuda.current_device()
monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
m = Omni(
model=_get_zimage_model(),
parallel_config=DiffusionParallelConfig(tensor_parallel_size=tp_size),
)
try:
# NOTE: Omni closes itself when a generate() call is exhausted.
# To avoid measuring teardown time (process shutdown, memory cleanup),
# we measure the latency to produce *subsequent* outputs within a single
# generator run.
#
# This also serves as a warmup: the first output may include extra
# compilation/caching overhead, while later outputs are closer to
# steady-state inference.
num_requests = 4 # 1 warmup + 3 timed
gen = m.generate(
[PROMPT] * num_requests,
OmniDiffusionSamplingParams(
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=0.0,
seed=seed,
num_outputs_per_prompt=1,
),
py_generator=True,
)
warmup_output = next(gen)
t_prev = time.perf_counter()
per_request_times_s: list[float] = []
last_output = warmup_output
for _ in range(num_requests - 1):
last_output = next(gen)
t_now = time.perf_counter()
per_request_times_s.append(t_now - t_prev)
t_prev = t_now
# Ensure the generator is fully consumed so it can clean up.
for _ in gen:
pass
median_time_s = float(np.median(per_request_times_s))
peak_memory_mb = monitor.peak_used_mb
return _extract_single_image([last_output]), median_time_s, peak_memory_mb
finally:
monitor.stop()
cleanup_dist_env_and_memory()
@pytest.mark.integration
def test_zimage_tensor_parallel_tp2(tmp_path: Path):
if current_omni_platform.is_npu() or current_omni_platform.is_rocm():
pytest.skip("Z-Image TP e2e test is only supported on CUDA for now.")
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
pytest.skip("Z-Image TP=2 requires >= 2 CUDA devices.")
height = 512
width = 512
num_inference_steps = 2
seed = 42
tp1_img, tp1_time_s, tp1_peak_mem = _run_zimage_generate(
tp_size=1,
height=height,
width=width,
num_inference_steps=num_inference_steps,
seed=seed,
)
tp2_img, tp2_time_s, tp2_peak_mem = _run_zimage_generate(
tp_size=2,
height=height,
width=width,
num_inference_steps=num_inference_steps,
seed=seed,
)
tp1_path = tmp_path / "zimage_tp1.png"
tp2_path = tmp_path / "zimage_tp2.png"
tp1_img.save(tp1_path)
tp2_img.save(tp2_path)
assert tp1_img.width == width and tp1_img.height == height
assert tp2_img.width == width and tp2_img.height == height
mean_abs_diff, max_abs_diff = _diff_metrics(tp1_img, tp2_img)
mean_threshold = 3e-2
max_threshold = 5e-1
print(
"Z-Image TP image diff stats (TP=1 vs TP=2): "
f"mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e}; "
f"thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e}; "
f"tp1_img={tp1_path}, tp2_img={tp2_path}"
)
assert mean_abs_diff <= mean_threshold and max_abs_diff <= max_threshold, (
f"Image diff exceeded threshold: mean_abs_diff={mean_abs_diff:.6e}, max_abs_diff={max_abs_diff:.6e} "
f"(thresholds: mean<={mean_threshold:.6e}, max<={max_threshold:.6e})"
)
print(f"Z-Image TP perf (lower is better): tp1_time_s={tp1_time_s:.6f}, tp2_time_s={tp2_time_s:.6f}")
assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})"
print(f"Z-Image TP peak memory (MB): tp1_peak_mem={tp1_peak_mem:.2f}, tp2_peak_mem={tp2_peak_mem:.2f}")
assert tp2_peak_mem < tp1_peak_mem, (
f"Expected TP=2 to use less peak memory than TP=1 (tp1={tp1_peak_mem}, tp2={tp2_peak_mem})"
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import functools
import os
import signal
import subprocess
import sys
import tempfile
from collections.abc import Callable
from contextlib import ExitStack, suppress
from pathlib import Path
from typing import Any, Literal
import cloudpickle
from typing_extensions import ParamSpec
from vllm.platforms import current_platform
VLLM_PATH = Path(__file__).parent.parent.parent
"""Path to root of the vLLM repository."""
_P = ParamSpec("_P")
def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
# Create a unique temporary file to store exception info from child
# process. Use test function name and process ID to avoid collisions.
with (
tempfile.NamedTemporaryFile(
delete=False,
mode="w+b",
prefix=f"vllm_test_{func.__name__}_{os.getpid()}_",
suffix=".exc",
) as exc_file,
ExitStack() as delete_after,
):
exc_file_path = exc_file.name
delete_after.callback(os.remove, exc_file_path)
pid = os.fork()
print(f"Fork a new process to run a test {pid}")
if pid == 0:
# Parent process responsible for deleting, don't delete
# in child.
delete_after.pop_all()
try:
func(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception as e:
import traceback
tb_string = traceback.format_exc()
# Try to serialize the exception object first
exc_to_serialize: dict[str, Any]
try:
# First, try to pickle the actual exception with
# its traceback.
exc_to_serialize = {"pickled_exception": e}
# Test if it can be pickled
cloudpickle.dumps(exc_to_serialize)
except (Exception, KeyboardInterrupt):
# Fall back to string-based approach.
exc_to_serialize = {
"exception_type": type(e).__name__,
"exception_msg": str(e),
"traceback": tb_string,
}
try:
with open(exc_file_path, "wb") as f:
cloudpickle.dump(exc_to_serialize, f)
except Exception:
# Fallback: just print the traceback.
print(tb_string)
os._exit(1)
else:
os._exit(0)
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_signal_handler)
if _exitcode != 0:
# Try to read the exception from the child process
exc_info = {}
if os.path.exists(exc_file_path):
with (
contextlib.suppress(Exception),
open(exc_file_path, "rb") as f,
):
exc_info = cloudpickle.load(f)
original_exception = exc_info.get("pickled_exception")
if original_exception is not None and isinstance(original_exception, Exception):
# Re-raise the actual exception object if it was
# successfully pickled.
raise original_exception
if (original_tb := exc_info.get("traceback")) is not None:
# Use string-based traceback for fallback case
raise AssertionError(
f"Test {func.__name__} failed when called with"
f" args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode}):\n{original_tb}"
) from None
# Fallback to the original generic error
raise AssertionError(
f"function {func.__name__} failed when called with"
f" args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode})"
) from None
return wrapper
def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to spawn a new process for each test function."""
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Check if we're already in a subprocess
if os.environ.get("RUNNING_IN_SUBPROCESS") == "1":
# If we are, just run the function directly
return f(*args, **kwargs)
import torch.multiprocessing as mp
with suppress(RuntimeError):
mp.set_start_method("spawn")
# Get the module
module_name = f.__module__
# Create a process with environment variable set
env = os.environ.copy()
env["RUNNING_IN_SUBPROCESS"] = "1"
with tempfile.TemporaryDirectory() as tempdir:
output_filepath = os.path.join(tempdir, "new_process.tmp")
# `cloudpickle` allows pickling complex functions directly
input_bytes = cloudpickle.dumps((f, output_filepath))
repo_root = str(VLLM_PATH.resolve())
env = dict(env or os.environ)
env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "")
cmd = [sys.executable, "-m", f"{module_name}"]
returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e
return wrapper
def create_new_process_for_each_test(
method: Literal["spawn", "fork"] | None = None,
) -> Callable[[Callable[_P, None]], Callable[_P, None]]:
"""Creates a decorator that runs each test function in a new process.
Args:
method: The process creation method. Can be either "spawn" or "fork".
If not specified, it defaults to "spawn" on ROCm and XPU
platforms and "fork" otherwise.
Returns:
A decorator to run test functions in separate processes.
"""
if method is None:
# TODO: Find out why spawn is not working correctly on ROCm
# The test content will not run and tests passed immediately.
# For now, using `fork` for ROCm as it can run with `fork`
# and tests are running correctly.
use_spawn = current_platform.is_xpu()
method = "spawn" if use_spawn else "fork"
assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'"
if method == "fork":
return fork_new_process_for_each_test
return spawn_new_process_for_each_test
# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
# Stage 0: Thinker (multimodal understanding + text generation)
# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
# The following config has been verified on 2x H100-80G GPUs.
stage_args:
- stage_id: 0
stage_type: llm # Use llm stage type to launch OmniLLM
runtime:
devices: "0"
max_batch_size: 5
engine_args:
model_stage: thinker
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.9
enforce_eager: false
trust_remote_code: true
engine_output_type: latent # Output hidden states for talker
distributed_executor_backend: "mp"
enable_prefix_caching: false
max_num_batched_tokens: 32768
hf_config_name: thinker_config
tensor_parallel_size: 1
load_format: dummy
final_output: true
final_output_type: text
is_comprehension: true
default_sampling_params:
temperature: 0.4
top_p: 0.9
top_k: 1
max_tokens: 100
seed: 42
detokenize: True
repetition_penalty: 1.05
- stage_id: 1
stage_type: llm # Use llm stage type to launch OmniLLM
runtime:
devices: "1"
max_batch_size: 5
engine_args:
model_stage: talker
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.6
enforce_eager: false
trust_remote_code: true
engine_output_type: latent # Output codec codes for code2wav
# tensor_parallel_size: 2
enable_prefix_caching: false
distributed_executor_backend: "mp"
hf_config_name: talker_config
load_format: dummy
engine_input_source: [0]
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
# final_output: true
# final_output_type: text
default_sampling_params:
temperature: 0.9
top_k: 50
max_tokens: 100
seed: 42
detokenize: False
repetition_penalty: 1.05
stop_token_ids: [2150]
- stage_id: 2
stage_type: llm # Use llm stage type to launch OmniLLM
runtime:
devices: "1"
max_batch_size: 1
engine_args:
model_stage: code2wav
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: generation
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
trust_remote_code: true
async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio # Final output: audio waveform
gpu_memory_utilization: 0.1
distributed_executor_backend: "mp"
max_num_batched_tokens: 1000000
hf_config_name: thinker_config
load_format: dummy
engine_input_source: [1]
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
final_output: true
final_output_type: audio
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
max_tokens: 200
seed: 42
detokenize: True
repetition_penalty: 1.1
# The following config has been verified on 2x H100-80G GPUs.
stage_args:
- stage_id: 0
runtime:
devices: "0,1"
max_batch_size: 5
engine_args:
model_stage: thinker
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.6
enforce_eager: true
trust_remote_code: true
engine_output_type: latent # Output hidden states for talker
distributed_executor_backend: "mp"
enable_prefix_caching: false
hf_config_name: thinker_config
tensor_parallel_size: 2
load_format: dummy
final_output: true
final_output_type: text
is_comprehension: true
default_sampling_params:
temperature: 0.4
top_p: 0.9
top_k: 1
max_tokens: 100
seed: 42
detokenize: True
repetition_penalty: 1.05
# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
# Stage 0: Thinker (multimodal understanding + text generation)
# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
# The following config has been verified on 2x H100-80G GPUs.
stage_args:
- stage_id: 0
runtime:
devices: "0"
max_batch_size: 5
engine_args:
model_stage: thinker
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.9
enforce_eager: false
trust_remote_code: true
engine_output_type: latent # Output hidden states for talker
distributed_executor_backend: "mp"
enable_prefix_caching: false
hf_config_name: thinker_config
tensor_parallel_size: 1
final_output: true
final_output_type: text
is_comprehension: true
default_sampling_params:
temperature: 0.4
top_p: 0.9
top_k: 1
max_tokens: 100
seed: 42
detokenize: True
repetition_penalty: 1.05
- stage_id: 1
runtime:
devices: "1"
max_batch_size: 5
engine_args:
model_stage: talker
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.6
enforce_eager: true
trust_remote_code: true
engine_output_type: latent # Output codec codes for code2wav
# tensor_parallel_size: 2
enable_prefix_caching: false
distributed_executor_backend: "mp"
hf_config_name: talker_config
engine_input_source: [0]
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
# final_output: true
# final_output_type: text
default_sampling_params:
temperature: 0.9
top_k: 50
max_tokens: 1000
seed: 42
detokenize: False
repetition_penalty: 1.05
stop_token_ids: [2150]
- stage_id: 2
runtime:
devices: "1"
max_batch_size: 1
engine_args:
model_stage: code2wav
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: generation
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
trust_remote_code: true
enable_prefix_caching: false
engine_output_type: audio # Final output: audio waveform
gpu_memory_utilization: 0.1
distributed_executor_backend: "mp"
max_num_batched_tokens: 1000000
hf_config_name: thinker_config
async_scheduling: false
engine_input_source: [1]
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
final_output: true
final_output_type: audio
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
max_tokens: 2000
seed: 42
detokenize: True
repetition_penalty: 1.1
import asyncio
import os
import sys
from contextlib import ExitStack
from pathlib import Path
import pytest
from vllm import SamplingParams
from vllm.inputs import PromptType
from vllm_omni.entrypoints.async_omni import AsyncOmni, ClientRequestState
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
SEED = 42
stage_config = str(Path(__file__).parent / "stage_configs" / "qwen3_omni_thinker_ci.yaml")
model = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
async def generate(
engine: AsyncOmni,
request_id: str,
prompt: PromptType,
max_tokens: int,
) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)
thinker_sampling_params = SamplingParams(
temperature=0.4, # Deterministic
top_p=0.9,
top_k=1,
max_tokens=max_tokens,
repetition_penalty=1.05,
stop_token_ids=[151645], # Qwen EOS token <|im_end|>
seed=SEED,
)
sampling_params_list = [
thinker_sampling_params,
]
count = 0
async for omni_output in engine.generate(
prompt=prompt,
request_id=request_id,
sampling_params_list=sampling_params_list,
output_modalities=["text"],
):
stage_id = omni_output.stage_id
out = omni_output.request_output
if stage_id == 0:
num_tokens = sum(len(output.token_ids) for output in out.outputs)
count = num_tokens
await asyncio.sleep(0.0)
return count, request_id
@pytest.mark.asyncio
async def test_abort():
with ExitStack() as after:
# Avoid SHM IPC in tests to prevent /dev/shm exhaustion and SIGBUS.
engine = AsyncOmni(
model=model,
stage_configs_path=stage_config,
shm_threshold_bytes=sys.maxsize,
)
after.callback(engine.shutdown)
# Keep token counts modest to reduce flakiness on slow test hardware.
NUM_REQUESTS = 3
NUM_EXPECTED_TOKENS = 64
NUM_EXPECTED_TOKENS_LONG = 256
REQUEST_IDS_TO_ABORT = [1]
prompt = "Hello my name is Robert and "
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks: list[asyncio.Task] = []
for idx, request_id in enumerate(request_ids):
max_tokens = NUM_EXPECTED_TOKENS_LONG if (idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS
tasks.append(asyncio.create_task(generate(engine, request_id, prompt, max_tokens)))
# API server cancels requests when they disconnect.
# Explicitly abort in the engine to avoid orphaned requests hanging.
for idx in REQUEST_IDS_TO_ABORT:
tasks[idx].cancel()
await engine.abort(request_ids[idx])
await asyncio.sleep(0.1)
# Confirm the other requests are okay.
for idx, task in enumerate(tasks):
# Confirm that it was actually canceled.
if idx in REQUEST_IDS_TO_ABORT:
with pytest.raises((asyncio.CancelledError, GeneratorExit)):
await asyncio.wait_for(task, timeout=60)
else:
# Otherwise, make sure the request was not impacted.
num_generated_tokens, request_id = await asyncio.wait_for(task, timeout=180)
expected_tokens = NUM_EXPECTED_TOKENS
assert num_generated_tokens == expected_tokens, (
f"{request_id} generated {num_generated_tokens} but expected {expected_tokens}"
)
# Confirm we can do another generation.
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
task = asyncio.create_task(generate(engine, request_id, prompt, NUM_EXPECTED_TOKENS))
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS
await asyncio.sleep(5)
@pytest.mark.asyncio
async def test_build_and_log_summary(monkeypatch):
from vllm_omni.entrypoints.utils import get_final_stage_id_for_e2e
RealCRS = ClientRequestState
capture_metrics = {}
class MockCRS(RealCRS):
def __init__(self, request_id: str):
super().__init__(request_id)
capture_metrics[request_id] = self
monkeypatch.setattr("vllm_omni.entrypoints.async_omni.ClientRequestState", MockCRS)
monkeypatch.setattr("vllm_omni.entrypoints.client_request_state.ClientRequestState", MockCRS)
with ExitStack() as after:
# Avoid SHM IPC in tests to prevent /dev/shm exhaustion and SIGBUS.
engine = AsyncOmni(
model=model,
stage_configs_path=stage_config,
shm_threshold_bytes=sys.maxsize,
)
after.callback(engine.shutdown)
prompt = "Hello my name is Robert and "
NUM_EXPECTED_TOKENS = 64
NUM_REQUESTS = 3
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks: list[asyncio.Task] = []
for idx, request_id in enumerate(request_ids):
tasks.append(asyncio.create_task(generate(engine, request_id, prompt, NUM_EXPECTED_TOKENS)))
# Confirm the requests are okay.
for idx, task in enumerate(tasks):
await task
output_modalities = ["text"]
final_stage_id_for_e2e = get_final_stage_id_for_e2e(
output_modalities, engine.output_modalities, engine.stage_list
)
summary = capture_metrics[request_ids[idx]].metrics.build_and_log_summary(final_stage_id_for_e2e)
# Check that total tokens matches sum of stage tokens.
assert summary["e2e_total_tokens"] == sum(stage["tokens"] for stage in summary["stages"])
# Check that total time matches sum of stage times.
assert summary["e2e_total_time_ms"] >= sum(stage["total_time_ms"] for stage in summary["stages"])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
E2E online serving test for Qwen-Image-Edit-2509 multi-image input.
"""
import base64
import os
import signal
import socket
import subprocess
import sys
import threading
import time
from io import BytesIO
from typing import Any
import openai
import pytest
import requests
from PIL import Image
from vllm.assets.image import ImageAsset
from vllm.utils.network_utils import get_open_port
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Increase timeout for downloading assets from S3 (default 5s is too short for CI)
os.environ.setdefault("VLLM_IMAGE_FETCH_TIMEOUT", "60")
models = ["Qwen/Qwen-Image-Edit-2509"]
test_params = models
t2i_models = ["Tongyi-MAI/Z-Image-Turbo"]
class OmniServer:
"""Omniserver for vLLM-Omni tests."""
def __init__(
self,
model: str,
serve_args: list[str],
*,
env_dict: dict[str, str] | None = None,
) -> None:
self.model = model
self.serve_args = serve_args
self.env_dict = env_dict
self.proc: subprocess.Popen | None = None
self.host = "127.0.0.1"
self.port = get_open_port()
def _start_server(self) -> None:
"""Start the vLLM-Omni server subprocess."""
env = os.environ.copy()
env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if self.env_dict is not None:
env.update(self.env_dict)
cmd = [
sys.executable,
"-m",
"vllm_omni.entrypoints.cli.main",
"serve",
self.model,
"--omni",
"--host",
self.host,
"--port",
str(self.port),
] + self.serve_args
print(f"Launching OmniServer with: {' '.join(cmd)}")
self.proc = subprocess.Popen(
cmd,
env=env,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root
start_new_session=True,
)
# Wait for server to be ready
max_wait = 600 # 10 minutes
start_time = time.time()
while time.time() - start_time < max_wait:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
result = sock.connect_ex((self.host, self.port))
if result == 0:
print(f"Server ready on {self.host}:{self.port}")
return
except Exception:
pass
time.sleep(2)
raise RuntimeError(f"Server failed to start within {max_wait} seconds")
def __enter__(self):
self._start_server()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.proc:
try:
os.killpg(self.proc.pid, signal.SIGTERM)
except ProcessLookupError:
pass
try:
self.proc.wait(timeout=30)
except subprocess.TimeoutExpired:
try:
os.killpg(self.proc.pid, signal.SIGKILL)
except ProcessLookupError:
pass
self.proc.wait()
@pytest.fixture
def omni_server(request):
"""Start vLLM-Omni server as a subprocess with actual model weights."""
model = request.param
with OmniServer(model, ["--num-gpus", "1"]) as server:
yield server
@pytest.fixture
def client(omni_server):
"""OpenAI client for the running vLLM-Omni server."""
return openai.OpenAI(
base_url=f"http://{omni_server.host}:{omni_server.port}/v1",
api_key="EMPTY",
)
@pytest.fixture(scope="session")
def base64_encoded_images() -> list[str]:
"""Base64 encoded PNG images for testing."""
images = [
ImageAsset("cherry_blossom").pil_image.convert("RGB"),
ImageAsset("stop_sign").pil_image.convert("RGB"),
]
encoded: list[str] = []
for img in images:
with BytesIO() as buffer:
img.save(buffer, format="PNG")
encoded.append(base64.b64encode(buffer.getvalue()).decode("utf-8"))
return encoded
def dummy_messages_from_image_data(
image_data_urls: list[str],
content_text: str = "Combine these two images into one scene.",
):
"""Create messages with image data URLs for OpenAI API."""
content = [{"type": "text", "text": content_text}]
for image_url in image_data_urls:
content.append({"type": "image_url", "image_url": {"url": image_url}})
return [{"role": "user", "content": content}]
def _extract_image_data_url(message_content) -> str:
assert isinstance(message_content, list) and len(message_content) >= 1
content_part = message_content[0]
if isinstance(content_part, dict):
image_url = content_part.get("image_url", {}).get("url", "")
else:
image_url_obj = getattr(content_part, "image_url", None)
if isinstance(image_url_obj, dict):
image_url = image_url_obj.get("url", "")
else:
image_url = getattr(image_url_obj, "url", "")
assert isinstance(image_url, str) and image_url
return image_url
def _decode_data_url_to_image_bytes(data_url: str) -> bytes:
assert data_url.startswith("data:image")
_, b64_data = data_url.split(",", 1)
return base64.b64decode(b64_data)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_i2i_multi_image_input_qwen_image_edit_2509(
omni_server,
base64_encoded_images: list[str],
) -> None:
"""Test multi-image input editing via OpenAI API with concurrent requests."""
image_data_urls = [f"data:image/png;base64,{img}" for img in base64_encoded_images]
messages = dummy_messages_from_image_data(image_data_urls)
barrier = threading.Barrier(2)
results: list[tuple[int, int]] = []
def _call_chat(width: int, height: int) -> None:
client = openai.OpenAI(
base_url=f"http://{omni_server.host}:{omni_server.port}/v1",
api_key="EMPTY",
)
barrier.wait()
chat_completion = client.chat.completions.create(
model=omni_server.model,
messages=messages,
extra_body={
"height": height,
"width": width,
"num_inference_steps": 2,
"guidance_scale": 0.0,
"seed": 42,
},
)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "stop"
assert choice.message.role == "assistant"
image_data_url = _extract_image_data_url(choice.message.content)
image_bytes = _decode_data_url_to_image_bytes(image_data_url)
img = Image.open(BytesIO(image_bytes))
img.load()
results.append(img.size)
threads = [
threading.Thread(target=_call_chat, args=(1248, 832)),
threading.Thread(target=_call_chat, args=(1024, 768)),
]
for t in threads:
t.start()
for t in threads:
t.join()
# TODO @ZJY
# assert (1248, 832) in results
# assert (1024, 768) in results
@pytest.mark.parametrize("omni_server", t2i_models, indirect=True)
def test_t2i_concurrent_requests_different_sizes(omni_server) -> None:
"""Test /v1/images/generations concurrent requests with different sizes."""
base_url = f"http://{omni_server.host}:{omni_server.port}"
url = f"{base_url}/v1/images/generations"
barrier = threading.Barrier(2)
results: list[tuple[int, int]] = []
def _call_generate(size: str) -> None:
payload: dict[str, Any] = {
"prompt": "cute cat playing with a ball",
"n": 1,
"size": size,
"response_format": "b64_json",
"num_inference_steps": 2,
}
barrier.wait()
response = requests.post(url, json=payload, timeout=120)
assert response.status_code == 200
data = response.json()
image_b64 = data["data"][0]["b64_json"]
image_bytes = base64.b64decode(image_b64)
img = Image.open(BytesIO(image_bytes))
img.load()
results.append(img.size)
threads = [
threading.Thread(target=_call_generate, args=("512x512",)),
threading.Thread(target=_call_generate, args=("768x512",)),
]
for t in threads:
t.start()
for t in threads:
t.join()
assert (512, 512) in results
assert (768, 512) in results
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
E2E online serving test for /v1/images/generations with per-request LoRA.
This validates:
- The API server accepts a per-request `lora` object in the Images API payload.
- LoRA can be switched per request (adapter A -> adapter B -> no LoRA).
- Output correctness is asserted using a small image slice with tolerance.
"""
import base64
import json
import os
from io import BytesIO
from pathlib import Path
import numpy as np
import pytest
import requests
import torch
from PIL import Image
from safetensors.torch import save_file
from tests.conftest import OmniServer
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODEL = "Tongyi-MAI/Z-Image-Turbo"
PROMPT = "a photo of a cat sitting on a laptop keyboard"
SIZE = "256x256"
SEED = 42
@pytest.fixture(scope="module")
def omni_server():
with OmniServer(MODEL, ["--num-gpus", "1"]) as server:
yield server
def _write_zimage_lora(adapter_dir: Path, *, q_scale: float = 0.0, k_scale: float = 0.0, v_scale: float = 0.0):
adapter_dir.mkdir(parents=True, exist_ok=True)
# Z-Image transformer uses dim=3840 by default.
dim = 3840
module_name = "transformer.layers.0.attention.to_qkv"
rank = 1
lora_a = torch.zeros((rank, dim), dtype=torch.float32)
lora_a[0, 0] = 1.0
# QKVParallelLinear packs (Q, K, V) => out dim is 3 * dim (tp=1).
lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32)
if q_scale:
lora_b[:dim, 0] = q_scale
if k_scale:
lora_b[dim : 2 * dim, 0] = k_scale
if v_scale:
lora_b[2 * dim :, 0] = v_scale
save_file(
{
f"base_model.model.{module_name}.lora_A.weight": lora_a,
f"base_model.model.{module_name}.lora_B.weight": lora_b,
},
str(adapter_dir / "adapter_model.safetensors"),
)
(adapter_dir / "adapter_config.json").write_text(
json.dumps(
{
"r": rank,
"lora_alpha": rank,
"target_modules": [module_name],
}
),
encoding="utf-8",
)
def _post_images(server: OmniServer, payload: dict) -> Image.Image:
url = f"http://{server.host}:{server.port}/v1/images/generations"
resp = requests.post(url, json=payload, headers={"Authorization": "Bearer EMPTY"}, timeout=900)
resp.raise_for_status()
data = resp.json()
b64 = data["data"][0]["b64_json"]
img_bytes = base64.b64decode(b64)
img = Image.open(BytesIO(img_bytes))
img.load()
return img.convert("RGB")
def _image_blue_tail_slice(img: Image.Image) -> np.ndarray:
arr = np.asarray(img, dtype=np.uint8)
assert arr.ndim == 3 and arr.shape[-1] == 3
tail = arr[-3:, -3:, -1].astype(np.float32)
assert tail.shape == (3, 3)
return tail
def _slice_diff_stats(actual: np.ndarray, expected: np.ndarray) -> tuple[float, float]:
diff = np.abs(actual - expected)
return float(diff.max()), float(diff.mean())
def _assert_slice_close(
actual: np.ndarray,
expected: np.ndarray,
*,
label: str,
base_max: float,
base_mean: float,
) -> None:
assert actual.shape == (3, 3)
assert expected.shape == (3, 3)
max_diff, mean_diff = _slice_diff_stats(actual, expected)
# NOTE: Different attention backends / torch.compile can introduce small
# floating-point drift that shows up as a few LSBs in uint8 pixels. Keep
# the reset check tolerant but bounded to avoid flaky CI.
max_thresh = max(10.0, base_max + 4.0)
mean_thresh = max(6.0, base_mean + 4.0)
assert max_diff <= max_thresh and mean_diff <= mean_thresh, (
f"{label} slice mismatch (max={max_diff:.1f} > {max_thresh:.1f} or "
f"mean={mean_diff:.1f} > {mean_thresh:.1f}): {actual.tolist()}"
)
def _assert_slice_diff(actual: np.ndarray, baseline: np.ndarray, *, label: str) -> None:
assert actual.shape == (3, 3)
assert baseline.shape == (3, 3)
diff = np.abs(actual - baseline).mean()
assert diff > 0.1, f"{label} slice diff too small: {diff} ({actual.tolist()} vs {baseline.tolist()})"
def _basic_payload() -> dict:
return {
"prompt": PROMPT,
"n": 1,
"size": SIZE,
"num_inference_steps": 2,
"guidance_scale": 0.0,
"seed": SEED,
}
def test_images_generations_per_request_lora_switching(omni_server: OmniServer, tmp_path: Path) -> None:
# Base generation.
base_img = _post_images(omni_server, _basic_payload())
base_slice = _image_blue_tail_slice(base_img)
base_ref_img = _post_images(omni_server, _basic_payload())
base_ref_slice = _image_blue_tail_slice(base_ref_img)
base_ref_max, base_ref_mean = _slice_diff_stats(base_ref_slice, base_slice)
# Adapter A: apply delta to V slice only.
lora_a_dir = tmp_path / "zimage_lora_a"
_write_zimage_lora(lora_a_dir, v_scale=8.0)
payload_a = _basic_payload()
payload_a["lora"] = {"name": "a", "path": str(lora_a_dir), "scale": 64.0}
img_a = _post_images(omni_server, payload_a)
a_slice = _image_blue_tail_slice(img_a)
_assert_slice_diff(a_slice, base_slice, label="lora_a_vs_base")
a_vs_base = float(np.abs(a_slice - base_slice).mean())
# Adapter B: apply delta to K slice only (should differ from adapter A).
lora_b_dir = tmp_path / "zimage_lora_b"
_write_zimage_lora(lora_b_dir, k_scale=4.0)
payload_b = _basic_payload()
payload_b["lora"] = {"name": "b", "path": str(lora_b_dir), "scale": 64.0}
img_b = _post_images(omni_server, payload_b)
b_slice = _image_blue_tail_slice(img_b)
_assert_slice_diff(b_slice, base_slice, label="lora_b_vs_base")
_assert_slice_diff(b_slice, a_slice, label="lora_b_vs_lora_a")
b_vs_base = float(np.abs(b_slice - base_slice).mean())
b_vs_a = float(np.abs(b_slice - a_slice).mean())
# Ensure switching back to no-LoRA restores the base output.
base_img_2 = _post_images(omni_server, _basic_payload())
base_slice_2 = _image_blue_tail_slice(base_img_2)
_, base_reset_mean = _slice_diff_stats(base_slice_2, base_slice)
_assert_slice_close(
base_slice_2,
base_slice,
label="base_after_reset",
base_max=base_ref_max,
base_mean=base_ref_mean,
)
# Ensure LoRA effects are clearly above the baseline drift.
min_delta = max(base_reset_mean + 1.0, 1.5)
assert a_vs_base > min_delta, f"lora_a_vs_base drift too small: {a_vs_base} <= {min_delta}"
assert b_vs_base > min_delta, f"lora_b_vs_base drift too small: {b_vs_base} <= {min_delta}"
assert b_vs_a > min_delta, f"lora_b_vs_lora_a drift too small: {b_vs_a} <= {min_delta}"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
E2E Online tests for Qwen3-Omni model with video input and audio output.
"""
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
import concurrent.futures
import threading
import time
from pathlib import Path
import openai
import pytest
from tests.conftest import (
OmniServer,
convert_audio_to_text,
cosine_similarity_text,
dummy_messages_from_mix_data,
generate_synthetic_audio,
generate_synthetic_image,
generate_synthetic_video,
merge_base64_and_convert_to_text,
modify_stage_config,
)
from vllm_omni.platforms import current_omni_platform
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
def get_default_config():
return str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")
def get_chunk_config():
path = modify_stage_config(
get_default_config(),
updates={
"async_chunk": True,
"stage_args": {
0: {
"engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
},
1: {
"engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
},
},
},
deletes={"stage_args": {2: ["custom_process_input_func"]}},
)
return path
CHUNK_CONFIG_PATH = get_chunk_config()
# CI stage config for 2xH100-80G GPUs or AMD GPU MI325
if current_omni_platform.is_rocm():
# ROCm stage config optimized for MI325 GPU
stage_configs = [str(Path(__file__).parent / "stage_configs" / "rocm" / "qwen3_omni_ci.yaml")]
else:
stage_configs = [get_default_config(), CHUNK_CONFIG_PATH]
# Create parameter combinations for model and stage config
test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
_omni_server_lock = threading.Lock()
@pytest.fixture(scope="module")
def omni_server(request):
"""Start vLLM-Omni server as a subprocess with actual model weights.
Uses session scope so the server starts only once for the entire test session.
Multi-stage initialization can take 10-20+ minutes.
"""
with _omni_server_lock:
model, stage_config_path = request.param
print(f"Starting OmniServer with model: {model}")
with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "120"]) as server:
print("OmniServer started successfully")
yield server
print("OmniServer stopping...")
print("OmniServer stopped")
@pytest.fixture
def client(omni_server):
"""OpenAI client for the running vLLM-Omni server."""
return openai.OpenAI(
base_url=f"http://{omni_server.host}:{omni_server.port}/v1",
api_key="EMPTY",
)
def get_system_prompt():
return {
"role": "system",
"content": [
{
"type": "text",
"text": (
"You are Qwen, a virtual human developed by the Qwen Team, "
"Alibaba Group, capable of perceiving auditory and visual inputs, "
"as well as generating text and speech."
),
}
],
}
def dummy_messages_from_video_data(
video_data_url: str,
content_text: str = "Describe the video briefly.",
):
"""Create messages with video data URL for OpenAI API."""
return [
get_system_prompt(),
{
"role": "user",
"content": [
{"type": "video_url", "video_url": {"url": video_data_url}},
{"type": "text", "text": content_text},
],
},
]
def get_prompt(prompt_type="text_only"):
prompts = {
"text_only": "What is the capital of China? Answer in 20 words.",
"mix": "What is recited in the audio? What is in this image? Describe the video briefly.",
}
return prompts.get(prompt_type, prompts["text_only"])
def get_max_batch_size(size_type="few"):
batch_sizes = {"few": 5, "medium": 100, "large": 256}
return batch_sizes.get(size_type, 5)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> None:
"""
Test multi-modal input processing and text/audio output generation via OpenAI API.
Deploy Setting: default yaml
Input Modal: text + audio + video + image
Output Modal: text + audio
Input Setting: stream=True
Datasets: single request
"""
# Test single completion
e2e_list = list()
video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(5, 1)['base64']}"
messages = dummy_messages_from_mix_data(
system_prompt=get_system_prompt(),
video_data_url=video_data_url,
image_data_url=image_data_url,
audio_data_url=audio_data_url,
content_text=get_prompt("mix"),
)
# Test single completion
start_time = time.perf_counter()
chat_completion = client.chat.completions.create(model=omni_server.model, messages=messages, stream=True)
text_content = ""
audio_data = []
for chunk in chat_completion:
for choice in chunk.choices:
if hasattr(choice, "delta"):
content = getattr(choice.delta, "content", None)
else:
content = None
modality = getattr(chunk, "modality", None)
if modality == "audio" and content:
audio_data.append(content)
elif modality == "text" and content:
# Text chunk - accumulate text content
text_content += content if content else ""
# Verify E2E
current_e2e = time.perf_counter() - start_time
print(f"the request e2e is: {current_e2e}")
# TODO: Verify the E2E latency after confirmation baseline.
e2e_list.append(current_e2e)
print(f"the avg e2e is: {sum(e2e_list) / len(e2e_list)}")
# Verify all completions succeeded
assert audio_data is not None, "No audio output is generated"
# Verify text output success
assert text_content is not None and len(text_content) >= 2, "No text output is generated"
assert any(
keyword in text_content.lower() for keyword in ["square", "quadrate", "sphere", "globe", "circle", "round"]
), "The output does not contain any of the keywords."
# Verify text output same as audio output
audio_content = merge_base64_and_convert_to_text(audio_data)
print(f"text content is: {text_content}")
print(f"audio content is: {audio_content}")
similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
print(f"similarity is: {similarity}")
assert similarity > 0.9, "The audio content is not same as the text"
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_text_to_text_audio_001(client: openai.OpenAI, omni_server) -> None:
"""
Test text input processing and text/audio output generation via OpenAI API.
Deploy Setting: default yaml
Input Modal: text
Output Modal: text + audio
Datasets: few requests
"""
num_concurrent_requests = get_max_batch_size()
messages = dummy_messages_from_mix_data(system_prompt=get_system_prompt(), content_text=get_prompt())
e2e_list = list()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor:
# Submit multiple completion requests concurrently
futures = [
executor.submit(client.chat.completions.create, model=omni_server.model, messages=messages)
for _ in range(num_concurrent_requests)
]
start_time = time.perf_counter()
# Wait for all requests to complete and collect results
chat_completions = list()
for future in concurrent.futures.as_completed(futures):
chat_completions.append(future.result())
# Verify E2E
current_e2e = time.perf_counter() - start_time
print(f"the request e2e is: {current_e2e}")
# TODO: Verify the E2E latency after confirmation baseline.
e2e_list.append(current_e2e)
print(f"the avg e2e is: {sum(e2e_list) / len(e2e_list)}")
# Verify all completions succeeded
assert len(chat_completions) == num_concurrent_requests, "Not all requests succeeded."
for chat_completion in chat_completions:
# Verify audio output success
audio_data = None
text_content = None
for choice in chat_completion.choices:
if choice.message.audio is not None:
audio_message = choice.message
audio_data = audio_message.audio.data
assert audio_message.audio.expires_at > time.time(), "The generated audio has expired."
if choice.message.content is not None:
# Verify text output success
text_content = choice.message.content
assert "beijing" in text_content.lower(), "The output do not contain keywords."
# Verify text output same as audio output
audio_content = convert_audio_to_text(audio_data)
print(f"text content is: {text_content}")
print(f"audio content is: {audio_content}")
similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
print(f"similarity is: {similarity}")
assert similarity > 0.9, "The audio content is not same as the text"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
E2E Online tests for Qwen3-Omni model.
"""
import concurrent.futures
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
import time
from pathlib import Path
import openai
import pytest
from tests.conftest import (
OmniServer,
convert_audio_to_text,
cosine_similarity_text,
dummy_messages_from_mix_data,
generate_synthetic_audio,
generate_synthetic_image,
modify_stage_config,
)
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
# CI stage config for 2*H100-80G GPUs
stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")]
# Create parameter combinations for model and stage config
test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
def client(omni_server):
"""OpenAI client for the running vLLM-Omni server."""
return openai.OpenAI(
base_url=f"http://{omni_server.host}:{omni_server.port}/v1",
api_key="EMPTY",
)
def get_system_prompt():
return {
"role": "system",
"content": [
{
"type": "text",
"text": (
"You are Qwen, a virtual human developed by the Qwen Team, "
"Alibaba Group, capable of perceiving auditory and visual inputs, "
"as well as generating text and speech."
),
}
],
}
def get_prompt(prompt_type="text_only"):
prompts = {
"text_only": "What is the capital of China?",
"mix": "What is recited in the audio? What is in this image? Describe the video briefly.",
}
return prompts.get(prompt_type, prompts["text_only"])
def get_max_batch_size(size_type="few"):
batch_sizes = {"few": 5, "medium": 100, "large": 256}
return batch_sizes.get(size_type, 5)
def get_deploy_config(deploy_type="TP1"):
result = {
"TP1": {
"stage_args": {
0: {
"engine_args.gpu_memory_utilization": 0.95,
"engine_args.tensor_parallel_size": 1,
"runtime.devices": "0",
},
2: {"runtime.devices": "1"},
}
}
}
return result.get(deploy_type, result["TP1"])
@pytest.mark.parametrize("test_config", test_params)
def test_text_to_text_001(test_config: tuple[str, str]) -> None:
"""Test processing text, generating text output via OpenAI API."""
model, stage_config_path = test_config
with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server:
messages = dummy_messages_from_mix_data(system_prompt=get_system_prompt(), content_text=get_prompt())
# Test single completion
api_client = client(server)
start_time = time.perf_counter()
chat_completion = api_client.chat.completions.create(
model=server.model, messages=messages, max_tokens=20, modalities=["text"]
)
# Verify E2E
print(f"the request e2e is: {time.perf_counter() - start_time}")
# TODO: Verify the E2E latency after confirmation baseline.
# Verify only output text
assert len(chat_completion.choices) == 1, "The generated content includes more than just text."
# Verify text output success
text_choice = chat_completion.choices[0]
assert text_choice.message.content is not None, "No text output is generated"
assert chat_completion.usage.completion_tokens <= 20, "The output length more than the requested max_tokens."
assert "beijing" in text_choice.message.content.lower(), "The output do not contain keywords."
@pytest.mark.parametrize("test_config", test_params)
def test_audio_to_text_001(test_config: tuple[str, str]) -> None:
"""Test processing text, generating text output via OpenAI API."""
model, stage_config_path = test_config
deploy_config = get_deploy_config()
deploy_config[0]["default_sampling_params.ignore_eos"] = True
stage_config_path = modify_stage_config(stage_config_path, deploy_config)
with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server:
audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(1, 1)['base64']}"
messages = dummy_messages_from_mix_data(audio_data_url=audio_data_url)
# Test single completion
api_client = client(server)
start_time = time.perf_counter()
chat_completion = api_client.chat.completions.create(
model=server.model, messages=messages, max_tokens=200, modalities=["text"]
)
# Verify only output text
assert len(chat_completion.choices) == 1, "The generated content includes more than just text."
# Verify text output success
text_choice = chat_completion.choices[0]
assert text_choice.message.content is not None, "No text output is generated"
assert chat_completion.usage.completion_tokens == 200, (
"The output length differs from the requested max_tokens."
)
# Verify E2E
print(f"the request e2e is: {time.perf_counter() - start_time}")
# TODO: Verify the E2E latency after confirmation baseline.
@pytest.mark.parametrize("test_config", test_params)
def test_audio_to_text_audio_001(test_config: tuple[str, str]) -> None:
"""Test processing text, generating audio output via OpenAI API."""
model, stage_config_path = test_config
num_concurrent_requests = get_max_batch_size()
stage_config_path = modify_stage_config(
stage_config_path,
{
"stage_args": {
0: {"runtime.max_batch_size": num_concurrent_requests},
1: {"runtime.max_batch_size": num_concurrent_requests},
}
},
)
with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server:
audio_data_url = []
for _ in range(5):
audio_data_url.append(f"data:audio/wav;base64,{generate_synthetic_audio(1, 5)['base64']}")
messages = dummy_messages_from_mix_data(audio_data_url=audio_data_url)
# Test single completion
api_client = client(server)
e2e_list = list()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor:
# Submit multiple completion requests concurrently
futures = [
executor.submit(api_client.chat.completions.create, model=server.model, messages=messages)
for _ in range(num_concurrent_requests)
]
start_time = time.perf_counter()
# Wait for all requests to complete and collect results
chat_completions = list()
for future in concurrent.futures.as_completed(futures):
chat_completions.append(future.result())
# Verify E2E
current_e2e = time.perf_counter() - start_time
print(f"the request e2e is: {current_e2e}")
# TODO: Verify the E2E latency after confirmation baseline.
e2e_list.append(current_e2e)
print(f"the avg e2e is: {sum(e2e_list) / len(e2e_list)}")
# Verify all completions succeeded
assert len(chat_completions) == num_concurrent_requests, "Not all requests succeeded."
for chat_completion in chat_completions:
# Verify audio output success
audio_message = chat_completion.choices[1].message
audio_data = audio_message.audio.data
assert audio_data is not None, "No audio output is generated"
assert audio_message.audio.expires_at > time.time(), "The generated audio has expired."
# Verify text output success
text_choice = chat_completion.choices[0]
text_content = text_choice.message.content
assert text_choice.message.content is not None, "No text output is generated"
# Verify text output same as audio output
audio_content = convert_audio_to_text(audio_data)
print(f"text content is: {text_content}")
print(f"audio content is: {audio_content}")
similarity = cosine_similarity_text(audio_content, text_content)
print(f"similarity between audio and text is: {similarity}")
assert similarity > 0.9, "The audio content is not same as the text"
@pytest.mark.parametrize("test_config", test_params)
def test_image_to_text_001(test_config: tuple[str, str]) -> None:
"""Test processing text, generating text output via OpenAI API."""
model, stage_config_path = test_config
deploy_config = get_deploy_config()
stage_config_path = modify_stage_config(stage_config_path, deploy_config)
with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server:
image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
messages = dummy_messages_from_mix_data(image_data_url=image_data_url)
# Test single completion
api_client = client(server)
start_time = time.perf_counter()
chat_completion = api_client.chat.completions.create(
model=server.model, messages=messages, max_tokens=100, modalities=["text"]
)
# Verify E2E
print(f"the request e2e is: {time.perf_counter() - start_time}")
# TODO: Verify the E2E latency after confirmation baseline.
# Verify only output text
assert len(chat_completion.choices) == 1, "The generated content includes more than just text."
# Verify text output success
text_choice = chat_completion.choices[0]
text_content = text_choice.message.content
assert text_content is not None, "No text output is generated"
assert chat_completion.usage.completion_tokens <= 100, "The output length more than the requested max_tokens."
assert "square" in text_content.lower(), "The output do not contain keywords."
@pytest.mark.parametrize("test_config", test_params)
def test_image_to_text_audio_001(test_config: tuple[str, str]) -> None:
"""Test processing text, generating audio output via OpenAI API."""
model, stage_config_path = test_config
num_concurrent_requests = 5
stage_config_path = modify_stage_config(
stage_config_path,
{
"stage_args": {
0: {"runtime.max_batch_size": num_concurrent_requests},
1: {"runtime.max_batch_size": num_concurrent_requests},
}
},
)
with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server:
image_data_url = []
for _ in range(4):
image_data_url.append(f"data:image/jpeg;base64,{generate_synthetic_image(1280, 720)['base64']}")
messages = dummy_messages_from_mix_data(image_data_url=image_data_url)
# Test single completion
api_client = client(server)
e2e_list = list()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor:
# Submit multiple completion requests concurrently
futures = [
executor.submit(
api_client.chat.completions.create,
model=server.model,
messages=messages,
)
for _ in range(num_concurrent_requests)
]
start_time = time.perf_counter()
# Wait for all requests to complete and collect results
chat_completions = list()
for future in concurrent.futures.as_completed(futures):
chat_completions.append(future.result())
# Verify E2E
current_e2e = time.perf_counter() - start_time
print(f"the request e2e is: {current_e2e}")
# TODO: Verify the E2E latency after confirmation baseline.
e2e_list.append(current_e2e)
print(f"the avg e2e is: {sum(e2e_list) / len(e2e_list)}")
# Verify all completions succeeded
assert len(chat_completions) == num_concurrent_requests, "Not all requests succeeded."
for chat_completion in chat_completions:
# Verify audio output success
audio_message = chat_completion.choices[1].message
audio_data = audio_message.audio.data
assert audio_data is not None, "No audio output is generated"
assert audio_message.audio.expires_at > time.time(), "The generated audio has expired."
# Verify text output success
text_choice = chat_completion.choices[0]
text_content = text_choice.message.content
assert text_content is not None, "No text output is generated"
assert "square" in text_content.lower(), "The output do not contain keywords."
# Verify text output same as audio output
audio_content = convert_audio_to_text(audio_data)
print(f"text content is: {text_content}")
print(f"audio content is: {audio_content}")
similarity = cosine_similarity_text(audio_content, text_content)
print(f"similarity between audio and text is: {similarity}")
assert similarity > 0.9, "The audio content is not same as the text"
# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
# Stage 0: Thinker (multimodal understanding + text generation)
# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
# The following config has been verified on 2x H100-80G GPUs.
stage_args:
- stage_id: 0
runtime:
devices: "0"
max_batch_size: 5
engine_args:
model_stage: thinker
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.9
enforce_eager: false
trust_remote_code: true
engine_output_type: latent # Output hidden states for talker
distributed_executor_backend: "mp"
max_num_batched_tokens: 32768
max_model_len: 32768
enable_prefix_caching: false
hf_config_name: thinker_config
tensor_parallel_size: 1
final_output: true
final_output_type: text
is_comprehension: true
default_sampling_params:
temperature: 0.4
top_p: 0.9
top_k: 1
max_tokens: 100
seed: 42
ignore_eos: False
detokenize: True
repetition_penalty: 1.05
- stage_id: 1
runtime:
devices: "1"
max_batch_size: 5
engine_args:
model_stage: talker
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.6
enforce_eager: false
trust_remote_code: true
engine_output_type: latent # Output codec codes for code2wav
enable_prefix_caching: false
max_num_batched_tokens: 32768
max_model_len: 32768
distributed_executor_backend: "mp"
hf_config_name: talker_config
engine_input_source: [0]
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
default_sampling_params:
temperature: 0.9
top_k: 50
max_tokens: 1000
seed: 42
detokenize: False
repetition_penalty: 1.05
stop_token_ids: [2150]
- stage_id: 2
runtime:
devices: "1"
max_batch_size: 1
engine_args:
model_stage: code2wav
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_type: generation
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
trust_remote_code: true
enable_prefix_caching: false
engine_output_type: audio # Final output: audio waveform
gpu_memory_utilization: 0.1
distributed_executor_backend: "mp"
max_num_batched_tokens: 100000
hf_config_name: thinker_config
async_scheduling: false
engine_input_source: [1]
custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
final_output: true
final_output_type: audio
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
max_tokens: 2000
seed: 42
detokenize: True
repetition_penalty: 1.1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for async image generation API endpoints.
This module contains unit tests and integration tests (with mocking) for the
OpenAI-compatible async text-to-image generation API endpoints in api_server.py.
"""
import base64
import io
from argparse import Namespace
from unittest.mock import AsyncMock, Mock
import pytest
from fastapi.testclient import TestClient
from PIL import Image
from vllm import SamplingParams
from vllm_omni.entrypoints.openai.image_api_utils import (
encode_image_base64,
parse_size,
)
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
# Unit Tests
def test_parse_size_valid():
"""Test size parsing with valid inputs"""
assert parse_size("1024x1024") == (1024, 1024)
assert parse_size("512x768") == (512, 768)
assert parse_size("256x256") == (256, 256)
assert parse_size("1792x1024") == (1792, 1024)
assert parse_size("1024x1792") == (1024, 1792)
def test_parse_size_invalid():
"""Test size parsing with invalid inputs"""
with pytest.raises(ValueError, match="Invalid size format"):
parse_size("invalid")
with pytest.raises(ValueError, match="Invalid size format"):
parse_size("1024")
with pytest.raises(ValueError, match="Invalid size format"):
parse_size("1024x")
with pytest.raises(ValueError, match="Invalid size format"):
parse_size("x1024")
def test_parse_size_negative():
"""Test size parsing with negative or zero dimensions"""
with pytest.raises(ValueError, match="positive integers"):
parse_size("0x1024")
with pytest.raises(ValueError, match="positive integers"):
parse_size("1024x0")
with pytest.raises(ValueError):
parse_size("-1024x1024")
def test_parse_size_edge_cases():
"""Test size parsing with edge cases like empty strings and non-integers"""
# Empty string
with pytest.raises(ValueError, match="non-empty string"):
parse_size("")
# Non-integer dimensions
with pytest.raises(ValueError, match="must be integers"):
parse_size("abc x def")
with pytest.raises(ValueError, match="must be integers"):
parse_size("1024.5x768.5")
# Missing separator (user might forget 'x')
with pytest.raises(ValueError, match="separator"):
parse_size("1024 1024")
def test_encode_image_base64():
"""Test image encoding to base64"""
# Create a simple test image
img = Image.new("RGB", (64, 64), color="red")
b64_str = encode_image_base64(img)
# Should be valid base64
assert isinstance(b64_str, str)
assert len(b64_str) > 0
# Should decode back to PNG
decoded = base64.b64decode(b64_str)
decoded_img = Image.open(io.BytesIO(decoded))
# Verify properties
assert decoded_img.size == (64, 64)
assert decoded_img.format == "PNG"
# Integration Tests (with mocking)
class MockGenerationResult:
"""Mock result object from AsyncOmniDiffusion.generate()"""
def __init__(self, images):
self.images = images
class FakeAsyncOmni:
"""Fake AsyncOmni that yields a single diffusion output."""
def __init__(self):
self.stage_list = ["llm", "diffusion"]
self.default_sampling_params_list = [SamplingParams(temperature=0.1), OmniDiffusionSamplingParams()]
self.captured_sampling_params_list = None
self.captured_prompt = None
async def generate(self, prompt, request_id, sampling_params_list):
self.captured_sampling_params_list = sampling_params_list
self.captured_prompt = prompt
images = [Image.new("RGB", (64, 64), color="green")]
yield MockGenerationResult(images)
@pytest.fixture
def mock_async_diffusion():
"""Mock AsyncOmniDiffusion instance that returns fake images"""
mock = Mock()
mock.is_running = True # For health endpoint
mock.check_health = AsyncMock() # For LLM mode health check
async def generate(**kwargs):
# Return n PIL images wrapped in result object
print("!!!!!!!!!!!!!!!!!!!!! kwargs", kwargs)
n = kwargs["sampling_params_list"][0].num_outputs_per_prompt
mock.captured_sampling_params_list = kwargs["sampling_params_list"]
mock.captured_prompt = kwargs["prompt"]
images = [Image.new("RGB", (64, 64), color="blue") for _ in range(n)]
return MockGenerationResult(images)
mock.generate = AsyncMock(side_effect=generate)
return mock
@pytest.fixture
def test_client(mock_async_diffusion):
"""Create test client with mocked async diffusion engine"""
from fastapi import FastAPI
from vllm_omni.entrypoints.openai.api_server import router
app = FastAPI()
app.include_router(router)
# Set up app state with diffusion engine
app.state.engine_client = mock_async_diffusion
app.state.diffusion_engine = mock_async_diffusion # Also set for health endpoint
app.state.stage_configs = [{"stage_type": "diffusion"}]
app.state.diffusion_model_name = "Qwen/Qwen-Image" # For models endpoint
app.state.args = Namespace(
default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5}}',
max_generated_image_size=4096, # 64*64
)
return TestClient(app)
@pytest.fixture
def async_omni_test_client():
"""Create test client with mocked AsyncOmni engine."""
from fastapi import FastAPI
from vllm_omni.entrypoints.openai.api_server import router
app = FastAPI()
app.include_router(router)
app.state.engine_client = FakeAsyncOmni()
app.state.stage_configs = [{"stage_type": "llm"}, {"stage_type": "diffusion"}]
app.state.args = Namespace(
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
max_generated_image_size=4096, # 64*64
)
return TestClient(app)
def test_health_endpoint(test_client):
"""Test health check endpoint for diffusion mode"""
response = test_client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
def test_health_endpoint_no_engine():
"""Test health check endpoint when no engine is initialized"""
from fastapi import FastAPI
from vllm_omni.entrypoints.openai.api_server import router
app = FastAPI()
app.include_router(router)
# Don't set any engine
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 503
data = response.json()
assert data["status"] == "unhealthy"
def test_models_endpoint(test_client):
"""Test /v1/models endpoint for diffusion mode"""
response = test_client.get("/v1/models")
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert len(data["data"]) == 1
assert data["data"][0]["id"] == "Qwen/Qwen-Image"
assert data["data"][0]["object"] == "model"
def test_models_endpoint_no_engine():
"""Test /v1/models endpoint when no engine is initialized"""
from fastapi import FastAPI
from vllm_omni.entrypoints.openai.api_server import router
app = FastAPI()
app.include_router(router)
# Don't set any engine
client = TestClient(app)
response = client.get("/v1/models")
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert len(data["data"]) == 0
def test_generate_single_image(test_client):
"""Test generating a single image"""
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "a cat",
"n": 1,
"size": "1024x1024",
},
)
assert response.status_code == 200
data = response.json()
# Check response structure
assert "created" in data
assert isinstance(data["created"], int)
assert "data" in data
assert len(data["data"]) == 1
assert "b64_json" in data["data"][0]
# Verify image can be decoded
img_bytes = base64.b64decode(data["data"][0]["b64_json"])
img = Image.open(io.BytesIO(img_bytes))
assert img.size == (64, 64) # Our mock returns 64x64 images
def test_generate_images_async_omni_sampling_params(async_omni_test_client):
"""Test AsyncOmni path uses per-stage sampling params."""
response = async_omni_test_client.post(
"/v1/images/generations",
json={
"prompt": "a cat",
"n": 2,
"size": "256x256",
"seed": 7,
},
)
assert response.status_code == 200
engine = async_omni_test_client.app.state.engine_client
captured = engine.captured_sampling_params_list
assert captured is not None
assert len(captured) == 2
assert captured[0].temperature == 0.1
assert captured[1].num_outputs_per_prompt == 2
assert captured[1].height == 256
assert captured[1].width == 256
assert captured[1].seed == 7
def test_generate_multiple_images(test_client):
"""Test generating multiple images"""
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "a dog",
"n": 3,
"size": "512x512",
},
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 3
# All images should be valid
for img_data in data["data"]:
assert "b64_json" in img_data
img_bytes = base64.b64decode(img_data["b64_json"])
img = Image.open(io.BytesIO(img_bytes))
assert img.format == "PNG"
def test_with_negative_prompt(test_client):
"""Test with negative prompt"""
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "beautiful landscape",
"negative_prompt": "blurry, low quality",
"size": "1024x1024",
},
)
assert response.status_code == 200
def test_with_seed(test_client):
"""Test with seed for reproducibility"""
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "a tree",
"seed": 42,
"size": "1024x1024",
},
)
assert response.status_code == 200
def test_with_custom_parameters(test_client):
"""Test with custom diffusion parameters"""
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "a mountain",
"size": "1024x1024",
"num_inference_steps": 100,
"true_cfg_scale": 5.5,
"seed": 123,
},
)
assert response.status_code == 200
def test_invalid_size(test_client):
"""Test with invalid size parameter - rejected by Pydantic"""
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "a cat",
"size": "invalid",
},
)
# Pydantic validation errors return 422 (Unprocessable Entity)
# "invalid" has no "x" so Pydantic rejects it
assert response.status_code == 422
# Check error detail contains size validation message
detail = str(response.json()["detail"])
assert "size" in detail.lower() or "invalid" in detail.lower()
def test_invalid_size_parse_error(test_client):
"""Test with malformed size - passes Pydantic but fails parse_size()"""
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "a cat",
"size": "1024x", # Has "x" so Pydantic accepts, but parse_size() rejects
},
)
# parse_size() raises ValueError → endpoint converts to 400 (Bad Request)
assert response.status_code == 400
detail = str(response.json()["detail"])
assert "size" in detail.lower() or "invalid" in detail.lower()
def test_missing_prompt(test_client):
"""Test with missing required prompt field"""
response = test_client.post(
"/v1/images/generations",
json={
"size": "1024x1024",
},
)
# Pydantic validation error
assert response.status_code == 422
def test_invalid_n_parameter(test_client):
"""Test with invalid n parameter (out of range)"""
# n < 1
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "a cat",
"n": 0,
},
)
assert response.status_code == 422
# n > 10
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "a cat",
"n": 11,
},
)
assert response.status_code == 422
def test_url_response_format_not_supported(test_client):
"""Test that URL format returns error"""
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "a cat",
"response_format": "url",
},
)
# Pydantic validation errors return 422 (Unprocessable Entity)
assert response.status_code == 422
# Check error mentions response_format or b64_json
detail = str(response.json()["detail"])
assert "b64_json" in detail.lower() or "response" in detail.lower()
def test_model_not_loaded():
"""Test error when diffusion engine is not initialized"""
from fastapi import FastAPI
from vllm_omni.entrypoints.openai.api_server import router
app = FastAPI()
app.include_router(router)
# Don't set diffusion_engine to simulate uninitialized state
app.state.diffusion_engine = None
client = TestClient(app)
response = client.post(
"/v1/images/generations",
json={
"prompt": "a cat",
},
)
assert response.status_code == 503
assert "not initialized" in response.json()["detail"].lower()
def test_different_image_sizes(test_client):
"""Test various valid image sizes"""
sizes = ["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
for size in sizes:
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "a test image",
"size": size,
},
)
assert response.status_code == 200, f"Failed for size {size}"
def test_parameter_validation():
"""Test Pydantic model validation"""
from vllm_omni.entrypoints.openai.protocol.images import ImageGenerationRequest
# Valid request - optional parameters default to None
req = ImageGenerationRequest(prompt="test")
assert req.prompt == "test"
assert req.n == 1
assert req.model is None
assert req.size is None # Engine will use model defaults
assert req.num_inference_steps is None # Engine will use model defaults
assert req.true_cfg_scale is None # Engine will use model defaults
# Invalid num_inference_steps (out of range)
with pytest.raises(ValueError):
ImageGenerationRequest(prompt="test", num_inference_steps=0)
with pytest.raises(ValueError):
ImageGenerationRequest(prompt="test", num_inference_steps=201)
# Invalid guidance_scale (out of range)
with pytest.raises(ValueError):
ImageGenerationRequest(prompt="test", guidance_scale=-1.0)
with pytest.raises(ValueError):
ImageGenerationRequest(prompt="test", guidance_scale=21.0)
# Pass-Through Tests
def test_parameters_passed_through(test_client, mock_async_diffusion):
"""Verify all parameters passed through without modification"""
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "test",
"num_inference_steps": 100,
"guidance_scale": 7.5,
"true_cfg_scale": 3.0,
"seed": 42,
},
)
assert response.status_code == 200
# Ensure generate() was called exactly once
mock_async_diffusion.generate.assert_awaited_once()
call_kwargs = mock_async_diffusion.generate.call_args[1]["sampling_params_list"][0]
assert call_kwargs.num_inference_steps == 100
assert call_kwargs.guidance_scale == 7.5
assert call_kwargs.true_cfg_scale == 3.0
assert call_kwargs.seed == 42
def test_model_field_omitted_works(test_client):
"""Test that omitting model field works"""
response = test_client.post(
"/v1/images/generations",
json={
"prompt": "test",
"size": "1024x1024",
# model field omitted
},
)
assert response.status_code == 200
def make_test_image_bytes(size=(64, 64)) -> bytes:
img = Image.new(
"RGB",
size,
)
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def test_image_edit_images_processing(async_omni_test_client):
img_bytes_1 = make_test_image_bytes((16, 16))
img_bytes_2 = make_test_image_bytes((32, 32))
# uploadfile with image key
response = async_omni_test_client.post(
"/v1/images/edits",
files=[
("image", img_bytes_1),
("image", img_bytes_2),
],
data={"prompt": "hello world."},
)
assert response.status_code == 200
engine = async_omni_test_client.app.state.engine_client
captured_prompt = engine.captured_prompt
processed_images = captured_prompt["multi_modal_data"]["image"]
assert len(processed_images) == 2
assert isinstance(processed_images[0], Image.Image)
assert isinstance(processed_images[1], Image.Image)
assert processed_images[0].size == (16, 16)
assert processed_images[1].size == (32, 32)
# uploadfile with image[] key
response = async_omni_test_client.post(
"/v1/images/edits",
files=[
("image[]", img_bytes_2),
("image[]", img_bytes_1),
],
data={"prompt": "hello world."},
)
assert response.status_code == 200
engine = async_omni_test_client.app.state.engine_client
captured_prompt = engine.captured_prompt
processed_images = captured_prompt["multi_modal_data"]["image"]
assert len(processed_images) == 2
assert isinstance(processed_images[0], Image.Image)
assert isinstance(processed_images[1], Image.Image)
assert processed_images[0].size == (32, 32)
assert processed_images[1].size == (16, 16)
# base64 url
buf1 = io.BytesIO()
img1 = Image.new("RGB", (16, 16))
img1.save(buf1, format="PNG")
b64_1 = "data:image/png;base64," + base64.b64encode(buf1.getvalue()).decode()
buf2 = io.BytesIO()
img2 = Image.new("RGB", (24, 24))
img2.save(buf2, format="PNG")
b64_2 = "data:image/png;base64," + base64.b64encode(buf2.getvalue()).decode()
response = async_omni_test_client.post(
"/v1/images/edits",
data={
"prompt": "hello from base64",
"url": [b64_1, b64_2],
},
)
assert response.status_code == 200
processed_images = engine.captured_prompt["multi_modal_data"]["image"]
assert len(processed_images) == 2
assert isinstance(processed_images[0], Image.Image)
assert isinstance(processed_images[1], Image.Image)
assert processed_images[0].size == (16, 16)
assert processed_images[1].size == (24, 24)
def test_image_edit_parameter_pass(async_omni_test_client):
img_bytes_1 = make_test_image_bytes((16, 16))
# uploadfile with image key
response = async_omni_test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={
"prompt": "hello world.",
"size": "16x24",
"output_format": "jpeg",
"num_inference_steps": 20,
"guidance_scale": 8.0,
"seed": 1234,
"negative_prompt": "negative",
"n": 2,
},
)
assert response.status_code == 200
engine = async_omni_test_client.app.state.engine_client
captured_prompt = engine.captured_prompt
captured_sampling_params = engine.captured_sampling_params_list[-1]
assert captured_prompt["prompt"] == "hello world."
assert captured_prompt["negative_prompt"] == "negative"
assert captured_sampling_params.num_inference_steps == 20
assert captured_sampling_params.guidance_scale == 8.0
assert captured_sampling_params.seed == 1234
assert captured_sampling_params.num_outputs_per_prompt == 2
assert captured_sampling_params.width == 16
assert captured_sampling_params.height == 24
data = response.json()
# All images should be valid
for img_data in data["data"]:
assert "b64_json" in img_data
img_bytes = base64.b64decode(img_data["b64_json"])
img = Image.open(io.BytesIO(img_bytes))
assert img.format.lower() == "jpeg"
assert data["output_format"] == "jpeg"
assert data["size"] == "16x24"
def test_image_edit_parameter_default(async_omni_test_client):
img_bytes_1 = make_test_image_bytes((24, 16))
# uploadfile with image key
response = async_omni_test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={
"prompt": "hello world.",
"size": "auto",
},
)
assert response.status_code == 200
engine = async_omni_test_client.app.state.engine_client
captured_sampling_params = engine.captured_sampling_params_list[-1]
assert captured_sampling_params.width == 24
assert captured_sampling_params.height == 16
assert captured_sampling_params.num_outputs_per_prompt == 1
assert captured_sampling_params.num_inference_steps == 4
assert captured_sampling_params.guidance_scale == 7.5
response = async_omni_test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={
"prompt": "hello world.",
"size": "96x96",
},
)
assert response.status_code == 400
def test_image_edit_parameter_default_single_stage(test_client):
img_bytes_1 = make_test_image_bytes((24, 16))
# uploadfile with image key
response = test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={
"prompt": "hello world.",
},
)
assert response.status_code == 200
engine = test_client.app.state.engine_client
captured_sampling_params = engine.captured_sampling_params_list[0]
assert captured_sampling_params.width == 24
assert captured_sampling_params.height == 16
assert captured_sampling_params.num_outputs_per_prompt == 1
assert captured_sampling_params.num_inference_steps == 4
assert captured_sampling_params.guidance_scale == 7.5
response = test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={
"prompt": "hello world.",
"size": "96x96",
},
)
assert response.status_code == 400
def test_image_edit_compression_jpeg(test_client):
img_bytes_1 = make_test_image_bytes((16, 16))
# uploadfile with image key
response = test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={"prompt": "hello world.", "output_format": "jpeg", "output_compression": 100},
)
assert response.status_code == 200
data = response.json()
img_bytes_100 = base64.b64decode(data["data"][0]["b64_json"])
img = Image.open(io.BytesIO(img_bytes_100))
assert img.format.lower() == "jpeg"
response = test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={
"prompt": "hello world.",
"output_format": "jpeg",
"output_compression": 50,
},
)
assert response.status_code == 200
data = response.json()
img_bytes_50 = base64.b64decode(data["data"][0]["b64_json"])
response = test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={
"prompt": "hello world.",
"output_format": "jpeg",
"output_compression": 10,
},
)
assert response.status_code == 200
data = response.json()
img_bytes_10 = base64.b64decode(data["data"][0]["b64_json"])
assert len(img_bytes_10) < len(img_bytes_50)
assert len(img_bytes_50) < len(img_bytes_100)
def test_image_edit_compression_png(async_omni_test_client):
img_bytes_1 = make_test_image_bytes((16, 16))
# uploadfile with image key
response = async_omni_test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={"prompt": "hello world.", "output_format": "PNG", "output_compression": 100},
)
assert response.status_code == 200
data = response.json()
img_bytes_100 = base64.b64decode(data["data"][0]["b64_json"])
img = Image.open(io.BytesIO(img_bytes_100))
assert img.format.lower() == "png"
response = async_omni_test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={
"prompt": "hello world.",
"output_format": "PNG",
"output_compression": 50,
},
)
assert response.status_code == 200
data = response.json()
img_bytes_50 = base64.b64decode(data["data"][0]["b64_json"])
response = async_omni_test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={
"prompt": "hello world.",
"output_format": "PNG",
"output_compression": 10,
},
)
assert response.status_code == 200
data = response.json()
img_bytes_10 = base64.b64decode(data["data"][0]["b64_json"])
assert len(img_bytes_10) < len(img_bytes_50)
assert len(img_bytes_50) < len(img_bytes_100)
# SPDX-License-Identifier: Apache-2.0
"""
Unit tests for OmniOpenAIServingChat sampling params handling.
Tests that standard OpenAI API parameters (max_tokens, temperature, etc.)
are correctly applied to the comprehension stage while preserving YAML defaults.
"""
from unittest.mock import MagicMock
import pytest
from vllm.sampling_params import SamplingParams
@pytest.fixture
def mock_comprehension_stage():
"""Create a mock comprehension stage with is_comprehension=True."""
stage = MagicMock()
stage.is_comprehension = True
stage.model_stage = "comprehension"
return stage
@pytest.fixture
def mock_other_stage():
"""Create a mock non-comprehension stage."""
stage = MagicMock()
stage.is_comprehension = False
stage.model_stage = "other"
return stage
@pytest.fixture
def default_comprehension_params():
"""Default sampling params for comprehension stage (from YAML)."""
return SamplingParams(
temperature=0.4,
top_p=0.9,
top_k=1,
max_tokens=2048,
seed=42,
repetition_penalty=1.05,
)
@pytest.fixture
def default_other_params():
"""Default sampling params for non-comprehension stage (from YAML)."""
return SamplingParams(
temperature=0.9,
top_k=50,
max_tokens=4096,
seed=42,
)
@pytest.fixture
def mock_engine_client(mock_comprehension_stage, mock_other_stage, default_comprehension_params, default_other_params):
"""Create mock engine client with stage_list and default_sampling_params_list."""
engine_client = MagicMock()
engine_client.stage_list = [mock_comprehension_stage, mock_other_stage]
engine_client.default_sampling_params_list = [
default_comprehension_params,
default_other_params,
]
return engine_client
@pytest.fixture
def serving_chat(mock_engine_client):
"""Create OmniOpenAIServingChat instance with mocked dependencies."""
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
# Create instance without calling __init__
instance = object.__new__(OmniOpenAIServingChat)
instance.engine_client = mock_engine_client
return instance
@pytest.fixture
def mock_request():
"""Create a mock request with all OpenAI sampling params set to None."""
request = MagicMock()
# OpenAI standard sampling fields
request.temperature = None
request.top_p = None
request.max_tokens = None
request.seed = None
request.stop = None
request.frequency_penalty = None
request.presence_penalty = None
return request
# =============================================================================
# Tests for _OPENAI_SAMPLING_FIELDS constant
# =============================================================================
def test_openai_sampling_fields_contains_expected_fields():
"""Test that _OPENAI_SAMPLING_FIELDS contains all expected OpenAI params."""
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
expected_fields = {
"temperature",
"top_p",
"max_tokens",
"seed",
"stop",
"frequency_penalty",
"presence_penalty",
}
assert OmniOpenAIServingChat._OPENAI_SAMPLING_FIELDS == expected_fields
# =============================================================================
# Tests for _build_sampling_params_list_from_request
# =============================================================================
def test_preserves_yaml_defaults_when_no_request_params(serving_chat, mock_request):
"""Test that YAML defaults are preserved when request has no params."""
result = serving_chat._build_sampling_params_list_from_request(mock_request)
assert len(result) == 2
comprehension_params = result[0]
assert comprehension_params.temperature == 0.4
assert comprehension_params.top_p == 0.9
assert comprehension_params.top_k == 1 # YAML custom param preserved
assert comprehension_params.max_tokens == 2048
assert comprehension_params.seed == 42
assert comprehension_params.repetition_penalty == 1.05 # YAML custom param preserved
def test_request_temperature_overrides_yaml_default(serving_chat, mock_request):
"""Test that request temperature overrides YAML default."""
mock_request.temperature = 0.8
result = serving_chat._build_sampling_params_list_from_request(mock_request)
comprehension_params = result[0]
assert comprehension_params.temperature == 0.8 # Overridden
assert comprehension_params.seed == 42 # Preserved from YAML
assert comprehension_params.top_k == 1 # YAML custom param preserved
def test_request_top_p_overrides_yaml_default(serving_chat, mock_request):
"""Test that request top_p overrides YAML default."""
mock_request.top_p = 0.95
result = serving_chat._build_sampling_params_list_from_request(mock_request)
comprehension_params = result[0]
assert comprehension_params.top_p == 0.95 # Overridden
assert comprehension_params.temperature == 0.4 # Preserved from YAML
def test_request_max_tokens_overrides_yaml_default(serving_chat, mock_request):
"""Test that request max_tokens overrides YAML default."""
mock_request.max_tokens = 100
result = serving_chat._build_sampling_params_list_from_request(mock_request)
assert result[0].max_tokens == 100
def test_max_tokens_uses_yaml_default_when_not_specified(serving_chat, mock_request):
"""Test that max_tokens falls back to YAML default when not in request."""
result = serving_chat._build_sampling_params_list_from_request(mock_request)
assert result[0].max_tokens == 2048
def test_request_seed_overrides_yaml_default(serving_chat, mock_request):
"""Test that request seed overrides YAML default."""
mock_request.seed = 123
result = serving_chat._build_sampling_params_list_from_request(mock_request)
comprehension_params = result[0]
assert comprehension_params.seed == 123 # Overridden
assert comprehension_params.temperature == 0.4 # Preserved from YAML
def test_request_frequency_penalty_overrides(serving_chat, mock_request):
"""Test that request frequency_penalty is applied."""
mock_request.frequency_penalty = 0.5
result = serving_chat._build_sampling_params_list_from_request(mock_request)
assert result[0].frequency_penalty == 0.5
def test_request_presence_penalty_overrides(serving_chat, mock_request):
"""Test that request presence_penalty is applied."""
mock_request.presence_penalty = 0.3
result = serving_chat._build_sampling_params_list_from_request(mock_request)
assert result[0].presence_penalty == 0.3
def test_non_comprehension_stages_use_cloned_defaults(serving_chat, mock_request):
"""Test that non-comprehension stages always use cloned YAML defaults."""
mock_request.max_tokens = 50
mock_request.temperature = 0.1
result = serving_chat._build_sampling_params_list_from_request(mock_request)
other_params = result[1]
assert other_params.temperature == 0.9 # YAML default (not affected by request)
assert other_params.max_tokens == 4096 # YAML default (not affected by request)
assert other_params.top_k == 50 # YAML default
assert other_params.seed == 42 # YAML default
def test_multiple_params_override_together(serving_chat, mock_request):
"""Test that multiple request params can override together."""
mock_request.max_tokens = 200
mock_request.temperature = 0.7
mock_request.top_p = 0.85
mock_request.seed = 999
result = serving_chat._build_sampling_params_list_from_request(mock_request)
comprehension_params = result[0]
# Overridden by request
assert comprehension_params.temperature == 0.7
assert comprehension_params.top_p == 0.85
assert comprehension_params.max_tokens == 200
assert comprehension_params.seed == 999
# Preserved from YAML (not in _OPENAI_SAMPLING_FIELDS)
assert comprehension_params.top_k == 1
assert comprehension_params.repetition_penalty == 1.05
def test_yaml_custom_params_not_overridden_by_request(serving_chat, mock_request):
"""Test that YAML custom params (top_k, repetition_penalty) are not affected."""
# Even if request has these attributes, they should not override YAML
# because they're not in _OPENAI_SAMPLING_FIELDS
mock_request.top_k = 100 # Not in allowlist
mock_request.repetition_penalty = 2.0 # Not in allowlist
result = serving_chat._build_sampling_params_list_from_request(mock_request)
comprehension_params = result[0]
assert comprehension_params.top_k == 1 # YAML default preserved
assert comprehension_params.repetition_penalty == 1.05 # YAML default preserved
# =============================================================================
# Tests for _apply_request_overrides
# =============================================================================
def test_apply_request_overrides_clones_params(serving_chat, mock_request, default_comprehension_params):
"""Test that _apply_request_overrides returns a cloned object."""
result = serving_chat._apply_request_overrides(default_comprehension_params, mock_request)
assert result is not default_comprehension_params # Different object
def test_apply_request_overrides_preserves_defaults(serving_chat, mock_request, default_comprehension_params):
"""Test that _apply_request_overrides preserves defaults when request has None."""
result = serving_chat._apply_request_overrides(default_comprehension_params, mock_request)
assert result.temperature == 0.4
assert result.top_p == 0.9
assert result.seed == 42
assert result.top_k == 1 # YAML custom param
def test_apply_request_overrides_applies_values(serving_chat, mock_request, default_comprehension_params):
"""Test that _apply_request_overrides applies non-None request values."""
mock_request.temperature = 0.8
mock_request.seed = 123
result = serving_chat._apply_request_overrides(default_comprehension_params, mock_request)
assert result.temperature == 0.8 # Overridden
assert result.seed == 123 # Overridden
assert result.top_p == 0.9 # Preserved from default
assert result.top_k == 1 # YAML custom param preserved
# =============================================================================
# Tests for _get_comprehension_stage_index
# =============================================================================
def test_get_comprehension_stage_index_finds_first_stage(mock_engine_client):
"""Test finding comprehension stage when it's at index 0."""
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
instance = object.__new__(OmniOpenAIServingChat)
instance.engine_client = mock_engine_client
assert instance._get_comprehension_stage_index() == 0
def test_get_comprehension_stage_index_finds_second_stage():
"""Test finding comprehension stage when it's at index 1."""
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
instance = object.__new__(OmniOpenAIServingChat)
other = MagicMock()
other.is_comprehension = False
comprehension = MagicMock()
comprehension.is_comprehension = True
instance.engine_client = MagicMock()
instance.engine_client.stage_list = [other, comprehension]
assert instance._get_comprehension_stage_index() == 1
def test_get_comprehension_stage_index_raises_when_not_found():
"""Test that ValueError is raised when no comprehension stage exists."""
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
instance = object.__new__(OmniOpenAIServingChat)
stage1 = MagicMock()
stage1.is_comprehension = False
stage2 = MagicMock()
stage2.is_comprehension = False
instance.engine_client = MagicMock()
instance.engine_client.stage_list = [stage1, stage2]
with pytest.raises(ValueError, match="No comprehension stage"):
instance._get_comprehension_stage_index()
# tests/entrypoints/openai/test_serving_speech.py
import logging
from inspect import Signature, signature
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
from fastapi import FastAPI
from fastapi.testclient import TestClient
from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
from vllm_omni.entrypoints.openai.protocol.audio import CreateAudio, OpenAICreateSpeechRequest
from vllm_omni.entrypoints.openai.serving_speech import (
OmniOpenAIServingSpeech,
)
from vllm_omni.outputs import OmniRequestOutput
logger = logging.getLogger(__name__)
class TestAudioMixin:
@pytest.fixture
def audio_mixin(self):
return AudioMixin()
def test_stereo_to_mono_conversion(self, audio_mixin):
stereo_tensor = np.random.rand(24000, 2).astype(np.float32)
audio_obj = CreateAudio(audio_tensor=stereo_tensor)
with (
patch.object(
audio_mixin, "_apply_speed_adjustment", side_effect=lambda tensor, speed, sr: (tensor, sr)
) as mock_speed,
patch("soundfile.write") as _,
):
audio_mixin.create_audio(audio_obj)
# Check that the tensor passed to speed adjustment is mono
mock_speed.assert_called_once()
adjusted_tensor = mock_speed.call_args[0][0]
assert len(adjusted_tensor) == 24000
@patch("librosa.effects.time_stretch")
def test_speed_adjustment(self, mock_time_stretch, audio_mixin):
mock_time_stretch.return_value = np.zeros(12000)
audio_tensor = np.random.rand(24000).astype(np.float32)
adjusted_audio, _ = audio_mixin._apply_speed_adjustment(audio_tensor, speed=2.0, sample_rate=24000)
mock_time_stretch.assert_called_with(y=audio_tensor, rate=2.0)
assert adjusted_audio.shape == (12000,)
@patch("soundfile.write")
def test_unsupported_format_fallback(self, mock_write, audio_mixin, caplog):
audio_tensor = np.random.rand(24000).astype(np.float32)
# Use a format that is not in the list of supported formats
audio_obj = CreateAudio(audio_tensor=audio_tensor, response_format="vorbis")
audio_mixin.create_audio(audio_obj)
# Should fall back to 'wav'
mock_write.assert_called_once()
write_kwargs = mock_write.call_args.kwargs
assert write_kwargs["format"] == "WAV"
def test_mono_audio_preservation(self, audio_mixin):
"""Test that mono (1D) audio tensors are processed correctly and passed to writer."""
mono_tensor = np.random.rand(24000).astype(np.float32)
audio_obj = CreateAudio(audio_tensor=mono_tensor)
with patch("soundfile.write") as mock_write:
audio_mixin.create_audio(audio_obj)
mock_write.assert_called_once()
# Verify the tensor passed to soundfile.write is the exact 1D tensor
output_tensor = mock_write.call_args[0][1]
assert output_tensor.ndim == 1
assert output_tensor.shape == (24000,)
assert np.array_equal(output_tensor, mono_tensor)
def test_stereo_audio_preservation(self, audio_mixin):
"""Test that stereo (2D) audio tensors are processed correctly and preserved."""
stereo_tensor = np.random.rand(24000, 2).astype(np.float32)
audio_obj = CreateAudio(audio_tensor=stereo_tensor)
with patch("soundfile.write") as mock_write:
audio_mixin.create_audio(audio_obj)
mock_write.assert_called_once()
# Verify the tensor passed to soundfile.write is the exact 2D tensor
output_tensor = mock_write.call_args[0][1]
assert output_tensor.ndim == 2
assert output_tensor.shape == (24000, 2)
assert np.array_equal(output_tensor, stereo_tensor)
def test_speed_adjustment_bypass(self, audio_mixin):
"""Test that speed=1.0 bypasses the expensive librosa time stretching."""
audio_tensor = np.random.rand(24000).astype(np.float32)
with patch("librosa.effects.time_stretch") as mock_time_stretch:
# speed=1.0 should return immediately without calling librosa
result, _ = audio_mixin._apply_speed_adjustment(audio_tensor, speed=1.0, sample_rate=24000)
mock_time_stretch.assert_not_called()
assert np.array_equal(result, audio_tensor)
@patch("librosa.effects.time_stretch")
def test_speed_adjustment_stereo_handling(self, mock_time_stretch, audio_mixin):
"""Test that speed adjustment is attempted on stereo inputs."""
stereo_tensor = np.random.rand(24000, 2).astype(np.float32)
# Mock return value representing a sped-up version (half length)
mock_time_stretch.return_value = np.zeros((12000, 2), dtype=np.float32)
result, _ = audio_mixin._apply_speed_adjustment(stereo_tensor, speed=2.0, sample_rate=24000)
mock_time_stretch.assert_called_once()
# Ensure the stereo tensor was passed to librosa
call_args = mock_time_stretch.call_args
assert np.array_equal(call_args.kwargs["y"], stereo_tensor)
assert call_args.kwargs["rate"] == 2.0
assert result.shape == (12000, 2)
# Helper to create mock model output for endpoint tests
def create_mock_audio_output_for_test(
request_id: str = "speech-mock-123",
) -> OmniRequestOutput:
class MockCompletionOutput:
def __init__(self, index: int = 0):
self.index = index
self.text = ""
self.token_ids = []
self.finish_reason = "stop"
self.stop_reason = None
self.logprobs = None
class MockRequestOutput:
def __init__(self, request_id: str, audio_tensor: torch.Tensor):
self.request_id = request_id
self.outputs = [MockCompletionOutput(index=0)]
self.multimodal_output = {"audio": audio_tensor}
self.finished = True
self.prompt_token_ids = None
self.encoder_prompt_token_ids = None
self.num_cached_tokens = None
self.prompt_logprobs = None
self.kv_transfer_params = None
num_samples = 24000
audio_tensor = torch.sin(torch.linspace(0, 440 * 2 * torch.pi, num_samples))
mock_request_output = MockRequestOutput(request_id=request_id, audio_tensor=audio_tensor)
return OmniRequestOutput(
stage_id=0,
final_output_type="audio",
request_output=mock_request_output,
)
def create_mock_audio_output_on_completion_for_test(
request_id: str = "speech-mock-completion-123",
) -> OmniRequestOutput:
class MockCompletionOutput:
def __init__(self, audio_tensor: torch.Tensor, index: int = 0):
self.index = index
self.text = ""
self.token_ids = []
self.finish_reason = "stop"
self.stop_reason = None
self.logprobs = None
self.multimodal_output = {"audio": audio_tensor, "sr": 24000}
class MockRequestOutput:
def __init__(self, request_id: str, audio_tensor: torch.Tensor):
self.request_id = request_id
self.outputs = [MockCompletionOutput(audio_tensor=audio_tensor, index=0)]
self.multimodal_output = {}
self.finished = True
self.prompt_token_ids = None
self.encoder_prompt_token_ids = None
self.num_cached_tokens = None
self.prompt_logprobs = None
self.kv_transfer_params = None
num_samples = 24000
audio_tensor = torch.sin(torch.linspace(0, 440 * 2 * torch.pi, num_samples))
mock_request_output = MockRequestOutput(request_id=request_id, audio_tensor=audio_tensor)
return OmniRequestOutput(
stage_id=0,
final_output_type="audio",
request_output=mock_request_output,
)
@pytest.fixture
def test_app():
# Mock the engine client
mock_engine_client = MagicMock()
mock_engine_client.errored = False
async def mock_generate_fn(*args, **kwargs):
yield create_mock_audio_output_for_test(request_id=kwargs.get("request_id"))
mock_engine_client.generate = MagicMock(side_effect=mock_generate_fn)
mock_engine_client.default_sampling_params_list = [{}]
# Mock models to have an is_base_model method
mock_models = MagicMock()
mock_models.is_base_model.return_value = True
mock_request_logger = MagicMock()
speech_server = OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=mock_request_logger,
)
# Patch the signature of create_speech to remove 'raw_request' for FastAPI route introspection
original_create_speech = speech_server.create_speech
_ = MagicMock(side_effect=original_create_speech)
sig = signature(original_create_speech)
new_parameters = [param for name, param in sig.parameters.items() if name != "raw_request"]
new_sig = Signature(parameters=new_parameters, return_annotation=sig.return_annotation)
async def awaitable_patched_create_speech(*args, **kwargs):
return await original_create_speech(*args, **kwargs)
awaitable_patched_create_speech.__signature__ = new_sig
speech_server.create_speech = awaitable_patched_create_speech
app = FastAPI()
app.add_api_route("/v1/audio/speech", speech_server.create_speech, methods=["POST"], response_model=None)
# Add list_voices endpoint
async def list_voices():
speakers = sorted(speech_server.supported_speakers) if speech_server.supported_speakers else []
return {"voices": speakers}
app.add_api_route("/v1/audio/voices", list_voices, methods=["GET"])
return app
@pytest.fixture
def client(test_app):
return TestClient(test_app)
class TestSpeechAPI:
def test_create_speech_success(self, client):
payload = {
"input": "Hello world",
"model": "tts-model",
"voice": "alloy",
"response_format": "wav",
}
response = client.post("/v1/audio/speech", json=payload)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
assert len(response.content) > 0
def test_create_speech_mp3_format(self, client):
payload = {
"input": "Hello world",
"model": "tts-model",
"voice": "alloy",
"response_format": "mp3",
}
response = client.post("/v1/audio/speech", json=payload)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert len(response.content) > 0
def test_create_speech_reads_audio_from_completion_output(self, test_app):
mock_engine_client = MagicMock()
mock_engine_client.errored = False
async def mock_generate_fn(*args, **kwargs):
yield create_mock_audio_output_on_completion_for_test(request_id=kwargs.get("request_id"))
mock_engine_client.generate = MagicMock(side_effect=mock_generate_fn)
mock_engine_client.default_sampling_params_list = [{}]
mock_models = MagicMock()
mock_models.is_base_model.return_value = True
speech_server = OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=MagicMock(),
)
original_create_speech = speech_server.create_speech
sig = signature(original_create_speech)
new_parameters = [param for name, param in sig.parameters.items() if name != "raw_request"]
new_sig = Signature(parameters=new_parameters, return_annotation=sig.return_annotation)
async def awaitable_patched_create_speech(*args, **kwargs):
return await original_create_speech(*args, **kwargs)
awaitable_patched_create_speech.__signature__ = new_sig
speech_server.create_speech = awaitable_patched_create_speech
app = FastAPI()
app.add_api_route("/v1/audio/speech", speech_server.create_speech, methods=["POST"], response_model=None)
client = TestClient(app)
payload = {
"input": "Hello world",
"model": "tts-model",
"voice": "alloy",
"response_format": "wav",
}
response = client.post("/v1/audio/speech", json=payload)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
assert len(response.content) > 0
def test_create_speech_invalid_format(self, client):
payload = {
"input": "Hello world",
"model": "tts-model",
"voice": "alloy",
"response_format": "invalid_format",
}
response = client.post("/v1/audio/speech", json=payload)
assert response.status_code == 422 # Unprocessable Entity
@patch("vllm_omni.entrypoints.openai.serving_speech.OmniOpenAIServingSpeech.create_audio")
def test_speed_parameter_is_used(self, mock_create_audio, test_app):
client = TestClient(test_app)
mock_audio_response = MagicMock()
mock_audio_response.audio_data = b"dummy_audio"
mock_audio_response.media_type = "audio/wav"
mock_create_audio.return_value = mock_audio_response
payload = {
"input": "This should be fast.",
"model": "tts-model",
"voice": "alloy",
"response_format": "wav",
"speed": 2.5,
}
client.post("/v1/audio/speech", json=payload)
mock_create_audio.assert_called_once()
call_args = mock_create_audio.call_args[0]
audio_obj = call_args[0]
assert isinstance(audio_obj, CreateAudio)
assert audio_obj.speed == 2.5
def test_list_voices_endpoint(self, client):
response = client.get("/v1/audio/voices")
assert response.status_code == 200
assert "voices" in response.json()
class TestTTSMethods:
"""Unit tests for TTS validation and parameter building."""
@pytest.fixture
def speech_server(self):
mock_engine_client = MagicMock()
mock_engine_client.errored = False
mock_engine_client.stage_list = None
mock_models = MagicMock()
mock_models.is_base_model.return_value = True
return OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=MagicMock(),
)
def test_is_tts_model(self, speech_server):
"""Test TTS model detection."""
# No stage_list -> False
assert speech_server._is_tts_model() is False
# With qwen3_tts stage -> True
mock_stage = MagicMock()
mock_stage.model_stage = "qwen3_tts"
speech_server.engine_client.stage_list = [mock_stage]
assert speech_server._is_tts_model() is True
def test_build_tts_prompt(self, speech_server):
"""Test TTS prompt format."""
prompt = speech_server._build_tts_prompt("Hello")
assert prompt == "<|im_start|>assistant\nHello<|im_end|>\n<|im_start|>assistant\n"
def test_validate_tts_request_basic(self, speech_server):
"""Test basic validation cases."""
# Empty input
req = OpenAICreateSpeechRequest(input="")
assert speech_server._validate_tts_request(req) == "Input text cannot be empty"
# Invalid language
req = OpenAICreateSpeechRequest(input="Hello", language="InvalidLang")
assert "Invalid language" in speech_server._validate_tts_request(req)
# When no speakers loaded, any voice is accepted (unconstrained)
req = OpenAICreateSpeechRequest(input="Hello", voice="Invalid")
assert speech_server._validate_tts_request(req) is None
# Valid request
req = OpenAICreateSpeechRequest(input="Hello", voice="Vivian")
assert speech_server._validate_tts_request(req) is None
def test_validate_tts_request_task_types(self, speech_server):
"""Test task-specific validation."""
# Base task requires ref_audio
req = OpenAICreateSpeechRequest(input="Hello", task_type="Base")
assert "ref_audio" in speech_server._validate_tts_request(req)
# VoiceDesign requires instructions
req = OpenAICreateSpeechRequest(input="Hello", task_type="VoiceDesign")
assert "instructions" in speech_server._validate_tts_request(req)
# ref_text only for Base
req = OpenAICreateSpeechRequest(input="Hello", ref_text="text")
assert "Base task" in speech_server._validate_tts_request(req)
def test_build_tts_params(self, speech_server):
"""Test TTS parameter building."""
req = OpenAICreateSpeechRequest(input="Hello", voice="Ryan", language="English")
params = speech_server._build_tts_params(req)
assert params["text"] == ["Hello"]
assert params["speaker"] == ["Ryan"]
assert params["language"] == ["English"]
assert params["task_type"] == ["CustomVoice"]
assert "max_new_tokens" not in params
def test_build_tts_params_with_explicit_max_new_tokens(self, speech_server):
"""Test explicit max_new_tokens override."""
req = OpenAICreateSpeechRequest(
input="Hello",
task_type="Base",
ref_audio="data:audio/wav;base64,AAAA",
max_new_tokens=128,
)
params = speech_server._build_tts_params(req)
assert params["max_new_tokens"] == [128]
def test_load_supported_speakers(self):
"""Test _load_supported_speakers."""
mock_engine_client = MagicMock()
mock_engine_client.errored = False
mock_engine_client.stage_list = None
# Mock talker_config with mixed-case speaker names
mock_talker_config = MagicMock()
mock_talker_config.spk_id = {"Ryan": 0, "Vivian": 1, "Aiden": 2}
mock_engine_client.model_config.hf_config.talker_config = mock_talker_config
mock_models = MagicMock()
mock_models.is_base_model.return_value = True
server = OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=MagicMock(),
)
# Verify speakers are normalized to lowercase
assert server.supported_speakers == {"ryan", "vivian", "aiden"}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm_omni.entrypoints import omni as omni_module
from vllm_omni.entrypoints.async_omni import AsyncOmni
def test_default_stage_config_includes_cache_backend(monkeypatch):
"""Ensure cache_backend/cache_config are preserved in default diffusion stage."""
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: [])
monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None)
monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None)
monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None)
omni = AsyncOmni(
model="dummy-model",
cache_backend="cache_dit",
cache_config='{"Fn_compute_blocks": 2}',
vae_use_slicing=True,
ulysses_degree=2,
)
stage_cfg = omni.stage_configs[0]
engine_args = stage_cfg.engine_args
assert engine_args.get("cache_backend") == "cache_dit"
cache_config = engine_args.get("cache_config")
assert cache_config["Fn_compute_blocks"] == 2
assert engine_args.get("vae_use_slicing") is True
parallel_config = engine_args.get("parallel_config")
if hasattr(parallel_config, "get"):
ulysses_degree = parallel_config.get("ulysses_degree")
else:
ulysses_degree = getattr(parallel_config, "ulysses_degree", None)
assert ulysses_degree == 2
def test_default_cache_config_used_when_missing(monkeypatch):
"""Ensure default cache_config is applied when cache_backend is set."""
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: [])
monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None)
monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None)
monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None)
omni = AsyncOmni(
model="dummy-model",
cache_backend="cache_dit",
)
engine_args = omni.stage_configs[0].engine_args
cache_config = engine_args.get("cache_config")
assert cache_config is not None
assert cache_config["Fn_compute_blocks"] == 1
def test_default_stage_devices_from_sequence_parallel(monkeypatch):
"""Ensure devices list reflects sequence parallel size when no parallel_config is provided."""
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", lambda model, base_engine_args=None: [])
monkeypatch.setattr(omni_module, "resolve_model_config_path", lambda model: None)
monkeypatch.setattr(AsyncOmni, "_start_stages", lambda self, model: None)
monkeypatch.setattr(AsyncOmni, "_wait_for_stages_ready", lambda self, timeout=0: None)
omni = AsyncOmni(
model="dummy-model",
ulysses_degree=2,
ring_degree=2,
)
stage_cfg = omni.stage_configs[0]
runtime = stage_cfg.runtime
if hasattr(runtime, "get"):
devices = runtime.get("devices")
else:
devices = getattr(runtime, "devices", None)
assert devices == "0,1,2,3"
import uuid
import warnings
from queue import Empty, Queue
from typing import Any
from unittest.mock import MagicMock
import pytest
from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies.
warnings.filterwarnings(
"ignore",
message=r"builtin type SwigPy.*has no __module__ attribute",
category=DeprecationWarning,
)
class _FakeEngineArgs(dict):
"""Fake engine args that can be used both as object attributes and as **kwargs."""
def __init__(self, args_dict: dict[str, Any]):
super().__init__(args_dict)
# Add required attributes if not present
if "model_stage" not in self:
self["model_stage"] = None
if "engine_output_type" not in self:
self["engine_output_type"] = None
# Also set as attributes for object-style access
for key, value in self.items():
setattr(self, key, value)
class _FakeStageConfig:
"""Fake stage config object that mimics the real stage config structure."""
def __init__(self, config_dict: dict[str, Any]):
# engine_args needs to work both as object (for OmniStage) and as dict (for **kwargs)
engine_args_dict = config_dict.get("engine_args", {})
self.engine_args = _FakeEngineArgs(engine_args_dict)
self.final_output = config_dict.get("final_output", False)
self.final_output_type = config_dict.get("final_output_type", None)
self.stage_id = config_dict.get("stage_id", 0)
# Store original dict for reference
self._config_dict = config_dict
class _FakeQueue:
"""Fake queue using standard library Queue to replace mp.Queue."""
def __init__(self, maxsize=0):
self._queue = Queue(maxsize=maxsize)
def put(self, item):
self._queue.put(item)
def put_nowait(self, item):
self._queue.put_nowait(item)
def get(self):
return self._queue.get()
def get_nowait(self):
return self._queue.get_nowait()
def empty(self):
return self._queue.empty()
class _FakeStage:
"""Lightweight Stage stub for multi-process pipeline version with queue support."""
def __init__(self, config, stage_init_timeout: int = 300):
# Handle both dict and object configs
if isinstance(config, dict):
config = _FakeStageConfig(config)
self.config = config
self.stage_config = config
self.engine = None
self.engine_outputs = None
# Set attributes that OmniStage expects
self.stage_id = getattr(config, "stage_id", 0)
self.engine_args = config.engine_args
self.model_stage = getattr(config.engine_args, "model_stage", None)
self.stage_type = "diffusion"
# set default sampling params
self.default_sampling_params = OmniDiffusionSamplingParams(num_inference_steps=1)
# Allow configuring final_output and final_output_type
self.final_output = config.final_output if hasattr(config, "final_output") else False
self.final_output_type = getattr(config, "final_output_type", None)
# Configurable processing logic, default returns placeholder
processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"])
self._processed_input = processed_input
# Queue references (set by attach_queues)
self._in_q = None
self._out_q = None
self._proc = None # Mock process reference
self._stage_init_timeout = max(0, int(stage_init_timeout))
def attach_queues(self, in_q, out_q):
"""Attach input and output queues."""
self._in_q = in_q
self._out_q = out_q
def init_stage_worker(
self,
model: str,
*,
is_async: bool = False,
shm_threshold_bytes: int = 65536,
ctx=None,
batch_timeout: int = 10,
**kwargs,
):
"""Mock init_stage_worker: don't start real process, just send stage_ready message."""
# Create a mock process object
self._proc = MagicMock()
self._proc.start = MagicMock()
self._proc.join = MagicMock()
self._proc.is_alive = MagicMock(return_value=False)
self._proc.terminate = MagicMock()
# Send stage_ready message to output queue
if self._out_q is not None:
try:
self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id})
except Exception:
pass
def stop_stage_worker(self):
"""Mock stop_stage_worker: clean up queue references."""
if self._in_q is not None:
try:
self._in_q.put_nowait(SHUTDOWN_TASK)
except Exception:
pass
def submit(self, payload: dict[str, Any]):
"""Submit task to input queue."""
if self._in_q is not None:
self._in_q.put(payload)
def try_collect(self) -> Any:
"""Non-blocking collect from output queue."""
if self._out_q is None:
return None
try:
return self._out_q.get_nowait()
except Empty:
return None
def set_engine_outputs(self, outputs):
"""Set engine outputs for the stage."""
self.engine_outputs = outputs
def process_engine_inputs(self, stage_list, prompts):
"""Process engine inputs: return preset processed result."""
return self._processed_input
class _FakeEngine:
"""Lightweight Engine stub: provides generate iterator output."""
def __init__(self, outputs: list[Any]):
self._outputs = outputs
def generate(self, prompts, sampling_params):
# Record the most recent prompts for outer assertions
self._last_prompts = prompts
# Simplified: return preset list at once, ensuring iterability
yield from self._outputs
@pytest.fixture
def fake_stage_config():
return {
# Don't include 'model' in engine_args since it's passed separately
"engine_args": {},
"final_output": True,
"final_output_type": "text",
# Second stage will use processed_input to verify the chain
"processed_input": ["processed-by-stage"],
}
def _setup_engine_mocks(monkeypatch):
"""Helper function to set up common engine mocks."""
fake_engine = MagicMock()
# Add necessary attributes to fake_engine
fake_engine.tokenizer = MagicMock()
fake_engine.log_stats = False
fake_engine.vllm_config = MagicMock()
fake_engine.vllm_config.model_config = MagicMock()
fake_engine.vllm_config.model_config.io_processor_plugin = None
fake_engine.get_supported_tasks = MagicMock(return_value=[])
fake_engine.model_config = MagicMock()
fake_engine.model_config.io_processor_plugin = None
# Add registry with resolve_model_cls method
fake_registry = MagicMock()
fake_registry.resolve_model_cls = MagicMock(return_value=(MagicMock(), "test_arch"))
fake_engine.model_config.registry = fake_registry
fake_engine.vllm_config.model_config.registry = fake_registry
monkeypatch.setattr(
"vllm.v1.engine.llm_engine.LLMEngine.from_engine_args",
lambda **kw: fake_engine,
raising=False,
)
# Mock model_config.registry.resolve_model_cls to return a tuple
# Use a real class instead of MagicMock to avoid inspect.getsource issues
class FakeModelClass:
pass
monkeypatch.setattr(
"vllm.model_executor.model_loader.utils.get_model_architecture",
lambda model_config: (FakeModelClass, "test_arch"),
raising=False,
)
monkeypatch.setattr(
"vllm.model_executor.model_loader.utils._get_model_architecture",
lambda model_config: (FakeModelClass, "test_arch"),
raising=False,
)
# Mock try_create_mm_pooling_model_cls to return the class as-is
monkeypatch.setattr(
"vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls",
lambda model_cls: model_cls,
raising=False,
)
# Mock _enable_processor_cache to return False
monkeypatch.setattr(
"vllm.multimodal.cache._enable_processor_cache",
lambda model_config, mm_registry: False,
raising=False,
)
# Mock get_io_processor to return None
monkeypatch.setattr(
"vllm.plugins.io_processors.get_io_processor",
lambda vllm_config, io_processor_plugin: None,
raising=False,
)
def _setup_multiprocessing_mocks(monkeypatch):
"""Helper function to set up multiprocessing mocks."""
import multiprocessing as mp
# Mock Process
fake_process_class = MagicMock()
fake_process_instance = MagicMock()
fake_process_instance.start = MagicMock()
fake_process_instance.join = MagicMock()
fake_process_instance.is_alive = MagicMock(return_value=False)
fake_process_instance.terminate = MagicMock()
fake_process_class.return_value = fake_process_instance
# Mock get_context to return a context with Queue that returns _FakeQueue
fake_ctx = MagicMock()
fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize)
fake_ctx.Process = fake_process_class
def _mock_get_context(method):
return fake_ctx
monkeypatch.setattr(mp, "get_context", _mock_get_context, raising=False)
monkeypatch.setattr(mp, "Process", fake_process_class, raising=False)
def _setup_ipc_mocks(monkeypatch):
"""Helper function to set up IPC function mocks."""
# Mock _encode: simple serialization
def _fake_encode(obj, threshold, obj_key, shm_key):
return {obj_key: obj}
# Mock _load: extract object from result
def _fake_load(result, obj_key, shm_key):
return result.get(obj_key)
# Mock _set: calculate serialization size
def _fake_set(obj):
return str(obj).encode()
monkeypatch.setattr("vllm_omni.entrypoints.omni._encode", _fake_encode, raising=False)
monkeypatch.setattr("vllm_omni.entrypoints.omni._load", _fake_load, raising=False)
monkeypatch.setattr("vllm_omni.entrypoints.omni._set", _fake_set, raising=False)
def _setup_log_mocks(monkeypatch):
"""Helper function to set up logging and stats mocks."""
# Mock OrchestratorMetrics to be a simple class that doesn't require file operations
class _FakeOrchestratorMetrics:
def __init__(self, num_stages, enable_stats, wall_start_ts):
self.num_stages = num_stages
self.enable_stats = enable_stats
self.stage_first_ts = [None] * num_stages
self.stage_last_ts = [None] * num_stages
self.e2e_done = set()
def on_stage_metrics(self, stage_id, req_id, metrics):
pass
def on_finalize_request(self, stage_id, req_id, start_ts):
self.e2e_done.add(req_id)
def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm):
pass
def build_and_log_summary(self, final_stage_id):
return "Fake summary"
monkeypatch.setattr(
"vllm_omni.entrypoints.omni.OrchestratorMetrics",
_FakeOrchestratorMetrics,
raising=False,
)
@pytest.fixture(autouse=True)
def mock_get_config(monkeypatch):
"""Auto-mock get_config and related model loading functions to avoid model path validation."""
# CRITICAL: Mock tokenizer-related imports FIRST, before any module imports
# This prevents ImportError when async_omni is imported (which happens via omni_stage)
import sys
fake_tokenizer = MagicMock()
fake_tokenizer.encode = MagicMock(return_value=[1, 2, 3])
fake_tokenizer.decode = MagicMock(return_value="test")
# Mock init_tokenizer_from_configs (used in async_omni)
def _mock_init_tokenizer_from_configs(model_config=None, **kwargs):
return fake_tokenizer
# Strategy 1: Mock in the original location (vllm.transformers_utils.tokenizer)
# This works if the module hasn't been imported yet
monkeypatch.setattr(
"vllm.transformers_utils.tokenizer.init_tokenizer_from_configs",
_mock_init_tokenizer_from_configs,
raising=False,
)
# Strategy 2: If the module is already in sys.modules, patch it directly
tokenizer_module_path = "vllm.transformers_utils.tokenizer"
if tokenizer_module_path in sys.modules:
tokenizer_module = sys.modules[tokenizer_module_path]
setattr(tokenizer_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs)
# CRITICAL: Mock length_from_prompt_token_ids_or_embeds BEFORE trying to mock async_omni
# This is because async_omni imports processor.py, which imports this function at module level
# Mock length_from_prompt_token_ids_or_embeds (used in processor.py)
def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None):
# Return a reasonable default length
if prompt_token_ids is not None:
if isinstance(prompt_token_ids, list):
return len(prompt_token_ids)
elif hasattr(prompt_token_ids, "shape"):
return prompt_token_ids.shape[-1] if len(prompt_token_ids.shape) > 0 else 1
if prompt_embeds is not None:
if hasattr(prompt_embeds, "shape"):
return prompt_embeds.shape[-2] if len(prompt_embeds.shape) > 1 else 1
return 10 # Default length
# Mock in vllm.utils
monkeypatch.setattr(
"vllm.utils.length_from_prompt_token_ids_or_embeds",
_mock_length_from_prompt_token_ids_or_embeds,
raising=False,
)
# Also mock in processor module if it's imported
monkeypatch.setattr(
"vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds",
_mock_length_from_prompt_token_ids_or_embeds,
raising=False,
)
# If processor module is already imported, patch it directly
processor_module_path = "vllm_omni.engine.input_processor"
if processor_module_path in sys.modules:
processor_module = sys.modules[processor_module_path]
setattr(
processor_module, "length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds
)
# Strategy 3: Now mock async_omni AFTER length_from_prompt_token_ids_or_embeds is mocked
# This prevents ImportError when async_omni imports processor.py
monkeypatch.setattr(
"vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs",
_mock_init_tokenizer_from_configs,
raising=False,
)
# Strategy 4: If async_omni is already imported, patch it directly
async_omni_path = "vllm_omni.entrypoints.async_omni"
if async_omni_path in sys.modules:
async_omni_module = sys.modules[async_omni_path]
setattr(async_omni_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs)
# Now mock get_config and other functions
fake_hf_config = MagicMock()
fake_hf_config.model_type = "qwen2_5_omni"
def _mock_get_config(model, **kwargs):
return fake_hf_config
monkeypatch.setattr(
"vllm.transformers_utils.config.get_config",
_mock_get_config,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.get_config",
_mock_get_config,
raising=False,
)
# Mock transformers' cached_file to avoid downloading model configs
def _mock_cached_file(path_or_repo_id, *args, **kwargs):
import os
import tempfile
fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json")
if not os.path.exists(fake_config_file):
with open(fake_config_file, "w") as f:
f.write('{"model_type": "qwen2_5_omni"}')
return fake_config_file
monkeypatch.setattr(
"transformers.utils.hub.cached_file",
_mock_cached_file,
raising=False,
)
monkeypatch.setattr(
"transformers.utils.hub.cached_files",
lambda path_or_repo_id, filenames, **kwargs: (
[_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None
),
raising=False,
)
def test_initialize_stage_configs_called_when_none(monkeypatch, fake_stage_config):
"""Test that stage configs are auto-loaded when stage_configs_path is None."""
def _fake_loader(model: str, base_engine_args=None):
return [
_FakeStageConfig(fake_stage_config),
_FakeStageConfig(fake_stage_config),
]
# Remove modules from cache BEFORE setting mocks
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
# Set up mocks
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
# Mock load_stage_configs_from_model
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
# Replace OmniStage
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
# Import the module after mocks are set
import vllm_omni.entrypoints.omni as omni_module
# Patch the imported function and class in the module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Verify: auto-loaded stage_configs and stage_list have consistent count
assert isinstance(omni.stage_configs, list)
assert len(omni.stage_configs) == 2
assert len(omni.stage_list) == 2
# Verify: each Stage is _FakeStage instance
for st in omni.stage_list:
assert isinstance(st, _FakeStage)
# Verify: queues are attached
for st in omni.stage_list:
assert st._in_q is not None
assert st._out_q is not None
# Verify: all stages are ready
assert len(omni._stages_ready) == 2
def test_generate_raises_on_length_mismatch(monkeypatch, fake_stage_config):
"""Test that generate raises ValueError when sampling_params_list length doesn't match."""
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(fake_stage_config)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
with pytest.raises(ValueError):
omni.generate(prompts=["hi"], sampling_params_list=[])
def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config):
"""Test multi-stage generation pipeline with queue polling."""
stage_cfg0 = dict(fake_stage_config)
stage_cfg1 = dict(fake_stage_config)
stage_cfg1["processed_input"] = ["processed-for-stage-1"]
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000")
monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid)
monkeypatch.setattr(omni_module, "uuid", uuid)
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id = f"0_{test_uuid}"
# Simulate worker behavior: manually put results into output queues
# Note: We put results before calling generate, which simulates worker processes
# that have already completed. The polling loop will collect them in stage order.
# Stage 0 output (will be collected first)
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 0, "text": "s0"}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
# Stage 1 output (will be collected after stage 0 forwards to it)
# Note: In real flow, stage 1 result would appear after stage 0 forwards,
# but for testing we pre-populate it. The polling loop processes stages
# in order, so stage 0 result will be collected first, then forwarded,
# then stage 1 result will be collected.
omni.stage_list[1]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 1, "text": "s1"}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
sampling_params_list = [
OmniDiffusionSamplingParams(num_inference_steps=1),
OmniDiffusionSamplingParams(num_inference_steps=1, max_sequence_length=10),
]
prompts = ["hi"]
outputs = omni.generate(prompts=prompts, sampling_params_list=sampling_params_list)
# Both stages have final_output=True, so should aggregate two OmniRequestOutput
assert len(outputs) == 2
# Verify stage outputs are set
assert omni.stage_list[0].engine_outputs == [{"stage": 0, "text": "s0"}]
assert omni.stage_list[1].engine_outputs == [{"stage": 1, "text": "s1"}]
# Verify stage 0 input queue received the task
assert not omni.stage_list[0]._in_q.empty()
# Verify stage 1 received forwarded task (process_engine_inputs was called)
assert omni.stage_list[1].process_engine_inputs([], []) is not None
def test_generate_pipeline_with_batch_input(monkeypatch, fake_stage_config):
"""Test single-stage generation pipeline with multiple inputs in one batch."""
stage_cfg0 = dict(fake_stage_config)
stage_cfg1 = dict(fake_stage_config)
stage_cfg0["final_output"] = False
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000")
monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid)
monkeypatch.setattr(omni_module, "uuid", uuid)
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id = f"0_{test_uuid}"
# Simulate worker behavior: manually put results into output queues
# Note: We put results before calling generate, which simulates worker processes
# that have already completed. The polling loop will collect them in stage order.
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 0, "text": "s0"}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 0, "text": "s0"}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
omni.stage_list[1]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 1}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
omni.stage_list[1]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 1}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
outputs = omni.generate(
prompts=[
{
"prompt": "hi",
"negative_prompt": "hi",
"multi_modal_data": {"image": ["dog.jpg", "cat.jpg"]},
},
{
"prompt": "hi",
"negative_prompt": "hi",
"multi_modal_data": {"image": ["dog.jpg", "cat.jpg"]},
},
],
sampling_params_list=[
OmniDiffusionSamplingParams(num_inference_steps=1),
OmniDiffusionSamplingParams(num_inference_steps=1),
],
)
assert len(outputs) == 2
def test_generate_no_final_output_returns_empty(monkeypatch, fake_stage_config):
"""Test that generate returns empty list when all stages have final_output=False."""
stage_cfg0 = dict(fake_stage_config)
stage_cfg1 = dict(fake_stage_config)
stage_cfg0["final_output"] = False
stage_cfg1["final_output"] = False
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000")
monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid)
monkeypatch.setattr(omni_module, "uuid", uuid)
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id = f"0_{test_uuid}"
# Simulate worker behavior: put results into output queues
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 0}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
omni.stage_list[1]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 1}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
outputs = omni.generate(
prompts=["p"],
sampling_params_list=[
OmniDiffusionSamplingParams(num_inference_steps=1),
OmniDiffusionSamplingParams(num_inference_steps=1, max_sequence_length=10),
],
)
assert outputs == []
def test_generate_sampling_params_none_use_default(monkeypatch, fake_stage_config):
"""Test that generate uses default sampling params when sampling_params_list is None."""
stage_cfg0 = dict(fake_stage_config)
stage_cfg1 = dict(fake_stage_config)
stage_cfg0["final_output"] = False
stage_cfg1["final_output"] = False
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000")
monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid)
monkeypatch.setattr(omni_module, "uuid", uuid)
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id = f"0_{test_uuid}"
# Simulate worker behavior: put results into output queues
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 0}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
omni.stage_list[1]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 1}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
# Use the default sampling params
omni.generate(prompts=["p"], sampling_params_list=None)
def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config):
"""Test that _wait_for_stages_ready handles timeout correctly."""
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(fake_stage_config)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
# Create a stage that doesn't send stage_ready message
class _FakeStageNoReady(_FakeStage):
def init_stage_worker(self, *args, **kwargs):
# Don't send stage_ready message
self._proc = MagicMock()
self._proc.start = MagicMock()
self._proc.join = MagicMock()
self._proc.is_alive = MagicMock(return_value=False)
self._proc.terminate = MagicMock()
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs))
from vllm_omni.entrypoints.omni import Omni
# Use very short timeout
omni = Omni(model="any", init_timeout=0.01)
# Verify that no stages are ready
assert len(omni._stages_ready) == 0
def test_generate_handles_error_messages(monkeypatch, fake_stage_config):
"""Test that generate handles error messages from stages correctly."""
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(fake_stage_config)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000")
monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid)
monkeypatch.setattr(omni_module, "uuid", uuid)
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id = f"0_{test_uuid}"
# Put error message in output queue
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"error": "test error",
}
)
# Also put a valid result after error to allow the loop to complete
# (error handling continues the loop, so we need a valid result to finish)
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 0, "text": "result"}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
# Generate should handle error gracefully (log but continue)
sampling_params_list = [OmniDiffusionSamplingParams(num_inference_steps=1)]
outputs = omni.generate(prompts=["hi"], sampling_params_list=sampling_params_list)
# Should return final output (error was logged but didn't stop processing)
assert isinstance(outputs, list)
# Since final_output=True, should have one output
assert len(outputs) == 1
def test_close_sends_shutdown_signal(monkeypatch, fake_stage_config):
"""Test that close() sends shutdown signal to all input queues."""
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(fake_stage_config)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Call close
omni.close()
# Verify shutdown signal (None) was sent to input queue
# Use get_nowait to avoid blocking (close() uses put_nowait, so should be safe)
try:
shutdown_signal = omni.stage_list[0]._in_q.get_nowait()
assert shutdown_signal == SHUTDOWN_TASK
except Empty:
# If queue was already empty or only had stage_ready, that's also acceptable
# The important thing is that close() was called without error
pass
# Verify stop_stage_worker was called (process should be set)
assert omni.stage_list[0]._proc is not None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment