Unverified Commit 3a3b06ee authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Improve error message for `is_multimodal` (#30483)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f4417f84
...@@ -53,6 +53,22 @@ The output embeddings must be one of the following formats: ...@@ -53,6 +53,22 @@ The output embeddings must be one of the following formats:
""" """
def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
"""
A helper function to be used in the context of
[vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids][]
to provide a better error message.
"""
if is_multimodal is None:
raise ValueError(
"`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
return is_multimodal
@runtime_checkable @runtime_checkable
class SupportsMultiModal(Protocol): class SupportsMultiModal(Protocol):
"""The interface required for all multi-modal models.""" """The interface required for all multi-modal models."""
...@@ -190,12 +206,10 @@ class SupportsMultiModal(Protocol): ...@@ -190,12 +206,10 @@ class SupportsMultiModal(Protocol):
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds return inputs_embeds
assert is_multimodal is not None
return _merge_multimodal_embeddings( return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=_require_is_multimodal(is_multimodal),
) )
......
...@@ -64,6 +64,7 @@ from .interfaces import ( ...@@ -64,6 +64,7 @@ from .interfaces import (
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
SupportsQuant, SupportsQuant,
_require_is_multimodal,
) )
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
...@@ -687,12 +688,10 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) ...@@ -687,12 +688,10 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds return inputs_embeds
assert is_multimodal is not None
return _merge_multimodal_embeddings( return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings, multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal, is_multimodal=_require_is_multimodal(is_multimodal),
) )
def forward( def forward(
......
...@@ -93,6 +93,7 @@ from .interfaces import ( ...@@ -93,6 +93,7 @@ from .interfaces import (
SupportsMRoPE, SupportsMRoPE,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
_require_is_multimodal,
) )
from .qwen2_5_vl import ( from .qwen2_5_vl import (
Qwen2_5_VisionAttention, Qwen2_5_VisionAttention,
...@@ -1572,7 +1573,7 @@ class Qwen3VLForConditionalGeneration( ...@@ -1572,7 +1573,7 @@ class Qwen3VLForConditionalGeneration(
if multimodal_embeddings is None or len(multimodal_embeddings) == 0: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds return inputs_embeds
assert is_multimodal is not None is_multimodal = _require_is_multimodal(is_multimodal)
if self.use_deepstack: if self.use_deepstack:
( (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment