Unverified Commit 00c3d68e authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Frontend][Core] Add plumbing to support audio language models (#7446)

parent e20233d3
......@@ -112,6 +112,8 @@ autodoc_mock_imports = [
"tensorizer",
"pynvml",
"outlines",
"librosa",
"soundfile",
"gguf",
"lark",
]
......
......@@ -15,14 +15,14 @@ This document walks you through the steps to extend a vLLM model so that it acce
It is assumed that you have already implemented the model in vLLM according to :ref:`these steps <adding_a_new_model>`.
Further update the model as follows:
- Implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface.
- Implement the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
.. code-block:: diff
+ from vllm.model_executor.models.interfaces import SupportsVision
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
- class YourModelForImage2Seq(nn.Module):
+ class YourModelForImage2Seq(nn.Module, SupportsVision):
+ class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
.. note::
The model class does not have to be named :code:`*ForCausalLM`.
......@@ -51,11 +51,11 @@ This decorator accepts a function that maps multi-modal inputs to the keyword ar
.. code-block:: diff
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsMultiModal
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
......@@ -72,13 +72,13 @@ and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.regis
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
Here are some examples:
......@@ -98,13 +98,13 @@ In such cases, you can define your own dummy data by registering a factory metho
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
.. note::
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
......@@ -128,14 +128,14 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:
......
......@@ -20,4 +20,6 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
import math
import sys
import time
from typing import Dict, List, Optional, Tuple, Union, cast
from unittest.mock import patch
import librosa
import numpy as np
import openai
import pytest
import requests
import torch
from vllm import ModelRegistry
from vllm.config import MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
from vllm.utils import get_open_port
from ...utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
MODEL_NAME = "facebook/opt-125m"
TEST_AUDIO_URLS = [
"https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
]
def server_function(port):
def fake_input_mapper(ctx: InputContext, data: object):
assert isinstance(data, tuple)
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
# Resample it to 1 sample per second
audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return llm_inputs
audio, sr = multi_modal_data.get("audio")
audio_duration = math.ceil(len(audio) / sr)
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
cached_get_tokenizer(ctx.model_config.tokenizer),
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=62, # "_"
repeat_count=audio_duration)
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", lambda *_, **__: 100)
@INPUT_REGISTRY.register_input_processor(fake_input_processor)
class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
def __init__(self, *args, multimodal_config: MultiModalConfig,
**kwargs):
assert multimodal_config is not None
super().__init__(*args, **kwargs)
def forward(
self,
*args,
processed_audio: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return super().forward(*args, **kwargs)
ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
with patch("vllm.entrypoints.chat_utils._mm_token_str",
lambda *_, **__: "_"):
sys.argv = ["placeholder.py"] + \
(f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
"--dtype bfloat16 --enforce-eager --api-key token-abc123 "
f"--port {port} --chat-template {chatml_jinja_path} "
"--disable-frontend-multiprocessing").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server',
run_name='__main__')
@pytest.fixture(scope="module")
def client():
port = get_open_port()
ctx = torch.multiprocessing.get_context("spawn")
server = ctx.Process(target=server_function, args=(port, ))
server.start()
MAX_SERVER_START_WAIT_S = 60
client = openai.AsyncOpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
# run health check
health_url = f"http://localhost:{port}/health"
start = time.time()
while True:
try:
if requests.get(health_url).status_code == 200:
break
except Exception as err:
result = server.exitcode
if result is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT_S:
raise RuntimeError("Server failed to start in time.") from err
try:
yield client
finally:
server.kill()
@pytest.fixture(scope="session")
def base64_encoded_audio() -> Dict[str, str]:
return {
audio_url: encode_audio_base64(*fetch_audio(audio_url))
for audio_url in TEST_AUDIO_URLS
}
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
model_name: str, audio_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_single_chat_session_audio_base64encoded(
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
base64_encoded_audio: Dict[str, str]):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url":
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
model_name: str, audio_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
)
output = chat_completion.choices[0].message.content
stop_reason = chat_completion.choices[0].finish_reason
# test streaming
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert delta.content
assert "".join(chunks) == output
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
audio_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
with pytest.raises(openai.BadRequestError): # test multi-audio input
await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
)
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
......@@ -2,7 +2,8 @@ import codecs
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union, cast)
# yapf conflicts with isort for this block
# yapf: disable
......@@ -21,12 +22,27 @@ from typing_extensions import Required, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
class AudioURL(TypedDict, total=False):
url: Required[str]
"""
Either a URL of the audio or a data URL with base64 encoded audio data.
"""
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
audio_url: Required[AudioURL]
type: Required[Literal["audio_url"]]
"""The type of the content part."""
class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
......@@ -35,6 +51,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam]
......@@ -97,10 +114,11 @@ def load_chat_template(
@lru_cache(maxsize=None)
def _image_token_str(model_config: ModelConfig,
tokenizer: PreTrainedTokenizer) -> Optional[str]:
def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
if modality == "image":
model_type = model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
......@@ -114,17 +132,23 @@ def _image_token_str(model_config: ModelConfig,
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
raise TypeError("No audio models are supported yet.")
else:
raise TypeError(f"Unknown modality: {modality}")
# TODO: Let user specify how to insert image tokens into prompt
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
"""Combine image and text prompts for vision language model"""
def _get_full_multimodal_text_prompt(placeholder_token_str: str,
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return f"{image_token_str}\n{text_prompt}"
# placeholder + text prompt format. This may change in the future.
return f"{placeholder_token_str}\n{text_prompt}"
def _parse_chat_message_content_parts(
......@@ -135,6 +159,7 @@ def _parse_chat_message_content_parts(
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
modality: Literal["image", "audio"] = "image"
for part in parts:
part_type = part["type"]
......@@ -142,9 +167,10 @@ def _parse_chat_message_content_parts(
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text)
elif part_type == "image_url":
modality = "image"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported.")
"Multiple multimodal inputs is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
......@@ -156,21 +182,32 @@ def _parse_chat_message_content_parts(
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
elif part_type == "audio_url":
modality = "audio"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
audio_url = cast(ChatCompletionContentPartAudioParam,
part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if mm_futures:
image_token_str = _image_token_str(model_config, tokenizer)
if image_token_str is not None:
if image_token_str in text_prompt:
placeholder_token_str = _mm_token_str(model_config, tokenizer,
modality)
if placeholder_token_str is not None:
if placeholder_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Detected multi-modal token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_image_text_prompt(
image_token_str=image_token_str,
text_prompt = _get_full_multimodal_text_prompt(
placeholder_token_str=placeholder_token_str,
text_prompt=text_prompt,
)
......
......@@ -44,6 +44,7 @@ if TYPE_CHECKING:
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_AUDIO_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
......@@ -321,6 +322,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Timeout for fetching audio when serving multimodal models
# Default is 5 seconds
"VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "5")),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH":
......
......@@ -38,7 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator)
from vllm.model_executor.models.interfaces import (has_inner_state,
supports_lora,
supports_vision)
supports_multimodal)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
......@@ -131,7 +131,7 @@ def _get_model_initialization_kwargs(
"be added in the future. If this is important to you, "
"please open an issue on github.")
if supports_vision(model_class):
if supports_multimodal(model_class):
if multimodal_config is None:
raise ValueError("Provide vision related configurations "
"through LLM entrypoint or engine arguments.")
......
......@@ -20,8 +20,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
......@@ -457,7 +457,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: Blip2Config,
......@@ -621,8 +621,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)
input_ids = None
......
......@@ -35,7 +35,7 @@ from vllm.multimodal.image import (cached_get_tokenizer,
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.utils import print_warning_once
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
logger = init_logger(__name__)
......@@ -886,7 +886,7 @@ class ChameleonModel(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(
self,
......
......@@ -40,8 +40,8 @@ from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
......@@ -209,7 +209,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsVision):
class FuyuForCausalLM(nn.Module, SupportsMultiModal):
def __init__(self,
config: FuyuConfig,
......@@ -271,8 +271,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
......
......@@ -10,12 +10,15 @@ logger = init_logger(__name__)
@runtime_checkable
class SupportsVision(Protocol):
"""The interface required for all vision language models (VLMs)."""
class SupportsMultiModal(Protocol):
"""
The interface required for all multimodal (vision or audio) language
models.
"""
supports_vision: ClassVar[Literal[True]] = True
supports_multimodal: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports vision inputs.
A flag that indicates this model supports multimodal inputs.
Note:
There is no need to redefine this flag if this class is in the
......@@ -29,30 +32,31 @@ class SupportsVision(Protocol):
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsVisionType(Protocol):
supports_vision: Literal[True]
class _SupportsMultiModalType(Protocol):
supports_multimodal: Literal[True]
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
...
@overload
def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]:
def supports_multimodal(
model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
...
@overload
def supports_vision(model: object) -> TypeIs[SupportsVision]:
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
...
def supports_vision(
def supports_multimodal(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]:
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
if isinstance(model, type):
return isinstance(model, _SupportsVisionType)
return isinstance(model, _SupportsMultiModalType)
return isinstance(model, SupportsVision)
return isinstance(model, SupportsMultiModal)
@runtime_checkable
......
......@@ -27,9 +27,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
merge_multimodal_embeddings)
IMG_START = '<img>'
IMG_END = '</img>'
......@@ -292,7 +292,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
class InternVLChatModel(nn.Module, SupportsVision):
class InternVLChatModel(nn.Module, SupportsMultiModal):
def __init__(self,
config: PretrainedConfig,
......@@ -451,8 +451,8 @@ class InternVLChatModel(nn.Module, SupportsVision):
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
input_ids = None
else:
......
......@@ -19,12 +19,12 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict):
......@@ -181,7 +181,7 @@ def _init_vision_tower(hf_config: LlavaConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: LlavaConfig,
......@@ -338,7 +338,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_vision_embeddings(
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
......
......@@ -23,13 +23,13 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
merge_multimodal_embeddings)
logger = init_logger(__name__)
......@@ -275,7 +275,7 @@ def _init_vision_tower(hf_config: LlavaNextConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: LlavaNextConfig,
......@@ -571,7 +571,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_vision_embeddings(
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
......
......@@ -48,7 +48,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.qwen2 import Qwen2Model
......@@ -479,7 +479,7 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
class MiniCPMVBaseModel(nn.Module, SupportsVision):
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
......
......@@ -19,10 +19,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import merge_vision_embeddings
from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
......@@ -130,7 +130,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: PaliGemmaConfig,
......@@ -244,7 +244,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
......
......@@ -42,8 +42,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
......@@ -453,7 +453,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
def __init__(self,
config: PretrainedConfig,
......@@ -568,8 +568,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
input_ids = None
else:
......
......@@ -54,41 +54,42 @@ def init_vllm_registered_model(
)
def merge_vision_embeddings(input_ids: torch.Tensor,
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: BatchedTensors,
image_token_id: int) -> torch.Tensor:
multimodal_embeddings: BatchedTensors,
placeholder_token_id: int) -> torch.Tensor:
"""
Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder image tokens in
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
mask = (input_ids == image_token_id)
mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum()
if isinstance(vision_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
if isinstance(multimodal_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
total_tokens = batch_size * batch_tokens
if num_expected_tokens != total_tokens:
expr = f"{batch_size} x {batch_tokens}"
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders")
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
inputs_embeds[mask] = multimodal_embeddings.view(
total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in vision_embeddings]
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders")
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = torch.cat(vision_embeddings)
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
return inputs_embeds
......
from vllm.inputs.registry import InputContext
from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin
class AudioPlugin(MultiModalPlugin):
"""Plugin for audio data."""
def get_data_key(self) -> str:
return "audio"
def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
raise NotImplementedError("There is no default audio input mapper")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
raise NotImplementedError(
"There is no default maximum multimodal tokens")
......@@ -3,8 +3,9 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from typing import Any, Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypedDict, TypeVar, Union, cast
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast
import numpy as np
import torch
import torch.types
from PIL import Image
......@@ -121,6 +122,9 @@ class MultiModalDataBuiltins(TypedDict, total=False):
image: Image.Image
"""The input image."""
audio: Tuple[np.ndarray, Union[int, float]]
"""The input audio and its sampling rate."""
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment