Unverified Commit e7e3e6d2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Voxtral (#20970)


Signed-off-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 4ffd963f
......@@ -10,7 +10,7 @@ on HuggingFace model repository.
import os
from dataclasses import asdict
from typing import NamedTuple, Optional
from typing import Any, NamedTuple, Optional
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
......@@ -30,7 +30,9 @@ question_per_audio_count = {
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompt: str
prompt: Optional[str] = None
prompt_token_ids: Optional[dict[str, list[int]]] = None
multi_modal_data: Optional[dict[str, Any]] = None
stop_token_ids: Optional[list[int]] = None
lora_requests: Optional[list[LoRARequest]] = None
......@@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple):
# Unless specified, these settings have been tested to work on a single L4.
# Voxtral
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import (
AudioChunk,
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
model_name = "mistralai/Voxtral-Mini-3B-2507"
tokenizer = MistralTokenizer.from_hf_hub(model_name)
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
config_format="mistral",
load_format="mistral",
tokenizer_mode="mistral",
enforce_eager=True,
enable_chunked_prefill=False,
)
text_chunk = TextChunk(text=question)
audios = [
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
for i in range(audio_count)
]
audio_chunks = [
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
]
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
req = ChatCompletionRequest(messages=messages, model=model_name)
tokens = tokenizer.encode_chat_completion(req)
prompt_ids, audios = tokens.tokens, tokens.audios
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
multi_modal_data = {"audio": audios_and_sr}
return ModelRequestData(
engine_args=engine_args,
prompt_token_ids=prompt_ids,
multi_modal_data=multi_modal_data,
)
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somehat different than what is
......@@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"voxtral": run_voxtral,
"granite_speech": run_granite_speech,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
......@@ -311,16 +368,24 @@ def main(args):
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
)
mm_data = {}
if audio_count > 0:
mm_data = {
"audio": [
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
]
}
mm_data = req_data.multi_modal_data
if not mm_data:
mm_data = {}
if audio_count > 0:
mm_data = {
"audio": [
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
]
}
assert args.num_prompts > 0
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
inputs = {"multi_modal_data": mm_data}
if req_data.prompt:
inputs["prompt"] = req_data.prompt
else:
inputs["prompt_token_ids"] = req_data.prompt_token_ids
if args.num_prompts > 1:
# Batch inference
inputs = [inputs] * args.num_prompts
......
......@@ -33,7 +33,7 @@ pyzmq >= 25.0.0
msgspec
gguf >= 0.13.0
importlib_metadata; python_version < '3.10'
mistral_common[opencv] >= 1.6.2
mistral_common[opencv] >= 1.8.0
opencv-python-headless >= 4.11.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
......
......@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.6.2 # required for pixtral test
mistral_common[opencv] >= 1.8.0 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
......
......@@ -28,7 +28,7 @@ torchvision==0.22.0
transformers_stream_generator # required for qwen-vl test
mamba_ssm # required for plamo2 test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.7.0 # required for pixtral test
mistral_common[opencv] >= 1.8.0 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
......
......@@ -305,7 +305,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.7.0
mistral-common==1.8.0
# via -r requirements/test.in
more-itertools==10.5.0
# via lm-eval
......@@ -518,6 +518,8 @@ pyasn1-modules==0.4.2
# via google-auth
pybind11==2.13.6
# via lm-eval
pycountry==24.6.1
# via pydantic-extra-types
pycparser==2.22
# via cffi
pycryptodomex==3.22.0
......@@ -528,9 +530,12 @@ pydantic==2.11.5
# datamodel-code-generator
# mistral-common
# mteb
# pydantic-extra-types
# ray
pydantic-core==2.33.2
# via pydantic
pydantic-extra-types==2.10.5
# via mistral-common
pygments==2.18.0
# via rich
pyparsing==3.2.0
......@@ -835,6 +840,7 @@ typing-extensions==4.12.2
# pqdm
# pydantic
# pydantic-core
# pydantic-extra-types
# torch
# typer
# typing-inspection
......
......@@ -692,7 +692,8 @@ setup(
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing
"audio": ["librosa", "soundfile",
"mistral_common[audio]"], # Required for audio processing
"video": [] # Kept for backwards compatibility
},
cmdclass=cmdclass,
......
......@@ -17,6 +17,11 @@ from vllm.assets.audio import AudioAsset
from ...utils import RemoteOpenAIServer
MISTRAL_FORMAT_ARGS = [
"--tokenizer_mode", "mistral", "--config_format", "mistral",
"--load_format", "mistral"
]
@pytest.fixture
def mary_had_lamb():
......@@ -33,9 +38,18 @@ def winning_call():
@pytest.mark.asyncio
async def test_basic_audio(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
@pytest.mark.parametrize(
"model_name",
["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
async def test_basic_audio(mary_had_lamb, model_name):
server_args = ["--enforce-eager"]
if model_name.startswith("mistralai"):
server_args += MISTRAL_FORMAT_ARGS
# TODO(PATRICK) - REMOVE AFTER RELEASE
return # skip for now
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
......@@ -65,10 +79,13 @@ async def test_bad_requests(mary_had_lamb):
@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
async def test_long_audio_request(mary_had_lamb, model_name):
server_args = ["--enforce-eager"]
if model_name.startswith("openai"):
return
mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
# Add small silence after each audio for repeatability in the split process
......@@ -87,7 +104,8 @@ async def test_long_audio_request(mary_had_lamb):
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert out.count("Mary had a little lamb") == 10
counts = out.count("Mary had a little lamb")
assert counts == 10, counts
@pytest.mark.asyncio
......
......@@ -440,6 +440,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"VoxtralForConditionalGeneration": _HfExamplesInfo("mistralai/Voxtral-Mini-3B-2507", is_available_online=False, tokenizer_mode="mistral"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
# [Cross-encoder]
......@@ -513,4 +514,4 @@ class HfExampleModels:
raise ValueError(f"No example model defined for {model_id}")
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
\ No newline at end of file
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
......@@ -112,6 +112,7 @@ class OpenAISpeechToText(OpenAIServing):
prompt = self.model_cls.get_generation_prompt(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=lang,
task_type=self.task_type,
request_prompt=request.prompt)
......
......@@ -722,7 +722,8 @@ class SupportsTranscription(Protocol):
@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, language: str,
stt_config: SpeechToTextConfig,
model_config: ModelConfig, language: str,
task_type: str,
request_prompt: str) -> PromptType:
"""Get the prompt for the ASR model.
......
......@@ -231,6 +231,7 @@ _MULTIMODAL_MODELS = {
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
......
This diff is collapsed.
......@@ -3,6 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from typing import Optional, TypedDict, Union, cast
import numpy as np
......@@ -13,6 +14,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionType
from vllm.attention.layer import MultiHeadAttention
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig)
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -26,6 +28,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
......@@ -178,6 +181,7 @@ class WhisperAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
standalone_encoder: bool = False,
):
super().__init__()
self.embed_dim = embed_dim
......@@ -213,16 +217,24 @@ class WhisperAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
if standalone_encoder:
self.attn = MultiHeadAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
else:
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
def _init_qkv(
self,
......@@ -357,7 +369,11 @@ class WhisperMLP(nn.Module):
class WhisperEncoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
......@@ -371,6 +387,7 @@ class WhisperEncoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
standalone_encoder=is_standalone_encoder,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.mlp = WhisperMLP(
......@@ -462,10 +479,16 @@ class WhisperDecoderLayer(nn.Module):
class WhisperEncoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False,
init_in_fp32: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
embed_dim = config.d_model
self.is_standalone_encoder = is_standalone_encoder
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(embed_dim)
......@@ -480,17 +503,25 @@ class WhisperEncoder(nn.Module):
kernel_size=3,
stride=2,
padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions,
embed_dim)
self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers,
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers"),
prefix=f"{prefix}.layers",
is_standalone_encoder=
is_standalone_encoder),
prefix=f"{prefix}.layers",
)
self.layer_norm = nn.LayerNorm(config.d_model)
with torch.no_grad():
maybe_fp32_init_ctx = set_default_torch_dtype(
torch.float32) if init_in_fp32 else nullcontext()
with (
torch.no_grad(),
maybe_fp32_init_ctx,
):
self.embed_positions = nn.Embedding(self.max_source_positions,
embed_dim)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape))
......@@ -499,8 +530,10 @@ class WhisperEncoder(nn.Module):
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
embeds = embeds.permute(1, 0)
embeds = embeds + self.embed_positions.weight[:embeds.size(0), :]
embeds = embeds.transpose(-1, -2)
embeds = (embeds +
self.embed_positions.weight[:embeds.size(-2), :]).to(
embeds.dtype)
hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states)
......@@ -792,10 +825,14 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
f"or {list(ISO639_1_OTHER_LANGS.values())}")
@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, language: str,
task_type: str,
request_prompt: str) -> PromptType:
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str,
task_type: str,
request_prompt: str) -> PromptType:
prompt = {
"encoder_prompt": {
# Whisper does not support encoder prompt.
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from transformers import PretrainedConfig
from transformers import PretrainedConfig, WhisperConfig
from vllm.logger import init_logger
......@@ -24,9 +24,21 @@ def adapt_config_dict(config_dict: dict[str, Any],
if bool(config_dict.get("yarn")):
config_dict = _remap_mistral_yarn_args(config_dict)
if bool((config_dict.get("multimodal") or {}).get("vision_encoder_args")
or config_dict.get("vision_encoder")):
is_vision = ((config_dict.get("multimodal")
or {}).get("vision_encoder_args")
or config_dict.get("vision_encoder"))
is_audio = bool(
((config_dict.get("multimodal") or {}).get("whisper_model_args")
or {}).get("encoder_args"))
assert not (is_vision and is_audio), \
"Vision and audio are mutually exclusive"
if is_vision:
config_dict = _remap_mistral_vision_args(config_dict)
if is_audio:
config_dict = _remap_mistral_audio_args(config_dict)
config = PretrainedConfig.from_dict(config_dict)
......@@ -118,3 +130,35 @@ def _remap_mistral_quantization_args(config: dict) -> dict:
config["quantization_config"] = quantization_config
return config
def _remap_mistral_audio_args(config: dict) -> dict:
whisper_args = config["multimodal"].pop("whisper_model_args")
encoder_args = whisper_args["encoder_args"]
downsample_args = whisper_args["downsample_args"]
quant_config = config.get("quantization_config")
config = {
"model_type":
"whixtral",
"architectures": ["VoxtralForConditionalGeneration"],
"text_config":
PretrainedConfig.from_dict(config),
"audio_config":
WhisperConfig(
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
window_size=encoder_args["audio_encoding_args"]["window_size"],
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
downsample_factor=downsample_args["downsample_factor"],
d_model=encoder_args["dim"],
encoder_layers=encoder_args["n_layers"],
encoder_ffn_dim=encoder_args["hidden_dim"],
encoder_attention_heads=encoder_args["n_heads"],
vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"],
)
}
if quant_config:
config["quantization_config"] = quant_config
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment