Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
90f9c2eb
Unverified
Commit
90f9c2eb
authored
Jun 16, 2025
by
Russell Bryant
Committed by
GitHub
Jun 16, 2025
Browse files
[V1] Change return type on get_multimodal_embeddings() (#19446)
Signed-off-by:
Russell Bryant
<
rbryant@redhat.com
>
parent
387bdf0a
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
59 additions
and
56 deletions
+59
-56
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+3
-3
vllm/model_executor/models/aya_vision.py
vllm/model_executor/models/aya_vision.py
+3
-3
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+3
-3
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+3
-3
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+3
-3
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+3
-3
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+3
-3
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+3
-3
vllm/model_executor/models/glm4v.py
vllm/model_executor/models/glm4v.py
+3
-3
vllm/model_executor/models/granite_speech.py
vllm/model_executor/models/granite_speech.py
+2
-1
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+3
-3
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+2
-2
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+3
-2
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+3
-3
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+4
-4
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+3
-3
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+3
-2
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+3
-3
vllm/model_executor/models/minimax_vl_01.py
vllm/model_executor/models/minimax_vl_01.py
+3
-3
vllm/model_executor/models/mistral3.py
vllm/model_executor/models/mistral3.py
+3
-3
No files found.
vllm/model_executor/models/aria.py
View file @
90f9c2eb
...
@@ -601,11 +601,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -601,11 +601,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
multimodal_embeddings
=
self
.
_process_image_input
(
image_input
)
multimodal_embeddings
=
self
.
_process_image_input
(
image_input
)
return
multimodal_embeddings
return
multimodal_embeddings
...
...
vllm/model_executor/models/aya_vision.py
View file @
90f9c2eb
...
@@ -406,11 +406,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -406,11 +406,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
return
self
.
_process_image_input
(
image_input
,
**
kwargs
)
return
self
.
_process_image_input
(
image_input
,
**
kwargs
)
...
...
vllm/model_executor/models/blip2.py
View file @
90f9c2eb
...
@@ -627,11 +627,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -627,11 +627,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
return
vision_embeddings
...
...
vllm/model_executor/models/chameleon.py
View file @
90f9c2eb
...
@@ -987,11 +987,11 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -987,11 +987,11 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
model
return
self
.
model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
assert
self
.
model
.
vqmodel
is
not
None
assert
self
.
model
.
vqmodel
is
not
None
image_tokens
=
self
.
model
.
get_image_tokens
(
image_input
[
"data"
].
to
(
image_tokens
=
self
.
model
.
get_image_tokens
(
image_input
[
"data"
].
to
(
self
.
config
.
torch_dtype
))
self
.
config
.
torch_dtype
))
...
...
vllm/model_executor/models/deepseek_vl2.py
View file @
90f9c2eb
...
@@ -586,11 +586,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -586,11 +586,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
return
vision_embeddings
...
...
vllm/model_executor/models/florence2.py
View file @
90f9c2eb
...
@@ -1032,11 +1032,11 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1032,11 +1032,11 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
return
vision_embeddings
...
...
vllm/model_executor/models/fuyu.py
View file @
90f9c2eb
...
@@ -324,11 +324,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -324,11 +324,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
return
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
...
...
vllm/model_executor/models/gemma3_mm.py
View file @
90f9c2eb
...
@@ -568,11 +568,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -568,11 +568,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
return
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
...
...
vllm/model_executor/models/glm4v.py
View file @
90f9c2eb
...
@@ -593,11 +593,11 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
...
@@ -593,11 +593,11 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
transformer
return
self
.
transformer
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
return
vision_embeddings
...
...
vllm/model_executor/models/granite_speech.py
View file @
90f9c2eb
...
@@ -706,10 +706,11 @@ class GraniteSpeechForConditionalGeneration(
...
@@ -706,10 +706,11 @@ class GraniteSpeechForConditionalGeneration(
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
,
**
kwargs
:
object
,
)
->
Optional
[
MultiModalEmbeddings
]
:
)
->
MultiModalEmbeddings
:
"""Compute the audio embeddings if audio inputs are present."""
"""Compute the audio embeddings if audio inputs are present."""
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
if
audio_input
is
None
:
if
audio_input
is
None
:
return
[]
return
None
return
None
audio_features
=
self
.
_process_audio_input
(
audio_input
)
audio_features
=
self
.
_process_audio_input
(
audio_input
)
return
audio_features
return
audio_features
...
...
vllm/model_executor/models/idefics3.py
View file @
90f9c2eb
...
@@ -706,11 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -706,11 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
model
return
self
.
model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
return
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
...
...
vllm/model_executor/models/interfaces.py
View file @
90f9c2eb
...
@@ -44,8 +44,8 @@ class SupportsMultiModal(Protocol):
...
@@ -44,8 +44,8 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
MRO of your model class.
"""
"""
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
"""
"""
Returns multimodal embeddings generated from multimodal kwargs
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
to be merged with text embeddings.
...
...
vllm/model_executor/models/internvl.py
View file @
90f9c2eb
...
@@ -1304,11 +1304,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -1304,11 +1304,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
modalities
:
if
not
modalities
:
return
[]
return
None
return
None
# The result multimodal_embeddings is tuple of tensors, with each
# The result multimodal_embeddings is tuple of tensors, with each
...
...
vllm/model_executor/models/llava.py
View file @
90f9c2eb
...
@@ -659,11 +659,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -659,11 +659,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
return
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
...
...
vllm/model_executor/models/llava_next.py
View file @
90f9c2eb
...
@@ -478,11 +478,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -478,11 +478,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
return
vision_embeddings
...
@@ -492,7 +492,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -492,7 +492,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
multimodal_embeddings
is
None
:
if
not
multimodal_embeddings
:
return
self
.
language_model
.
get_input_embeddings
(
input_ids
)
return
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
embed_multimodal
(
inputs_embeds
=
embed_multimodal
(
...
...
vllm/model_executor/models/llava_next_video.py
View file @
90f9c2eb
...
@@ -401,11 +401,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -401,11 +401,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
video_input
is
None
:
if
video_input
is
None
:
return
None
return
[]
vision_embeddings
=
self
.
_process_video_pixels
(
video_input
)
vision_embeddings
=
self
.
_process_video_pixels
(
video_input
)
return
vision_embeddings
return
vision_embeddings
...
...
vllm/model_executor/models/llava_onevision.py
View file @
90f9c2eb
...
@@ -839,11 +839,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -839,11 +839,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
**
kwargs
)
if
not
mm_input_by_modality
:
if
not
mm_input_by_modality
:
return
[]
return
None
return
None
# The result multimodal_embeddings is tuple of tensors, with each
# The result multimodal_embeddings is tuple of tensors, with each
...
...
vllm/model_executor/models/minicpmv.py
View file @
90f9c2eb
...
@@ -878,11 +878,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -878,11 +878,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
llm
return
self
.
llm
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
modalities
:
if
not
modalities
:
return
None
return
[]
return
self
.
_process_multimodal_inputs
(
modalities
)
return
self
.
_process_multimodal_inputs
(
modalities
)
...
...
vllm/model_executor/models/minimax_vl_01.py
View file @
90f9c2eb
...
@@ -318,11 +318,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -318,11 +318,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
return
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
...
...
vllm/model_executor/models/mistral3.py
View file @
90f9c2eb
...
@@ -495,11 +495,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
...
@@ -495,11 +495,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]
:
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
[]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment