Unverified Commit e7e3e6d2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Voxtral (#20970)


Signed-off-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 4ffd963f
...@@ -10,7 +10,7 @@ on HuggingFace model repository. ...@@ -10,7 +10,7 @@ on HuggingFace model repository.
import os import os
from dataclasses import asdict from dataclasses import asdict
from typing import NamedTuple, Optional from typing import Any, NamedTuple, Optional
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -30,7 +30,9 @@ question_per_audio_count = { ...@@ -30,7 +30,9 @@ question_per_audio_count = {
class ModelRequestData(NamedTuple): class ModelRequestData(NamedTuple):
engine_args: EngineArgs engine_args: EngineArgs
prompt: str prompt: Optional[str] = None
prompt_token_ids: Optional[dict[str, list[int]]] = None
multi_modal_data: Optional[dict[str, Any]] = None
stop_token_ids: Optional[list[int]] = None stop_token_ids: Optional[list[int]] = None
lora_requests: Optional[list[LoRARequest]] = None lora_requests: Optional[list[LoRARequest]] = None
...@@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple): ...@@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple):
# Unless specified, these settings have been tested to work on a single L4. # Unless specified, these settings have been tested to work on a single L4.
# Voxtral
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import (
AudioChunk,
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
model_name = "mistralai/Voxtral-Mini-3B-2507"
tokenizer = MistralTokenizer.from_hf_hub(model_name)
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
config_format="mistral",
load_format="mistral",
tokenizer_mode="mistral",
enforce_eager=True,
enable_chunked_prefill=False,
)
text_chunk = TextChunk(text=question)
audios = [
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
for i in range(audio_count)
]
audio_chunks = [
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
]
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
req = ChatCompletionRequest(messages=messages, model=model_name)
tokens = tokenizer.encode_chat_completion(req)
prompt_ids, audios = tokens.tokens, tokens.audios
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
multi_modal_data = {"audio": audios_and_sr}
return ModelRequestData(
engine_args=engine_args,
prompt_token_ids=prompt_ids,
multi_modal_data=multi_modal_data,
)
# Granite Speech # Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somehat different than what is # NOTE - the setting in this example are somehat different than what is
...@@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: ...@@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = { model_example_map = {
"voxtral": run_voxtral,
"granite_speech": run_granite_speech, "granite_speech": run_granite_speech,
"minicpmo": run_minicpmo, "minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm, "phi4_mm": run_phi4mm,
...@@ -311,6 +368,8 @@ def main(args): ...@@ -311,6 +368,8 @@ def main(args):
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
) )
mm_data = req_data.multi_modal_data
if not mm_data:
mm_data = {} mm_data = {}
if audio_count > 0: if audio_count > 0:
mm_data = { mm_data = {
...@@ -320,7 +379,13 @@ def main(args): ...@@ -320,7 +379,13 @@ def main(args):
} }
assert args.num_prompts > 0 assert args.num_prompts > 0
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data} inputs = {"multi_modal_data": mm_data}
if req_data.prompt:
inputs["prompt"] = req_data.prompt
else:
inputs["prompt_token_ids"] = req_data.prompt_token_ids
if args.num_prompts > 1: if args.num_prompts > 1:
# Batch inference # Batch inference
inputs = [inputs] * args.num_prompts inputs = [inputs] * args.num_prompts
......
...@@ -33,7 +33,7 @@ pyzmq >= 25.0.0 ...@@ -33,7 +33,7 @@ pyzmq >= 25.0.0
msgspec msgspec
gguf >= 0.13.0 gguf >= 0.13.0
importlib_metadata; python_version < '3.10' importlib_metadata; python_version < '3.10'
mistral_common[opencv] >= 1.6.2 mistral_common[opencv] >= 1.8.0
opencv-python-headless >= 4.11.0 # required for video IO opencv-python-headless >= 4.11.0 # required for video IO
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
......
...@@ -23,7 +23,7 @@ jiwer # required for audio tests ...@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test timm # required for internvl test
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.6.2 # required for pixtral test mistral_common[opencv] >= 1.8.0 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test
......
...@@ -28,7 +28,7 @@ torchvision==0.22.0 ...@@ -28,7 +28,7 @@ torchvision==0.22.0
transformers_stream_generator # required for qwen-vl test transformers_stream_generator # required for qwen-vl test
mamba_ssm # required for plamo2 test mamba_ssm # required for plamo2 test
matplotlib # required for qwen-vl test matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.7.0 # required for pixtral test mistral_common[opencv] >= 1.8.0 # required for voxtral test
num2words # required for smolvlm test num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test datamodel_code_generator # required for minicpm3 test
......
...@@ -305,7 +305,7 @@ mbstrdecoder==1.1.3 ...@@ -305,7 +305,7 @@ mbstrdecoder==1.1.3
# typepy # typepy
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
mistral-common==1.7.0 mistral-common==1.8.0
# via -r requirements/test.in # via -r requirements/test.in
more-itertools==10.5.0 more-itertools==10.5.0
# via lm-eval # via lm-eval
...@@ -518,6 +518,8 @@ pyasn1-modules==0.4.2 ...@@ -518,6 +518,8 @@ pyasn1-modules==0.4.2
# via google-auth # via google-auth
pybind11==2.13.6 pybind11==2.13.6
# via lm-eval # via lm-eval
pycountry==24.6.1
# via pydantic-extra-types
pycparser==2.22 pycparser==2.22
# via cffi # via cffi
pycryptodomex==3.22.0 pycryptodomex==3.22.0
...@@ -528,9 +530,12 @@ pydantic==2.11.5 ...@@ -528,9 +530,12 @@ pydantic==2.11.5
# datamodel-code-generator # datamodel-code-generator
# mistral-common # mistral-common
# mteb # mteb
# pydantic-extra-types
# ray # ray
pydantic-core==2.33.2 pydantic-core==2.33.2
# via pydantic # via pydantic
pydantic-extra-types==2.10.5
# via mistral-common
pygments==2.18.0 pygments==2.18.0
# via rich # via rich
pyparsing==3.2.0 pyparsing==3.2.0
...@@ -835,6 +840,7 @@ typing-extensions==4.12.2 ...@@ -835,6 +840,7 @@ typing-extensions==4.12.2
# pqdm # pqdm
# pydantic # pydantic
# pydantic-core # pydantic-core
# pydantic-extra-types
# torch # torch
# typer # typer
# typing-inspection # typing-inspection
......
...@@ -692,7 +692,8 @@ setup( ...@@ -692,7 +692,8 @@ setup(
"tensorizer": ["tensorizer==2.10.1"], "tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"], "fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing "audio": ["librosa", "soundfile",
"mistral_common[audio]"], # Required for audio processing
"video": [] # Kept for backwards compatibility "video": [] # Kept for backwards compatibility
}, },
cmdclass=cmdclass, cmdclass=cmdclass,
......
...@@ -17,6 +17,11 @@ from vllm.assets.audio import AudioAsset ...@@ -17,6 +17,11 @@ from vllm.assets.audio import AudioAsset
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MISTRAL_FORMAT_ARGS = [
"--tokenizer_mode", "mistral", "--config_format", "mistral",
"--load_format", "mistral"
]
@pytest.fixture @pytest.fixture
def mary_had_lamb(): def mary_had_lamb():
...@@ -33,9 +38,18 @@ def winning_call(): ...@@ -33,9 +38,18 @@ def winning_call():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_audio(mary_had_lamb): @pytest.mark.parametrize(
model_name = "openai/whisper-large-v3-turbo" "model_name",
["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
async def test_basic_audio(mary_had_lamb, model_name):
server_args = ["--enforce-eager"] server_args = ["--enforce-eager"]
if model_name.startswith("mistralai"):
server_args += MISTRAL_FORMAT_ARGS
# TODO(PATRICK) - REMOVE AFTER RELEASE
return # skip for now
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server: with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client() client = remote_server.get_async_client()
...@@ -65,10 +79,13 @@ async def test_bad_requests(mary_had_lamb): ...@@ -65,10 +79,13 @@ async def test_bad_requests(mary_had_lamb):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb): @pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
model_name = "openai/whisper-large-v3-turbo" async def test_long_audio_request(mary_had_lamb, model_name):
server_args = ["--enforce-eager"] server_args = ["--enforce-eager"]
if model_name.startswith("openai"):
return
mary_had_lamb.seek(0) mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb) audio, sr = librosa.load(mary_had_lamb)
# Add small silence after each audio for repeatability in the split process # Add small silence after each audio for repeatability in the split process
...@@ -87,7 +104,8 @@ async def test_long_audio_request(mary_had_lamb): ...@@ -87,7 +104,8 @@ async def test_long_audio_request(mary_had_lamb):
response_format="text", response_format="text",
temperature=0.0) temperature=0.0)
out = json.loads(transcription)['text'] out = json.loads(transcription)['text']
assert out.count("Mary had a little lamb") == 10 counts = out.count("Mary had a little lamb")
assert counts == 10, counts
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -440,6 +440,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -440,6 +440,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501 tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
trust_remote_code=True), # noqa: E501 trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"VoxtralForConditionalGeneration": _HfExamplesInfo("mistralai/Voxtral-Mini-3B-2507", is_available_online=False, tokenizer_mode="mistral"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
# [Cross-encoder] # [Cross-encoder]
......
...@@ -112,6 +112,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -112,6 +112,7 @@ class OpenAISpeechToText(OpenAIServing):
prompt = self.model_cls.get_generation_prompt( prompt = self.model_cls.get_generation_prompt(
audio=chunk, audio=chunk,
stt_config=self.asr_config, stt_config=self.asr_config,
model_config=self.model_config,
language=lang, language=lang,
task_type=self.task_type, task_type=self.task_type,
request_prompt=request.prompt) request_prompt=request.prompt)
......
...@@ -722,7 +722,8 @@ class SupportsTranscription(Protocol): ...@@ -722,7 +722,8 @@ class SupportsTranscription(Protocol):
@classmethod @classmethod
def get_generation_prompt(cls, audio: np.ndarray, def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, language: str, stt_config: SpeechToTextConfig,
model_config: ModelConfig, language: str,
task_type: str, task_type: str,
request_prompt: str) -> PromptType: request_prompt: str) -> PromptType:
"""Get the prompt for the ASR model. """Get the prompt for the ASR model.
......
...@@ -231,6 +231,7 @@ _MULTIMODAL_MODELS = { ...@@ -231,6 +231,7 @@ _MULTIMODAL_MODELS = {
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
# [Encoder-decoder] # [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
......
This diff is collapsed.
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from typing import Optional, TypedDict, Union, cast from typing import Optional, TypedDict, Union, cast
import numpy as np import numpy as np
...@@ -13,6 +14,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, ...@@ -13,6 +14,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
from transformers.models.whisper.modeling_whisper import sinusoids from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.attention.layer import MultiHeadAttention
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig) VllmConfig)
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -26,6 +28,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -26,6 +28,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
...@@ -178,6 +181,7 @@ class WhisperAttention(nn.Module): ...@@ -178,6 +181,7 @@ class WhisperAttention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
standalone_encoder: bool = False,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -213,6 +217,14 @@ class WhisperAttention(nn.Module): ...@@ -213,6 +217,14 @@ class WhisperAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.out_proj", prefix=f"{prefix}.out_proj",
) )
if standalone_encoder:
self.attn = MultiHeadAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
else:
self.attn = Attention( self.attn = Attention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
...@@ -357,7 +369,11 @@ class WhisperMLP(nn.Module): ...@@ -357,7 +369,11 @@ class WhisperMLP(nn.Module):
class WhisperEncoderLayer(nn.Module): class WhisperEncoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -371,6 +387,7 @@ class WhisperEncoderLayer(nn.Module): ...@@ -371,6 +387,7 @@ class WhisperEncoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
standalone_encoder=is_standalone_encoder,
) )
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.mlp = WhisperMLP( self.mlp = WhisperMLP(
...@@ -462,10 +479,16 @@ class WhisperDecoderLayer(nn.Module): ...@@ -462,10 +479,16 @@ class WhisperDecoderLayer(nn.Module):
class WhisperEncoder(nn.Module): class WhisperEncoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False,
init_in_fp32: bool = False):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
embed_dim = config.d_model embed_dim = config.d_model
self.is_standalone_encoder = is_standalone_encoder
self.num_mel_bins = config.num_mel_bins self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(embed_dim) self.embed_scale = (math.sqrt(embed_dim)
...@@ -480,17 +503,25 @@ class WhisperEncoder(nn.Module): ...@@ -480,17 +503,25 @@ class WhisperEncoder(nn.Module):
kernel_size=3, kernel_size=3,
stride=2, stride=2,
padding=1) padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions,
embed_dim)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers, config.encoder_layers,
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers"), prefix=f"{prefix}.layers",
is_standalone_encoder=
is_standalone_encoder),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
with torch.no_grad(): maybe_fp32_init_ctx = set_default_torch_dtype(
torch.float32) if init_in_fp32 else nullcontext()
with (
torch.no_grad(),
maybe_fp32_init_ctx,
):
self.embed_positions = nn.Embedding(self.max_source_positions,
embed_dim)
self.embed_positions.weight.copy_( self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)) sinusoids(*self.embed_positions.weight.shape))
...@@ -499,8 +530,10 @@ class WhisperEncoder(nn.Module): ...@@ -499,8 +530,10 @@ class WhisperEncoder(nn.Module):
for features in input_features: for features in input_features:
embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds)) embeds = nn.functional.gelu(self.conv2(embeds))
embeds = embeds.permute(1, 0) embeds = embeds.transpose(-1, -2)
embeds = embeds + self.embed_positions.weight[:embeds.size(0), :] embeds = (embeds +
self.embed_positions.weight[:embeds.size(-2), :]).to(
embeds.dtype)
hidden_states.append(embeds) hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states) hidden_states = torch.cat(hidden_states)
...@@ -792,8 +825,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -792,8 +825,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
f"or {list(ISO639_1_OTHER_LANGS.values())}") f"or {list(ISO639_1_OTHER_LANGS.values())}")
@classmethod @classmethod
def get_generation_prompt(cls, audio: np.ndarray, def get_generation_prompt(
stt_config: SpeechToTextConfig, language: str, cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str,
task_type: str, task_type: str,
request_prompt: str) -> PromptType: request_prompt: str) -> PromptType:
prompt = { prompt = {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any from typing import Any
from transformers import PretrainedConfig from transformers import PretrainedConfig, WhisperConfig
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -24,9 +24,21 @@ def adapt_config_dict(config_dict: dict[str, Any], ...@@ -24,9 +24,21 @@ def adapt_config_dict(config_dict: dict[str, Any],
if bool(config_dict.get("yarn")): if bool(config_dict.get("yarn")):
config_dict = _remap_mistral_yarn_args(config_dict) config_dict = _remap_mistral_yarn_args(config_dict)
if bool((config_dict.get("multimodal") or {}).get("vision_encoder_args")
or config_dict.get("vision_encoder")): is_vision = ((config_dict.get("multimodal")
or {}).get("vision_encoder_args")
or config_dict.get("vision_encoder"))
is_audio = bool(
((config_dict.get("multimodal") or {}).get("whisper_model_args")
or {}).get("encoder_args"))
assert not (is_vision and is_audio), \
"Vision and audio are mutually exclusive"
if is_vision:
config_dict = _remap_mistral_vision_args(config_dict) config_dict = _remap_mistral_vision_args(config_dict)
if is_audio:
config_dict = _remap_mistral_audio_args(config_dict)
config = PretrainedConfig.from_dict(config_dict) config = PretrainedConfig.from_dict(config_dict)
...@@ -118,3 +130,35 @@ def _remap_mistral_quantization_args(config: dict) -> dict: ...@@ -118,3 +130,35 @@ def _remap_mistral_quantization_args(config: dict) -> dict:
config["quantization_config"] = quantization_config config["quantization_config"] = quantization_config
return config return config
def _remap_mistral_audio_args(config: dict) -> dict:
whisper_args = config["multimodal"].pop("whisper_model_args")
encoder_args = whisper_args["encoder_args"]
downsample_args = whisper_args["downsample_args"]
quant_config = config.get("quantization_config")
config = {
"model_type":
"whixtral",
"architectures": ["VoxtralForConditionalGeneration"],
"text_config":
PretrainedConfig.from_dict(config),
"audio_config":
WhisperConfig(
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
window_size=encoder_args["audio_encoding_args"]["window_size"],
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
downsample_factor=downsample_args["downsample_factor"],
d_model=encoder_args["dim"],
encoder_layers=encoder_args["n_layers"],
encoder_ffn_dim=encoder_args["hidden_dim"],
encoder_attention_heads=encoder_args["n_heads"],
vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"],
)
}
if quant_config:
config["quantization_config"] = quant_config
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment