Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cb15ee28
Unverified
Commit
cb15ee28
authored
Nov 15, 2025
by
tingtinggithub
Committed by
GitHub
Nov 15, 2025
Browse files
Allow Gemma3 to take image embeddings (#28483)
Signed-off-by:
tingtinggithub
<
streamttt@gmail.com
>
parent
f36292db
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
69 additions
and
29 deletions
+69
-29
docs/models/supported_models.md
docs/models/supported_models.md
+1
-1
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+55
-22
vllm/multimodal/parse.py
vllm/multimodal/parse.py
+6
-5
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+7
-1
No files found.
docs/models/supported_models.md
View file @
cb15ee28
...
@@ -669,7 +669,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
...
@@ -669,7 +669,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
`DeepseekOCRForCausalLM`
| DeepSeek-OCR | T + I
<sup>
+
</sup>
|
`deepseek-ai/DeepSeek-OCR`
, etc. | | ✅︎ |
|
`DeepseekOCRForCausalLM`
| DeepSeek-OCR | T + I
<sup>
+
</sup>
|
`deepseek-ai/DeepSeek-OCR`
, etc. | | ✅︎ |
|
`Ernie4_5_VLMoeForConditionalGeneration`
| Ernie4.5-VL | T + I
<sup>
+
</sup>
/ V
<sup>
+
</sup>
|
`baidu/ERNIE-4.5-VL-28B-A3B-PT`
,
`baidu/ERNIE-4.5-VL-424B-A47B-PT`
| | ✅︎ |
|
`Ernie4_5_VLMoeForConditionalGeneration`
| Ernie4.5-VL | T + I
<sup>
+
</sup>
/ V
<sup>
+
</sup>
|
`baidu/ERNIE-4.5-VL-28B-A3B-PT`
,
`baidu/ERNIE-4.5-VL-424B-A47B-PT`
| | ✅︎ |
|
`FuyuForCausalLM`
| Fuyu | T + I |
`adept/fuyu-8b`
, etc. | | ✅︎ |
|
`FuyuForCausalLM`
| Fuyu | T + I |
`adept/fuyu-8b`
, etc. | | ✅︎ |
|
`Gemma3ForConditionalGeneration`
| Gemma 3 | T + I
<sup>
+
</sup>
|
`google/gemma-3-4b-it`
,
`google/gemma-3-27b-it`
, etc. | ✅︎ | ✅︎ |
|
`Gemma3ForConditionalGeneration`
| Gemma 3 | T + I
<sup>
E
+
</sup>
|
`google/gemma-3-4b-it`
,
`google/gemma-3-27b-it`
, etc. | ✅︎ | ✅︎ |
|
`Gemma3nForConditionalGeneration`
| Gemma 3n | T + I + A |
`google/gemma-3n-E2B-it`
,
`google/gemma-3n-E4B-it`
, etc. | | |
|
`Gemma3nForConditionalGeneration`
| Gemma 3n | T + I + A |
`google/gemma-3n-E2B-it`
,
`google/gemma-3n-E4B-it`
, etc. | | |
|
`GLM4VForCausalLM`
<sup>
^
</sup>
| GLM-4V | T + I |
`zai-org/glm-4v-9b`
,
`zai-org/cogagent-9b-20241220`
, etc. | ✅︎ | ✅︎ |
|
`GLM4VForCausalLM`
<sup>
^
</sup>
| GLM-4V | T + I |
`zai-org/glm-4v-9b`
,
`zai-org/cogagent-9b-20241220`
, etc. | ✅︎ | ✅︎ |
|
`Glm4vForConditionalGeneration`
| GLM-4.1V-Thinking | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.1V-9B-Thinking`
, etc. | ✅︎ | ✅︎ |
|
`Glm4vForConditionalGeneration`
| GLM-4.1V-Thinking | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`zai-org/GLM-4.1V-9B-Thinking`
, etc. | ✅︎ | ✅︎ |
...
...
vllm/model_executor/models/gemma3_mm.py
View file @
cb15ee28
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Annotated
,
Any
,
Literal
from
typing
import
Annotated
,
Any
,
Literal
,
TypeAlias
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -20,7 +20,12 @@ from vllm.multimodal.inputs import (
...
@@ -20,7 +20,12 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
MultiModalKwargsItems
,
)
)
from
vllm.multimodal.parse
import
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
,
)
from
vllm.multimodal.processing
import
(
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
...
@@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema):
...
@@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema):
num_patches
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
)]
num_patches
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
)]
Gemma3ImageInputs
=
Gemma3ImagePixelInputs
class
Gemma3ImageEmbeddingInputs
(
TensorSchema
):
type
:
Literal
[
"image_embeds"
]
=
"image_embeds"
image_embeds
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
"nf"
,
"hs"
),
]
Gemma3ImageInputs
:
TypeAlias
=
Gemma3ImagePixelInputs
|
Gemma3ImageEmbeddingInputs
class
Gemma3ProcessingInfo
(
BaseProcessingInfo
):
class
Gemma3ProcessingInfo
(
BaseProcessingInfo
):
...
@@ -178,8 +191,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -178,8 +191,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def
get_image_repl
(
def
get_image_repl
(
self
,
self
,
*
,
*
,
image_width
:
int
,
image_width
:
int
|
None
,
image_height
:
int
,
image_height
:
int
|
None
,
num_crops
:
int
|
None
=
None
,
processor
:
Gemma3Processor
|
None
,
processor
:
Gemma3Processor
|
None
,
)
->
PromptUpdateDetails
[
str
]:
)
->
PromptUpdateDetails
[
str
]:
if
processor
is
None
:
if
processor
is
None
:
...
@@ -187,11 +201,13 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -187,11 +201,13 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
boi_token
=
processor
.
boi_token
boi_token
=
processor
.
boi_token
num_crops
=
self
.
get_num_crops
(
if
num_crops
is
None
:
image_width
=
image_width
,
assert
image_width
is
not
None
and
image_height
is
not
None
image_height
=
image_height
,
num_crops
=
self
.
get_num_crops
(
processor
=
processor
,
image_width
=
image_width
,
)
image_height
=
image_height
,
processor
=
processor
,
)
if
num_crops
==
0
:
if
num_crops
==
0
:
image_text
=
boi_token
image_text
=
boi_token
...
@@ -321,6 +337,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
...
@@ -321,6 +337,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
return
dict
(
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
...
@@ -333,7 +350,19 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
...
@@ -333,7 +350,19 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_token
=
hf_processor
.
boi_token
image_token
=
hf_processor
.
boi_token
def
get_replacement_gemma3
(
item_idx
:
int
):
def
get_replacement_gemma3
(
item_idx
:
int
):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
images
=
mm_items
.
get_items
(
"image"
,
(
ImageEmbeddingItems
,
ImageProcessorItems
)
)
if
isinstance
(
images
,
ImageEmbeddingItems
):
# For image embedding inputs, only support no crops cases
# since it's not supported in hf processor anyway
return
self
.
info
.
get_image_repl
(
image_width
=
None
,
image_height
=
None
,
num_crops
=
0
,
processor
=
hf_processor
,
)
image_size
=
images
.
get_image_size
(
item_idx
)
image_size
=
images
.
get_image_size
(
item_idx
)
return
self
.
info
.
get_image_repl
(
return
self
.
info
.
get_image_repl
(
...
@@ -557,17 +586,19 @@ class Gemma3ForConditionalGeneration(
...
@@ -557,17 +586,19 @@ class Gemma3ForConditionalGeneration(
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
num_patches
=
kwargs
.
pop
(
"num_patches"
,
None
)
num_patches
=
kwargs
.
pop
(
"num_patches"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
assert
image_embeds
is
None
,
"Gemma3 does not support image_embeds."
if
pixel_values
is
None
:
return
None
image_size
=
self
.
config
.
vision_config
.
image_size
if
pixel_values
is
not
None
:
image_size
=
self
.
config
.
vision_config
.
image_size
return
Gemma3ImagePixelInputs
(
return
Gemma3ImagePixelInputs
(
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
num_patches
=
num_patches
,
num_patches
=
num_patches
,
resolve_bindings
=
{
"h"
:
image_size
,
"w"
:
image_size
},
resolve_bindings
=
{
"h"
:
image_size
,
"w"
:
image_size
},
)
)
elif
image_embeds
is
not
None
:
return
Gemma3ImageEmbeddingInputs
(
image_embeds
=
image_embeds
,
type
=
"image_embeds"
,
)
def
_image_pixels_to_features
(
def
_image_pixels_to_features
(
self
,
self
,
...
@@ -579,7 +610,9 @@ class Gemma3ForConditionalGeneration(
...
@@ -579,7 +610,9 @@ class Gemma3ForConditionalGeneration(
def
_process_image_input
(
def
_process_image_input
(
self
,
self
,
image_input
:
Gemma3ImageInputs
,
image_input
:
Gemma3ImageInputs
,
)
->
list
[
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
list
[
torch
.
Tensor
]:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"image_embeds"
]
assert
self
.
vision_tower
is
not
None
assert
self
.
vision_tower
is
not
None
pixel_values
=
image_input
[
"pixel_values"
]
pixel_values
=
image_input
[
"pixel_values"
]
...
...
vllm/multimodal/parse.py
View file @
cb15ee28
...
@@ -359,8 +359,9 @@ class MultiModalDataParser:
...
@@ -359,8 +359,9 @@ class MultiModalDataParser:
)
)
self
.
video_needs_metadata
=
video_needs_metadata
self
.
video_needs_metadata
=
video_needs_metadata
def
_is_embeddings
(
@
classmethod
self
,
data
:
object
def
is_embeddings
(
cls
,
data
:
object
)
->
TypeGuard
[
torch
.
Tensor
|
list
[
torch
.
Tensor
]]:
)
->
TypeGuard
[
torch
.
Tensor
|
list
[
torch
.
Tensor
]]:
if
isinstance
(
data
,
torch
.
Tensor
):
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
.
ndim
==
3
return
data
.
ndim
==
3
...
@@ -420,7 +421,7 @@ class MultiModalDataParser:
...
@@ -420,7 +421,7 @@ class MultiModalDataParser:
):
):
return
None
return
None
if
self
.
_
is_embeddings
(
data
):
if
self
.
is_embeddings
(
data
):
return
AudioEmbeddingItems
(
data
)
return
AudioEmbeddingItems
(
data
)
data_items
:
list
[
AudioItem
]
data_items
:
list
[
AudioItem
]
...
@@ -458,7 +459,7 @@ class MultiModalDataParser:
...
@@ -458,7 +459,7 @@ class MultiModalDataParser:
if
self
.
_is_empty
(
data
):
if
self
.
_is_empty
(
data
):
return
None
return
None
if
self
.
_
is_embeddings
(
data
):
if
self
.
is_embeddings
(
data
):
return
ImageEmbeddingItems
(
data
)
return
ImageEmbeddingItems
(
data
)
if
(
if
(
...
@@ -484,7 +485,7 @@ class MultiModalDataParser:
...
@@ -484,7 +485,7 @@ class MultiModalDataParser:
if
self
.
_is_empty
(
data
):
if
self
.
_is_empty
(
data
):
return
None
return
None
if
self
.
_
is_embeddings
(
data
):
if
self
.
is_embeddings
(
data
):
return
VideoEmbeddingItems
(
data
)
return
VideoEmbeddingItems
(
data
)
data_items
:
list
[
VideoItem
]
data_items
:
list
[
VideoItem
]
...
...
vllm/v1/engine/processor.py
View file @
cb15ee28
...
@@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest
...
@@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal.cache
import
processor_cache_from_config
from
vllm.multimodal.cache
import
processor_cache_from_config
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalUUIDDict
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalUUIDDict
from
vllm.multimodal.parse
import
MultiModalDataParser
from
vllm.multimodal.processing
import
EncDecMultiModalProcessor
from
vllm.multimodal.processing
import
EncDecMultiModalProcessor
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
...
@@ -340,7 +341,12 @@ class Processor:
...
@@ -340,7 +341,12 @@ class Processor:
mm_uuids
:
dict
[
str
,
list
[
str
|
None
]
|
str
]
=
{}
mm_uuids
:
dict
[
str
,
list
[
str
|
None
]
|
str
]
=
{}
for
modality
,
data
in
mm_data
.
items
():
for
modality
,
data
in
mm_data
.
items
():
n
=
len
(
data
)
if
isinstance
(
data
,
list
)
else
1
# Hash each item for embedding inputs.
n
=
(
len
(
data
)
if
isinstance
(
data
,
list
)
or
MultiModalDataParser
.
is_embeddings
(
data
)
else
1
)
mm_uuids
[
modality
]
=
[
f
"
{
request_id
}
-
{
modality
}
-
{
i
}
"
for
i
in
range
(
n
)]
mm_uuids
[
modality
]
=
[
f
"
{
request_id
}
-
{
modality
}
-
{
i
}
"
for
i
in
range
(
n
)]
return
mm_uuids
return
mm_uuids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment