Unverified Commit cbbde3d0 authored by Zhongdongming Dai's avatar Zhongdongming Dai Committed by GitHub
Browse files

feat: add initial audio/TTS pipeline support for vLLM-omni backend (#7495)


Signed-off-by: default avatarZhongdongming Dai <zhongdongmin@nvidia.com>
parent 76c70f41
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Protocol types for audio generation (TTS).
These types follow the vLLM-Omni OpenAICreateSpeechRequest format,
with TTS-specific parameters as top-level fields (not nested in nvext).
Note: These Pydantic models mirror the Rust protocol types in
lib/llm/src/protocols/openai/audios.rs. Ideally these should be
code-generated from the Rust definitions; for now they are maintained
manually and must be kept in sync.
"""
from typing import Literal, Optional
from pydantic import BaseModel, Field
class NvCreateAudioSpeechRequest(BaseModel):
"""Request for audio speech generation (/v1/audio/speech endpoint).
Follows vLLM-Omni's OpenAICreateSpeechRequest format.
"""
# Standard OpenAI params
input: str
"""The text to synthesize into speech."""
model: Optional[str] = None
"""The TTS model to use."""
voice: Optional[str] = None
"""Voice/speaker name (e.g., 'vivian', 'ryan', 'aiden')."""
response_format: Optional[
Literal["wav", "pcm", "flac", "mp3", "aac", "opus"]
] = "wav"
"""Output format."""
speed: Optional[float] = Field(default=1.0, ge=0.25, le=4.0)
"""Speed factor."""
# Qwen3-TTS specific params (top-level, matching vLLM-Omni)
task_type: Optional[Literal["CustomVoice", "VoiceDesign", "Base"]] = None
"""TTS task type."""
language: Optional[str] = None
"""Language: Auto, Chinese, English, Japanese, Korean, etc."""
instructions: Optional[str] = None
"""Voice style/emotion instructions (for VoiceDesign)."""
ref_audio: Optional[str] = None
"""Reference audio URL or base64 (for voice cloning with Base task)."""
ref_text: Optional[str] = None
"""Reference transcript (for voice cloning with Base task)."""
max_new_tokens: Optional[int] = None
"""Maximum tokens to generate (default: 2048)."""
class AudioData(BaseModel):
"""Audio data in response."""
url: Optional[str] = None
"""URL of the generated audio (if response_format is 'url')."""
b64_json: Optional[str] = None
"""Base64-encoded audio data."""
class NvAudioSpeechResponse(BaseModel):
"""Response structure for audio speech generation."""
id: str
"""Unique identifier for the response."""
object: str = "audio.speech"
"""Object type."""
model: str
"""Model used for generation."""
status: str = "completed"
"""Generation status."""
progress: int = 100
"""Progress percentage (0-100)."""
created: int
"""Unix timestamp of creation."""
data: list[AudioData] = []
"""List of generated audio data."""
error: Optional[str] = None
"""Error message if generation failed."""
inference_time_s: Optional[float] = None
"""Inference time in seconds."""
......@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest
from dynamo.llm import ModelType
......@@ -91,8 +92,7 @@ def parse_request_type(
return NvCreateVideoRequest(**raw_request), RequestType.VIDEO_GENERATION
if modality is OutputModality.AUDIO:
# Audio protocol types are not yet defined; pass through the raw dict.
return raw_request, RequestType.AUDIO_GENERATION
return NvCreateAudioSpeechRequest(**raw_request), RequestType.AUDIO_GENERATION
# Text Modality
return raw_request, RequestType.CHAT_COMPLETION
......@@ -158,6 +158,52 @@ class OmniArgGroup(ArgGroup):
help="Disable torch.compile and force eager execution for diffusion models.",
)
# TTS parameters
tts_g = parser.add_argument_group(
"Omni TTS Options",
"TTS/audio-specific parameters for vLLM-Omni speech generation.",
)
add_argument(
tts_g,
flag_name="--tts-max-instructions-length",
env_var="DYN_OMNI_TTS_MAX_INSTRUCTIONS_LENGTH",
default=500,
arg_type=int,
help="Maximum character length for TTS voice instructions.",
)
add_argument(
tts_g,
flag_name="--tts-max-new-tokens-min",
env_var="DYN_OMNI_TTS_MAX_NEW_TOKENS_MIN",
default=1,
arg_type=int,
help="Minimum allowed value for max_new_tokens in TTS requests.",
)
add_argument(
tts_g,
flag_name="--tts-max-new-tokens-max",
env_var="DYN_OMNI_TTS_MAX_NEW_TOKENS_MAX",
default=4096,
arg_type=int,
help="Maximum allowed value for max_new_tokens in TTS requests.",
)
add_argument(
tts_g,
flag_name="--tts-ref-audio-timeout",
env_var="DYN_OMNI_TTS_REF_AUDIO_TIMEOUT",
default=15,
arg_type=int,
help="Timeout in seconds for downloading reference audio URLs.",
)
add_argument(
tts_g,
flag_name="--tts-ref-audio-max-bytes",
env_var="DYN_OMNI_TTS_REF_AUDIO_MAX_BYTES",
default=50 * 1024 * 1024,
arg_type=int,
help="Maximum size in bytes for reference audio files (default: 50MB).",
)
# Diffusion parallel configuration
add_argument(
g,
......@@ -217,6 +263,13 @@ class OmniConfig(DynamoRuntimeConfig):
ring_degree: int = 1
cfg_parallel_size: int = 1
# TTS parameters
tts_max_instructions_length: int = 500
tts_max_new_tokens_min: int = 1
tts_max_new_tokens_max: int = 4096
tts_ref_audio_timeout: int = 15
tts_ref_audio_max_bytes: int = 50 * 1024 * 1024
def validate(self) -> None:
DynamoRuntimeConfig.validate(self)
if self.default_video_fps <= 0:
......
This diff is collapsed.
......@@ -159,8 +159,30 @@ class BaseOmniHandler(BaseWorkerHandler[Dict[str, Any], Dict[str, Any]]):
request, self.default_sampling_params, self.model_max_len
)
def _error_chunk(self, request_id: str, error_message: str) -> Dict[str, Any]:
"""Create an error chunk in OpenAI format."""
def _error_chunk(
self,
request_id: str,
error_message: str,
request_type=None,
) -> Dict[str, Any]:
"""Create an error response matching the expected protocol for the request type.
For AUDIO_GENERATION returns NvAudioSpeechResponse format.
For all other types returns OpenAI chat.completion.chunk format.
"""
from dynamo.common.utils.output_modalities import RequestType
if request_type == RequestType.AUDIO_GENERATION:
from dynamo.common.protocols.audio_protocol import NvAudioSpeechResponse
return NvAudioSpeechResponse(
id=request_id,
model=self.config.served_model_name or self.config.model,
status="failed",
created=int(time.time()),
error=error_message,
).model_dump()
return {
"id": request_id,
"created": int(time.time()),
......
......@@ -20,6 +20,10 @@ from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.health_check import VllmOmniHealthCheckPayload
from dynamo.vllm.main import setup_metrics_collection
from dynamo.vllm.omni.tts_utils import (
cleanup_dummy_tokenizer_for_tts,
ensure_dummy_tokenizer_for_tts,
)
from .args import OmniConfig, parse_omni_args
......@@ -69,6 +73,15 @@ async def init_omni(
if model_type is None:
model_type = ModelType.Images
# Audio/TTS models (e.g., Qwen3-TTS) don't ship a standard tokenizer.json,
# which causes register_model to fail when building the ModelDeploymentCard.
# Create a minimal placeholder so the Rust card loader doesn't bail,
# then delete it immediately after so vLLM-Omni's inference-time
# AutoTokenizer.from_pretrained() doesn't pick up the fake file.
dummy_tokenizer_paths = []
if "audio" in config.output_modalities:
dummy_tokenizer_paths = ensure_dummy_tokenizer_for_tts(config.model)
await register_model(
ModelInput.Text,
model_type,
......@@ -78,6 +91,9 @@ async def init_omni(
kv_cache_block_size=config.engine_args.block_size,
)
if dummy_tokenizer_paths:
cleanup_dummy_tokenizer_for_tts(dummy_tokenizer_paths)
logger.info("Starting to serve Omni worker endpoint...")
health_check_payload = (
......
......@@ -17,6 +17,7 @@ from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from dynamo._core import Context
from dynamo.common.multimodal import ImageLoader
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import (
ImageData,
NvCreateImageRequest,
......@@ -36,6 +37,7 @@ from dynamo.common.utils.video_utils import (
parse_size,
)
from dynamo.llm.exceptions import EngineShutdown
from dynamo.vllm.omni.audio_handler import AudioGenerationHandler
from dynamo.vllm.omni.base_handler import BaseOmniHandler
logger = logging.getLogger(__name__)
......@@ -57,17 +59,19 @@ class EngineInputs:
image requests). None means use the default for the request type.
"""
prompt: OmniTextPrompt
prompt: Union[OmniTextPrompt, Dict[str, Any]]
sampling_params_list: list | None = None
request_type: RequestType = RequestType.CHAT_COMPLETION
fps: int = 0
speed: float = 1.0
response_format: str | None = None
class OmniHandler(BaseOmniHandler):
"""Unified handler for multi-stage pipelines using vLLM-Omni.
Handles text-to-text, text-to-image, and text-to-video generation.
Handles text-to-text, text-to-image, text-to-video, and text-to-audio generation.
Audio/TTS logic is delegated to AudioGenerationHandler via composition.
"""
def __init__(
......@@ -100,6 +104,14 @@ class OmniHandler(BaseOmniHandler):
self.media_output_http_url = media_output_http_url
self._image_loader = ImageLoader()
# Audio/TTS handler — composition, not inheritance.
self.audio = AudioGenerationHandler(
config=config,
engine_client=self.engine_client,
media_output_fs=media_output_fs,
media_output_http_url=media_output_http_url,
)
async def generate(
self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
......@@ -154,7 +166,14 @@ class OmniHandler(BaseOmniHandler):
}
return
inputs = self.build_engine_inputs(parsed_request, request_type, image=image)
try:
inputs = await self.build_engine_inputs(
parsed_request, request_type, image=image
)
except (ValueError, NotImplementedError) as e:
logger.error(f"Invalid request {request_id}: {e}")
yield self._error_chunk(request_id, str(e), request_type)
return
generate_kwargs: Dict[str, Any] = {
"prompt": inputs.prompt,
......@@ -207,17 +226,33 @@ class OmniHandler(BaseOmniHandler):
if chunk:
yield chunk
elif stage_output.final_output_type == "audio":
mm_output = stage_output.multimodal_output
if mm_output:
chunk = await self.audio.format_output(
mm_output,
request_id,
response_format=inputs.response_format,
request_type=inputs.request_type,
speed=inputs.speed,
)
if chunk:
yield chunk
except EngineShutdown:
logger.info(f"Request {request_id} aborted due to shutdown")
raise
except Exception as e:
logger.error(f"Error during generation for request {request_id}: {e}")
yield self._error_chunk(request_id, str(e))
yield self._error_chunk(request_id, str(e), inputs.request_type)
def build_engine_inputs(
async def build_engine_inputs(
self,
parsed_request: Union[
NvCreateImageRequest, NvCreateVideoRequest, Dict[str, Any]
NvCreateImageRequest,
NvCreateVideoRequest,
NvCreateAudioSpeechRequest,
Dict[str, Any],
],
request_type: RequestType,
image: PIL.Image.Image | None = None,
......@@ -226,7 +261,7 @@ class OmniHandler(BaseOmniHandler):
Args:
parsed_request: Output from parse_request_type -- a Pydantic model
for image/video requests, or a raw dict for chat completions.
for image/video/audio requests, or a raw dict for chat completions.
request_type: The RequestType determined by parse_request_type.
image: Pre-loaded PIL Image for I2V requests (from input_reference).
......@@ -242,29 +277,24 @@ class OmniHandler(BaseOmniHandler):
elif request_type == RequestType.VIDEO_GENERATION:
assert isinstance(parsed_request, NvCreateVideoRequest)
return self._engine_inputs_from_video(parsed_request, image=image)
elif request_type == RequestType.AUDIO_GENERATION:
raise NotImplementedError("Audio generation is not yet supported")
assert isinstance(parsed_request, NvCreateAudioSpeechRequest)
return await self.audio.build_engine_inputs(parsed_request)
raise ValueError(f"Unknown request type: {request_type}")
def _engine_inputs_from_chat(self, request: Dict[str, Any]) -> EngineInputs:
"""Build engine inputs from a chat completions request dict."""
# Chat completions request does not support extra_body passthrough
# So, we can't extract any diffusion related params from the raw_request
# It falls back to default sampling params
text_prompt = self._extract_text_prompt(request)
if text_prompt is None:
raise ValueError("No user message found in chat completion request")
prompt = OmniTextPrompt(prompt=text_prompt)
sampling_params_list = None
return EngineInputs(
prompt=prompt,
sampling_params_list=sampling_params_list,
sampling_params_list=None,
request_type=RequestType.CHAT_COMPLETION,
fps=0,
)
......@@ -276,9 +306,9 @@ class OmniHandler(BaseOmniHandler):
prompt = OmniTextPrompt(
prompt=req.prompt,
negative_prompt=nvext.negative_prompt
if nvext and nvext.negative_prompt
else None,
negative_prompt=(
nvext.negative_prompt if nvext and nvext.negative_prompt else None
),
)
sp = OmniDiffusionSamplingParams(
......@@ -331,9 +361,9 @@ class OmniHandler(BaseOmniHandler):
prompt = OmniTextPrompt(
prompt=req.prompt,
negative_prompt=nvext.negative_prompt
if nvext and nvext.negative_prompt
else None,
negative_prompt=(
nvext.negative_prompt if nvext and nvext.negative_prompt else None
),
)
if image is not None:
......@@ -577,9 +607,11 @@ class OmniHandler(BaseOmniHandler):
"role": "assistant",
"content": delta_text,
},
"finish_reason": normalize_finish_reason(output.finish_reason)
"finish_reason": (
normalize_finish_reason(output.finish_reason)
if output.finish_reason
else None,
else None
),
}
],
}
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""TTS/audio utility functions for the vLLM-Omni backend."""
import json
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
def ensure_dummy_tokenizer_for_tts(model: str) -> list[Path]:
"""Create a minimal tokenizer.json for TTS models that lack one.
Audio/TTS models (e.g., Qwen3-TTS) use a custom speech tokenizer and don't
ship the standard tokenizer.json expected by the Rust ModelDeploymentCard
loader. This writes a placeholder so register_model doesn't fail.
Returns the list of created dummy paths so the caller can delete them
after registration (otherwise the fake tokenizer poisons vLLM-Omni's
inference-time AutoTokenizer.from_pretrained call).
This is a short-term workaround. The long-term fix is making TokenizerKind
optional in ModelDeploymentCard::from_repo_checkout().
"""
from huggingface_hub import scan_cache_dir
created: list[Path] = []
cache_info = scan_cache_dir()
for repo in cache_info.repos:
if repo.repo_id == model:
for revision in repo.revisions:
tokenizer_path = Path(revision.snapshot_path) / "tokenizer.json"
if not tokenizer_path.exists():
logger.warning(
"TTS model %s has no tokenizer.json; "
"creating a minimal placeholder at %s",
model,
tokenizer_path,
)
# Write a minimal but valid HF tokenizer JSON that
# tokenizers.TokenizerFast.from_file() can parse without
# crashing. The "model" key with type "BPE" is the
# minimum required structure.
minimal_tokenizer = {
"version": "1.0",
"model": {"type": "BPE", "vocab": {}, "merges": []},
}
tokenizer_path.write_text(json.dumps(minimal_tokenizer))
created.append(tokenizer_path)
return created
return created
def cleanup_dummy_tokenizer_for_tts(paths: list[Path]):
"""Remove dummy tokenizer.json files created by ensure_dummy_tokenizer_for_tts.
Must be called after register_model() completes so the fake tokenizer
doesn't interfere with vLLM-Omni's inference-time tokenizer loading
(AutoTokenizer.from_pretrained picks up our stub and crashes).
"""
for path in paths:
try:
path.unlink(missing_ok=True)
logger.info("Removed dummy tokenizer placeholder: %s", path)
except OSError as e:
logger.warning("Failed to remove dummy tokenizer %s: %s", path, e)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for AudioGenerationHandler."""
from unittest.mock import MagicMock
import pytest
try:
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.utils.output_modalities import RequestType
from dynamo.vllm.omni.audio_handler import AudioGenerationHandler
except ImportError:
pytest.skip("vLLM omni dependencies not available", allow_module_level=True)
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
def _make_audio_handler(**config_overrides):
"""Create an AudioGenerationHandler with mocked dependencies."""
config = MagicMock()
config.model = "test-tts-model"
config.served_model_name = None
config.tts_max_instructions_length = 500
config.tts_max_new_tokens_min = 1
config.tts_max_new_tokens_max = 4096
config.tts_ref_audio_timeout = 15
config.tts_ref_audio_max_bytes = 50 * 1024 * 1024
for k, v in config_overrides.items():
setattr(config, k, v)
engine_client = MagicMock()
engine_client.model_config.hf_config = MagicMock(spec=[])
handler = AudioGenerationHandler(
config=config,
engine_client=engine_client,
media_output_fs=None,
media_output_http_url=None,
)
return handler
class TestValidateTtsRequest:
"""Tests for _validate_tts_request."""
@pytest.mark.asyncio
async def test_empty_input_rejected(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(input=" ")
with pytest.raises(ValueError, match="Input text cannot be empty"):
await handler.build_engine_inputs(req)
def test_invalid_task_type_rejected_by_pydantic(self):
"""Pydantic Literal validation rejects invalid task_type at construction."""
with pytest.raises(Exception):
NvCreateAudioSpeechRequest(input="hello", task_type="Banana")
def test_valid_task_types_accepted(self):
handler = _make_audio_handler()
for task in ("CustomVoice", "VoiceDesign", "Base"):
req = NvCreateAudioSpeechRequest(input="hello", task_type=task)
if task == "VoiceDesign":
req.instructions = "cheerful"
elif task == "Base":
req.ref_audio = "data:audio/wav;base64,AAAA"
handler._validate_tts_request(req)
def test_voice_design_requires_instructions(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(input="hello", task_type="VoiceDesign")
with pytest.raises(ValueError, match="instructions"):
handler._validate_tts_request(req)
def test_base_requires_ref_audio(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(input="hello", task_type="Base")
with pytest.raises(ValueError, match="ref_audio"):
handler._validate_tts_request(req)
def test_ref_text_only_for_base(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(
input="hello", task_type="CustomVoice", ref_text="foo"
)
with pytest.raises(ValueError, match="only valid for Base"):
handler._validate_tts_request(req)
def test_instructions_length_enforced(self):
handler = _make_audio_handler(tts_max_instructions_length=10)
req = NvCreateAudioSpeechRequest(input="hello", instructions="x" * 11)
with pytest.raises(ValueError, match="Instructions too long"):
handler._validate_tts_request(req)
def test_max_new_tokens_range(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(input="hello", max_new_tokens=0)
with pytest.raises(ValueError, match="at least"):
handler._validate_tts_request(req)
req = NvCreateAudioSpeechRequest(input="hello", max_new_tokens=99999)
with pytest.raises(ValueError, match="cannot exceed"):
handler._validate_tts_request(req)
def test_invalid_voice_rejected_when_speakers_loaded(self):
handler = _make_audio_handler()
handler._tts_supported_speakers = {"vivian", "ryan"}
req = NvCreateAudioSpeechRequest(input="hello", voice="nonexistent")
with pytest.raises(ValueError, match="Invalid voice"):
handler._validate_tts_request(req)
def test_valid_voice_accepted(self):
handler = _make_audio_handler()
handler._tts_supported_speakers = {"vivian", "ryan"}
req = NvCreateAudioSpeechRequest(input="hello", voice="Vivian")
handler._validate_tts_request(req) # Should not raise
def test_invalid_language_rejected_when_languages_loaded(self):
handler = _make_audio_handler()
handler._tts_supported_languages = {"english", "chinese"}
req = NvCreateAudioSpeechRequest(input="hello", language="Klingon")
with pytest.raises(ValueError, match="Invalid language"):
handler._validate_tts_request(req)
def test_auto_language_always_accepted(self):
handler = _make_audio_handler()
handler._tts_supported_languages = {"english"}
req = NvCreateAudioSpeechRequest(input="hello", language="Auto")
handler._validate_tts_request(req) # Should not raise
class TestIsTtsModel:
"""Tests for _is_tts_model detection."""
def test_qwen3_tts_detected(self):
handler = _make_audio_handler()
stage = MagicMock()
stage.model_stage = "qwen3_tts"
handler.engine_client.stage_list = [stage]
assert handler._is_tts_model() is True
def test_non_tts_model(self):
handler = _make_audio_handler()
stage = MagicMock()
stage.model_stage = "diffusion"
handler.engine_client.stage_list = [stage]
assert handler._is_tts_model() is False
def test_no_stage_list(self):
handler = _make_audio_handler()
handler.engine_client.stage_list = None
assert handler._is_tts_model() is False
class TestEngineInputsFromAudio:
"""Tests for build_engine_inputs."""
@pytest.mark.asyncio
async def test_generic_path_for_non_tts(self):
"""Non-TTS model gets plain text prompt."""
handler = _make_audio_handler()
stage = MagicMock()
stage.model_stage = "diffusion"
handler.engine_client.stage_list = [stage]
req = NvCreateAudioSpeechRequest(input="Hello world")
inputs = await handler.build_engine_inputs(req)
assert inputs.request_type == RequestType.AUDIO_GENERATION
assert inputs.prompt["prompt"] == "Hello world"
assert inputs.sampling_params_list is None
@pytest.mark.asyncio
async def test_empty_input_rejected(self):
handler = _make_audio_handler()
req = NvCreateAudioSpeechRequest(input=" ")
with pytest.raises(ValueError, match="empty"):
await handler.build_engine_inputs(req)
@pytest.mark.asyncio
async def test_speed_propagated(self):
"""Speed from request is stored in EngineInputs."""
handler = _make_audio_handler()
handler.engine_client.stage_list = None # non-TTS path
req = NvCreateAudioSpeechRequest(input="hello", speed=2.0)
inputs = await handler.build_engine_inputs(req)
assert inputs.speed == 2.0
class TestExtractAudioTensor:
"""Tests for _extract_audio_tensor."""
def test_extracts_from_audio_key(self):
import numpy as np
handler = _make_audio_handler()
mm = {"audio": np.array([0.1, -0.2, 0.3], dtype=np.float32), "sr": 24000}
audio_np, sr = handler._extract_audio_tensor(mm)
assert sr == 24000
assert len(audio_np) == 3
def test_extracts_from_model_outputs_key(self):
import numpy as np
handler = _make_audio_handler()
mm = {
"model_outputs": np.array([0.5, -0.5], dtype=np.float32),
"sr": 16000,
}
audio_np, sr = handler._extract_audio_tensor(mm)
assert sr == 16000
assert len(audio_np) == 2
def test_missing_audio_raises(self):
handler = _make_audio_handler()
with pytest.raises(ValueError, match="No audio data"):
handler._extract_audio_tensor({"sr": 24000})
def test_squeezes_extra_dims(self):
import numpy as np
handler = _make_audio_handler()
mm = {"audio": np.array([[0.1, 0.2, 0.3]], dtype=np.float32), "sr": 24000}
audio_np, _ = handler._extract_audio_tensor(mm)
assert audio_np.ndim == 1
class TestEncodeAudio:
"""Tests for _encode_audio."""
def test_wav_encoding(self):
import numpy as np
handler = _make_audio_handler()
audio = np.zeros(2400, dtype=np.float32)
audio_bytes, media_type = handler._encode_audio(audio, 24000, "wav")
assert media_type == "audio/wav"
assert len(audio_bytes) > 0
assert audio_bytes[:4] == b"RIFF" # WAV header
def test_unsupported_format_falls_back_to_wav(self):
import numpy as np
handler = _make_audio_handler()
audio = np.zeros(100, dtype=np.float32)
_, media_type = handler._encode_audio(audio, 24000, "xyz")
assert media_type == "audio/wav"
def test_default_format_is_wav(self):
import numpy as np
handler = _make_audio_handler()
audio = np.zeros(100, dtype=np.float32)
_, media_type = handler._encode_audio(audio, 24000)
assert media_type == "audio/wav"
class TestFormatAudioChunk:
"""Tests for format_output."""
@pytest.mark.asyncio
async def test_empty_mm_output_returns_error(self):
handler = _make_audio_handler()
result = await handler.format_output({}, "req-1")
assert result["status"] == "failed"
assert "No audio generated" in result["error"]
@pytest.mark.asyncio
async def test_successful_generation(self):
import numpy as np
handler = _make_audio_handler()
mm = {"audio": np.random.randn(4800).astype(np.float32), "sr": 24000}
result = await handler.format_output(mm, "req-1")
assert result["status"] == "completed"
assert result["object"] == "audio.speech"
assert len(result["data"]) == 1
assert result["data"][0]["b64_json"] is not None
......@@ -8,6 +8,7 @@ import pytest
try:
from PIL import Image
from dynamo.common.protocols.audio_protocol import NvCreateAudioSpeechRequest
from dynamo.common.protocols.image_protocol import NvCreateImageRequest
from dynamo.common.protocols.video_protocol import NvCreateVideoRequest, VideoNvExt
from dynamo.common.utils.output_modalities import RequestType
......@@ -86,20 +87,22 @@ class TestPrepareImageOutput:
class TestBuildEngineInputs:
def test_chat_completion(self):
@pytest.mark.asyncio
async def test_chat_completion(self):
"""Chat request extracts text prompt with no sampling params."""
handler = _make_handler()
raw = {"messages": [{"role": "user", "content": "hello"}]}
inputs = handler.build_engine_inputs(raw, RequestType.CHAT_COMPLETION)
inputs = await handler.build_engine_inputs(raw, RequestType.CHAT_COMPLETION)
assert inputs.request_type == RequestType.CHAT_COMPLETION
assert inputs.prompt["prompt"] == "hello"
assert inputs.sampling_params_list is None
def test_image_generation(self):
@pytest.mark.asyncio
async def test_image_generation(self):
"""Image request parses prompt, size, and creates diffusion sampling params."""
handler = _make_handler()
req = NvCreateImageRequest(prompt="a cat", size="512x512")
inputs = handler.build_engine_inputs(req, RequestType.IMAGE_GENERATION)
inputs = await handler.build_engine_inputs(req, RequestType.IMAGE_GENERATION)
assert inputs.request_type == RequestType.IMAGE_GENERATION
assert inputs.prompt["prompt"] == "a cat"
assert len(inputs.sampling_params_list) == 1
......@@ -107,22 +110,38 @@ class TestBuildEngineInputs:
assert sp.height == 512
assert sp.width == 512
def test_video_generation(self):
@pytest.mark.asyncio
async def test_video_generation(self):
"""Video request parses prompt, size, seconds, and sets fps."""
handler = _make_handler()
req = NvCreateVideoRequest(
prompt="a drone", model="test", size="832x480", seconds=2
)
inputs = handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION)
inputs = await handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION)
assert inputs.request_type == RequestType.VIDEO_GENERATION
assert inputs.prompt["prompt"] == "a drone"
assert inputs.fps > 0
def test_audio_not_implemented(self):
"""Audio generation raises NotImplementedError."""
@pytest.mark.asyncio
async def test_audio_generation_delegates_toaudio(self):
"""Audio request delegates to audio."""
handler = _make_handler()
with pytest.raises(NotImplementedError):
handler.build_engine_inputs({}, RequestType.AUDIO_GENERATION)
expected = EngineInputs(
prompt={"prompt": "Hello world"},
request_type=RequestType.AUDIO_GENERATION,
)
async def mock_engine_inputs(req):
return expected
handler.audio = MagicMock()
handler.audio.build_engine_inputs = mock_engine_inputs
inputs = await handler.build_engine_inputs(
NvCreateAudioSpeechRequest(input="Hello world"),
RequestType.AUDIO_GENERATION,
)
assert inputs.request_type == RequestType.AUDIO_GENERATION
assert inputs.prompt["prompt"] == "Hello world"
class TestFormatTextChunk:
......@@ -254,7 +273,8 @@ class TestFormatVideoChunk:
class TestI2VEngineInputs:
"""Tests for image-to-video: multi_modal_data attachment, I2V nvext params, and protocol fields."""
def test_t2v_no_multi_modal_data_and_i2v_attaches_image(self):
@pytest.mark.asyncio
async def test_t2v_no_multi_modal_data_and_i2v_attaches_image(self):
"""T2V has no multi_modal_data; I2V attaches image to prompt."""
handler = _make_handler()
req = NvCreateVideoRequest(
......@@ -262,15 +282,18 @@ class TestI2VEngineInputs:
)
# T2V: no image
t2v = handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION)
t2v = await handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION)
assert "multi_modal_data" not in t2v.prompt
# I2V: image attached
img = Image.new("RGB", (64, 64), color="red")
i2v = handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION, image=img)
i2v = await handler.build_engine_inputs(
req, RequestType.VIDEO_GENERATION, image=img
)
assert i2v.prompt["multi_modal_data"]["image"] is img
def test_i2v_nvext_params_on_sampling_params(self):
@pytest.mark.asyncio
async def test_i2v_nvext_params_on_sampling_params(self):
"""boundary_ratio and guidance_scale_2 are forwarded to sampling params."""
handler = _make_handler()
req = NvCreateVideoRequest(
......@@ -281,9 +304,8 @@ class TestI2VEngineInputs:
boundary_ratio=0.875, guidance_scale_2=1.0, num_inference_steps=40
),
)
sp = handler.build_engine_inputs(
req, RequestType.VIDEO_GENERATION
).sampling_params_list[0]
result = await handler.build_engine_inputs(req, RequestType.VIDEO_GENERATION)
sp = result.sampling_params_list[0]
assert sp.boundary_ratio == 0.875
assert sp.guidance_scale_2 == 1.0
assert sp.num_inference_steps == 40
......
......@@ -4,7 +4,7 @@
title: vLLM-Omni
---
Dynamo supports multimodal generation through the [vLLM-Omni](https://github.com/vllm-project/vllm-omni) backend. This integration exposes text-to-text, text-to-image, and text-to-video capabilities via OpenAI-compatible API endpoints.
Dynamo supports multimodal generation through the [vLLM-Omni](https://github.com/vllm-project/vllm-omni) backend. This integration exposes text-to-text, text-to-image, text-to-video, and text-to-audio (TTS) capabilities via OpenAI-compatible API endpoints.
## Prerequisites
......@@ -26,8 +26,9 @@ pip install git+https://github.com/vllm-project/vllm-omni.git@v0.16.0rc1
| Text-to-Image | `/v1/chat/completions`, `/v1/images/generations` | `image` |
| Text-to-Video | `/v1/videos` | `video` |
| Image-to-Video | `/v1/videos` | `video` |
| Text-to-Audio (TTS) | `/v1/audio/speech` | `audio` |
The `--output-modalities` flag determines which endpoint(s) the worker registers. When set to `image`, both `/v1/chat/completions` (returns inline base64 images) and `/v1/images/generations` are available. When set to `video`, the worker serves `/v1/videos`.
The `--output-modalities` flag determines which endpoint(s) the worker registers. When set to `image`, both `/v1/chat/completions` (returns inline base64 images) and `/v1/images/generations` are available. When set to `video`, the worker serves `/v1/videos`. When set to `audio`, the worker serves `/v1/audio/speech`.
## Tested Models
......@@ -37,6 +38,7 @@ The `--output-modalities` flag determines which endpoint(s) the worker registers
| Text-to-Image | `Qwen/Qwen-Image`, `AIDC-AI/Ovis-Image-7B` |
| Text-to-Video | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`, `Wan-AI/Wan2.2-T2V-A14B-Diffusers` |
| Image-to-Video | `Wan-AI/Wan2.2-TI2V-5B-Diffusers`, `Wan-AI/Wan2.2-I2V-A14B-Diffusers` |
| Text-to-Audio (TTS) | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice`, `Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign` |
To run a non-default model, pass `--model` to any launch script:
......@@ -203,13 +205,80 @@ The `input_reference` field accepts:
The I2V-specific `nvext` fields (`boundary_ratio`, `guidance_scale_2`) control the dual-expert MoE denoising schedule in Wan2.x models. See [Wan2.2-I2V model card](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers) for details.
## Text-to-Audio (TTS)
Launch using the provided script with `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice`:
```bash
bash examples/backends/vllm/launch/agg_omni_audio.sh
```
### CustomVoice (predefined speakers)
```bash
curl -X POST http://localhost:8000/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{
"input": "Hello, how are you?",
"voice": "vivian",
"language": "English"
}' --output output.wav
```
### CustomVoice with style instructions
```bash
curl -X POST http://localhost:8000/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{
"input": "I am so excited!",
"voice": "vivian",
"instructions": "Speak with great enthusiasm"
}' --output excited.wav
```
### VoiceDesign (describe a voice)
```bash
bash examples/backends/vllm/launch/agg_omni_audio.sh --model Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign
curl -X POST http://localhost:8000/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{
"input": "Hello world",
"task_type": "VoiceDesign",
"instructions": "A warm, friendly female voice with a gentle tone"
}' --output voicedesign.wav
```
### Parameters
The `/v1/audio/speech` endpoint follows the [vLLM-Omni](https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/examples/online_serving/qwen3_tts/) API format. All TTS-specific parameters are top-level fields:
| Field | Description | Default |
|---|---|---|
| `input` | Text to synthesize (required) | -- |
| `model` | TTS model name | auto-detected |
| `voice` | Speaker name (e.g., vivian, ryan). Validated against model config. | Vivian |
| `response_format` | Audio format: wav, mp3, pcm, flac, aac, opus | wav |
| `speed` | Speed factor (0.25-4.0) | 1.0 |
| `task_type` | CustomVoice, VoiceDesign, or Base (Qwen3-TTS) | CustomVoice |
| `language` | Language code. Validated against model config. | Auto |
| `instructions` | Voice style/emotion description. Required for VoiceDesign. | -- |
| `ref_audio` | Reference audio URL or base64 data URI. Required for Base. | -- |
| `ref_text` | Transcript of reference audio (Base task) | -- |
| `max_new_tokens` | Maximum tokens to generate (1-4096) | 2048 |
Available voices and languages are loaded dynamically from the model's `config.json` at startup. Non-Qwen3-TTS audio models (e.g., MiMo-Audio) use a generic text prompt and ignore TTS-specific parameters.
## CLI Reference
The omni backend uses a dedicated entrypoint: `python -m dynamo.vllm.omni`.
| Flag | Description |
|---|---|
| `--output-modalities <modality>` | Output modality: `text`, `image`, or `video` |
| `--omni` | Enable the vLLM-Omni orchestrator (required for all omni workloads) |
| `--output-modalities <modality>` | Output modality: `text`, `image`, `video`, or `audio` |
| `--stage-configs-path <path>` | Path to stage config YAML (optional; vLLM-Omni uses model defaults if omitted) |
| `--boundary-ratio <float>` | MoE expert switching boundary (default: 0.875) |
| `--flow-shift <float>` | Scheduler flow_shift (5.0 for 720p, 12.0 for 480p) |
......@@ -231,7 +300,7 @@ The omni backend uses a dedicated entrypoint: `python -m dynamo.vllm.omni`.
## Storage Configuration
Generated images and videos are stored via [fsspec](https://filesystem-spec.readthedocs.io/), which supports local filesystems, S3, GCS, and Azure Blob.
Generated images, videos, and audio files are stored via [fsspec](https://filesystem-spec.readthedocs.io/), which supports local filesystems, S3, GCS, and Azure Blob.
By default, media is written to the local filesystem at `file:///tmp/dynamo_media`. To use cloud storage:
......@@ -254,3 +323,5 @@ Omni pipelines are configured via YAML stage configs. See [`examples/backends/vl
- Image input is supported only for I2V via `input_reference` in `/v1/videos`. Other endpoints accept text prompts only.
- KV cache events are not published for omni workers.
- Each worker supports a single output modality at a time.
- Audio: streaming (`stream: true`) is not yet supported.
- Audio: Base task (voice cloning) is not yet supported.
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../../common/launch_utils.sh"
MODEL="Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
# Parse command line arguments
EXTRA_ARGS=()
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
*)
EXTRA_ARGS+=("$1")
shift
;;
esac
done
HTTP_PORT="${DYN_HTTP_PORT:-8000}"
print_launch_banner --no-curl "Launching vLLM-Omni Audio/TTS (1 GPU)" "$MODEL" "$HTTP_PORT"
print_curl_footer <<CURL
curl -X POST http://localhost:${HTTP_PORT}/v1/audio/speech \\
-H 'Content-Type: application/json' \\
-d '{
"input": "Hey, this is generated using Dynamo!",
"model": "${MODEL}",
"voice": "vivian",
"language": "English"
}' \\
-o dynamo-audio.wav
CURL
python -m dynamo.frontend &
FRONTEND_PID=$!
sleep 2
echo "Starting Omni Audio worker..."
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \
python -m dynamo.vllm.omni \
--model "$MODEL" \
--output-modalities audio \
--media-output-fs-url file:///tmp/dynamo_media \
--trust-remote-code \
--enforce-eager \
"${EXTRA_ARGS[@]}" &
# Exit on first worker failure; kill 0 in the EXIT trap tears down the rest
wait_any_exit
......@@ -19,6 +19,7 @@ use crate::protocols::openai::ParsingOptions;
use crate::types::{
generic::tensor::TensorStreamingEngine,
openai::{
audios::OpenAIAudiosStreamingEngine,
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
images::OpenAIImagesStreamingEngine, videos::OpenAIVideosStreamingEngine,
......@@ -149,6 +150,13 @@ impl Model {
.any(|entry| entry.value().has_videos_engine())
}
/// Check if any WorkerSet has an audios engine.
pub fn has_audios_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_audios_engine())
}
/// Whether this model should be visible in /v1/models.
pub fn is_displayable(&self) -> bool {
let has_serving_engine = |ws: &WorkerSet| {
......@@ -158,6 +166,7 @@ impl Model {
|| ws.has_images_engine()
|| ws.has_tensor_engine()
|| ws.has_videos_engine()
|| ws.has_audios_engine()
};
let has_any_serving_engine = self.worker_sets.iter().any(|entry| {
......@@ -207,6 +216,11 @@ impl Model {
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_audios_engine(&self) -> Result<OpenAIAudiosStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.audios_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_tensor_engine(&self) -> Result<TensorStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.tensor_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
......
......@@ -24,6 +24,7 @@ use crate::{
types::{
generic::tensor::TensorStreamingEngine,
openai::{
audios::OpenAIAudiosStreamingEngine,
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
......@@ -290,6 +291,16 @@ impl ModelManager {
.get_videos_engine()
}
pub fn get_audios_engine(
&self,
model: &str,
) -> Result<OpenAIAudiosStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_audios_engine()
}
// -- Combined engine + parsing options (atomically from one WorkerSet) --
pub fn get_chat_completions_engine_with_parsing(
......@@ -456,6 +467,27 @@ impl ModelManager {
Ok(())
}
pub fn add_audios_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIAudiosStreamingEngine,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_audios_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_audios_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.audios_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(())
}
pub fn add_prefill_model(
&self,
model: &str,
......
......@@ -34,6 +34,7 @@ use crate::{
protocols::{
common::llm_backend::EmbeddingsEngineOutput,
openai::{
audios::{NvAudioSpeechResponse, NvCreateAudioSpeechRequest},
chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
},
......@@ -685,7 +686,19 @@ impl ModelWatcher {
worker_set.videos_engine = Some(Arc::new(videos_router));
}
// TODO: add audio models support
if card.model_type.supports_audios() {
let audios_router = PushRouter::<
NvCreateAudioSpeechRequest,
Annotated<NvAudioSpeechResponse>,
>::from_client_with_threshold(
client.clone(),
self.router_config.router_mode,
None,
None,
)
.await?;
worker_set.audios_engine = Some(Arc::new(audios_router));
}
} else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
// Case: Text + Chat (pure text-to-text, no diffusion)
let push_router = PushRouter::<
......
......@@ -16,6 +16,7 @@ use crate::{
types::{
generic::tensor::TensorStreamingEngine,
openai::{
audios::OpenAIAudiosStreamingEngine,
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
......@@ -41,6 +42,7 @@ pub struct WorkerSet {
pub(crate) embeddings_engine: Option<OpenAIEmbeddingsStreamingEngine>,
pub(crate) images_engine: Option<OpenAIImagesStreamingEngine>,
pub(crate) videos_engine: Option<OpenAIVideosStreamingEngine>,
pub(crate) audios_engine: Option<OpenAIAudiosStreamingEngine>,
pub(crate) tensor_engine: Option<TensorStreamingEngine>,
/// KV router for this set's workers (if KV mode)
......@@ -65,6 +67,7 @@ impl WorkerSet {
embeddings_engine: None,
images_engine: None,
videos_engine: None,
audios_engine: None,
tensor_engine: None,
kv_router: None,
worker_monitor: None,
......@@ -104,6 +107,10 @@ impl WorkerSet {
self.videos_engine.is_some()
}
pub fn has_audios_engine(&self) -> bool {
self.audios_engine.is_some()
}
pub fn has_tensor_engine(&self) -> bool {
self.tensor_engine.is_some()
}
......@@ -119,6 +126,7 @@ impl WorkerSet {
&& !self.has_embeddings_engine()
&& !self.has_images_engine()
&& !self.has_videos_engine()
&& !self.has_audios_engine()
&& !self.has_tensor_engine()
}
......
......@@ -304,6 +304,9 @@ pub enum Endpoint {
/// OAI Videos
Videos,
/// OAI Audio Speech
Audios,
/// OAI Responses
Responses,
......@@ -1026,6 +1029,7 @@ impl std::fmt::Display for Endpoint {
Endpoint::Embeddings => write!(f, "embeddings"),
Endpoint::Images => write!(f, "images"),
Endpoint::Videos => write!(f, "videos"),
Endpoint::Audios => write!(f, "audios"),
Endpoint::Responses => write!(f, "responses"),
Endpoint::AnthropicMessages => write!(f, "anthropic_messages"),
Endpoint::Tensor => write!(f, "tensor"),
......@@ -1041,6 +1045,7 @@ impl Endpoint {
Endpoint::Embeddings => "embeddings",
Endpoint::Images => "images",
Endpoint::Videos => "videos",
Endpoint::Audios => "audios",
Endpoint::Responses => "responses",
Endpoint::AnthropicMessages => "anthropic_messages",
Endpoint::Tensor => "tensor",
......
......@@ -46,6 +46,7 @@ use crate::engines::ValidateRequest;
use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator;
use crate::protocols::openai::nvext::apply_header_routing_overrides;
use crate::protocols::openai::{
audios::{NvAudioSpeechResponse, NvCreateAudioSpeechRequest},
chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionResponse,
NvCreateChatCompletionStreamResponse,
......@@ -2201,6 +2202,113 @@ pub fn videos_router(
(vec![doc, stream_doc], router)
}
async fn audio_speech(
State(state): State<Arc<service_v2::State>>,
headers: HeaderMap,
Json(request): Json<NvCreateAudioSpeechRequest>,
) -> Result<Response, ErrorResponse> {
// return a 503 if the service is not ready
check_ready(&state)?;
let response_format = request.response_format.clone();
let request_id = get_or_create_request_id(request.user.as_deref(), &headers);
let request = Context::with_id(request, request_id);
let request_id = request.id().to_string();
let streaming = false;
// model is optional in the request; fall back to the first registered model
let model = request.model.clone().unwrap_or_else(|| {
state
.manager()
.model_display_names()
.into_iter()
.next()
.unwrap_or_default()
});
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let engine = state
.manager()
.get_audios_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
let mut inflight =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Audios, streaming);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate audio"))?;
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let response = NvAudioSpeechResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!("Failed to fold audio stream for {}: {:?}", request_id, e);
ErrorMessage::internal_server_error("Failed to fold audio stream")
})?;
// Check for failure before marking success
if response.status == "failed" {
return Ok((axum::http::StatusCode::BAD_REQUEST, Json(response)).into_response());
}
inflight.mark_ok();
// If response contains b64_json audio data, decode and return as binary
// (matching OpenAI/vLLM-Omni behavior: curl --output file.wav)
if let Some(first) = response.data.first()
&& let Some(b64) = &first.b64_json
&& let Ok(audio_bytes) = base64::engine::general_purpose::STANDARD.decode(b64)
{
let content_type = match response_format.as_deref().unwrap_or("wav") {
"mp3" => "audio/mpeg",
"flac" => "audio/flac",
"pcm" => "audio/pcm",
"aac" => "audio/aac",
"opus" => "audio/ogg; codecs=opus",
_ => "audio/wav",
};
return Ok(Response::builder()
.header("content-type", content_type)
.body(axum::body::Body::from(audio_bytes))
.unwrap());
}
// Fallback: return JSON (url format responses)
Ok(Json(response).into_response())
}
/// Create an Axum [`Router`] for the Audio Speech endpoint
/// Default path is `/v1/audio/speech`
pub fn audios_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/audio/speech".to_string());
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(audio_speech))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc], router)
}
#[cfg(test)]
mod tests {
......
......@@ -56,6 +56,7 @@ struct StateFlags {
embeddings_endpoints_enabled: AtomicBool,
images_endpoints_enabled: AtomicBool,
videos_endpoints_enabled: AtomicBool,
audios_endpoints_enabled: AtomicBool,
responses_endpoints_enabled: AtomicBool,
anthropic_endpoints_enabled: AtomicBool,
}
......@@ -68,8 +69,7 @@ impl StateFlags {
EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Images => self.images_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Videos => self.videos_endpoints_enabled.load(Ordering::Relaxed),
// TODO: add audios_endpoints_enabled flag
EndpointType::Audios => false,
EndpointType::Audios => self.audios_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::AnthropicMessages => {
self.anthropic_endpoints_enabled.load(Ordering::Relaxed)
......@@ -94,8 +94,9 @@ impl StateFlags {
EndpointType::Videos => self
.videos_endpoints_enabled
.store(enabled, Ordering::Relaxed),
// TODO: add audios_endpoints_enabled flag
EndpointType::Audios => {}
EndpointType::Audios => self
.audios_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Responses => self
.responses_endpoints_enabled
.store(enabled, Ordering::Relaxed),
......@@ -122,6 +123,7 @@ impl State {
embeddings_endpoints_enabled: AtomicBool::new(false),
images_endpoints_enabled: AtomicBool::new(false),
videos_endpoints_enabled: AtomicBool::new(false),
audios_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false),
anthropic_endpoints_enabled: AtomicBool::new(false),
},
......@@ -587,6 +589,7 @@ impl HttpServiceConfigBuilder {
super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok());
let (images_docs, images_route) = super::openai::images_router(state.clone(), None);
let (videos_docs, videos_route) = super::openai::videos_router(state.clone(), None);
let (audios_docs, audios_route) = super::openai::audios_router(state.clone(), None);
let (responses_docs, responses_route) = super::openai::responses_router(
state.clone(),
request_template.clone(),
......@@ -598,6 +601,7 @@ impl HttpServiceConfigBuilder {
endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route));
endpoint_routes.insert(EndpointType::Images, (images_docs, images_route));
endpoint_routes.insert(EndpointType::Videos, (videos_docs, videos_route));
endpoint_routes.insert(EndpointType::Audios, (audios_docs, audios_route));
endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route));
if env_is_truthy(env_llm::DYN_ENABLE_ANTHROPIC_API) {
......
......@@ -11,6 +11,7 @@ use super::{
use crate::protocols::openai::common_ext::CommonExtProvider;
use crate::types::TokenIdType;
pub mod audios;
pub mod chat_completions;
pub mod common_ext;
pub mod completions;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment