test_realtime_validation.py 12.2 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import asyncio
import json
6
import warnings
7
8

import numpy as np
9
import pybase64 as base64
10
11
12
import pytest
import websockets

13
14
from tests.entrypoints.openai.conftest import add_attention_backend
from tests.utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
15
from vllm.assets.audio import AudioAsset
16
from vllm.multimodal.media.audio import load_audio
17

18
19
20
21
22
23
24
# Increase engine iteration timeout for ROCm where first-use JIT compilation
# can exceed the default 60s, causing a silent deadlock in feed_tokens.
REALTIME_ENV_OVERRIDES = {
    **ROCM_ENV_OVERRIDES,
    "VLLM_ENGINE_ITERATION_TIMEOUT_S": "600",
}

25
26
27
28
29
30
31
MISTRAL_FORMAT_ARGS = [
    "--tokenizer_mode",
    "mistral",
    "--config_format",
    "mistral",
    "--load_format",
    "mistral",
32
] + ROCM_EXTRA_ARGS
33

34
MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602"
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58


def _get_websocket_url(server: RemoteOpenAIServer) -> str:
    """Convert HTTP URL to WebSocket URL for realtime endpoint."""
    http_url = server.url_root
    ws_url = http_url.replace("http://", "ws://")
    return f"{ws_url}/v1/realtime"


async def receive_event(ws, timeout: float = 60.0) -> dict:
    """Receive and parse JSON event from WebSocket."""
    message = await asyncio.wait_for(ws.recv(), timeout=timeout)
    return json.loads(message)


async def send_event(ws, event: dict) -> None:
    """Send JSON event to WebSocket."""
    await ws.send(json.dumps(event))


@pytest.fixture
def mary_had_lamb_audio_chunks() -> list[str]:
    """Audio split into ~1 second chunks for streaming."""
    path = AudioAsset("mary_had_lamb").get_local_path()
59
    audio, _ = load_audio(str(path), sr=16000, mono=True)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    # Split into ~0.1 second chunks (1600 samples at 16kHz)
    chunk_size = 1600
    chunks = []
    for i in range(0, len(audio), chunk_size):
        chunk = audio[i : i + chunk_size]
        chunk_int16 = (chunk * 32767).astype(np.int16)
        chunk_bytes = chunk_int16.tobytes()
        chunks.append(base64.b64encode(chunk_bytes).decode("utf-8"))

    return chunks


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_multi_chunk_streaming(
    model_name, mary_had_lamb_audio_chunks, rocm_aiter_fa_attention
):
    """Test streaming multiple audio chunks before committing."""
79
    server_args = ["--enforce-eager", "--max-model-len", "2048"]
80
81
82
83
84
85

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

    add_attention_backend(server_args, rocm_aiter_fa_attention)

86
    with RemoteOpenAIServer(
87
        model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
88
    ) as remote_server:
89
90
91
92
93
94
95
96
        ws_url = _get_websocket_url(remote_server)
        async with websockets.connect(ws_url) as ws:
            # Receive session.created
            event = await receive_event(ws, timeout=30.0)
            assert event["type"] == "session.created"

            await send_event(ws, {"type": "session.update", "model": model_name})

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            # Wait for the server to acknowledge the session update.
            try:
                while True:
                    event = await receive_event(ws, timeout=5.0)
                    if event["type"] == "session.updated":
                        break
            except TimeoutError:
                warnings.warn(
                    f"session.updated not received within {5.0}s after "
                    "session.update. The server may not implement this event.",
                    stacklevel=2,
                )

            # (ROCm) Warm-up: send a non-final commit (required to start
            # transcription) with a small audio chunk to trigger aiter
            # compilation on first use.
            await send_event(ws, {"type": "input_audio_buffer.commit"})
            await send_event(
                ws,
                {
                    "type": "input_audio_buffer.append",
                    "audio": mary_had_lamb_audio_chunks[0],
                },
            )
            await send_event(ws, {"type": "input_audio_buffer.commit", "final": True})

            # (ROCm) Drain all warm-up responses with generous timeout for
            # JIT compilation
            warmup_done = False
            while not warmup_done:
127
                event = await receive_event(ws, timeout=600.0)
128
129
130
131
                if event["type"] in ("transcription.done", "error"):
                    warmup_done = True

            # Now send the real test audio
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
            await send_event(ws, {"type": "input_audio_buffer.commit"})

            # Send multiple audio chunks
            for chunk in mary_had_lamb_audio_chunks:
                await send_event(
                    ws, {"type": "input_audio_buffer.append", "audio": chunk}
                )

            # Send commit to end
            await send_event(ws, {"type": "input_audio_buffer.commit", "final": True})

            # Collect transcription deltas
            full_text = ""
            done_received = False

            while not done_received:
                event = await receive_event(ws, timeout=60.0)

                if event["type"] == "transcription.delta":
                    full_text += event["delta"]
                elif event["type"] == "transcription.done":
                    done_received = True
                    assert "text" in event
                elif event["type"] == "error":
                    pytest.fail(f"Received error: {event}")

            # Verify transcription contains expected content
            assert event["type"] == "transcription.done"
            assert event["text"] == full_text
            assert full_text == (
162
                " First words I spoke in the original phonograph."
163
                " A little piece of practical poetry. Mary had a little lamb,"
164
                " it sleeps with quite a flow, and everywhere that Mary went,"
165
                " the lamb was sure to go."
166
            )
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_empty_commit_does_not_crash_engine(
    model_name, mary_had_lamb_audio_chunks, rocm_aiter_fa_attention
):
    """Test that committing without audio does not crash the engine.

    Regression test for https://github.com/vllm-project/vllm/issues/34532.
    An empty commit (no prior input_audio_buffer.append) used to trigger
    ``AssertionError: For realtime you must provide a multimodal_embedding
    at every step`` which killed the entire engine process, disconnecting
    every connected client.
    """
    server_args = ["--enforce-eager", "--max-model-len", "2048"]

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

    add_attention_backend(server_args, rocm_aiter_fa_attention)

189
    with RemoteOpenAIServer(
190
        model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
191
    ) as remote_server:
192
193
194
195
196
197
198
199
200
        ws_url = _get_websocket_url(remote_server)

        # --- First connection: empty commit (no audio appended) ----------
        async with websockets.connect(ws_url) as ws:
            event = await receive_event(ws, timeout=30.0)
            assert event["type"] == "session.created"

            await send_event(ws, {"type": "session.update", "model": model_name})

201
202
203
204
205
206
207
208
209
210
211
212
            try:
                while True:
                    event = await receive_event(ws, timeout=5.0)
                    if event["type"] == "session.updated":
                        break
            except TimeoutError:
                warnings.warn(
                    f"session.updated not received within {5.0}s after "
                    "session.update. The server may not implement this event.",
                    stacklevel=2,
                )

213
214
215
216
217
218
219
220
            # Start generation without sending any audio
            await send_event(ws, {"type": "input_audio_buffer.commit"})

            # Immediately signal end-of-audio
            await send_event(ws, {"type": "input_audio_buffer.commit", "final": True})

            # We should get *some* response (error or empty transcription),
            # but the engine must NOT crash.
221
222
            # (ROCm) Use generous timeout for first request (aiter JIT compilation)
            event = await receive_event(ws, timeout=360.0)
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            assert event["type"] in (
                "error",
                "transcription.done",
                "transcription.delta",
            )

        # --- Second connection: normal transcription ---------------------
        # Verifies the engine is still alive after the empty commit above.
        async with websockets.connect(ws_url) as ws:
            event = await receive_event(ws, timeout=30.0)
            assert event["type"] == "session.created"

            await send_event(ws, {"type": "session.update", "model": model_name})

237
238
239
240
241
242
243
244
245
246
247
248
249
            try:
                while True:
                    event = await receive_event(ws, timeout=5.0)
                    if event["type"] == "session.updated":
                        break
            except TimeoutError:
                warnings.warn(
                    f"session.updated not received within {5.0}s after "
                    "session.update. The server may not implement this event.",
                    stacklevel=2,
                )

            # Start transcription
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
            await send_event(ws, {"type": "input_audio_buffer.commit"})

            for chunk in mary_had_lamb_audio_chunks:
                await send_event(
                    ws, {"type": "input_audio_buffer.append", "audio": chunk}
                )

            await send_event(ws, {"type": "input_audio_buffer.commit", "final": True})

            done_received = False
            while not done_received:
                event = await receive_event(ws, timeout=60.0)
                if event["type"] == "transcription.done":
                    done_received = True
                elif event["type"] == "error":
                    pytest.fail(f"Engine error after empty commit: {event}")
            assert done_received
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_session_update_invalid_model_returns_error(
    model_name, rocm_aiter_fa_attention
):
    """Test that session.update with an invalid model returns an error."""
    server_args = ["--enforce-eager", "--max-model-len", "2048"]

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

    add_attention_backend(server_args, rocm_aiter_fa_attention)

    with RemoteOpenAIServer(
        model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
    ) as remote_server:
        ws_url = _get_websocket_url(remote_server)
        async with websockets.connect(ws_url) as ws:
            event = await receive_event(ws, timeout=30.0)
            assert event["type"] == "session.created"

            # Send session.update with a model that doesn't exist
            await send_event(
                ws,
                {"type": "session.update", "model": "nonexistent-model"},
            )

            event = await receive_event(ws, timeout=10.0)
            assert event["type"] == "error"
            assert "nonexistent-model" in event["error"]


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_commit_without_session_update_returns_error(
    model_name, rocm_aiter_fa_attention
):
    """Test that committing before validating the model returns an error
    and does not fall through to processing."""
    server_args = ["--enforce-eager", "--max-model-len", "2048"]

    if model_name.startswith("mistralai"):
        server_args += MISTRAL_FORMAT_ARGS

    add_attention_backend(server_args, rocm_aiter_fa_attention)

    with RemoteOpenAIServer(
        model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
    ) as remote_server:
        ws_url = _get_websocket_url(remote_server)
        async with websockets.connect(ws_url) as ws:
            event = await receive_event(ws, timeout=30.0)
            assert event["type"] == "session.created"

            # Send commit without sending session.update first
            await send_event(
                ws,
                {"type": "input_audio_buffer.commit", "final": True},
            )

            event = await receive_event(ws, timeout=10.0)
            assert event["type"] == "error"
            assert "model_not_validated" in event.get("code", "")