Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bd2659a5
Unverified
Commit
bd2659a5
authored
Mar 08, 2026
by
Alex Brooks
Committed by
GitHub
Mar 08, 2026
Browse files
Increase Flexibility for OOV Multimodal Token Handling (#34858)
Signed-off-by:
Alex Brooks
<
albrooks@redhat.com
>
parent
90512b2e
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
61 additions
and
59 deletions
+61
-59
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+0
-4
vllm/model_executor/models/eagle2_5_vl.py
vllm/model_executor/models/eagle2_5_vl.py
+0
-2
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+0
-2
vllm/model_executor/models/funasr.py
vllm/model_executor/models/funasr.py
+0
-1
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+5
-2
vllm/model_executor/models/gemma3n_mm.py
vllm/model_executor/models/gemma3n_mm.py
+0
-2
vllm/model_executor/models/granite_speech.py
vllm/model_executor/models/granite_speech.py
+6
-3
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+31
-18
vllm/model_executor/models/interns1.py
vllm/model_executor/models/interns1.py
+0
-2
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+0
-2
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+5
-0
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+5
-3
vllm/model_executor/models/molmo2.py
vllm/model_executor/models/molmo2.py
+0
-2
vllm/model_executor/models/nemotron_vl.py
vllm/model_executor/models/nemotron_vl.py
+0
-2
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+0
-2
vllm/model_executor/models/qwen2_5_omni_thinker.py
vllm/model_executor/models/qwen2_5_omni_thinker.py
+9
-3
vllm/model_executor/models/qwen3_5.py
vllm/model_executor/models/qwen3_5.py
+0
-2
vllm/model_executor/models/qwen3_5_mtp.py
vllm/model_executor/models/qwen3_5_mtp.py
+0
-2
vllm/model_executor/models/qwen3_asr.py
vllm/model_executor/models/qwen3_asr.py
+0
-2
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+0
-3
No files found.
vllm/model_executor/models/clip.py
View file @
bd2659a5
...
@@ -931,13 +931,11 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
...
@@ -931,13 +931,11 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
embed_input_ids
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
embed_input_ids
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
,
is_multimodal
:
torch
.
Tensor
|
None
,
handle_oov_mm_token
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
super
().
_embed_text_input_ids
(
inputs_embeds
=
super
().
_embed_text_input_ids
(
input_ids
,
input_ids
,
embed_input_ids
,
embed_input_ids
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
# NOTE: inputs_embeds in model runner has size text_config.projection_dim
# NOTE: inputs_embeds in model runner has size text_config.projection_dim
...
@@ -966,7 +964,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
...
@@ -966,7 +964,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
self
.
_is_text_input
=
(
self
.
_is_text_input
=
(
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
...
@@ -980,7 +977,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
...
@@ -980,7 +977,6 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
...
...
vllm/model_executor/models/eagle2_5_vl.py
View file @
bd2659a5
...
@@ -416,7 +416,6 @@ class Eagle2_5_VLForConditionalGeneration(
...
@@ -416,7 +416,6 @@ class Eagle2_5_VLForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Embed input IDs with optional multimodal embeddings."""
"""Embed input IDs with optional multimodal embeddings."""
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
...
@@ -426,7 +425,6 @@ class Eagle2_5_VLForConditionalGeneration(
...
@@ -426,7 +425,6 @@ class Eagle2_5_VLForConditionalGeneration(
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
bd2659a5
...
@@ -1664,7 +1664,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
...
@@ -1664,7 +1664,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
multimodal_embeddings
is
not
None
and
len
(
multimodal_embeddings
)
>
0
:
if
multimodal_embeddings
is
not
None
and
len
(
multimodal_embeddings
)
>
0
:
self
.
_set_visual_token_mask
(
input_ids
)
self
.
_set_visual_token_mask
(
input_ids
)
...
@@ -1677,7 +1676,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
...
@@ -1677,7 +1676,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/funasr.py
View file @
bd2659a5
...
@@ -975,7 +975,6 @@ class FunASRForConditionalGeneration(
...
@@ -975,7 +975,6 @@ class FunASRForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
model
.
decoder
.
embed_input_ids
(
input_ids
)
inputs_embeds
=
self
.
model
.
decoder
.
embed_input_ids
(
input_ids
)
...
...
vllm/model_executor/models/gemma3_mm.py
View file @
bd2659a5
...
@@ -507,6 +507,11 @@ class Gemma3ForConditionalGeneration(
...
@@ -507,6 +507,11 @@ class Gemma3ForConditionalGeneration(
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
self
.
configure_mm_token_handling
(
vocab_size
=
config
.
text_config
.
vocab_size
,
mm_token_ids
=
[
config
.
image_token_index
],
)
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
vision_tower
=
SiglipVisionModel
(
self
.
vision_tower
=
SiglipVisionModel
(
config
.
vision_config
,
config
.
vision_config
,
...
@@ -587,7 +592,6 @@ class Gemma3ForConditionalGeneration(
...
@@ -587,7 +592,6 @@ class Gemma3ForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Early return for text-only inference (no multimodal data)
# Early return for text-only inference (no multimodal data)
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
...
@@ -598,7 +602,6 @@ class Gemma3ForConditionalGeneration(
...
@@ -598,7 +602,6 @@ class Gemma3ForConditionalGeneration(
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/gemma3n_mm.py
View file @
bd2659a5
...
@@ -685,7 +685,6 @@ class Gemma3nForConditionalGeneration(
...
@@ -685,7 +685,6 @@ class Gemma3nForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
# them here, as the model forward has only access to the input_embeds.
# them here, as the model forward has only access to the input_embeds.
...
@@ -710,7 +709,6 @@ class Gemma3nForConditionalGeneration(
...
@@ -710,7 +709,6 @@ class Gemma3nForConditionalGeneration(
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/granite_speech.py
View file @
bd2659a5
...
@@ -600,6 +600,12 @@ class GraniteSpeechForConditionalGeneration(
...
@@ -600,6 +600,12 @@ class GraniteSpeechForConditionalGeneration(
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
# Check for OOV tokens to see if offsets need to be preserved
self
.
configure_mm_token_handling
(
vocab_size
=
config
.
text_config
.
vocab_size
,
mm_token_ids
=
[
config
.
audio_token_index
],
)
with
self
.
_mark_language_model
(
vllm_config
):
with
self
.
_mark_language_model
(
vllm_config
):
# The language model is typically a Granite LLM
# The language model is typically a Granite LLM
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
...
@@ -793,8 +799,6 @@ class GraniteSpeechForConditionalGeneration(
...
@@ -793,8 +799,6 @@ class GraniteSpeechForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# This is to satisfy the type checker for each overload
# This is to satisfy the type checker for each overload
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
...
@@ -804,7 +808,6 @@ class GraniteSpeechForConditionalGeneration(
...
@@ -804,7 +808,6 @@ class GraniteSpeechForConditionalGeneration(
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/interfaces.py
View file @
bd2659a5
...
@@ -130,6 +130,13 @@ class SupportsMultiModal(Protocol):
...
@@ -130,6 +130,13 @@ class SupportsMultiModal(Protocol):
Set internally by `_mark_tower_model`.
Set internally by `_mark_tower_model`.
"""
"""
_has_oov_mm_tokens
:
bool
=
False
"""
In general, this should be set at init time by invoking
`configure_mm_token_handling` models & passing all potentially
OOV multimodal tokens.
"""
@
classmethod
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
"""
"""
...
@@ -149,6 +156,17 @@ class SupportsMultiModal(Protocol):
...
@@ -149,6 +156,17 @@ class SupportsMultiModal(Protocol):
"""
"""
...
...
def
configure_mm_token_handling
(
self
,
vocab_size
:
int
,
mm_token_ids
:
list
[
int
]):
"""Check if any multimodal tokens are out of vocabulary. If so, we will
explicitly mask all multimodal tokens out when computing text embeddings,
since the multimodal embeddings will be scattered over the results.
"""
self
.
_has_oov_mm_tokens
=
any
(
tok_id
>=
vocab_size
for
tok_id
in
mm_token_ids
)
logger
.
info
(
"Contains out of vocabulary multimodal tokens? %s"
,
self
.
_has_oov_mm_tokens
,
)
def
get_language_model
(
self
)
->
VllmModel
:
def
get_language_model
(
self
)
->
VllmModel
:
"""
"""
Returns the underlying language model used for text generation.
Returns the underlying language model used for text generation.
...
@@ -324,7 +342,6 @@ class SupportsMultiModal(Protocol):
...
@@ -324,7 +342,6 @@ class SupportsMultiModal(Protocol):
multimodal_embeddings
:
MultiModalEmbeddings
,
multimodal_embeddings
:
MultiModalEmbeddings
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
,
is_multimodal
:
torch
.
Tensor
,
handle_oov_mm_token
:
bool
=
False
,
)
->
Tensor
:
...
)
->
Tensor
:
...
def
_embed_text_input_ids
(
def
_embed_text_input_ids
(
...
@@ -333,17 +350,14 @@ class SupportsMultiModal(Protocol):
...
@@ -333,17 +350,14 @@ class SupportsMultiModal(Protocol):
embed_input_ids
:
Callable
[[
Tensor
],
Tensor
],
embed_input_ids
:
Callable
[[
Tensor
],
Tensor
],
*
,
*
,
is_multimodal
:
Tensor
|
None
,
is_multimodal
:
Tensor
|
None
,
handle_oov_mm_token
:
bool
,
)
->
Tensor
:
)
->
Tensor
:
if
handle_oov_mm_token
and
is_multimodal
is
not
None
:
if
is_multimodal
is
not
None
and
self
.
_has_oov_mm_tokens
:
is_text
=
~
is_multimodal
# Force all input IDs to be in vocab; we do this instead of squeezing
text_embeds
=
embed_input_ids
(
input_ids
[
is_text
])
# to ensure that any external configuration requiring offset tracking,
# e.g., LoRA, are applied correctly regardless of whether or not
return
torch
.
empty
(
# we have multimodal tokens.
(
input_ids
.
shape
[
0
],
text_embeds
.
shape
[
1
]),
in_vocab_ids
=
input_ids
.
masked_fill
(
is_multimodal
,
0
)
dtype
=
text_embeds
.
dtype
,
return
embed_input_ids
(
in_vocab_ids
)
device
=
text_embeds
.
device
,
).
masked_scatter_
(
is_text
.
unsqueeze_
(
-
1
),
text_embeds
)
return
embed_input_ids
(
input_ids
)
return
embed_input_ids
(
input_ids
)
...
@@ -353,7 +367,6 @@ class SupportsMultiModal(Protocol):
...
@@ -353,7 +367,6 @@ class SupportsMultiModal(Protocol):
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
Tensor
|
None
=
None
,
is_multimodal
:
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
Tensor
:
)
->
Tensor
:
"""
"""
Apply token embeddings to `input_ids`.
Apply token embeddings to `input_ids`.
...
@@ -361,19 +374,19 @@ class SupportsMultiModal(Protocol):
...
@@ -361,19 +374,19 @@ class SupportsMultiModal(Protocol):
If `multimodal_embeddings` is passed, scatter them into
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
NOTE: If this model has multimodal tokens that are of vocabulary
the language model, you can set `handle_oov_mm_token=False`
(i.e., self._has_oov_mm_tokens=True), the input_ids will be copied
to avoid calling the language model's `embed_input_ids` method
and masked to 0 during the forward pass for the text embeddings.
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
"""
from
.utils
import
_merge_multimodal_embeddings
from
.utils
import
_merge_multimodal_embeddings
# Get text embeddings first; multimodal embeddings will clobber
# any invalid contents in the indices of multimodal embeddings
# for the in vocabulary and out of vocabulary case.
inputs_embeds
=
self
.
_embed_text_input_ids
(
inputs_embeds
=
self
.
_embed_text_input_ids
(
input_ids
,
input_ids
,
self
.
get_language_model
().
embed_input_ids
,
self
.
get_language_model
().
embed_input_ids
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
...
...
vllm/model_executor/models/interns1.py
View file @
bd2659a5
...
@@ -764,7 +764,6 @@ class InternS1ForConditionalGeneration(
...
@@ -764,7 +764,6 @@ class InternS1ForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
multimodal_embeddings
is
not
None
and
len
(
multimodal_embeddings
)
>
0
:
if
multimodal_embeddings
is
not
None
and
len
(
multimodal_embeddings
)
>
0
:
self
.
_set_visual_token_mask
(
input_ids
)
self
.
_set_visual_token_mask
(
input_ids
)
...
@@ -777,7 +776,6 @@ class InternS1ForConditionalGeneration(
...
@@ -777,7 +776,6 @@ class InternS1ForConditionalGeneration(
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/internvl.py
View file @
bd2659a5
...
@@ -1347,7 +1347,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
...
@@ -1347,7 +1347,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
multimodal_embeddings
is
not
None
and
len
(
multimodal_embeddings
)
>
0
:
if
multimodal_embeddings
is
not
None
and
len
(
multimodal_embeddings
)
>
0
:
self
.
_set_visual_token_mask
(
input_ids
)
self
.
_set_visual_token_mask
(
input_ids
)
...
@@ -1360,7 +1359,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
...
@@ -1360,7 +1359,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/llava.py
View file @
bd2659a5
...
@@ -544,6 +544,11 @@ class LlavaForConditionalGeneration(
...
@@ -544,6 +544,11 @@ class LlavaForConditionalGeneration(
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
self
.
configure_mm_token_handling
(
vocab_size
=
config
.
text_config
.
vocab_size
,
mm_token_ids
=
[
config
.
image_token_index
],
)
# NOTE: These are special cases for Pixtral-12B in the HF-format
# NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if
(
if
(
...
...
vllm/model_executor/models/llava_next.py
View file @
bd2659a5
...
@@ -270,6 +270,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
...
@@ -270,6 +270,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
multimodal_config
=
multimodal_config
self
.
configure_mm_token_handling
(
vocab_size
=
config
.
text_config
.
vocab_size
,
mm_token_ids
=
[
config
.
image_token_index
],
)
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
with
self
.
_mark_tower_model
(
vllm_config
,
"image"
):
self
.
vision_tower
=
init_vision_tower_for_llava
(
self
.
vision_tower
=
init_vision_tower_for_llava
(
config
,
config
,
...
@@ -497,8 +502,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
...
@@ -497,8 +502,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# This is to satisfy the type checker for each overload
# This is to satisfy the type checker for each overload
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
...
@@ -508,7 +511,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
...
@@ -508,7 +511,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/molmo2.py
View file @
bd2659a5
...
@@ -2711,13 +2711,11 @@ class Molmo2ForConditionalGeneration(
...
@@ -2711,13 +2711,11 @@ class Molmo2ForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
_embed_text_input_ids
(
inputs_embeds
=
self
.
_embed_text_input_ids
(
input_ids
,
input_ids
,
self
.
get_language_model
().
embed_input_ids
,
self
.
get_language_model
().
embed_input_ids
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
...
...
vllm/model_executor/models/nemotron_vl.py
View file @
bd2659a5
...
@@ -628,7 +628,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
...
@@ -628,7 +628,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
multimodal_embeddings
is
not
None
and
len
(
multimodal_embeddings
)
>
0
:
if
multimodal_embeddings
is
not
None
and
len
(
multimodal_embeddings
)
>
0
:
self
.
_set_visual_token_mask
(
input_ids
)
self
.
_set_visual_token_mask
(
input_ids
)
...
@@ -641,7 +640,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
...
@@ -641,7 +640,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/phi3v.py
View file @
bd2659a5
...
@@ -663,13 +663,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
...
@@ -663,13 +663,11 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
_embed_text_input_ids
(
inputs_embeds
=
self
.
_embed_text_input_ids
(
input_ids
,
input_ids
,
self
.
embed_tokens
,
self
.
embed_tokens
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
...
...
vllm/model_executor/models/qwen2_5_omni_thinker.py
View file @
bd2659a5
...
@@ -1428,11 +1428,19 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
...
@@ -1428,11 +1428,19 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
return
super
().
embed_input_ids
(
input_ids
)
return
super
().
embed_input_ids
(
input_ids
)
inputs_embeds
=
self
.
_embed_text_input_ids
(
input_ids
,
self
.
get_language_model
().
embed_input_ids
,
is_multimodal
=
is_multimodal
,
)
if
len
(
multimodal_embeddings
)
==
0
:
return
inputs_embeds
# Check for audio-in-video: interleaved video and audio tokens
# Check for audio-in-video: interleaved video and audio tokens
# in the multimodal region. Only use the interleaved path when
# in the multimodal region. Only use the interleaved path when
# needed; otherwise fall back to the default parent implementation.
# needed; otherwise fall back to the default parent implementation.
...
@@ -1450,7 +1458,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
...
@@ -1450,7 +1458,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
input_ids
,
input_ids
,
self
.
get_language_model
().
embed_input_ids
,
self
.
get_language_model
().
embed_input_ids
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
return
merge_interleaved_embeddings
(
return
merge_interleaved_embeddings
(
inputs_embeds
,
inputs_embeds
,
...
@@ -1467,7 +1474,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
...
@@ -1467,7 +1474,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/qwen3_5.py
View file @
bd2659a5
...
@@ -672,13 +672,11 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
...
@@ -672,13 +672,11 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid)
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
_embed_text_input_ids
(
inputs_embeds
=
self
.
_embed_text_input_ids
(
input_ids
,
input_ids
,
self
.
language_model
.
embed_input_ids
,
self
.
language_model
.
embed_input_ids
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
...
...
vllm/model_executor/models/qwen3_5_mtp.py
View file @
bd2659a5
...
@@ -380,13 +380,11 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal):
...
@@ -380,13 +380,11 @@ class Qwen3_5MTP(nn.Module, SupportsMultiModal):
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
_embed_text_input_ids
(
inputs_embeds
=
self
.
_embed_text_input_ids
(
input_ids
,
input_ids
,
self
.
model
.
embed_input_ids
,
self
.
model
.
embed_input_ids
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
...
...
vllm/model_executor/models/qwen3_asr.py
View file @
bd2659a5
...
@@ -389,13 +389,11 @@ class Qwen3ASRForConditionalGeneration(
...
@@ -389,13 +389,11 @@ class Qwen3ASRForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
_embed_text_input_ids
(
inputs_embeds
=
self
.
_embed_text_input_ids
(
input_ids
,
input_ids
,
self
.
language_model
.
embed_input_ids
,
self
.
language_model
.
embed_input_ids
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
bd2659a5
...
@@ -1851,13 +1851,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
...
@@ -1851,13 +1851,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
_embed_text_input_ids
(
inputs_embeds
=
self
.
_embed_text_input_ids
(
input_ids
,
input_ids
,
self
.
language_model
.
embed_input_ids
,
self
.
language_model
.
embed_input_ids
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
...
@@ -1962,7 +1960,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
...
@@ -1962,7 +1960,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
input_ids
,
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
)
def
forward
(
def
forward
(
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment