Unverified Commit 6c0baee6 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Voxtral Realtime] Refactor & Improve buffering logic (#34428)


Signed-off-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 1100a976
...@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs ...@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0 pyzmq >= 25.0.0
msgspec msgspec
gguf >= 0.17.0 gguf >= 0.17.0
mistral_common[image] >= 1.9.0 mistral_common[image] >= 1.9.1
opencv-python-headless >= 4.13.0 # required for video IO opencv-python-headless >= 4.13.0 # required for video IO
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
......
...@@ -23,7 +23,7 @@ jiwer # required for audio tests ...@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test timm # required for internvl test
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.9.0 # required for voxtral test mistral_common[image,audio] >= 1.9.1 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
opencv-python-headless >= 4.13.0 # required for video test opencv-python-headless >= 4.13.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test
......
...@@ -30,7 +30,7 @@ torchaudio==2.10.0 ...@@ -30,7 +30,7 @@ torchaudio==2.10.0
torchvision==0.25.0 torchvision==0.25.0
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.9.0 # required for voxtral test mistral_common[image,audio] >= 1.9.1 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
opencv-python-headless >= 4.13.0 # required for video test opencv-python-headless >= 4.13.0 # required for video test
......
...@@ -499,7 +499,7 @@ mbstrdecoder==1.1.3 ...@@ -499,7 +499,7 @@ mbstrdecoder==1.1.3
# typepy # typepy
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
mistral-common==1.9.0 mistral-common==1.9.1
# via -r requirements/test.in # via -r requirements/test.in
mlflow==2.22.0 mlflow==2.22.0
# via terratorch # via terratorch
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from dataclasses import asdict from dataclasses import asdict
import pytest import pytest
...@@ -10,14 +9,13 @@ from mistral_common.protocol.transcription.request import ( ...@@ -10,14 +9,13 @@ from mistral_common.protocol.transcription.request import (
StreamingMode, StreamingMode,
TranscriptionRequest, TranscriptionRequest,
) )
from mistral_common.tokens.tokenizers.audio import AudioConfig
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from vllm import LLM, EngineArgs, SamplingParams from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs.data import TokensPrompt from vllm.model_executor.models.voxtral_realtime import VoxtralRealtimeBuffer
from vllm.v1.engine.async_llm import AsyncLLM, StreamingInput from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602" MODEL_NAME = "mistralai/Voxtral-Mini-4B-Realtime-2602"
ENGINE_CONFIG = dict( ENGINE_CONFIG = dict(
...@@ -114,136 +112,40 @@ def test_voxtral_realtime_forward(audio_assets, tokenizer, engine): ...@@ -114,136 +112,40 @@ def test_voxtral_realtime_forward(audio_assets, tokenizer, engine):
assert texts == EXPECTED_TEXT assert texts == EXPECTED_TEXT
class RealTimeAudioInput:
"""
This class is used to stream an audio file just as
if it would be streamed in real-time.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:
self._tokenizer = tokenizer
self._config: AudioConfig = (
self._tokenizer.instruct_tokenizer.audio_encoder.audio_config
)
self._look_ahead_in_ms = self._config.streaming_look_ahead_ms
self._look_back_in_ms = self._config.streaming_look_back_ms
self._sampling_rate = self._config.sampling_rate
self._audio: Audio | None = None
# mutable objects
self._start = 0
n_left_pad_samples = (
self._config.raw_audio_length_per_tok * self._config.n_left_pad_tokens
)
self._end = self.streaming_delay + n_left_pad_samples + self.streaming_size
self._queue: asyncio.Queue[StreamingInput | None] = asyncio.Queue()
@classmethod
async def create(cls, audio: Audio, tokenizer: MistralTokenizer):
self = cls(tokenizer)
# we're doing "OFFLINE" encoding here to right & left pad the audio since
# we have access to the whole audio
# if we'd do an actual online realtime streaming application we
# should instead pass `StreamingMode.ONLINE`
req = TranscriptionRequest(
streaming=StreamingMode.OFFLINE,
audio=RawAudio.from_audio(audio),
language=None,
)
audio_enc = self._tokenizer.encode_transcription(req)
self._audio = audio_enc.audios[0]
# add first request
await self.add_tokens(audio_enc.tokens)
return self
@property
def look_ahead(self) -> int:
return self._get_len_in_samples(self._look_ahead_in_ms)
@property
def look_back(self) -> int:
return self._get_len_in_samples(self._look_back_in_ms)
@property
def streaming_delay(self) -> int:
return self._get_len_in_samples(self._config.transcription_delay_ms)
@property
def streaming_size(self) -> int:
stream_size_in_ms = 1000 / self._config.frame_rate
return self._get_len_in_samples(stream_size_in_ms)
def _get_len_in_samples(self, len_in_ms: float) -> int:
_len_in_s = self._sampling_rate * len_in_ms / 1000
assert _len_in_s.is_integer(), _len_in_s
len_in_s = int(_len_in_s)
return len_in_s
async def add_tokens(self, tokens: list[int]) -> None:
assert self._audio is not None
if self._start >= len(self._audio.audio_array):
self.stop()
return
_end = self._end + self.look_ahead
_start = max(0, self._start - self.look_back)
multi_modal_data = {"audio": (self._audio.audio_array[_start:_end], None)}
prompt = TokensPrompt(
prompt_token_ids=tokens, multi_modal_data=multi_modal_data
)
await self._queue.put(StreamingInput(prompt))
# increase
self._start = self._end
self._end = self._end + self.streaming_size
def stop(self):
self._queue.put_nowait(None)
async def generator(self):
while (item := await self._queue.get()) is not None:
yield item
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine): async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine):
sampling_params = SamplingParams(temperature=0.0, max_tokens=1) sampling_params = SamplingParams(temperature=0.0, max_tokens=1)
audio_config = tokenizer.instruct_tokenizer.audio_encoder.audio_config
output_tokens_list = [] output_tokens_list = []
for i, audio_asset in enumerate(audio_assets): for i, audio_asset in enumerate(audio_assets):
output_tokens = [] output_tokens = []
audio = Audio.from_file(audio_asset.get_local_path(), strict=False) audio = Audio.from_file(audio_asset.get_local_path(), strict=False)
streaming_input = await RealTimeAudioInput.create(
audio=audio, tokenizer=tokenizer req = TranscriptionRequest(
streaming=StreamingMode.OFFLINE,
audio=RawAudio.from_audio(audio),
language=None,
) )
audio_enc = tokenizer.encode_transcription(req)
buffer = VoxtralRealtimeBuffer(audio_config, audio_enc.tokens)
await buffer.append_audio(audio_enc.audios[0].audio_array)
await buffer.append_audio(None)
request_id = f"session-{i}" request_id = f"session-{i}"
async for resp in async_engine.generate( async for resp in async_engine.generate(
prompt=streaming_input.generator(), prompt=buffer.get_input_stream(),
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
): ):
tokens = resp.outputs[0].token_ids[-1:] tokens = resp.outputs[0].token_ids[-1:]
output_tokens.extend(tokens) output_tokens.extend(tokens)
await streaming_input.add_tokens(tokens) await buffer.append_tokens(tokens)
output_tokens_list.append(output_tokens) output_tokens_list.append(output_tokens)
texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list] texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list]
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my") texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
assert texts == EXPECTED_TEXT assert texts == EXPECTED_TEXT
...@@ -155,9 +155,7 @@ class VoxtralProcessorAdapter: ...@@ -155,9 +155,7 @@ class VoxtralProcessorAdapter:
assert audio.ndim == 1 assert audio.ndim == 1
if not self._audio_processor.audio_config.is_streaming: if not self._audio_processor.audio_config.is_streaming:
audio = self._audio_processor.pad( audio = self._audio_processor.pad(audio, self.sampling_rate)
audio, self.sampling_rate, is_online_streaming=False
)
audio_tokens = [self.begin_audio_token_id] + [ audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id self.audio_token_id
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import asyncio import asyncio
import math import math
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Iterable, Iterator, Mapping
from typing import Literal from typing import Literal
import numpy as np import numpy as np
...@@ -18,7 +18,7 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig ...@@ -18,7 +18,7 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, StreamingInput, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
from vllm.model_executor.models.voxtral import ( from vllm.model_executor.models.voxtral import (
...@@ -47,8 +47,6 @@ from .utils import ( ...@@ -47,8 +47,6 @@ from .utils import (
logger = init_logger(__name__) logger = init_logger(__name__)
_PRE_ALLOCATE_BUFFER_SIZE_IN_S = 30
class VoxtralRealtimeMultiModalProcessor(VoxtralMultiModalProcessor): class VoxtralRealtimeMultiModalProcessor(VoxtralMultiModalProcessor):
def __init__( def __init__(
...@@ -130,84 +128,81 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor: ...@@ -130,84 +128,81 @@ def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor:
class VoxtralRealtimeBuffer: class VoxtralRealtimeBuffer:
def __init__(self, config: AudioConfig) -> None: def __init__(self, config: AudioConfig, prompt_tokens: list[int]) -> None:
self._config = config self._config = config
self._look_ahead_in_ms = config.streaming_look_ahead_ms _look_ahead_in_ms = self._config.streaming_look_ahead_ms
self._look_back_in_ms = config.streaming_look_back_ms _look_back_in_ms = self._config.streaming_look_back_ms
self._look_ahead_in_samples = self._ms_to_samples(_look_ahead_in_ms)
self._sampling_rate = self._config.sampling_rate self._look_back_in_samples = self._ms_to_samples(_look_back_in_ms)
self._look_ahead = self._get_len_in_samples(self._look_ahead_in_ms) # None signals the end
self._look_back = self._get_len_in_samples(self._look_back_in_ms) self._audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()
self._streaming_size = self._get_len_in_samples(1000 / self._config.frame_rate) self._leftover: np.ndarray | None = None
self._token_queue: asyncio.Queue[int] = asyncio.Queue()
# mutable objects
streaming_delay = self._get_len_in_samples(self._config.transcription_delay_ms) self._initial_end = len(prompt_tokens) * self._config.raw_audio_length_per_tok
self._start = 0 for token in prompt_tokens:
self._end = streaming_delay + self._streaming_size self._token_queue.put_nowait(token)
# always pre-allocate 30 second buffers def _generate_frame_size_and_num_tokens(self) -> Iterator[tuple[int, int]]:
self._buffer_size = _PRE_ALLOCATE_BUFFER_SIZE_IN_S * self._sampling_rate streaming_step_size = self._ms_to_samples(1000 / self._config.frame_rate)
self._buffer: np.ndarray = np.empty(self._buffer_size, dtype=np.float32) start = 0
self._filled_buffer_len = 0 end = self._initial_end
while True:
@property frame_start = max(start - self._look_back_in_samples, 0)
def start_idx(self): frame_end = end + self._look_ahead_in_samples
return max(self._start - self._look_back, 0) frame_size = frame_end - frame_start
num_tokens = (end - start) / self._config.raw_audio_length_per_tok
@property assert num_tokens.is_integer()
def end_idx(self): yield frame_size, int(num_tokens)
return self._end + self._look_ahead start = end
end += streaming_step_size
@property
def is_audio_complete(self) -> bool: def _ms_to_samples(self, ms: float) -> int:
return self._filled_buffer_len >= self.end_idx len_ = self._config.sampling_rate * ms / 1000
assert len_.is_integer(), len_
def _get_len_in_samples(self, len_in_ms: float) -> int: return int(len_)
_len_in_s = self._sampling_rate * len_in_ms / 1000
assert _len_in_s.is_integer(), _len_in_s async def append_audio(self, audio_array: np.ndarray | None) -> None:
len_in_s = int(_len_in_s) await self._audio_queue.put(audio_array)
return len_in_s async def append_tokens(self, tokens: Iterable[int]) -> None:
for token in tokens:
def _allocate_new_buffer(self) -> None: await self._token_queue.put(token)
# allocate new buffer
new_buffer = np.empty(self._buffer_size, dtype=np.float32) async def get_input_stream(self) -> AsyncGenerator[StreamingInput]:
left_to_copy = max(self._filled_buffer_len - self.start_idx, 0) for frame_size, num_tokens in self._generate_frame_size_and_num_tokens():
next_tokens = [await self._token_queue.get() for _ in range(num_tokens)]
if left_to_copy > 0:
new_buffer[:left_to_copy] = self._buffer[ audio_arrays: list[np.ndarray] = (
self.start_idx : self._filled_buffer_len [self._leftover] if self._leftover is not None else []
]
del self._buffer
self._buffer = new_buffer
self._filled_buffer_len = left_to_copy
self._start = self._look_back
self._end = self._start + self._streaming_size
def write_audio(self, audio: np.ndarray) -> None:
put_end_idx = self._filled_buffer_len + len(audio)
if put_end_idx > self._buffer_size:
self._allocate_new_buffer()
self._buffer[self._filled_buffer_len : self._filled_buffer_len + len(audio)] = (
audio
) )
self._filled_buffer_len += len(audio) while sum(len(arr) for arr in audio_arrays) < frame_size:
arr = await self._audio_queue.get()
def read_audio(self) -> np.ndarray | None: if arr is None:
if not self.is_audio_complete: return
return None audio_arrays.append(arr)
audio_array = np.concatenate(audio_arrays)
frame = audio_array[:frame_size]
# The current stride took look_ahead_in_samples audio of the next sample
# In addition the next sample will take look_back_in_samples audio of
# the current sample => So let's put both of this into the leftover
stride = (
frame_size - self._look_ahead_in_samples - self._look_back_in_samples
)
assert stride > 0, f"{stride=} must be positive"
audio = self._buffer[self.start_idx : self.end_idx] self._leftover = audio_array[stride:]
self._start = self._end
self._end += self._streaming_size
return audio yield StreamingInput(
TokensPrompt(
prompt_token_ids=next_tokens,
multi_modal_data={"audio": (frame, None)},
)
)
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
...@@ -234,7 +229,7 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim ...@@ -234,7 +229,7 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
) )
audio_config = self.tokenizer.instruct.audio_encoder.audio_config audio_config = self.tokenizer.instruct.audio_encoder.audio_config
self.n_delay_tokens = audio_config.num_delay_tokens self.n_delay_tokens = audio_config.get_num_delay_tokens()
# for realtime transcription # for realtime transcription
@classmethod @classmethod
...@@ -248,45 +243,47 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim ...@@ -248,45 +243,47 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
audio_encoder = tokenizer.instruct.audio_encoder audio_encoder = tokenizer.instruct.audio_encoder
config = audio_encoder.audio_config config = audio_encoder.audio_config
buffer = VoxtralRealtimeBuffer(config) # Get prompt tokens (streaming prefix tokens) without encoding audio
is_first_yield = True prompt_tokens = (
tokenizer.instruct.start() + audio_encoder.encode_streaming_tokens()
async for audio in audio_stream:
buffer.write_audio(audio)
while (new_audio := buffer.read_audio()) is not None:
if is_first_yield:
# make sure that input_stream is empty
assert input_stream.empty()
audio = Audio(new_audio, config.sampling_rate, format="wav")
request = TranscriptionRequest(
streaming=StreamingMode.ONLINE,
audio=RawAudio.from_audio(audio),
language=None,
) )
# mistral tokenizer takes care
# of preparing the first prompt inputs # Get left/right padding audio
# and does some left-silence padding left_pad, right_pad = audio_encoder.get_padding_audio()
# for improved performance
audio_enc = tokenizer.mistral.encode_transcription(request) buffer = VoxtralRealtimeBuffer(config, prompt_tokens)
token_ids = audio_enc.tokens # Feed audio with padding into buffer in background
new_audio = audio_enc.audios[0].audio_array async def feed_audio():
yielded_first_chunk = False
is_first_yield = False async for audio_chunk in audio_stream:
else: if not yielded_first_chunk:
# pop last element from input_stream yielded_first_chunk = True
# Prepend left padding before first real audio
await buffer.append_audio(left_pad.audio_array)
await buffer.append_audio(audio_chunk)
# Append right padding at the end
await buffer.append_audio(right_pad.audio_array)
await buffer.append_audio(None) # signal end
# Feed output tokens back into buffer in background
async def feed_tokens():
while True:
all_outputs = await asyncio.wait_for( all_outputs = await asyncio.wait_for(
input_stream.get(), timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S input_stream.get(),
timeout=VLLM_ENGINE_ITERATION_TIMEOUT_S,
) )
token_ids = all_outputs[-1:] await buffer.append_tokens(all_outputs[-1:])
multi_modal_data = {"audio": (new_audio, None)} audio_task = asyncio.create_task(feed_audio())
yield TokensPrompt( token_task = asyncio.create_task(feed_tokens())
prompt_token_ids=token_ids, multi_modal_data=multi_modal_data
) try:
async for streaming_input in buffer.get_input_stream():
yield streaming_input.prompt
finally:
audio_task.cancel()
token_task.cancel()
@property @property
def audio_config(self): def audio_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment