Commit c1cacde6 authored by weishb's avatar weishb
Browse files

vllm-omni_0.15.0.rc1+fix1 first commit

parent 35607782
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Image-to-Video generation example using Wan2.2 I2V/TI2V models.
Supports:
- Wan2.2-I2V-A14B-Diffusers: MoE model with CLIP image encoder
- Wan2.2-TI2V-5B-Diffusers: Unified T2V+I2V model (dense 5B)
Usage:
# I2V-A14B (MoE)
python image_to_video.py --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \
--image input.jpg --prompt "A cat playing with yarn"
# TI2V-5B (unified)
python image_to_video.py --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \
--image input.jpg --prompt "A cat playing with yarn"
"""
import argparse
import os
from pathlib import Path
import numpy as np
import PIL.Image
import torch
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate a video from an image with Wan2.2 I2V/TI2V.")
parser.add_argument(
"--model",
default="Wan-AI/Wan2.2-I2V-A14B-Diffusers",
help="Diffusers Wan2.2 I2V model ID or local path.",
)
parser.add_argument("--image", required=True, help="Path to input image.")
parser.add_argument("--prompt", default="", help="Text prompt describing the desired motion.")
parser.add_argument("--negative_prompt", default="", help="Negative prompt.")
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
parser.add_argument("--guidance_scale", type=float, default=5.0, help="CFG scale.")
parser.add_argument(
"--guidance_scale_high", type=float, default=None, help="Optional separate CFG for high-noise (MoE only)."
)
parser.add_argument(
"--height", type=int, default=None, help="Video height (auto-calculated from image if not set)."
)
parser.add_argument("--width", type=int, default=None, help="Video width (auto-calculated from image if not set).")
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames.")
parser.add_argument("--num_inference_steps", type=int, default=50, help="Sampling steps.")
parser.add_argument("--boundary_ratio", type=float, default=0.875, help="Boundary split ratio for MoE models.")
parser.add_argument(
"--flow_shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)."
)
parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).")
parser.add_argument("--fps", type=int, default=16, help="Frames per second for the output video.")
parser.add_argument(
"--vae_use_slicing",
action="store_true",
help="Enable VAE slicing for memory optimization.",
)
parser.add_argument(
"--vae_use_tiling",
action="store_true",
help="Enable VAE tiling for memory optimization.",
)
parser.add_argument(
"--enable-cpu-offload",
action="store_true",
help="Enable CPU offloading for diffusion models.",
)
parser.add_argument(
"--enable-layerwise-offload",
action="store_true",
help="Enable layerwise (blockwise) offloading on DiT modules.",
)
parser.add_argument(
"--layerwise-num-gpu-layers",
type=int,
default=1,
help="Number of ready layers (blocks) to keep on GPU during generation.",
)
parser.add_argument(
"--cfg_parallel_size",
type=int,
default=1,
choices=[1, 2],
help="Number of GPUs used for classifier free guidance parallel size.",
)
parser.add_argument(
"--enforce_eager",
action="store_true",
help="Disable torch.compile and force eager execution.",
)
return parser.parse_args()
def calculate_dimensions(image: PIL.Image.Image, max_area: int = 480 * 832) -> tuple[int, int]:
"""Calculate output dimensions maintaining aspect ratio."""
aspect_ratio = image.height / image.width
mod_value = 16 # Must be divisible by 16
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
return height, width
def main():
args = parse_args()
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed)
# Load input image
image = PIL.Image.open(args.image).convert("RGB")
# Calculate dimensions if not provided
height = args.height
width = args.width
if height is None or width is None:
# Default to 480P area for I2V
calc_height, calc_width = calculate_dimensions(image, max_area=480 * 832)
height = height or calc_height
width = width or calc_width
# Resize image to target dimensions
image = image.resize((width, height), PIL.Image.Resampling.LANCZOS)
# Check if profiling is requested via environment variable
profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))
parallel_config = DiffusionParallelConfig(
cfg_parallel_size=args.cfg_parallel_size,
)
omni = Omni(
model=args.model,
enable_layerwise_offload=args.enable_layerwise_offload,
layerwise_num_gpu_layers=args.layerwise_num_gpu_layers,
vae_use_slicing=args.vae_use_slicing,
vae_use_tiling=args.vae_use_tiling,
boundary_ratio=args.boundary_ratio,
flow_shift=args.flow_shift,
enable_cpu_offload=args.enable_cpu_offload,
parallel_config=parallel_config,
enforce_eager=args.enforce_eager,
)
if profiler_enabled:
print("[Profiler] Starting profiling...")
omni.start_profile()
# Print generation configuration
print(f"\n{'=' * 60}")
print("Generation Configuration:")
print(f" Model: {args.model}")
print(f" Inference steps: {args.num_inference_steps}")
print(f" Frames: {args.num_frames}")
print(f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size}")
print(f" Video size: {args.width}x{args.height}")
print(f"{'=' * 60}\n")
# omni.generate() returns Generator[OmniRequestOutput, None, None]
frames = omni.generate(
{
"prompt": args.prompt,
"negative_prompt": args.negative_prompt,
"multi_modal_data": {"image": image},
},
OmniDiffusionSamplingParams(
height=height,
width=width,
generator=generator,
guidance_scale=args.guidance_scale,
guidance_scale_2=args.guidance_scale_high,
num_inference_steps=args.num_inference_steps,
num_frames=args.num_frames,
),
)
# Extract video frames from OmniRequestOutput
if isinstance(frames, list) and len(frames) > 0:
first_item = frames[0]
# Check if it's an OmniRequestOutput
if hasattr(first_item, "final_output_type"):
if first_item.final_output_type != "image":
raise ValueError(
f"Unexpected output type '{first_item.final_output_type}', expected 'image' for video generation."
)
# Pipeline mode: extract from nested request_output
if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output:
if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0:
inner_output = first_item.request_output[0]
if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"):
frames = inner_output.images[0] if inner_output.images else None
if frames is None:
raise ValueError("No video frames found in output.")
# Diffusion mode: use direct images field
elif hasattr(first_item, "images") and first_item.images:
frames = first_item.images
else:
raise ValueError("No video frames found in OmniRequestOutput.")
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
from diffusers.utils import export_to_video
except ImportError:
raise ImportError("diffusers is required for export_to_video.")
# frames may be np.ndarray (preferred) or torch.Tensor
# export_to_video expects a list of frames with values in [0, 1]
if isinstance(frames, torch.Tensor):
video_tensor = frames.detach().cpu()
if video_tensor.dim() == 5:
# [B, C, F, H, W] or [B, F, H, W, C]
if video_tensor.shape[1] in (3, 4):
video_tensor = video_tensor[0].permute(1, 2, 3, 0)
else:
video_tensor = video_tensor[0]
elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4):
video_tensor = video_tensor.permute(1, 2, 3, 0)
# If float, assume [-1,1] and normalize to [0,1]
if video_tensor.is_floating_point():
video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5
video_array = video_tensor.float().numpy()
else:
video_array = frames
if hasattr(video_array, "shape") and video_array.ndim == 5:
video_array = video_array[0]
# Convert 4D array (frames, H, W, C) to list of frames for export_to_video
if isinstance(video_array, np.ndarray) and video_array.ndim == 4:
video_array = list(video_array)
export_to_video(video_array, str(output_path), fps=args.fps)
print(f"Saved generated video to {output_path}")
if profiler_enabled:
print("\n[Profiler] Stopping profiler and collecting results...")
profile_results = omni.stop_profile()
if profile_results and isinstance(profile_results, dict):
traces = profile_results.get("traces", [])
print("\n" + "=" * 60)
print("PROFILING RESULTS:")
for rank, trace in enumerate(traces):
print(f"\nRank {rank}:")
if trace:
print(f" • Trace: {trace}")
if not traces:
print(" No traces collected.")
print("=" * 60)
else:
print("[Profiler] No valid profiling data returned.")
if __name__ == "__main__":
main()
# LoRA Inference Examples
This directory contains examples for using LoRA (Low-Rank Adaptation) adapters with vLLM-omni diffusion models for offline inference.
The example uses the `stabilityai/stable-diffusion-3.5-medium` as the default model, but you can replace it with other models in vLLM-omni.
## Overview
Similar to vLLM, vLLM-omni uses a unified LoRA handling mechanism:
- **Pre-loaded LoRA**: Loaded at initialization via `--lora-path` (pre-loaded into cache)
- **Per-request LoRA**: Loaded on-demand. In the example, the LoRA is loaded via `--lora-request-path` in each request
Both approaches use the same underlying mechanism - all LoRA adapters are handled uniformly through `set_active_adapter()`. If no LoRA request is provided in a request, all adapters are deactivated.
## Usage
### Pre-loaded LoRA (via --lora-path)
Load a LoRA adapter at initialization. This adapter is pre-loaded into the cache and can be activated by requests:
```bash
python -m examples.offline_inference.lora_inference.lora_inference \
--prompt "A piece of cheesecake" \
--lora-path /path/to/lora/ \
--lora-scale 1.0 \
--num_inference_steps 50 \
--height 1024 \
--width 1024 \
--output output_preloaded.png
```
**Note**: When using `--lora-path`, the adapter is loaded at init time with a stable ID derived from the adapter path. This example activates it automatically for the request.
### Per-request LoRA (via --lora-request-path)
Load a LoRA adapter on-demand for each request:
```bash
python -m examples.offline_inference.lora_inference.lora_inference \
--prompt "A piece of cheesecake" \
--lora-request-path /path/to/lora/ \
--lora-scale 1.0 \
--num_inference_steps 50 \
--height 1024 \
--width 1024 \
--output output_per_request.png
```
### No LoRA
If no LoRA request is provided, we will use the base model without any LoRA adapters:
```bash
python -m examples.offline_inference.lora_inference.lora_inference \
--prompt "A piece of cheesecake" \
--num_inference_steps 50 \
--height 1024 \
--width 1024 \
--output output_no_lora.png
```
## Parameters
### LoRA Parameters
- `--lora-path`: Path to LoRA adapter folder to pre-load at initialization (loads into cache with a stable ID derived from the path)
- `--lora-request-path`: Path to LoRA adapter folder for per-request loading
- `--lora-request-id`: Integer ID for the LoRA adapter (optional). If not provided and `--lora-request-path` is set, will derive a stable ID from the path.
- `--lora-scale`: Scale factor for LoRA weights (default: 1.0). Higher values increase the influence of the LoRA adapter.
### Standard Parameters
- `--prompt`: Text prompt for image generation (required)
- `--seed`: Random seed for reproducibility (default: 42)
- `--height`: Image height in pixels (default: 1024)
- `--width`: Image width in pixels (default: 1024)
- `--num_inference_steps`: Number of denoising steps (default: 50)
- `--output`: Output file path (default: `lora_output.png`)
## How LoRA Works
All LoRA adapters are handled uniformly:
1. **Initialization**: If `--lora-path` is provided, the adapter is loaded into cache with a stable ID derived from the adapter path
2. **Per-request**: If `--lora-request-path` is provided, the adapter is loaded/activated for that request
3. **No LoRA**: If no LoRA request is provided (`req.lora_request` is None), all adapters are deactivated
The system uses LRU cache management - adapters are cached and evicted when the cache is full (unless pinned).
## LoRA Adapter Format
LoRA adapters must be in PEFT (Parameter-Efficient Fine-Tuning) format. A typical LoRA adapter directory structure:
```
lora_adapter/
├── adapter_config.json
└── adapter_model.safetensors
```
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from pathlib import Path
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.lora.request import LoRARequest
from vllm_omni.lora.utils import stable_lora_int_id
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate images with LoRA adapters.")
parser.add_argument("--model", default="stabilityai/stable-diffusion-3.5-medium", help="Model name or path.")
parser.add_argument("--prompt", required=True, help="Text prompt for image generation.")
parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic results.")
parser.add_argument("--height", type=int, default=1024, help="Height of generated image.")
parser.add_argument("--width", type=int, default=1024, help="Width of generated image.")
parser.add_argument(
"--num_inference_steps",
type=int,
default=50,
help="Number of denoising steps for the diffusion sampler.",
)
parser.add_argument(
"--output",
type=str,
default="lora_output.png",
help="Path to save the generated image (PNG).",
)
parser.add_argument(
"--lora-path",
type=str,
default=None,
help="Path to LoRA adapter folder to pre-load at initialization (PEFT format). "
"Note: pre-loading populates the cache; you still need to pass a lora_request to activate it.",
)
parser.add_argument(
"--lora-request-path",
type=str,
default=None,
help="Path to LoRA adapter folder for per-request activation (dynamic LoRA). "
"If --lora-request-id is not provided, a stable ID will be derived from this path.",
)
parser.add_argument(
"--lora-request-id",
type=int,
default=None,
help="Integer ID for the LoRA adapter (for dynamic LoRA). "
"If not provided and --lora-request-path is set, will derive a stable ID from the path.",
)
parser.add_argument(
"--lora-scale",
type=float,
default=1.0,
help="Scale factor for LoRA weights (default: 1.0).",
)
return parser.parse_args()
def main():
args = parse_args()
model = args.model
omni_kwargs = {}
if args.lora_path:
omni_kwargs["lora_path"] = args.lora_path
print(f"Using static LoRA from: {args.lora_path}")
omni = Omni(model=model, **omni_kwargs)
lora_request = None
if args.lora_request_path:
if args.lora_request_id is None:
lora_request_id = stable_lora_int_id(args.lora_request_path)
else:
lora_request_id = args.lora_request_id
lora_name = Path(args.lora_request_path).stem
lora_request = LoRARequest(
lora_name=lora_name,
lora_int_id=lora_request_id,
lora_path=args.lora_request_path,
)
print(f"Using per-request LoRA: name={lora_name}, id={lora_request_id}, scale={args.lora_scale}")
elif args.lora_path:
# pre-loaded LoRA
lora_request_id = stable_lora_int_id(args.lora_path)
lora_request = LoRARequest(
lora_name="preloaded",
lora_int_id=lora_request_id,
lora_path=args.lora_path,
)
print(f"Activating pre-loaded LoRA: id={lora_request_id}, scale={args.lora_scale}")
sampling_params = OmniDiffusionSamplingParams(
height=args.height,
width=args.width,
num_inference_steps=args.num_inference_steps,
)
if lora_request:
sampling_params.lora_request = lora_request
sampling_params.lora_scale = args.lora_scale
outputs = omni.generate(args.prompt, sampling_params)
if not outputs or len(outputs) == 0:
raise ValueError("No output generated from omni.generate()")
if isinstance(outputs, list):
first_output = outputs[0]
else:
first_output = outputs
images = None
if hasattr(first_output, "images") and first_output.images:
images = first_output.images
elif hasattr(first_output, "request_output") and first_output.request_output:
req_out = first_output.request_output
if isinstance(req_out, list) and len(req_out) > 0:
req_out = req_out[0]
if hasattr(req_out, "images") and req_out.images:
images = req_out.images
if not images:
raise ValueError("No images found in request_output")
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
suffix = output_path.suffix or ".png"
stem = output_path.stem or "lora_output"
if len(images) <= 1:
images[0].save(output_path)
print(f"Saved generated image to {output_path}")
else:
for idx, img in enumerate(images):
save_path = output_path.parent / f"{stem}_{idx}{suffix}"
img.save(save_path)
print(f"Saved generated image to {save_path}")
if __name__ == "__main__":
main()
# Qwen2.5-Omni
## Setup
Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
## Run examples
### Multiple Prompts
Get into the example folder
```bash
cd examples/offline_inference/qwen2_5_omni
```
Then run the command below. Note: for processing large volume data, it uses py_generator mode, which will return a python generator from Omni class.
```bash
bash run_multiple_prompts.sh
```
### Single Prompt
Get into the example folder
```bash
cd examples/offline_inference/qwen2_5_omni
```
Then run the command below.
```bash
bash run_single_prompt.sh
```
### Modality control
If you want to control output modalities, e.g. only output text, you can run the command below:
```bash
python end2end.py --output-wav output_audio \
--query-type mixed_modalities \
--modalities text
```
#### Using Local Media Files
The `end2end.py` script supports local media files (audio, video, image) via CLI arguments:
```bash
# Use single local media files
python end2end.py --query-type use_image --image-path /path/to/image.jpg
python end2end.py --query-type use_video --video-path /path/to/video.mp4
python end2end.py --query-type use_audio --audio-path /path/to/audio.wav
# Combine multiple local media files
python end2end.py --query-type mixed_modalities \
--video-path /path/to/video.mp4 \
--image-path /path/to/image.jpg \
--audio-path /path/to/audio.wav
# Use audio from video file
python end2end.py --query-type use_audio_in_video --video-path /path/to/video.mp4
```
If media file paths are not provided, the script will use default assets. Supported query types:
- `use_image`: Image input only
- `use_video`: Video input only
- `use_audio`: Audio input only
- `mixed_modalities`: Audio + image + video
- `use_audio_in_video`: Extract audio from video
- `text`: Text-only query
### FAQ
If you encounter error about backend of librosa, try to install ffmpeg with command below.
```
sudo apt update
sudo apt install ffmpeg
```
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM-Omni for running offline inference
with the correct prompt format on Qwen2.5-Omni
"""
import os
import time
from typing import NamedTuple
import librosa
import numpy as np
import soundfile as sf
from PIL import Image
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset, video_to_ndarrays
from vllm.multimodal.image import convert_image_mode
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm_omni.entrypoints.omni import Omni
SEED = 42
class QueryResult(NamedTuple):
inputs: dict
limit_mm_per_prompt: dict[str, int]
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech."
)
def get_text_query(question: str = None) -> QueryResult:
if question is None:
question = "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words."
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
},
limit_mm_per_prompt={},
)
def get_mixed_modalities_query(
video_path: str | None = None,
image_path: str | None = None,
audio_path: str | None = None,
num_frames: int = 16,
sampling_rate: int = 16000,
) -> QueryResult:
question = "What is recited in the audio? What is the content of this image? Why is this video funny?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|vision_bos|><|IMAGE|><|vision_eos|>"
"<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
# Load video
if video_path:
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found: {video_path}")
video_frames = video_to_ndarrays(video_path, num_frames=num_frames)
else:
video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays
# Load image
if image_path:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
pil_image = Image.open(image_path)
image_data = convert_image_mode(pil_image, "RGB")
else:
image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
# Load audio
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": audio_data,
"image": image_data,
"video": video_frames,
},
},
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
)
def get_use_audio_in_video_query(
video_path: str | None = None, num_frames: int = 16, sampling_rate: int = 16000
) -> QueryResult:
question = "Describe the content of the video, then convert what the baby say into text."
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|><|audio_bos|><|AUDIO|><|audio_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
if video_path:
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found: {video_path}")
video_frames = video_to_ndarrays(video_path, num_frames=num_frames)
# Extract audio from video file
audio_signal, sr = librosa.load(video_path, sr=sampling_rate)
audio = (audio_signal.astype(np.float32), sr)
else:
asset = VideoAsset(name="baby_reading", num_frames=num_frames)
video_frames = asset.np_ndarrays
audio = asset.get_audio(sampling_rate=sampling_rate)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": video_frames,
"audio": audio,
},
"mm_processor_kwargs": {
"use_audio_in_video": True,
},
},
limit_mm_per_prompt={"audio": 1, "video": 1},
)
def get_multi_audios_query(audio_path: str | None = None, sampling_rate: int = 16000) -> QueryResult:
question = "Are these two audio clips the same?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|audio_bos|><|AUDIO|><|audio_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
# Use the provided audio as the first audio, default as second
audio_list = [
audio_data,
AudioAsset("mary_had_lamb").audio_and_sample_rate,
]
else:
audio_list = [
AudioAsset("winning_call").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate,
]
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": audio_list,
},
},
limit_mm_per_prompt={
"audio": 2,
},
)
def get_image_query(question: str = None, image_path: str | None = None) -> QueryResult:
if question is None:
question = "What is the content of this image?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_bos|><|IMAGE|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
if image_path:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
pil_image = Image.open(image_path)
image_data = convert_image_mode(pil_image, "RGB")
else:
image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"image": image_data,
},
},
limit_mm_per_prompt={"image": 1},
)
def get_video_query(question: str = None, video_path: str | None = None, num_frames: int = 16) -> QueryResult:
if question is None:
question = "Why is this video funny?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
if video_path:
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found: {video_path}")
video_frames = video_to_ndarrays(video_path, num_frames=num_frames)
else:
video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": video_frames,
},
},
limit_mm_per_prompt={"video": 1},
)
def get_audio_query(question: str = None, audio_path: str | None = None, sampling_rate: int = 16000) -> QueryResult:
if question is None:
question = "What is the content of this audio?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": audio_data,
},
},
limit_mm_per_prompt={"audio": 1},
)
query_map = {
"use_mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
"use_multi_audios": get_multi_audios_query,
"use_image": get_image_query,
"use_video": get_video_query,
"use_audio": get_audio_query,
"text": get_text_query,
}
def main(args):
model_name = "Qwen/Qwen2.5-Omni-7B"
# Get paths from args
video_path = getattr(args, "video_path", None)
image_path = getattr(args, "image_path", None)
audio_path = getattr(args, "audio_path", None)
num_frames = getattr(args, "num_frames", 16)
sampling_rate = getattr(args, "sampling_rate", 16000)
# Get the query function and call it with appropriate parameters
query_func = query_map[args.query_type]
if args.query_type == "mixed_modalities":
query_result = query_func(
video_path=video_path,
image_path=image_path,
audio_path=audio_path,
num_frames=num_frames,
sampling_rate=sampling_rate,
)
elif args.query_type == "use_audio_in_video":
query_result = query_func(video_path=video_path, num_frames=num_frames, sampling_rate=sampling_rate)
elif args.query_type == "multi_audios":
query_result = query_func(audio_path=audio_path, sampling_rate=sampling_rate)
elif args.query_type == "use_image":
query_result = query_func(image_path=image_path)
elif args.query_type == "use_video":
query_result = query_func(video_path=video_path, num_frames=num_frames)
elif args.query_type == "use_audio":
query_result = query_func(audio_path=audio_path, sampling_rate=sampling_rate)
else:
query_result = query_func()
omni_llm = Omni(
model=model_name,
log_stats=args.enable_stats,
stage_init_timeout=args.stage_init_timeout,
batch_timeout=args.batch_timeout,
init_timeout=args.init_timeout,
shm_threshold_bytes=args.shm_threshold_bytes,
)
thinker_sampling_params = SamplingParams(
temperature=0.0, # Deterministic - no randomness
top_p=1.0, # Disable nucleus sampling
top_k=-1, # Disable top-k sampling
max_tokens=2048,
seed=SEED, # Fixed seed for sampling
detokenize=True,
repetition_penalty=1.1,
)
talker_sampling_params = SamplingParams(
temperature=0.9,
top_p=0.8,
top_k=40,
max_tokens=2048,
seed=SEED, # Fixed seed for sampling
detokenize=True,
repetition_penalty=1.05,
stop_token_ids=[8294],
)
code2wav_sampling_params = SamplingParams(
temperature=0.0, # Deterministic - no randomness
top_p=1.0, # Disable nucleus sampling
top_k=-1, # Disable top-k sampling
max_tokens=2048,
seed=SEED, # Fixed seed for sampling
detokenize=True,
repetition_penalty=1.1,
)
sampling_params_list = [
thinker_sampling_params,
talker_sampling_params,
code2wav_sampling_params,
]
if args.txt_prompts is None:
prompts = [query_result.inputs for _ in range(args.num_prompts)]
else:
assert args.query_type == "text", "txt-prompts is only supported for text query type"
with open(args.txt_prompts, encoding="utf-8") as f:
lines = [ln.strip() for ln in f.readlines()]
prompts = [get_text_query(ln).inputs for ln in lines if ln != ""]
print(f"[Info] Loaded {len(prompts)} prompts from {args.txt_prompts}")
if args.modalities is not None:
output_modalities = args.modalities.split(",")
for i, prompt in enumerate(prompts):
prompt["modalities"] = output_modalities
profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))
if profiler_enabled:
omni_llm.start_profile(stages=[0])
omni_generator = omni_llm.generate(prompts, sampling_params_list, py_generator=args.py_generator)
# Determine output directory: prefer --output-dir; fallback to --output-wav
output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav
os.makedirs(output_dir, exist_ok=True)
total_requests = len(prompts)
processed_count = 0
for stage_outputs in omni_generator:
if stage_outputs.final_output_type == "text":
for output in stage_outputs.request_output:
request_id = output.request_id
text_output = output.outputs[0].text
# Save aligned text file per request
prompt_text = output.prompt
out_txt = os.path.join(output_dir, f"{request_id}.txt")
lines = []
lines.append("Prompt:\n")
lines.append(str(prompt_text) + "\n")
lines.append("vllm_text_output:\n")
lines.append(str(text_output).strip() + "\n")
try:
with open(out_txt, "w", encoding="utf-8") as f:
f.writelines(lines)
except Exception as e:
print(f"[Warn] Failed writing text file {out_txt}: {e}")
print(f"Request ID: {request_id}, Text saved to {out_txt}")
elif stage_outputs.final_output_type == "audio":
for output in stage_outputs.request_output:
request_id = output.request_id
audio_tensor = output.outputs[0].multimodal_output["audio"]
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
sf.write(output_wav, audio_tensor.detach().cpu().numpy(), samplerate=24000)
print(f"Request ID: {request_id}, Saved audio to {output_wav}")
processed_count += len(stage_outputs.request_output)
if profiler_enabled and processed_count >= total_requests:
print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...")
# Stop the profiler while workers are still alive
omni_llm.stop_profile()
print("[Info] Waiting 30s for workers to write massive trace files to disk...")
time.sleep(30)
print("[Info] Trace export wait finished.")
omni_llm.close()
def parse_args():
parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models")
parser.add_argument(
"--query-type",
"-q",
type=str,
default="use_mixed_modalities",
choices=query_map.keys(),
help="Query type.",
)
parser.add_argument(
"--enable-stats",
action="store_true",
default=False,
help="Enable writing detailed statistics (default: disabled)",
)
parser.add_argument(
"--stage-init-timeout",
type=int,
default=300,
help="Timeout for initializing a single stage in seconds (default: 300)",
)
parser.add_argument(
"--batch-timeout",
type=int,
default=5,
help="Timeout for batching in seconds (default: 5)",
)
parser.add_argument(
"--init-timeout",
type=int,
default=300,
help="Timeout for initializing stages in seconds (default: 300)",
)
parser.add_argument(
"--shm-threshold-bytes",
type=int,
default=65536,
help="Threshold for using shared memory in bytes (default: 65536)",
)
parser.add_argument(
"--output-wav",
default="output_audio",
help="[Deprecated] Output wav directory (use --output-dir).",
)
parser.add_argument(
"--num-prompts",
type=int,
default=1,
help="Number of prompts to generate.",
)
parser.add_argument(
"--txt-prompts",
type=str,
default=None,
help="Path to a .txt file with one prompt per line (preferred).",
)
parser.add_argument(
"--video-path",
"-v",
type=str,
default=None,
help="Path to local video file. If not provided, uses default video asset.",
)
parser.add_argument(
"--image-path",
"-i",
type=str,
default=None,
help="Path to local image file. If not provided, uses default image asset.",
)
parser.add_argument(
"--audio-path",
"-a",
type=str,
default=None,
help="Path to local audio file. If not provided, uses default audio asset.",
)
parser.add_argument(
"--num-frames",
type=int,
default=16,
help="Number of frames to extract from video (default: 16).",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=16000,
help="Sampling rate for audio loading (default: 16000).",
)
parser.add_argument(
"--worker-backend", type=str, default="multi_process", choices=["multi_process", "ray"], help="backend"
)
parser.add_argument(
"--ray-address",
type=str,
default=None,
help="Address of the Ray cluster.",
)
parser.add_argument(
"--modalities",
type=str,
default=None,
help="Modalities to use for the prompts.",
)
parser.add_argument(
"--py-generator",
action="store_true",
default=False,
help="Use py_generator mode. The returned type of Omni.generate() is a Python Generator object.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)
#!/usr/bin/env python3
import argparse
def extract_prompt(line: str) -> str | None:
# Extract the content between the first '|' and the second '|'
i = line.find("|")
if i == -1:
return None
j = line.find("|", i + 1)
if j == -1:
return None
return line[i + 1 : j].strip()
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--input", "-i", required=True, help="Input .lst file path")
parser.add_argument("--output", "-o", required=True, help="Output file path")
parser.add_argument(
"--topk",
"-k",
type=int,
default=100,
help="Extract the top K prompts (default: 100)",
)
args = parser.parse_args()
prompts = []
with open(args.input, encoding="utf-8", errors="ignore") as f:
for line in f:
if len(prompts) >= args.topk:
break
p = extract_prompt(line.rstrip("\n"))
if p:
prompts.append(p)
with open(args.output, "w", encoding="utf-8") as f:
for p in prompts:
f.write(p + "\n")
if __name__ == "__main__":
main()
python end2end.py --output-wav output_audio \
--query-type text \
--txt-prompts ../qwen3_omni/text_prompts_10.txt \
--py-generator
python end2end.py --output-wav output_audio \
--query-type use_mixed_modalities
# Qwen3-Omni
## Setup
Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
## Run examples
### Multiple Prompts
Get into the example folder
```bash
cd examples/offline_inference/qwen3_omni
```
Then run the command below. Note: for processing large volume data, it uses py_generator mode, which will return a python generator from Omni class.
```bash
bash run_multiple_prompts.sh
```
### Single Prompt
Get into the example folder
```bash
cd examples/offline_inference/qwen3_omni
```
Then run the command below.
```bash
bash run_single_prompt.sh
```
If you have not enough memory, you can set thinker with tensor parallel. Just run the command below.
```bash
bash run_single_prompt_tp.sh
```
### Modality control
If you want to control output modalities, e.g. only output text, you can run the command below:
```bash
python end2end.py --output-wav output_audio \
--query-type use_audio \
--modalities text
```
#### Using Local Media Files
The `end2end.py` script supports local media files (audio, video, image) via command-line arguments:
```bash
# Use local video file
python end2end.py --query-type use_video --video-path /path/to/video.mp4
# Use local image file
python end2end.py --query-type use_image --image-path /path/to/image.jpg
# Use local audio file
python end2end.py --query-type use_audio --audio-path /path/to/audio.wav
# Combine multiple local media files
python end2end.py --query-type mixed_modalities \
--video-path /path/to/video.mp4 \
--image-path /path/to/image.jpg \
--audio-path /path/to/audio.wav
```
If media file paths are not provided, the script will use default assets. Supported query types:
- `use_video`: Video input
- `use_image`: Image input
- `use_audio`: Audio input
- `text`: Text-only query
- `multi_audios`: Multiple audio inputs
- `mixed_modalities`: Combination of video, image, and audio inputs
### FAQ
If you encounter error about backend of librosa, try to install ffmpeg with command below.
```
sudo apt update
sudo apt install ffmpeg
```
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen3-Omni (thinker only).
"""
import os
import time
from typing import NamedTuple
import librosa
import numpy as np
import soundfile as sf
import vllm
from PIL import Image
from vllm import SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset, video_to_ndarrays
from vllm.multimodal.image import convert_image_mode
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm_omni.entrypoints.omni import Omni
SEED = 42
class QueryResult(NamedTuple):
inputs: dict
limit_mm_per_prompt: dict[str, int]
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech."
)
def get_text_query(question: str = None) -> QueryResult:
if question is None:
question = "Explain the system architecture for a scalable audio generation pipeline. Answer in 15 words."
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
},
limit_mm_per_prompt={},
)
def get_video_query(question: str = None, video_path: str | None = None, num_frames: int = 16) -> QueryResult:
if question is None:
question = "Why is this video funny?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
if video_path:
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found: {video_path}")
video_frames = video_to_ndarrays(video_path, num_frames=num_frames)
else:
video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": video_frames,
},
},
limit_mm_per_prompt={"video": 1},
)
def get_image_query(question: str = None, image_path: str | None = None) -> QueryResult:
if question is None:
question = "What is the content of this image?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
if image_path:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
pil_image = Image.open(image_path)
image_data = convert_image_mode(pil_image, "RGB")
else:
image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"image": image_data,
},
},
limit_mm_per_prompt={"image": 1},
)
def get_audio_query(question: str = None, audio_path: str | None = None, sampling_rate: int = 16000) -> QueryResult:
if question is None:
question = "What is the content of this audio?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": audio_data,
},
},
limit_mm_per_prompt={"audio": 1},
)
def get_mixed_modalities_query(
video_path: str | None = None,
image_path: str | None = None,
audio_path: str | None = None,
num_frames: int = 16,
sampling_rate: int = 16000,
) -> QueryResult:
question = "What is recited in the audio? What is the content of this image? Why is this video funny?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
"<|vision_start|><|image_pad|><|vision_end|>"
"<|vision_start|><|video_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
# Load video
if video_path:
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found: {video_path}")
video_frames = video_to_ndarrays(video_path, num_frames=num_frames)
else:
video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays
# Load image
if image_path:
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
pil_image = Image.open(image_path)
image_data = convert_image_mode(pil_image, "RGB")
else:
image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
# Load audio
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": audio_data,
"image": image_data,
"video": video_frames,
},
},
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
)
def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
"<|audio_start|><|audio_pad|><|audio_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": [
AudioAsset("winning_call").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate,
],
},
},
limit_mm_per_prompt={
"audio": 2,
},
)
# def get_use_audio_in_video_query(video_path: str | None = None) -> QueryResult:
# question = (
# "Describe the content of the video in details, then convert what the "
# "baby say into text."
# )
# prompt = (
# f"<|im_start|>system\n{default_system}<|im_end|>\n"
# "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
# f"{question}<|im_end|>\n"
# f"<|im_start|>assistant\n"
# )
# if video_path:
# if not os.path.exists(video_path):
# raise FileNotFoundError(f"Video file not found: {video_path}")
# video_frames = video_to_ndarrays(video_path, num_frames=16)
# else:
# video_frames = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays
# audio = extract_video_audio(video_path, sampling_rate=16000)
# return QueryResult(
# inputs={
# "prompt": prompt,
# "multi_modal_data": {
# "video": video_frames,
# "audio": audio,
# },
# "mm_processor_kwargs": {
# "use_audio_in_video": True,
# },
# },
# limit_mm_per_prompt={"audio": 1, "video": 1},
# )
def get_use_audio_in_video_query() -> QueryResult:
question = "Describe the content of the video in details, then convert what the baby say into text."
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": asset.np_ndarrays,
"audio": audio,
},
"mm_processor_kwargs": {
"use_audio_in_video": True,
},
},
limit_mm_per_prompt={"audio": 1, "video": 1},
)
query_map = {
"text": get_text_query,
"use_audio": get_audio_query,
"use_image": get_image_query,
"use_video": get_video_query,
"use_multi_audios": get_multi_audios_query,
"use_mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
}
def main(args):
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
print("=" * 20, "\n", f"vllm version: {vllm.__version__}", "\n", "=" * 20)
# Get paths from args
video_path = getattr(args, "video_path", None)
image_path = getattr(args, "image_path", None)
audio_path = getattr(args, "audio_path", None)
# Get the query function and call it with appropriate parameters
query_func = query_map[args.query_type]
if args.query_type == "use_video":
query_result = query_func(video_path=video_path, num_frames=getattr(args, "num_frames", 16))
elif args.query_type == "use_image":
query_result = query_func(image_path=image_path)
elif args.query_type == "use_audio":
query_result = query_func(audio_path=audio_path, sampling_rate=getattr(args, "sampling_rate", 16000))
elif args.query_type == "mixed_modalities":
query_result = query_func(
video_path=video_path,
image_path=image_path,
audio_path=audio_path,
num_frames=getattr(args, "num_frames", 16),
sampling_rate=getattr(args, "sampling_rate", 16000),
)
elif args.query_type == "multi_audios":
query_result = query_func()
elif args.query_type == "use_audio_in_video":
query_result = query_func()
else:
query_result = query_func()
omni_llm = Omni(
model=model_name,
stage_configs_path=args.stage_configs_path,
log_stats=args.enable_stats,
stage_init_timeout=args.stage_init_timeout,
)
thinker_sampling_params = SamplingParams(
temperature=0.9,
top_p=0.9,
top_k=-1,
max_tokens=1200,
repetition_penalty=1.05,
logit_bias={},
seed=SEED,
)
talker_sampling_params = SamplingParams(
temperature=0.9,
top_k=50,
max_tokens=4096,
seed=SEED,
detokenize=False,
repetition_penalty=1.05,
stop_token_ids=[2150], # TALKER_CODEC_EOS_TOKEN_ID
)
# Sampling parameters for Code2Wav stage (audio generation)
code2wav_sampling_params = SamplingParams(
temperature=0.0,
top_p=1.0,
top_k=-1,
max_tokens=4096 * 16,
seed=SEED,
detokenize=True,
repetition_penalty=1.1,
)
sampling_params_list = [
thinker_sampling_params,
talker_sampling_params, # code predictor is integrated into talker for Qwen3 Omni
code2wav_sampling_params,
]
if args.txt_prompts is None:
prompts = [query_result.inputs for _ in range(args.num_prompts)]
else:
assert args.query_type == "text", "txt-prompts is only supported for text query type"
with open(args.txt_prompts, encoding="utf-8") as f:
lines = [ln.strip() for ln in f.readlines()]
prompts = [get_text_query(ln).inputs for ln in lines if ln != ""]
print(f"[Info] Loaded {len(prompts)} prompts from {args.txt_prompts}")
if args.modalities is not None:
output_modalities = args.modalities.split(",")
for i, prompt in enumerate(prompts):
prompt["modalities"] = output_modalities
profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))
if profiler_enabled:
omni_llm.start_profile(stages=[0])
omni_generator = omni_llm.generate(prompts, sampling_params_list, py_generator=args.py_generator)
# Determine output directory: prefer --output-dir; fallback to --output-wav
output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav
os.makedirs(output_dir, exist_ok=True)
total_requests = len(prompts)
processed_count = 0
print(f"query type: {args.query_type}")
for stage_outputs in omni_generator:
if stage_outputs.final_output_type == "text":
for output in stage_outputs.request_output:
request_id = output.request_id
text_output = output.outputs[0].text
# Save aligned text file per request
prompt_text = output.prompt
out_txt = os.path.join(output_dir, f"{request_id}.txt")
lines = []
lines.append("Prompt:\n")
lines.append(str(prompt_text) + "\n")
lines.append("vllm_text_output:\n")
lines.append(str(text_output).strip() + "\n")
try:
with open(out_txt, "w", encoding="utf-8") as f:
f.writelines(lines)
except Exception as e:
print(f"[Warn] Failed writing text file {out_txt}: {e}")
print(f"Request ID: {request_id}, Text saved to {out_txt}")
elif stage_outputs.final_output_type == "audio":
for output in stage_outputs.request_output:
request_id = output.request_id
audio_tensor = output.outputs[0].multimodal_output["audio"]
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
# Convert to numpy array and ensure correct format
audio_numpy = audio_tensor.float().detach().cpu().numpy()
# Ensure audio is 1D (flatten if needed)
if audio_numpy.ndim > 1:
audio_numpy = audio_numpy.flatten()
# Save audio file with explicit WAV format
sf.write(output_wav, audio_numpy, samplerate=24000, format="WAV")
print(f"Request ID: {request_id}, Saved audio to {output_wav}")
processed_count += len(stage_outputs.request_output)
if profiler_enabled and processed_count >= total_requests:
print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...")
# Stop the profiler while workers are still alive
omni_llm.stop_profile()
print("[Info] Waiting 30s for workers to write trace files to disk...")
time.sleep(30)
print("[Info] Trace export wait time finished.")
omni_llm.close()
def parse_args():
parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models")
parser.add_argument(
"--query-type",
"-q",
type=str,
default="use_mixed_modalities",
choices=query_map.keys(),
help="Query type.",
)
parser.add_argument(
"--enable-stats",
action="store_true",
default=False,
help="Enable writing detailed statistics (default: disabled)",
)
parser.add_argument(
"--stage-init-timeout",
type=int,
default=300,
help="Timeout for initializing a single stage in seconds (default: 300)",
)
parser.add_argument(
"--batch-timeout",
type=int,
default=5,
help="Timeout for batching in seconds (default: 5)",
)
parser.add_argument(
"--init-timeout",
type=int,
default=300,
help="Timeout for initializing stages in seconds (default: 300)",
)
parser.add_argument(
"--shm-threshold-bytes",
type=int,
default=65536,
help="Threshold for using shared memory in bytes (default: 65536)",
)
parser.add_argument(
"--output-wav",
default="output_audio",
help="[Deprecated] Output wav directory (use --output-dir).",
)
parser.add_argument(
"--num-prompts",
type=int,
default=1,
help="Number of prompts to generate.",
)
parser.add_argument(
"--txt-prompts",
type=str,
default=None,
help="Path to a .txt file with one prompt per line (preferred).",
)
parser.add_argument(
"--stage-configs-path",
type=str,
default=None,
help="Path to a stage configs file.",
)
parser.add_argument(
"--video-path",
"-v",
type=str,
default=None,
help="Path to local video file. If not provided, uses default video asset.",
)
parser.add_argument(
"--image-path",
"-i",
type=str,
default=None,
help="Path to local image file. If not provided, uses default image asset.",
)
parser.add_argument(
"--audio-path",
"-a",
type=str,
default=None,
help="Path to local audio file. If not provided, uses default audio asset.",
)
parser.add_argument(
"--num-frames",
type=int,
default=16,
help="Number of frames to extract from video (default: 16).",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=16000,
help="Sampling rate for audio loading (default: 16000).",
)
parser.add_argument(
"--log-dir",
type=str,
default="logs",
help="Log directory (default: logs).",
)
parser.add_argument(
"--modalities",
type=str,
default=None,
help="Output modalities to use for the prompts.",
)
parser.add_argument(
"--py-generator",
action="store_true",
default=False,
help="Use py_generator mode. The returned type of Omni.generate() is a Python Generator object.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)
python end2end.py --output-wav output_audio \
--query-type text \
--txt-prompts text_prompts_10.txt \
--py-generator
python end2end.py --output-wav output_audio \
--query-type use_audio
python end2end.py --output-wav output_audio \
--query-type use_audio \
--stage-init-timeout 300
# stage-init-timeout sets the maximum wait to avoid two vLLM stages initializing at the same time on the same card.
What is the capital of France?
How many planets are in our solar system?
What is the largest ocean on Earth?
Who wrote the novel "1984"?
What is the chemical symbol for water?
What year did World War II end?
What is the tallest mountain in the world?
What is the speed of light in vacuum?
Who painted the Mona Lisa?
What is the smallest prime number?
# Qwen3-TTS Offline Inference
This directory contains an offline demo for running Qwen3 TTS models with vLLM Omni. It builds task-specific inputs and generates WAV files locally.
## Model Overview
Qwen3 TTS provides multiple task variants for speech generation:
- **CustomVoice**: Generate speech with a known speaker identity (speaker ID) and optional instruction.
- **VoiceDesign**: Generate speech from text plus a descriptive instruction that designs a new voice.
- **Base**: Voice cloning using a reference audio + reference transcript, with optional mode selection.
## Setup
Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
### ROCm Dependencies
You will need to install these two dependencies `onnxruntime-rocm` and `sox`.
```
pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm
pip install onnxruntime-rocm sox
```
## Quick Start
Run a single sample for a task:
```
python end2end.py --query-type CustomVoice
```
Generated audio files are saved to `output_audio/` by default.
## Task Usage
### CustomVoice
Single sample:
```
python end2end.py --query-type CustomVoice
```
Batch sample (multiple prompts in one run):
```
python end2end.py --query-type CustomVoice --use-batch-sample
```
### VoiceDesign
Single sample:
```
python end2end.py --query-type VoiceDesign
```
Batch sample:
```
python end2end.py --query-type VoiceDesign --use-batch-sample
```
### Base (Voice Clone)
Single sample:
```
python end2end.py --query-type Base
```
Batch sample:
```
python end2end.py --query-type Base --use-batch-sample
```
Mode selection for Base:
- `--mode-tag icl` (default): standard mode
- `--mode-tag xvec_only`: enable `x_vector_only_mode` in the request
Examples:
```
python end2end.py --query-type Base --mode-tag icl
```
## Notes
- The script uses the model paths embedded in `end2end.py`. Update them if your local cache path differs.
- Use `--output-dir` (preferred) or `--output-wav` to change the output folder.
"""Offline inference demo for Qwen3 TTS via vLLM Omni.
Provides single and batch sample inputs for CustomVoice, VoiceDesign, and Base
tasks, then runs Omni generation and saves output wav files.
"""
import os
from typing import NamedTuple
import soundfile as sf
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from vllm import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm_omni import Omni
class QueryResult(NamedTuple):
"""Container for a prepared Omni request."""
inputs: dict
model_name: str
def get_custom_voice_query(use_batch_sample: bool = False) -> QueryResult:
"""Build CustomVoice sample inputs.
Args:
use_batch_sample: When True, return a batch of prompts; otherwise a single prompt.
Returns:
QueryResult with Omni inputs and the CustomVoice model path.
"""
task_type = "CustomVoice"
if use_batch_sample:
texts = ["其实我真的有发现,我是一个特别善于观察别人情绪的人。", "She said she would be here by noon."]
instructs = ["", "Very happy."]
languages = ["Chinese", "English"]
speakers = ["Vivian", "Ryan"]
inputs = []
for text, instruct, language, speaker in zip(texts, instructs, languages, speakers):
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
inputs.append(
{
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"text": [text],
"instruct": [instruct],
"language": [language],
"speaker": [speaker],
"max_new_tokens": [2048],
},
}
)
else:
text = "其实我真的有发现,我是一个特别善于观察别人情绪的人。"
language = "Chinese"
speaker = "Vivian"
instruct = "用特别愤怒的语气说"
prompts = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
inputs = {
"prompt": prompts,
"additional_information": {
"task_type": [task_type],
"text": [text],
"language": [language],
"speaker": [speaker],
"instruct": [instruct],
"max_new_tokens": [2048],
},
}
return QueryResult(
inputs=inputs,
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
)
def get_voice_design_query(use_batch_sample: bool = False) -> QueryResult:
"""Build VoiceDesign sample inputs.
Args:
use_batch_sample: When True, return a batch of prompts; otherwise a single prompt.
Returns:
QueryResult with Omni inputs and the VoiceDesign model path.
"""
task_type = "VoiceDesign"
if use_batch_sample:
texts = [
"哥哥,你回来啦,人家等了你好久好久了,要抱抱!",
"It's in the top drawer... wait, it's empty? No way, that's impossible! I'm sure I put it there!",
]
instructs = [
"体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。",
"Speak in an incredulous tone, but with a hint of panic beginning to creep into your voice.",
]
languages = ["Chinese", "English"]
inputs = []
for text, instruct, language in zip(texts, instructs, languages):
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
inputs.append(
{
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"text": [text],
"language": [language],
"instruct": [instruct],
"max_new_tokens": [2048],
"non_streaming_mode": [True],
},
}
)
else:
text = "哥哥,你回来啦,人家等了你好久好久了,要抱抱!"
instruct = "体现撒娇稚嫩的萝莉女声,音调偏高且起伏明显,营造出黏人、做作又刻意卖萌的听觉效果。"
language = "Chinese"
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
inputs = {
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"text": [text],
"language": [language],
"instruct": [instruct],
"max_new_tokens": [2048],
"non_streaming_mode": [True],
},
}
return QueryResult(
inputs=inputs,
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
)
def get_base_query(use_batch_sample: bool = False, mode_tag: str = "icl") -> QueryResult:
"""Build Base (voice clone) sample inputs.
Args:
use_batch_sample: When True, return a batch of prompts (Case 2).
mode_tag: "icl" or "xvec_only" to control x_vector_only_mode behavior.
Returns:
QueryResult with Omni inputs and the Base model path.
"""
task_type = "Base"
ref_audio_path_1 = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"
ref_audio_single = ref_audio_path_1
ref_text_single = (
"Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."
)
syn_text_single = "Good one. Okay, fine, I'm just gonna leave this sock monkey here. Goodbye."
syn_lang_single = "Auto"
x_vector_only_mode = mode_tag == "xvec_only"
if use_batch_sample:
syn_text_batch = [
"Good one. Okay, fine, I'm just gonna leave this sock monkey here. Goodbye.",
"其实我真的有发现,我是一个特别善于观察别人情绪的人。",
]
syn_lang_batch = ["Chinese", "English"]
inputs = []
for text, language in zip(syn_text_batch, syn_lang_batch):
prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
inputs.append(
{
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"ref_audio": [ref_audio_single],
"ref_text": [ref_text_single],
"text": [text],
"language": [language],
"x_vector_only_mode": [x_vector_only_mode],
"max_new_tokens": [2048],
},
}
)
else:
prompt = f"<|im_start|>assistant\n{syn_text_single}<|im_end|>\n<|im_start|>assistant\n"
inputs = {
"prompt": prompt,
"additional_information": {
"task_type": [task_type],
"ref_audio": [ref_audio_single],
"ref_text": [ref_text_single],
"text": [syn_text_single],
"language": [syn_lang_single],
"x_vector_only_mode": [x_vector_only_mode],
"max_new_tokens": [2048],
},
}
return QueryResult(
inputs=inputs,
model_name="Qwen/Qwen3-TTS-12Hz-1.7B-Base",
)
def main(args):
"""Run offline inference with Omni using prepared sample inputs.
Args:
args: Parsed CLI args from parse_args().
"""
query_func = query_map[args.query_type]
if args.query_type in {"CustomVoice", "VoiceDesign"}:
query_result = query_func(use_batch_sample=args.use_batch_sample)
elif args.query_type == "Base":
query_result = query_func(
use_batch_sample=args.use_batch_sample,
mode_tag=args.mode_tag,
)
else:
query_result = query_func()
model_name = query_result.model_name
omni = Omni(
model=model_name,
stage_configs_path=args.stage_configs_path,
log_stats=args.enable_stats,
stage_init_timeout=args.stage_init_timeout,
)
sampling_params = SamplingParams(
temperature=0.9,
top_p=1.0,
top_k=50,
max_tokens=2048,
seed=42,
detokenize=False,
repetition_penalty=1.05,
)
sampling_params_list = [
sampling_params,
]
output_dir = args.output_dir if getattr(args, "output_dir", None) else args.output_wav
os.makedirs(output_dir, exist_ok=True)
omni_generator = omni.generate(query_result.inputs, sampling_params_list)
for stage_outputs in omni_generator:
for output in stage_outputs.request_output:
request_id = output.request_id
audio_tensor = output.outputs[0].multimodal_output["audio"]
output_wav = os.path.join(output_dir, f"output_{request_id}.wav")
audio_samplerate = output.outputs[0].multimodal_output["sr"].item()
# Convert to numpy array and ensure correct format
audio_numpy = audio_tensor.float().detach().cpu().numpy()
# Ensure audio is 1D (flatten if needed)
if audio_numpy.ndim > 1:
audio_numpy = audio_numpy.flatten()
# Save audio file with explicit WAV format
sf.write(output_wav, audio_numpy, samplerate=audio_samplerate, format="WAV")
print(f"Request ID: {request_id}, Saved audio to {output_wav}")
def parse_args():
"""Parse CLI arguments for offline TTS inference.
Returns:
argparse.Namespace with CLI options.
"""
parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models")
parser.add_argument(
"--query-type",
"-q",
type=str,
default="CustomVoice",
choices=query_map.keys(),
help="Query type.",
)
parser.add_argument(
"--enable-stats",
action="store_true",
default=False,
help="Enable writing detailed statistics (default: disabled)",
)
parser.add_argument(
"--stage-init-timeout",
type=int,
default=300,
help="Timeout for initializing a single stage in seconds (default: 300)",
)
parser.add_argument(
"--batch-timeout",
type=int,
default=5,
help="Timeout for batching in seconds (default: 5)",
)
parser.add_argument(
"--init-timeout",
type=int,
default=300,
help="Timeout for initializing stages in seconds (default: 300)",
)
parser.add_argument(
"--shm-threshold-bytes",
type=int,
default=65536,
help="Threshold for using shared memory in bytes (default: 65536)",
)
parser.add_argument(
"--output-wav",
default="output_audio",
help="[Deprecated] Output wav directory (use --output-dir).",
)
parser.add_argument(
"--num-prompts",
type=int,
default=1,
help="Number of prompts to generate.",
)
parser.add_argument(
"--txt-prompts",
type=str,
default=None,
help="Path to a .txt file with one prompt per line (preferred).",
)
parser.add_argument(
"--stage-configs-path",
type=str,
default=None,
help="Path to a stage configs file.",
)
parser.add_argument(
"--audio-path",
"-a",
type=str,
default=None,
help="Path to local audio file. If not provided, uses default audio asset.",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=16000,
help="Sampling rate for audio loading (default: 16000).",
)
parser.add_argument(
"--log-dir",
type=str,
default="logs",
help="Log directory (default: logs).",
)
parser.add_argument(
"--py-generator",
action="store_true",
default=False,
help="Use py_generator mode. The returned type of Omni.generate() is a Python Generator object.",
)
parser.add_argument(
"--use-batch-sample",
action="store_true",
default=False,
help="Use batch input sample for CustomVoice/VoiceDesign/Base query.",
)
parser.add_argument(
"--mode-tag",
type=str,
default="icl",
choices=["icl", "xvec_only"],
help="Mode tag for Base query x_vector_only_mode (default: icl).",
)
return parser.parse_args()
query_map = {
"CustomVoice": get_custom_voice_query,
"VoiceDesign": get_voice_design_query,
"Base": get_base_query,
}
if __name__ == "__main__":
args = parse_args()
main(args)
# Text-To-Audio
The `stabilityai/stable-audio-open-1.0` pipeline generates audio from text prompts.
## Prerequisites
If you use a gated model (e.g., `stabilityai/stable-audio-open-1.0`), ensure you have access:
1. **Accept Model License**: Visit the model page on Hugging Face (e.g., [stabilityai/stable-audio-open-1.0]) and accept the user agreement.
2. **Authenticate**: Log in to Hugging Face locally to access the gated model.
```bash
huggingface-cli login
```
## Local CLI Usage
```bash
python text_to_audio.py \
--model stabilityai/stable-audio-open-1.0 \
--prompt "The sound of a hammer hitting a wooden surface" \
--negative_prompt "Low quality" \
--seed 42 \
--guidance_scale 7.0 \
--audio_length 10.0 \
--num_inference_steps 100 \
--output stable_audio_output.wav
```
Key arguments:
- `--prompt`: text description (string).
- `--negative_prompt`: negative prompt for classifier-free guidance.
- `--seed`: integer seed for deterministic generation.
- `--guidance_scale`: classifier-free guidance scale.
- `--audio_length`: audio duration in seconds.
- `--num_inference_steps`: diffusion sampling steps.(more steps = higher quality, slower).
- `--output`: path to save the generated WAV file.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example script for text-to-audio generation using Stable Audio Open.
This script demonstrates how to generate audio from text prompts using
the Stable Audio Open model with vLLM-Omni.
Usage:
python text_to_audio.py --prompt "The sound of a dog barking"
python text_to_audio.py --prompt "A piano playing a gentle melody" --audio_length 10.0
python text_to_audio.py --prompt "Thunder and rain sounds" --negative_prompt "Low quality"
"""
import argparse
import time
from pathlib import Path
import numpy as np
import torch
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate audio with Stable Audio Open.")
parser.add_argument(
"--model",
default="stabilityai/stable-audio-open-1.0",
help="Stable Audio model name or local path.",
)
parser.add_argument(
"--prompt",
default="The sound of a hammer hitting a wooden surface.",
help="Text prompt for audio generation.",
)
parser.add_argument(
"--negative_prompt",
default="Low quality.",
help="Negative prompt for classifier-free guidance.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for deterministic results.",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=7.0,
help="Classifier-free guidance scale.",
)
parser.add_argument(
"--audio_start",
type=float,
default=0.0,
help="Audio start time in seconds.",
)
parser.add_argument(
"--audio_length",
type=float,
default=10.0,
help="Audio length in seconds (max ~47s for stable-audio-open-1.0).",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=100,
help="Number of denoising steps for the diffusion sampler.",
)
parser.add_argument(
"--num_waveforms",
type=int,
default=1,
help="Number of audio waveforms to generate for the given prompt.",
)
parser.add_argument(
"--output",
type=str,
default="stable_audio_output.wav",
help="Path to save the generated audio (WAV format).",
)
parser.add_argument(
"--sample_rate",
type=int,
default=44100,
help="Sample rate for output audio (Stable Audio uses 44100 Hz).",
)
return parser.parse_args()
def save_audio(audio_data: np.ndarray, output_path: str, sample_rate: int = 44100):
"""Save audio data to a WAV file."""
try:
import soundfile as sf
sf.write(output_path, audio_data, sample_rate)
except ImportError:
try:
import scipy.io.wavfile as wav
# Ensure audio is in the correct format for scipy
if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
# Normalize to int16 range
audio_data = np.clip(audio_data, -1.0, 1.0)
audio_data = (audio_data * 32767).astype(np.int16)
wav.write(output_path, sample_rate, audio_data)
except ImportError:
raise ImportError(
"Either 'soundfile' or 'scipy' is required to save audio files. "
"Install with: pip install soundfile or pip install scipy"
)
def main():
args = parse_args()
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed)
print(f"\n{'=' * 60}")
print("Stable Audio Open - Text-to-Audio Generation")
print(f"{'=' * 60}")
print(f" Model: {args.model}")
print(f" Prompt: {args.prompt}")
print(f" Negative prompt: {args.negative_prompt}")
print(f" Audio length: {args.audio_length}s")
print(f" Inference steps: {args.num_inference_steps}")
print(f" Guidance scale: {args.guidance_scale}")
print(f" Seed: {args.seed}")
print(f"{'=' * 60}\n")
# Initialize Omni with Stable Audio model
omni = Omni(model=args.model)
# Calculate audio end time
audio_end_in_s = args.audio_start + args.audio_length
# Time profiling for generation
generation_start = time.perf_counter()
# Generate audio
outputs = omni.generate(
{
"prompt": args.prompt,
"negative_prompt": args.negative_prompt,
},
OmniDiffusionSamplingParams(
generator=generator,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
num_outputs_per_prompt=args.num_waveforms,
extra_args={
"audio_start_in_s": args.audio_start,
"audio_end_in_s": audio_end_in_s,
},
),
)
generation_end = time.perf_counter()
generation_time = generation_end - generation_start
print(f"Total generation time: {generation_time:.2f} seconds")
# Process and save audio
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
suffix = output_path.suffix or ".wav"
stem = output_path.stem or "stable_audio_output"
# Extract audio from omni.generate() outputs
if not outputs:
raise ValueError("No output generated from omni.generate()")
output = outputs[0]
if not hasattr(output, "request_output") or not output.request_output:
raise ValueError("No request_output found in OmniRequestOutput")
request_output = output.request_output[0]
if not hasattr(request_output, "multimodal_output"):
raise ValueError("No multimodal_output found in request_output")
audio = request_output.multimodal_output.get("audio")
if audio is None:
raise ValueError("No audio output found in request_output")
# Handle different output formats
if isinstance(audio, torch.Tensor):
audio = audio.cpu().float().numpy()
# Audio shape is typically [batch, channels, samples] or [channels, samples]
if audio.ndim == 3:
# [batch, channels, samples]
if args.num_waveforms <= 1:
audio_data = audio[0].T # [samples, channels]
save_audio(audio_data, str(output_path), args.sample_rate)
print(f"Saved generated audio to {output_path}")
else:
for idx in range(audio.shape[0]):
audio_data = audio[idx].T # [samples, channels]
save_path = output_path.parent / f"{stem}_{idx}{suffix}"
save_audio(audio_data, str(save_path), args.sample_rate)
print(f"Saved generated audio to {save_path}")
elif audio.ndim == 2:
# [channels, samples]
audio_data = audio.T # [samples, channels]
save_audio(audio_data, str(output_path), args.sample_rate)
print(f"Saved generated audio to {output_path}")
else:
# [samples] - mono audio
save_audio(audio, str(output_path), args.sample_rate)
print(f"Saved generated audio to {output_path}")
print(f"\nGenerated {args.audio_length}s of audio at {args.sample_rate} Hz")
if __name__ == "__main__":
main()
# Text-To-Image
This folder provides several entrypoints for experimenting with `Qwen/Qwen-Image` `Qwen/Qwen-Image-2512` `Tongyi-MAI/Z-Image-Turbo` using vLLM-Omni:
- `text_to_image.py`: command-line script for single image generation with advanced options.
- `web_demo.py`: lightweight Gradio UI for interactive prompt/seed/CFG exploration.
Note that when you pass in multiple independent prompts, they will be processed sequentially. Batching requests is currently not supported.
## Basic Usage
```python
from vllm_omni.entrypoints.omni import Omni
if __name__ == "__main__":
omni = Omni(model="Qwen/Qwen-Image")
prompt = "a cup of coffee on the table"
outputs = omni.generate(prompt)
images = outputs[0].request_output[0].images
images[0].save("coffee.png")
```
Or put more than one prompt in a request.
```python
from vllm_omni.entrypoints.omni import Omni
if __name__ == "__main__":
omni = Omni(model="Qwen/Qwen-Image")
prompts = [
"a cup of coffee on a table",
"a toy dinosaur on a sandy beach",
"a fox waking up in bed and yawning",
]
outputs = omni.generate(prompts)
for i, output in enumerate(outputs):
image = output.request_output[0].images[0].save(f"{i}.jpg")
```
!!! info
However, it is not currently recommended to do so
because not all models support batch inference,
and batch requesting mostly does not provide significant performance improvement (despite the impression that it does).
This feature is primarily for the sake of interface compatibility with vLLM and to allow for future improvements.
!!! info
For diffusion pipelines, the stage config field `stage_args.[].runtime.max_batch_size` is 1 by default, and the input
list is sliced into single-item requests before feeding into the diffusion pipeline. For models that do internally support
batched inputs, you can [modify this configuration](../../../configuration/stage_configs.md) to let the model accept a longer batch of prompts.
Apart from string prompt, vLLM-Omni also supports dictionary prompts in the same style as vLLM.
This is useful for models that support negative prompts.
```python
from vllm_omni.entrypoints.omni import Omni
if __name__ == "__main__":
omni = Omni(model="Qwen/Qwen-Image")
outputs = omni.generate([
{
"prompt": "a cup of coffee on a table"
"negative_prompt": "low resolution"
},
{
"prompt": "a toy dinosaur on a sandy beach"
"negative_prompt": "cinematic, realistic"
}
])
for i, output in enumerate(outputs):
image = output.request_output[0].images[0].save(f"{i}.jpg")
```
## Local CLI Usage
```bash
python text_to_image.py \
--model Tongyi-MAI/Z-Image-Turbo \
--prompt "a cup of coffee on the table" \
--seed 42 \
--cfg_scale 4.0 \
--num_images_per_prompt 1 \
--num_inference_steps 50 \
--height 1024 \
--width 1024 \
--output outputs/coffee.png
```
Key arguments:
- `--prompt`: text description (string).
- `--seed`: integer seed for deterministic sampling.
- `--cfg_scale`: true CFG scale (model-specific guidance strength).
- `--num_images_per_prompt`: number of images to generate per prompt (saves as `output`, `output_1`, ...).
- `--num_inference_steps`: diffusion sampling steps (more steps = higher quality, slower).
- `--height/--width`: output resolution (defaults 1024x1024).
- `--output`: path to save the generated PNG.
- `--vae_use_slicing`: enable VAE slicing for memory optimization.
- `--vae_use_tiling`: enable VAE tiling for memory optimization.
- `--cfg_parallel_size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel).
- `--enable-cpu-offload`: enable CPU offloading for diffusion models.
> ℹ️ If you encounter OOM errors, try using `--vae_use_slicing` and `--vae_use_tiling` to reduce memory usage.
> ℹ️ Qwen-Image currently publishes best-effort presets at `1328x1328`, `1664x928`, `928x1664`, `1472x1140`, `1140x1472`, `1584x1056`, and `1056x1584`. Adjust `--height/--width` accordingly for the most reliable outcomes.
## Web UI Demo
Launch the gradio demo:
```bash
python gradio_demo.py --port 7862
```
Then open `http://localhost:7862/` on your local browser to interact with the web UI.
import argparse
from functools import lru_cache
import gradio as gr
import torch
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
ASPECT_RATIOS: dict[str, tuple[int, int]] = {
"1:1": (1328, 1328),
"16:9": (1664, 928),
"9:16": (928, 1664),
"4:3": (1472, 1140),
"3:4": (1140, 1472),
"3:2": (1584, 1056),
"2:3": (1056, 1584),
}
ASPECT_RATIO_CHOICES = [f"{ratio} ({w}x{h})" for ratio, (w, h) in ASPECT_RATIOS.items()]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Gradio demo for Qwen-Image offline inference.")
parser.add_argument("--model", default="Qwen/Qwen-Image", help="Diffusion model name or local path.")
parser.add_argument(
"--height",
type=int,
default=1328,
help="Default image height (must match one of the supported presets).",
)
parser.add_argument(
"--width",
type=int,
default=1328,
help="Default image width (must match one of the supported presets).",
)
parser.add_argument("--default-prompt", default="a cup of coffee on the table", help="Initial prompt shown in UI.")
parser.add_argument("--default-seed", type=int, default=42, help="Initial seed shown in UI.")
parser.add_argument("--default-cfg-scale", type=float, default=4.0, help="Initial CFG scale shown in UI.")
parser.add_argument(
"--num_inference_steps",
type=int,
default=50,
help="Default number of denoising steps shown in the UI.",
)
parser.add_argument("--ip", default="127.0.0.1", help="Host/IP for Gradio `launch`.")
parser.add_argument("--port", type=int, default=7862, help="Port for Gradio `launch`.")
parser.add_argument("--share", action="store_true", help="Share the Gradio demo publicly.")
args = parser.parse_args()
args.aspect_ratio_label = next(
(ratio for ratio, dims in ASPECT_RATIOS.items() if dims == (args.width, args.height)),
None,
)
if args.aspect_ratio_label is None:
supported = ", ".join(f"{ratio} ({w}x{h})" for ratio, (w, h) in ASPECT_RATIOS.items())
parser.error(f"Unsupported resolution {args.width}x{args.height}. Please pick one of: {supported}.")
return args
@lru_cache(maxsize=1)
def get_omni(model_name: str) -> Omni:
# Enable VAE memory optimizations on NPU
vae_use_slicing = current_omni_platform.is_npu()
vae_use_tiling = current_omni_platform.is_npu()
return Omni(
model=model_name,
vae_use_slicing=vae_use_slicing,
vae_use_tiling=vae_use_tiling,
)
def build_demo(args: argparse.Namespace) -> gr.Blocks:
omni = get_omni(args.model)
def run_inference(
prompt: str,
seed_value: float,
cfg_scale_value: float,
resolution_choice: str,
num_steps_value: float,
num_images_choice: float,
):
if not prompt or not prompt.strip():
raise gr.Error("Please enter a non-empty prompt.")
ratio_label = resolution_choice.split(" ", 1)[0]
if ratio_label not in ASPECT_RATIOS:
raise gr.Error(f"Unsupported aspect ratio: {ratio_label}")
width, height = ASPECT_RATIOS[ratio_label]
try:
seed = int(seed_value)
num_steps = int(num_steps_value)
num_images = int(num_images_choice)
except (TypeError, ValueError) as exc:
raise gr.Error("Seed, inference steps, and number of images must be valid integers.") from exc
if num_steps <= 0:
raise gr.Error("Inference steps must be a positive integer.")
if num_images not in {1, 2, 3, 4}:
raise gr.Error("Number of images must be 1, 2, 3, or 4.")
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed)
outputs = omni.generate(
prompt.strip(),
OmniDiffusionSamplingParams(
height=height,
width=width,
generator=generator,
true_cfg_scale=float(cfg_scale_value),
num_inference_steps=num_steps,
num_outputs_per_prompt=num_images,
),
)
images_outputs = []
for output in outputs:
req_out = output.request_output[0]
if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
raise ValueError("Invalid request_output structure or missing 'images' key")
images = req_out.images
if not images:
raise ValueError("No images found in request_output")
# Extend the list with individual images (not append the entire list)
images_outputs.extend(images)
if len(images_outputs) >= num_images:
break
# Return only the requested number of images
return images_outputs[:num_images]
with gr.Blocks(
title="vLLM-Omni Web Serving Demo",
css="""
/* Left column button width */
.left-column button {
width: 100%;
}
/* Right preview area: fixed height, hide unnecessary buttons */
.fixed-image {
height: 660px;
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
}
.fixed-image .duplicate-button,
.fixed-image .svelte-drgfj2 {
display: none !important;
}
/* Gallery container: fill available space and center content */
#image-gallery {
width: 100%;
height: 100%;
display: flex;
align-items: center;
justify-content: center;
}
/* Gallery grid: center horizontally and vertically, set gap */
#image-gallery .grid {
display: flex;
flex-wrap: wrap;
justify-content: center;
align-items: center;
align-content: center;
gap: 16px;
width: 100%;
height: 100%;
}
/* Gallery grid items: center content */
#image-gallery .grid > div {
display: flex;
align-items: center;
justify-content: center;
}
/* Gallery images: limit max height, maintain aspect ratio */
.fixed-image img {
max-height: 660px !important;
width: auto !important;
object-fit: contain;
}
""",
) as demo:
gr.Markdown("# vLLM-Omni Web Serving Demo")
gr.Markdown(f"**Model:** {args.model}")
with gr.Row():
with gr.Column(scale=1, elem_classes="left-column"):
prompt_input = gr.Textbox(
label="Prompt",
value=args.default_prompt,
placeholder="Describe the image you want to generate...",
lines=5,
)
seed_input = gr.Number(label="Seed", value=args.default_seed, precision=0)
cfg_input = gr.Number(label="CFG Scale", value=args.default_cfg_scale)
steps_input = gr.Number(
label="Inference Steps",
value=args.num_inference_steps,
precision=0,
minimum=1,
)
aspect_dropdown = gr.Dropdown(
label="Aspect Ratio (W:H)",
choices=ASPECT_RATIO_CHOICES,
value=f"{args.aspect_ratio_label} ({ASPECT_RATIOS[args.aspect_ratio_label][0]}x{ASPECT_RATIOS[args.aspect_ratio_label][1]})",
)
num_images = gr.Dropdown(
label="Number of images",
choices=["1", "2", "3", "4"],
value="1",
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2, elem_classes="fixed-image"):
gallery = gr.Gallery(
label="Preview",
columns=2,
rows=2,
height=660,
allow_preview=True,
show_label=True,
elem_id="image-gallery",
)
generate_btn.click(
fn=run_inference,
inputs=[prompt_input, seed_input, cfg_input, aspect_dropdown, steps_input, num_images],
outputs=gallery,
)
return demo
def main():
args = parse_args()
demo = build_demo(args)
demo.launch(server_name=args.ip, server_port=args.port, share=args.share)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment