Unverified Commit f75b61a9 authored by Tal Nir's avatar Tal Nir Committed by GitHub
Browse files

[Voxtral Realtime] Fix engine crash on empty multimodal embeddings (#34862)


Signed-off-by: default avatarTal Nir <tal@nervexneurotech.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 7f51e938
......@@ -121,3 +121,75 @@ async def test_multi_chunk_streaming(
" it sleeps with quite a flow, and everywhere that Mary went,"
" the lamb was sure to go."
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_empty_commit_does_not_crash_engine(
model_name, mary_had_lamb_audio_chunks, rocm_aiter_fa_attention
):
"""Test that committing without audio does not crash the engine.
Regression test for https://github.com/vllm-project/vllm/issues/34532.
An empty commit (no prior input_audio_buffer.append) used to trigger
``AssertionError: For realtime you must provide a multimodal_embedding
at every step`` which killed the entire engine process, disconnecting
every connected client.
"""
server_args = ["--enforce-eager", "--max-model-len", "2048"]
if model_name.startswith("mistralai"):
server_args += MISTRAL_FORMAT_ARGS
add_attention_backend(server_args, rocm_aiter_fa_attention)
with RemoteOpenAIServer(model_name, server_args) as remote_server:
ws_url = _get_websocket_url(remote_server)
# --- First connection: empty commit (no audio appended) ----------
async with websockets.connect(ws_url) as ws:
event = await receive_event(ws, timeout=30.0)
assert event["type"] == "session.created"
await send_event(ws, {"type": "session.update", "model": model_name})
# Start generation without sending any audio
await send_event(ws, {"type": "input_audio_buffer.commit"})
# Immediately signal end-of-audio
await send_event(ws, {"type": "input_audio_buffer.commit", "final": True})
# We should get *some* response (error or empty transcription),
# but the engine must NOT crash.
event = await receive_event(ws, timeout=30.0)
assert event["type"] in (
"error",
"transcription.done",
"transcription.delta",
)
# --- Second connection: normal transcription ---------------------
# Verifies the engine is still alive after the empty commit above.
async with websockets.connect(ws_url) as ws:
event = await receive_event(ws, timeout=30.0)
assert event["type"] == "session.created"
await send_event(ws, {"type": "session.update", "model": model_name})
await send_event(ws, {"type": "input_audio_buffer.commit"})
for chunk in mary_had_lamb_audio_chunks:
await send_event(
ws, {"type": "input_audio_buffer.append", "audio": chunk}
)
await send_event(ws, {"type": "input_audio_buffer.commit", "final": True})
done_received = False
while not done_received:
event = await receive_event(ws, timeout=60.0)
if event["type"] == "transcription.done":
done_received = True
elif event["type"] == "error":
pytest.fail(f"Engine error after empty commit: {event}")
assert done_received
......@@ -299,13 +299,29 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
"""Pass post-conv embeddings directly as input"""
# for realtime we simply flatten the multimodal embeddings
# to be in tensor format, we treat the input ids later
assert multimodal_embeddings is not None
assert len(multimodal_embeddings) > 0, (
"For realtime you must provide a multimodal_embedding at every step."
)
"""Pass post-conv embeddings directly as input.
For realtime models, multimodal embeddings are required at every
decode step. If they are missing (e.g. due to an empty audio
commit, encoder-cache eviction under GPU memory pressure, or a
client disconnect), return zero embeddings instead of crashing
the engine so that all other in-flight requests stay alive.
"""
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
logger.warning(
"Realtime model received empty multimodal embeddings "
"for %d input tokens. Returning zero embeddings to "
"avoid engine crash.",
input_ids.shape[0],
)
pool_size = self.config.audio_config.block_pool_size
embed_dim = self.config.audio_config.d_model * pool_size
return torch.zeros(
input_ids.shape[0],
embed_dim,
dtype=self.whisper_encoder.dtype,
device=input_ids.device,
)
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
return mm_embeds_flat
......@@ -367,9 +383,12 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
"""Transform audio waveforms -> initial whisper post-conv embeddings"""
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
assert audio_inputs is not None, (
"For realtime you must provide an audio input at every step."
)
if audio_inputs is None:
logger.warning(
"Realtime model received no audio inputs in "
"embed_multimodal. Returning empty embeddings."
)
return []
def _truncate_left(
sample: torch.Tensor, mult_of: int, pos: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment