Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm-omni
Commits
c1cacde6
Commit
c1cacde6
authored
Mar 25, 2026
by
weishb
Browse files
vllm-omni_0.15.0.rc1+fix1 first commit
parent
35607782
Changes
306
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4886 additions
and
0 deletions
+4886
-0
tests/e2e/offline_inference/test_t2v_model.py
tests/e2e/offline_inference/test_t2v_model.py
+64
-0
tests/e2e/offline_inference/test_teacache.py
tests/e2e/offline_inference/test_teacache.py
+88
-0
tests/e2e/offline_inference/test_zimage_tensor_parallel.py
tests/e2e/offline_inference/test_zimage_tensor_parallel.py
+185
-0
tests/e2e/offline_inference/utils.py
tests/e2e/offline_inference/utils.py
+210
-0
tests/e2e/online_serving/__init__.py
tests/e2e/online_serving/__init__.py
+0
-0
tests/e2e/online_serving/stage_configs/qwen3_omni_ci.yaml
tests/e2e/online_serving/stage_configs/qwen3_omni_ci.yaml
+103
-0
tests/e2e/online_serving/stage_configs/qwen3_omni_thinker_ci.yaml
...e/online_serving/stage_configs/qwen3_omni_thinker_ci.yaml
+31
-0
tests/e2e/online_serving/stage_configs/rocm/qwen3_omni_ci.yaml
.../e2e/online_serving/stage_configs/rocm/qwen3_omni_ci.yaml
+95
-0
tests/e2e/online_serving/test_async_omni.py
tests/e2e/online_serving/test_async_omni.py
+161
-0
tests/e2e/online_serving/test_image_gen_edit.py
tests/e2e/online_serving/test_image_gen_edit.py
+273
-0
tests/e2e/online_serving/test_images_generations_lora.py
tests/e2e/online_serving/test_images_generations_lora.py
+193
-0
tests/e2e/online_serving/test_qwen3_omni.py
tests/e2e/online_serving/test_qwen3_omni.py
+273
-0
tests/e2e/online_serving/test_qwen3_omni_expansion.py
tests/e2e/online_serving/test_qwen3_omni_expansion.py
+312
-0
tests/e2e/stage_configs/qwen3_omni_ci.yaml
tests/e2e/stage_configs/qwen3_omni_ci.yaml
+98
-0
tests/entrypoints/openai_api/__init__.py
tests/entrypoints/openai_api/__init__.py
+0
-0
tests/entrypoints/openai_api/test_image_server.py
tests/entrypoints/openai_api/test_image_server.py
+816
-0
tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
...trypoints/openai_api/test_serving_chat_sampling_params.py
+333
-0
tests/entrypoints/openai_api/test_serving_speech.py
tests/entrypoints/openai_api/test_serving_speech.py
+473
-0
tests/entrypoints/test_async_omni_diffusion_config.py
tests/entrypoints/test_async_omni_diffusion_config.py
+75
-0
tests/entrypoints/test_omni_diffusion.py
tests/entrypoints/test_omni_diffusion.py
+1103
-0
No files found.
Too many changes to show.
To preserve performance only
306 of 306+
files are displayed.
Plain diff
Email patch
tests/e2e/offline_inference/test_t2v_model.py
0 → 100644
View file @
c1cacde6
import
os
import
sys
from
pathlib
import
Path
import
pytest
import
torch
from
vllm_omni.inputs.data
import
OmniDiffusionSamplingParams
# ruff: noqa: E402
REPO_ROOT
=
Path
(
__file__
).
resolve
().
parents
[
2
]
if
str
(
REPO_ROOT
)
not
in
sys
.
path
:
sys
.
path
.
insert
(
0
,
str
(
REPO_ROOT
))
from
vllm_omni
import
Omni
from
vllm_omni.outputs
import
OmniRequestOutput
# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
models
=
[
"Wan-AI/Wan2.2-T2V-A14B-Diffusers"
]
@
pytest
.
mark
.
parametrize
(
"model_name"
,
models
)
def
test_video_diffusion_model
(
model_name
:
str
):
m
=
Omni
(
model
=
model_name
,
boundary_ratio
=
0.875
,
flow_shift
=
5.0
,
)
# Use minimal settings for testing
# num_frames must satisfy: num_frames % vae_scale_factor_temporal == 1
# For Wan2.2, vae_scale_factor_temporal=4, so valid values are 5, 9, 13, 17, ...
height
=
480
width
=
640
num_frames
=
5
outputs
=
m
.
generate
(
prompts
=
"A cat sitting on a table"
,
sampling_params_list
=
OmniDiffusionSamplingParams
(
height
=
height
,
width
=
width
,
num_frames
=
num_frames
,
num_inference_steps
=
2
,
guidance_scale
=
1.0
,
generator
=
torch
.
Generator
(
"cuda"
).
manual_seed
(
42
),
),
)
first_output
=
outputs
[
0
]
assert
first_output
.
final_output_type
==
"image"
if
not
hasattr
(
first_output
,
"request_output"
)
or
not
first_output
.
request_output
:
raise
ValueError
(
"No request_output found in OmniRequestOutput"
)
req_out
=
first_output
.
request_output
[
0
]
if
not
isinstance
(
req_out
,
OmniRequestOutput
)
or
not
hasattr
(
req_out
,
"images"
):
raise
ValueError
(
"Invalid request_output structure or missing 'images' key"
)
frames
=
req_out
.
images
[
0
]
assert
frames
is
not
None
assert
hasattr
(
frames
,
"shape"
)
# frames shape: (batch, num_frames, height, width, channels)
assert
frames
.
shape
[
1
]
==
num_frames
assert
frames
.
shape
[
2
]
==
height
assert
frames
.
shape
[
3
]
==
width
tests/e2e/offline_inference/test_teacache.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
System test for TeaCache backend.
This test verifies that TeaCache acceleration works correctly with diffusion models.
It uses minimal settings to keep test time short for CI.
"""
import
os
import
sys
from
pathlib
import
Path
import
pytest
import
torch
from
vllm_omni.inputs.data
import
OmniDiffusionSamplingParams
# ruff: noqa: E402
REPO_ROOT
=
Path
(
__file__
).
resolve
().
parents
[
2
]
if
str
(
REPO_ROOT
)
not
in
sys
.
path
:
sys
.
path
.
insert
(
0
,
str
(
REPO_ROOT
))
from
vllm_omni
import
Omni
from
vllm_omni.outputs
import
OmniRequestOutput
os
.
environ
[
"VLLM_TEST_CLEAN_GPU_MEMORY"
]
=
"1"
# Use random weights model for testing
models
=
[
"riverclouds/qwen_image_random"
]
@
pytest
.
mark
.
parametrize
(
"model_name"
,
models
)
def
test_teacache
(
model_name
:
str
):
"""Test TeaCache backend with diffusion model."""
# Configure TeaCache with default settings for fast testing
cache_config
=
{
"rel_l1_thresh"
:
0.2
,
# Default threshold
}
m
=
None
try
:
m
=
Omni
(
model
=
model_name
,
cache_backend
=
"tea_cache"
,
cache_config
=
cache_config
,
)
# Use minimal settings for fast testing
height
=
256
width
=
256
num_inference_steps
=
4
# Minimal steps for fast test
outputs
=
m
.
generate
(
"a photo of a cat sitting on a laptop keyboard"
,
OmniDiffusionSamplingParams
(
height
=
height
,
width
=
width
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
0.0
,
generator
=
torch
.
Generator
(
"cuda"
).
manual_seed
(
42
),
num_outputs_per_prompt
=
1
,
# Single output for speed
),
)
# Extract images from request_output[0]['images']
first_output
=
outputs
[
0
]
assert
first_output
.
final_output_type
==
"image"
if
not
hasattr
(
first_output
,
"request_output"
)
or
not
first_output
.
request_output
:
raise
ValueError
(
"No request_output found in OmniRequestOutput"
)
req_out
=
first_output
.
request_output
[
0
]
if
not
isinstance
(
req_out
,
OmniRequestOutput
)
or
not
hasattr
(
req_out
,
"images"
):
raise
ValueError
(
"Invalid request_output structure or missing 'images' key"
)
images
=
req_out
.
images
# Verify generation succeeded
assert
images
is
not
None
assert
len
(
images
)
==
1
# Check image size
assert
images
[
0
].
width
==
width
assert
images
[
0
].
height
==
height
except
Exception
as
e
:
print
(
f
"Test failed with error:
{
e
}
"
)
raise
finally
:
if
m
is
not
None
and
hasattr
(
m
,
"close"
):
m
.
close
()
tests/e2e/offline_inference/test_zimage_tensor_parallel.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
sys
import
time
from
pathlib
import
Path
import
numpy
as
np
import
pytest
import
torch
from
PIL
import
Image
from
vllm.distributed.parallel_state
import
cleanup_dist_env_and_memory
from
vllm_omni.inputs.data
import
OmniDiffusionSamplingParams
# ruff: noqa: E402
REPO_ROOT
=
Path
(
__file__
).
resolve
().
parents
[
2
]
if
str
(
REPO_ROOT
)
not
in
sys
.
path
:
sys
.
path
.
insert
(
0
,
str
(
REPO_ROOT
))
from
tests.utils
import
GPUMemoryMonitor
from
vllm_omni
import
Omni
from
vllm_omni.diffusion.data
import
DiffusionParallelConfig
from
vllm_omni.outputs
import
OmniRequestOutput
from
vllm_omni.platforms
import
current_omni_platform
# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
PROMPT
=
"a photo of a cat sitting on a laptop keyboard"
def
_get_zimage_model
()
->
str
:
# Allow overriding the model for local/offline environments.
# Can be either a HuggingFace repo id or a local path.
return
os
.
environ
.
get
(
"VLLM_TEST_ZIMAGE_MODEL"
,
"Tongyi-MAI/Z-Image-Turbo"
)
def
_pil_to_float_rgb_tensor
(
img
:
Image
.
Image
)
->
torch
.
Tensor
:
"""Convert PIL image to float32 RGB tensor in [0, 1] with shape [H, W, 3]."""
arr
=
np
.
asarray
(
img
.
convert
(
"RGB"
),
dtype
=
np
.
float32
)
/
255.0
return
torch
.
from_numpy
(
arr
)
def
_diff_metrics
(
a
:
Image
.
Image
,
b
:
Image
.
Image
)
->
tuple
[
float
,
float
]:
"""Return (mean_abs_diff, max_abs_diff) over RGB pixels in [0, 1]."""
ta
=
_pil_to_float_rgb_tensor
(
a
)
tb
=
_pil_to_float_rgb_tensor
(
b
)
assert
ta
.
shape
==
tb
.
shape
,
f
"Image shapes differ:
{
ta
.
shape
}
vs
{
tb
.
shape
}
"
abs_diff
=
torch
.
abs
(
ta
-
tb
)
return
abs_diff
.
mean
().
item
(),
abs_diff
.
max
().
item
()
def
_extract_single_image
(
outputs
)
->
Image
.
Image
:
first_output
=
outputs
[
0
]
assert
first_output
.
final_output_type
==
"image"
if
not
hasattr
(
first_output
,
"request_output"
)
or
not
first_output
.
request_output
:
raise
ValueError
(
"No request_output found in OmniRequestOutput"
)
req_out
=
first_output
.
request_output
[
0
]
if
not
isinstance
(
req_out
,
OmniRequestOutput
)
or
not
hasattr
(
req_out
,
"images"
):
raise
ValueError
(
"Invalid request_output structure or missing 'images' key"
)
images
=
req_out
.
images
if
images
is
None
or
len
(
images
)
!=
1
:
raise
ValueError
(
f
"Expected 1 image, got
{
0
if
images
is
None
else
len
(
images
)
}
"
)
return
images
[
0
]
def
_run_zimage_generate
(
*
,
tp_size
:
int
,
height
:
int
,
width
:
int
,
num_inference_steps
:
int
,
seed
:
int
)
->
tuple
[
Image
.
Image
,
float
,
float
]:
torch
.
cuda
.
empty_cache
()
device_index
=
torch
.
cuda
.
current_device
()
monitor
=
GPUMemoryMonitor
(
device_index
=
device_index
,
interval
=
0.02
)
monitor
.
start
()
m
=
Omni
(
model
=
_get_zimage_model
(),
parallel_config
=
DiffusionParallelConfig
(
tensor_parallel_size
=
tp_size
),
)
try
:
# NOTE: Omni closes itself when a generate() call is exhausted.
# To avoid measuring teardown time (process shutdown, memory cleanup),
# we measure the latency to produce *subsequent* outputs within a single
# generator run.
#
# This also serves as a warmup: the first output may include extra
# compilation/caching overhead, while later outputs are closer to
# steady-state inference.
num_requests
=
4
# 1 warmup + 3 timed
gen
=
m
.
generate
(
[
PROMPT
]
*
num_requests
,
OmniDiffusionSamplingParams
(
height
=
height
,
width
=
width
,
num_inference_steps
=
num_inference_steps
,
guidance_scale
=
0.0
,
seed
=
seed
,
num_outputs_per_prompt
=
1
,
),
py_generator
=
True
,
)
warmup_output
=
next
(
gen
)
t_prev
=
time
.
perf_counter
()
per_request_times_s
:
list
[
float
]
=
[]
last_output
=
warmup_output
for
_
in
range
(
num_requests
-
1
):
last_output
=
next
(
gen
)
t_now
=
time
.
perf_counter
()
per_request_times_s
.
append
(
t_now
-
t_prev
)
t_prev
=
t_now
# Ensure the generator is fully consumed so it can clean up.
for
_
in
gen
:
pass
median_time_s
=
float
(
np
.
median
(
per_request_times_s
))
peak_memory_mb
=
monitor
.
peak_used_mb
return
_extract_single_image
([
last_output
]),
median_time_s
,
peak_memory_mb
finally
:
monitor
.
stop
()
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
integration
def
test_zimage_tensor_parallel_tp2
(
tmp_path
:
Path
):
if
current_omni_platform
.
is_npu
()
or
current_omni_platform
.
is_rocm
():
pytest
.
skip
(
"Z-Image TP e2e test is only supported on CUDA for now."
)
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"Z-Image TP=2 requires >= 2 CUDA devices."
)
height
=
512
width
=
512
num_inference_steps
=
2
seed
=
42
tp1_img
,
tp1_time_s
,
tp1_peak_mem
=
_run_zimage_generate
(
tp_size
=
1
,
height
=
height
,
width
=
width
,
num_inference_steps
=
num_inference_steps
,
seed
=
seed
,
)
tp2_img
,
tp2_time_s
,
tp2_peak_mem
=
_run_zimage_generate
(
tp_size
=
2
,
height
=
height
,
width
=
width
,
num_inference_steps
=
num_inference_steps
,
seed
=
seed
,
)
tp1_path
=
tmp_path
/
"zimage_tp1.png"
tp2_path
=
tmp_path
/
"zimage_tp2.png"
tp1_img
.
save
(
tp1_path
)
tp2_img
.
save
(
tp2_path
)
assert
tp1_img
.
width
==
width
and
tp1_img
.
height
==
height
assert
tp2_img
.
width
==
width
and
tp2_img
.
height
==
height
mean_abs_diff
,
max_abs_diff
=
_diff_metrics
(
tp1_img
,
tp2_img
)
mean_threshold
=
3e-2
max_threshold
=
5e-1
print
(
"Z-Image TP image diff stats (TP=1 vs TP=2): "
f
"mean_abs_diff=
{
mean_abs_diff
:.
6
e
}
, max_abs_diff=
{
max_abs_diff
:.
6
e
}
; "
f
"thresholds: mean<=
{
mean_threshold
:.
6
e
}
, max<=
{
max_threshold
:.
6
e
}
; "
f
"tp1_img=
{
tp1_path
}
, tp2_img=
{
tp2_path
}
"
)
assert
mean_abs_diff
<=
mean_threshold
and
max_abs_diff
<=
max_threshold
,
(
f
"Image diff exceeded threshold: mean_abs_diff=
{
mean_abs_diff
:.
6
e
}
, max_abs_diff=
{
max_abs_diff
:.
6
e
}
"
f
"(thresholds: mean<=
{
mean_threshold
:.
6
e
}
, max<=
{
max_threshold
:.
6
e
}
)"
)
print
(
f
"Z-Image TP perf (lower is better): tp1_time_s=
{
tp1_time_s
:.
6
f
}
, tp2_time_s=
{
tp2_time_s
:.
6
f
}
"
)
assert
tp2_time_s
<
tp1_time_s
,
f
"Expected TP=2 to be faster than TP=1 (tp1=
{
tp1_time_s
}
, tp2=
{
tp2_time_s
}
)"
print
(
f
"Z-Image TP peak memory (MB): tp1_peak_mem=
{
tp1_peak_mem
:.
2
f
}
, tp2_peak_mem=
{
tp2_peak_mem
:.
2
f
}
"
)
assert
tp2_peak_mem
<
tp1_peak_mem
,
(
f
"Expected TP=2 to use less peak memory than TP=1 (tp1=
{
tp1_peak_mem
}
, tp2=
{
tp2_peak_mem
}
)"
)
tests/e2e/offline_inference/utils.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
functools
import
os
import
signal
import
subprocess
import
sys
import
tempfile
from
collections.abc
import
Callable
from
contextlib
import
ExitStack
,
suppress
from
pathlib
import
Path
from
typing
import
Any
,
Literal
import
cloudpickle
from
typing_extensions
import
ParamSpec
from
vllm.platforms
import
current_platform
VLLM_PATH
=
Path
(
__file__
).
parent
.
parent
.
parent
"""Path to root of the vLLM repository."""
_P
=
ParamSpec
(
"_P"
)
def
fork_new_process_for_each_test
(
func
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
None
:
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os
.
setpgrp
()
from
_pytest.outcomes
import
Skipped
# Create a unique temporary file to store exception info from child
# process. Use test function name and process ID to avoid collisions.
with
(
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
mode
=
"w+b"
,
prefix
=
f
"vllm_test_
{
func
.
__name__
}
_
{
os
.
getpid
()
}
_"
,
suffix
=
".exc"
,
)
as
exc_file
,
ExitStack
()
as
delete_after
,
):
exc_file_path
=
exc_file
.
name
delete_after
.
callback
(
os
.
remove
,
exc_file_path
)
pid
=
os
.
fork
()
print
(
f
"Fork a new process to run a test
{
pid
}
"
)
if
pid
==
0
:
# Parent process responsible for deleting, don't delete
# in child.
delete_after
.
pop_all
()
try
:
func
(
*
args
,
**
kwargs
)
except
Skipped
as
e
:
# convert Skipped to exit code 0
print
(
str
(
e
))
os
.
_exit
(
0
)
except
Exception
as
e
:
import
traceback
tb_string
=
traceback
.
format_exc
()
# Try to serialize the exception object first
exc_to_serialize
:
dict
[
str
,
Any
]
try
:
# First, try to pickle the actual exception with
# its traceback.
exc_to_serialize
=
{
"pickled_exception"
:
e
}
# Test if it can be pickled
cloudpickle
.
dumps
(
exc_to_serialize
)
except
(
Exception
,
KeyboardInterrupt
):
# Fall back to string-based approach.
exc_to_serialize
=
{
"exception_type"
:
type
(
e
).
__name__
,
"exception_msg"
:
str
(
e
),
"traceback"
:
tb_string
,
}
try
:
with
open
(
exc_file_path
,
"wb"
)
as
f
:
cloudpickle
.
dump
(
exc_to_serialize
,
f
)
except
Exception
:
# Fallback: just print the traceback.
print
(
tb_string
)
os
.
_exit
(
1
)
else
:
os
.
_exit
(
0
)
else
:
pgid
=
os
.
getpgid
(
pid
)
_pid
,
_exitcode
=
os
.
waitpid
(
pid
,
0
)
# ignore SIGTERM signal itself
old_signal_handler
=
signal
.
signal
(
signal
.
SIGTERM
,
signal
.
SIG_IGN
)
# kill all child processes
os
.
killpg
(
pgid
,
signal
.
SIGTERM
)
# restore the signal handler
signal
.
signal
(
signal
.
SIGTERM
,
old_signal_handler
)
if
_exitcode
!=
0
:
# Try to read the exception from the child process
exc_info
=
{}
if
os
.
path
.
exists
(
exc_file_path
):
with
(
contextlib
.
suppress
(
Exception
),
open
(
exc_file_path
,
"rb"
)
as
f
,
):
exc_info
=
cloudpickle
.
load
(
f
)
original_exception
=
exc_info
.
get
(
"pickled_exception"
)
if
original_exception
is
not
None
and
isinstance
(
original_exception
,
Exception
):
# Re-raise the actual exception object if it was
# successfully pickled.
raise
original_exception
if
(
original_tb
:
=
exc_info
.
get
(
"traceback"
))
is
not
None
:
# Use string-based traceback for fallback case
raise
AssertionError
(
f
"Test
{
func
.
__name__
}
failed when called with"
f
" args
{
args
}
and kwargs
{
kwargs
}
"
f
" (exit code:
{
_exitcode
}
):
\n
{
original_tb
}
"
)
from
None
# Fallback to the original generic error
raise
AssertionError
(
f
"function
{
func
.
__name__
}
failed when called with"
f
" args
{
args
}
and kwargs
{
kwargs
}
"
f
" (exit code:
{
_exitcode
}
)"
)
from
None
return
wrapper
def
spawn_new_process_for_each_test
(
f
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
"""Decorator to spawn a new process for each test function."""
@
functools
.
wraps
(
f
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
None
:
# Check if we're already in a subprocess
if
os
.
environ
.
get
(
"RUNNING_IN_SUBPROCESS"
)
==
"1"
:
# If we are, just run the function directly
return
f
(
*
args
,
**
kwargs
)
import
torch.multiprocessing
as
mp
with
suppress
(
RuntimeError
):
mp
.
set_start_method
(
"spawn"
)
# Get the module
module_name
=
f
.
__module__
# Create a process with environment variable set
env
=
os
.
environ
.
copy
()
env
[
"RUNNING_IN_SUBPROCESS"
]
=
"1"
with
tempfile
.
TemporaryDirectory
()
as
tempdir
:
output_filepath
=
os
.
path
.
join
(
tempdir
,
"new_process.tmp"
)
# `cloudpickle` allows pickling complex functions directly
input_bytes
=
cloudpickle
.
dumps
((
f
,
output_filepath
))
repo_root
=
str
(
VLLM_PATH
.
resolve
())
env
=
dict
(
env
or
os
.
environ
)
env
[
"PYTHONPATH"
]
=
repo_root
+
os
.
pathsep
+
env
.
get
(
"PYTHONPATH"
,
""
)
cmd
=
[
sys
.
executable
,
"-m"
,
f
"
{
module_name
}
"
]
returned
=
subprocess
.
run
(
cmd
,
input
=
input_bytes
,
capture_output
=
True
,
env
=
env
)
# check if the subprocess is successful
try
:
returned
.
check_returncode
()
except
Exception
as
e
:
# wrap raised exception to provide more information
raise
RuntimeError
(
f
"Error raised in subprocess:
\n
{
returned
.
stderr
.
decode
()
}
"
)
from
e
return
wrapper
def
create_new_process_for_each_test
(
method
:
Literal
[
"spawn"
,
"fork"
]
|
None
=
None
,
)
->
Callable
[[
Callable
[
_P
,
None
]],
Callable
[
_P
,
None
]]:
"""Creates a decorator that runs each test function in a new process.
Args:
method: The process creation method. Can be either "spawn" or "fork".
If not specified, it defaults to "spawn" on ROCm and XPU
platforms and "fork" otherwise.
Returns:
A decorator to run test functions in separate processes.
"""
if
method
is
None
:
# TODO: Find out why spawn is not working correctly on ROCm
# The test content will not run and tests passed immediately.
# For now, using `fork` for ROCm as it can run with `fork`
# and tests are running correctly.
use_spawn
=
current_platform
.
is_xpu
()
method
=
"spawn"
if
use_spawn
else
"fork"
assert
method
in
[
"spawn"
,
"fork"
],
"Method must be either 'spawn' or 'fork'"
if
method
==
"fork"
:
return
fork_new_process_for_each_test
return
spawn_new_process_for_each_test
tests/e2e/online_serving/__init__.py
0 → 100644
View file @
c1cacde6
tests/e2e/online_serving/stage_configs/qwen3_omni_ci.yaml
0 → 100644
View file @
c1cacde6
# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
# Stage 0: Thinker (multimodal understanding + text generation)
# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
# The following config has been verified on 2x H100-80G GPUs.
stage_args
:
-
stage_id
:
0
stage_type
:
llm
# Use llm stage type to launch OmniLLM
runtime
:
devices
:
"
0"
max_batch_size
:
5
engine_args
:
model_stage
:
thinker
model_arch
:
Qwen3OmniMoeForConditionalGeneration
worker_type
:
ar
scheduler_cls
:
vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization
:
0.9
enforce_eager
:
false
trust_remote_code
:
true
engine_output_type
:
latent
# Output hidden states for talker
distributed_executor_backend
:
"
mp"
enable_prefix_caching
:
false
max_num_batched_tokens
:
32768
hf_config_name
:
thinker_config
tensor_parallel_size
:
1
load_format
:
dummy
final_output
:
true
final_output_type
:
text
is_comprehension
:
true
default_sampling_params
:
temperature
:
0.4
top_p
:
0.9
top_k
:
1
max_tokens
:
100
seed
:
42
detokenize
:
True
repetition_penalty
:
1.05
-
stage_id
:
1
stage_type
:
llm
# Use llm stage type to launch OmniLLM
runtime
:
devices
:
"
1"
max_batch_size
:
5
engine_args
:
model_stage
:
talker
model_arch
:
Qwen3OmniMoeForConditionalGeneration
worker_type
:
ar
scheduler_cls
:
vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization
:
0.6
enforce_eager
:
false
trust_remote_code
:
true
engine_output_type
:
latent
# Output codec codes for code2wav
# tensor_parallel_size: 2
enable_prefix_caching
:
false
distributed_executor_backend
:
"
mp"
hf_config_name
:
talker_config
load_format
:
dummy
engine_input_source
:
[
0
]
custom_process_input_func
:
vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
# final_output: true
# final_output_type: text
default_sampling_params
:
temperature
:
0.9
top_k
:
50
max_tokens
:
100
seed
:
42
detokenize
:
False
repetition_penalty
:
1.05
stop_token_ids
:
[
2150
]
-
stage_id
:
2
stage_type
:
llm
# Use llm stage type to launch OmniLLM
runtime
:
devices
:
"
1"
max_batch_size
:
1
engine_args
:
model_stage
:
code2wav
model_arch
:
Qwen3OmniMoeForConditionalGeneration
worker_type
:
generation
scheduler_cls
:
vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager
:
true
trust_remote_code
:
true
async_scheduling
:
false
enable_prefix_caching
:
false
engine_output_type
:
audio
# Final output: audio waveform
gpu_memory_utilization
:
0.1
distributed_executor_backend
:
"
mp"
max_num_batched_tokens
:
1000000
hf_config_name
:
thinker_config
load_format
:
dummy
engine_input_source
:
[
1
]
custom_process_input_func
:
vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
final_output
:
true
final_output_type
:
audio
default_sampling_params
:
temperature
:
0.0
top_p
:
1.0
top_k
:
-1
max_tokens
:
200
seed
:
42
detokenize
:
True
repetition_penalty
:
1.1
tests/e2e/online_serving/stage_configs/qwen3_omni_thinker_ci.yaml
0 → 100644
View file @
c1cacde6
# The following config has been verified on 2x H100-80G GPUs.
stage_args
:
-
stage_id
:
0
runtime
:
devices
:
"
0,1"
max_batch_size
:
5
engine_args
:
model_stage
:
thinker
model_arch
:
Qwen3OmniMoeForConditionalGeneration
worker_type
:
ar
scheduler_cls
:
vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization
:
0.6
enforce_eager
:
true
trust_remote_code
:
true
engine_output_type
:
latent
# Output hidden states for talker
distributed_executor_backend
:
"
mp"
enable_prefix_caching
:
false
hf_config_name
:
thinker_config
tensor_parallel_size
:
2
load_format
:
dummy
final_output
:
true
final_output_type
:
text
is_comprehension
:
true
default_sampling_params
:
temperature
:
0.4
top_p
:
0.9
top_k
:
1
max_tokens
:
100
seed
:
42
detokenize
:
True
repetition_penalty
:
1.05
tests/e2e/online_serving/stage_configs/rocm/qwen3_omni_ci.yaml
0 → 100644
View file @
c1cacde6
# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
# Stage 0: Thinker (multimodal understanding + text generation)
# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
# The following config has been verified on 2x H100-80G GPUs.
stage_args
:
-
stage_id
:
0
runtime
:
devices
:
"
0"
max_batch_size
:
5
engine_args
:
model_stage
:
thinker
model_arch
:
Qwen3OmniMoeForConditionalGeneration
worker_type
:
ar
scheduler_cls
:
vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization
:
0.9
enforce_eager
:
false
trust_remote_code
:
true
engine_output_type
:
latent
# Output hidden states for talker
distributed_executor_backend
:
"
mp"
enable_prefix_caching
:
false
hf_config_name
:
thinker_config
tensor_parallel_size
:
1
final_output
:
true
final_output_type
:
text
is_comprehension
:
true
default_sampling_params
:
temperature
:
0.4
top_p
:
0.9
top_k
:
1
max_tokens
:
100
seed
:
42
detokenize
:
True
repetition_penalty
:
1.05
-
stage_id
:
1
runtime
:
devices
:
"
1"
max_batch_size
:
5
engine_args
:
model_stage
:
talker
model_arch
:
Qwen3OmniMoeForConditionalGeneration
worker_type
:
ar
scheduler_cls
:
vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization
:
0.6
enforce_eager
:
true
trust_remote_code
:
true
engine_output_type
:
latent
# Output codec codes for code2wav
# tensor_parallel_size: 2
enable_prefix_caching
:
false
distributed_executor_backend
:
"
mp"
hf_config_name
:
talker_config
engine_input_source
:
[
0
]
custom_process_input_func
:
vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
# final_output: true
# final_output_type: text
default_sampling_params
:
temperature
:
0.9
top_k
:
50
max_tokens
:
1000
seed
:
42
detokenize
:
False
repetition_penalty
:
1.05
stop_token_ids
:
[
2150
]
-
stage_id
:
2
runtime
:
devices
:
"
1"
max_batch_size
:
1
engine_args
:
model_stage
:
code2wav
model_arch
:
Qwen3OmniMoeForConditionalGeneration
worker_type
:
generation
scheduler_cls
:
vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager
:
true
trust_remote_code
:
true
enable_prefix_caching
:
false
engine_output_type
:
audio
# Final output: audio waveform
gpu_memory_utilization
:
0.1
distributed_executor_backend
:
"
mp"
max_num_batched_tokens
:
1000000
hf_config_name
:
thinker_config
async_scheduling
:
false
engine_input_source
:
[
1
]
custom_process_input_func
:
vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
final_output
:
true
final_output_type
:
audio
default_sampling_params
:
temperature
:
0.0
top_p
:
1.0
top_k
:
-1
max_tokens
:
2000
seed
:
42
detokenize
:
True
repetition_penalty
:
1.1
tests/e2e/online_serving/test_async_omni.py
0 → 100644
View file @
c1cacde6
import
asyncio
import
os
import
sys
from
contextlib
import
ExitStack
from
pathlib
import
Path
import
pytest
from
vllm
import
SamplingParams
from
vllm.inputs
import
PromptType
from
vllm_omni.entrypoints.async_omni
import
AsyncOmni
,
ClientRequestState
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
SEED
=
42
stage_config
=
str
(
Path
(
__file__
).
parent
/
"stage_configs"
/
"qwen3_omni_thinker_ci.yaml"
)
model
=
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
async
def
generate
(
engine
:
AsyncOmni
,
request_id
:
str
,
prompt
:
PromptType
,
max_tokens
:
int
,
)
->
tuple
[
int
,
str
]:
# Ensure generate doesn't complete too fast for cancellation test.
await
asyncio
.
sleep
(
0.2
)
thinker_sampling_params
=
SamplingParams
(
temperature
=
0.4
,
# Deterministic
top_p
=
0.9
,
top_k
=
1
,
max_tokens
=
max_tokens
,
repetition_penalty
=
1.05
,
stop_token_ids
=
[
151645
],
# Qwen EOS token <|im_end|>
seed
=
SEED
,
)
sampling_params_list
=
[
thinker_sampling_params
,
]
count
=
0
async
for
omni_output
in
engine
.
generate
(
prompt
=
prompt
,
request_id
=
request_id
,
sampling_params_list
=
sampling_params_list
,
output_modalities
=
[
"text"
],
):
stage_id
=
omni_output
.
stage_id
out
=
omni_output
.
request_output
if
stage_id
==
0
:
num_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
out
.
outputs
)
count
=
num_tokens
await
asyncio
.
sleep
(
0.0
)
return
count
,
request_id
@
pytest
.
mark
.
asyncio
async
def
test_abort
():
with
ExitStack
()
as
after
:
# Avoid SHM IPC in tests to prevent /dev/shm exhaustion and SIGBUS.
engine
=
AsyncOmni
(
model
=
model
,
stage_configs_path
=
stage_config
,
shm_threshold_bytes
=
sys
.
maxsize
,
)
after
.
callback
(
engine
.
shutdown
)
# Keep token counts modest to reduce flakiness on slow test hardware.
NUM_REQUESTS
=
3
NUM_EXPECTED_TOKENS
=
64
NUM_EXPECTED_TOKENS_LONG
=
256
REQUEST_IDS_TO_ABORT
=
[
1
]
prompt
=
"Hello my name is Robert and "
request_ids
=
[
f
"request-
{
i
}
"
for
i
in
range
(
NUM_REQUESTS
)]
# Create concurrent requests.
tasks
:
list
[
asyncio
.
Task
]
=
[]
for
idx
,
request_id
in
enumerate
(
request_ids
):
max_tokens
=
NUM_EXPECTED_TOKENS_LONG
if
(
idx
in
REQUEST_IDS_TO_ABORT
)
else
NUM_EXPECTED_TOKENS
tasks
.
append
(
asyncio
.
create_task
(
generate
(
engine
,
request_id
,
prompt
,
max_tokens
)))
# API server cancels requests when they disconnect.
# Explicitly abort in the engine to avoid orphaned requests hanging.
for
idx
in
REQUEST_IDS_TO_ABORT
:
tasks
[
idx
].
cancel
()
await
engine
.
abort
(
request_ids
[
idx
])
await
asyncio
.
sleep
(
0.1
)
# Confirm the other requests are okay.
for
idx
,
task
in
enumerate
(
tasks
):
# Confirm that it was actually canceled.
if
idx
in
REQUEST_IDS_TO_ABORT
:
with
pytest
.
raises
((
asyncio
.
CancelledError
,
GeneratorExit
)):
await
asyncio
.
wait_for
(
task
,
timeout
=
60
)
else
:
# Otherwise, make sure the request was not impacted.
num_generated_tokens
,
request_id
=
await
asyncio
.
wait_for
(
task
,
timeout
=
180
)
expected_tokens
=
NUM_EXPECTED_TOKENS
assert
num_generated_tokens
==
expected_tokens
,
(
f
"
{
request_id
}
generated
{
num_generated_tokens
}
but expected
{
expected_tokens
}
"
)
# Confirm we can do another generation.
request_id
=
f
"request-
{
REQUEST_IDS_TO_ABORT
[
0
]
}
"
task
=
asyncio
.
create_task
(
generate
(
engine
,
request_id
,
prompt
,
NUM_EXPECTED_TOKENS
))
num_generated_tokens
,
request_id
=
await
task
assert
num_generated_tokens
==
NUM_EXPECTED_TOKENS
await
asyncio
.
sleep
(
5
)
@
pytest
.
mark
.
asyncio
async
def
test_build_and_log_summary
(
monkeypatch
):
from
vllm_omni.entrypoints.utils
import
get_final_stage_id_for_e2e
RealCRS
=
ClientRequestState
capture_metrics
=
{}
class
MockCRS
(
RealCRS
):
def
__init__
(
self
,
request_id
:
str
):
super
().
__init__
(
request_id
)
capture_metrics
[
request_id
]
=
self
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.async_omni.ClientRequestState"
,
MockCRS
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.client_request_state.ClientRequestState"
,
MockCRS
)
with
ExitStack
()
as
after
:
# Avoid SHM IPC in tests to prevent /dev/shm exhaustion and SIGBUS.
engine
=
AsyncOmni
(
model
=
model
,
stage_configs_path
=
stage_config
,
shm_threshold_bytes
=
sys
.
maxsize
,
)
after
.
callback
(
engine
.
shutdown
)
prompt
=
"Hello my name is Robert and "
NUM_EXPECTED_TOKENS
=
64
NUM_REQUESTS
=
3
request_ids
=
[
f
"request-
{
i
}
"
for
i
in
range
(
NUM_REQUESTS
)]
# Create concurrent requests.
tasks
:
list
[
asyncio
.
Task
]
=
[]
for
idx
,
request_id
in
enumerate
(
request_ids
):
tasks
.
append
(
asyncio
.
create_task
(
generate
(
engine
,
request_id
,
prompt
,
NUM_EXPECTED_TOKENS
)))
# Confirm the requests are okay.
for
idx
,
task
in
enumerate
(
tasks
):
await
task
output_modalities
=
[
"text"
]
final_stage_id_for_e2e
=
get_final_stage_id_for_e2e
(
output_modalities
,
engine
.
output_modalities
,
engine
.
stage_list
)
summary
=
capture_metrics
[
request_ids
[
idx
]].
metrics
.
build_and_log_summary
(
final_stage_id_for_e2e
)
# Check that total tokens matches sum of stage tokens.
assert
summary
[
"e2e_total_tokens"
]
==
sum
(
stage
[
"tokens"
]
for
stage
in
summary
[
"stages"
])
# Check that total time matches sum of stage times.
assert
summary
[
"e2e_total_time_ms"
]
>=
sum
(
stage
[
"total_time_ms"
]
for
stage
in
summary
[
"stages"
])
tests/e2e/online_serving/test_image_gen_edit.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
E2E online serving test for Qwen-Image-Edit-2509 multi-image input.
"""
import
base64
import
os
import
signal
import
socket
import
subprocess
import
sys
import
threading
import
time
from
io
import
BytesIO
from
typing
import
Any
import
openai
import
pytest
import
requests
from
PIL
import
Image
from
vllm.assets.image
import
ImageAsset
from
vllm.utils.network_utils
import
get_open_port
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
# Increase timeout for downloading assets from S3 (default 5s is too short for CI)
os
.
environ
.
setdefault
(
"VLLM_IMAGE_FETCH_TIMEOUT"
,
"60"
)
models
=
[
"Qwen/Qwen-Image-Edit-2509"
]
test_params
=
models
t2i_models
=
[
"Tongyi-MAI/Z-Image-Turbo"
]
class
OmniServer
:
"""Omniserver for vLLM-Omni tests."""
def
__init__
(
self
,
model
:
str
,
serve_args
:
list
[
str
],
*
,
env_dict
:
dict
[
str
,
str
]
|
None
=
None
,
)
->
None
:
self
.
model
=
model
self
.
serve_args
=
serve_args
self
.
env_dict
=
env_dict
self
.
proc
:
subprocess
.
Popen
|
None
=
None
self
.
host
=
"127.0.0.1"
self
.
port
=
get_open_port
()
def
_start_server
(
self
)
->
None
:
"""Start the vLLM-Omni server subprocess."""
env
=
os
.
environ
.
copy
()
env
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
if
self
.
env_dict
is
not
None
:
env
.
update
(
self
.
env_dict
)
cmd
=
[
sys
.
executable
,
"-m"
,
"vllm_omni.entrypoints.cli.main"
,
"serve"
,
self
.
model
,
"--omni"
,
"--host"
,
self
.
host
,
"--port"
,
str
(
self
.
port
),
]
+
self
.
serve_args
print
(
f
"Launching OmniServer with:
{
' '
.
join
(
cmd
)
}
"
)
self
.
proc
=
subprocess
.
Popen
(
cmd
,
env
=
env
,
cwd
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))),
# Set working directory to vllm-omni root
start_new_session
=
True
,
)
# Wait for server to be ready
max_wait
=
600
# 10 minutes
start_time
=
time
.
time
()
while
time
.
time
()
-
start_time
<
max_wait
:
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
sock
:
sock
.
settimeout
(
1
)
result
=
sock
.
connect_ex
((
self
.
host
,
self
.
port
))
if
result
==
0
:
print
(
f
"Server ready on
{
self
.
host
}
:
{
self
.
port
}
"
)
return
except
Exception
:
pass
time
.
sleep
(
2
)
raise
RuntimeError
(
f
"Server failed to start within
{
max_wait
}
seconds"
)
def
__enter__
(
self
):
self
.
_start_server
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
if
self
.
proc
:
try
:
os
.
killpg
(
self
.
proc
.
pid
,
signal
.
SIGTERM
)
except
ProcessLookupError
:
pass
try
:
self
.
proc
.
wait
(
timeout
=
30
)
except
subprocess
.
TimeoutExpired
:
try
:
os
.
killpg
(
self
.
proc
.
pid
,
signal
.
SIGKILL
)
except
ProcessLookupError
:
pass
self
.
proc
.
wait
()
@
pytest
.
fixture
def
omni_server
(
request
):
"""Start vLLM-Omni server as a subprocess with actual model weights."""
model
=
request
.
param
with
OmniServer
(
model
,
[
"--num-gpus"
,
"1"
])
as
server
:
yield
server
@
pytest
.
fixture
def
client
(
omni_server
):
"""OpenAI client for the running vLLM-Omni server."""
return
openai
.
OpenAI
(
base_url
=
f
"http://
{
omni_server
.
host
}
:
{
omni_server
.
port
}
/v1"
,
api_key
=
"EMPTY"
,
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
base64_encoded_images
()
->
list
[
str
]:
"""Base64 encoded PNG images for testing."""
images
=
[
ImageAsset
(
"cherry_blossom"
).
pil_image
.
convert
(
"RGB"
),
ImageAsset
(
"stop_sign"
).
pil_image
.
convert
(
"RGB"
),
]
encoded
:
list
[
str
]
=
[]
for
img
in
images
:
with
BytesIO
()
as
buffer
:
img
.
save
(
buffer
,
format
=
"PNG"
)
encoded
.
append
(
base64
.
b64encode
(
buffer
.
getvalue
()).
decode
(
"utf-8"
))
return
encoded
def
dummy_messages_from_image_data
(
image_data_urls
:
list
[
str
],
content_text
:
str
=
"Combine these two images into one scene."
,
):
"""Create messages with image data URLs for OpenAI API."""
content
=
[{
"type"
:
"text"
,
"text"
:
content_text
}]
for
image_url
in
image_data_urls
:
content
.
append
({
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}})
return
[{
"role"
:
"user"
,
"content"
:
content
}]
def
_extract_image_data_url
(
message_content
)
->
str
:
assert
isinstance
(
message_content
,
list
)
and
len
(
message_content
)
>=
1
content_part
=
message_content
[
0
]
if
isinstance
(
content_part
,
dict
):
image_url
=
content_part
.
get
(
"image_url"
,
{}).
get
(
"url"
,
""
)
else
:
image_url_obj
=
getattr
(
content_part
,
"image_url"
,
None
)
if
isinstance
(
image_url_obj
,
dict
):
image_url
=
image_url_obj
.
get
(
"url"
,
""
)
else
:
image_url
=
getattr
(
image_url_obj
,
"url"
,
""
)
assert
isinstance
(
image_url
,
str
)
and
image_url
return
image_url
def
_decode_data_url_to_image_bytes
(
data_url
:
str
)
->
bytes
:
assert
data_url
.
startswith
(
"data:image"
)
_
,
b64_data
=
data_url
.
split
(
","
,
1
)
return
base64
.
b64decode
(
b64_data
)
@
pytest
.
mark
.
parametrize
(
"omni_server"
,
test_params
,
indirect
=
True
)
def
test_i2i_multi_image_input_qwen_image_edit_2509
(
omni_server
,
base64_encoded_images
:
list
[
str
],
)
->
None
:
"""Test multi-image input editing via OpenAI API with concurrent requests."""
image_data_urls
=
[
f
"data:image/png;base64,
{
img
}
"
for
img
in
base64_encoded_images
]
messages
=
dummy_messages_from_image_data
(
image_data_urls
)
barrier
=
threading
.
Barrier
(
2
)
results
:
list
[
tuple
[
int
,
int
]]
=
[]
def
_call_chat
(
width
:
int
,
height
:
int
)
->
None
:
client
=
openai
.
OpenAI
(
base_url
=
f
"http://
{
omni_server
.
host
}
:
{
omni_server
.
port
}
/v1"
,
api_key
=
"EMPTY"
,
)
barrier
.
wait
()
chat_completion
=
client
.
chat
.
completions
.
create
(
model
=
omni_server
.
model
,
messages
=
messages
,
extra_body
=
{
"height"
:
height
,
"width"
:
width
,
"num_inference_steps"
:
2
,
"guidance_scale"
:
0.0
,
"seed"
:
42
,
},
)
assert
len
(
chat_completion
.
choices
)
==
1
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"stop"
assert
choice
.
message
.
role
==
"assistant"
image_data_url
=
_extract_image_data_url
(
choice
.
message
.
content
)
image_bytes
=
_decode_data_url_to_image_bytes
(
image_data_url
)
img
=
Image
.
open
(
BytesIO
(
image_bytes
))
img
.
load
()
results
.
append
(
img
.
size
)
threads
=
[
threading
.
Thread
(
target
=
_call_chat
,
args
=
(
1248
,
832
)),
threading
.
Thread
(
target
=
_call_chat
,
args
=
(
1024
,
768
)),
]
for
t
in
threads
:
t
.
start
()
for
t
in
threads
:
t
.
join
()
# TODO @ZJY
# assert (1248, 832) in results
# assert (1024, 768) in results
@
pytest
.
mark
.
parametrize
(
"omni_server"
,
t2i_models
,
indirect
=
True
)
def
test_t2i_concurrent_requests_different_sizes
(
omni_server
)
->
None
:
"""Test /v1/images/generations concurrent requests with different sizes."""
base_url
=
f
"http://
{
omni_server
.
host
}
:
{
omni_server
.
port
}
"
url
=
f
"
{
base_url
}
/v1/images/generations"
barrier
=
threading
.
Barrier
(
2
)
results
:
list
[
tuple
[
int
,
int
]]
=
[]
def
_call_generate
(
size
:
str
)
->
None
:
payload
:
dict
[
str
,
Any
]
=
{
"prompt"
:
"cute cat playing with a ball"
,
"n"
:
1
,
"size"
:
size
,
"response_format"
:
"b64_json"
,
"num_inference_steps"
:
2
,
}
barrier
.
wait
()
response
=
requests
.
post
(
url
,
json
=
payload
,
timeout
=
120
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
image_b64
=
data
[
"data"
][
0
][
"b64_json"
]
image_bytes
=
base64
.
b64decode
(
image_b64
)
img
=
Image
.
open
(
BytesIO
(
image_bytes
))
img
.
load
()
results
.
append
(
img
.
size
)
threads
=
[
threading
.
Thread
(
target
=
_call_generate
,
args
=
(
"512x512"
,)),
threading
.
Thread
(
target
=
_call_generate
,
args
=
(
"768x512"
,)),
]
for
t
in
threads
:
t
.
start
()
for
t
in
threads
:
t
.
join
()
assert
(
512
,
512
)
in
results
assert
(
768
,
512
)
in
results
tests/e2e/online_serving/test_images_generations_lora.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
E2E online serving test for /v1/images/generations with per-request LoRA.
This validates:
- The API server accepts a per-request `lora` object in the Images API payload.
- LoRA can be switched per request (adapter A -> adapter B -> no LoRA).
- Output correctness is asserted using a small image slice with tolerance.
"""
import
base64
import
json
import
os
from
io
import
BytesIO
from
pathlib
import
Path
import
numpy
as
np
import
pytest
import
requests
import
torch
from
PIL
import
Image
from
safetensors.torch
import
save_file
from
tests.conftest
import
OmniServer
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
MODEL
=
"Tongyi-MAI/Z-Image-Turbo"
PROMPT
=
"a photo of a cat sitting on a laptop keyboard"
SIZE
=
"256x256"
SEED
=
42
@
pytest
.
fixture
(
scope
=
"module"
)
def
omni_server
():
with
OmniServer
(
MODEL
,
[
"--num-gpus"
,
"1"
])
as
server
:
yield
server
def
_write_zimage_lora
(
adapter_dir
:
Path
,
*
,
q_scale
:
float
=
0.0
,
k_scale
:
float
=
0.0
,
v_scale
:
float
=
0.0
):
adapter_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# Z-Image transformer uses dim=3840 by default.
dim
=
3840
module_name
=
"transformer.layers.0.attention.to_qkv"
rank
=
1
lora_a
=
torch
.
zeros
((
rank
,
dim
),
dtype
=
torch
.
float32
)
lora_a
[
0
,
0
]
=
1.0
# QKVParallelLinear packs (Q, K, V) => out dim is 3 * dim (tp=1).
lora_b
=
torch
.
zeros
((
3
*
dim
,
rank
),
dtype
=
torch
.
float32
)
if
q_scale
:
lora_b
[:
dim
,
0
]
=
q_scale
if
k_scale
:
lora_b
[
dim
:
2
*
dim
,
0
]
=
k_scale
if
v_scale
:
lora_b
[
2
*
dim
:,
0
]
=
v_scale
save_file
(
{
f
"base_model.model.
{
module_name
}
.lora_A.weight"
:
lora_a
,
f
"base_model.model.
{
module_name
}
.lora_B.weight"
:
lora_b
,
},
str
(
adapter_dir
/
"adapter_model.safetensors"
),
)
(
adapter_dir
/
"adapter_config.json"
).
write_text
(
json
.
dumps
(
{
"r"
:
rank
,
"lora_alpha"
:
rank
,
"target_modules"
:
[
module_name
],
}
),
encoding
=
"utf-8"
,
)
def
_post_images
(
server
:
OmniServer
,
payload
:
dict
)
->
Image
.
Image
:
url
=
f
"http://
{
server
.
host
}
:
{
server
.
port
}
/v1/images/generations"
resp
=
requests
.
post
(
url
,
json
=
payload
,
headers
=
{
"Authorization"
:
"Bearer EMPTY"
},
timeout
=
900
)
resp
.
raise_for_status
()
data
=
resp
.
json
()
b64
=
data
[
"data"
][
0
][
"b64_json"
]
img_bytes
=
base64
.
b64decode
(
b64
)
img
=
Image
.
open
(
BytesIO
(
img_bytes
))
img
.
load
()
return
img
.
convert
(
"RGB"
)
def
_image_blue_tail_slice
(
img
:
Image
.
Image
)
->
np
.
ndarray
:
arr
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
assert
arr
.
ndim
==
3
and
arr
.
shape
[
-
1
]
==
3
tail
=
arr
[
-
3
:,
-
3
:,
-
1
].
astype
(
np
.
float32
)
assert
tail
.
shape
==
(
3
,
3
)
return
tail
def
_slice_diff_stats
(
actual
:
np
.
ndarray
,
expected
:
np
.
ndarray
)
->
tuple
[
float
,
float
]:
diff
=
np
.
abs
(
actual
-
expected
)
return
float
(
diff
.
max
()),
float
(
diff
.
mean
())
def
_assert_slice_close
(
actual
:
np
.
ndarray
,
expected
:
np
.
ndarray
,
*
,
label
:
str
,
base_max
:
float
,
base_mean
:
float
,
)
->
None
:
assert
actual
.
shape
==
(
3
,
3
)
assert
expected
.
shape
==
(
3
,
3
)
max_diff
,
mean_diff
=
_slice_diff_stats
(
actual
,
expected
)
# NOTE: Different attention backends / torch.compile can introduce small
# floating-point drift that shows up as a few LSBs in uint8 pixels. Keep
# the reset check tolerant but bounded to avoid flaky CI.
max_thresh
=
max
(
10.0
,
base_max
+
4.0
)
mean_thresh
=
max
(
6.0
,
base_mean
+
4.0
)
assert
max_diff
<=
max_thresh
and
mean_diff
<=
mean_thresh
,
(
f
"
{
label
}
slice mismatch (max=
{
max_diff
:.
1
f
}
>
{
max_thresh
:.
1
f
}
or "
f
"mean=
{
mean_diff
:.
1
f
}
>
{
mean_thresh
:.
1
f
}
):
{
actual
.
tolist
()
}
"
)
def
_assert_slice_diff
(
actual
:
np
.
ndarray
,
baseline
:
np
.
ndarray
,
*
,
label
:
str
)
->
None
:
assert
actual
.
shape
==
(
3
,
3
)
assert
baseline
.
shape
==
(
3
,
3
)
diff
=
np
.
abs
(
actual
-
baseline
).
mean
()
assert
diff
>
0.1
,
f
"
{
label
}
slice diff too small:
{
diff
}
(
{
actual
.
tolist
()
}
vs
{
baseline
.
tolist
()
}
)"
def
_basic_payload
()
->
dict
:
return
{
"prompt"
:
PROMPT
,
"n"
:
1
,
"size"
:
SIZE
,
"num_inference_steps"
:
2
,
"guidance_scale"
:
0.0
,
"seed"
:
SEED
,
}
def
test_images_generations_per_request_lora_switching
(
omni_server
:
OmniServer
,
tmp_path
:
Path
)
->
None
:
# Base generation.
base_img
=
_post_images
(
omni_server
,
_basic_payload
())
base_slice
=
_image_blue_tail_slice
(
base_img
)
base_ref_img
=
_post_images
(
omni_server
,
_basic_payload
())
base_ref_slice
=
_image_blue_tail_slice
(
base_ref_img
)
base_ref_max
,
base_ref_mean
=
_slice_diff_stats
(
base_ref_slice
,
base_slice
)
# Adapter A: apply delta to V slice only.
lora_a_dir
=
tmp_path
/
"zimage_lora_a"
_write_zimage_lora
(
lora_a_dir
,
v_scale
=
8.0
)
payload_a
=
_basic_payload
()
payload_a
[
"lora"
]
=
{
"name"
:
"a"
,
"path"
:
str
(
lora_a_dir
),
"scale"
:
64.0
}
img_a
=
_post_images
(
omni_server
,
payload_a
)
a_slice
=
_image_blue_tail_slice
(
img_a
)
_assert_slice_diff
(
a_slice
,
base_slice
,
label
=
"lora_a_vs_base"
)
a_vs_base
=
float
(
np
.
abs
(
a_slice
-
base_slice
).
mean
())
# Adapter B: apply delta to K slice only (should differ from adapter A).
lora_b_dir
=
tmp_path
/
"zimage_lora_b"
_write_zimage_lora
(
lora_b_dir
,
k_scale
=
4.0
)
payload_b
=
_basic_payload
()
payload_b
[
"lora"
]
=
{
"name"
:
"b"
,
"path"
:
str
(
lora_b_dir
),
"scale"
:
64.0
}
img_b
=
_post_images
(
omni_server
,
payload_b
)
b_slice
=
_image_blue_tail_slice
(
img_b
)
_assert_slice_diff
(
b_slice
,
base_slice
,
label
=
"lora_b_vs_base"
)
_assert_slice_diff
(
b_slice
,
a_slice
,
label
=
"lora_b_vs_lora_a"
)
b_vs_base
=
float
(
np
.
abs
(
b_slice
-
base_slice
).
mean
())
b_vs_a
=
float
(
np
.
abs
(
b_slice
-
a_slice
).
mean
())
# Ensure switching back to no-LoRA restores the base output.
base_img_2
=
_post_images
(
omni_server
,
_basic_payload
())
base_slice_2
=
_image_blue_tail_slice
(
base_img_2
)
_
,
base_reset_mean
=
_slice_diff_stats
(
base_slice_2
,
base_slice
)
_assert_slice_close
(
base_slice_2
,
base_slice
,
label
=
"base_after_reset"
,
base_max
=
base_ref_max
,
base_mean
=
base_ref_mean
,
)
# Ensure LoRA effects are clearly above the baseline drift.
min_delta
=
max
(
base_reset_mean
+
1.0
,
1.5
)
assert
a_vs_base
>
min_delta
,
f
"lora_a_vs_base drift too small:
{
a_vs_base
}
<=
{
min_delta
}
"
assert
b_vs_base
>
min_delta
,
f
"lora_b_vs_base drift too small:
{
b_vs_base
}
<=
{
min_delta
}
"
assert
b_vs_a
>
min_delta
,
f
"lora_b_vs_lora_a drift too small:
{
b_vs_a
}
<=
{
min_delta
}
"
tests/e2e/online_serving/test_qwen3_omni.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
E2E Online tests for Qwen3-Omni model with video input and audio output.
"""
import
os
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
os
.
environ
[
"VLLM_TEST_CLEAN_GPU_MEMORY"
]
=
"0"
import
concurrent.futures
import
threading
import
time
from
pathlib
import
Path
import
openai
import
pytest
from
tests.conftest
import
(
OmniServer
,
convert_audio_to_text
,
cosine_similarity_text
,
dummy_messages_from_mix_data
,
generate_synthetic_audio
,
generate_synthetic_image
,
generate_synthetic_video
,
merge_base64_and_convert_to_text
,
modify_stage_config
,
)
from
vllm_omni.platforms
import
current_omni_platform
models
=
[
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
]
def
get_default_config
():
return
str
(
Path
(
__file__
).
parent
.
parent
/
"stage_configs"
/
"qwen3_omni_ci.yaml"
)
def
get_chunk_config
():
path
=
modify_stage_config
(
get_default_config
(),
updates
=
{
"async_chunk"
:
True
,
"stage_args"
:
{
0
:
{
"engine_args.custom_process_next_stage_input_func"
:
"vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
},
1
:
{
"engine_args.custom_process_next_stage_input_func"
:
"vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
},
},
},
deletes
=
{
"stage_args"
:
{
2
:
[
"custom_process_input_func"
]}},
)
return
path
CHUNK_CONFIG_PATH
=
get_chunk_config
()
# CI stage config for 2xH100-80G GPUs or AMD GPU MI325
if
current_omni_platform
.
is_rocm
():
# ROCm stage config optimized for MI325 GPU
stage_configs
=
[
str
(
Path
(
__file__
).
parent
/
"stage_configs"
/
"rocm"
/
"qwen3_omni_ci.yaml"
)]
else
:
stage_configs
=
[
get_default_config
(),
CHUNK_CONFIG_PATH
]
# Create parameter combinations for model and stage config
test_params
=
[(
model
,
stage_config
)
for
model
in
models
for
stage_config
in
stage_configs
]
_omni_server_lock
=
threading
.
Lock
()
@
pytest
.
fixture
(
scope
=
"module"
)
def
omni_server
(
request
):
"""Start vLLM-Omni server as a subprocess with actual model weights.
Uses session scope so the server starts only once for the entire test session.
Multi-stage initialization can take 10-20+ minutes.
"""
with
_omni_server_lock
:
model
,
stage_config_path
=
request
.
param
print
(
f
"Starting OmniServer with model:
{
model
}
"
)
with
OmniServer
(
model
,
[
"--stage-configs-path"
,
stage_config_path
,
"--stage-init-timeout"
,
"120"
])
as
server
:
print
(
"OmniServer started successfully"
)
yield
server
print
(
"OmniServer stopping..."
)
print
(
"OmniServer stopped"
)
@
pytest
.
fixture
def
client
(
omni_server
):
"""OpenAI client for the running vLLM-Omni server."""
return
openai
.
OpenAI
(
base_url
=
f
"http://
{
omni_server
.
host
}
:
{
omni_server
.
port
}
/v1"
,
api_key
=
"EMPTY"
,
)
def
get_system_prompt
():
return
{
"role"
:
"system"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
(
"You are Qwen, a virtual human developed by the Qwen Team, "
"Alibaba Group, capable of perceiving auditory and visual inputs, "
"as well as generating text and speech."
),
}
],
}
def
dummy_messages_from_video_data
(
video_data_url
:
str
,
content_text
:
str
=
"Describe the video briefly."
,
):
"""Create messages with video data URL for OpenAI API."""
return
[
get_system_prompt
(),
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"video_url"
,
"video_url"
:
{
"url"
:
video_data_url
}},
{
"type"
:
"text"
,
"text"
:
content_text
},
],
},
]
def
get_prompt
(
prompt_type
=
"text_only"
):
prompts
=
{
"text_only"
:
"What is the capital of China? Answer in 20 words."
,
"mix"
:
"What is recited in the audio? What is in this image? Describe the video briefly."
,
}
return
prompts
.
get
(
prompt_type
,
prompts
[
"text_only"
])
def
get_max_batch_size
(
size_type
=
"few"
):
batch_sizes
=
{
"few"
:
5
,
"medium"
:
100
,
"large"
:
256
}
return
batch_sizes
.
get
(
size_type
,
5
)
@
pytest
.
mark
.
parametrize
(
"omni_server"
,
test_params
,
indirect
=
True
)
def
test_mix_to_text_audio_001
(
client
:
openai
.
OpenAI
,
omni_server
,
request
)
->
None
:
"""
Test multi-modal input processing and text/audio output generation via OpenAI API.
Deploy Setting: default yaml
Input Modal: text + audio + video + image
Output Modal: text + audio
Input Setting: stream=True
Datasets: single request
"""
# Test single completion
e2e_list
=
list
()
video_data_url
=
f
"data:video/mp4;base64,
{
generate_synthetic_video
(
224
,
224
,
300
)[
'base64'
]
}
"
image_data_url
=
f
"data:image/jpeg;base64,
{
generate_synthetic_image
(
224
,
224
)[
'base64'
]
}
"
audio_data_url
=
f
"data:audio/wav;base64,
{
generate_synthetic_audio
(
5
,
1
)[
'base64'
]
}
"
messages
=
dummy_messages_from_mix_data
(
system_prompt
=
get_system_prompt
(),
video_data_url
=
video_data_url
,
image_data_url
=
image_data_url
,
audio_data_url
=
audio_data_url
,
content_text
=
get_prompt
(
"mix"
),
)
# Test single completion
start_time
=
time
.
perf_counter
()
chat_completion
=
client
.
chat
.
completions
.
create
(
model
=
omni_server
.
model
,
messages
=
messages
,
stream
=
True
)
text_content
=
""
audio_data
=
[]
for
chunk
in
chat_completion
:
for
choice
in
chunk
.
choices
:
if
hasattr
(
choice
,
"delta"
):
content
=
getattr
(
choice
.
delta
,
"content"
,
None
)
else
:
content
=
None
modality
=
getattr
(
chunk
,
"modality"
,
None
)
if
modality
==
"audio"
and
content
:
audio_data
.
append
(
content
)
elif
modality
==
"text"
and
content
:
# Text chunk - accumulate text content
text_content
+=
content
if
content
else
""
# Verify E2E
current_e2e
=
time
.
perf_counter
()
-
start_time
print
(
f
"the request e2e is:
{
current_e2e
}
"
)
# TODO: Verify the E2E latency after confirmation baseline.
e2e_list
.
append
(
current_e2e
)
print
(
f
"the avg e2e is:
{
sum
(
e2e_list
)
/
len
(
e2e_list
)
}
"
)
# Verify all completions succeeded
assert
audio_data
is
not
None
,
"No audio output is generated"
# Verify text output success
assert
text_content
is
not
None
and
len
(
text_content
)
>=
2
,
"No text output is generated"
assert
any
(
keyword
in
text_content
.
lower
()
for
keyword
in
[
"square"
,
"quadrate"
,
"sphere"
,
"globe"
,
"circle"
,
"round"
]
),
"The output does not contain any of the keywords."
# Verify text output same as audio output
audio_content
=
merge_base64_and_convert_to_text
(
audio_data
)
print
(
f
"text content is:
{
text_content
}
"
)
print
(
f
"audio content is:
{
audio_content
}
"
)
similarity
=
cosine_similarity_text
(
audio_content
.
lower
(),
text_content
.
lower
())
print
(
f
"similarity is:
{
similarity
}
"
)
assert
similarity
>
0.9
,
"The audio content is not same as the text"
@
pytest
.
mark
.
parametrize
(
"omni_server"
,
test_params
,
indirect
=
True
)
def
test_text_to_text_audio_001
(
client
:
openai
.
OpenAI
,
omni_server
)
->
None
:
"""
Test text input processing and text/audio output generation via OpenAI API.
Deploy Setting: default yaml
Input Modal: text
Output Modal: text + audio
Datasets: few requests
"""
num_concurrent_requests
=
get_max_batch_size
()
messages
=
dummy_messages_from_mix_data
(
system_prompt
=
get_system_prompt
(),
content_text
=
get_prompt
())
e2e_list
=
list
()
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
num_concurrent_requests
)
as
executor
:
# Submit multiple completion requests concurrently
futures
=
[
executor
.
submit
(
client
.
chat
.
completions
.
create
,
model
=
omni_server
.
model
,
messages
=
messages
)
for
_
in
range
(
num_concurrent_requests
)
]
start_time
=
time
.
perf_counter
()
# Wait for all requests to complete and collect results
chat_completions
=
list
()
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
chat_completions
.
append
(
future
.
result
())
# Verify E2E
current_e2e
=
time
.
perf_counter
()
-
start_time
print
(
f
"the request e2e is:
{
current_e2e
}
"
)
# TODO: Verify the E2E latency after confirmation baseline.
e2e_list
.
append
(
current_e2e
)
print
(
f
"the avg e2e is:
{
sum
(
e2e_list
)
/
len
(
e2e_list
)
}
"
)
# Verify all completions succeeded
assert
len
(
chat_completions
)
==
num_concurrent_requests
,
"Not all requests succeeded."
for
chat_completion
in
chat_completions
:
# Verify audio output success
audio_data
=
None
text_content
=
None
for
choice
in
chat_completion
.
choices
:
if
choice
.
message
.
audio
is
not
None
:
audio_message
=
choice
.
message
audio_data
=
audio_message
.
audio
.
data
assert
audio_message
.
audio
.
expires_at
>
time
.
time
(),
"The generated audio has expired."
if
choice
.
message
.
content
is
not
None
:
# Verify text output success
text_content
=
choice
.
message
.
content
assert
"beijing"
in
text_content
.
lower
(),
"The output do not contain keywords."
# Verify text output same as audio output
audio_content
=
convert_audio_to_text
(
audio_data
)
print
(
f
"text content is:
{
text_content
}
"
)
print
(
f
"audio content is:
{
audio_content
}
"
)
similarity
=
cosine_similarity_text
(
audio_content
.
lower
(),
text_content
.
lower
())
print
(
f
"similarity is:
{
similarity
}
"
)
assert
similarity
>
0.9
,
"The audio content is not same as the text"
tests/e2e/online_serving/test_qwen3_omni_expansion.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
E2E Online tests for Qwen3-Omni model.
"""
import
concurrent.futures
import
os
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
import
time
from
pathlib
import
Path
import
openai
import
pytest
from
tests.conftest
import
(
OmniServer
,
convert_audio_to_text
,
cosine_similarity_text
,
dummy_messages_from_mix_data
,
generate_synthetic_audio
,
generate_synthetic_image
,
modify_stage_config
,
)
models
=
[
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
]
# CI stage config for 2*H100-80G GPUs
stage_configs
=
[
str
(
Path
(
__file__
).
parent
.
parent
/
"stage_configs"
/
"qwen3_omni_ci.yaml"
)]
# Create parameter combinations for model and stage config
test_params
=
[(
model
,
stage_config
)
for
model
in
models
for
stage_config
in
stage_configs
]
def
client
(
omni_server
):
"""OpenAI client for the running vLLM-Omni server."""
return
openai
.
OpenAI
(
base_url
=
f
"http://
{
omni_server
.
host
}
:
{
omni_server
.
port
}
/v1"
,
api_key
=
"EMPTY"
,
)
def
get_system_prompt
():
return
{
"role"
:
"system"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
(
"You are Qwen, a virtual human developed by the Qwen Team, "
"Alibaba Group, capable of perceiving auditory and visual inputs, "
"as well as generating text and speech."
),
}
],
}
def
get_prompt
(
prompt_type
=
"text_only"
):
prompts
=
{
"text_only"
:
"What is the capital of China?"
,
"mix"
:
"What is recited in the audio? What is in this image? Describe the video briefly."
,
}
return
prompts
.
get
(
prompt_type
,
prompts
[
"text_only"
])
def
get_max_batch_size
(
size_type
=
"few"
):
batch_sizes
=
{
"few"
:
5
,
"medium"
:
100
,
"large"
:
256
}
return
batch_sizes
.
get
(
size_type
,
5
)
def
get_deploy_config
(
deploy_type
=
"TP1"
):
result
=
{
"TP1"
:
{
"stage_args"
:
{
0
:
{
"engine_args.gpu_memory_utilization"
:
0.95
,
"engine_args.tensor_parallel_size"
:
1
,
"runtime.devices"
:
"0"
,
},
2
:
{
"runtime.devices"
:
"1"
},
}
}
}
return
result
.
get
(
deploy_type
,
result
[
"TP1"
])
@
pytest
.
mark
.
parametrize
(
"test_config"
,
test_params
)
def
test_text_to_text_001
(
test_config
:
tuple
[
str
,
str
])
->
None
:
"""Test processing text, generating text output via OpenAI API."""
model
,
stage_config_path
=
test_config
with
OmniServer
(
model
,
[
"--stage-configs-path"
,
stage_config_path
,
"--stage-init-timeout"
,
"90"
])
as
server
:
messages
=
dummy_messages_from_mix_data
(
system_prompt
=
get_system_prompt
(),
content_text
=
get_prompt
())
# Test single completion
api_client
=
client
(
server
)
start_time
=
time
.
perf_counter
()
chat_completion
=
api_client
.
chat
.
completions
.
create
(
model
=
server
.
model
,
messages
=
messages
,
max_tokens
=
20
,
modalities
=
[
"text"
]
)
# Verify E2E
print
(
f
"the request e2e is:
{
time
.
perf_counter
()
-
start_time
}
"
)
# TODO: Verify the E2E latency after confirmation baseline.
# Verify only output text
assert
len
(
chat_completion
.
choices
)
==
1
,
"The generated content includes more than just text."
# Verify text output success
text_choice
=
chat_completion
.
choices
[
0
]
assert
text_choice
.
message
.
content
is
not
None
,
"No text output is generated"
assert
chat_completion
.
usage
.
completion_tokens
<=
20
,
"The output length more than the requested max_tokens."
assert
"beijing"
in
text_choice
.
message
.
content
.
lower
(),
"The output do not contain keywords."
@
pytest
.
mark
.
parametrize
(
"test_config"
,
test_params
)
def
test_audio_to_text_001
(
test_config
:
tuple
[
str
,
str
])
->
None
:
"""Test processing text, generating text output via OpenAI API."""
model
,
stage_config_path
=
test_config
deploy_config
=
get_deploy_config
()
deploy_config
[
0
][
"default_sampling_params.ignore_eos"
]
=
True
stage_config_path
=
modify_stage_config
(
stage_config_path
,
deploy_config
)
with
OmniServer
(
model
,
[
"--stage-configs-path"
,
stage_config_path
,
"--stage-init-timeout"
,
"90"
])
as
server
:
audio_data_url
=
f
"data:audio/wav;base64,
{
generate_synthetic_audio
(
1
,
1
)[
'base64'
]
}
"
messages
=
dummy_messages_from_mix_data
(
audio_data_url
=
audio_data_url
)
# Test single completion
api_client
=
client
(
server
)
start_time
=
time
.
perf_counter
()
chat_completion
=
api_client
.
chat
.
completions
.
create
(
model
=
server
.
model
,
messages
=
messages
,
max_tokens
=
200
,
modalities
=
[
"text"
]
)
# Verify only output text
assert
len
(
chat_completion
.
choices
)
==
1
,
"The generated content includes more than just text."
# Verify text output success
text_choice
=
chat_completion
.
choices
[
0
]
assert
text_choice
.
message
.
content
is
not
None
,
"No text output is generated"
assert
chat_completion
.
usage
.
completion_tokens
==
200
,
(
"The output length differs from the requested max_tokens."
)
# Verify E2E
print
(
f
"the request e2e is:
{
time
.
perf_counter
()
-
start_time
}
"
)
# TODO: Verify the E2E latency after confirmation baseline.
@
pytest
.
mark
.
parametrize
(
"test_config"
,
test_params
)
def
test_audio_to_text_audio_001
(
test_config
:
tuple
[
str
,
str
])
->
None
:
"""Test processing text, generating audio output via OpenAI API."""
model
,
stage_config_path
=
test_config
num_concurrent_requests
=
get_max_batch_size
()
stage_config_path
=
modify_stage_config
(
stage_config_path
,
{
"stage_args"
:
{
0
:
{
"runtime.max_batch_size"
:
num_concurrent_requests
},
1
:
{
"runtime.max_batch_size"
:
num_concurrent_requests
},
}
},
)
with
OmniServer
(
model
,
[
"--stage-configs-path"
,
stage_config_path
,
"--stage-init-timeout"
,
"90"
])
as
server
:
audio_data_url
=
[]
for
_
in
range
(
5
):
audio_data_url
.
append
(
f
"data:audio/wav;base64,
{
generate_synthetic_audio
(
1
,
5
)[
'base64'
]
}
"
)
messages
=
dummy_messages_from_mix_data
(
audio_data_url
=
audio_data_url
)
# Test single completion
api_client
=
client
(
server
)
e2e_list
=
list
()
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
num_concurrent_requests
)
as
executor
:
# Submit multiple completion requests concurrently
futures
=
[
executor
.
submit
(
api_client
.
chat
.
completions
.
create
,
model
=
server
.
model
,
messages
=
messages
)
for
_
in
range
(
num_concurrent_requests
)
]
start_time
=
time
.
perf_counter
()
# Wait for all requests to complete and collect results
chat_completions
=
list
()
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
chat_completions
.
append
(
future
.
result
())
# Verify E2E
current_e2e
=
time
.
perf_counter
()
-
start_time
print
(
f
"the request e2e is:
{
current_e2e
}
"
)
# TODO: Verify the E2E latency after confirmation baseline.
e2e_list
.
append
(
current_e2e
)
print
(
f
"the avg e2e is:
{
sum
(
e2e_list
)
/
len
(
e2e_list
)
}
"
)
# Verify all completions succeeded
assert
len
(
chat_completions
)
==
num_concurrent_requests
,
"Not all requests succeeded."
for
chat_completion
in
chat_completions
:
# Verify audio output success
audio_message
=
chat_completion
.
choices
[
1
].
message
audio_data
=
audio_message
.
audio
.
data
assert
audio_data
is
not
None
,
"No audio output is generated"
assert
audio_message
.
audio
.
expires_at
>
time
.
time
(),
"The generated audio has expired."
# Verify text output success
text_choice
=
chat_completion
.
choices
[
0
]
text_content
=
text_choice
.
message
.
content
assert
text_choice
.
message
.
content
is
not
None
,
"No text output is generated"
# Verify text output same as audio output
audio_content
=
convert_audio_to_text
(
audio_data
)
print
(
f
"text content is:
{
text_content
}
"
)
print
(
f
"audio content is:
{
audio_content
}
"
)
similarity
=
cosine_similarity_text
(
audio_content
,
text_content
)
print
(
f
"similarity between audio and text is:
{
similarity
}
"
)
assert
similarity
>
0.9
,
"The audio content is not same as the text"
@
pytest
.
mark
.
parametrize
(
"test_config"
,
test_params
)
def
test_image_to_text_001
(
test_config
:
tuple
[
str
,
str
])
->
None
:
"""Test processing text, generating text output via OpenAI API."""
model
,
stage_config_path
=
test_config
deploy_config
=
get_deploy_config
()
stage_config_path
=
modify_stage_config
(
stage_config_path
,
deploy_config
)
with
OmniServer
(
model
,
[
"--stage-configs-path"
,
stage_config_path
,
"--stage-init-timeout"
,
"90"
])
as
server
:
image_data_url
=
f
"data:image/jpeg;base64,
{
generate_synthetic_image
(
224
,
224
)[
'base64'
]
}
"
messages
=
dummy_messages_from_mix_data
(
image_data_url
=
image_data_url
)
# Test single completion
api_client
=
client
(
server
)
start_time
=
time
.
perf_counter
()
chat_completion
=
api_client
.
chat
.
completions
.
create
(
model
=
server
.
model
,
messages
=
messages
,
max_tokens
=
100
,
modalities
=
[
"text"
]
)
# Verify E2E
print
(
f
"the request e2e is:
{
time
.
perf_counter
()
-
start_time
}
"
)
# TODO: Verify the E2E latency after confirmation baseline.
# Verify only output text
assert
len
(
chat_completion
.
choices
)
==
1
,
"The generated content includes more than just text."
# Verify text output success
text_choice
=
chat_completion
.
choices
[
0
]
text_content
=
text_choice
.
message
.
content
assert
text_content
is
not
None
,
"No text output is generated"
assert
chat_completion
.
usage
.
completion_tokens
<=
100
,
"The output length more than the requested max_tokens."
assert
"square"
in
text_content
.
lower
(),
"The output do not contain keywords."
@
pytest
.
mark
.
parametrize
(
"test_config"
,
test_params
)
def
test_image_to_text_audio_001
(
test_config
:
tuple
[
str
,
str
])
->
None
:
"""Test processing text, generating audio output via OpenAI API."""
model
,
stage_config_path
=
test_config
num_concurrent_requests
=
5
stage_config_path
=
modify_stage_config
(
stage_config_path
,
{
"stage_args"
:
{
0
:
{
"runtime.max_batch_size"
:
num_concurrent_requests
},
1
:
{
"runtime.max_batch_size"
:
num_concurrent_requests
},
}
},
)
with
OmniServer
(
model
,
[
"--stage-configs-path"
,
stage_config_path
,
"--stage-init-timeout"
,
"90"
])
as
server
:
image_data_url
=
[]
for
_
in
range
(
4
):
image_data_url
.
append
(
f
"data:image/jpeg;base64,
{
generate_synthetic_image
(
1280
,
720
)[
'base64'
]
}
"
)
messages
=
dummy_messages_from_mix_data
(
image_data_url
=
image_data_url
)
# Test single completion
api_client
=
client
(
server
)
e2e_list
=
list
()
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
num_concurrent_requests
)
as
executor
:
# Submit multiple completion requests concurrently
futures
=
[
executor
.
submit
(
api_client
.
chat
.
completions
.
create
,
model
=
server
.
model
,
messages
=
messages
,
)
for
_
in
range
(
num_concurrent_requests
)
]
start_time
=
time
.
perf_counter
()
# Wait for all requests to complete and collect results
chat_completions
=
list
()
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
chat_completions
.
append
(
future
.
result
())
# Verify E2E
current_e2e
=
time
.
perf_counter
()
-
start_time
print
(
f
"the request e2e is:
{
current_e2e
}
"
)
# TODO: Verify the E2E latency after confirmation baseline.
e2e_list
.
append
(
current_e2e
)
print
(
f
"the avg e2e is:
{
sum
(
e2e_list
)
/
len
(
e2e_list
)
}
"
)
# Verify all completions succeeded
assert
len
(
chat_completions
)
==
num_concurrent_requests
,
"Not all requests succeeded."
for
chat_completion
in
chat_completions
:
# Verify audio output success
audio_message
=
chat_completion
.
choices
[
1
].
message
audio_data
=
audio_message
.
audio
.
data
assert
audio_data
is
not
None
,
"No audio output is generated"
assert
audio_message
.
audio
.
expires_at
>
time
.
time
(),
"The generated audio has expired."
# Verify text output success
text_choice
=
chat_completion
.
choices
[
0
]
text_content
=
text_choice
.
message
.
content
assert
text_content
is
not
None
,
"No text output is generated"
assert
"square"
in
text_content
.
lower
(),
"The output do not contain keywords."
# Verify text output same as audio output
audio_content
=
convert_audio_to_text
(
audio_data
)
print
(
f
"text content is:
{
text_content
}
"
)
print
(
f
"audio content is:
{
audio_content
}
"
)
similarity
=
cosine_similarity_text
(
audio_content
,
text_content
)
print
(
f
"similarity between audio and text is:
{
similarity
}
"
)
assert
similarity
>
0.9
,
"The audio content is not same as the text"
tests/e2e/stage_configs/qwen3_omni_ci.yaml
0 → 100644
View file @
c1cacde6
# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
# Stage 0: Thinker (multimodal understanding + text generation)
# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
# The following config has been verified on 2x H100-80G GPUs.
stage_args
:
-
stage_id
:
0
runtime
:
devices
:
"
0"
max_batch_size
:
5
engine_args
:
model_stage
:
thinker
model_arch
:
Qwen3OmniMoeForConditionalGeneration
worker_type
:
ar
scheduler_cls
:
vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization
:
0.9
enforce_eager
:
false
trust_remote_code
:
true
engine_output_type
:
latent
# Output hidden states for talker
distributed_executor_backend
:
"
mp"
max_num_batched_tokens
:
32768
max_model_len
:
32768
enable_prefix_caching
:
false
hf_config_name
:
thinker_config
tensor_parallel_size
:
1
final_output
:
true
final_output_type
:
text
is_comprehension
:
true
default_sampling_params
:
temperature
:
0.4
top_p
:
0.9
top_k
:
1
max_tokens
:
100
seed
:
42
ignore_eos
:
False
detokenize
:
True
repetition_penalty
:
1.05
-
stage_id
:
1
runtime
:
devices
:
"
1"
max_batch_size
:
5
engine_args
:
model_stage
:
talker
model_arch
:
Qwen3OmniMoeForConditionalGeneration
worker_type
:
ar
scheduler_cls
:
vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization
:
0.6
enforce_eager
:
false
trust_remote_code
:
true
engine_output_type
:
latent
# Output codec codes for code2wav
enable_prefix_caching
:
false
max_num_batched_tokens
:
32768
max_model_len
:
32768
distributed_executor_backend
:
"
mp"
hf_config_name
:
talker_config
engine_input_source
:
[
0
]
custom_process_input_func
:
vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
default_sampling_params
:
temperature
:
0.9
top_k
:
50
max_tokens
:
1000
seed
:
42
detokenize
:
False
repetition_penalty
:
1.05
stop_token_ids
:
[
2150
]
-
stage_id
:
2
runtime
:
devices
:
"
1"
max_batch_size
:
1
engine_args
:
model_stage
:
code2wav
model_arch
:
Qwen3OmniMoeForConditionalGeneration
worker_type
:
generation
scheduler_cls
:
vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager
:
true
trust_remote_code
:
true
enable_prefix_caching
:
false
engine_output_type
:
audio
# Final output: audio waveform
gpu_memory_utilization
:
0.1
distributed_executor_backend
:
"
mp"
max_num_batched_tokens
:
100000
hf_config_name
:
thinker_config
async_scheduling
:
false
engine_input_source
:
[
1
]
custom_process_input_func
:
vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
final_output
:
true
final_output_type
:
audio
default_sampling_params
:
temperature
:
0.0
top_p
:
1.0
top_k
:
-1
max_tokens
:
2000
seed
:
42
detokenize
:
True
repetition_penalty
:
1.1
tests/entrypoints/openai_api/__init__.py
0 → 100644
View file @
c1cacde6
tests/entrypoints/openai_api/test_image_server.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for async image generation API endpoints.
This module contains unit tests and integration tests (with mocking) for the
OpenAI-compatible async text-to-image generation API endpoints in api_server.py.
"""
import
base64
import
io
from
argparse
import
Namespace
from
unittest.mock
import
AsyncMock
,
Mock
import
pytest
from
fastapi.testclient
import
TestClient
from
PIL
import
Image
from
vllm
import
SamplingParams
from
vllm_omni.entrypoints.openai.image_api_utils
import
(
encode_image_base64
,
parse_size
,
)
from
vllm_omni.inputs.data
import
OmniDiffusionSamplingParams
# Unit Tests
def
test_parse_size_valid
():
"""Test size parsing with valid inputs"""
assert
parse_size
(
"1024x1024"
)
==
(
1024
,
1024
)
assert
parse_size
(
"512x768"
)
==
(
512
,
768
)
assert
parse_size
(
"256x256"
)
==
(
256
,
256
)
assert
parse_size
(
"1792x1024"
)
==
(
1792
,
1024
)
assert
parse_size
(
"1024x1792"
)
==
(
1024
,
1792
)
def
test_parse_size_invalid
():
"""Test size parsing with invalid inputs"""
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid size format"
):
parse_size
(
"invalid"
)
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid size format"
):
parse_size
(
"1024"
)
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid size format"
):
parse_size
(
"1024x"
)
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid size format"
):
parse_size
(
"x1024"
)
def
test_parse_size_negative
():
"""Test size parsing with negative or zero dimensions"""
with
pytest
.
raises
(
ValueError
,
match
=
"positive integers"
):
parse_size
(
"0x1024"
)
with
pytest
.
raises
(
ValueError
,
match
=
"positive integers"
):
parse_size
(
"1024x0"
)
with
pytest
.
raises
(
ValueError
):
parse_size
(
"-1024x1024"
)
def
test_parse_size_edge_cases
():
"""Test size parsing with edge cases like empty strings and non-integers"""
# Empty string
with
pytest
.
raises
(
ValueError
,
match
=
"non-empty string"
):
parse_size
(
""
)
# Non-integer dimensions
with
pytest
.
raises
(
ValueError
,
match
=
"must be integers"
):
parse_size
(
"abc x def"
)
with
pytest
.
raises
(
ValueError
,
match
=
"must be integers"
):
parse_size
(
"1024.5x768.5"
)
# Missing separator (user might forget 'x')
with
pytest
.
raises
(
ValueError
,
match
=
"separator"
):
parse_size
(
"1024 1024"
)
def
test_encode_image_base64
():
"""Test image encoding to base64"""
# Create a simple test image
img
=
Image
.
new
(
"RGB"
,
(
64
,
64
),
color
=
"red"
)
b64_str
=
encode_image_base64
(
img
)
# Should be valid base64
assert
isinstance
(
b64_str
,
str
)
assert
len
(
b64_str
)
>
0
# Should decode back to PNG
decoded
=
base64
.
b64decode
(
b64_str
)
decoded_img
=
Image
.
open
(
io
.
BytesIO
(
decoded
))
# Verify properties
assert
decoded_img
.
size
==
(
64
,
64
)
assert
decoded_img
.
format
==
"PNG"
# Integration Tests (with mocking)
class
MockGenerationResult
:
"""Mock result object from AsyncOmniDiffusion.generate()"""
def
__init__
(
self
,
images
):
self
.
images
=
images
class
FakeAsyncOmni
:
"""Fake AsyncOmni that yields a single diffusion output."""
def
__init__
(
self
):
self
.
stage_list
=
[
"llm"
,
"diffusion"
]
self
.
default_sampling_params_list
=
[
SamplingParams
(
temperature
=
0.1
),
OmniDiffusionSamplingParams
()]
self
.
captured_sampling_params_list
=
None
self
.
captured_prompt
=
None
async
def
generate
(
self
,
prompt
,
request_id
,
sampling_params_list
):
self
.
captured_sampling_params_list
=
sampling_params_list
self
.
captured_prompt
=
prompt
images
=
[
Image
.
new
(
"RGB"
,
(
64
,
64
),
color
=
"green"
)]
yield
MockGenerationResult
(
images
)
@
pytest
.
fixture
def
mock_async_diffusion
():
"""Mock AsyncOmniDiffusion instance that returns fake images"""
mock
=
Mock
()
mock
.
is_running
=
True
# For health endpoint
mock
.
check_health
=
AsyncMock
()
# For LLM mode health check
async
def
generate
(
**
kwargs
):
# Return n PIL images wrapped in result object
print
(
"!!!!!!!!!!!!!!!!!!!!! kwargs"
,
kwargs
)
n
=
kwargs
[
"sampling_params_list"
][
0
].
num_outputs_per_prompt
mock
.
captured_sampling_params_list
=
kwargs
[
"sampling_params_list"
]
mock
.
captured_prompt
=
kwargs
[
"prompt"
]
images
=
[
Image
.
new
(
"RGB"
,
(
64
,
64
),
color
=
"blue"
)
for
_
in
range
(
n
)]
return
MockGenerationResult
(
images
)
mock
.
generate
=
AsyncMock
(
side_effect
=
generate
)
return
mock
@
pytest
.
fixture
def
test_client
(
mock_async_diffusion
):
"""Create test client with mocked async diffusion engine"""
from
fastapi
import
FastAPI
from
vllm_omni.entrypoints.openai.api_server
import
router
app
=
FastAPI
()
app
.
include_router
(
router
)
# Set up app state with diffusion engine
app
.
state
.
engine_client
=
mock_async_diffusion
app
.
state
.
diffusion_engine
=
mock_async_diffusion
# Also set for health endpoint
app
.
state
.
stage_configs
=
[{
"stage_type"
:
"diffusion"
}]
app
.
state
.
diffusion_model_name
=
"Qwen/Qwen-Image"
# For models endpoint
app
.
state
.
args
=
Namespace
(
default_sampling_params
=
'{"0": {"num_inference_steps":4, "guidance_scale":7.5}}'
,
max_generated_image_size
=
4096
,
# 64*64
)
return
TestClient
(
app
)
@
pytest
.
fixture
def
async_omni_test_client
():
"""Create test client with mocked AsyncOmni engine."""
from
fastapi
import
FastAPI
from
vllm_omni.entrypoints.openai.api_server
import
router
app
=
FastAPI
()
app
.
include_router
(
router
)
app
.
state
.
engine_client
=
FakeAsyncOmni
()
app
.
state
.
stage_configs
=
[{
"stage_type"
:
"llm"
},
{
"stage_type"
:
"diffusion"
}]
app
.
state
.
args
=
Namespace
(
default_sampling_params
=
'{"1": {"num_inference_steps":4, "guidance_scale":7.5}}'
,
max_generated_image_size
=
4096
,
# 64*64
)
return
TestClient
(
app
)
def
test_health_endpoint
(
test_client
):
"""Test health check endpoint for diffusion mode"""
response
=
test_client
.
get
(
"/health"
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
assert
data
[
"status"
]
==
"healthy"
def
test_health_endpoint_no_engine
():
"""Test health check endpoint when no engine is initialized"""
from
fastapi
import
FastAPI
from
vllm_omni.entrypoints.openai.api_server
import
router
app
=
FastAPI
()
app
.
include_router
(
router
)
# Don't set any engine
client
=
TestClient
(
app
)
response
=
client
.
get
(
"/health"
)
assert
response
.
status_code
==
503
data
=
response
.
json
()
assert
data
[
"status"
]
==
"unhealthy"
def
test_models_endpoint
(
test_client
):
"""Test /v1/models endpoint for diffusion mode"""
response
=
test_client
.
get
(
"/v1/models"
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
assert
data
[
"object"
]
==
"list"
assert
len
(
data
[
"data"
])
==
1
assert
data
[
"data"
][
0
][
"id"
]
==
"Qwen/Qwen-Image"
assert
data
[
"data"
][
0
][
"object"
]
==
"model"
def
test_models_endpoint_no_engine
():
"""Test /v1/models endpoint when no engine is initialized"""
from
fastapi
import
FastAPI
from
vllm_omni.entrypoints.openai.api_server
import
router
app
=
FastAPI
()
app
.
include_router
(
router
)
# Don't set any engine
client
=
TestClient
(
app
)
response
=
client
.
get
(
"/v1/models"
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
assert
data
[
"object"
]
==
"list"
assert
len
(
data
[
"data"
])
==
0
def
test_generate_single_image
(
test_client
):
"""Test generating a single image"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a cat"
,
"n"
:
1
,
"size"
:
"1024x1024"
,
},
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
# Check response structure
assert
"created"
in
data
assert
isinstance
(
data
[
"created"
],
int
)
assert
"data"
in
data
assert
len
(
data
[
"data"
])
==
1
assert
"b64_json"
in
data
[
"data"
][
0
]
# Verify image can be decoded
img_bytes
=
base64
.
b64decode
(
data
[
"data"
][
0
][
"b64_json"
])
img
=
Image
.
open
(
io
.
BytesIO
(
img_bytes
))
assert
img
.
size
==
(
64
,
64
)
# Our mock returns 64x64 images
def
test_generate_images_async_omni_sampling_params
(
async_omni_test_client
):
"""Test AsyncOmni path uses per-stage sampling params."""
response
=
async_omni_test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a cat"
,
"n"
:
2
,
"size"
:
"256x256"
,
"seed"
:
7
,
},
)
assert
response
.
status_code
==
200
engine
=
async_omni_test_client
.
app
.
state
.
engine_client
captured
=
engine
.
captured_sampling_params_list
assert
captured
is
not
None
assert
len
(
captured
)
==
2
assert
captured
[
0
].
temperature
==
0.1
assert
captured
[
1
].
num_outputs_per_prompt
==
2
assert
captured
[
1
].
height
==
256
assert
captured
[
1
].
width
==
256
assert
captured
[
1
].
seed
==
7
def
test_generate_multiple_images
(
test_client
):
"""Test generating multiple images"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a dog"
,
"n"
:
3
,
"size"
:
"512x512"
,
},
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
assert
len
(
data
[
"data"
])
==
3
# All images should be valid
for
img_data
in
data
[
"data"
]:
assert
"b64_json"
in
img_data
img_bytes
=
base64
.
b64decode
(
img_data
[
"b64_json"
])
img
=
Image
.
open
(
io
.
BytesIO
(
img_bytes
))
assert
img
.
format
==
"PNG"
def
test_with_negative_prompt
(
test_client
):
"""Test with negative prompt"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"beautiful landscape"
,
"negative_prompt"
:
"blurry, low quality"
,
"size"
:
"1024x1024"
,
},
)
assert
response
.
status_code
==
200
def
test_with_seed
(
test_client
):
"""Test with seed for reproducibility"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a tree"
,
"seed"
:
42
,
"size"
:
"1024x1024"
,
},
)
assert
response
.
status_code
==
200
def
test_with_custom_parameters
(
test_client
):
"""Test with custom diffusion parameters"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a mountain"
,
"size"
:
"1024x1024"
,
"num_inference_steps"
:
100
,
"true_cfg_scale"
:
5.5
,
"seed"
:
123
,
},
)
assert
response
.
status_code
==
200
def
test_invalid_size
(
test_client
):
"""Test with invalid size parameter - rejected by Pydantic"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a cat"
,
"size"
:
"invalid"
,
},
)
# Pydantic validation errors return 422 (Unprocessable Entity)
# "invalid" has no "x" so Pydantic rejects it
assert
response
.
status_code
==
422
# Check error detail contains size validation message
detail
=
str
(
response
.
json
()[
"detail"
])
assert
"size"
in
detail
.
lower
()
or
"invalid"
in
detail
.
lower
()
def
test_invalid_size_parse_error
(
test_client
):
"""Test with malformed size - passes Pydantic but fails parse_size()"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a cat"
,
"size"
:
"1024x"
,
# Has "x" so Pydantic accepts, but parse_size() rejects
},
)
# parse_size() raises ValueError → endpoint converts to 400 (Bad Request)
assert
response
.
status_code
==
400
detail
=
str
(
response
.
json
()[
"detail"
])
assert
"size"
in
detail
.
lower
()
or
"invalid"
in
detail
.
lower
()
def
test_missing_prompt
(
test_client
):
"""Test with missing required prompt field"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"size"
:
"1024x1024"
,
},
)
# Pydantic validation error
assert
response
.
status_code
==
422
def
test_invalid_n_parameter
(
test_client
):
"""Test with invalid n parameter (out of range)"""
# n < 1
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a cat"
,
"n"
:
0
,
},
)
assert
response
.
status_code
==
422
# n > 10
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a cat"
,
"n"
:
11
,
},
)
assert
response
.
status_code
==
422
def
test_url_response_format_not_supported
(
test_client
):
"""Test that URL format returns error"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a cat"
,
"response_format"
:
"url"
,
},
)
# Pydantic validation errors return 422 (Unprocessable Entity)
assert
response
.
status_code
==
422
# Check error mentions response_format or b64_json
detail
=
str
(
response
.
json
()[
"detail"
])
assert
"b64_json"
in
detail
.
lower
()
or
"response"
in
detail
.
lower
()
def
test_model_not_loaded
():
"""Test error when diffusion engine is not initialized"""
from
fastapi
import
FastAPI
from
vllm_omni.entrypoints.openai.api_server
import
router
app
=
FastAPI
()
app
.
include_router
(
router
)
# Don't set diffusion_engine to simulate uninitialized state
app
.
state
.
diffusion_engine
=
None
client
=
TestClient
(
app
)
response
=
client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a cat"
,
},
)
assert
response
.
status_code
==
503
assert
"not initialized"
in
response
.
json
()[
"detail"
].
lower
()
def
test_different_image_sizes
(
test_client
):
"""Test various valid image sizes"""
sizes
=
[
"256x256"
,
"512x512"
,
"1024x1024"
,
"1792x1024"
,
"1024x1792"
]
for
size
in
sizes
:
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"a test image"
,
"size"
:
size
,
},
)
assert
response
.
status_code
==
200
,
f
"Failed for size
{
size
}
"
def
test_parameter_validation
():
"""Test Pydantic model validation"""
from
vllm_omni.entrypoints.openai.protocol.images
import
ImageGenerationRequest
# Valid request - optional parameters default to None
req
=
ImageGenerationRequest
(
prompt
=
"test"
)
assert
req
.
prompt
==
"test"
assert
req
.
n
==
1
assert
req
.
model
is
None
assert
req
.
size
is
None
# Engine will use model defaults
assert
req
.
num_inference_steps
is
None
# Engine will use model defaults
assert
req
.
true_cfg_scale
is
None
# Engine will use model defaults
# Invalid num_inference_steps (out of range)
with
pytest
.
raises
(
ValueError
):
ImageGenerationRequest
(
prompt
=
"test"
,
num_inference_steps
=
0
)
with
pytest
.
raises
(
ValueError
):
ImageGenerationRequest
(
prompt
=
"test"
,
num_inference_steps
=
201
)
# Invalid guidance_scale (out of range)
with
pytest
.
raises
(
ValueError
):
ImageGenerationRequest
(
prompt
=
"test"
,
guidance_scale
=-
1.0
)
with
pytest
.
raises
(
ValueError
):
ImageGenerationRequest
(
prompt
=
"test"
,
guidance_scale
=
21.0
)
# Pass-Through Tests
def
test_parameters_passed_through
(
test_client
,
mock_async_diffusion
):
"""Verify all parameters passed through without modification"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"test"
,
"num_inference_steps"
:
100
,
"guidance_scale"
:
7.5
,
"true_cfg_scale"
:
3.0
,
"seed"
:
42
,
},
)
assert
response
.
status_code
==
200
# Ensure generate() was called exactly once
mock_async_diffusion
.
generate
.
assert_awaited_once
()
call_kwargs
=
mock_async_diffusion
.
generate
.
call_args
[
1
][
"sampling_params_list"
][
0
]
assert
call_kwargs
.
num_inference_steps
==
100
assert
call_kwargs
.
guidance_scale
==
7.5
assert
call_kwargs
.
true_cfg_scale
==
3.0
assert
call_kwargs
.
seed
==
42
def
test_model_field_omitted_works
(
test_client
):
"""Test that omitting model field works"""
response
=
test_client
.
post
(
"/v1/images/generations"
,
json
=
{
"prompt"
:
"test"
,
"size"
:
"1024x1024"
,
# model field omitted
},
)
assert
response
.
status_code
==
200
def
make_test_image_bytes
(
size
=
(
64
,
64
))
->
bytes
:
img
=
Image
.
new
(
"RGB"
,
size
,
)
buf
=
io
.
BytesIO
()
img
.
save
(
buf
,
format
=
"PNG"
)
return
buf
.
getvalue
()
def
test_image_edit_images_processing
(
async_omni_test_client
):
img_bytes_1
=
make_test_image_bytes
((
16
,
16
))
img_bytes_2
=
make_test_image_bytes
((
32
,
32
))
# uploadfile with image key
response
=
async_omni_test_client
.
post
(
"/v1/images/edits"
,
files
=
[
(
"image"
,
img_bytes_1
),
(
"image"
,
img_bytes_2
),
],
data
=
{
"prompt"
:
"hello world."
},
)
assert
response
.
status_code
==
200
engine
=
async_omni_test_client
.
app
.
state
.
engine_client
captured_prompt
=
engine
.
captured_prompt
processed_images
=
captured_prompt
[
"multi_modal_data"
][
"image"
]
assert
len
(
processed_images
)
==
2
assert
isinstance
(
processed_images
[
0
],
Image
.
Image
)
assert
isinstance
(
processed_images
[
1
],
Image
.
Image
)
assert
processed_images
[
0
].
size
==
(
16
,
16
)
assert
processed_images
[
1
].
size
==
(
32
,
32
)
# uploadfile with image[] key
response
=
async_omni_test_client
.
post
(
"/v1/images/edits"
,
files
=
[
(
"image[]"
,
img_bytes_2
),
(
"image[]"
,
img_bytes_1
),
],
data
=
{
"prompt"
:
"hello world."
},
)
assert
response
.
status_code
==
200
engine
=
async_omni_test_client
.
app
.
state
.
engine_client
captured_prompt
=
engine
.
captured_prompt
processed_images
=
captured_prompt
[
"multi_modal_data"
][
"image"
]
assert
len
(
processed_images
)
==
2
assert
isinstance
(
processed_images
[
0
],
Image
.
Image
)
assert
isinstance
(
processed_images
[
1
],
Image
.
Image
)
assert
processed_images
[
0
].
size
==
(
32
,
32
)
assert
processed_images
[
1
].
size
==
(
16
,
16
)
# base64 url
buf1
=
io
.
BytesIO
()
img1
=
Image
.
new
(
"RGB"
,
(
16
,
16
))
img1
.
save
(
buf1
,
format
=
"PNG"
)
b64_1
=
"data:image/png;base64,"
+
base64
.
b64encode
(
buf1
.
getvalue
()).
decode
()
buf2
=
io
.
BytesIO
()
img2
=
Image
.
new
(
"RGB"
,
(
24
,
24
))
img2
.
save
(
buf2
,
format
=
"PNG"
)
b64_2
=
"data:image/png;base64,"
+
base64
.
b64encode
(
buf2
.
getvalue
()).
decode
()
response
=
async_omni_test_client
.
post
(
"/v1/images/edits"
,
data
=
{
"prompt"
:
"hello from base64"
,
"url"
:
[
b64_1
,
b64_2
],
},
)
assert
response
.
status_code
==
200
processed_images
=
engine
.
captured_prompt
[
"multi_modal_data"
][
"image"
]
assert
len
(
processed_images
)
==
2
assert
isinstance
(
processed_images
[
0
],
Image
.
Image
)
assert
isinstance
(
processed_images
[
1
],
Image
.
Image
)
assert
processed_images
[
0
].
size
==
(
16
,
16
)
assert
processed_images
[
1
].
size
==
(
24
,
24
)
def
test_image_edit_parameter_pass
(
async_omni_test_client
):
img_bytes_1
=
make_test_image_bytes
((
16
,
16
))
# uploadfile with image key
response
=
async_omni_test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
"size"
:
"16x24"
,
"output_format"
:
"jpeg"
,
"num_inference_steps"
:
20
,
"guidance_scale"
:
8.0
,
"seed"
:
1234
,
"negative_prompt"
:
"negative"
,
"n"
:
2
,
},
)
assert
response
.
status_code
==
200
engine
=
async_omni_test_client
.
app
.
state
.
engine_client
captured_prompt
=
engine
.
captured_prompt
captured_sampling_params
=
engine
.
captured_sampling_params_list
[
-
1
]
assert
captured_prompt
[
"prompt"
]
==
"hello world."
assert
captured_prompt
[
"negative_prompt"
]
==
"negative"
assert
captured_sampling_params
.
num_inference_steps
==
20
assert
captured_sampling_params
.
guidance_scale
==
8.0
assert
captured_sampling_params
.
seed
==
1234
assert
captured_sampling_params
.
num_outputs_per_prompt
==
2
assert
captured_sampling_params
.
width
==
16
assert
captured_sampling_params
.
height
==
24
data
=
response
.
json
()
# All images should be valid
for
img_data
in
data
[
"data"
]:
assert
"b64_json"
in
img_data
img_bytes
=
base64
.
b64decode
(
img_data
[
"b64_json"
])
img
=
Image
.
open
(
io
.
BytesIO
(
img_bytes
))
assert
img
.
format
.
lower
()
==
"jpeg"
assert
data
[
"output_format"
]
==
"jpeg"
assert
data
[
"size"
]
==
"16x24"
def
test_image_edit_parameter_default
(
async_omni_test_client
):
img_bytes_1
=
make_test_image_bytes
((
24
,
16
))
# uploadfile with image key
response
=
async_omni_test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
"size"
:
"auto"
,
},
)
assert
response
.
status_code
==
200
engine
=
async_omni_test_client
.
app
.
state
.
engine_client
captured_sampling_params
=
engine
.
captured_sampling_params_list
[
-
1
]
assert
captured_sampling_params
.
width
==
24
assert
captured_sampling_params
.
height
==
16
assert
captured_sampling_params
.
num_outputs_per_prompt
==
1
assert
captured_sampling_params
.
num_inference_steps
==
4
assert
captured_sampling_params
.
guidance_scale
==
7.5
response
=
async_omni_test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
"size"
:
"96x96"
,
},
)
assert
response
.
status_code
==
400
def
test_image_edit_parameter_default_single_stage
(
test_client
):
img_bytes_1
=
make_test_image_bytes
((
24
,
16
))
# uploadfile with image key
response
=
test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
},
)
assert
response
.
status_code
==
200
engine
=
test_client
.
app
.
state
.
engine_client
captured_sampling_params
=
engine
.
captured_sampling_params_list
[
0
]
assert
captured_sampling_params
.
width
==
24
assert
captured_sampling_params
.
height
==
16
assert
captured_sampling_params
.
num_outputs_per_prompt
==
1
assert
captured_sampling_params
.
num_inference_steps
==
4
assert
captured_sampling_params
.
guidance_scale
==
7.5
response
=
test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
"size"
:
"96x96"
,
},
)
assert
response
.
status_code
==
400
def
test_image_edit_compression_jpeg
(
test_client
):
img_bytes_1
=
make_test_image_bytes
((
16
,
16
))
# uploadfile with image key
response
=
test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
"output_format"
:
"jpeg"
,
"output_compression"
:
100
},
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
img_bytes_100
=
base64
.
b64decode
(
data
[
"data"
][
0
][
"b64_json"
])
img
=
Image
.
open
(
io
.
BytesIO
(
img_bytes_100
))
assert
img
.
format
.
lower
()
==
"jpeg"
response
=
test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
"output_format"
:
"jpeg"
,
"output_compression"
:
50
,
},
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
img_bytes_50
=
base64
.
b64decode
(
data
[
"data"
][
0
][
"b64_json"
])
response
=
test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
"output_format"
:
"jpeg"
,
"output_compression"
:
10
,
},
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
img_bytes_10
=
base64
.
b64decode
(
data
[
"data"
][
0
][
"b64_json"
])
assert
len
(
img_bytes_10
)
<
len
(
img_bytes_50
)
assert
len
(
img_bytes_50
)
<
len
(
img_bytes_100
)
def
test_image_edit_compression_png
(
async_omni_test_client
):
img_bytes_1
=
make_test_image_bytes
((
16
,
16
))
# uploadfile with image key
response
=
async_omni_test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
"output_format"
:
"PNG"
,
"output_compression"
:
100
},
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
img_bytes_100
=
base64
.
b64decode
(
data
[
"data"
][
0
][
"b64_json"
])
img
=
Image
.
open
(
io
.
BytesIO
(
img_bytes_100
))
assert
img
.
format
.
lower
()
==
"png"
response
=
async_omni_test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
"output_format"
:
"PNG"
,
"output_compression"
:
50
,
},
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
img_bytes_50
=
base64
.
b64decode
(
data
[
"data"
][
0
][
"b64_json"
])
response
=
async_omni_test_client
.
post
(
"/v1/images/edits"
,
files
=
[(
"image"
,
img_bytes_1
)],
data
=
{
"prompt"
:
"hello world."
,
"output_format"
:
"PNG"
,
"output_compression"
:
10
,
},
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
img_bytes_10
=
base64
.
b64decode
(
data
[
"data"
][
0
][
"b64_json"
])
assert
len
(
img_bytes_10
)
<
len
(
img_bytes_50
)
assert
len
(
img_bytes_50
)
<
len
(
img_bytes_100
)
tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
"""
Unit tests for OmniOpenAIServingChat sampling params handling.
Tests that standard OpenAI API parameters (max_tokens, temperature, etc.)
are correctly applied to the comprehension stage while preserving YAML defaults.
"""
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.sampling_params
import
SamplingParams
@
pytest
.
fixture
def
mock_comprehension_stage
():
"""Create a mock comprehension stage with is_comprehension=True."""
stage
=
MagicMock
()
stage
.
is_comprehension
=
True
stage
.
model_stage
=
"comprehension"
return
stage
@
pytest
.
fixture
def
mock_other_stage
():
"""Create a mock non-comprehension stage."""
stage
=
MagicMock
()
stage
.
is_comprehension
=
False
stage
.
model_stage
=
"other"
return
stage
@
pytest
.
fixture
def
default_comprehension_params
():
"""Default sampling params for comprehension stage (from YAML)."""
return
SamplingParams
(
temperature
=
0.4
,
top_p
=
0.9
,
top_k
=
1
,
max_tokens
=
2048
,
seed
=
42
,
repetition_penalty
=
1.05
,
)
@
pytest
.
fixture
def
default_other_params
():
"""Default sampling params for non-comprehension stage (from YAML)."""
return
SamplingParams
(
temperature
=
0.9
,
top_k
=
50
,
max_tokens
=
4096
,
seed
=
42
,
)
@
pytest
.
fixture
def
mock_engine_client
(
mock_comprehension_stage
,
mock_other_stage
,
default_comprehension_params
,
default_other_params
):
"""Create mock engine client with stage_list and default_sampling_params_list."""
engine_client
=
MagicMock
()
engine_client
.
stage_list
=
[
mock_comprehension_stage
,
mock_other_stage
]
engine_client
.
default_sampling_params_list
=
[
default_comprehension_params
,
default_other_params
,
]
return
engine_client
@
pytest
.
fixture
def
serving_chat
(
mock_engine_client
):
"""Create OmniOpenAIServingChat instance with mocked dependencies."""
from
vllm_omni.entrypoints.openai.serving_chat
import
OmniOpenAIServingChat
# Create instance without calling __init__
instance
=
object
.
__new__
(
OmniOpenAIServingChat
)
instance
.
engine_client
=
mock_engine_client
return
instance
@
pytest
.
fixture
def
mock_request
():
"""Create a mock request with all OpenAI sampling params set to None."""
request
=
MagicMock
()
# OpenAI standard sampling fields
request
.
temperature
=
None
request
.
top_p
=
None
request
.
max_tokens
=
None
request
.
seed
=
None
request
.
stop
=
None
request
.
frequency_penalty
=
None
request
.
presence_penalty
=
None
return
request
# =============================================================================
# Tests for _OPENAI_SAMPLING_FIELDS constant
# =============================================================================
def
test_openai_sampling_fields_contains_expected_fields
():
"""Test that _OPENAI_SAMPLING_FIELDS contains all expected OpenAI params."""
from
vllm_omni.entrypoints.openai.serving_chat
import
OmniOpenAIServingChat
expected_fields
=
{
"temperature"
,
"top_p"
,
"max_tokens"
,
"seed"
,
"stop"
,
"frequency_penalty"
,
"presence_penalty"
,
}
assert
OmniOpenAIServingChat
.
_OPENAI_SAMPLING_FIELDS
==
expected_fields
# =============================================================================
# Tests for _build_sampling_params_list_from_request
# =============================================================================
def
test_preserves_yaml_defaults_when_no_request_params
(
serving_chat
,
mock_request
):
"""Test that YAML defaults are preserved when request has no params."""
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
assert
len
(
result
)
==
2
comprehension_params
=
result
[
0
]
assert
comprehension_params
.
temperature
==
0.4
assert
comprehension_params
.
top_p
==
0.9
assert
comprehension_params
.
top_k
==
1
# YAML custom param preserved
assert
comprehension_params
.
max_tokens
==
2048
assert
comprehension_params
.
seed
==
42
assert
comprehension_params
.
repetition_penalty
==
1.05
# YAML custom param preserved
def
test_request_temperature_overrides_yaml_default
(
serving_chat
,
mock_request
):
"""Test that request temperature overrides YAML default."""
mock_request
.
temperature
=
0.8
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
comprehension_params
=
result
[
0
]
assert
comprehension_params
.
temperature
==
0.8
# Overridden
assert
comprehension_params
.
seed
==
42
# Preserved from YAML
assert
comprehension_params
.
top_k
==
1
# YAML custom param preserved
def
test_request_top_p_overrides_yaml_default
(
serving_chat
,
mock_request
):
"""Test that request top_p overrides YAML default."""
mock_request
.
top_p
=
0.95
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
comprehension_params
=
result
[
0
]
assert
comprehension_params
.
top_p
==
0.95
# Overridden
assert
comprehension_params
.
temperature
==
0.4
# Preserved from YAML
def
test_request_max_tokens_overrides_yaml_default
(
serving_chat
,
mock_request
):
"""Test that request max_tokens overrides YAML default."""
mock_request
.
max_tokens
=
100
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
assert
result
[
0
].
max_tokens
==
100
def
test_max_tokens_uses_yaml_default_when_not_specified
(
serving_chat
,
mock_request
):
"""Test that max_tokens falls back to YAML default when not in request."""
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
assert
result
[
0
].
max_tokens
==
2048
def
test_request_seed_overrides_yaml_default
(
serving_chat
,
mock_request
):
"""Test that request seed overrides YAML default."""
mock_request
.
seed
=
123
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
comprehension_params
=
result
[
0
]
assert
comprehension_params
.
seed
==
123
# Overridden
assert
comprehension_params
.
temperature
==
0.4
# Preserved from YAML
def
test_request_frequency_penalty_overrides
(
serving_chat
,
mock_request
):
"""Test that request frequency_penalty is applied."""
mock_request
.
frequency_penalty
=
0.5
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
assert
result
[
0
].
frequency_penalty
==
0.5
def
test_request_presence_penalty_overrides
(
serving_chat
,
mock_request
):
"""Test that request presence_penalty is applied."""
mock_request
.
presence_penalty
=
0.3
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
assert
result
[
0
].
presence_penalty
==
0.3
def
test_non_comprehension_stages_use_cloned_defaults
(
serving_chat
,
mock_request
):
"""Test that non-comprehension stages always use cloned YAML defaults."""
mock_request
.
max_tokens
=
50
mock_request
.
temperature
=
0.1
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
other_params
=
result
[
1
]
assert
other_params
.
temperature
==
0.9
# YAML default (not affected by request)
assert
other_params
.
max_tokens
==
4096
# YAML default (not affected by request)
assert
other_params
.
top_k
==
50
# YAML default
assert
other_params
.
seed
==
42
# YAML default
def
test_multiple_params_override_together
(
serving_chat
,
mock_request
):
"""Test that multiple request params can override together."""
mock_request
.
max_tokens
=
200
mock_request
.
temperature
=
0.7
mock_request
.
top_p
=
0.85
mock_request
.
seed
=
999
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
comprehension_params
=
result
[
0
]
# Overridden by request
assert
comprehension_params
.
temperature
==
0.7
assert
comprehension_params
.
top_p
==
0.85
assert
comprehension_params
.
max_tokens
==
200
assert
comprehension_params
.
seed
==
999
# Preserved from YAML (not in _OPENAI_SAMPLING_FIELDS)
assert
comprehension_params
.
top_k
==
1
assert
comprehension_params
.
repetition_penalty
==
1.05
def
test_yaml_custom_params_not_overridden_by_request
(
serving_chat
,
mock_request
):
"""Test that YAML custom params (top_k, repetition_penalty) are not affected."""
# Even if request has these attributes, they should not override YAML
# because they're not in _OPENAI_SAMPLING_FIELDS
mock_request
.
top_k
=
100
# Not in allowlist
mock_request
.
repetition_penalty
=
2.0
# Not in allowlist
result
=
serving_chat
.
_build_sampling_params_list_from_request
(
mock_request
)
comprehension_params
=
result
[
0
]
assert
comprehension_params
.
top_k
==
1
# YAML default preserved
assert
comprehension_params
.
repetition_penalty
==
1.05
# YAML default preserved
# =============================================================================
# Tests for _apply_request_overrides
# =============================================================================
def
test_apply_request_overrides_clones_params
(
serving_chat
,
mock_request
,
default_comprehension_params
):
"""Test that _apply_request_overrides returns a cloned object."""
result
=
serving_chat
.
_apply_request_overrides
(
default_comprehension_params
,
mock_request
)
assert
result
is
not
default_comprehension_params
# Different object
def
test_apply_request_overrides_preserves_defaults
(
serving_chat
,
mock_request
,
default_comprehension_params
):
"""Test that _apply_request_overrides preserves defaults when request has None."""
result
=
serving_chat
.
_apply_request_overrides
(
default_comprehension_params
,
mock_request
)
assert
result
.
temperature
==
0.4
assert
result
.
top_p
==
0.9
assert
result
.
seed
==
42
assert
result
.
top_k
==
1
# YAML custom param
def
test_apply_request_overrides_applies_values
(
serving_chat
,
mock_request
,
default_comprehension_params
):
"""Test that _apply_request_overrides applies non-None request values."""
mock_request
.
temperature
=
0.8
mock_request
.
seed
=
123
result
=
serving_chat
.
_apply_request_overrides
(
default_comprehension_params
,
mock_request
)
assert
result
.
temperature
==
0.8
# Overridden
assert
result
.
seed
==
123
# Overridden
assert
result
.
top_p
==
0.9
# Preserved from default
assert
result
.
top_k
==
1
# YAML custom param preserved
# =============================================================================
# Tests for _get_comprehension_stage_index
# =============================================================================
def
test_get_comprehension_stage_index_finds_first_stage
(
mock_engine_client
):
"""Test finding comprehension stage when it's at index 0."""
from
vllm_omni.entrypoints.openai.serving_chat
import
OmniOpenAIServingChat
instance
=
object
.
__new__
(
OmniOpenAIServingChat
)
instance
.
engine_client
=
mock_engine_client
assert
instance
.
_get_comprehension_stage_index
()
==
0
def
test_get_comprehension_stage_index_finds_second_stage
():
"""Test finding comprehension stage when it's at index 1."""
from
vllm_omni.entrypoints.openai.serving_chat
import
OmniOpenAIServingChat
instance
=
object
.
__new__
(
OmniOpenAIServingChat
)
other
=
MagicMock
()
other
.
is_comprehension
=
False
comprehension
=
MagicMock
()
comprehension
.
is_comprehension
=
True
instance
.
engine_client
=
MagicMock
()
instance
.
engine_client
.
stage_list
=
[
other
,
comprehension
]
assert
instance
.
_get_comprehension_stage_index
()
==
1
def
test_get_comprehension_stage_index_raises_when_not_found
():
"""Test that ValueError is raised when no comprehension stage exists."""
from
vllm_omni.entrypoints.openai.serving_chat
import
OmniOpenAIServingChat
instance
=
object
.
__new__
(
OmniOpenAIServingChat
)
stage1
=
MagicMock
()
stage1
.
is_comprehension
=
False
stage2
=
MagicMock
()
stage2
.
is_comprehension
=
False
instance
.
engine_client
=
MagicMock
()
instance
.
engine_client
.
stage_list
=
[
stage1
,
stage2
]
with
pytest
.
raises
(
ValueError
,
match
=
"No comprehension stage"
):
instance
.
_get_comprehension_stage_index
()
tests/entrypoints/openai_api/test_serving_speech.py
0 → 100644
View file @
c1cacde6
# tests/entrypoints/openai/test_serving_speech.py
import
logging
from
inspect
import
Signature
,
signature
from
unittest.mock
import
MagicMock
,
patch
import
numpy
as
np
import
pytest
import
torch
from
fastapi
import
FastAPI
from
fastapi.testclient
import
TestClient
from
vllm_omni.entrypoints.openai.audio_utils_mixin
import
AudioMixin
from
vllm_omni.entrypoints.openai.protocol.audio
import
CreateAudio
,
OpenAICreateSpeechRequest
from
vllm_omni.entrypoints.openai.serving_speech
import
(
OmniOpenAIServingSpeech
,
)
from
vllm_omni.outputs
import
OmniRequestOutput
logger
=
logging
.
getLogger
(
__name__
)
class
TestAudioMixin
:
@
pytest
.
fixture
def
audio_mixin
(
self
):
return
AudioMixin
()
def
test_stereo_to_mono_conversion
(
self
,
audio_mixin
):
stereo_tensor
=
np
.
random
.
rand
(
24000
,
2
).
astype
(
np
.
float32
)
audio_obj
=
CreateAudio
(
audio_tensor
=
stereo_tensor
)
with
(
patch
.
object
(
audio_mixin
,
"_apply_speed_adjustment"
,
side_effect
=
lambda
tensor
,
speed
,
sr
:
(
tensor
,
sr
)
)
as
mock_speed
,
patch
(
"soundfile.write"
)
as
_
,
):
audio_mixin
.
create_audio
(
audio_obj
)
# Check that the tensor passed to speed adjustment is mono
mock_speed
.
assert_called_once
()
adjusted_tensor
=
mock_speed
.
call_args
[
0
][
0
]
assert
len
(
adjusted_tensor
)
==
24000
@
patch
(
"librosa.effects.time_stretch"
)
def
test_speed_adjustment
(
self
,
mock_time_stretch
,
audio_mixin
):
mock_time_stretch
.
return_value
=
np
.
zeros
(
12000
)
audio_tensor
=
np
.
random
.
rand
(
24000
).
astype
(
np
.
float32
)
adjusted_audio
,
_
=
audio_mixin
.
_apply_speed_adjustment
(
audio_tensor
,
speed
=
2.0
,
sample_rate
=
24000
)
mock_time_stretch
.
assert_called_with
(
y
=
audio_tensor
,
rate
=
2.0
)
assert
adjusted_audio
.
shape
==
(
12000
,)
@
patch
(
"soundfile.write"
)
def
test_unsupported_format_fallback
(
self
,
mock_write
,
audio_mixin
,
caplog
):
audio_tensor
=
np
.
random
.
rand
(
24000
).
astype
(
np
.
float32
)
# Use a format that is not in the list of supported formats
audio_obj
=
CreateAudio
(
audio_tensor
=
audio_tensor
,
response_format
=
"vorbis"
)
audio_mixin
.
create_audio
(
audio_obj
)
# Should fall back to 'wav'
mock_write
.
assert_called_once
()
write_kwargs
=
mock_write
.
call_args
.
kwargs
assert
write_kwargs
[
"format"
]
==
"WAV"
def
test_mono_audio_preservation
(
self
,
audio_mixin
):
"""Test that mono (1D) audio tensors are processed correctly and passed to writer."""
mono_tensor
=
np
.
random
.
rand
(
24000
).
astype
(
np
.
float32
)
audio_obj
=
CreateAudio
(
audio_tensor
=
mono_tensor
)
with
patch
(
"soundfile.write"
)
as
mock_write
:
audio_mixin
.
create_audio
(
audio_obj
)
mock_write
.
assert_called_once
()
# Verify the tensor passed to soundfile.write is the exact 1D tensor
output_tensor
=
mock_write
.
call_args
[
0
][
1
]
assert
output_tensor
.
ndim
==
1
assert
output_tensor
.
shape
==
(
24000
,)
assert
np
.
array_equal
(
output_tensor
,
mono_tensor
)
def
test_stereo_audio_preservation
(
self
,
audio_mixin
):
"""Test that stereo (2D) audio tensors are processed correctly and preserved."""
stereo_tensor
=
np
.
random
.
rand
(
24000
,
2
).
astype
(
np
.
float32
)
audio_obj
=
CreateAudio
(
audio_tensor
=
stereo_tensor
)
with
patch
(
"soundfile.write"
)
as
mock_write
:
audio_mixin
.
create_audio
(
audio_obj
)
mock_write
.
assert_called_once
()
# Verify the tensor passed to soundfile.write is the exact 2D tensor
output_tensor
=
mock_write
.
call_args
[
0
][
1
]
assert
output_tensor
.
ndim
==
2
assert
output_tensor
.
shape
==
(
24000
,
2
)
assert
np
.
array_equal
(
output_tensor
,
stereo_tensor
)
def
test_speed_adjustment_bypass
(
self
,
audio_mixin
):
"""Test that speed=1.0 bypasses the expensive librosa time stretching."""
audio_tensor
=
np
.
random
.
rand
(
24000
).
astype
(
np
.
float32
)
with
patch
(
"librosa.effects.time_stretch"
)
as
mock_time_stretch
:
# speed=1.0 should return immediately without calling librosa
result
,
_
=
audio_mixin
.
_apply_speed_adjustment
(
audio_tensor
,
speed
=
1.0
,
sample_rate
=
24000
)
mock_time_stretch
.
assert_not_called
()
assert
np
.
array_equal
(
result
,
audio_tensor
)
@
patch
(
"librosa.effects.time_stretch"
)
def
test_speed_adjustment_stereo_handling
(
self
,
mock_time_stretch
,
audio_mixin
):
"""Test that speed adjustment is attempted on stereo inputs."""
stereo_tensor
=
np
.
random
.
rand
(
24000
,
2
).
astype
(
np
.
float32
)
# Mock return value representing a sped-up version (half length)
mock_time_stretch
.
return_value
=
np
.
zeros
((
12000
,
2
),
dtype
=
np
.
float32
)
result
,
_
=
audio_mixin
.
_apply_speed_adjustment
(
stereo_tensor
,
speed
=
2.0
,
sample_rate
=
24000
)
mock_time_stretch
.
assert_called_once
()
# Ensure the stereo tensor was passed to librosa
call_args
=
mock_time_stretch
.
call_args
assert
np
.
array_equal
(
call_args
.
kwargs
[
"y"
],
stereo_tensor
)
assert
call_args
.
kwargs
[
"rate"
]
==
2.0
assert
result
.
shape
==
(
12000
,
2
)
# Helper to create mock model output for endpoint tests
def
create_mock_audio_output_for_test
(
request_id
:
str
=
"speech-mock-123"
,
)
->
OmniRequestOutput
:
class
MockCompletionOutput
:
def
__init__
(
self
,
index
:
int
=
0
):
self
.
index
=
index
self
.
text
=
""
self
.
token_ids
=
[]
self
.
finish_reason
=
"stop"
self
.
stop_reason
=
None
self
.
logprobs
=
None
class
MockRequestOutput
:
def
__init__
(
self
,
request_id
:
str
,
audio_tensor
:
torch
.
Tensor
):
self
.
request_id
=
request_id
self
.
outputs
=
[
MockCompletionOutput
(
index
=
0
)]
self
.
multimodal_output
=
{
"audio"
:
audio_tensor
}
self
.
finished
=
True
self
.
prompt_token_ids
=
None
self
.
encoder_prompt_token_ids
=
None
self
.
num_cached_tokens
=
None
self
.
prompt_logprobs
=
None
self
.
kv_transfer_params
=
None
num_samples
=
24000
audio_tensor
=
torch
.
sin
(
torch
.
linspace
(
0
,
440
*
2
*
torch
.
pi
,
num_samples
))
mock_request_output
=
MockRequestOutput
(
request_id
=
request_id
,
audio_tensor
=
audio_tensor
)
return
OmniRequestOutput
(
stage_id
=
0
,
final_output_type
=
"audio"
,
request_output
=
mock_request_output
,
)
def
create_mock_audio_output_on_completion_for_test
(
request_id
:
str
=
"speech-mock-completion-123"
,
)
->
OmniRequestOutput
:
class
MockCompletionOutput
:
def
__init__
(
self
,
audio_tensor
:
torch
.
Tensor
,
index
:
int
=
0
):
self
.
index
=
index
self
.
text
=
""
self
.
token_ids
=
[]
self
.
finish_reason
=
"stop"
self
.
stop_reason
=
None
self
.
logprobs
=
None
self
.
multimodal_output
=
{
"audio"
:
audio_tensor
,
"sr"
:
24000
}
class
MockRequestOutput
:
def
__init__
(
self
,
request_id
:
str
,
audio_tensor
:
torch
.
Tensor
):
self
.
request_id
=
request_id
self
.
outputs
=
[
MockCompletionOutput
(
audio_tensor
=
audio_tensor
,
index
=
0
)]
self
.
multimodal_output
=
{}
self
.
finished
=
True
self
.
prompt_token_ids
=
None
self
.
encoder_prompt_token_ids
=
None
self
.
num_cached_tokens
=
None
self
.
prompt_logprobs
=
None
self
.
kv_transfer_params
=
None
num_samples
=
24000
audio_tensor
=
torch
.
sin
(
torch
.
linspace
(
0
,
440
*
2
*
torch
.
pi
,
num_samples
))
mock_request_output
=
MockRequestOutput
(
request_id
=
request_id
,
audio_tensor
=
audio_tensor
)
return
OmniRequestOutput
(
stage_id
=
0
,
final_output_type
=
"audio"
,
request_output
=
mock_request_output
,
)
@
pytest
.
fixture
def
test_app
():
# Mock the engine client
mock_engine_client
=
MagicMock
()
mock_engine_client
.
errored
=
False
async
def
mock_generate_fn
(
*
args
,
**
kwargs
):
yield
create_mock_audio_output_for_test
(
request_id
=
kwargs
.
get
(
"request_id"
))
mock_engine_client
.
generate
=
MagicMock
(
side_effect
=
mock_generate_fn
)
mock_engine_client
.
default_sampling_params_list
=
[{}]
# Mock models to have an is_base_model method
mock_models
=
MagicMock
()
mock_models
.
is_base_model
.
return_value
=
True
mock_request_logger
=
MagicMock
()
speech_server
=
OmniOpenAIServingSpeech
(
engine_client
=
mock_engine_client
,
models
=
mock_models
,
request_logger
=
mock_request_logger
,
)
# Patch the signature of create_speech to remove 'raw_request' for FastAPI route introspection
original_create_speech
=
speech_server
.
create_speech
_
=
MagicMock
(
side_effect
=
original_create_speech
)
sig
=
signature
(
original_create_speech
)
new_parameters
=
[
param
for
name
,
param
in
sig
.
parameters
.
items
()
if
name
!=
"raw_request"
]
new_sig
=
Signature
(
parameters
=
new_parameters
,
return_annotation
=
sig
.
return_annotation
)
async
def
awaitable_patched_create_speech
(
*
args
,
**
kwargs
):
return
await
original_create_speech
(
*
args
,
**
kwargs
)
awaitable_patched_create_speech
.
__signature__
=
new_sig
speech_server
.
create_speech
=
awaitable_patched_create_speech
app
=
FastAPI
()
app
.
add_api_route
(
"/v1/audio/speech"
,
speech_server
.
create_speech
,
methods
=
[
"POST"
],
response_model
=
None
)
# Add list_voices endpoint
async
def
list_voices
():
speakers
=
sorted
(
speech_server
.
supported_speakers
)
if
speech_server
.
supported_speakers
else
[]
return
{
"voices"
:
speakers
}
app
.
add_api_route
(
"/v1/audio/voices"
,
list_voices
,
methods
=
[
"GET"
])
return
app
@
pytest
.
fixture
def
client
(
test_app
):
return
TestClient
(
test_app
)
class
TestSpeechAPI
:
def
test_create_speech_success
(
self
,
client
):
payload
=
{
"input"
:
"Hello world"
,
"model"
:
"tts-model"
,
"voice"
:
"alloy"
,
"response_format"
:
"wav"
,
}
response
=
client
.
post
(
"/v1/audio/speech"
,
json
=
payload
)
assert
response
.
status_code
==
200
assert
response
.
headers
[
"content-type"
]
==
"audio/wav"
assert
len
(
response
.
content
)
>
0
def
test_create_speech_mp3_format
(
self
,
client
):
payload
=
{
"input"
:
"Hello world"
,
"model"
:
"tts-model"
,
"voice"
:
"alloy"
,
"response_format"
:
"mp3"
,
}
response
=
client
.
post
(
"/v1/audio/speech"
,
json
=
payload
)
assert
response
.
status_code
==
200
assert
response
.
headers
[
"content-type"
]
==
"audio/mpeg"
assert
len
(
response
.
content
)
>
0
def
test_create_speech_reads_audio_from_completion_output
(
self
,
test_app
):
mock_engine_client
=
MagicMock
()
mock_engine_client
.
errored
=
False
async
def
mock_generate_fn
(
*
args
,
**
kwargs
):
yield
create_mock_audio_output_on_completion_for_test
(
request_id
=
kwargs
.
get
(
"request_id"
))
mock_engine_client
.
generate
=
MagicMock
(
side_effect
=
mock_generate_fn
)
mock_engine_client
.
default_sampling_params_list
=
[{}]
mock_models
=
MagicMock
()
mock_models
.
is_base_model
.
return_value
=
True
speech_server
=
OmniOpenAIServingSpeech
(
engine_client
=
mock_engine_client
,
models
=
mock_models
,
request_logger
=
MagicMock
(),
)
original_create_speech
=
speech_server
.
create_speech
sig
=
signature
(
original_create_speech
)
new_parameters
=
[
param
for
name
,
param
in
sig
.
parameters
.
items
()
if
name
!=
"raw_request"
]
new_sig
=
Signature
(
parameters
=
new_parameters
,
return_annotation
=
sig
.
return_annotation
)
async
def
awaitable_patched_create_speech
(
*
args
,
**
kwargs
):
return
await
original_create_speech
(
*
args
,
**
kwargs
)
awaitable_patched_create_speech
.
__signature__
=
new_sig
speech_server
.
create_speech
=
awaitable_patched_create_speech
app
=
FastAPI
()
app
.
add_api_route
(
"/v1/audio/speech"
,
speech_server
.
create_speech
,
methods
=
[
"POST"
],
response_model
=
None
)
client
=
TestClient
(
app
)
payload
=
{
"input"
:
"Hello world"
,
"model"
:
"tts-model"
,
"voice"
:
"alloy"
,
"response_format"
:
"wav"
,
}
response
=
client
.
post
(
"/v1/audio/speech"
,
json
=
payload
)
assert
response
.
status_code
==
200
assert
response
.
headers
[
"content-type"
]
==
"audio/wav"
assert
len
(
response
.
content
)
>
0
def
test_create_speech_invalid_format
(
self
,
client
):
payload
=
{
"input"
:
"Hello world"
,
"model"
:
"tts-model"
,
"voice"
:
"alloy"
,
"response_format"
:
"invalid_format"
,
}
response
=
client
.
post
(
"/v1/audio/speech"
,
json
=
payload
)
assert
response
.
status_code
==
422
# Unprocessable Entity
@
patch
(
"vllm_omni.entrypoints.openai.serving_speech.OmniOpenAIServingSpeech.create_audio"
)
def
test_speed_parameter_is_used
(
self
,
mock_create_audio
,
test_app
):
client
=
TestClient
(
test_app
)
mock_audio_response
=
MagicMock
()
mock_audio_response
.
audio_data
=
b
"dummy_audio"
mock_audio_response
.
media_type
=
"audio/wav"
mock_create_audio
.
return_value
=
mock_audio_response
payload
=
{
"input"
:
"This should be fast."
,
"model"
:
"tts-model"
,
"voice"
:
"alloy"
,
"response_format"
:
"wav"
,
"speed"
:
2.5
,
}
client
.
post
(
"/v1/audio/speech"
,
json
=
payload
)
mock_create_audio
.
assert_called_once
()
call_args
=
mock_create_audio
.
call_args
[
0
]
audio_obj
=
call_args
[
0
]
assert
isinstance
(
audio_obj
,
CreateAudio
)
assert
audio_obj
.
speed
==
2.5
def
test_list_voices_endpoint
(
self
,
client
):
response
=
client
.
get
(
"/v1/audio/voices"
)
assert
response
.
status_code
==
200
assert
"voices"
in
response
.
json
()
class
TestTTSMethods
:
"""Unit tests for TTS validation and parameter building."""
@
pytest
.
fixture
def
speech_server
(
self
):
mock_engine_client
=
MagicMock
()
mock_engine_client
.
errored
=
False
mock_engine_client
.
stage_list
=
None
mock_models
=
MagicMock
()
mock_models
.
is_base_model
.
return_value
=
True
return
OmniOpenAIServingSpeech
(
engine_client
=
mock_engine_client
,
models
=
mock_models
,
request_logger
=
MagicMock
(),
)
def
test_is_tts_model
(
self
,
speech_server
):
"""Test TTS model detection."""
# No stage_list -> False
assert
speech_server
.
_is_tts_model
()
is
False
# With qwen3_tts stage -> True
mock_stage
=
MagicMock
()
mock_stage
.
model_stage
=
"qwen3_tts"
speech_server
.
engine_client
.
stage_list
=
[
mock_stage
]
assert
speech_server
.
_is_tts_model
()
is
True
def
test_build_tts_prompt
(
self
,
speech_server
):
"""Test TTS prompt format."""
prompt
=
speech_server
.
_build_tts_prompt
(
"Hello"
)
assert
prompt
==
"<|im_start|>assistant
\n
Hello<|im_end|>
\n
<|im_start|>assistant
\n
"
def
test_validate_tts_request_basic
(
self
,
speech_server
):
"""Test basic validation cases."""
# Empty input
req
=
OpenAICreateSpeechRequest
(
input
=
""
)
assert
speech_server
.
_validate_tts_request
(
req
)
==
"Input text cannot be empty"
# Invalid language
req
=
OpenAICreateSpeechRequest
(
input
=
"Hello"
,
language
=
"InvalidLang"
)
assert
"Invalid language"
in
speech_server
.
_validate_tts_request
(
req
)
# When no speakers loaded, any voice is accepted (unconstrained)
req
=
OpenAICreateSpeechRequest
(
input
=
"Hello"
,
voice
=
"Invalid"
)
assert
speech_server
.
_validate_tts_request
(
req
)
is
None
# Valid request
req
=
OpenAICreateSpeechRequest
(
input
=
"Hello"
,
voice
=
"Vivian"
)
assert
speech_server
.
_validate_tts_request
(
req
)
is
None
def
test_validate_tts_request_task_types
(
self
,
speech_server
):
"""Test task-specific validation."""
# Base task requires ref_audio
req
=
OpenAICreateSpeechRequest
(
input
=
"Hello"
,
task_type
=
"Base"
)
assert
"ref_audio"
in
speech_server
.
_validate_tts_request
(
req
)
# VoiceDesign requires instructions
req
=
OpenAICreateSpeechRequest
(
input
=
"Hello"
,
task_type
=
"VoiceDesign"
)
assert
"instructions"
in
speech_server
.
_validate_tts_request
(
req
)
# ref_text only for Base
req
=
OpenAICreateSpeechRequest
(
input
=
"Hello"
,
ref_text
=
"text"
)
assert
"Base task"
in
speech_server
.
_validate_tts_request
(
req
)
def
test_build_tts_params
(
self
,
speech_server
):
"""Test TTS parameter building."""
req
=
OpenAICreateSpeechRequest
(
input
=
"Hello"
,
voice
=
"Ryan"
,
language
=
"English"
)
params
=
speech_server
.
_build_tts_params
(
req
)
assert
params
[
"text"
]
==
[
"Hello"
]
assert
params
[
"speaker"
]
==
[
"Ryan"
]
assert
params
[
"language"
]
==
[
"English"
]
assert
params
[
"task_type"
]
==
[
"CustomVoice"
]
assert
"max_new_tokens"
not
in
params
def
test_build_tts_params_with_explicit_max_new_tokens
(
self
,
speech_server
):
"""Test explicit max_new_tokens override."""
req
=
OpenAICreateSpeechRequest
(
input
=
"Hello"
,
task_type
=
"Base"
,
ref_audio
=
"data:audio/wav;base64,AAAA"
,
max_new_tokens
=
128
,
)
params
=
speech_server
.
_build_tts_params
(
req
)
assert
params
[
"max_new_tokens"
]
==
[
128
]
def
test_load_supported_speakers
(
self
):
"""Test _load_supported_speakers."""
mock_engine_client
=
MagicMock
()
mock_engine_client
.
errored
=
False
mock_engine_client
.
stage_list
=
None
# Mock talker_config with mixed-case speaker names
mock_talker_config
=
MagicMock
()
mock_talker_config
.
spk_id
=
{
"Ryan"
:
0
,
"Vivian"
:
1
,
"Aiden"
:
2
}
mock_engine_client
.
model_config
.
hf_config
.
talker_config
=
mock_talker_config
mock_models
=
MagicMock
()
mock_models
.
is_base_model
.
return_value
=
True
server
=
OmniOpenAIServingSpeech
(
engine_client
=
mock_engine_client
,
models
=
mock_models
,
request_logger
=
MagicMock
(),
)
# Verify speakers are normalized to lowercase
assert
server
.
supported_speakers
==
{
"ryan"
,
"vivian"
,
"aiden"
}
tests/entrypoints/test_async_omni_diffusion_config.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm_omni.entrypoints
import
omni
as
omni_module
from
vllm_omni.entrypoints.async_omni
import
AsyncOmni
def
test_default_stage_config_includes_cache_backend
(
monkeypatch
):
"""Ensure cache_backend/cache_config are preserved in default diffusion stage."""
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
lambda
model
,
base_engine_args
=
None
:
[])
monkeypatch
.
setattr
(
omni_module
,
"resolve_model_config_path"
,
lambda
model
:
None
)
monkeypatch
.
setattr
(
AsyncOmni
,
"_start_stages"
,
lambda
self
,
model
:
None
)
monkeypatch
.
setattr
(
AsyncOmni
,
"_wait_for_stages_ready"
,
lambda
self
,
timeout
=
0
:
None
)
omni
=
AsyncOmni
(
model
=
"dummy-model"
,
cache_backend
=
"cache_dit"
,
cache_config
=
'{"Fn_compute_blocks": 2}'
,
vae_use_slicing
=
True
,
ulysses_degree
=
2
,
)
stage_cfg
=
omni
.
stage_configs
[
0
]
engine_args
=
stage_cfg
.
engine_args
assert
engine_args
.
get
(
"cache_backend"
)
==
"cache_dit"
cache_config
=
engine_args
.
get
(
"cache_config"
)
assert
cache_config
[
"Fn_compute_blocks"
]
==
2
assert
engine_args
.
get
(
"vae_use_slicing"
)
is
True
parallel_config
=
engine_args
.
get
(
"parallel_config"
)
if
hasattr
(
parallel_config
,
"get"
):
ulysses_degree
=
parallel_config
.
get
(
"ulysses_degree"
)
else
:
ulysses_degree
=
getattr
(
parallel_config
,
"ulysses_degree"
,
None
)
assert
ulysses_degree
==
2
def
test_default_cache_config_used_when_missing
(
monkeypatch
):
"""Ensure default cache_config is applied when cache_backend is set."""
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
lambda
model
,
base_engine_args
=
None
:
[])
monkeypatch
.
setattr
(
omni_module
,
"resolve_model_config_path"
,
lambda
model
:
None
)
monkeypatch
.
setattr
(
AsyncOmni
,
"_start_stages"
,
lambda
self
,
model
:
None
)
monkeypatch
.
setattr
(
AsyncOmni
,
"_wait_for_stages_ready"
,
lambda
self
,
timeout
=
0
:
None
)
omni
=
AsyncOmni
(
model
=
"dummy-model"
,
cache_backend
=
"cache_dit"
,
)
engine_args
=
omni
.
stage_configs
[
0
].
engine_args
cache_config
=
engine_args
.
get
(
"cache_config"
)
assert
cache_config
is
not
None
assert
cache_config
[
"Fn_compute_blocks"
]
==
1
def
test_default_stage_devices_from_sequence_parallel
(
monkeypatch
):
"""Ensure devices list reflects sequence parallel size when no parallel_config is provided."""
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
lambda
model
,
base_engine_args
=
None
:
[])
monkeypatch
.
setattr
(
omni_module
,
"resolve_model_config_path"
,
lambda
model
:
None
)
monkeypatch
.
setattr
(
AsyncOmni
,
"_start_stages"
,
lambda
self
,
model
:
None
)
monkeypatch
.
setattr
(
AsyncOmni
,
"_wait_for_stages_ready"
,
lambda
self
,
timeout
=
0
:
None
)
omni
=
AsyncOmni
(
model
=
"dummy-model"
,
ulysses_degree
=
2
,
ring_degree
=
2
,
)
stage_cfg
=
omni
.
stage_configs
[
0
]
runtime
=
stage_cfg
.
runtime
if
hasattr
(
runtime
,
"get"
):
devices
=
runtime
.
get
(
"devices"
)
else
:
devices
=
getattr
(
runtime
,
"devices"
,
None
)
assert
devices
==
"0,1,2,3"
tests/entrypoints/test_omni_diffusion.py
0 → 100644
View file @
c1cacde6
import
uuid
import
warnings
from
queue
import
Empty
,
Queue
from
typing
import
Any
from
unittest.mock
import
MagicMock
import
pytest
from
vllm_omni.entrypoints.stage_utils
import
SHUTDOWN_TASK
from
vllm_omni.inputs.data
import
OmniDiffusionSamplingParams
# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies.
warnings
.
filterwarnings
(
"ignore"
,
message
=
r
"builtin type SwigPy.*has no __module__ attribute"
,
category
=
DeprecationWarning
,
)
class
_FakeEngineArgs
(
dict
):
"""Fake engine args that can be used both as object attributes and as **kwargs."""
def
__init__
(
self
,
args_dict
:
dict
[
str
,
Any
]):
super
().
__init__
(
args_dict
)
# Add required attributes if not present
if
"model_stage"
not
in
self
:
self
[
"model_stage"
]
=
None
if
"engine_output_type"
not
in
self
:
self
[
"engine_output_type"
]
=
None
# Also set as attributes for object-style access
for
key
,
value
in
self
.
items
():
setattr
(
self
,
key
,
value
)
class
_FakeStageConfig
:
"""Fake stage config object that mimics the real stage config structure."""
def
__init__
(
self
,
config_dict
:
dict
[
str
,
Any
]):
# engine_args needs to work both as object (for OmniStage) and as dict (for **kwargs)
engine_args_dict
=
config_dict
.
get
(
"engine_args"
,
{})
self
.
engine_args
=
_FakeEngineArgs
(
engine_args_dict
)
self
.
final_output
=
config_dict
.
get
(
"final_output"
,
False
)
self
.
final_output_type
=
config_dict
.
get
(
"final_output_type"
,
None
)
self
.
stage_id
=
config_dict
.
get
(
"stage_id"
,
0
)
# Store original dict for reference
self
.
_config_dict
=
config_dict
class
_FakeQueue
:
"""Fake queue using standard library Queue to replace mp.Queue."""
def
__init__
(
self
,
maxsize
=
0
):
self
.
_queue
=
Queue
(
maxsize
=
maxsize
)
def
put
(
self
,
item
):
self
.
_queue
.
put
(
item
)
def
put_nowait
(
self
,
item
):
self
.
_queue
.
put_nowait
(
item
)
def
get
(
self
):
return
self
.
_queue
.
get
()
def
get_nowait
(
self
):
return
self
.
_queue
.
get_nowait
()
def
empty
(
self
):
return
self
.
_queue
.
empty
()
class
_FakeStage
:
"""Lightweight Stage stub for multi-process pipeline version with queue support."""
def
__init__
(
self
,
config
,
stage_init_timeout
:
int
=
300
):
# Handle both dict and object configs
if
isinstance
(
config
,
dict
):
config
=
_FakeStageConfig
(
config
)
self
.
config
=
config
self
.
stage_config
=
config
self
.
engine
=
None
self
.
engine_outputs
=
None
# Set attributes that OmniStage expects
self
.
stage_id
=
getattr
(
config
,
"stage_id"
,
0
)
self
.
engine_args
=
config
.
engine_args
self
.
model_stage
=
getattr
(
config
.
engine_args
,
"model_stage"
,
None
)
self
.
stage_type
=
"diffusion"
# set default sampling params
self
.
default_sampling_params
=
OmniDiffusionSamplingParams
(
num_inference_steps
=
1
)
# Allow configuring final_output and final_output_type
self
.
final_output
=
config
.
final_output
if
hasattr
(
config
,
"final_output"
)
else
False
self
.
final_output_type
=
getattr
(
config
,
"final_output_type"
,
None
)
# Configurable processing logic, default returns placeholder
processed_input
=
getattr
(
config
,
"_config_dict"
,
{}).
get
(
"processed_input"
,
[
"processed"
])
self
.
_processed_input
=
processed_input
# Queue references (set by attach_queues)
self
.
_in_q
=
None
self
.
_out_q
=
None
self
.
_proc
=
None
# Mock process reference
self
.
_stage_init_timeout
=
max
(
0
,
int
(
stage_init_timeout
))
def
attach_queues
(
self
,
in_q
,
out_q
):
"""Attach input and output queues."""
self
.
_in_q
=
in_q
self
.
_out_q
=
out_q
def
init_stage_worker
(
self
,
model
:
str
,
*
,
is_async
:
bool
=
False
,
shm_threshold_bytes
:
int
=
65536
,
ctx
=
None
,
batch_timeout
:
int
=
10
,
**
kwargs
,
):
"""Mock init_stage_worker: don't start real process, just send stage_ready message."""
# Create a mock process object
self
.
_proc
=
MagicMock
()
self
.
_proc
.
start
=
MagicMock
()
self
.
_proc
.
join
=
MagicMock
()
self
.
_proc
.
is_alive
=
MagicMock
(
return_value
=
False
)
self
.
_proc
.
terminate
=
MagicMock
()
# Send stage_ready message to output queue
if
self
.
_out_q
is
not
None
:
try
:
self
.
_out_q
.
put_nowait
({
"type"
:
"stage_ready"
,
"stage_id"
:
self
.
stage_id
})
except
Exception
:
pass
def
stop_stage_worker
(
self
):
"""Mock stop_stage_worker: clean up queue references."""
if
self
.
_in_q
is
not
None
:
try
:
self
.
_in_q
.
put_nowait
(
SHUTDOWN_TASK
)
except
Exception
:
pass
def
submit
(
self
,
payload
:
dict
[
str
,
Any
]):
"""Submit task to input queue."""
if
self
.
_in_q
is
not
None
:
self
.
_in_q
.
put
(
payload
)
def
try_collect
(
self
)
->
Any
:
"""Non-blocking collect from output queue."""
if
self
.
_out_q
is
None
:
return
None
try
:
return
self
.
_out_q
.
get_nowait
()
except
Empty
:
return
None
def
set_engine_outputs
(
self
,
outputs
):
"""Set engine outputs for the stage."""
self
.
engine_outputs
=
outputs
def
process_engine_inputs
(
self
,
stage_list
,
prompts
):
"""Process engine inputs: return preset processed result."""
return
self
.
_processed_input
class
_FakeEngine
:
"""Lightweight Engine stub: provides generate iterator output."""
def
__init__
(
self
,
outputs
:
list
[
Any
]):
self
.
_outputs
=
outputs
def
generate
(
self
,
prompts
,
sampling_params
):
# Record the most recent prompts for outer assertions
self
.
_last_prompts
=
prompts
# Simplified: return preset list at once, ensuring iterability
yield
from
self
.
_outputs
@
pytest
.
fixture
def
fake_stage_config
():
return
{
# Don't include 'model' in engine_args since it's passed separately
"engine_args"
:
{},
"final_output"
:
True
,
"final_output_type"
:
"text"
,
# Second stage will use processed_input to verify the chain
"processed_input"
:
[
"processed-by-stage"
],
}
def
_setup_engine_mocks
(
monkeypatch
):
"""Helper function to set up common engine mocks."""
fake_engine
=
MagicMock
()
# Add necessary attributes to fake_engine
fake_engine
.
tokenizer
=
MagicMock
()
fake_engine
.
log_stats
=
False
fake_engine
.
vllm_config
=
MagicMock
()
fake_engine
.
vllm_config
.
model_config
=
MagicMock
()
fake_engine
.
vllm_config
.
model_config
.
io_processor_plugin
=
None
fake_engine
.
get_supported_tasks
=
MagicMock
(
return_value
=
[])
fake_engine
.
model_config
=
MagicMock
()
fake_engine
.
model_config
.
io_processor_plugin
=
None
# Add registry with resolve_model_cls method
fake_registry
=
MagicMock
()
fake_registry
.
resolve_model_cls
=
MagicMock
(
return_value
=
(
MagicMock
(),
"test_arch"
))
fake_engine
.
model_config
.
registry
=
fake_registry
fake_engine
.
vllm_config
.
model_config
.
registry
=
fake_registry
monkeypatch
.
setattr
(
"vllm.v1.engine.llm_engine.LLMEngine.from_engine_args"
,
lambda
**
kw
:
fake_engine
,
raising
=
False
,
)
# Mock model_config.registry.resolve_model_cls to return a tuple
# Use a real class instead of MagicMock to avoid inspect.getsource issues
class
FakeModelClass
:
pass
monkeypatch
.
setattr
(
"vllm.model_executor.model_loader.utils.get_model_architecture"
,
lambda
model_config
:
(
FakeModelClass
,
"test_arch"
),
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm.model_executor.model_loader.utils._get_model_architecture"
,
lambda
model_config
:
(
FakeModelClass
,
"test_arch"
),
raising
=
False
,
)
# Mock try_create_mm_pooling_model_cls to return the class as-is
monkeypatch
.
setattr
(
"vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls"
,
lambda
model_cls
:
model_cls
,
raising
=
False
,
)
# Mock _enable_processor_cache to return False
monkeypatch
.
setattr
(
"vllm.multimodal.cache._enable_processor_cache"
,
lambda
model_config
,
mm_registry
:
False
,
raising
=
False
,
)
# Mock get_io_processor to return None
monkeypatch
.
setattr
(
"vllm.plugins.io_processors.get_io_processor"
,
lambda
vllm_config
,
io_processor_plugin
:
None
,
raising
=
False
,
)
def
_setup_multiprocessing_mocks
(
monkeypatch
):
"""Helper function to set up multiprocessing mocks."""
import
multiprocessing
as
mp
# Mock Process
fake_process_class
=
MagicMock
()
fake_process_instance
=
MagicMock
()
fake_process_instance
.
start
=
MagicMock
()
fake_process_instance
.
join
=
MagicMock
()
fake_process_instance
.
is_alive
=
MagicMock
(
return_value
=
False
)
fake_process_instance
.
terminate
=
MagicMock
()
fake_process_class
.
return_value
=
fake_process_instance
# Mock get_context to return a context with Queue that returns _FakeQueue
fake_ctx
=
MagicMock
()
fake_ctx
.
Queue
=
lambda
maxsize
=
0
:
_FakeQueue
(
maxsize
=
maxsize
)
fake_ctx
.
Process
=
fake_process_class
def
_mock_get_context
(
method
):
return
fake_ctx
monkeypatch
.
setattr
(
mp
,
"get_context"
,
_mock_get_context
,
raising
=
False
)
monkeypatch
.
setattr
(
mp
,
"Process"
,
fake_process_class
,
raising
=
False
)
def
_setup_ipc_mocks
(
monkeypatch
):
"""Helper function to set up IPC function mocks."""
# Mock _encode: simple serialization
def
_fake_encode
(
obj
,
threshold
,
obj_key
,
shm_key
):
return
{
obj_key
:
obj
}
# Mock _load: extract object from result
def
_fake_load
(
result
,
obj_key
,
shm_key
):
return
result
.
get
(
obj_key
)
# Mock _set: calculate serialization size
def
_fake_set
(
obj
):
return
str
(
obj
).
encode
()
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni._encode"
,
_fake_encode
,
raising
=
False
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni._load"
,
_fake_load
,
raising
=
False
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni._set"
,
_fake_set
,
raising
=
False
)
def
_setup_log_mocks
(
monkeypatch
):
"""Helper function to set up logging and stats mocks."""
# Mock OrchestratorMetrics to be a simple class that doesn't require file operations
class
_FakeOrchestratorMetrics
:
def
__init__
(
self
,
num_stages
,
enable_stats
,
wall_start_ts
):
self
.
num_stages
=
num_stages
self
.
enable_stats
=
enable_stats
self
.
stage_first_ts
=
[
None
]
*
num_stages
self
.
stage_last_ts
=
[
None
]
*
num_stages
self
.
e2e_done
=
set
()
def
on_stage_metrics
(
self
,
stage_id
,
req_id
,
metrics
):
pass
def
on_finalize_request
(
self
,
stage_id
,
req_id
,
start_ts
):
self
.
e2e_done
.
add
(
req_id
)
def
on_forward
(
self
,
from_stage
,
to_stage
,
req_id
,
size_bytes
,
tx_ms
,
use_shm
):
pass
def
build_and_log_summary
(
self
,
final_stage_id
):
return
"Fake summary"
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni.OrchestratorMetrics"
,
_FakeOrchestratorMetrics
,
raising
=
False
,
)
@
pytest
.
fixture
(
autouse
=
True
)
def
mock_get_config
(
monkeypatch
):
"""Auto-mock get_config and related model loading functions to avoid model path validation."""
# CRITICAL: Mock tokenizer-related imports FIRST, before any module imports
# This prevents ImportError when async_omni is imported (which happens via omni_stage)
import
sys
fake_tokenizer
=
MagicMock
()
fake_tokenizer
.
encode
=
MagicMock
(
return_value
=
[
1
,
2
,
3
])
fake_tokenizer
.
decode
=
MagicMock
(
return_value
=
"test"
)
# Mock init_tokenizer_from_configs (used in async_omni)
def
_mock_init_tokenizer_from_configs
(
model_config
=
None
,
**
kwargs
):
return
fake_tokenizer
# Strategy 1: Mock in the original location (vllm.transformers_utils.tokenizer)
# This works if the module hasn't been imported yet
monkeypatch
.
setattr
(
"vllm.transformers_utils.tokenizer.init_tokenizer_from_configs"
,
_mock_init_tokenizer_from_configs
,
raising
=
False
,
)
# Strategy 2: If the module is already in sys.modules, patch it directly
tokenizer_module_path
=
"vllm.transformers_utils.tokenizer"
if
tokenizer_module_path
in
sys
.
modules
:
tokenizer_module
=
sys
.
modules
[
tokenizer_module_path
]
setattr
(
tokenizer_module
,
"init_tokenizer_from_configs"
,
_mock_init_tokenizer_from_configs
)
# CRITICAL: Mock length_from_prompt_token_ids_or_embeds BEFORE trying to mock async_omni
# This is because async_omni imports processor.py, which imports this function at module level
# Mock length_from_prompt_token_ids_or_embeds (used in processor.py)
def
_mock_length_from_prompt_token_ids_or_embeds
(
prompt_token_ids
=
None
,
prompt_embeds
=
None
):
# Return a reasonable default length
if
prompt_token_ids
is
not
None
:
if
isinstance
(
prompt_token_ids
,
list
):
return
len
(
prompt_token_ids
)
elif
hasattr
(
prompt_token_ids
,
"shape"
):
return
prompt_token_ids
.
shape
[
-
1
]
if
len
(
prompt_token_ids
.
shape
)
>
0
else
1
if
prompt_embeds
is
not
None
:
if
hasattr
(
prompt_embeds
,
"shape"
):
return
prompt_embeds
.
shape
[
-
2
]
if
len
(
prompt_embeds
.
shape
)
>
1
else
1
return
10
# Default length
# Mock in vllm.utils
monkeypatch
.
setattr
(
"vllm.utils.length_from_prompt_token_ids_or_embeds"
,
_mock_length_from_prompt_token_ids_or_embeds
,
raising
=
False
,
)
# Also mock in processor module if it's imported
monkeypatch
.
setattr
(
"vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds"
,
_mock_length_from_prompt_token_ids_or_embeds
,
raising
=
False
,
)
# If processor module is already imported, patch it directly
processor_module_path
=
"vllm_omni.engine.input_processor"
if
processor_module_path
in
sys
.
modules
:
processor_module
=
sys
.
modules
[
processor_module_path
]
setattr
(
processor_module
,
"length_from_prompt_token_ids_or_embeds"
,
_mock_length_from_prompt_token_ids_or_embeds
)
# Strategy 3: Now mock async_omni AFTER length_from_prompt_token_ids_or_embeds is mocked
# This prevents ImportError when async_omni imports processor.py
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs"
,
_mock_init_tokenizer_from_configs
,
raising
=
False
,
)
# Strategy 4: If async_omni is already imported, patch it directly
async_omni_path
=
"vllm_omni.entrypoints.async_omni"
if
async_omni_path
in
sys
.
modules
:
async_omni_module
=
sys
.
modules
[
async_omni_path
]
setattr
(
async_omni_module
,
"init_tokenizer_from_configs"
,
_mock_init_tokenizer_from_configs
)
# Now mock get_config and other functions
fake_hf_config
=
MagicMock
()
fake_hf_config
.
model_type
=
"qwen2_5_omni"
def
_mock_get_config
(
model
,
**
kwargs
):
return
fake_hf_config
monkeypatch
.
setattr
(
"vllm.transformers_utils.config.get_config"
,
_mock_get_config
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.get_config"
,
_mock_get_config
,
raising
=
False
,
)
# Mock transformers' cached_file to avoid downloading model configs
def
_mock_cached_file
(
path_or_repo_id
,
*
args
,
**
kwargs
):
import
os
import
tempfile
fake_config_file
=
os
.
path
.
join
(
tempfile
.
gettempdir
(),
"fake_config.json"
)
if
not
os
.
path
.
exists
(
fake_config_file
):
with
open
(
fake_config_file
,
"w"
)
as
f
:
f
.
write
(
'{"model_type": "qwen2_5_omni"}'
)
return
fake_config_file
monkeypatch
.
setattr
(
"transformers.utils.hub.cached_file"
,
_mock_cached_file
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"transformers.utils.hub.cached_files"
,
lambda
path_or_repo_id
,
filenames
,
**
kwargs
:
(
[
_mock_cached_file
(
path_or_repo_id
,
filenames
[
0
])]
if
filenames
else
None
),
raising
=
False
,
)
def
test_initialize_stage_configs_called_when_none
(
monkeypatch
,
fake_stage_config
):
"""Test that stage configs are auto-loaded when stage_configs_path is None."""
def
_fake_loader
(
model
:
str
,
base_engine_args
=
None
):
return
[
_FakeStageConfig
(
fake_stage_config
),
_FakeStageConfig
(
fake_stage_config
),
]
# Remove modules from cache BEFORE setting mocks
import
sys
for
module_name
in
[
"vllm_omni.entrypoints.utils"
,
"vllm_omni.entrypoints.omni"
,
"vllm_omni.entrypoints.omni_stage"
,
]:
if
module_name
in
sys
.
modules
:
del
sys
.
modules
[
module_name
]
# Set up mocks
_setup_engine_mocks
(
monkeypatch
)
_setup_multiprocessing_mocks
(
monkeypatch
)
_setup_ipc_mocks
(
monkeypatch
)
_setup_log_mocks
(
monkeypatch
)
# Mock load_stage_configs_from_model
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
# Replace OmniStage
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni_stage.OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
),
raising
=
False
,
)
# Import the module after mocks are set
import
vllm_omni.entrypoints.omni
as
omni_module
# Patch the imported function and class in the module
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
_fake_loader
)
monkeypatch
.
setattr
(
omni_module
,
"OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
))
from
vllm_omni.entrypoints.omni
import
Omni
omni
=
Omni
(
model
=
"any"
,
init_timeout
=
1
)
# Verify: auto-loaded stage_configs and stage_list have consistent count
assert
isinstance
(
omni
.
stage_configs
,
list
)
assert
len
(
omni
.
stage_configs
)
==
2
assert
len
(
omni
.
stage_list
)
==
2
# Verify: each Stage is _FakeStage instance
for
st
in
omni
.
stage_list
:
assert
isinstance
(
st
,
_FakeStage
)
# Verify: queues are attached
for
st
in
omni
.
stage_list
:
assert
st
.
_in_q
is
not
None
assert
st
.
_out_q
is
not
None
# Verify: all stages are ready
assert
len
(
omni
.
_stages_ready
)
==
2
def
test_generate_raises_on_length_mismatch
(
monkeypatch
,
fake_stage_config
):
"""Test that generate raises ValueError when sampling_params_list length doesn't match."""
def
_fake_loader
(
model
:
str
,
base_engine_args
=
None
):
return
[
_FakeStageConfig
(
fake_stage_config
)]
import
sys
for
module_name
in
[
"vllm_omni.entrypoints.utils"
,
"vllm_omni.entrypoints.omni"
,
"vllm_omni.entrypoints.omni_stage"
,
]:
if
module_name
in
sys
.
modules
:
del
sys
.
modules
[
module_name
]
_setup_engine_mocks
(
monkeypatch
)
_setup_multiprocessing_mocks
(
monkeypatch
)
_setup_ipc_mocks
(
monkeypatch
)
_setup_log_mocks
(
monkeypatch
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni_stage.OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
),
raising
=
False
,
)
import
vllm_omni.entrypoints.omni
as
omni_module
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
_fake_loader
)
monkeypatch
.
setattr
(
omni_module
,
"OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
))
from
vllm_omni.entrypoints.omni
import
Omni
omni
=
Omni
(
model
=
"any"
,
init_timeout
=
1
)
with
pytest
.
raises
(
ValueError
):
omni
.
generate
(
prompts
=
[
"hi"
],
sampling_params_list
=
[])
def
test_generate_pipeline_and_final_outputs
(
monkeypatch
,
fake_stage_config
):
"""Test multi-stage generation pipeline with queue polling."""
stage_cfg0
=
dict
(
fake_stage_config
)
stage_cfg1
=
dict
(
fake_stage_config
)
stage_cfg1
[
"processed_input"
]
=
[
"processed-for-stage-1"
]
def
_fake_loader
(
model
:
str
,
base_engine_args
=
None
):
return
[
_FakeStageConfig
(
stage_cfg0
),
_FakeStageConfig
(
stage_cfg1
)]
import
sys
for
module_name
in
[
"vllm_omni.entrypoints.utils"
,
"vllm_omni.entrypoints.omni"
,
"vllm_omni.entrypoints.omni_stage"
,
]:
if
module_name
in
sys
.
modules
:
del
sys
.
modules
[
module_name
]
_setup_engine_mocks
(
monkeypatch
)
_setup_multiprocessing_mocks
(
monkeypatch
)
_setup_ipc_mocks
(
monkeypatch
)
_setup_log_mocks
(
monkeypatch
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni_stage.OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
),
raising
=
False
,
)
import
vllm_omni.entrypoints.omni
as
omni_module
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
_fake_loader
)
monkeypatch
.
setattr
(
omni_module
,
"OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid
=
uuid
.
UUID
(
"00000000-0000-0000-0000-000000000000"
)
monkeypatch
.
setattr
(
uuid
,
"uuid4"
,
lambda
:
test_uuid
)
monkeypatch
.
setattr
(
omni_module
,
"uuid"
,
uuid
)
from
vllm_omni.entrypoints.omni
import
Omni
omni
=
Omni
(
model
=
"any"
,
init_timeout
=
1
)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id
=
f
"0_
{
test_uuid
}
"
# Simulate worker behavior: manually put results into output queues
# Note: We put results before calling generate, which simulates worker processes
# that have already completed. The polling loop will collect them in stage order.
# Stage 0 output (will be collected first)
omni
.
stage_list
[
0
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
0
,
"text"
:
"s0"
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
# Stage 1 output (will be collected after stage 0 forwards to it)
# Note: In real flow, stage 1 result would appear after stage 0 forwards,
# but for testing we pre-populate it. The polling loop processes stages
# in order, so stage 0 result will be collected first, then forwarded,
# then stage 1 result will be collected.
omni
.
stage_list
[
1
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
1
,
"text"
:
"s1"
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
sampling_params_list
=
[
OmniDiffusionSamplingParams
(
num_inference_steps
=
1
),
OmniDiffusionSamplingParams
(
num_inference_steps
=
1
,
max_sequence_length
=
10
),
]
prompts
=
[
"hi"
]
outputs
=
omni
.
generate
(
prompts
=
prompts
,
sampling_params_list
=
sampling_params_list
)
# Both stages have final_output=True, so should aggregate two OmniRequestOutput
assert
len
(
outputs
)
==
2
# Verify stage outputs are set
assert
omni
.
stage_list
[
0
].
engine_outputs
==
[{
"stage"
:
0
,
"text"
:
"s0"
}]
assert
omni
.
stage_list
[
1
].
engine_outputs
==
[{
"stage"
:
1
,
"text"
:
"s1"
}]
# Verify stage 0 input queue received the task
assert
not
omni
.
stage_list
[
0
].
_in_q
.
empty
()
# Verify stage 1 received forwarded task (process_engine_inputs was called)
assert
omni
.
stage_list
[
1
].
process_engine_inputs
([],
[])
is
not
None
def
test_generate_pipeline_with_batch_input
(
monkeypatch
,
fake_stage_config
):
"""Test single-stage generation pipeline with multiple inputs in one batch."""
stage_cfg0
=
dict
(
fake_stage_config
)
stage_cfg1
=
dict
(
fake_stage_config
)
stage_cfg0
[
"final_output"
]
=
False
def
_fake_loader
(
model
:
str
,
base_engine_args
=
None
):
return
[
_FakeStageConfig
(
stage_cfg0
),
_FakeStageConfig
(
stage_cfg1
)]
import
sys
for
module_name
in
[
"vllm_omni.entrypoints.utils"
,
"vllm_omni.entrypoints.omni"
,
"vllm_omni.entrypoints.omni_stage"
,
]:
if
module_name
in
sys
.
modules
:
del
sys
.
modules
[
module_name
]
_setup_engine_mocks
(
monkeypatch
)
_setup_multiprocessing_mocks
(
monkeypatch
)
_setup_ipc_mocks
(
monkeypatch
)
_setup_log_mocks
(
monkeypatch
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni_stage.OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
),
raising
=
False
,
)
import
vllm_omni.entrypoints.omni
as
omni_module
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
_fake_loader
)
monkeypatch
.
setattr
(
omni_module
,
"OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid
=
uuid
.
UUID
(
"00000000-0000-0000-0000-000000000000"
)
monkeypatch
.
setattr
(
uuid
,
"uuid4"
,
lambda
:
test_uuid
)
monkeypatch
.
setattr
(
omni_module
,
"uuid"
,
uuid
)
from
vllm_omni.entrypoints.omni
import
Omni
omni
=
Omni
(
model
=
"any"
,
init_timeout
=
1
)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id
=
f
"0_
{
test_uuid
}
"
# Simulate worker behavior: manually put results into output queues
# Note: We put results before calling generate, which simulates worker processes
# that have already completed. The polling loop will collect them in stage order.
omni
.
stage_list
[
0
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
0
,
"text"
:
"s0"
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
omni
.
stage_list
[
0
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
0
,
"text"
:
"s0"
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
omni
.
stage_list
[
1
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
1
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
omni
.
stage_list
[
1
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
1
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
outputs
=
omni
.
generate
(
prompts
=
[
{
"prompt"
:
"hi"
,
"negative_prompt"
:
"hi"
,
"multi_modal_data"
:
{
"image"
:
[
"dog.jpg"
,
"cat.jpg"
]},
},
{
"prompt"
:
"hi"
,
"negative_prompt"
:
"hi"
,
"multi_modal_data"
:
{
"image"
:
[
"dog.jpg"
,
"cat.jpg"
]},
},
],
sampling_params_list
=
[
OmniDiffusionSamplingParams
(
num_inference_steps
=
1
),
OmniDiffusionSamplingParams
(
num_inference_steps
=
1
),
],
)
assert
len
(
outputs
)
==
2
def
test_generate_no_final_output_returns_empty
(
monkeypatch
,
fake_stage_config
):
"""Test that generate returns empty list when all stages have final_output=False."""
stage_cfg0
=
dict
(
fake_stage_config
)
stage_cfg1
=
dict
(
fake_stage_config
)
stage_cfg0
[
"final_output"
]
=
False
stage_cfg1
[
"final_output"
]
=
False
def
_fake_loader
(
model
:
str
,
base_engine_args
=
None
):
return
[
_FakeStageConfig
(
stage_cfg0
),
_FakeStageConfig
(
stage_cfg1
)]
import
sys
for
module_name
in
[
"vllm_omni.entrypoints.utils"
,
"vllm_omni.entrypoints.omni"
,
"vllm_omni.entrypoints.omni_stage"
,
]:
if
module_name
in
sys
.
modules
:
del
sys
.
modules
[
module_name
]
_setup_engine_mocks
(
monkeypatch
)
_setup_multiprocessing_mocks
(
monkeypatch
)
_setup_ipc_mocks
(
monkeypatch
)
_setup_log_mocks
(
monkeypatch
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni_stage.OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
),
raising
=
False
,
)
import
vllm_omni.entrypoints.omni
as
omni_module
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
_fake_loader
)
monkeypatch
.
setattr
(
omni_module
,
"OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid
=
uuid
.
UUID
(
"00000000-0000-0000-0000-000000000000"
)
monkeypatch
.
setattr
(
uuid
,
"uuid4"
,
lambda
:
test_uuid
)
monkeypatch
.
setattr
(
omni_module
,
"uuid"
,
uuid
)
from
vllm_omni.entrypoints.omni
import
Omni
omni
=
Omni
(
model
=
"any"
,
init_timeout
=
1
)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id
=
f
"0_
{
test_uuid
}
"
# Simulate worker behavior: put results into output queues
omni
.
stage_list
[
0
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
0
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
omni
.
stage_list
[
1
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
1
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
outputs
=
omni
.
generate
(
prompts
=
[
"p"
],
sampling_params_list
=
[
OmniDiffusionSamplingParams
(
num_inference_steps
=
1
),
OmniDiffusionSamplingParams
(
num_inference_steps
=
1
,
max_sequence_length
=
10
),
],
)
assert
outputs
==
[]
def
test_generate_sampling_params_none_use_default
(
monkeypatch
,
fake_stage_config
):
"""Test that generate uses default sampling params when sampling_params_list is None."""
stage_cfg0
=
dict
(
fake_stage_config
)
stage_cfg1
=
dict
(
fake_stage_config
)
stage_cfg0
[
"final_output"
]
=
False
stage_cfg1
[
"final_output"
]
=
False
def
_fake_loader
(
model
:
str
,
base_engine_args
=
None
):
return
[
_FakeStageConfig
(
stage_cfg0
),
_FakeStageConfig
(
stage_cfg1
)]
import
sys
for
module_name
in
[
"vllm_omni.entrypoints.utils"
,
"vllm_omni.entrypoints.omni"
,
"vllm_omni.entrypoints.omni_stage"
,
]:
if
module_name
in
sys
.
modules
:
del
sys
.
modules
[
module_name
]
_setup_engine_mocks
(
monkeypatch
)
_setup_multiprocessing_mocks
(
monkeypatch
)
_setup_ipc_mocks
(
monkeypatch
)
_setup_log_mocks
(
monkeypatch
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni_stage.OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
),
raising
=
False
,
)
import
vllm_omni.entrypoints.omni
as
omni_module
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
_fake_loader
)
monkeypatch
.
setattr
(
omni_module
,
"OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid
=
uuid
.
UUID
(
"00000000-0000-0000-0000-000000000000"
)
monkeypatch
.
setattr
(
uuid
,
"uuid4"
,
lambda
:
test_uuid
)
monkeypatch
.
setattr
(
omni_module
,
"uuid"
,
uuid
)
from
vllm_omni.entrypoints.omni
import
Omni
omni
=
Omni
(
model
=
"any"
,
init_timeout
=
1
)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id
=
f
"0_
{
test_uuid
}
"
# Simulate worker behavior: put results into output queues
omni
.
stage_list
[
0
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
0
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
omni
.
stage_list
[
1
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
1
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
# Use the default sampling params
omni
.
generate
(
prompts
=
[
"p"
],
sampling_params_list
=
None
)
def
test_wait_for_stages_ready_timeout
(
monkeypatch
,
fake_stage_config
):
"""Test that _wait_for_stages_ready handles timeout correctly."""
def
_fake_loader
(
model
:
str
,
base_engine_args
=
None
):
return
[
_FakeStageConfig
(
fake_stage_config
)]
import
sys
for
module_name
in
[
"vllm_omni.entrypoints.utils"
,
"vllm_omni.entrypoints.omni"
,
"vllm_omni.entrypoints.omni_stage"
,
]:
if
module_name
in
sys
.
modules
:
del
sys
.
modules
[
module_name
]
_setup_engine_mocks
(
monkeypatch
)
_setup_multiprocessing_mocks
(
monkeypatch
)
_setup_ipc_mocks
(
monkeypatch
)
_setup_log_mocks
(
monkeypatch
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
# Create a stage that doesn't send stage_ready message
class
_FakeStageNoReady
(
_FakeStage
):
def
init_stage_worker
(
self
,
*
args
,
**
kwargs
):
# Don't send stage_ready message
self
.
_proc
=
MagicMock
()
self
.
_proc
.
start
=
MagicMock
()
self
.
_proc
.
join
=
MagicMock
()
self
.
_proc
.
is_alive
=
MagicMock
(
return_value
=
False
)
self
.
_proc
.
terminate
=
MagicMock
()
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni_stage.OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStageNoReady
(
cfg
,
**
kwargs
),
raising
=
False
,
)
import
vllm_omni.entrypoints.omni
as
omni_module
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
_fake_loader
)
monkeypatch
.
setattr
(
omni_module
,
"OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStageNoReady
(
cfg
,
**
kwargs
))
from
vllm_omni.entrypoints.omni
import
Omni
# Use very short timeout
omni
=
Omni
(
model
=
"any"
,
init_timeout
=
0.01
)
# Verify that no stages are ready
assert
len
(
omni
.
_stages_ready
)
==
0
def
test_generate_handles_error_messages
(
monkeypatch
,
fake_stage_config
):
"""Test that generate handles error messages from stages correctly."""
def
_fake_loader
(
model
:
str
,
base_engine_args
=
None
):
return
[
_FakeStageConfig
(
fake_stage_config
)]
import
sys
for
module_name
in
[
"vllm_omni.entrypoints.utils"
,
"vllm_omni.entrypoints.omni"
,
"vllm_omni.entrypoints.omni_stage"
,
]:
if
module_name
in
sys
.
modules
:
del
sys
.
modules
[
module_name
]
_setup_engine_mocks
(
monkeypatch
)
_setup_multiprocessing_mocks
(
monkeypatch
)
_setup_ipc_mocks
(
monkeypatch
)
_setup_log_mocks
(
monkeypatch
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni_stage.OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
),
raising
=
False
,
)
import
vllm_omni.entrypoints.omni
as
omni_module
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
_fake_loader
)
monkeypatch
.
setattr
(
omni_module
,
"OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid
=
uuid
.
UUID
(
"00000000-0000-0000-0000-000000000000"
)
monkeypatch
.
setattr
(
uuid
,
"uuid4"
,
lambda
:
test_uuid
)
monkeypatch
.
setattr
(
omni_module
,
"uuid"
,
uuid
)
from
vllm_omni.entrypoints.omni
import
Omni
omni
=
Omni
(
model
=
"any"
,
init_timeout
=
1
)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id
=
f
"0_
{
test_uuid
}
"
# Put error message in output queue
omni
.
stage_list
[
0
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"error"
:
"test error"
,
}
)
# Also put a valid result after error to allow the loop to complete
# (error handling continues the loop, so we need a valid result to finish)
omni
.
stage_list
[
0
].
_out_q
.
put_nowait
(
{
"request_id"
:
expected_request_id
,
"engine_outputs"
:
[{
"stage"
:
0
,
"text"
:
"result"
}],
"metrics"
:
{
"num_tokens_out"
:
1
,
"stage_gen_time_ms"
:
10.0
},
}
)
# Generate should handle error gracefully (log but continue)
sampling_params_list
=
[
OmniDiffusionSamplingParams
(
num_inference_steps
=
1
)]
outputs
=
omni
.
generate
(
prompts
=
[
"hi"
],
sampling_params_list
=
sampling_params_list
)
# Should return final output (error was logged but didn't stop processing)
assert
isinstance
(
outputs
,
list
)
# Since final_output=True, should have one output
assert
len
(
outputs
)
==
1
def
test_close_sends_shutdown_signal
(
monkeypatch
,
fake_stage_config
):
"""Test that close() sends shutdown signal to all input queues."""
def
_fake_loader
(
model
:
str
,
base_engine_args
=
None
):
return
[
_FakeStageConfig
(
fake_stage_config
)]
import
sys
for
module_name
in
[
"vllm_omni.entrypoints.utils"
,
"vllm_omni.entrypoints.omni"
,
"vllm_omni.entrypoints.omni_stage"
,
]:
if
module_name
in
sys
.
modules
:
del
sys
.
modules
[
module_name
]
_setup_engine_mocks
(
monkeypatch
)
_setup_multiprocessing_mocks
(
monkeypatch
)
_setup_ipc_mocks
(
monkeypatch
)
_setup_log_mocks
(
monkeypatch
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model"
,
_fake_loader
,
raising
=
False
,
)
monkeypatch
.
setattr
(
"vllm_omni.entrypoints.omni_stage.OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
),
raising
=
False
,
)
import
vllm_omni.entrypoints.omni
as
omni_module
monkeypatch
.
setattr
(
omni_module
,
"load_stage_configs_from_model"
,
_fake_loader
)
monkeypatch
.
setattr
(
omni_module
,
"OmniStage"
,
lambda
cfg
,
**
kwargs
:
_FakeStage
(
cfg
,
**
kwargs
))
from
vllm_omni.entrypoints.omni
import
Omni
omni
=
Omni
(
model
=
"any"
,
init_timeout
=
1
)
# Call close
omni
.
close
()
# Verify shutdown signal (None) was sent to input queue
# Use get_nowait to avoid blocking (close() uses put_nowait, so should be safe)
try
:
shutdown_signal
=
omni
.
stage_list
[
0
].
_in_q
.
get_nowait
()
assert
shutdown_signal
==
SHUTDOWN_TASK
except
Empty
:
# If queue was already empty or only had stage_ready, that's also acceptable
# The important thing is that close() was called without error
pass
# Verify stop_stage_worker was called (process should be set)
assert
omni
.
stage_list
[
0
].
_proc
is
not
None
Prev
1
…
8
9
10
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment