Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
14fdd21d
Unverified
Commit
14fdd21d
authored
Jun 18, 2025
by
Russell Bryant
Committed by
GitHub
Jun 18, 2025
Browse files
[Core] More fixes to MultiModalEmbeddings type handling (#19715)
Signed-off-by:
Russell Bryant
<
rbryant@redhat.com
>
parent
04fefe7c
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
31 additions
and
16 deletions
+31
-16
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+2
-1
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+2
-1
vllm/model_executor/models/ovis.py
vllm/model_executor/models/ovis.py
+2
-1
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+2
-1
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+2
-1
vllm/model_executor/models/phi4mm.py
vllm/model_executor/models/phi4mm.py
+2
-1
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+2
-1
vllm/model_executor/models/qwen2_5_omni_thinker.py
vllm/model_executor/models/qwen2_5_omni_thinker.py
+3
-2
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+2
-1
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+2
-1
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+2
-1
vllm/model_executor/models/qwen_vl.py
vllm/model_executor/models/qwen_vl.py
+2
-1
vllm/model_executor/models/skyworkr1v.py
vllm/model_executor/models/skyworkr1v.py
+2
-1
vllm/model_executor/models/tarsier.py
vllm/model_executor/models/tarsier.py
+2
-1
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+2
-1
No files found.
vllm/model_executor/models/mllama4.py
View file @
14fdd21d
...
...
@@ -808,7 +808,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
...
...
vllm/model_executor/models/molmo.py
View file @
14fdd21d
...
...
@@ -1487,7 +1487,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
assert
self
.
img_patch_id
is
not
None
inputs_embeds
=
merge_multimodal_embeddings
(
...
...
vllm/model_executor/models/ovis.py
View file @
14fdd21d
...
...
@@ -515,7 +515,8 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
llm
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
image_pad_token_id
)
...
...
vllm/model_executor/models/paligemma.py
View file @
14fdd21d
...
...
@@ -364,7 +364,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
config
.
image_token_index
)
...
...
vllm/model_executor/models/phi3v.py
View file @
14fdd21d
...
...
@@ -669,7 +669,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
if
multimodal_embeddings
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
image_token_id
)
...
...
vllm/model_executor/models/phi4mm.py
View file @
14fdd21d
...
...
@@ -1148,7 +1148,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
[
_IMAGE_PLACEHOLDER_TOKEN_ID
,
_AUDIO_PLACEHOLDER_TOKEN_ID
])
...
...
vllm/model_executor/models/pixtral.py
View file @
14fdd21d
...
...
@@ -423,7 +423,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
...
...
vllm/model_executor/models/qwen2_5_omni_thinker.py
View file @
14fdd21d
...
...
@@ -805,7 +805,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
# TODO (ywang96): support overlapping modalitiy embeddings so that
# `use_audio_in_video` will work on V1.
...
...
@@ -845,7 +846,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings
:
Optional
[
NestedTensors
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
None
:
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
return
inputs_embeds
for
embeddings
,
modality
in
multimodal_embeddings
:
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
14fdd21d
...
...
@@ -1046,7 +1046,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
[
self
.
config
.
image_token_id
,
self
.
config
.
video_token_id
])
...
...
vllm/model_executor/models/qwen2_audio.py
View file @
14fdd21d
...
...
@@ -364,7 +364,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
config
.
audio_token_index
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
14fdd21d
...
...
@@ -1289,7 +1289,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
[
self
.
config
.
image_token_id
,
self
.
config
.
video_token_id
])
...
...
vllm/model_executor/models/qwen_vl.py
View file @
14fdd21d
...
...
@@ -754,7 +754,8 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
transformer
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
transformer
.
visual
.
image_pad_id
)
...
...
vllm/model_executor/models/skyworkr1v.py
View file @
14fdd21d
...
...
@@ -883,7 +883,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
assert
self
.
img_context_token_id
is
not
None
self
.
_set_visual_token_mask
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
...
...
vllm/model_executor/models/tarsier.py
View file @
14fdd21d
...
...
@@ -598,7 +598,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
...
...
vllm/model_executor/models/ultravox.py
View file @
14fdd21d
...
...
@@ -560,7 +560,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
# TODO(ywang96): remove this block after v0 is deprecated.
if
not
envs
.
VLLM_USE_V1
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment