Unverified Commit cbbde3d0 authored by Zhongdongming Dai's avatar Zhongdongming Dai Committed by GitHub
Browse files

feat: add initial audio/TTS pipeline support for vLLM-omni backend (#7495)


Signed-off-by: default avatarZhongdongming Dai <zhongdongmin@nvidia.com>
parent 76c70f41
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Protocol types for audio generation (TTS).
These types follow the vLLM-Omni OpenAICreateSpeechRequest format,
with TTS-specific parameters as top-level fields (not nested in nvext).
Note: These Pydantic models mirror the Rust protocol types in
lib/llm/src/protocols/openai/audios.rs. Ideally these should be
code-generated from the Rust definitions; for now they are maintained
manually and must be kept in sync.
"""
from typing import Literal, Optional
from pydantic import BaseModel, Field
class NvCreateAudioSpeechRequest(BaseModel):
"""Request for audio speech generation (/v1/audio/speech endpoint).
Follows vLLM-Omni's OpenAICreateSpeechRequest format.
"""
# Standard OpenAI params
input: str
"""The text to synthesize into speech."""
model: Optional[str] = None
"""The TTS model to use."""
voice: Optional[str] = None
"""Voice/speaker name (e.g., 'vivian', 'ryan', 'aiden')."""
response_format: Optional[
Literal["wav", "pcm", "flac", "mp3", "aac", "opus"]
] = "wav"
"""Output format."""
speed: Optional[float] = Field(default=1.0, ge=0.25, le=4.0)
"""Speed factor."""
# Qwen3-TTS specific params (top-level, matching vLLM-Omni)
task_type: Optional[Literal["CustomVoice", "VoiceDesign", "Base"]] = None
"""TTS task type."""
language: Optional[str] = None
"""Language: Auto, Chinese, English, Japanese, Korean, etc."""
instructions: Optional[str] = None
"""Voice style/emotion instructions (for VoiceDesign)."""
ref_audio: Optional[str] = None
"""Reference audio URL or base64 (for voice cloning with Base task)."""
ref_text: Optional[str] = None
"""Reference transcript (for voice cloning with Base task)."""
max_new_tokens: Optional[int] = None
"""Maximum tokens to generate (default: 2048)."""
class AudioData(BaseModel):
"""Audio data in response."""
url: Optional[str] = None
"""URL of the generated audio (if response_format is 'url')."""
b64_json: Optional[str] = None
"""Base64-encoded audio data."""
class NvAudioSpeechResponse(BaseModel):
"""Response structure for audio speech generation."""
id: str
"""Unique identifier for the response."""
object: str = "audio.speech"
"""Object type."""
model: str
"""Model used for generation."""
status: str = "completed"
"""Generation status."""
progress: int = 100
"""Progress percentage (0-100)."""
created: int
"""Unix timestamp of creation."""
data: list[AudioData] = []
"""List of generated audio data."""
error: Optional[str] = None
"""Error message if generation failed."""
inference_time_s: Optional[float] = None
"""Inference time in seconds."""
...@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import NvCreateImageRequest from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest from dynamo.common.protocols.video_protocol import NvCreateVideoRequest
from dynamo.llm import ModelType from dynamo.llm import ModelType
...@@ -91,8 +92,7 @@ def parse_request_type( ...@@ -91,8 +92,7 @@ def parse_request_type(
return NvCreateVideoRequest(**raw_request), RequestType.VIDEO_GENERATION return NvCreateVideoRequest(**raw_request), RequestType.VIDEO_GENERATION
if modality is OutputModality.AUDIO: if modality is OutputModality.AUDIO:
# Audio protocol types are not yet defined; pass through the raw dict. return NvCreateAudioSpeechRequest(**raw_request), RequestType.AUDIO_GENERATION
return raw_request, RequestType.AUDIO_GENERATION
# Text Modality # Text Modality
return raw_request, RequestType.CHAT_COMPLETION return raw_request, RequestType.CHAT_COMPLETION
...@@ -158,6 +158,52 @@ class OmniArgGroup(ArgGroup): ...@@ -158,6 +158,52 @@ class OmniArgGroup(ArgGroup):
help="Disable torch.compile and force eager execution for diffusion models.", help="Disable torch.compile and force eager execution for diffusion models.",
) )
# TTS parameters
tts_g = parser.add_argument_group(
"Omni TTS Options",
"TTS/audio-specific parameters for vLLM-Omni speech generation.",
)
add_argument(
tts_g,
flag_name="--tts-max-instructions-length",
env_var="DYN_OMNI_TTS_MAX_INSTRUCTIONS_LENGTH",
default=500,
arg_type=int,
help="Maximum character length for TTS voice instructions.",
)
add_argument(
tts_g,
flag_name="--tts-max-new-tokens-min",
env_var="DYN_OMNI_TTS_MAX_NEW_TOKENS_MIN",
default=1,
arg_type=int,
help="Minimum allowed value for max_new_tokens in TTS requests.",
)
add_argument(
tts_g,
flag_name="--tts-max-new-tokens-max",
env_var="DYN_OMNI_TTS_MAX_NEW_TOKENS_MAX",
default=4096,
arg_type=int,
help="Maximum allowed value for max_new_tokens in TTS requests.",
)
add_argument(
tts_g,
flag_name="--tts-ref-audio-timeout",
env_var="DYN_OMNI_TTS_REF_AUDIO_TIMEOUT",
default=15,
arg_type=int,
help="Timeout in seconds for downloading reference audio URLs.",
)
add_argument(
tts_g,
flag_name="--tts-ref-audio-max-bytes",
env_var="DYN_OMNI_TTS_REF_AUDIO_MAX_BYTES",
default=50 * 1024 * 1024,
arg_type=int,
help="Maximum size in bytes for reference audio files (default: 50MB).",
)
# Diffusion parallel configuration # Diffusion parallel configuration
add_argument( add_argument(
g, g,
...@@ -217,6 +263,13 @@ class OmniConfig(DynamoRuntimeConfig): ...@@ -217,6 +263,13 @@ class OmniConfig(DynamoRuntimeConfig):
ring_degree: int = 1 ring_degree: int = 1
cfg_parallel_size: int = 1 cfg_parallel_size: int = 1
# TTS parameters
tts_max_instructions_length: int = 500
tts_max_new_tokens_min: int = 1
tts_max_new_tokens_max: int = 4096
tts_ref_audio_timeout: int = 15
tts_ref_audio_max_bytes: int = 50 * 1024 * 1024
def validate(self) -> None: def validate(self) -> None:
DynamoRuntimeConfig.validate(self) DynamoRuntimeConfig.validate(self)
if self.default_video_fps <= 0: if self.default_video_fps <= 0:
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Audio/TTS handler utilities for the vLLM-Omni backend.
Extracted from omni_handler.py to keep modality-specific logic separate.
OmniHandler holds an instance as ``self.audio`` (composition).
"""
import asyncio
import base64
import logging
import time
import uuid
from io import BytesIO
from typing import Any, Dict
from vllm_omni.inputs.data import OmniTextPrompt
from dynamo.common.protocols.audio_protocol import (
AudioData,
NvAudioSpeechResponse,
NvCreateAudioSpeechRequest,
)
from dynamo.common.storage import upload_to_fs
from dynamo.common.utils.output_modalities import RequestType
logger = logging.getLogger(__name__)
# model_stage names that receive Qwen3-TTS-specific prompt format
# (prompt_token_ids + additional_information). Other audio models
# (MiMo-Audio, Qwen3-Omni, Stable Audio, etc.) use a plain text prompt.
# Mirrors vLLM-Omni's _TTS_MODEL_STAGES in serving_speech.py.
_TTS_MODEL_STAGES: set = {"qwen3_tts"}
# Fallback language set used when model config is unavailable.
_TTS_LANGUAGES_FALLBACK = {
"Auto",
"Chinese",
"English",
"Japanese",
"Korean",
"German",
"French",
"Russian",
"Portuguese",
"Spanish",
"Italian",
}
class AudioGenerationHandler:
"""Handles audio/TTS request processing for the vLLM-Omni backend.
Instantiated by OmniHandler during initialization and held as a
composition attribute (``self._audio_handler``). This keeps
audio-specific logic (validation, prompt building, encoding) out
of the orchestrator.
"""
def __init__(self, config, engine_client, media_output_fs, media_output_http_url):
self.config = config
self.engine_client = engine_client
self.media_output_fs = media_output_fs
self.media_output_http_url = media_output_http_url
self._tts_tokenizer: Any = None
# Cache TTS capabilities from model config at init.
self._tts_supported_speakers: set = self._load_supported_speakers()
self._tts_supported_languages: set = self._load_supported_languages()
if self._tts_supported_speakers:
logger.info(
"Loaded %d TTS speakers: %s",
len(self._tts_supported_speakers),
sorted(self._tts_supported_speakers),
)
if self._tts_supported_languages:
logger.info(
"Loaded %d TTS languages: %s",
len(self._tts_supported_languages),
sorted(self._tts_supported_languages),
)
# -- TTS capability loading from model config -----------------------------
def _load_supported_speakers(self) -> set:
"""Load supported speakers from model config (case-insensitive).
Reads ``hf_config.talker_config.spk_id`` or ``speaker_id``,
matching vLLM-Omni's ``_load_supported_speakers()``.
"""
try:
hf_config = self.engine_client.model_config.hf_config
talker_config = getattr(hf_config, "talker_config", None)
if talker_config is None:
return set()
for attr_name in ("spk_id", "speaker_id"):
speakers_dict = getattr(talker_config, attr_name, None)
if speakers_dict and isinstance(speakers_dict, dict):
return {s.lower() for s in speakers_dict.keys()}
except Exception as e:
logger.warning("Could not load speakers from model config: %s", e)
return set()
def _load_supported_languages(self) -> set:
"""Load supported languages from model config.
Reads ``hf_config.talker_config.codec_language_id``.
"""
try:
hf_config = self.engine_client.model_config.hf_config
talker_config = getattr(hf_config, "talker_config", None)
if talker_config is None:
return set()
lang_dict = getattr(talker_config, "codec_language_id", None)
if lang_dict and isinstance(lang_dict, dict):
return {lang.lower() for lang in lang_dict.keys()}
except Exception as e:
logger.warning("Could not load languages from model config: %s", e)
return set()
# -- TTS model detection --------------------------------------------------
def _is_tts_model(self) -> bool:
"""Check if the loaded model is a Qwen3-TTS-style model.
Searches for a TTS model_stage in the engine's stage list,
stage configs, or model config. Supports multiple vLLM-Omni versions.
"""
# Try stage_list
stage_list = getattr(self.engine_client, "stage_list", None)
if stage_list:
for stage in stage_list:
ms = getattr(stage, "model_stage", None)
logger.debug("_is_tts_model: stage=%s model_stage=%s", stage, ms)
if ms in _TTS_MODEL_STAGES:
return True
# Try stage_configs
stage_configs = getattr(self.engine_client, "stage_configs", None)
if stage_configs:
for cfg in stage_configs:
engine_args = (
cfg.get("engine_args", {})
if isinstance(cfg, dict)
else getattr(cfg, "engine_args", {})
)
ms = (
engine_args.get("model_stage")
if isinstance(engine_args, dict)
else getattr(engine_args, "model_stage", None)
)
logger.debug("_is_tts_model: stage_config model_stage=%s", ms)
if ms in _TTS_MODEL_STAGES:
return True
# Try model_config.hf_config.model_type (universal fallback)
try:
model_type = self.engine_client.model_config.hf_config.model_type
logger.debug("_is_tts_model: hf_config.model_type=%s", model_type)
if model_type in _TTS_MODEL_STAGES:
return True
except (AttributeError, TypeError) as e:
logger.debug("_is_tts_model: hf_config fallback failed: %s", e)
logger.warning(
"_is_tts_model: could not detect TTS model. "
"stage_list=%s, stage_configs=%s",
stage_list is not None,
stage_configs is not None,
)
return False
# -- Audio engine input construction --------------------------------------
async def build_engine_inputs(self, req: NvCreateAudioSpeechRequest):
"""Build engine inputs for an audio/TTS request.
Two code paths (matching vLLM-Omni serving_speech.py):
* **TTS path** (Qwen3-TTS): ``prompt_token_ids`` +
``additional_information``.
* **Generic audio path** (MiMo-Audio, etc.): plain text prompt.
"""
# Import here to avoid circular dependency
from dynamo.vllm.omni.omni_handler import EngineInputs
if not req.input or not req.input.strip():
raise ValueError("Input text cannot be empty")
if self._is_tts_model():
return await self._engine_inputs_tts(req)
# Generic audio model – plain text prompt (same as image/video)
prompt = OmniTextPrompt(prompt=req.input)
logger.info(f"Audio request (generic): input='{req.input[:50]}...'")
return EngineInputs(
prompt=prompt,
sampling_params_list=None,
request_type=RequestType.AUDIO_GENERATION,
response_format=req.response_format,
speed=req.speed or 1.0,
)
# -- Qwen3-TTS-specific helpers -------------------------------------------
async def _engine_inputs_tts(self, req: NvCreateAudioSpeechRequest):
"""Build engine inputs for Qwen3-TTS models."""
from dynamo.vllm.omni.omni_handler import EngineInputs
self._validate_tts_request(req)
if req.voice is not None:
req.voice = req.voice.lower()
task_type = req.task_type or "CustomVoice"
tts_params: Dict[str, Any] = {
"text": [req.input],
"task_type": [task_type],
"language": [req.language or "Auto"],
"instruct": [req.instructions or ""],
"max_new_tokens": [req.max_new_tokens or 2048],
}
if req.voice is not None:
tts_params["speaker"] = [req.voice]
elif task_type == "CustomVoice":
tts_params["speaker"] = ["Vivian"]
if req.ref_audio is not None:
wav_list, sr = await self._resolve_ref_audio(req.ref_audio)
tts_params["ref_audio"] = [[wav_list, sr]]
if req.ref_text is not None:
tts_params["ref_text"] = [req.ref_text]
if task_type == "VoiceDesign":
tts_params["non_streaming_mode"] = [True]
estimated_len = self._estimate_tts_prompt_len(tts_params)
prompt = {
"prompt_token_ids": [1] * estimated_len,
"additional_information": tts_params,
}
logger.info(
f"Audio TTS request: input='{req.input[:50]}...', "
f"voice={tts_params.get('speaker', ['N/A'])[0]}, "
f"task_type={task_type}, prompt_len={estimated_len}"
)
return EngineInputs(
prompt=prompt,
sampling_params_list=None,
request_type=RequestType.AUDIO_GENERATION,
response_format=req.response_format,
speed=req.speed or 1.0,
)
def _validate_tts_request(self, req: NvCreateAudioSpeechRequest) -> None:
"""Validate Qwen3-TTS-specific request parameters."""
task_type = req.task_type or "CustomVoice"
_ALLOWED_TASK_TYPES = {"CustomVoice", "VoiceDesign", "Base"}
if task_type not in _ALLOWED_TASK_TYPES:
raise ValueError(
f"Invalid task_type '{task_type}'. "
f"Supported: {', '.join(sorted(_ALLOWED_TASK_TYPES))}"
)
if req.language is not None:
supported_langs = self._tts_supported_languages or {
lang.lower() for lang in _TTS_LANGUAGES_FALLBACK
}
if req.language.lower() not in supported_langs and req.language != "Auto":
raise ValueError(
f"Invalid language '{req.language}'. "
f"Supported: Auto, {', '.join(sorted(supported_langs))}"
)
if task_type == "CustomVoice" and req.voice is not None:
if self._tts_supported_speakers:
if req.voice.lower() not in self._tts_supported_speakers:
raise ValueError(
f"Invalid voice '{req.voice}'. "
f"Supported: {', '.join(self._tts_supported_speakers)}"
)
if task_type == "Base" and req.ref_audio is None:
raise ValueError("Base task requires 'ref_audio' for voice cloning")
if task_type != "Base":
if req.ref_text is not None:
raise ValueError("'ref_text' is only valid for Base task")
if task_type == "VoiceDesign" and not req.instructions:
raise ValueError(
"VoiceDesign task requires 'instructions' to describe the voice"
)
if (
req.instructions
and len(req.instructions) > self.config.tts_max_instructions_length
):
raise ValueError(
f"Instructions too long "
f"(max {self.config.tts_max_instructions_length} characters)"
)
if req.max_new_tokens is not None:
if req.max_new_tokens < self.config.tts_max_new_tokens_min:
raise ValueError(
f"max_new_tokens must be at least "
f"{self.config.tts_max_new_tokens_min}"
)
if req.max_new_tokens > self.config.tts_max_new_tokens_max:
raise ValueError(
f"max_new_tokens cannot exceed "
f"{self.config.tts_max_new_tokens_max}"
)
async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple:
"""Download or decode reference audio for voice cloning (Base task)."""
import io
import soundfile as sf
if ref_audio_str.startswith(("http://", "https://")):
import ipaddress
import socket
from urllib.parse import urlparse
import aiohttp
parsed = urlparse(ref_audio_str)
if not parsed.hostname:
raise ValueError("Invalid ref_audio URL")
for info in socket.getaddrinfo(
parsed.hostname, parsed.port or 443, type=socket.SOCK_STREAM
):
ip_str = str(info[4][0]).split("%", 1)[0]
addr = ipaddress.ip_address(ip_str)
if addr.is_private or addr.is_loopback:
raise ValueError(
f"ref_audio URL resolves to blocked address: {addr}"
)
async with aiohttp.ClientSession() as session:
async with session.get(
ref_audio_str,
timeout=aiohttp.ClientTimeout(
total=self.config.tts_ref_audio_timeout
),
) as resp:
if resp.status != 200:
raise ValueError(
f"Failed to download ref_audio: HTTP {resp.status}"
)
audio_bytes = await resp.read()
if len(audio_bytes) > self.config.tts_ref_audio_max_bytes:
raise ValueError(
f"ref_audio too large "
f"({len(audio_bytes)} bytes, "
f"max {self.config.tts_ref_audio_max_bytes})"
)
elif ref_audio_str.startswith("data:"):
_, encoded = ref_audio_str.split(",", 1)
audio_bytes = base64.b64decode(encoded)
if len(audio_bytes) > self.config.tts_ref_audio_max_bytes:
raise ValueError(
f"ref_audio data URI too large "
f"({len(audio_bytes)} bytes, "
f"max {self.config.tts_ref_audio_max_bytes})"
)
else:
raise ValueError(
"ref_audio must be a URL (http/https) or base64 data URI (data:...)"
)
wav_data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
return wav_data, int(sr)
def _estimate_tts_prompt_len(self, tts_params: Dict[str, Any]) -> int:
"""Estimate Qwen3-TTS prompt length using its tokenizer.
Falls back to 2048 if the model-specific estimator is unavailable.
"""
try:
from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import (
Qwen3TTSTalkerForConditionalGeneration,
)
if not hasattr(self, "_tts_tokenizer") or self._tts_tokenizer is None:
from transformers import AutoTokenizer
self._tts_tokenizer = AutoTokenizer.from_pretrained(
self.config.model,
trust_remote_code=True,
padding_side="left",
)
hf_config = self.engine_client.model_config.hf_config
talker_config = getattr(hf_config, "talker_config", None)
task_type = (tts_params.get("task_type") or ["CustomVoice"])[0]
return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information(
additional_information=tts_params,
task_type=task_type,
tokenize_prompt=lambda t: self._tts_tokenizer(t, padding=False)[
"input_ids"
],
codec_language_id=(
getattr(talker_config, "codec_language_id", None)
if talker_config
else None
),
spk_is_dialect=(
getattr(talker_config, "spk_is_dialect", None)
if talker_config
else None
),
)
except Exception as e:
logger.warning(
"Failed to estimate TTS prompt length, using fallback 2048: %s", e
)
return 2048
# -- Audio output formatting ----------------------------------------------
def _extract_audio_tensor(self, mm_output: Dict[str, Any]) -> tuple:
"""Extract audio tensor and sample rate from multimodal_output dict."""
import numpy as np
import torch
audio_key = "audio" if "audio" in mm_output else "model_outputs"
audio_val = mm_output.get(audio_key)
if audio_val is None:
raise ValueError(
f"No audio data in multimodal_output. Keys: {list(mm_output.keys())}"
)
if isinstance(audio_val, list):
audio_val = torch.cat(audio_val, dim=-1)
if hasattr(audio_val, "float"):
audio_np = audio_val.float().detach().cpu().numpy()
elif isinstance(audio_val, np.ndarray):
audio_np = audio_val.astype(np.float32)
else:
audio_np = np.array(audio_val, dtype=np.float32)
if audio_np.ndim > 1:
audio_np = audio_np.squeeze()
sr_raw = mm_output.get("sr", 24000)
if isinstance(sr_raw, list):
sr_raw = sr_raw[-1] if sr_raw else 24000
sample_rate = sr_raw.item() if hasattr(sr_raw, "item") else int(sr_raw)
return audio_np, sample_rate
def _encode_audio(
self, audio_np, sample_rate: int, fmt: str = "wav", speed: float = 1.0
) -> tuple:
"""Encode a numpy float32 waveform to audio bytes.
Uses soundfile for multi-format support.
Applies speed adjustment via librosa if speed != 1.0.
"""
import soundfile as sf
if speed != 1.0:
try:
import librosa
audio_np = librosa.effects.time_stretch(y=audio_np, rate=speed)
except ImportError:
logger.warning("librosa not installed, ignoring speed adjustment")
fmt = (fmt or "wav").lower()
format_map = {
"wav": ("WAV", "audio/wav", {}),
"pcm": ("RAW", "audio/pcm", {"subtype": "PCM_16"}),
"flac": ("FLAC", "audio/flac", {}),
"mp3": ("MP3", "audio/mpeg", {}),
"aac": ("AAC", "audio/aac", {}),
"opus": ("OGG", "audio/ogg", {"subtype": "OPUS"}),
}
if fmt not in format_map:
logger.warning(f"Unsupported format '{fmt}', defaulting to wav")
fmt = "wav"
sf_format, media_type, kwargs = format_map[fmt]
buf = BytesIO()
sf.write(buf, audio_np, sample_rate, format=sf_format, **kwargs)
return buf.getvalue(), media_type
async def format_output(
self,
mm_output: Dict[str, Any],
request_id: str,
response_format: str | None = None,
request_type: RequestType = RequestType.AUDIO_GENERATION,
speed: float = 1.0,
) -> Dict[str, Any] | None:
"""Format multimodal audio output for the response."""
if not mm_output:
return NvAudioSpeechResponse(
id=request_id,
model=self.config.served_model_name or self.config.model,
status="failed",
created=int(time.time()),
error="No audio generated",
).model_dump()
try:
start_time = time.time()
audio_np, sample_rate = self._extract_audio_tensor(mm_output)
encode_fmt = (
"wav"
if response_format in (None, "url", "b64_json")
else response_format
)
assert encode_fmt is not None
audio_bytes, media_type = await asyncio.to_thread(
self._encode_audio, audio_np, sample_rate, encode_fmt, speed
)
logger.info(
f"Audio encoded for request {request_id}: "
f"{len(audio_np)} samples, sr={sample_rate}, "
f"{len(audio_bytes)} bytes {encode_fmt}"
)
inference_time = time.time() - start_time
if response_format == "url":
ext = encode_fmt if encode_fmt != "opus" else "ogg"
storage_path = f"audios/{request_id}/{uuid.uuid4()}.{ext}"
url = await upload_to_fs(
self.media_output_fs,
storage_path,
audio_bytes,
self.media_output_http_url,
)
audio_data_obj = AudioData(url=url)
else:
b64 = base64.b64encode(audio_bytes).decode("utf-8")
audio_data_obj = AudioData(b64_json=b64)
response = NvAudioSpeechResponse(
id=request_id,
object="audio.speech",
model=self.config.served_model_name or self.config.model,
status="completed",
progress=100,
created=int(time.time()),
data=[audio_data_obj],
inference_time_s=inference_time,
)
return response.model_dump()
except Exception as e:
logger.error(f"Failed to process audio for request {request_id}: {e}")
error_response = NvAudioSpeechResponse(
id=request_id,
object="audio.speech",
model=self.config.served_model_name or self.config.model,
status="failed",
progress=0,
created=int(time.time()),
data=[],
error=str(e),
)
return error_response.model_dump()
...@@ -159,8 +159,30 @@ class BaseOmniHandler(BaseWorkerHandler[Dict[str, Any], Dict[str, Any]]): ...@@ -159,8 +159,30 @@ class BaseOmniHandler(BaseWorkerHandler[Dict[str, Any], Dict[str, Any]]):
request, self.default_sampling_params, self.model_max_len request, self.default_sampling_params, self.model_max_len
) )
def _error_chunk(self, request_id: str, error_message: str) -> Dict[str, Any]: def _error_chunk(
"""Create an error chunk in OpenAI format.""" self,
request_id: str,
error_message: str,
request_type=None,
) -> Dict[str, Any]:
"""Create an error response matching the expected protocol for the request type.
For AUDIO_GENERATION returns NvAudioSpeechResponse format.
For all other types returns OpenAI chat.completion.chunk format.
"""
from dynamo.common.utils.output_modalities import RequestType
if request_type == RequestType.AUDIO_GENERATION:
from dynamo.common.protocols.audio_protocol import NvAudioSpeechResponse
return NvAudioSpeechResponse(
id=request_id,
model=self.config.served_model_name or self.config.model,
status="failed",
created=int(time.time()),
error=error_message,
).model_dump()
return { return {
"id": request_id, "id": request_id,
"created": int(time.time()), "created": int(time.time()),
......
...@@ -20,6 +20,10 @@ from dynamo.runtime import DistributedRuntime ...@@ -20,6 +20,10 @@ from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.health_check import VllmOmniHealthCheckPayload from dynamo.vllm.health_check import VllmOmniHealthCheckPayload
from dynamo.vllm.main import setup_metrics_collection from dynamo.vllm.main import setup_metrics_collection
from dynamo.vllm.omni.tts_utils import (
cleanup_dummy_tokenizer_for_tts,
ensure_dummy_tokenizer_for_tts,
)
from .args import OmniConfig, parse_omni_args from .args import OmniConfig, parse_omni_args
...@@ -69,6 +73,15 @@ async def init_omni( ...@@ -69,6 +73,15 @@ async def init_omni(
if model_type is None: if model_type is None:
model_type = ModelType.Images model_type = ModelType.Images
# Audio/TTS models (e.g., Qwen3-TTS) don't ship a standard tokenizer.json,
# which causes register_model to fail when building the ModelDeploymentCard.
# Create a minimal placeholder so the Rust card loader doesn't bail,
# then delete it immediately after so vLLM-Omni's inference-time
# AutoTokenizer.from_pretrained() doesn't pick up the fake file.
dummy_tokenizer_paths = []
if "audio" in config.output_modalities:
dummy_tokenizer_paths = ensure_dummy_tokenizer_for_tts(config.model)
await register_model( await register_model(
ModelInput.Text, ModelInput.Text,
model_type, model_type,
...@@ -78,6 +91,9 @@ async def init_omni( ...@@ -78,6 +91,9 @@ async def init_omni(
kv_cache_block_size=config.engine_args.block_size, kv_cache_block_size=config.engine_args.block_size,
) )
if dummy_tokenizer_paths:
cleanup_dummy_tokenizer_for_tts(dummy_tokenizer_paths)
logger.info("Starting to serve Omni worker endpoint...") logger.info("Starting to serve Omni worker endpoint...")
health_check_payload = ( health_check_payload = (
......
...@@ -17,6 +17,7 @@ from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt ...@@ -17,6 +17,7 @@ from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.multimodal import ImageLoader from dynamo.common.multimodal import ImageLoader
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import ( from dynamo.common.protocols.image_protocol import (
ImageData, ImageData,
NvCreateImageRequest, NvCreateImageRequest,
...@@ -36,6 +37,7 @@ from dynamo.common.utils.video_utils import ( ...@@ -36,6 +37,7 @@ from dynamo.common.utils.video_utils import (
parse_size, parse_size,
) )
from dynamo.llm.exceptions import EngineShutdown from dynamo.llm.exceptions import EngineShutdown
from dynamo.vllm.omni.audio_handler import AudioGenerationHandler
from dynamo.vllm.omni.base_handler import BaseOmniHandler from dynamo.vllm.omni.base_handler import BaseOmniHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -57,17 +59,19 @@ class EngineInputs: ...@@ -57,17 +59,19 @@ class EngineInputs:
image requests). None means use the default for the request type. image requests). None means use the default for the request type.
""" """
prompt: OmniTextPrompt prompt: Union[OmniTextPrompt, Dict[str, Any]]
sampling_params_list: list | None = None sampling_params_list: list | None = None
request_type: RequestType = RequestType.CHAT_COMPLETION request_type: RequestType = RequestType.CHAT_COMPLETION
fps: int = 0 fps: int = 0
speed: float = 1.0
response_format: str | None = None response_format: str | None = None
class OmniHandler(BaseOmniHandler): class OmniHandler(BaseOmniHandler):
"""Unified handler for multi-stage pipelines using vLLM-Omni. """Unified handler for multi-stage pipelines using vLLM-Omni.
Handles text-to-text, text-to-image, and text-to-video generation. Handles text-to-text, text-to-image, text-to-video, and text-to-audio generation.
Audio/TTS logic is delegated to AudioGenerationHandler via composition.
""" """
def __init__( def __init__(
...@@ -100,6 +104,14 @@ class OmniHandler(BaseOmniHandler): ...@@ -100,6 +104,14 @@ class OmniHandler(BaseOmniHandler):
self.media_output_http_url = media_output_http_url self.media_output_http_url = media_output_http_url
self._image_loader = ImageLoader() self._image_loader = ImageLoader()
# Audio/TTS handler — composition, not inheritance.
self.audio = AudioGenerationHandler(
config=config,
engine_client=self.engine_client,
media_output_fs=media_output_fs,
media_output_http_url=media_output_http_url,
)
async def generate( async def generate(
self, request: Dict[str, Any], context: Context self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict[str, Any], None]: ) -> AsyncGenerator[Dict[str, Any], None]:
...@@ -154,7 +166,14 @@ class OmniHandler(BaseOmniHandler): ...@@ -154,7 +166,14 @@ class OmniHandler(BaseOmniHandler):
} }
return return
inputs = self.build_engine_inputs(parsed_request, request_type, image=image) try:
inputs = await self.build_engine_inputs(
parsed_request, request_type, image=image
)
except (ValueError, NotImplementedError) as e:
logger.error(f"Invalid request {request_id}: {e}")
yield self._error_chunk(request_id, str(e), request_type)
return
generate_kwargs: Dict[str, Any] = { generate_kwargs: Dict[str, Any] = {
"prompt": inputs.prompt, "prompt": inputs.prompt,
...@@ -207,17 +226,33 @@ class OmniHandler(BaseOmniHandler): ...@@ -207,17 +226,33 @@ class OmniHandler(BaseOmniHandler):
if chunk: if chunk:
yield chunk yield chunk
elif stage_output.final_output_type == "audio":
mm_output = stage_output.multimodal_output
if mm_output:
chunk = await self.audio.format_output(
mm_output,
request_id,
response_format=inputs.response_format,
request_type=inputs.request_type,
speed=inputs.speed,
)
if chunk:
yield chunk
except EngineShutdown: except EngineShutdown:
logger.info(f"Request {request_id} aborted due to shutdown") logger.info(f"Request {request_id} aborted due to shutdown")
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error during generation for request {request_id}: {e}") logger.error(f"Error during generation for request {request_id}: {e}")
yield self._error_chunk(request_id, str(e)) yield self._error_chunk(request_id, str(e), inputs.request_type)
def build_engine_inputs( async def build_engine_inputs(
self, self,
parsed_request: Union[ parsed_request: Union[
NvCreateImageRequest, NvCreateVideoRequest, Dict[str, Any] NvCreateImageRequest,
NvCreateVideoRequest,
NvCreateAudioSpeechRequest,
Dict[str, Any],
], ],
request_type: RequestType, request_type: RequestType,
image: PIL.Image.Image | None = None, image: PIL.Image.Image | None = None,
...@@ -226,7 +261,7 @@ class OmniHandler(BaseOmniHandler): ...@@ -226,7 +261,7 @@ class OmniHandler(BaseOmniHandler):
Args: Args:
parsed_request: Output from parse_request_type -- a Pydantic model parsed_request: Output from parse_request_type -- a Pydantic model
for image/video requests, or a raw dict for chat completions. for image/video/audio requests, or a raw dict for chat completions.
request_type: The RequestType determined by parse_request_type. request_type: The RequestType determined by parse_request_type.
image: Pre-loaded PIL Image for I2V requests (from input_reference). image: Pre-loaded PIL Image for I2V requests (from input_reference).
...@@ -242,29 +277,24 @@ class OmniHandler(BaseOmniHandler): ...@@ -242,29 +277,24 @@ class OmniHandler(BaseOmniHandler):
elif request_type == RequestType.VIDEO_GENERATION: elif request_type == RequestType.VIDEO_GENERATION:
assert isinstance(parsed_request, NvCreateVideoRequest) assert isinstance(parsed_request, NvCreateVideoRequest)
return self._engine_inputs_from_video(parsed_request, image=image) return self._engine_inputs_from_video(parsed_request, image=image)
elif request_type == RequestType.AUDIO_GENERATION: elif request_type == RequestType.AUDIO_GENERATION:
raise NotImplementedError("Audio generation is not yet supported") assert isinstance(parsed_request, NvCreateAudioSpeechRequest)
return await self.audio.build_engine_inputs(parsed_request)
raise ValueError(f"Unknown request type: {request_type}") raise ValueError(f"Unknown request type: {request_type}")
def _engine_inputs_from_chat(self, request: Dict[str, Any]) -> EngineInputs: def _engine_inputs_from_chat(self, request: Dict[str, Any]) -> EngineInputs:
"""Build engine inputs from a chat completions request dict.""" """Build engine inputs from a chat completions request dict."""
# Chat completions request does not support extra_body passthrough
# So, we can't extract any diffusion related params from the raw_request
# It falls back to default sampling params
text_prompt = self._extract_text_prompt(request) text_prompt = self._extract_text_prompt(request)
if text_prompt is None: if text_prompt is None:
raise ValueError("No user message found in chat completion request") raise ValueError("No user message found in chat completion request")
prompt = OmniTextPrompt(prompt=text_prompt) prompt = OmniTextPrompt(prompt=text_prompt)
sampling_params_list = None
return EngineInputs( return EngineInputs(
prompt=prompt, prompt=prompt,
sampling_params_list=sampling_params_list, sampling_params_list=None,
request_type=RequestType.CHAT_COMPLETION, request_type=RequestType.CHAT_COMPLETION,
fps=0, fps=0,
) )
...@@ -276,9 +306,9 @@ class OmniHandler(BaseOmniHandler): ...@@ -276,9 +306,9 @@ class OmniHandler(BaseOmniHandler):
prompt = OmniTextPrompt( prompt = OmniTextPrompt(
prompt=req.prompt, prompt=req.prompt,
negative_prompt=nvext.negative_prompt negative_prompt=(
if nvext and nvext.negative_prompt nvext.negative_prompt if nvext and nvext.negative_prompt else None
else None, ),
) )
sp = OmniDiffusionSamplingParams( sp = OmniDiffusionSamplingParams(
...@@ -331,9 +361,9 @@ class OmniHandler(BaseOmniHandler): ...@@ -331,9 +361,9 @@ class OmniHandler(BaseOmniHandler):
prompt = OmniTextPrompt( prompt = OmniTextPrompt(
prompt=req.prompt, prompt=req.prompt,
negative_prompt=nvext.negative_prompt negative_prompt=(
if nvext and nvext.negative_prompt nvext.negative_prompt if nvext and nvext.negative_prompt else None
else None, ),
) )
if image is not None: if image is not None:
...@@ -577,9 +607,11 @@ class OmniHandler(BaseOmniHandler): ...@@ -577,9 +607,11 @@ class OmniHandler(BaseOmniHandler):
"role": "assistant", "role": "assistant",
"content": delta_text, "content": delta_text,
}, },
"finish_reason": normalize_finish_reason(output.finish_reason) "finish_reason": (
normalize_finish_reason(output.finish_reason)
if output.finish_reason if output.finish_reason
else None, else None
),
} }
], ],
} }
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""TTS/audio utility functions for the vLLM-Omni backend."""
import json
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
def ensure_dummy_tokenizer_for_tts(model: str) -> list[Path]:
"""Create a minimal tokenizer.json for TTS models that lack one.
Audio/TTS models (e.g., Qwen3-TTS) use a custom speech tokenizer and don't
ship the standard tokenizer.json expected by the Rust ModelDeploymentCard
loader. This writes a placeholder so register_model doesn't fail.
Returns the list of created dummy paths so the caller can delete them
after registration (otherwise the fake tokenizer poisons vLLM-Omni's
inference-time AutoTokenizer.from_pretrained call).
This is a short-term workaround. The long-term fix is making TokenizerKind
optional in ModelDeploymentCard::from_repo_checkout().
"""
from huggingface_hub import scan_cache_dir
created: list[Path] = []
cache_info = scan_cache_dir()
for repo in cache_info.repos:
if repo.repo_id == model:
for revision in repo.revisions:
tokenizer_path = Path(revision.snapshot_path) / "tokenizer.json"
if not tokenizer_path.exists():
logger.warning(
"TTS model %s has no tokenizer.json; "
"creating a minimal placeholder at %s",
model,
tokenizer_path,
)
# Write a minimal but valid HF tokenizer JSON that
# tokenizers.TokenizerFast.from_file() can parse without
# crashing. The "model" key with type "BPE" is the
# minimum required structure.
minimal_tokenizer = {
"version": "1.0",
"model": {"type": "BPE", "vocab": {}, "merges": []},
}
tokenizer_path.write_text(json.dumps(minimal_tokenizer))
created.append(tokenizer_path)
return created
return created
def cleanup_dummy_tokenizer_for_tts(paths: list[Path]):
"""Remove dummy tokenizer.json files created by ensure_dummy_tokenizer_for_tts.
Must be called after register_model() completes so the fake tokenizer
doesn't interfere with vLLM-Omni's inference-time tokenizer loading
(AutoTokenizer.from_pretrained picks up our stub and crashes).
"""
for path in paths:
try:
path.unlink(missing_ok=True)
logger.info("Removed dummy tokenizer placeholder: %s", path)
except OSError as e:
logger.warning("Failed to remove dummy tokenizer %s: %s", path, e)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for AudioGenerationHandler."""
from unittest.mock import MagicMock
import pytest
try:
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.utils.output_modalities import RequestType
from dynamo.vllm.omni.audio_handler import AudioGenerationHandler
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
def _make_audio_handler(**config_overrides):
"""Create an AudioGenerationHandler with mocked dependencies."""
config = MagicMock()
config.model = "test-tts-model"
config.served_model_name = None
config.tts_max_instructions_length = 500
config.tts_max_new_tokens_min = 1
config.tts_max_new_tokens_max = 4096
config.tts_ref_audio_timeout = 15
config.tts_ref_audio_max_bytes = 50 * 1024 * 1024
for k, v in config_overrides.items():
setattr(config, k, v)
engine_client = MagicMock()
engine_client.model_config.hf_config = MagicMock(spec=[])
handler = AudioGenerationHandler(
config=config,
engine_client=engine_client,
media_output_fs=None,
media_output_http_url=None,
)
return handler
class TestValidateTtsRequest:
"""Tests for _validate_tts_request."""
@pytest.mark.asyncio
async def test_empty_input_rejected(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(input=" ")
with pytest.raises(ValueError, match="Input text cannot be empty"):
await handler.build_engine_inputs(req)
def test_invalid_task_type_rejected_by_pydantic(self):
"""Pydantic Literal validation rejects invalid task_type at construction."""
with pytest.raises(Exception):
NvCreateAudioSpeechRequest(input="hello", task_type="Banana")
def test_valid_task_types_accepted(self):
handler = _make_audio_handler()
for task in ("CustomVoice", "VoiceDesign", "Base"):
req = NvCreateAudioSpeechRequest(input="hello", task_type=task)
if task == "VoiceDesign":
req.instructions = "cheerful"
elif task == "Base":
req.ref_audio = "data:audio/wav;base64,AAAA"
handler._validate_tts_request(req)
def test_voice_design_requires_instructions(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(input="hello", task_type="VoiceDesign")
with pytest.raises(ValueError, match="instructions"):
handler._validate_tts_request(req)
def test_base_requires_ref_audio(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(input="hello", task_type="Base")
with pytest.raises(ValueError, match="ref_audio"):
handler._validate_tts_request(req)
def test_ref_text_only_for_base(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(
input="hello", task_type="CustomVoice", ref_text="foo"
)
with pytest.raises(ValueError, match="only valid for Base"):
handler._validate_tts_request(req)
def test_instructions_length_enforced(self):
handler = _make_audio_handler(tts_max_instructions_length=10)
req = NvCreateAudioSpeechRequest(input="hello", instructions="x" * 11)
with pytest.raises(ValueError, match="Instructions too long"):
handler._validate_tts_request(req)
def test_max_new_tokens_range(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(input="hello", max_new_tokens=0)
with pytest.raises(ValueError, match="at least"):
handler._validate_tts_request(req)
req = NvCreateAudioSpeechRequest(input="hello", max_new_tokens=99999)
with pytest.raises(ValueError, match="cannot exceed"):
handler._validate_tts_request(req)
def test_invalid_voice_rejected_when_speakers_loaded(self):
handler = _make_audio_handler()
handler._tts_supported_speakers = {"vivian", "ryan"}
req = NvCreateAudioSpeechRequest(input="hello", voice="nonexistent")
with pytest.raises(ValueError, match="Invalid voice"):
handler._validate_tts_request(req)
def test_valid_voice_accepted(self):
handler = _make_audio_handler()
handler._tts_supported_speakers = {"vivian", "ryan"}
req = NvCreateAudioSpeechRequest(input="hello", voice="Vivian")
handler._validate_tts_request(req) # Should not raise
def test_invalid_language_rejected_when_languages_loaded(self):
handler = _make_audio_handler()
handler._tts_supported_languages = {"english", "chinese"}
req = NvCreateAudioSpeechRequest(input="hello", language="Klingon")
with pytest.raises(ValueError, match="Invalid language"):
handler._validate_tts_request(req)
def test_auto_language_always_accepted(self):
handler = _make_audio_handler()
handler._tts_supported_languages = {"english"}
req = NvCreateAudioSpeechRequest(input="hello", language="Auto")
handler._validate_tts_request(req) # Should not raise
class TestIsTtsModel:
"""Tests for _is_tts_model detection."""
def test_qwen3_tts_detected(self):
handler = _make_audio_handler()
stage = MagicMock()
stage.model_stage = "qwen3_tts"
handler.engine_client.stage_list = [stage]
assert handler._is_tts_model() is True
def test_non_tts_model(self):
handler = _make_audio_handler()
stage = MagicMock()
stage.model_stage = "diffusion"
handler.engine_client.stage_list = [stage]
assert handler._is_tts_model() is False
def test_no_stage_list(self):
handler = _make_audio_handler()
handler.engine_client.stage_list = None
assert handler._is_tts_model() is False
class TestEngineInputsFromAudio:
"""Tests for build_engine_inputs."""
@pytest.mark.asyncio
async def test_generic_path_for_non_tts(self):
"""Non-TTS model gets plain text prompt."""
handler = _make_audio_handler()
stage = MagicMock()
stage.model_stage = "diffusion"
handler.engine_client.stage_list = [stage]
req = NvCreateAudioSpeechRequest(input="Hello world")
inputs = await handler.build_engine_inputs(req)
assert inputs.request_type == RequestType.AUDIO_GENERATION
assert inputs.prompt["prompt"] == "Hello world"
assert inputs.sampling_params_list is None
@pytest.mark.asyncio
async def test_empty_input_rejected(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(input=" ")
with pytest.raises(ValueError, match="empty"):
await handler.build_engine_inputs(req)
@pytest.mark.asyncio
async def test_speed_propagated(self):
"""Speed from request is stored in EngineInputs."""
handler = _make_audio_handler()
handler.engine_client.stage_list = None # non-TTS path
req = NvCreateAudioSpeechRequest(input="hello", speed=2.0)
inputs = await handler.build_engine_inputs(req)
assert inputs.speed == 2.0
class TestExtractAudioTensor:
"""Tests for _extract_audio_tensor."""
def test_extracts_from_audio_key(self):
import numpy as np
handler = _make_audio_handler()
mm = {"audio": np.array([0.1, -0.2, 0.3], dtype=np.float32), "sr": 24000}
audio_np, sr = handler._extract_audio_tensor(mm)
assert sr == 24000
assert len(audio_np) == 3
def test_extracts_from_model_outputs_key(self):
import numpy as np
handler = _make_audio_handler()
mm = {
"model_outputs": np.array([0.5, -0.5], dtype=np.float32),
"sr": 16000,
}
audio_np, sr = handler._extract_audio_tensor(mm)
assert sr == 16000
assert len(audio_np) == 2
def test_missing_audio_raises(self):
handler = _make_audio_handler()
with pytest.raises(ValueError, match="No audio data"):
handler._extract_audio_tensor({"sr": 24000})
def test_squeezes_extra_dims(self):
import numpy as np
handler = _make_audio_handler()
mm = {"audio": np.array([[0.1, 0.2, 0.3]], dtype=np.float32), "sr": 24000}
audio_np, _ = handler._extract_audio_tensor(mm)
assert audio_np.ndim == 1
class TestEncodeAudio:
"""Tests for _encode_audio."""
def test_wav_encoding(self):
import numpy as np
handler = _make_audio_handler()
audio = np.zeros(2400, dtype=np.float32)
audio_bytes, media_type = handler._encode_audio(audio, 24000, "wav")
assert media_type == "audio/wav"
assert len(audio_bytes) > 0
assert audio_bytes[:4] == b"RIFF" # WAV header
def test_unsupported_format_falls_back_to_wav(self):
import numpy as np
handler = _make_audio_handler()
audio = np.zeros(100, dtype=np.float32)
_, media_type = handler._encode_audio(audio, 24000, "xyz")
assert media_type == "audio/wav"
def test_default_format_is_wav(self):
import numpy as np
handler = _make_audio_handler()
audio = np.zeros(100, dtype=np.float32)
_, media_type = handler._encode_audio(audio, 24000)
assert media_type == "audio/wav"
class TestFormatAudioChunk:
"""Tests for format_output."""
@pytest.mark.asyncio
async def test_empty_mm_output_returns_error(self):
handler = _make_audio_handler()
result = await handler.format_output({}, "req-1")
assert result["status"] == "failed"
assert "No audio generated" in result["error"]
@pytest.mark.asyncio
async def test_successful_generation(self):
import numpy as np
handler = _make_audio_handler()
mm = {"audio": np.random.randn(4800).astype(np.float32), "sr": 24000}
result = await handler.format_output(mm, "req-1")
assert result["status"] == "completed"
assert result["object"] == "audio.speech"
assert len(result["data"]) == 1
assert result["data"][0]["b64_json"] is not None
...@@ -8,6 +8,7 @@ import pytest ...@@ -8,6 +8,7 @@ import pytest
try: try:
from PIL import Image from PIL import Image
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import NvCreateImageRequest from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest, VideoNvExt from dynamo.common.protocols.video_protocol import NvCreateVideoRequest, VideoNvExt
from dynamo.common.utils.output_modalities import RequestType from dynamo.common.utils.output_modalities import RequestType
...@@ -86,20 +87,22 @@ class TestPrepareImageOutput: ...@@ -86,20 +87,22 @@ class TestPrepareImageOutput:
class TestBuildEngineInputs: class TestBuildEngineInputs:
def test_chat_completion(self): @pytest.mark.asyncio
async def test_chat_completion(self):
"""Chat request extracts text prompt with no sampling params.""" """Chat request extracts text prompt with no sampling params."""
handler = _make_handler() handler = _make_handler()
raw = {"messages": [{"role": "user", "content": "hello"}]} raw = {"messages": [{"role": "user", "content": "hello"}]}
inputs = handler.build_engine_inputs(raw, RequestType.CHAT_COMPLETION) inputs = await handler.build_engine_inputs(raw, RequestType.CHAT_COMPLETION)
assert inputs.request_type == RequestType.CHAT_COMPLETION assert inputs.request_type == RequestType.CHAT_COMPLETION
assert inputs.prompt["prompt"] == "hello" assert inputs.prompt["prompt"] == "hello"
assert inputs.sampling_params_list is None assert inputs.sampling_params_list is None
def test_image_generation(self): @pytest.mark.asyncio
async def test_image_generation(self):
"""Image request parses prompt, size, and creates diffusion sampling params.""" """Image request parses prompt, size, and creates diffusion sampling params."""
handler = _make_handler() handler = _make_handler()
req = NvCreateImageRequest(prompt="a cat", size="512x512") req = NvCreateImageRequest(prompt="a cat", size="512x512")
inputs = handler.build_engine_inputs(req, RequestType.IMAGE_GENERATION) inputs = await handler.build_engine_inputs(req, RequestType.IMAGE_GENERATION)
assert inputs.request_type == RequestType.IMAGE_GENERATION assert inputs.request_type == RequestType.IMAGE_GENERATION
assert inputs.prompt["prompt"] == "a cat" assert inputs.prompt["prompt"] == "a cat"
assert len(inputs.sampling_params_list) == 1 assert len(inputs.sampling_params_list) == 1
...@@ -107,22 +110,38 @@ class TestBuildEngineInputs: ...@@ -107,22 +110,38 @@ class TestBuildEngineInputs:
assert sp.height == 512 assert sp.height == 512
assert sp.width == 512 assert sp.width == 512
def test_video_generation(self): @pytest.mark.asyncio
async def test_video_generation(self):
"""Video request parses prompt, size, seconds, and sets fps.""" """Video request parses prompt, size, seconds, and sets fps."""
handler = _make_handler() handler = _make_handler()
req = NvCreateVideoRequest( req = NvCreateVideoRequest(
prompt="a drone", model="test", size="832x480", seconds=2 prompt="a drone", model="test", size="832x480", seconds=2
) )
inputs = handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION) inputs = await handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION)
assert inputs.request_type == RequestType.VIDEO_GENERATION assert inputs.request_type == RequestType.VIDEO_GENERATION
assert inputs.prompt["prompt"] == "a drone" assert inputs.prompt["prompt"] == "a drone"
assert inputs.fps > 0 assert inputs.fps > 0
def test_audio_not_implemented(self): @pytest.mark.asyncio
"""Audio generation raises NotImplementedError.""" async def test_audio_generation_delegates_toaudio(self):
"""Audio request delegates to audio."""
handler = _make_handler() handler = _make_handler()
with pytest.raises(NotImplementedError): expected = EngineInputs(
handler.build_engine_inputs({}, RequestType.AUDIO_GENERATION) prompt={"prompt": "Hello world"},
request_type=RequestType.AUDIO_GENERATION,
)
async def mock_engine_inputs(req):
return expected
handler.audio = MagicMock()
handler.audio.build_engine_inputs = mock_engine_inputs
inputs = await handler.build_engine_inputs(
NvCreateAudioSpeechRequest(input="Hello world"),
RequestType.AUDIO_GENERATION,
)
assert inputs.request_type == RequestType.AUDIO_GENERATION
assert inputs.prompt["prompt"] == "Hello world"
class TestFormatTextChunk: class TestFormatTextChunk:
...@@ -254,7 +273,8 @@ class TestFormatVideoChunk: ...@@ -254,7 +273,8 @@ class TestFormatVideoChunk:
class TestI2VEngineInputs: class TestI2VEngineInputs:
"""Tests for image-to-video: multi_modal_data attachment, I2V nvext params, and protocol fields.""" """Tests for image-to-video: multi_modal_data attachment, I2V nvext params, and protocol fields."""
def test_t2v_no_multi_modal_data_and_i2v_attaches_image(self): @pytest.mark.asyncio
async def test_t2v_no_multi_modal_data_and_i2v_attaches_image(self):
"""T2V has no multi_modal_data; I2V attaches image to prompt.""" """T2V has no multi_modal_data; I2V attaches image to prompt."""
handler = _make_handler() handler = _make_handler()
req = NvCreateVideoRequest( req = NvCreateVideoRequest(
...@@ -262,15 +282,18 @@ class TestI2VEngineInputs: ...@@ -262,15 +282,18 @@ class TestI2VEngineInputs:
) )
# T2V: no image # T2V: no image
t2v = handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION) t2v = await handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION)
assert "multi_modal_data" not in t2v.prompt assert "multi_modal_data" not in t2v.prompt
# I2V: image attached # I2V: image attached
img = Image.new("RGB", (64, 64), color="red") img = Image.new("RGB", (64, 64), color="red")
i2v = handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION, image=img) i2v = await handler.build_engine_inputs(
req, RequestType.VIDEO_GENERATION, image=img
)
assert i2v.prompt["multi_modal_data"]["image"] is img assert i2v.prompt["multi_modal_data"]["image"] is img
def test_i2v_nvext_params_on_sampling_params(self): @pytest.mark.asyncio
async def test_i2v_nvext_params_on_sampling_params(self):
"""boundary_ratio and guidance_scale_2 are forwarded to sampling params.""" """boundary_ratio and guidance_scale_2 are forwarded to sampling params."""
handler = _make_handler() handler = _make_handler()
req = NvCreateVideoRequest( req = NvCreateVideoRequest(
...@@ -281,9 +304,8 @@ class TestI2VEngineInputs: ...@@ -281,9 +304,8 @@ class TestI2VEngineInputs:
boundary_ratio=0.875, guidance_scale_2=1.0, num_inference_steps=40 boundary_ratio=0.875, guidance_scale_2=1.0, num_inference_steps=40
), ),
) )
sp = handler.build_engine_inputs( result = await handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION)
req, RequestType.VIDEO_GENERATION sp = result.sampling_params_list[0]
).sampling_params_list[0]
assert sp.boundary_ratio == 0.875 assert sp.boundary_ratio == 0.875
assert sp.guidance_scale_2 == 1.0 assert sp.guidance_scale_2 == 1.0
assert sp.num_inference_steps == 40 assert sp.num_inference_steps == 40
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
title: vLLM-Omni title: vLLM-Omni
--- ---
Dynamo supports multimodal generation through the [vLLM-Omni](https://github.com/vllm-project/vllm-omni) backend. This integration exposes text-to-text, text-to-image, and text-to-video capabilities via OpenAI-compatible API endpoints. Dynamo supports multimodal generation through the [vLLM-Omni](https://github.com/vllm-project/vllm-omni) backend. This integration exposes text-to-text, text-to-image, text-to-video, and text-to-audio (TTS) capabilities via OpenAI-compatible API endpoints.
## Prerequisites ## Prerequisites
...@@ -26,8 +26,9 @@ pip install git+https://github.com/vllm-project/vllm-omni.git@v0.16.0rc1 ...@@ -26,8 +26,9 @@ pip install git+https://github.com/vllm-project/vllm-omni.git@v0.16.0rc1
| Text-to-Image | `/v1/chat/completions`, `/v1/images/generations` | `image` | | Text-to-Image | `/v1/chat/completions`, `/v1/images/generations` | `image` |
| Text-to-Video | `/v1/videos` | `video` | | Text-to-Video | `/v1/videos` | `video` |
| Image-to-Video | `/v1/videos` | `video` | | Image-to-Video | `/v1/videos` | `video` |
| Text-to-Audio (TTS) | `/v1/audio/speech` | `audio` |
The `--output-modalities` flag determines which endpoint(s) the worker registers. When set to `image`, both `/v1/chat/completions` (returns inline base64 images) and `/v1/images/generations` are available. When set to `video`, the worker serves `/v1/videos`. The `--output-modalities` flag determines which endpoint(s) the worker registers. When set to `image`, both `/v1/chat/completions` (returns inline base64 images) and `/v1/images/generations` are available. When set to `video`, the worker serves `/v1/videos`. When set to `audio`, the worker serves `/v1/audio/speech`.
## Tested Models ## Tested Models
...@@ -37,6 +38,7 @@ The `--output-modalities` flag determines which endpoint(s) the worker registers ...@@ -37,6 +38,7 @@ The `--output-modalities` flag determines which endpoint(s) the worker registers
| Text-to-Image | `Qwen/Qwen-Image`, `AIDC-AI/Ovis-Image-7B` | | Text-to-Image | `Qwen/Qwen-Image`, `AIDC-AI/Ovis-Image-7B` |
| Text-to-Video | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | | Text-to-Video | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers` |
| Image-to-Video | `Wan-AI/Wan2.2-TI2V-5B-Diffusers`, `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | | Image-to-Video | `Wan-AI/Wan2.2-TI2V-5B-Diffusers`, `Wan-AI/Wan2.2-I2V-A14B-Diffusers` |
| Text-to-Audio (TTS) | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice`, `Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign` |
To run a non-default model, pass `--model` to any launch script: To run a non-default model, pass `--model` to any launch script:
...@@ -203,13 +205,80 @@ The `input_reference` field accepts: ...@@ -203,13 +205,80 @@ The `input_reference` field accepts:
The I2V-specific `nvext` fields (`boundary_ratio`, `guidance_scale_2`) control the dual-expert MoE denoising schedule in Wan2.x models. See [Wan2.2-I2V model card](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) for details. The I2V-specific `nvext` fields (`boundary_ratio`, `guidance_scale_2`) control the dual-expert MoE denoising schedule in Wan2.x models. See [Wan2.2-I2V model card](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) for details.
## Text-to-Audio (TTS)
Launch using the provided script with `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice`:
```bash
bash examples/backends/vllm/launch/agg_omni_audio.sh
```
### CustomVoice (predefined speakers)
```bash
curl -X POST http://localhost:8000/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{
"input": "Hello, how are you?",
"voice": "vivian",
"language": "English"
}' --output output.wav
```
### CustomVoice with style instructions
```bash
curl -X POST http://localhost:8000/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{
"input": "I am so excited!",
"voice": "vivian",
"instructions": "Speak with great enthusiasm"
}' --output excited.wav
```
### VoiceDesign (describe a voice)
```bash
bash examples/backends/vllm/launch/agg_omni_audio.sh --model Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign
curl -X POST http://localhost:8000/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{
"input": "Hello world",
"task_type": "VoiceDesign",
"instructions": "A warm, friendly female voice with a gentle tone"
}' --output voicedesign.wav
```
### Parameters
The `/v1/audio/speech` endpoint follows the [vLLM-Omni](https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/examples/online_serving/qwen3_tts/) API format. All TTS-specific parameters are top-level fields:
| Field | Description | Default |
|---|---|---|
| `input` | Text to synthesize (required) | -- |
| `model` | TTS model name | auto-detected |
| `voice` | Speaker name (e.g., vivian, ryan). Validated against model config. | Vivian |
| `response_format` | Audio format: wav, mp3, pcm, flac, aac, opus | wav |
| `speed` | Speed factor (0.25-4.0) | 1.0 |
| `task_type` | CustomVoice, VoiceDesign, or Base (Qwen3-TTS) | CustomVoice |
| `language` | Language code. Validated against model config. | Auto |
| `instructions` | Voice style/emotion description. Required for VoiceDesign. | -- |
| `ref_audio` | Reference audio URL or base64 data URI. Required for Base. | -- |
| `ref_text` | Transcript of reference audio (Base task) | -- |
| `max_new_tokens` | Maximum tokens to generate (1-4096) | 2048 |
Available voices and languages are loaded dynamically from the model's `config.json` at startup. Non-Qwen3-TTS audio models (e.g., MiMo-Audio) use a generic text prompt and ignore TTS-specific parameters.
## CLI Reference ## CLI Reference
The omni backend uses a dedicated entrypoint: `python -m dynamo.vllm.omni`. The omni backend uses a dedicated entrypoint: `python -m dynamo.vllm.omni`.
| Flag | Description | | Flag | Description |
|---|---| |---|---|
| `--output-modalities <modality>` | Output modality: `text`, `image`, or `video` | | `--omni` | Enable the vLLM-Omni orchestrator (required for all omni workloads) |
| `--output-modalities <modality>` | Output modality: `text`, `image`, `video`, or `audio` |
| `--stage-configs-path <path>` | Path to stage config YAML (optional; vLLM-Omni uses model defaults if omitted) | | `--stage-configs-path <path>` | Path to stage config YAML (optional; vLLM-Omni uses model defaults if omitted) |
| `--boundary-ratio <float>` | MoE expert switching boundary (default: 0.875) | | `--boundary-ratio <float>` | MoE expert switching boundary (default: 0.875) |
| `--flow-shift <float>` | Scheduler flow_shift (5.0 for 720p, 12.0 for 480p) | | `--flow-shift <float>` | Scheduler flow_shift (5.0 for 720p, 12.0 for 480p) |
...@@ -231,7 +300,7 @@ The omni backend uses a dedicated entrypoint: `python -m dynamo.vllm.omni`. ...@@ -231,7 +300,7 @@ The omni backend uses a dedicated entrypoint: `python -m dynamo.vllm.omni`.
## Storage Configuration ## Storage Configuration
Generated images and videos are stored via [fsspec](https://filesystem-spec.readthedocs.io/), which supports local filesystems, S3, GCS, and Azure Blob. Generated images, videos, and audio files are stored via [fsspec](https://filesystem-spec.readthedocs.io/), which supports local filesystems, S3, GCS, and Azure Blob.
By default, media is written to the local filesystem at `file:///tmp/dynamo_media`. To use cloud storage: By default, media is written to the local filesystem at `file:///tmp/dynamo_media`. To use cloud storage:
...@@ -254,3 +323,5 @@ Omni pipelines are configured via YAML stage configs. See [`examples/backends/vl ...@@ -254,3 +323,5 @@ Omni pipelines are configured via YAML stage configs. See [`examples/backends/vl
- Image input is supported only for I2V via `input_reference` in `/v1/videos`. Other endpoints accept text prompts only. - Image input is supported only for I2V via `input_reference` in `/v1/videos`. Other endpoints accept text prompts only.
- KV cache events are not published for omni workers. - KV cache events are not published for omni workers.
- Each worker supports a single output modality at a time. - Each worker supports a single output modality at a time.
- Audio: streaming (`stream: true`) is not yet supported.
- Audio: Base task (voice cloning) is not yet supported.
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../../common/launch_utils.sh"
MODEL="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
# Parse command line arguments
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
HTTP_PORT="${DYN_HTTP_PORT:-8000}"
print_launch_banner --no-curl "Launching vLLM-Omni Audio/TTS (1 GPU)" "$MODEL" "$HTTP_PORT"
print_curl_footer <<CURL
curl -X POST http://localhost:${HTTP_PORT}/v1/audio/speech \\
-H 'Content-Type: application/json' \\
-d '{
"input": "Hey, this is generated using Dynamo!",
"model": "${MODEL}",
"voice": "vivian",
"language": "English"
}' \\
-o dynamo-audio.wav
CURL
python -m dynamo.frontend &
FRONTEND_PID=$!
sleep 2
echo "Starting Omni Audio worker..."
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \
python -m dynamo.vllm.omni \
--model "$MODEL" \
--output-modalities audio \
--media-output-fs-url file:///tmp/dynamo_media \
--trust-remote-code \
--enforce-eager \
"${EXTRA_ARGS[@]}" &
# Exit on first worker failure; kill 0 in the EXIT trap tears down the rest
wait_any_exit
...@@ -19,6 +19,7 @@ use crate::protocols::openai::ParsingOptions; ...@@ -19,6 +19,7 @@ use crate::protocols::openai::ParsingOptions;
use crate::types::{ use crate::types::{
generic::tensor::TensorStreamingEngine, generic::tensor::TensorStreamingEngine,
openai::{ openai::{
audios::OpenAIAudiosStreamingEngine,
chat_completions::OpenAIChatCompletionsStreamingEngine, chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
images::OpenAIImagesStreamingEngine, videos::OpenAIVideosStreamingEngine, images::OpenAIImagesStreamingEngine, videos::OpenAIVideosStreamingEngine,
...@@ -149,6 +150,13 @@ impl Model { ...@@ -149,6 +150,13 @@ impl Model {
.any(|entry| entry.value().has_videos_engine()) .any(|entry| entry.value().has_videos_engine())
} }
/// Check if any WorkerSet has an audios engine.
pub fn has_audios_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_audios_engine())
}
/// Whether this model should be visible in /v1/models. /// Whether this model should be visible in /v1/models.
pub fn is_displayable(&self) -> bool { pub fn is_displayable(&self) -> bool {
let has_serving_engine = |ws: &WorkerSet| { let has_serving_engine = |ws: &WorkerSet| {
...@@ -158,6 +166,7 @@ impl Model { ...@@ -158,6 +166,7 @@ impl Model {
|| ws.has_images_engine() || ws.has_images_engine()
|| ws.has_tensor_engine() || ws.has_tensor_engine()
|| ws.has_videos_engine() || ws.has_videos_engine()
|| ws.has_audios_engine()
}; };
let has_any_serving_engine = self.worker_sets.iter().any(|entry| { let has_any_serving_engine = self.worker_sets.iter().any(|entry| {
...@@ -207,6 +216,11 @@ impl Model { ...@@ -207,6 +216,11 @@ impl Model {
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
} }
pub fn get_audios_engine(&self) -> Result<OpenAIAudiosStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.audios_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_tensor_engine(&self) -> Result<TensorStreamingEngine, ModelManagerError> { pub fn get_tensor_engine(&self) -> Result<TensorStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.tensor_engine.clone()) self.select_worker_set_with(|ws| ws.tensor_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
......
...@@ -24,6 +24,7 @@ use crate::{ ...@@ -24,6 +24,7 @@ use crate::{
types::{ types::{
generic::tensor::TensorStreamingEngine, generic::tensor::TensorStreamingEngine,
openai::{ openai::{
audios::OpenAIAudiosStreamingEngine,
chat_completions::OpenAIChatCompletionsStreamingEngine, chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
...@@ -290,6 +291,16 @@ impl ModelManager { ...@@ -290,6 +291,16 @@ impl ModelManager {
.get_videos_engine() .get_videos_engine()
} }
pub fn get_audios_engine(
&self,
model: &str,
) -> Result<OpenAIAudiosStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_audios_engine()
}
// -- Combined engine + parsing options (atomically from one WorkerSet) -- // -- Combined engine + parsing options (atomically from one WorkerSet) --
pub fn get_chat_completions_engine_with_parsing( pub fn get_chat_completions_engine_with_parsing(
...@@ -456,6 +467,27 @@ impl ModelManager { ...@@ -456,6 +467,27 @@ impl ModelManager {
Ok(()) Ok(())
} }
pub fn add_audios_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIAudiosStreamingEngine,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_audios_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_audios_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.audios_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(())
}
pub fn add_prefill_model( pub fn add_prefill_model(
&self, &self,
model: &str, model: &str,
......
...@@ -34,6 +34,7 @@ use crate::{ ...@@ -34,6 +34,7 @@ use crate::{
protocols::{ protocols::{
common::llm_backend::EmbeddingsEngineOutput, common::llm_backend::EmbeddingsEngineOutput,
openai::{ openai::{
audios::{NvAudioSpeechResponse, NvCreateAudioSpeechRequest},
chat_completions::{ chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}, },
...@@ -685,7 +686,19 @@ impl ModelWatcher { ...@@ -685,7 +686,19 @@ impl ModelWatcher {
worker_set.videos_engine = Some(Arc::new(videos_router)); worker_set.videos_engine = Some(Arc::new(videos_router));
} }
// TODO: add audio models support if card.model_type.supports_audios() {
let audios_router = PushRouter::<
NvCreateAudioSpeechRequest,
Annotated<NvAudioSpeechResponse>,
>::from_client_with_threshold(
client.clone(),
self.router_config.router_mode,
None,
None,
)
.await?;
worker_set.audios_engine = Some(Arc::new(audios_router));
}
} else if card.model_input == ModelInput::Text && card.model_type.supports_chat() { } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
// Case: Text + Chat (pure text-to-text, no diffusion) // Case: Text + Chat (pure text-to-text, no diffusion)
let push_router = PushRouter::< let push_router = PushRouter::<
......
...@@ -16,6 +16,7 @@ use crate::{ ...@@ -16,6 +16,7 @@ use crate::{
types::{ types::{
generic::tensor::TensorStreamingEngine, generic::tensor::TensorStreamingEngine,
openai::{ openai::{
audios::OpenAIAudiosStreamingEngine,
chat_completions::OpenAIChatCompletionsStreamingEngine, chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
...@@ -41,6 +42,7 @@ pub struct WorkerSet { ...@@ -41,6 +42,7 @@ pub struct WorkerSet {
pub(crate) embeddings_engine: Option<OpenAIEmbeddingsStreamingEngine>, pub(crate) embeddings_engine: Option<OpenAIEmbeddingsStreamingEngine>,
pub(crate) images_engine: Option<OpenAIImagesStreamingEngine>, pub(crate) images_engine: Option<OpenAIImagesStreamingEngine>,
pub(crate) videos_engine: Option<OpenAIVideosStreamingEngine>, pub(crate) videos_engine: Option<OpenAIVideosStreamingEngine>,
pub(crate) audios_engine: Option<OpenAIAudiosStreamingEngine>,
pub(crate) tensor_engine: Option<TensorStreamingEngine>, pub(crate) tensor_engine: Option<TensorStreamingEngine>,
/// KV router for this set's workers (if KV mode) /// KV router for this set's workers (if KV mode)
...@@ -65,6 +67,7 @@ impl WorkerSet { ...@@ -65,6 +67,7 @@ impl WorkerSet {
embeddings_engine: None, embeddings_engine: None,
images_engine: None, images_engine: None,
videos_engine: None, videos_engine: None,
audios_engine: None,
tensor_engine: None, tensor_engine: None,
kv_router: None, kv_router: None,
worker_monitor: None, worker_monitor: None,
...@@ -104,6 +107,10 @@ impl WorkerSet { ...@@ -104,6 +107,10 @@ impl WorkerSet {
self.videos_engine.is_some() self.videos_engine.is_some()
} }
pub fn has_audios_engine(&self) -> bool {
self.audios_engine.is_some()
}
pub fn has_tensor_engine(&self) -> bool { pub fn has_tensor_engine(&self) -> bool {
self.tensor_engine.is_some() self.tensor_engine.is_some()
} }
...@@ -119,6 +126,7 @@ impl WorkerSet { ...@@ -119,6 +126,7 @@ impl WorkerSet {
&& !self.has_embeddings_engine() && !self.has_embeddings_engine()
&& !self.has_images_engine() && !self.has_images_engine()
&& !self.has_videos_engine() && !self.has_videos_engine()
&& !self.has_audios_engine()
&& !self.has_tensor_engine() && !self.has_tensor_engine()
} }
......
...@@ -304,6 +304,9 @@ pub enum Endpoint { ...@@ -304,6 +304,9 @@ pub enum Endpoint {
/// OAI Videos /// OAI Videos
Videos, Videos,
/// OAI Audio Speech
Audios,
/// OAI Responses /// OAI Responses
Responses, Responses,
...@@ -1026,6 +1029,7 @@ impl std::fmt::Display for Endpoint { ...@@ -1026,6 +1029,7 @@ impl std::fmt::Display for Endpoint {
Endpoint::Embeddings => write!(f, "embeddings"), Endpoint::Embeddings => write!(f, "embeddings"),
Endpoint::Images => write!(f, "images"), Endpoint::Images => write!(f, "images"),
Endpoint::Videos => write!(f, "videos"), Endpoint::Videos => write!(f, "videos"),
Endpoint::Audios => write!(f, "audios"),
Endpoint::Responses => write!(f, "responses"), Endpoint::Responses => write!(f, "responses"),
Endpoint::AnthropicMessages => write!(f, "anthropic_messages"), Endpoint::AnthropicMessages => write!(f, "anthropic_messages"),
Endpoint::Tensor => write!(f, "tensor"), Endpoint::Tensor => write!(f, "tensor"),
...@@ -1041,6 +1045,7 @@ impl Endpoint { ...@@ -1041,6 +1045,7 @@ impl Endpoint {
Endpoint::Embeddings => "embeddings", Endpoint::Embeddings => "embeddings",
Endpoint::Images => "images", Endpoint::Images => "images",
Endpoint::Videos => "videos", Endpoint::Videos => "videos",
Endpoint::Audios => "audios",
Endpoint::Responses => "responses", Endpoint::Responses => "responses",
Endpoint::AnthropicMessages => "anthropic_messages", Endpoint::AnthropicMessages => "anthropic_messages",
Endpoint::Tensor => "tensor", Endpoint::Tensor => "tensor",
......
...@@ -46,6 +46,7 @@ use crate::engines::ValidateRequest; ...@@ -46,6 +46,7 @@ use crate::engines::ValidateRequest;
use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator; use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator;
use crate::protocols::openai::nvext::apply_header_routing_overrides; use crate::protocols::openai::nvext::apply_header_routing_overrides;
use crate::protocols::openai::{ use crate::protocols::openai::{
audios::{NvAudioSpeechResponse, NvCreateAudioSpeechRequest},
chat_completions::{ chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionResponse,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
...@@ -2201,6 +2202,113 @@ pub fn videos_router( ...@@ -2201,6 +2202,113 @@ pub fn videos_router(
(vec![doc, stream_doc], router) (vec![doc, stream_doc], router)
} }
async fn audio_speech(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateAudioSpeechRequest>,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready
check_ready(&state)?;
let response_format = request.response_format.clone();
let request_id = get_or_create_request_id(request.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let request_id = request.id().to_string();
let streaming = false;
// model is optional in the request; fall back to the first registered model
let model = request.model.clone().unwrap_or_else(|| {
state
.manager()
.model_display_names()
.into_iter()
.next()
.unwrap_or_default()
});
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let engine = state
.manager()
.get_audios_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
let mut inflight =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Audios, streaming);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate audio"))?;
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let response = NvAudioSpeechResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!("Failed to fold audio stream for {}: {:?}", request_id, e);
ErrorMessage::internal_server_error("Failed to fold audio stream")
})?;
// Check for failure before marking success
if response.status == "failed" {
return Ok((axum::http::StatusCode::BAD_REQUEST, Json(response)).into_response());
}
inflight.mark_ok();
// If response contains b64_json audio data, decode and return as binary
// (matching OpenAI/vLLM-Omni behavior: curl --output file.wav)
if let Some(first) = response.data.first()
&& let Some(b64) = &first.b64_json
&& let Ok(audio_bytes) = base64::engine::general_purpose::STANDARD.decode(b64)
{
let content_type = match response_format.as_deref().unwrap_or("wav") {
"mp3" => "audio/mpeg",
"flac" => "audio/flac",
"pcm" => "audio/pcm",
"aac" => "audio/aac",
"opus" => "audio/ogg; codecs=opus",
_ => "audio/wav",
};
return Ok(Response::builder()
.header("content-type", content_type)
.body(axum::body::Body::from(audio_bytes))
.unwrap());
}
// Fallback: return JSON (url format responses)
Ok(Json(response).into_response())
}
/// Create an Axum [`Router`] for the Audio Speech endpoint
/// Default path is `/v1/audio/speech`
pub fn audios_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/audio/speech".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(audio_speech))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc], router)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
......
...@@ -56,6 +56,7 @@ struct StateFlags { ...@@ -56,6 +56,7 @@ struct StateFlags {
embeddings_endpoints_enabled: AtomicBool, embeddings_endpoints_enabled: AtomicBool,
images_endpoints_enabled: AtomicBool, images_endpoints_enabled: AtomicBool,
videos_endpoints_enabled: AtomicBool, videos_endpoints_enabled: AtomicBool,
audios_endpoints_enabled: AtomicBool,
responses_endpoints_enabled: AtomicBool, responses_endpoints_enabled: AtomicBool,
anthropic_endpoints_enabled: AtomicBool, anthropic_endpoints_enabled: AtomicBool,
} }
...@@ -68,8 +69,7 @@ impl StateFlags { ...@@ -68,8 +69,7 @@ impl StateFlags {
EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Images => self.images_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Images => self.images_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Videos => self.videos_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Videos => self.videos_endpoints_enabled.load(Ordering::Relaxed),
// TODO: add audios_endpoints_enabled flag EndpointType::Audios => self.audios_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Audios => false,
EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed), EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::AnthropicMessages => { EndpointType::AnthropicMessages => {
self.anthropic_endpoints_enabled.load(Ordering::Relaxed) self.anthropic_endpoints_enabled.load(Ordering::Relaxed)
...@@ -94,8 +94,9 @@ impl StateFlags { ...@@ -94,8 +94,9 @@ impl StateFlags {
EndpointType::Videos => self EndpointType::Videos => self
.videos_endpoints_enabled .videos_endpoints_enabled
.store(enabled, Ordering::Relaxed), .store(enabled, Ordering::Relaxed),
// TODO: add audios_endpoints_enabled flag EndpointType::Audios => self
EndpointType::Audios => {} .audios_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Responses => self EndpointType::Responses => self
.responses_endpoints_enabled .responses_endpoints_enabled
.store(enabled, Ordering::Relaxed), .store(enabled, Ordering::Relaxed),
...@@ -122,6 +123,7 @@ impl State { ...@@ -122,6 +123,7 @@ impl State {
embeddings_endpoints_enabled: AtomicBool::new(false), embeddings_endpoints_enabled: AtomicBool::new(false),
images_endpoints_enabled: AtomicBool::new(false), images_endpoints_enabled: AtomicBool::new(false),
videos_endpoints_enabled: AtomicBool::new(false), videos_endpoints_enabled: AtomicBool::new(false),
audios_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false), responses_endpoints_enabled: AtomicBool::new(false),
anthropic_endpoints_enabled: AtomicBool::new(false), anthropic_endpoints_enabled: AtomicBool::new(false),
}, },
...@@ -587,6 +589,7 @@ impl HttpServiceConfigBuilder { ...@@ -587,6 +589,7 @@ impl HttpServiceConfigBuilder {
super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok()); super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok());
let (images_docs, images_route) = super::openai::images_router(state.clone(), None); let (images_docs, images_route) = super::openai::images_router(state.clone(), None);
let (videos_docs, videos_route) = super::openai::videos_router(state.clone(), None); let (videos_docs, videos_route) = super::openai::videos_router(state.clone(), None);
let (audios_docs, audios_route) = super::openai::audios_router(state.clone(), None);
let (responses_docs, responses_route) = super::openai::responses_router( let (responses_docs, responses_route) = super::openai::responses_router(
state.clone(), state.clone(),
request_template.clone(), request_template.clone(),
...@@ -598,6 +601,7 @@ impl HttpServiceConfigBuilder { ...@@ -598,6 +601,7 @@ impl HttpServiceConfigBuilder {
endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route)); endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route));
endpoint_routes.insert(EndpointType::Images, (images_docs, images_route)); endpoint_routes.insert(EndpointType::Images, (images_docs, images_route));
endpoint_routes.insert(EndpointType::Videos, (videos_docs, videos_route)); endpoint_routes.insert(EndpointType::Videos, (videos_docs, videos_route));
endpoint_routes.insert(EndpointType::Audios, (audios_docs, audios_route));
endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route)); endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route));
if env_is_truthy(env_llm::DYN_ENABLE_ANTHROPIC_API) { if env_is_truthy(env_llm::DYN_ENABLE_ANTHROPIC_API) {
......
...@@ -11,6 +11,7 @@ use super::{ ...@@ -11,6 +11,7 @@ use super::{
use crate::protocols::openai::common_ext::CommonExtProvider; use crate::protocols::openai::common_ext::CommonExtProvider;
use crate::types::TokenIdType; use crate::types::TokenIdType;
pub mod audios;
pub mod chat_completions; pub mod chat_completions;
pub mod common_ext; pub mod common_ext;
pub mod completions; pub mod completions;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment