Unverified Commit 00c3d68e authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Frontend][Core] Add plumbing to support audio language models (#7446)

parent e20233d3
...@@ -112,6 +112,8 @@ autodoc_mock_imports = [ ...@@ -112,6 +112,8 @@ autodoc_mock_imports = [
"tensorizer", "tensorizer",
"pynvml", "pynvml",
"outlines", "outlines",
"librosa",
"soundfile",
"gguf", "gguf",
"lark", "lark",
] ]
......
...@@ -15,14 +15,14 @@ This document walks you through the steps to extend a vLLM model so that it acce ...@@ -15,14 +15,14 @@ This document walks you through the steps to extend a vLLM model so that it acce
It is assumed that you have already implemented the model in vLLM according to :ref:`these steps <adding_a_new_model>`. It is assumed that you have already implemented the model in vLLM according to :ref:`these steps <adding_a_new_model>`.
Further update the model as follows: Further update the model as follows:
- Implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface. - Implement the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
.. code-block:: diff .. code-block:: diff
+ from vllm.model_executor.models.interfaces import SupportsVision + from vllm.model_executor.models.interfaces import SupportsMultiModal
- class YourModelForImage2Seq(nn.Module): - class YourModelForImage2Seq(nn.Module):
+ class YourModelForImage2Seq(nn.Module, SupportsVision): + class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
.. note:: .. note::
The model class does not have to be named :code:`*ForCausalLM`. The model class does not have to be named :code:`*ForCausalLM`.
...@@ -51,11 +51,11 @@ This decorator accepts a function that maps multi-modal inputs to the keyword ar ...@@ -51,11 +51,11 @@ This decorator accepts a function that maps multi-modal inputs to the keyword ar
.. code-block:: diff .. code-block:: diff
from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.interfaces import SupportsMultiModal
+ from vllm.multimodal import MULTIMODAL_REGISTRY + from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_input_mapper() + @MULTIMODAL_REGISTRY.register_image_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision): class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function. A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
...@@ -72,13 +72,13 @@ and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.regis ...@@ -72,13 +72,13 @@ and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.regis
.. code-block:: diff .. code-block:: diff
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>) + @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>) @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision): class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
Here are some examples: Here are some examples:
...@@ -98,13 +98,13 @@ In such cases, you can define your own dummy data by registering a factory metho ...@@ -98,13 +98,13 @@ In such cases, you can define your own dummy data by registering a factory metho
.. code-block:: diff .. code-block:: diff
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>) @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>) + @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision): class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
.. note:: .. note::
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step. The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
...@@ -128,14 +128,14 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce ...@@ -128,14 +128,14 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce
.. code-block:: diff .. code-block:: diff
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>) @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>) @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>) + @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision): class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation. A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples: Here are some examples:
......
...@@ -20,4 +20,6 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 ...@@ -20,4 +20,6 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10 typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq pyzmq
librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1 gguf == 0.9.1
import math
import sys
import time
from typing import Dict, List, Optional, Tuple, Union, cast
from unittest.mock import patch
import librosa
import numpy as np
import openai
import pytest
import requests
import torch
from vllm import ModelRegistry
from vllm.config import MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
from vllm.utils import get_open_port
from ...utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
MODEL_NAME = "facebook/opt-125m"
TEST_AUDIO_URLS = [
"https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
]
def server_function(port):
def fake_input_mapper(ctx: InputContext, data: object):
assert isinstance(data, tuple)
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
# Resample it to 1 sample per second
audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return llm_inputs
audio, sr = multi_modal_data.get("audio")
audio_duration = math.ceil(len(audio) / sr)
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
cached_get_tokenizer(ctx.model_config.tokenizer),
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=62, # "_"
repeat_count=audio_duration)
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", lambda *_, **__: 100)
@INPUT_REGISTRY.register_input_processor(fake_input_processor)
class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
def __init__(self, *args, multimodal_config: MultiModalConfig,
**kwargs):
assert multimodal_config is not None
super().__init__(*args, **kwargs)
def forward(
self,
*args,
processed_audio: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return super().forward(*args, **kwargs)
ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
with patch("vllm.entrypoints.chat_utils._mm_token_str",
lambda *_, **__: "_"):
sys.argv = ["placeholder.py"] + \
(f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
"--dtype bfloat16 --enforce-eager --api-key token-abc123 "
f"--port {port} --chat-template {chatml_jinja_path} "
"--disable-frontend-multiprocessing").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server',
run_name='__main__')
@pytest.fixture(scope="module")
def client():
port = get_open_port()
ctx = torch.multiprocessing.get_context("spawn")
server = ctx.Process(target=server_function, args=(port, ))
server.start()
MAX_SERVER_START_WAIT_S = 60
client = openai.AsyncOpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
# run health check
health_url = f"http://localhost:{port}/health"
start = time.time()
while True:
try:
if requests.get(health_url).status_code == 200:
break
except Exception as err:
result = server.exitcode
if result is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT_S:
raise RuntimeError("Server failed to start in time.") from err
try:
yield client
finally:
server.kill()
@pytest.fixture(scope="session")
def base64_encoded_audio() -> Dict[str, str]:
return {
audio_url: encode_audio_base64(*fetch_audio(audio_url))
for audio_url in TEST_AUDIO_URLS
}
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
model_name: str, audio_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_single_chat_session_audio_base64encoded(
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
base64_encoded_audio: Dict[str, str]):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url":
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
model_name: str, audio_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
)
output = chat_completion.choices[0].message.content
stop_reason = chat_completion.choices[0].finish_reason
# test streaming
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert delta.content
assert "".join(chunks) == output
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
audio_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
with pytest.raises(openai.BadRequestError): # test multi-audio input
await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
)
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
...@@ -2,7 +2,8 @@ import codecs ...@@ -2,7 +2,8 @@ import codecs
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union, cast)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -21,12 +22,27 @@ from typing_extensions import Required, TypedDict ...@@ -21,12 +22,27 @@ from typing_extensions import Required, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
class AudioURL(TypedDict, total=False):
url: Required[str]
"""
Either a URL of the audio or a data URL with base64 encoded audio data.
"""
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
audio_url: Required[AudioURL]
type: Required[Literal["audio_url"]]
"""The type of the content part."""
class CustomChatCompletionContentPartParam(TypedDict, total=False): class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore __pydantic_config__ = ConfigDict(extra="allow") # type: ignore
...@@ -35,6 +51,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ...@@ -35,6 +51,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam, ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam] CustomChatCompletionContentPartParam]
...@@ -97,10 +114,11 @@ def load_chat_template( ...@@ -97,10 +114,11 @@ def load_chat_template(
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _image_token_str(model_config: ModelConfig, def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
tokenizer: PreTrainedTokenizer) -> Optional[str]: modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt # TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template) # (similar to chat template)
if modality == "image":
model_type = model_config.hf_config.model_type model_type = model_config.hf_config.model_type
if model_type == "phi3_v": if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer # Workaround since this token is not defined in the tokenizer
...@@ -114,17 +132,23 @@ def _image_token_str(model_config: ModelConfig, ...@@ -114,17 +132,23 @@ def _image_token_str(model_config: ModelConfig,
return tokenizer.decode(model_config.hf_config.image_token_index) return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"): if model_type in ("chameleon", "internvl_chat"):
return "<image>" return "<image>"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
raise TypeError("No audio models are supported yet.")
else:
raise TypeError(f"Unknown modality: {modality}")
# TODO: Let user specify how to insert image tokens into prompt # TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template) # (similar to chat template)
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str: def _get_full_multimodal_text_prompt(placeholder_token_str: str,
"""Combine image and text prompts for vision language model""" text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model"""
# NOTE: For now we assume all model architectures use the same # NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future. # placeholder + text prompt format. This may change in the future.
return f"{image_token_str}\n{text_prompt}" return f"{placeholder_token_str}\n{text_prompt}"
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
...@@ -135,6 +159,7 @@ def _parse_chat_message_content_parts( ...@@ -135,6 +159,7 @@ def _parse_chat_message_content_parts(
) -> ChatMessageParseResult: ) -> ChatMessageParseResult:
texts: List[str] = [] texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = []
modality: Literal["image", "audio"] = "image"
for part in parts: for part in parts:
part_type = part["type"] part_type = part["type"]
...@@ -142,9 +167,10 @@ def _parse_chat_message_content_parts( ...@@ -142,9 +167,10 @@ def _parse_chat_message_content_parts(
text = cast(ChatCompletionContentPartTextParam, part)["text"] text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text) texts.append(text)
elif part_type == "image_url": elif part_type == "image_url":
modality = "image"
if len(mm_futures) > 0: if len(mm_futures) > 0:
raise NotImplementedError( raise NotImplementedError(
"Multiple 'image_url' input is currently not supported.") "Multiple multimodal inputs is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam, image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"] part)["image_url"]
...@@ -156,21 +182,32 @@ def _parse_chat_message_content_parts( ...@@ -156,21 +182,32 @@ def _parse_chat_message_content_parts(
image_future = async_get_and_parse_image(image_url["url"]) image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future) mm_futures.append(image_future)
elif part_type == "audio_url":
modality = "audio"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
audio_url = cast(ChatCompletionContentPartAudioParam,
part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future)
else: else:
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts) text_prompt = "\n".join(texts)
if mm_futures: if mm_futures:
image_token_str = _image_token_str(model_config, tokenizer) placeholder_token_str = _mm_token_str(model_config, tokenizer,
if image_token_str is not None: modality)
if image_token_str in text_prompt: if placeholder_token_str is not None:
if placeholder_token_str in text_prompt:
logger.warning( logger.warning(
"Detected image token string in the text prompt. " "Detected multi-modal token string in the text prompt. "
"Skipping prompt formatting.") "Skipping prompt formatting.")
else: else:
text_prompt = _get_full_image_text_prompt( text_prompt = _get_full_multimodal_text_prompt(
image_token_str=image_token_str, placeholder_token_str=placeholder_token_str,
text_prompt=text_prompt, text_prompt=text_prompt,
) )
......
...@@ -44,6 +44,7 @@ if TYPE_CHECKING: ...@@ -44,6 +44,7 @@ if TYPE_CHECKING:
VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_AUDIO_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None NVCC_THREADS: Optional[str] = None
...@@ -321,6 +322,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -321,6 +322,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_IMAGE_FETCH_TIMEOUT": "VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Timeout for fetching audio when serving multimodal models
# Default is 5 seconds
"VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "5")),
# Path to the XLA persistent cache directory. # Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs. # Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH": "VLLM_XLA_CACHE_PATH":
......
...@@ -38,7 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -38,7 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator) safetensors_weights_iterator)
from vllm.model_executor.models.interfaces import (has_inner_state, from vllm.model_executor.models.interfaces import (has_inner_state,
supports_lora, supports_lora,
supports_vision) supports_multimodal)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -131,7 +131,7 @@ def _get_model_initialization_kwargs( ...@@ -131,7 +131,7 @@ def _get_model_initialization_kwargs(
"be added in the future. If this is important to you, " "be added in the future. If this is important to you, "
"please open an issue on github.") "please open an issue on github.")
if supports_vision(model_class): if supports_multimodal(model_class):
if multimodal_config is None: if multimodal_config is None:
raise ValueError("Provide vision related configurations " raise ValueError("Provide vision related configurations "
"through LLM entrypoint or engine arguments.") "through LLM entrypoint or engine arguments.")
......
...@@ -20,8 +20,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData ...@@ -20,8 +20,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
from .interfaces import SupportsVision from .interfaces import SupportsMultiModal
from .utils import merge_vision_embeddings from .utils import merge_multimodal_embeddings
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head", "language_model.lm_head": "lm_head",
...@@ -457,7 +457,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -457,7 +457,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsVision): class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, def __init__(self,
config: Blip2Config, config: Blip2Config,
...@@ -621,8 +621,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision): ...@@ -621,8 +621,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, inputs_embeds = merge_multimodal_embeddings(
vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID) BLIP2_IMAGE_TOKEN_ID)
input_ids = None input_ids = None
......
...@@ -35,7 +35,7 @@ from vllm.multimodal.image import (cached_get_tokenizer, ...@@ -35,7 +35,7 @@ from vllm.multimodal.image import (cached_get_tokenizer,
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsVision from .interfaces import SupportsMultiModal
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -886,7 +886,7 @@ class ChameleonModel(nn.Module): ...@@ -886,7 +886,7 @@ class ChameleonModel(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon) @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
class ChameleonForConditionalGeneration(nn.Module, SupportsVision): class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__( def __init__(
self, self,
......
...@@ -40,8 +40,8 @@ from vllm.multimodal.image import (cached_get_image_processor, ...@@ -40,8 +40,8 @@ from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer) cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from .interfaces import SupportsVision from .interfaces import SupportsMultiModal
from .utils import merge_vision_embeddings from .utils import merge_multimodal_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -209,7 +209,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): ...@@ -209,7 +209,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsVision): class FuyuForCausalLM(nn.Module, SupportsMultiModal):
def __init__(self, def __init__(self,
config: FuyuConfig, config: FuyuConfig,
...@@ -271,8 +271,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision): ...@@ -271,8 +271,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids) inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, inputs_embeds = merge_multimodal_embeddings(
vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.image_token_id) self.image_token_id)
else: else:
......
...@@ -10,12 +10,15 @@ logger = init_logger(__name__) ...@@ -10,12 +10,15 @@ logger = init_logger(__name__)
@runtime_checkable @runtime_checkable
class SupportsVision(Protocol): class SupportsMultiModal(Protocol):
"""The interface required for all vision language models (VLMs).""" """
The interface required for all multimodal (vision or audio) language
models.
"""
supports_vision: ClassVar[Literal[True]] = True supports_multimodal: ClassVar[Literal[True]] = True
""" """
A flag that indicates this model supports vision inputs. A flag that indicates this model supports multimodal inputs.
Note: Note:
There is no need to redefine this flag if this class is in the There is no need to redefine this flag if this class is in the
...@@ -29,30 +32,31 @@ class SupportsVision(Protocol): ...@@ -29,30 +32,31 @@ class SupportsVision(Protocol):
# We can't use runtime_checkable with ClassVar for issubclass checks # We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead # so we need to treat the class as an instance and use isinstance instead
@runtime_checkable @runtime_checkable
class _SupportsVisionType(Protocol): class _SupportsMultiModalType(Protocol):
supports_vision: Literal[True] supports_multimodal: Literal[True]
def __call__(self, *, multimodal_config: MultiModalConfig) -> None: def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
... ...
@overload @overload
def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]: def supports_multimodal(
model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
... ...
@overload @overload
def supports_vision(model: object) -> TypeIs[SupportsVision]: def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
... ...
def supports_vision( def supports_multimodal(
model: Union[Type[object], object], model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]: ) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsVisionType) return isinstance(model, _SupportsMultiModalType)
return isinstance(model, SupportsVision) return isinstance(model, SupportsMultiModal)
@runtime_checkable @runtime_checkable
......
...@@ -27,9 +27,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput ...@@ -27,9 +27,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
from .interfaces import SupportsVision from .interfaces import SupportsMultiModal
from .utils import (filter_weights, init_vllm_registered_model, from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings) merge_multimodal_embeddings)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -292,7 +292,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int): ...@@ -292,7 +292,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl) @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
class InternVLChatModel(nn.Module, SupportsVision): class InternVLChatModel(nn.Module, SupportsMultiModal):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -451,8 +451,8 @@ class InternVLChatModel(nn.Module, SupportsVision): ...@@ -451,8 +451,8 @@ class InternVLChatModel(nn.Module, SupportsVision):
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, inputs_embeds = merge_multimodal_embeddings(
vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id) self.img_context_token_id)
input_ids = None input_ids = None
else: else:
......
...@@ -19,12 +19,12 @@ from vllm.sequence import IntermediateTensors, SamplerOutput ...@@ -19,12 +19,12 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens, dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip) input_processor_for_clip)
from .interfaces import SupportsVision from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model, from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings) merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
...@@ -181,7 +181,7 @@ def _init_vision_tower(hf_config: LlavaConfig): ...@@ -181,7 +181,7 @@ def _init_vision_tower(hf_config: LlavaConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, def __init__(self,
config: LlavaConfig, config: LlavaConfig,
...@@ -338,7 +338,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -338,7 +338,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
inputs_embeds = merge_vision_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index) self.config.image_token_index)
......
...@@ -23,13 +23,13 @@ from vllm.sequence import IntermediateTensors, SamplerOutput ...@@ -23,13 +23,13 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size, dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip) get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsVision from .interfaces import SupportsMultiModal
from .llava import LlavaMultiModalProjector from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model, from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -275,7 +275,7 @@ def _init_vision_tower(hf_config: LlavaNextConfig): ...@@ -275,7 +275,7 @@ def _init_vision_tower(hf_config: LlavaNextConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, def __init__(self,
config: LlavaNextConfig, config: LlavaNextConfig,
...@@ -571,7 +571,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -571,7 +571,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
inputs_embeds = merge_vision_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index) self.config.image_token_index)
......
...@@ -48,7 +48,7 @@ from vllm.model_executor.layers.sampler import Sampler ...@@ -48,7 +48,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
...@@ -479,7 +479,7 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -479,7 +479,7 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs return llm_inputs
class MiniCPMVBaseModel(nn.Module, SupportsVision): class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
""" """
The abstract class of MiniCPMV can only be inherited, but cannot be The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated. instantiated.
......
...@@ -19,10 +19,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -19,10 +19,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsVision from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens) dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import merge_vision_embeddings from .utils import merge_multimodal_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -130,7 +130,7 @@ class PaliGemmaMultiModalProjector(nn.Module): ...@@ -130,7 +130,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision): class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, def __init__(self,
config: PaliGemmaConfig, config: PaliGemmaConfig,
...@@ -244,7 +244,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -244,7 +244,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index) self.config.image_token_index)
......
...@@ -42,8 +42,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput ...@@ -42,8 +42,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip) input_processor_for_clip)
from .interfaces import SupportsVision from .interfaces import SupportsMultiModal
from .utils import merge_vision_embeddings from .utils import merge_multimodal_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -453,7 +453,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -453,7 +453,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision): class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -568,8 +568,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -568,8 +568,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds, inputs_embeds = merge_multimodal_embeddings(
vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.image_token_id) self.image_token_id)
input_ids = None input_ids = None
else: else:
......
...@@ -54,41 +54,42 @@ def init_vllm_registered_model( ...@@ -54,41 +54,42 @@ def init_vllm_registered_model(
) )
def merge_vision_embeddings(input_ids: torch.Tensor, def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
vision_embeddings: BatchedTensors, multimodal_embeddings: BatchedTensors,
image_token_id: int) -> torch.Tensor: placeholder_token_id: int) -> torch.Tensor:
""" """
Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder image tokens in positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``. ``input_ids``.
Note: Note:
This updates ``inputs_embeds`` in place. This updates ``inputs_embeds`` in place.
""" """
mask = (input_ids == image_token_id) mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum() num_expected_tokens = mask.sum()
if isinstance(vision_embeddings, torch.Tensor): if isinstance(multimodal_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
total_tokens = batch_size * batch_tokens total_tokens = batch_size * batch_tokens
if num_expected_tokens != total_tokens: if num_expected_tokens != total_tokens:
expr = f"{batch_size} x {batch_tokens}" expr = f"{batch_size} x {batch_tokens}"
raise ValueError( raise ValueError(
f"Attempted to assign {expr} = {total_tokens} " f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders") f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim) inputs_embeds[mask] = multimodal_embeddings.view(
total_tokens, embed_dim)
else: else:
size_per_batch = [t.shape[0] for t in vision_embeddings] size_per_batch = [t.shape[0] for t in multimodal_embeddings]
total_tokens = sum(size_per_batch) total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens: if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch)) expr = ' + '.join(map(str, size_per_batch))
raise ValueError( raise ValueError(
f"Attempted to assign {expr} = {total_tokens} " f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders") f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = torch.cat(vision_embeddings) inputs_embeds[mask] = torch.cat(multimodal_embeddings)
return inputs_embeds return inputs_embeds
......
from vllm.inputs.registry import InputContext
from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin
class AudioPlugin(MultiModalPlugin):
"""Plugin for audio data."""
def get_data_key(self) -> str:
return "audio"
def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
raise NotImplementedError("There is no default audio input mapper")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
raise NotImplementedError(
"There is no default maximum multimodal tokens")
...@@ -3,8 +3,9 @@ from abc import ABC, abstractmethod ...@@ -3,8 +3,9 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Type, TypedDict, TypeVar, Union, cast from typing import Tuple, Type, TypedDict, TypeVar, Union, cast
import numpy as np
import torch import torch
import torch.types import torch.types
from PIL import Image from PIL import Image
...@@ -121,6 +122,9 @@ class MultiModalDataBuiltins(TypedDict, total=False): ...@@ -121,6 +122,9 @@ class MultiModalDataBuiltins(TypedDict, total=False):
image: Image.Image image: Image.Image
"""The input image.""" """The input image."""
audio: Tuple[np.ndarray, Union[int, float]]
"""The input audio and its sampling rate."""
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]] MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment