Unverified Commit 10152d21 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Realtime API] Adds minimal realtime API based on websockets (#33187)


Signed-off-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
parent 1a7894db
...@@ -59,6 +59,8 @@ We currently support the following OpenAI APIs: ...@@ -59,6 +59,8 @@ We currently support the following OpenAI APIs:
- Only applicable to [Automatic Speech Recognition (ASR) models](../models/supported_models.md#transcription). - Only applicable to [Automatic Speech Recognition (ASR) models](../models/supported_models.md#transcription).
- [Translation API](#translations-api) (`/v1/audio/translations`) - [Translation API](#translations-api) (`/v1/audio/translations`)
- Only applicable to [Automatic Speech Recognition (ASR) models](../models/supported_models.md#transcription). - Only applicable to [Automatic Speech Recognition (ASR) models](../models/supported_models.md#transcription).
- [Realtime API](#realtime-api) (`/v1/realtime`)
- Only applicable to [Automatic Speech Recognition (ASR) models](../models/supported_models.md#transcription).
In addition, we have the following custom APIs: In addition, we have the following custom APIs:
...@@ -567,6 +569,96 @@ The following extra parameters are supported: ...@@ -567,6 +569,96 @@ The following extra parameters are supported:
--8<-- "vllm/entrypoints/openai/protocol.py:translation-extra-params" --8<-- "vllm/entrypoints/openai/protocol.py:translation-extra-params"
``` ```
### Realtime API
The Realtime API provides WebSocket-based streaming audio transcription, allowing real-time speech-to-text as audio is being recorded.
!!! note
To use the Realtime API, please install with extra audio dependencies using `uv pip install vllm[audio]`.
#### Audio Format
Audio must be sent as base64-encoded PCM16 audio at 16kHz sample rate, mono channel.
#### Protocol Overview
1. Client connects to `ws://host/v1/realtime`
2. Server sends `session.created` event
3. Client optionally sends `session.update` with model/params
4. Client sends `input_audio_buffer.commit` when ready
5. Client sends `input_audio_buffer.append` events with base64 PCM16 chunks
6. Server sends `transcription.delta` events with incremental text
7. Server sends `transcription.done` with final text + usage
8. Repeat from step 5 for next utterance
9. Optionally, client sends input_audio_buffer.commit with final=True
to signal audio input is finished. Useful when streaming audio files
#### Client → Server Events
| Event | Description |
|-------|-------------|
| `input_audio_buffer.append` | Send base64-encoded audio chunk: `{"type": "input_audio_buffer.append", "audio": "<base64>"}` |
| `input_audio_buffer.commit` | Trigger transcription processing or end: `{"type": "input_audio_buffer.commit", "final": bool}` |
| `session.update` | Configure session: `{"type": "session.update", "model": "model-name"}` |
#### Server → Client Events
| Event | Description |
|-------|-------------|
| `session.created` | Connection established with session ID and timestamp |
| `transcription.delta` | Incremental transcription text: `{"type": "transcription.delta", "delta": "text"}` |
| `transcription.done` | Final transcription with usage stats |
| `error` | Error notification with message and optional code |
#### Python WebSocket Example
??? code
```python
import asyncio
import base64
import json
import websockets
async def realtime_transcribe():
uri = "ws://localhost:8000/v1/realtime"
async with websockets.connect(uri) as ws:
# Wait for session.created
response = await ws.recv()
print(f"Session: {response}")
# Commit buffer
await ws.send(json.dumps({
"type": "input_audio_buffer.commit"
}))
# Send audio chunks (example with file)
with open("audio.raw", "rb") as f:
while chunk := f.read(4096):
await ws.send(json.dumps({
"type": "input_audio_buffer.append",
"audio": base64.b64encode(chunk).decode()
}))
# Signal all audio is sent
await ws.send(json.dumps({
"type": "input_audio_buffer.commit",
"final": True,
}))
# Receive transcription
while True:
response = json.loads(await ws.recv())
if response["type"] == "transcription.delta":
print(response["delta"], end="", flush=True)
elif response["type"] == "transcription.done":
print(f"\nFinal: {response['text']}")
break
asyncio.run(realtime_transcribe())
```
### Tokenizer API ### Tokenizer API
Our Tokenizer API is a simple wrapper over [HuggingFace-style tokenizers](https://huggingface.co/docs/transformers/en/main_classes/tokenizer). Our Tokenizer API is a simple wrapper over [HuggingFace-style tokenizers](https://huggingface.co/docs/transformers/en/main_classes/tokenizer).
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This script demonstrates how to use the vLLM Realtime WebSocket API to perform
audio transcription by uploading an audio file.
Before running this script, you must start the vLLM server with a realtime-capable
model, for example:
vllm serve mistralai/Voxtral-Mini-3B-Realtime-2602 --enforce-eager
Requirements:
- vllm with audio support
- websockets
- librosa
- numpy
The script:
1. Connects to the Realtime WebSocket endpoint
2. Converts an audio file to PCM16 @ 16kHz
3. Sends audio chunks to the server
4. Receives and prints transcription as it streams
"""
import argparse
import asyncio
import base64
import json
import librosa
import numpy as np
import websockets
from vllm.assets.audio import AudioAsset
def audio_to_pcm16_base64(audio_path: str) -> str:
"""
Load an audio file and convert it to base64-encoded PCM16 @ 16kHz.
"""
# Load audio and resample to 16kHz mono
audio, _ = librosa.load(audio_path, sr=16000, mono=True)
# Convert to PCM16
pcm16 = (audio * 32767).astype(np.int16)
# Encode as base64
return base64.b64encode(pcm16.tobytes()).decode("utf-8")
async def realtime_transcribe(audio_path: str, host: str, port: int, model: str):
"""
Connect to the Realtime API and transcribe an audio file.
"""
uri = f"ws://{host}:{port}/v1/realtime"
async with websockets.connect(uri) as ws:
# Wait for session.created
response = json.loads(await ws.recv())
if response["type"] == "session.created":
print(f"Session created: {response['id']}")
else:
print(f"Unexpected response: {response}")
return
# Validate model
await ws.send(json.dumps({"type": "session.update", "model": model}))
# Signal ready to start
await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
# Convert audio file to base64 PCM16
print(f"Loading audio from: {audio_path}")
audio_base64 = audio_to_pcm16_base64(audio_path)
# Send audio in chunks (4KB of raw audio = ~8KB base64)
chunk_size = 4096
audio_bytes = base64.b64decode(audio_base64)
total_chunks = (len(audio_bytes) + chunk_size - 1) // chunk_size
print(f"Sending {total_chunks} audio chunks...")
for i in range(0, len(audio_bytes), chunk_size):
chunk = audio_bytes[i : i + chunk_size]
await ws.send(
json.dumps(
{
"type": "input_audio_buffer.append",
"audio": base64.b64encode(chunk).decode("utf-8"),
}
)
)
# Signal all audio is sent
await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
print("Audio sent. Waiting for transcription...\n")
# Receive transcription
print("Transcription: ", end="", flush=True)
while True:
response = json.loads(await ws.recv())
if response["type"] == "transcription.delta":
print(response["delta"], end="", flush=True)
elif response["type"] == "transcription.done":
print(f"\n\nFinal transcription: {response['text']}")
if response.get("usage"):
print(f"Usage: {response['usage']}")
break
elif response["type"] == "error":
print(f"\nError: {response['error']}")
break
def main(args):
if args.audio_path:
audio_path = args.audio_path
else:
# Use default audio asset
audio_path = str(AudioAsset("mary_had_lamb").get_local_path())
print(f"No audio path provided, using default: {audio_path}")
asyncio.run(realtime_transcribe(audio_path, args.host, args.port, args.model))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Realtime WebSocket Transcription Client"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Voxtral-Mini-3B-Realtime-2602",
help="Model that is served and should be pinged.",
)
parser.add_argument(
"--audio_path",
type=str,
default=None,
help="Path to the audio file to transcribe.",
)
parser.add_argument(
"--host",
type=str,
default="localhost",
help="vLLM server host (default: localhost)",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="vLLM server port (default: 8000)",
)
args = parser.parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Minimal Gradio demo for real-time speech transcription using the vLLM Realtime API.
Start the vLLM server first:
vllm serve mistralai/Voxtral-Mini-3B-Realtime-2602 --enforce-eager
Then run this script:
python openai_realtime_microphone_client.py --host localhost --port 8000
Use --share to create a public Gradio link.
Requirements: websockets, numpy, gradio
"""
import argparse
import asyncio
import base64
import json
import queue
import threading
import gradio as gr
import numpy as np
import websockets
SAMPLE_RATE = 16_000
# Global state
audio_queue: queue.Queue = queue.Queue()
transcription_text = ""
is_running = False
ws_url = ""
model = ""
async def websocket_handler():
"""Connect to WebSocket and handle audio streaming + transcription."""
global transcription_text, is_running
async with websockets.connect(ws_url) as ws:
# Wait for session.created
await ws.recv()
# Validate model
await ws.send(json.dumps({"type": "session.update", "model": model}))
# Signal ready
await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
async def send_audio():
while is_running:
try:
chunk = await asyncio.get_event_loop().run_in_executor(
None, lambda: audio_queue.get(timeout=0.1)
)
await ws.send(
json.dumps(
{"type": "input_audio_buffer.append", "audio": chunk}
)
)
except queue.Empty:
continue
async def receive_transcription():
global transcription_text
async for message in ws:
data = json.loads(message)
if data.get("type") == "transcription.delta":
transcription_text += data["delta"]
await asyncio.gather(send_audio(), receive_transcription())
def start_websocket():
"""Start WebSocket connection in background thread."""
global is_running
is_running = True
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(websocket_handler())
except Exception as e:
print(f"WebSocket error: {e}")
def start_recording():
"""Start the transcription service."""
global transcription_text
transcription_text = ""
thread = threading.Thread(target=start_websocket, daemon=True)
thread.start()
return gr.update(interactive=False), gr.update(interactive=True), ""
def stop_recording():
"""Stop the transcription service."""
global is_running
is_running = False
return gr.update(interactive=True), gr.update(interactive=False), transcription_text
def process_audio(audio):
"""Process incoming audio and queue for streaming."""
global transcription_text
if audio is None or not is_running:
return transcription_text
sample_rate, audio_data = audio
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
# Normalize to float
if audio_data.dtype == np.int16:
audio_float = audio_data.astype(np.float32) / 32767.0
else:
audio_float = audio_data.astype(np.float32)
# Resample to 16kHz if needed
if sample_rate != SAMPLE_RATE:
num_samples = int(len(audio_float) * SAMPLE_RATE / sample_rate)
audio_float = np.interp(
np.linspace(0, len(audio_float) - 1, num_samples),
np.arange(len(audio_float)),
audio_float,
)
# Convert to PCM16 and base64 encode
pcm16 = (audio_float * 32767).astype(np.int16)
b64_chunk = base64.b64encode(pcm16.tobytes()).decode("utf-8")
audio_queue.put(b64_chunk)
return transcription_text
# Gradio interface
with gr.Blocks(title="Real-time Speech Transcription") as demo:
gr.Markdown("# Real-time Speech Transcription")
gr.Markdown("Click **Start** and speak into your microphone.")
with gr.Row():
start_btn = gr.Button("Start", variant="primary")
stop_btn = gr.Button("Stop", variant="stop", interactive=False)
audio_input = gr.Audio(sources=["microphone"], streaming=True, type="numpy")
transcription_output = gr.Textbox(label="Transcription", lines=5)
start_btn.click(
start_recording, outputs=[start_btn, stop_btn, transcription_output]
)
stop_btn.click(stop_recording, outputs=[start_btn, stop_btn, transcription_output])
audio_input.stream(
process_audio, inputs=[audio_input], outputs=[transcription_output]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Realtime WebSocket Transcription with Gradio"
)
parser.add_argument(
"--model",
type=str,
default="mistralai/Voxtral-Mini-3B-Realtime-2602",
help="Model that is served and should be pinged.",
)
parser.add_argument(
"--host", type=str, default="localhost", help="vLLM server host"
)
parser.add_argument("--port", type=int, default=8000, help="vLLM server port")
parser.add_argument(
"--share", action="store_true", help="Create public Gradio link"
)
args = parser.parse_args()
ws_url = f"ws://{args.host}:{args.port}/v1/realtime"
model = args.model
demo.launch(share=args.share)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
import json
import librosa
import numpy as np
import pytest
import websockets
from vllm.assets.audio import AudioAsset
from ...utils import RemoteOpenAIServer
from .conftest import add_attention_backend
MISTRAL_FORMAT_ARGS = [
"--tokenizer_mode",
"mistral",
"--config_format",
"mistral",
"--load_format",
"mistral",
]
MODEL_NAME = "mistralai/Voxtral-Mini-3B-Realtime-2602"
def _audio_to_base64_pcm16(path: str, target_sr: int = 16000) -> str:
"""Load audio file, convert to PCM16 @ target sample rate, base64 encode."""
audio, _ = librosa.load(path, sr=target_sr, mono=True)
# Convert float32 [-1, 1] to int16 [-32768, 32767]
audio_int16 = (audio * 32767).astype(np.int16)
audio_bytes = audio_int16.tobytes()
return base64.b64encode(audio_bytes).decode("utf-8")
def _get_websocket_url(server: RemoteOpenAIServer) -> str:
"""Convert HTTP URL to WebSocket URL for realtime endpoint."""
http_url = server.url_root
ws_url = http_url.replace("http://", "ws://")
return f"{ws_url}/v1/realtime"
async def receive_event(ws, timeout: float = 60.0) -> dict:
"""Receive and parse JSON event from WebSocket."""
message = await asyncio.wait_for(ws.recv(), timeout=timeout)
return json.loads(message)
async def send_event(ws, event: dict) -> None:
"""Send JSON event to WebSocket."""
await ws.send(json.dumps(event))
@pytest.fixture
def mary_had_lamb_audio_chunks() -> list[str]:
"""Audio split into ~1 second chunks for streaming."""
path = AudioAsset("mary_had_lamb").get_local_path()
audio, _ = librosa.load(str(path), sr=16000, mono=True)
# Split into ~0.1 second chunks (1600 samples at 16kHz)
chunk_size = 1600
chunks = []
for i in range(0, len(audio), chunk_size):
chunk = audio[i : i + chunk_size]
chunk_int16 = (chunk * 32767).astype(np.int16)
chunk_bytes = chunk_int16.tobytes()
chunks.append(base64.b64encode(chunk_bytes).decode("utf-8"))
return chunks
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Voxtral streaming is not yet public")
async def test_multi_chunk_streaming(
model_name, mary_had_lamb_audio_chunks, rocm_aiter_fa_attention
):
"""Test streaming multiple audio chunks before committing."""
server_args = ["--enforce-eager"]
if model_name.startswith("mistralai"):
server_args += MISTRAL_FORMAT_ARGS
add_attention_backend(server_args, rocm_aiter_fa_attention)
with RemoteOpenAIServer(model_name, server_args) as remote_server:
ws_url = _get_websocket_url(remote_server)
async with websockets.connect(ws_url) as ws:
# Receive session.created
event = await receive_event(ws, timeout=30.0)
assert event["type"] == "session.created"
await send_event(ws, {"type": "session.update", "model": model_name})
# Send commit to start transcription
await send_event(ws, {"type": "input_audio_buffer.commit"})
# Send multiple audio chunks
for chunk in mary_had_lamb_audio_chunks:
await send_event(
ws, {"type": "input_audio_buffer.append", "audio": chunk}
)
# Send commit to end
await send_event(ws, {"type": "input_audio_buffer.commit", "final": True})
# Collect transcription deltas
full_text = ""
done_received = False
while not done_received:
event = await receive_event(ws, timeout=60.0)
if event["type"] == "transcription.delta":
full_text += event["delta"]
elif event["type"] == "transcription.done":
done_received = True
assert "text" in event
elif event["type"] == "error":
pytest.fail(f"Received error: {event}")
# Verify transcription contains expected content
assert event["type"] == "transcription.done"
assert event["text"] == full_text
assert full_text == (
" He has first words I spoke in the original phonograph."
" A little piece of practical poetry. Mary had a little lamb,"
" it squeaked with quite a flow, and everywhere that Mary went,"
" the lamb was sure to go"
)
...@@ -19,11 +19,12 @@ import pytest ...@@ -19,11 +19,12 @@ import pytest
import pytest_asyncio import pytest_asyncio
from vllm import SamplingParams from vllm import SamplingParams
from vllm.inputs.data import StreamingInput
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput from vllm.v1.engine.async_llm import AsyncLLM
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
......
...@@ -7,9 +7,10 @@ from unittest.mock import AsyncMock, MagicMock ...@@ -7,9 +7,10 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from vllm.inputs.data import StreamingInput
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.output_processor import RequestOutputCollector from vllm.v1.engine.output_processor import RequestOutputCollector
......
...@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Iterable, Mapping ...@@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import Any from typing import Any
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType, StreamingInput
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor from vllm.plugins.io_processors import IOProcessor
...@@ -49,7 +49,7 @@ class EngineClient(ABC): ...@@ -49,7 +49,7 @@ class EngineClient(ABC):
@abstractmethod @abstractmethod
def generate( def generate(
self, self,
prompt: EngineCoreRequest | PromptType, prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
*, *,
......
...@@ -36,6 +36,7 @@ async def serve_http( ...@@ -36,6 +36,7 @@ async def serve_http(
h11_max_header_count. h11_max_header_count.
""" """
logger.info("Available routes are:") logger.info("Available routes are:")
# post endpoints
for route in app.routes: for route in app.routes:
methods = getattr(route, "methods", None) methods = getattr(route, "methods", None)
path = getattr(route, "path", None) path = getattr(route, "path", None)
...@@ -45,6 +46,17 @@ async def serve_http( ...@@ -45,6 +46,17 @@ async def serve_http(
logger.info("Route: %s, Methods: %s", path, ", ".join(methods)) logger.info("Route: %s, Methods: %s", path, ", ".join(methods))
# other endpoints
for route in app.routes:
endpoint = getattr(route, "endpoint", None)
methods = getattr(route, "methods", None)
path = getattr(route, "path", None)
if endpoint is None or path is None or methods is not None:
continue
logger.info("Route: %s, Endpoint: %s", path, endpoint.__name__)
# Extract header limit options if present # Extract header limit options if present
h11_max_incomplete_event_size = uvicorn_kwargs.pop( h11_max_incomplete_event_size = uvicorn_kwargs.pop(
"h11_max_incomplete_event_size", None "h11_max_incomplete_event_size", None
......
...@@ -196,6 +196,13 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) -> ...@@ -196,6 +196,13 @@ def build_app(args: Namespace, supported_tasks: tuple["SupportedTask", ...]) ->
register_translations_api_router(app) register_translations_api_router(app)
if "realtime" in supported_tasks:
from vllm.entrypoints.openai.realtime.api_router import (
attach_router as register_realtime_api_router,
)
register_realtime_api_router(app)
if any(task in POOLING_TASKS for task in supported_tasks): if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling import register_pooling_api_routers from vllm.entrypoints.pooling import register_pooling_api_routers
...@@ -319,6 +326,11 @@ async def init_app_state( ...@@ -319,6 +326,11 @@ async def init_app_state(
engine_client, state, args, request_logger, supported_tasks engine_client, state, args, request_logger, supported_tasks
) )
if "realtime" in supported_tasks:
from vllm.entrypoints.openai.realtime.api_router import init_realtime_state
init_realtime_state(engine_client, state, args, request_logger, supported_tasks)
if any(task in POOLING_TASKS for task in supported_tasks): if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling import init_pooling_state from vllm.entrypoints.pooling import init_pooling_state
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from fastapi import APIRouter, FastAPI, WebSocket
from vllm.entrypoints.openai.realtime.connection import RealtimeConnection
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
from vllm.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
router = APIRouter()
@router.websocket("/v1/realtime")
async def realtime_endpoint(websocket: WebSocket):
"""WebSocket endpoint for realtime audio transcription.
Protocol:
1. Client connects to ws://host/v1/realtime
2. Server sends session.created event
3. Client optionally sends session.update with model/params
4. Client sends input_audio_buffer.commit when ready
5. Client sends input_audio_buffer.append events with base64 PCM16 chunks
6. Server processes and sends transcription.delta events
7. Server sends transcription.done with final text + usage
8. Repeat from step 5 for next utterance
9. Optionally, client sends input_audio_buffer.commit with final=True
to signal audio input is finished. Useful when streaming audio files
Audio format: PCM16, 16kHz, mono, base64-encoded
"""
app = websocket.app
serving = app.state.openai_serving_realtime
connection = RealtimeConnection(websocket, serving)
await connection.handle_connection()
def attach_router(app: FastAPI):
"""Attach the realtime router to the FastAPI app."""
app.include_router(router)
logger.info("Realtime API router attached")
def init_realtime_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
state.openai_serving_realtime = (
OpenAIServingRealtime(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
)
if "realtime" in supported_tasks
else None
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
import json
from collections.abc import AsyncGenerator
from http import HTTPStatus
from uuid import uuid4
import numpy as np
from fastapi import WebSocket
from starlette.websockets import WebSocketDisconnect
from vllm import envs
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.realtime.protocol import (
ErrorEvent,
InputAudioBufferAppend,
InputAudioBufferCommit,
SessionCreated,
TranscriptionDelta,
TranscriptionDone,
)
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
logger = init_logger(__name__)
class RealtimeConnection:
"""Manages WebSocket lifecycle and state for realtime transcription.
This class handles:
- WebSocket connection lifecycle (accept, receive, send, close)
- Event routing (session.update, append, commit)
- Audio buffering via asyncio.Queue
- Generation task management
- Error handling and cleanup
"""
def __init__(self, websocket: WebSocket, serving: OpenAIServingRealtime):
self.websocket = websocket
self.connection_id = f"ws-{uuid4()}"
self.serving = serving
self.audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()
self.generation_task: asyncio.Task | None = None
self._is_connected = False
self._is_input_finished = False
self._is_model_validated = False
self._max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
async def handle_connection(self):
"""Main connection loop."""
await self.websocket.accept()
logger.debug("WebSocket connection accepted: %s", self.connection_id)
self._is_connected = True
# Send session created event
await self.send(SessionCreated())
try:
while True:
message = await self.websocket.receive_text()
try:
event = json.loads(message)
await self.handle_event(event)
except json.JSONDecodeError:
await self.send_error("Invalid JSON", "invalid_json")
except Exception as e:
logger.exception("Error handling event: %s", e)
await self.send_error(str(e), "processing_error")
except WebSocketDisconnect:
logger.debug("WebSocket disconnected: %s", self.connection_id)
self._is_connected = False
except Exception as e:
logger.exception("Unexpected error in connection: %s", e)
finally:
await self.cleanup()
def _check_model(self, model: str | None) -> None | ErrorResponse:
if self.serving._is_model_supported(model):
return None
return self.serving.create_error_response(
message=f"The model `{model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
param="model",
)
async def handle_event(self, event: dict):
"""Route events to handlers.
Supported event types:
- session.update: Configure model
- input_audio_buffer.append: Add audio chunk to queue
- input_audio_buffer.commit: Start transcription generation
"""
event_type = event.get("type")
if event_type == "session.update":
logger.debug("Session updated: %s", event)
self._check_model(event["model"])
self._is_model_validated = True
elif event_type == "input_audio_buffer.append":
append_event = InputAudioBufferAppend(**event)
try:
audio_bytes = base64.b64decode(append_event.audio)
# Convert PCM16 bytes to float32 numpy array
audio_array = (
np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32)
/ 32768.0
)
if len(audio_array) / 1024**2 > self._max_audio_filesize_mb:
raise VLLMValidationError(
"Maximum file size exceeded",
parameter="audio_filesize_mb",
value=len(audio_array) / 1024**2,
)
if len(audio_array) == 0:
raise VLLMValidationError("Can't process empty audio.")
# Put audio chunk in queue
self.audio_queue.put_nowait(audio_array)
except Exception as e:
logger.error("Failed to decode audio: %s", e)
await self.send_error("Invalid audio data", "invalid_audio")
elif event_type == "input_audio_buffer.commit":
if not self._is_model_validated:
err_msg = (
"Model not validated. Make sure to validate the"
" model by sending a session.update event."
)
await self.send_error(
err_msg,
"model_not_validated",
)
commit_event = InputAudioBufferCommit(**event)
# final signals that the audio is finished
if commit_event.final:
self._is_input_finished = True
else:
await self.start_generation()
else:
await self.send_error(f"Unknown event type: {event_type}", "unknown_event")
async def audio_stream_generator(self) -> AsyncGenerator[np.ndarray, None]:
"""Generator that yields audio chunks from the queue."""
while True:
audio_chunk = await self.audio_queue.get()
if audio_chunk is None: # Sentinel value to stop
break
yield audio_chunk
async def start_generation(self):
"""Start the transcription generation task."""
if self.generation_task is not None and not self.generation_task.done():
logger.warning("Generation already in progress, ignoring commit")
return
# Create audio stream generator
audio_stream = self.audio_stream_generator()
input_stream = asyncio.Queue[list[int]]()
# Transform to StreamingInput generator
streaming_input_gen = self.serving.transcribe_realtime(
audio_stream, input_stream
)
# Start generation task
self.generation_task = asyncio.create_task(
self._run_generation(streaming_input_gen, input_stream)
)
async def _run_generation(
self,
streaming_input_gen: AsyncGenerator,
input_stream: asyncio.Queue[list[int]],
):
"""Run the generation and stream results back to the client.
This method:
1. Creates sampling parameters from session config
2. Passes the streaming input generator to engine.generate()
3. Streams transcription.delta events as text is generated
4. Sends final transcription.done event with usage stats
5. Feeds generated token IDs back to input_stream for next iteration
6. Cleans up the audio queue
"""
request_id = f"rt-{self.connection_id}-{uuid4()}"
full_text = ""
prompt_token_ids_len: int = 0
completion_tokens_len: int = 0
try:
# Create sampling params
from vllm.sampling_params import RequestOutputKind, SamplingParams
sampling_params = SamplingParams.from_optional(
temperature=0.0,
max_tokens=1,
output_kind=RequestOutputKind.DELTA,
skip_clone=True,
)
# Pass the streaming input generator to the engine
# The engine will consume audio chunks as they arrive and
# stream back transcription results incrementally
result_gen = self.serving.engine_client.generate(
prompt=streaming_input_gen,
sampling_params=sampling_params,
request_id=request_id,
)
# Stream results back to client as they're generated
async for output in result_gen:
if output.outputs and len(output.outputs) > 0:
if not prompt_token_ids_len and output.prompt_token_ids:
prompt_token_ids_len = len(output.prompt_token_ids)
delta = output.outputs[0].text
full_text += delta
# append output to input
input_stream.put_nowait(list(output.outputs[0].token_ids))
await self.send(TranscriptionDelta(delta=delta))
completion_tokens_len += len(output.outputs[0].token_ids)
if not self._is_connected:
# finish because websocket connection was killed
break
if self.audio_queue.empty() and self._is_input_finished:
# finish because client signals that audio input
# is finished
break
usage = UsageInfo(
prompt_tokens=prompt_token_ids_len,
completion_tokens=completion_tokens_len,
total_tokens=prompt_token_ids_len + completion_tokens_len,
)
# Send final completion event
await self.send(TranscriptionDone(text=full_text, usage=usage))
# Clear queue for next utterance
while not self.audio_queue.empty():
self.audio_queue.get_nowait()
except Exception as e:
logger.exception("Error in generation: %s", e)
await self.send_error(str(e), "processing_error")
async def send(
self, event: SessionCreated | TranscriptionDelta | TranscriptionDone
):
"""Send event to client."""
data = event.model_dump_json()
await self.websocket.send_text(data)
async def send_error(self, message: str, code: str | None = None):
"""Send error event to client."""
error_event = ErrorEvent(error=message, code=code)
await self.websocket.send_text(error_event.model_dump_json())
async def cleanup(self):
"""Cleanup resources."""
# Signal audio stream to stop
self.audio_queue.put_nowait(None)
# Cancel generation task if running
if self.generation_task and not self.generation_task.done():
self.generation_task.cancel()
logger.debug("Connection cleanup complete: %s", self.connection_id)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import Literal
from pydantic import Field
from vllm.entrypoints.openai.engine.protocol import (
OpenAIBaseModel,
UsageInfo,
)
from vllm.utils import random_uuid
# Client -> Server Events
class InputAudioBufferAppend(OpenAIBaseModel):
"""Append audio chunk to buffer"""
type: Literal["input_audio_buffer.append"] = "input_audio_buffer.append"
audio: str # base64-encoded PCM16 @ 16kHz
class InputAudioBufferCommit(OpenAIBaseModel):
"""Process accumulated audio buffer"""
type: Literal["input_audio_buffer.commit"] = "input_audio_buffer.commit"
final: bool = False
# Server -> Client Events
class SessionUpdate(OpenAIBaseModel):
"""Configure session parameters"""
type: Literal["session.update"] = "session.update"
model: str | None = None
class SessionCreated(OpenAIBaseModel):
"""Connection established notification"""
type: Literal["session.created"] = "session.created"
id: str = Field(default_factory=lambda: f"sess-{random_uuid()}")
created: int = Field(default_factory=lambda: int(time.time()))
class TranscriptionDelta(OpenAIBaseModel):
"""Incremental transcription text"""
type: Literal["transcription.delta"] = "transcription.delta"
delta: str # Incremental text
class TranscriptionDone(OpenAIBaseModel):
"""Final transcription with usage stats"""
type: Literal["transcription.done"] = "transcription.done"
text: str # Complete transcription
usage: UsageInfo | None = None
class ErrorEvent(OpenAIBaseModel):
"""Error notification"""
type: Literal["error"] = "error"
error: str
code: str | None = None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import AsyncGenerator
from functools import cached_property
from typing import Literal, cast
import numpy as np
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs.data import PromptType, StreamingInput
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsRealtime
logger = init_logger(__name__)
class OpenAIServingRealtime(OpenAIServing):
"""Realtime audio transcription service via WebSocket streaming.
Provides streaming audio-to-text transcription by transforming audio chunks
into StreamingInput objects that can be consumed by the engine.
"""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.task_type: Literal["realtime"] = "realtime"
logger.info("OpenAIServingRealtime initialized for task: %s", self.task_type)
@cached_property
def model_cls(self) -> type[SupportsRealtime]:
"""Get the model class that supports transcription."""
from vllm.model_executor.model_loader import get_model_cls
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsRealtime], model_cls)
async def transcribe_realtime(
self,
audio_stream: AsyncGenerator[np.ndarray, None],
input_stream: asyncio.Queue[list[int]],
) -> AsyncGenerator[StreamingInput, None]:
"""Transform audio stream into StreamingInput for engine.generate().
Args:
audio_stream: Async generator yielding float32 numpy audio arrays
input_stream: Queue containing context token IDs from previous
generation outputs. Used for autoregressive multi-turn
processing where each generation's output becomes the context
for the next iteration.
Yields:
StreamingInput objects containing audio prompts for the engine
"""
# mypy is being stupid
# TODO(Patrick) - fix this
stream_input_iter = cast(
AsyncGenerator[PromptType, None],
self.model_cls.buffer_realtime_audio(
audio_stream, input_stream, self.model_config
),
)
async for prompt in stream_input_iter:
yield StreamingInput(prompt=prompt)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast
import torch import torch
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
...@@ -357,3 +360,15 @@ def to_enc_dec_tuple_list( ...@@ -357,3 +360,15 @@ def to_enc_dec_tuple_list(
(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts for enc_dec_prompt in enc_dec_prompts
] ]
@dataclass
class StreamingInput:
"""Input data for a streaming generation request.
This is used with generate() to support multi-turn streaming sessions
where inputs are provided via an async generator.
"""
prompt: PromptType
sampling_params: SamplingParams | None = None
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Mapping, MutableSequence import asyncio
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, MutableSequence
from contextlib import ExitStack, contextmanager, nullcontext from contextlib import ExitStack, contextmanager, nullcontext
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
...@@ -1015,6 +1016,37 @@ class SupportsQuant: ...@@ -1015,6 +1016,37 @@ class SupportsQuant:
return None return None
@runtime_checkable
class SupportsRealtime(Protocol):
"""The interface required for all models that support transcription."""
supports_realtime: ClassVar[Literal[True]] = True
@classmethod
async def buffer_realtime_audio(
cls,
audio_stream: AsyncGenerator[np.ndarray, None],
input_stream: asyncio.Queue[list[int]],
model_config: ModelConfig,
) -> AsyncGenerator[PromptType, None]: ...
@overload
def supports_realtime(
model: type[object],
) -> TypeIs[type[SupportsRealtime]]: ...
@overload
def supports_realtime(model: object) -> TypeIs[SupportsRealtime]: ...
def supports_realtime(
model: type[object] | object,
) -> TypeIs[type[SupportsRealtime]] | TypeIs[SupportsRealtime]:
return getattr(model, "supports_realtime", False)
@runtime_checkable @runtime_checkable
class SupportsTranscription(Protocol): class SupportsTranscription(Protocol):
"""The interface required for all models that support transcription.""" """The interface required for all models that support transcription."""
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import cached_property, partial
...@@ -20,7 +19,6 @@ from mistral_common.protocol.transcription.request import TranscriptionRequest ...@@ -20,7 +19,6 @@ from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import ( from mistral_common.tokens.tokenizers.audio import (
Audio, Audio,
AudioEncoder, AudioEncoder,
TranscriptionFormat,
) )
from transformers import BatchFeature, TensorType, WhisperConfig from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
...@@ -163,19 +161,10 @@ class VoxtralProcessorAdapter: ...@@ -163,19 +161,10 @@ class VoxtralProcessorAdapter:
assert isinstance(audio, np.ndarray) assert isinstance(audio, np.ndarray)
assert audio.ndim == 1 assert audio.ndim == 1
# pad if necessary if not self._audio_processor.audio_config.is_streaming:
# TODO(Patrick) - remove once mistral-common is bumped audio = self._audio_processor.pad(
if ( audio, self.sampling_rate, is_online_streaming=False
self._audio_processor.audio_config.transcription_format )
!= TranscriptionFormat.STREAMING
):
sig = inspect.signature(self._audio_processor.pad)
if "is_online_streaming" in sig.parameters:
audio = self._audio_processor.pad(
audio, self.sampling_rate, is_online_streaming=False
)
else:
audio = self._audio_processor.pad(audio, self.sampling_rate)
audio_tokens = [self.begin_audio_token_id] + [ audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id self.audio_token_id
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import math import math
from collections.abc import Mapping from collections.abc import AsyncGenerator, Mapping
from typing import Literal, cast from typing import Literal, cast
import numpy as np import numpy as np
...@@ -12,12 +13,14 @@ from mistral_common.protocol.transcription.request import ( ...@@ -12,12 +13,14 @@ from mistral_common.protocol.transcription.request import (
StreamingMode, StreamingMode,
TranscriptionRequest, TranscriptionRequest,
) )
from mistral_common.tokens.tokenizers.audio import Audio from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
from vllm.model_executor.models.voxtral import ( from vllm.model_executor.models.voxtral import (
VoxtralDummyInputsBuilder, VoxtralDummyInputsBuilder,
VoxtralForConditionalGeneration, VoxtralForConditionalGeneration,
...@@ -44,6 +47,8 @@ from .utils import ( ...@@ -44,6 +47,8 @@ from .utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
_PRE_ALLOCATE_BUFFER_SIZE_IN_S = 30
class VoxtralStreamingMultiModalProcessor(VoxtralMultiModalProcessor): class VoxtralStreamingMultiModalProcessor(VoxtralMultiModalProcessor):
def __init__( def __init__(
...@@ -124,29 +129,164 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor: ...@@ -124,29 +129,164 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor:
return (base.unsqueeze(1) + offsets).view(-1) return (base.unsqueeze(1) + offsets).view(-1)
class VoxtralRealtimeBuffer:
def __init__(self, config: AudioConfig) -> None:
self._config = config
self._look_ahead_in_ms = config.streaming_look_ahead_ms
self._look_back_in_ms = config.streaming_look_back_ms
self._sampling_rate = self._config.sampling_rate
self._look_ahead = self._get_len_in_samples(self._look_ahead_in_ms)
self._look_back = self._get_len_in_samples(self._look_back_in_ms)
self._streaming_size = self._get_len_in_samples(1000 / self._config.frame_rate)
# mutable objects
streaming_delay = self._get_len_in_samples(self._config.transcription_delay_ms)
self._start = 0
self._end = streaming_delay + self._streaming_size
# always pre-allocate 30 second buffers
self._buffer_size = _PRE_ALLOCATE_BUFFER_SIZE_IN_S * self._sampling_rate
self._buffer: np.ndarray = np.empty(self._buffer_size, dtype=np.float32)
self._filled_buffer_len = 0
@property
def start_idx(self):
return max(self._start - self._look_back, 0)
@property
def end_idx(self):
return self._end + self._look_ahead
@property
def is_audio_complete(self) -> bool:
return self._filled_buffer_len >= self.end_idx
def _get_len_in_samples(self, len_in_ms: float) -> int:
_len_in_s = self._sampling_rate * len_in_ms / 1000
assert _len_in_s.is_integer(), _len_in_s
len_in_s = int(_len_in_s)
return len_in_s
def _allocate_new_buffer(self) -> None:
# allocate new buffer
new_buffer = np.empty(self._buffer_size, dtype=np.float32)
left_to_copy = max(self._filled_buffer_len - self.start_idx, 0)
if left_to_copy > 0:
new_buffer[:left_to_copy] = self._buffer[
self.start_idx : self._filled_buffer_len
]
del self._buffer
self._buffer = new_buffer
self._filled_buffer_len = left_to_copy
self._start = self._look_back
self._end = self._start + self._streaming_size
def write_audio(self, audio: np.ndarray) -> None:
put_end_idx = self._filled_buffer_len + len(audio)
if put_end_idx > self._buffer_size:
self._allocate_new_buffer()
self._buffer[self._filled_buffer_len : self._filled_buffer_len + len(audio)] = (
audio
)
self._filled_buffer_len += len(audio)
def read_audio(self) -> np.ndarray | None:
if not self.is_audio_complete:
return None
audio = self._buffer[self.start_idx : self.end_idx]
self._start = self._end
self._end += self._streaming_size
return audio
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
VoxtralStreamingMultiModalProcessor, VoxtralStreamingMultiModalProcessor,
info=VoxtralProcessingInfo, info=VoxtralProcessingInfo,
dummy_inputs=VoxtralDummyInputsBuilder, dummy_inputs=VoxtralDummyInputsBuilder,
) )
class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): @support_torch_compile
class VoxtralStreamingGeneration(VoxtralForConditionalGeneration, SupportsRealtime):
requires_raw_input_tokens = True requires_raw_input_tokens = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
assert (
not vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
), (
"Voxtral streaming doesn't support full cudagraphs yet. "
"Please use PIECEWISE."
)
self.time_embedding: TimeEmbedding = TimeEmbedding( self.time_embedding: TimeEmbedding = TimeEmbedding(
dim=self.config.text_config.hidden_size dim=self.config.text_config.hidden_size
) )
audio_config = self.tokenizer.instruct.audio_encoder.audio_config audio_config = self.tokenizer.instruct.audio_encoder.audio_config
_n_delay_tokens = ( self.n_delay_tokens = audio_config.num_delay_tokens
audio_config.frame_rate * audio_config.transcription_delay_ms / 1000
)
assert _n_delay_tokens.is_integer(), (
f"n_delay_tokens must be integer, got {_n_delay_tokens}"
)
self.n_delay_tokens = int(_n_delay_tokens) # for realtime transcription
@classmethod
async def buffer_realtime_audio(
cls,
audio_stream: AsyncGenerator[np.ndarray, None],
input_stream: asyncio.Queue[list[int]],
model_config: ModelConfig,
) -> AsyncGenerator[PromptType, None]:
tokenizer = cached_tokenizer_from_config(model_config)
audio_encoder = tokenizer.instruct.audio_encoder
config = audio_encoder.audio_config
buffer = VoxtralRealtimeBuffer(config)
is_first_yield = True
async for audio in audio_stream:
buffer.write_audio(audio)
while (new_audio := buffer.read_audio()) is not None:
if is_first_yield:
# make sure that input_stream is empty
assert input_stream.empty()
audio = Audio(new_audio, config.sampling_rate, format="wav")
request = TranscriptionRequest(
streaming=StreamingMode.ONLINE,
audio=RawAudio.from_audio(audio),
language=None,
)
# mistral tokenizer takes care
# of preparing the first prompt inputs
# and does some left-silence padding
# for improved performance
audio_enc = tokenizer.mistral.encode_transcription(request)
token_ids = audio_enc.tokens
new_audio = audio_enc.audios[0].audio_array
is_first_yield = False
else:
# pop last element from input_stream
all_outputs = await asyncio.wait_for(
input_stream.get(), timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S
)
token_ids = all_outputs[-1:]
multi_modal_data = {"audio": (new_audio, None)}
yield TokensPrompt(
prompt_token_ids=token_ids, multi_modal_data=multi_modal_data
)
@property @property
def audio_config(self): def audio_config(self):
...@@ -205,8 +345,9 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): ...@@ -205,8 +345,9 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
# sum pool text and audio embeddings # sum pool text and audio embeddings
inputs_embeds = audio_text_embeds + text_embeds inputs_embeds = audio_text_embeds + text_embeds
time_tensor = torch.tensor( time_tensor = torch.full(
[self.n_delay_tokens], (1,),
fill_value=self.n_delay_tokens,
device=inputs_embeds.device, device=inputs_embeds.device,
dtype=inputs_embeds.dtype, dtype=inputs_embeds.dtype,
) )
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, get_args from typing import Literal, get_args
GenerationTask = Literal["generate", "transcription"] GenerationTask = Literal["generate", "transcription", "realtime"]
GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask) GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask)
PoolingTask = Literal[ PoolingTask = Literal[
......
...@@ -7,7 +7,6 @@ import time ...@@ -7,7 +7,6 @@ import time
import warnings import warnings
from collections.abc import AsyncGenerator, Iterable, Mapping from collections.abc import AsyncGenerator, Iterable, Mapping
from copy import copy from copy import copy
from dataclasses import dataclass
from typing import Any from typing import Any
import torch import torch
...@@ -19,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -19,6 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.inputs.data import StreamingInput
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...@@ -53,18 +53,6 @@ from vllm.v1.metrics.stats import IterationStats ...@@ -53,18 +53,6 @@ from vllm.v1.metrics.stats import IterationStats
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class StreamingInput:
"""Input data for a streaming generation request.
This is used with generate() to support multi-turn streaming sessions
where inputs are provided via an async generator.
"""
prompt: PromptType
sampling_params: SamplingParams | None = None
class InputStreamError(Exception): class InputStreamError(Exception):
"""Wrapper for errors from the input stream generator. """Wrapper for errors from the input stream generator.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment