Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
74704d45
Unverified
Commit
74704d45
authored
Oct 14, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 14, 2025
Browse files
[Model] Use merge_by_field_config for MM models (O-P) (#26776)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
d2f816d6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
122 deletions
+30
-122
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+7
-16
vllm/model_executor/models/phi4_multimodal.py
vllm/model_executor/models/phi4_multimodal.py
+11
-53
vllm/model_executor/models/phi4mm.py
vllm/model_executor/models/phi4mm.py
+12
-53
No files found.
vllm/model_executor/models/phi3v.py
View file @
74704d45
...
...
@@ -56,7 +56,6 @@ from vllm.multimodal.processing import (
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.clip
import
CLIPVisionModel
...
...
@@ -70,7 +69,6 @@ from .utils import (
AutoWeightsLoader
,
WeightsMapper
,
_merge_multimodal_embeddings
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
)
...
...
@@ -564,6 +562,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
dummy_inputs
=
Phi3VDummyInputsBuilder
,
)
class
Phi3VForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
):
merge_by_field_config
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model.vision_embed_tokens.wte"
:
"embed_tokens"
,
...
...
@@ -631,8 +631,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
if
pixel_values
is
not
None
:
return
Phi3VImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
flatten_bn
(
pixel_values
)
,
image_sizes
=
flatten_bn
(
image_sizes
,
concat
=
True
),
pixel_values
=
pixel_values
,
image_sizes
=
image_sizes
,
resolve_bindings
=
{
"h"
:
CLIP_VIT_LARGE_PATCH14_336_CONFIG
.
image_size
,
"w"
:
CLIP_VIT_LARGE_PATCH14_336_CONFIG
.
image_size
,
...
...
@@ -642,7 +642,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
if
image_embeds
is
not
None
:
return
Phi3VImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
flatten_bn
(
image_embeds
)
,
data
=
image_embeds
,
)
raise
AssertionError
(
"This line should be unreachable."
)
...
...
@@ -652,19 +652,10 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
image_input
:
Phi3VImageInputs
,
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
image_data
=
image_input
[
"data"
]
if
is_list_of
(
image_data
,
torch
.
Tensor
):
# it's already a list of tensors
return
image_data
if
len
(
image_data
.
shape
)
==
3
:
# 3D tensor
return
list
(
torch
.
unbind
(
image_data
,
dim
=
0
))
raise
ValueError
(
"We expect batched 2D tensors; "
"this can be either a list of 2D tensors or a single 3D tensor."
)
return
image_input
[
"data"
]
assert
self
.
vision_embed_tokens
is
not
None
image_embeds
=
self
.
vision_embed_tokens
(
image_input
[
"pixel_values"
],
image_input
[
"image_sizes"
]
)
...
...
vllm/model_executor/models/phi4_multimodal.py
View file @
74704d45
...
...
@@ -64,7 +64,6 @@ from vllm.multimodal.processing import (
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.idefics2_vision_model
import
Idefics2VisionTransformer
...
...
@@ -72,7 +71,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
)
...
...
@@ -672,7 +670,7 @@ class Phi4MMImagePixelInputs(TensorSchema):
type
:
Literal
[
"pixel_values"
]
data
:
Annotated
[
pixel_values
:
Annotated
[
torch
.
Tensor
|
list
[
torch
.
Tensor
],
TensorShape
(
"bn"
,
"p"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"p"
}
...
...
@@ -721,7 +719,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema):
type
:
Literal
[
"audio_features"
]
data
:
Annotated
[
audio_features
:
Annotated
[
torch
.
Tensor
|
list
[
torch
.
Tensor
],
TensorShape
(
"bn"
,
"t"
,
80
,
dynamic_dims
=
{
"t"
}),
]
...
...
@@ -1189,6 +1187,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
Implements the Phi-4-multimodal-instruct model in vLLM.
"""
merge_by_field_config
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
,
...
...
@@ -1273,7 +1273,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if
audio_features
is
not
None
:
return
Phi4MMAudioFeatureInputs
(
type
=
"audio_features"
,
data
=
flatten_bn
(
audio_features
)
type
=
"audio_features"
,
audio_features
=
audio_features
,
)
if
audio_embeds
is
not
None
:
...
...
@@ -1298,7 +1299,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if
audio_input
[
"type"
]
==
"audio_embeds"
:
return
audio_input
[
"data"
]
audio_features
=
audio_input
[
"
data
"
]
audio_features
=
audio_input
[
"
audio_features
"
]
# (e.g. multiple examples) and the second dim is the multi-audio dim
# (e.g. multiple audios in the same example)
...
...
@@ -1315,8 +1316,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Phi4MMImagePixelInputs
|
None
:
image_
pixel_values
:
NestedTensors
=
kwargs
.
get
(
"image_pixel_values"
)
if
image_
pixel_values
is
None
:
pixel_values
=
kwargs
.
get
(
"image_pixel_values"
)
if
pixel_values
is
None
:
return
None
image_sizes
=
kwargs
.
get
(
"image_sizes"
)
...
...
@@ -1328,52 +1329,9 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
and
num_img_tokens
is
not
None
),
"Missing image inputs"
if
is_list_of
(
image_pixel_values
,
torch
.
Tensor
):
assert
all
(
p
.
dim
()
==
5
for
p
in
image_pixel_values
),
(
"Incorrect image inputs"
)
# list len is batch_size.
# each tensor has dimension: num_img_per_example, num_hd_patches,
# channels, height, width.
# need to pad along num_hd_patches.
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
image_pixel_values
=
cat_with_pad
(
image_pixel_values
,
dim
=
0
)
elif
isinstance
(
image_pixel_values
,
torch
.
Tensor
):
# dimension: batch_size, num_img_per_example, num_hd_patches,
# channels, height, width.
# we flatten first 2 dims to make it a single large batch for
# SigLIP Encoder.
assert
image_pixel_values
.
dim
()
==
6
,
"Incorrect image inputs"
image_pixel_values
=
image_pixel_values
.
flatten
(
0
,
1
)
else
:
raise
ValueError
(
"Incorrect image_pixel_values inputs"
)
if
isinstance
(
image_attention_mask
,
list
):
image_attention_mask
=
cat_with_pad
(
image_attention_mask
,
dim
=
0
)
elif
isinstance
(
image_attention_mask
,
torch
.
Tensor
):
image_attention_mask
=
image_attention_mask
.
flatten
(
0
,
1
)
else
:
raise
ValueError
(
"Incorrect image_attention_mask inputs"
)
if
isinstance
(
image_sizes
,
list
):
image_sizes
=
torch
.
cat
(
image_sizes
,
dim
=
0
)
elif
isinstance
(
image_sizes
,
torch
.
Tensor
):
image_sizes
=
image_sizes
.
flatten
(
0
,
1
)
else
:
raise
ValueError
(
"Incorrect image_sizes inputs"
)
if
isinstance
(
num_img_tokens
,
list
):
num_img_tokens
=
[
n
for
num_tensor
in
num_img_tokens
for
n
in
num_tensor
.
tolist
()
]
elif
isinstance
(
num_img_tokens
,
torch
.
Tensor
):
num_img_tokens
=
num_img_tokens
.
flatten
(
0
,
1
).
tolist
()
else
:
raise
ValueError
(
"Incorrect num_img_tokens inputs"
)
return
Phi4MMImagePixelInputs
(
type
=
"pixel_values"
,
data
=
image_
pixel_values
,
pixel_values
=
pixel_values
,
image_sizes
=
image_sizes
,
image_attention_mask
=
image_attention_mask
,
num_img_tokens
=
num_img_tokens
,
...
...
@@ -1405,7 +1363,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
image_embeds
=
image_input
[
"image_embeds"
].
type
(
self
.
visual
.
dtype
)
else
:
dtype
=
next
(
self
.
image_embed
.
parameters
()).
dtype
pixel_values
=
image_input
[
"
data
"
].
to
(
dtype
)
pixel_values
=
image_input
[
"
pixel_values
"
].
to
(
dtype
)
image_sizes
=
image_input
[
"image_sizes"
]
image_attention_mask
=
image_input
[
"image_attention_mask"
]
image_embeds
=
self
.
image_embed
(
...
...
vllm/model_executor/models/phi4mm.py
View file @
74704d45
...
...
@@ -50,13 +50,12 @@ from vllm.multimodal.processing import (
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.interfaces
import
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
from
.phi4mm_audio
import
AudioEmbedding
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
# <|endoftext10|> (see vocab.json in hf model)
_IMAGE_PLACEHOLDER_TOKEN_ID
=
200010
...
...
@@ -467,7 +466,7 @@ class Phi4MMImagePixelInputs(TensorSchema):
type
:
Literal
[
"pixel_values"
]
data
:
Annotated
[
pixel_values
:
Annotated
[
torch
.
Tensor
|
list
[
torch
.
Tensor
],
TensorShape
(
"bn"
,
"p"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"p"
}
...
...
@@ -499,7 +498,7 @@ class Phi4MMAudioFeatureInputs(TensorSchema):
type
:
Literal
[
"audio_features"
]
data
:
Annotated
[
audio_features
:
Annotated
[
torch
.
Tensor
|
list
[
torch
.
Tensor
],
TensorShape
(
"bn"
,
"t"
,
80
,
dynamic_dims
=
{
"t"
}),
]
...
...
@@ -986,6 +985,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
Implements the Phi-4-multimodal-instruct model in vLLM.
"""
merge_by_field_config
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
,
...
...
@@ -1094,7 +1095,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if
audio_features
is
not
None
:
return
Phi4MMAudioFeatureInputs
(
type
=
"audio_features"
,
data
=
flatten_bn
(
audio_features
)
type
=
"audio_features"
,
audio_features
=
audio_features
,
)
if
audio_embeds
is
not
None
:
...
...
@@ -1119,7 +1121,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if
audio_input
[
"type"
]
==
"audio_embeds"
:
return
audio_input
[
"data"
]
audio_features
=
audio_input
[
"
data
"
]
audio_features
=
audio_input
[
"
audio_features
"
]
# (e.g. multiple examples) and the second dim is the multi-audio dim
# (e.g. multiple audios in the same example)
...
...
@@ -1136,8 +1138,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Phi4MMImagePixelInputs
|
None
:
input_image_embeds
:
NestedTensor
s
=
kwargs
.
get
(
"input_image_embeds"
)
if
input_image_embed
s
is
None
:
pixel_value
s
=
kwargs
.
get
(
"input_image_embeds"
)
if
pixel_value
s
is
None
:
return
None
image_sizes
=
kwargs
.
get
(
"image_sizes"
)
...
...
@@ -1149,52 +1151,9 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
and
num_img_tokens
is
not
None
),
"Missing image inputs"
if
is_list_of
(
input_image_embeds
,
torch
.
Tensor
):
assert
all
(
p
.
dim
()
==
5
for
p
in
input_image_embeds
),
(
"Incorrect image inputs"
)
# list len is batch_size.
# each tensor has dimension: num_img_per_example, num_hd_patches,
# channels, height, width.
# need to pad along num_hd_patches.
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
input_image_embeds
=
cat_with_pad
(
input_image_embeds
,
dim
=
0
)
elif
isinstance
(
input_image_embeds
,
torch
.
Tensor
):
# dimension: batch_size, num_img_per_example, num_hd_patches,
# channels, height, width.
# we flatten first 2 dims to make it a single large batch for
# SigLIP Encoder.
assert
input_image_embeds
.
dim
()
==
6
,
"Incorrect image inputs"
input_image_embeds
=
input_image_embeds
.
flatten
(
0
,
1
)
else
:
raise
ValueError
(
"Incorrect input_image_embeds inputs"
)
if
isinstance
(
image_attention_mask
,
list
):
image_attention_mask
=
cat_with_pad
(
image_attention_mask
,
dim
=
0
)
elif
isinstance
(
image_attention_mask
,
torch
.
Tensor
):
image_attention_mask
=
image_attention_mask
.
flatten
(
0
,
1
)
else
:
raise
ValueError
(
"Incorrect image_attention_mask inputs"
)
if
isinstance
(
image_sizes
,
list
):
image_sizes
=
torch
.
cat
(
image_sizes
,
dim
=
0
)
elif
isinstance
(
image_sizes
,
torch
.
Tensor
):
image_sizes
=
image_sizes
.
flatten
(
0
,
1
)
else
:
raise
ValueError
(
"Incorrect image_sizes inputs"
)
if
isinstance
(
num_img_tokens
,
list
):
num_img_tokens
=
[
n
for
num_tensor
in
num_img_tokens
for
n
in
num_tensor
.
tolist
()
]
elif
isinstance
(
num_img_tokens
,
torch
.
Tensor
):
num_img_tokens
=
num_img_tokens
.
flatten
(
0
,
1
).
tolist
()
else
:
raise
ValueError
(
"Incorrect num_img_tokens inputs"
)
return
Phi4MMImagePixelInputs
(
type
=
"pixel_values"
,
data
=
input_image_embed
s
,
pixel_values
=
pixel_value
s
,
image_sizes
=
image_sizes
,
image_attention_mask
=
image_attention_mask
,
num_img_tokens
=
num_img_tokens
,
...
...
@@ -1223,7 +1182,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
self
,
image_input
:
Phi4MMImagePixelInputs
)
->
list
[
torch
.
Tensor
]:
dtype
=
next
(
self
.
vision_encoder
.
parameters
()).
dtype
pixel_values
=
image_input
[
"
data
"
].
to
(
dtype
)
pixel_values
=
image_input
[
"
pixel_values
"
].
to
(
dtype
)
image_sizes
=
image_input
[
"image_sizes"
]
image_attention_mask
=
image_input
[
"image_attention_mask"
]
image_embeds
=
self
.
vision_encoder
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment