Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
39b643dc
Unverified
Commit
39b643dc
authored
Oct 03, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 02, 2025
Browse files
[Model] Use `merge_by_field_config` for MM models (G) (#26117)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
711f4856
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
56 additions
and
108 deletions
+56
-108
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+14
-21
vllm/model_executor/models/gemma3n_mm.py
vllm/model_executor/models/gemma3n_mm.py
+31
-38
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+2
-36
vllm/model_executor/models/glm4v.py
vllm/model_executor/models/glm4v.py
+6
-9
vllm/model_executor/models/granite_speech.py
vllm/model_executor/models/granite_speech.py
+3
-4
No files found.
vllm/model_executor/models/gemma3_mm.py
View file @
39b643dc
...
...
@@ -36,7 +36,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
)
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
...
...
@@ -289,7 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
processor
=
hf_processor
)
for
size
in
image_sizes
]
processed_outputs
[
"num_
crop
s"
]
=
torch
.
tensor
(
num_crops
)
processed_outputs
[
"num_
patche
s"
]
=
torch
.
tensor
(
num_crops
)
+
1
return
processed_outputs
...
...
@@ -298,12 +298,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
num_
crop
s
=
hf_inputs
.
get
(
"num_
crop
s"
,
torch
.
empty
(
0
))
num_
patche
s
=
hf_inputs
.
get
(
"num_
patche
s"
,
torch
.
empty
(
0
))
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_
crops
+
1
),
num_
crop
s
=
MultiModalFieldConfig
.
batched
(
"image"
),
"image"
,
num_
patches
),
num_
patche
s
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_updates
(
...
...
@@ -460,6 +460,8 @@ class Gemma3MultiModalProjector(nn.Module):
dummy_inputs
=
Gemma3DummyInputsBuilder
)
class
Gemma3ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsLoRA
):
merge_by_field_config
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -526,29 +528,20 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3ImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
num_
crop
s
=
kwargs
.
pop
(
"num_
crop
s"
,
None
)
num_
patche
s
=
kwargs
.
pop
(
"num_
patche
s"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
assert
image_embeds
is
None
,
"Gemma3 does not support image_embeds."
if
pixel_values
is
None
:
return
None
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
not
isinstance
(
num_crops
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of num_crops. "
f
"Got type:
{
type
(
num_crops
)
}
"
)
image_size
=
self
.
config
.
vision_config
.
image_size
return
Gemma3ImagePixelInputs
(
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
),
num_patches
=
flatten_bn
(
num_crops
,
concat
=
True
)
+
1
,
resolve_bindings
=
{
"h"
:
image_size
,
"w"
:
image_size
})
return
Gemma3ImagePixelInputs
(
pixel_values
=
pixel_values
,
num_patches
=
num_patches
,
resolve_bindings
=
{
"h"
:
image_size
,
"w"
:
image_size
})
def
_image_pixels_to_features
(
self
,
...
...
vllm/model_executor/models/gemma3n_mm.py
View file @
39b643dc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Any
,
Literal
,
Optional
,
TypedDict
,
Union
,
cast
from
typing
import
Annotated
,
Any
,
Literal
,
Optional
,
Union
,
cast
import
numpy
as
np
import
torch
...
...
@@ -41,6 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsTranscription
)
...
...
@@ -54,17 +55,28 @@ TOKENS_PER_IMAGE = 256
TOKENS_PER_AUDIO
=
188
class
Gemma3nImagePixelInputs
(
TypedDict
):
pixel_values
:
torch
.
Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class
Gemma3nImagePixelInputs
(
TensorSchema
):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each patch
- w: Width of each patch
"""
type
:
Literal
[
"pixel_values"
]
=
"pixel_values"
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
3
,
"h"
,
"w"
)]
class
Gemma3nAudioInputs
(
TypedDict
):
input_features
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
input_features_padded
:
torch
.
Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
input_features_mask
:
torch
.
Tensor
"""Shape: `(batch_size * num_audio, seq_length)`"""
class
Gemma3nAudioInputs
(
TensorSchema
):
"""
Dimensions:
- bn: Batch size * number of audios
- s: seq_length
- f: num_features
"""
type
:
Literal
[
"audio"
]
=
"audio"
input_features_padded
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"s"
,
"f"
)]
input_features_mask
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"s"
)]
Gemma3nImageInputs
=
Gemma3nImagePixelInputs
...
...
@@ -212,9 +224,9 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
input_features
=
MultiModalFieldConfig
.
batched
(
"audio"
),
input_features_padded
=
MultiModalFieldConfig
.
batched
(
"audio"
),
input_features_mask
=
MultiModalFieldConfig
.
batched
(
"audio"
))
input_features_mask
=
MultiModalFieldConfig
.
batched
(
"audio"
),
)
def
_get_prompt_updates
(
self
,
...
...
@@ -422,6 +434,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
dummy_inputs
=
Gemma3nDummyInputsBuilder
)
class
Gemma3nForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsTranscription
):
merge_by_field_config
=
True
supported_languages
=
ISO639_1_SUPPORTED_LANGS
packed_modules_mapping
=
{
...
...
@@ -482,14 +495,6 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
device
=
self
.
language_model
.
model
.
embed_tokens
.
weight
.
device
,
dtype
=
self
.
language_model
.
model
.
embed_tokens
.
weight
.
dtype
)
@
property
def
dtype
(
self
):
return
next
(
self
.
parameters
()).
dtype
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# TODO check if there are any
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3nImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
...
@@ -499,34 +504,22 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
if
pixel_values
is
None
:
return
None
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
pixel_values
=
pixel_values
.
contiguous
()
return
Gemma3nImagePixelInputs
(
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
)
return
Gemma3nImagePixelInputs
(
pixel_values
=
pixel_values
)
def
_parse_and_validate_audio_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3nAudioInputs
]:
input_features
=
kwargs
.
pop
(
"input_features"
,
None
)
if
input_features
is
None
:
input_features_padded
=
kwargs
.
pop
(
"input_features_padded"
,
None
)
if
input_features_padded
is
None
:
return
None
input_features_mask
=
kwargs
.
pop
(
"input_features_mask"
,
None
)
if
input_features_mask
is
None
:
return
None
input_features_padded
=
kwargs
.
pop
(
"input_features_padded"
,
None
)
if
input_features_padded
is
None
:
return
None
return
Gemma3nAudioInputs
(
input_features
=
input_features
,
input_features_mask
=
input_features_mask
,
input_features_padded
=
input_features_padded
,
input_features_mask
=
input_features_mask
,
)
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
...
...
@@ -539,7 +532,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
)
and
"image"
not
in
mm_input_by_modality
:
mm_input_by_modality
[
"image"
]
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
input_key
==
"input_features"
\
if
input_key
==
"input_features
_padded
"
\
and
"audio"
not
in
mm_input_by_modality
:
mm_input_by_modality
[
"audio"
]
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
...
...
vllm/model_executor/models/glm4_1v.py
View file @
39b643dc
...
...
@@ -1319,6 +1319,8 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
)
class
Glm4vForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsLoRA
,
SupportsPP
):
merge_by_field_config
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -1381,22 +1383,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
mm_input
.
ndim
==
2
:
return
mm_input
if
mm_input
.
ndim
!=
3
:
raise
ValueError
(
f
"
{
name
}
should be 2D or batched 3D tensor. "
f
"Got ndim:
{
mm_input
.
ndim
}
"
f
"(shape=
{
mm_input
.
shape
}
)"
)
return
mm_input
.
reshape
(
-
1
,
mm_input
.
shape
[
-
1
])
else
:
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Glm4vImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
...
@@ -1407,11 +1393,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
return
None
if
pixel_values
is
not
None
:
pixel_values
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values
,
"image pixel values"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
return
Glm4vImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
pixel_values
,
...
...
@@ -1419,11 +1400,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
)
if
image_embeds
is
not
None
:
image_embeds
=
self
.
_validate_and_reshape_mm_tensor
(
image_embeds
,
"image embeds"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
return
Glm4vImageEmbeddingInputs
(
type
=
"image_embeds"
,
image_embeds
=
image_embeds
,
...
...
@@ -1440,11 +1416,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
return
None
if
pixel_values_videos
is
not
None
:
pixel_values_videos
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values_videos
,
"video pixel values"
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
"video grid_thw"
)
return
Glm4vVideoPixelInputs
(
type
=
"pixel_values_videos"
,
pixel_values_videos
=
pixel_values_videos
,
...
...
@@ -1452,11 +1423,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
)
if
video_embeds
is
not
None
:
video_embeds
=
self
.
_validate_and_reshape_mm_tensor
(
video_embeds
,
"video embeds"
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
"video grid_thw"
)
return
Glm4vVideoEmbeddingInputs
(
type
=
"video_embeds"
,
video_embeds
=
video_embeds
,
...
...
vllm/model_executor/models/glm4v.py
View file @
39b643dc
...
...
@@ -43,7 +43,6 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from
.chatglm
import
ChatGLMBaseModel
,
ChatGLMModel
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
)
from
.utils
import
flatten_bn
class
GLMVImagePixelInputs
(
TensorSchema
):
...
...
@@ -529,8 +528,9 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
@
MULTIMODAL_REGISTRY
.
register_processor
(
GLM4VMultiModalProcessor
,
info
=
GLM4VProcessingInfo
,
dummy_inputs
=
GLM4VDummyInputsBuilder
)
class
GLM4VForCausalLM
(
ChatGLMBaseModel
,
SupportsLoRA
,
SupportsPP
,
SupportsMultiModal
):
class
GLM4VForCausalLM
(
ChatGLMBaseModel
,
SupportsMultiModal
,
SupportsLoRA
,
SupportsPP
):
merge_by_field_config
=
True
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
...
...
@@ -574,14 +574,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
expected_h
=
expected_w
=
self
.
config
.
vision_config
[
"image_size"
]
return
GLMVImagePixelInputs
(
type
=
"pixel_values"
,
data
=
flatten_bn
(
pixel_values
,
concat
=
True
),
data
=
pixel_values
,
resolve_bindings
=
{
"h"
:
expected_h
,
"w"
:
expected_w
...
...
@@ -598,6 +593,8 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
transformer
get_input_embeddings
=
SupportsMultiModal
.
get_input_embeddings
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/granite_speech.py
View file @
39b643dc
...
...
@@ -168,10 +168,8 @@ class GraniteSpeechMultiModalProcessor(
# Calculate the number of audio tokens per entry in the batch;
# This is used to split the batch back out after padding.
audio_token_index
=
self
.
info
.
get_hf_config
().
audio_token_index
processed_outputs
[
"audio_embed_sizes"
]
=
[
torch
.
sum
(
indices
==
audio_token_index
).
item
()
for
indices
in
processed_outputs
[
"input_ids"
]
]
processed_outputs
[
"audio_embed_sizes"
]
=
(
processed_outputs
[
"input_ids"
]
==
audio_token_index
).
sum
(
-
1
)
return
processed_outputs
...
...
@@ -527,6 +525,7 @@ class GraniteSpeechForConditionalGeneration(
SupportsPP
,
SupportsLoRA
,
):
merge_by_field_config
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment