Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
437c3ce0
Unverified
Commit
437c3ce0
authored
Aug 31, 2025
by
Benji Beck
Committed by
GitHub
Sep 01, 2025
Browse files
Migrate Phi4 inputs to TensorSchema (#23471)
Signed-off-by:
Benji Beck
<
benjibeck@meta.com
>
parent
499b074b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
129 additions
and
73 deletions
+129
-73
vllm/model_executor/models/phi4_multimodal.py
vllm/model_executor/models/phi4_multimodal.py
+72
-39
vllm/model_executor/models/phi4mm.py
vllm/model_executor/models/phi4mm.py
+57
-34
No files found.
vllm/model_executor/models/phi4_multimodal.py
View file @
437c3ce0
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Any
,
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Any
,
Literal
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.interfaces
import
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
from
.interfaces
import
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
...
@@ -615,50 +616,90 @@ class Phi4MMAudioEmbedding(nn.Module):
...
@@ -615,50 +616,90 @@ class Phi4MMAudioEmbedding(nn.Module):
return
loaded_params
return
loaded_params
class
Phi4MMImagePixelInputs
(
TypedDict
):
class
Phi4MMImagePixelInputs
(
TensorSchema
):
type
:
Literal
[
"pixel_values"
]
data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
"""
Shape:
Dimensions:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
- bn: Batch size * number of images
- p: Number of patches (1 + num_patches)
Note that `num_patches` may be different per batch and image,
- c: Number of channels (3)
in which case the data is passed as a list instead of a batched tensor.
- h: Height of each image patch
- w: Width of each image patch
- nc: Number of crops
- H_mask: Height of attention mask
- W_mask: Width of attention mask
"""
"""
image_sizes
:
torch
.
Tensor
type
:
Literal
[
"pixel_values"
]
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
data
:
Annotated
[
"""
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
TensorShape
(
"bn"
,
"p"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"p"
}
),
# may be different per batch and image
]
num_img_tokens
:
list
[
int
]
image_sizes
:
Annotated
[
"""Shape: `(batch_size * num_images)`"""
torch
.
Tensor
,
TensorShape
(
"bn"
,
2
),
# (height, width)
]
image_attention_mask
:
torch
.
Tensor
num_img_tokens
:
Annotated
[
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
list
[
int
],
TensorShape
(
"bn"
),
]
image_attention_mask
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"nc"
,
32
,
32
),
# H_mask, W_mask
]
class
Phi4MMImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class
Phi4MMImageEmbeddingInputs
(
TensorSchema
):
"""
"""
Dimensions:
- bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match language model backbone)
"""
type
:
Literal
[
"image_embeds"
]
data
:
Annotated
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
TensorShape
(
"bn"
,
"f"
,
"h"
),
]
class
Phi4MMAudioFeatureInputs
(
TensorSchema
):
"""
Dimensions:
- bn: Batch size * number of audios
- f: Number of Mel filterbank bins (80)
- t: Time frames (M)
"""
class
Phi4MMAudioFeatureInputs
(
TypedDict
):
type
:
Literal
[
"audio_features"
]
type
:
Literal
[
"audio_features"
]
data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""Shape: `(batch_size * num_audios, 80, M)"""
data
:
Annotated
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
TensorShape
(
"bn"
,
"t"
,
80
,
dynamic_dims
=
{
"t"
}),
]
class
Phi4MMAudioEmbeddingInputs
(
TensorSchema
):
"""
Dimensions:
- b: Batch size
- n: Number of audios
- f: Audio feature size
- h: Hidden size (must match language model backbone)
"""
class
Phi4MMAudioEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"audio_embeds"
]
type
:
Literal
[
"audio_embeds"
]
data
:
NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
data
:
Annotated
[
NestedTensors
,
TensorShape
(
"b"
,
"n"
,
"f"
,
"h"
),
]
Phi4MMImageInput
=
Union
[
Phi4MMImagePixelInputs
,
Phi4MMImageEmbeddingInputs
]
Phi4MMImageInput
=
Union
[
Phi4MMImagePixelInputs
,
Phi4MMImageEmbeddingInputs
]
...
@@ -1170,18 +1211,10 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
...
@@ -1170,18 +1211,10 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return
None
return
None
if
audio_features
is
not
None
:
if
audio_features
is
not
None
:
if
not
isinstance
(
audio_features
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of audio features. "
f
"Got type:
{
type
(
audio_features
)
}
"
)
return
Phi4MMAudioFeatureInputs
(
type
=
"audio_features"
,
return
Phi4MMAudioFeatureInputs
(
type
=
"audio_features"
,
data
=
flatten_bn
(
audio_features
))
data
=
flatten_bn
(
audio_features
))
if
audio_embeds
is
not
None
:
if
audio_embeds
is
not
None
:
if
not
isinstance
(
audio_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of audio embeds. "
f
"Got type:
{
type
(
audio_embeds
)
}
"
)
return
Phi4MMAudioEmbeddingInputs
(
type
=
"audio_embeds"
,
return
Phi4MMAudioEmbeddingInputs
(
type
=
"audio_embeds"
,
data
=
audio_embeds
)
data
=
audio_embeds
)
...
@@ -1259,7 +1292,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
...
@@ -1259,7 +1292,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif
isinstance
(
image_sizes
,
torch
.
Tensor
):
elif
isinstance
(
image_sizes
,
torch
.
Tensor
):
image_sizes
=
image_sizes
.
flatten
(
0
,
1
)
image_sizes
=
image_sizes
.
flatten
(
0
,
1
)
else
:
else
:
raise
ValueError
(
"Incorrect image_
attention_mask
inputs"
)
raise
ValueError
(
"Incorrect image_
sizes
inputs"
)
if
isinstance
(
num_img_tokens
,
list
):
if
isinstance
(
num_img_tokens
,
list
):
num_img_tokens
=
[
num_img_tokens
=
[
...
@@ -1269,7 +1302,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
...
@@ -1269,7 +1302,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif
isinstance
(
num_img_tokens
,
torch
.
Tensor
):
elif
isinstance
(
num_img_tokens
,
torch
.
Tensor
):
num_img_tokens
=
num_img_tokens
.
flatten
(
0
,
1
).
tolist
()
num_img_tokens
=
num_img_tokens
.
flatten
(
0
,
1
).
tolist
()
else
:
else
:
raise
ValueError
(
"Incorrect
image_attention_mask
inputs"
)
raise
ValueError
(
"Incorrect
num_img_tokens
inputs"
)
return
Phi4MMImagePixelInputs
(
return
Phi4MMImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
...
...
vllm/model_executor/models/phi4mm.py
View file @
437c3ce0
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Any
,
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Any
,
Literal
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.interfaces
import
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
from
.interfaces
import
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
...
@@ -391,41 +392,71 @@ class Phi4MMImageEncoder(nn.Module):
...
@@ -391,41 +392,71 @@ class Phi4MMImageEncoder(nn.Module):
return
img_set_tensor
return
img_set_tensor
class
Phi4MMImagePixelInputs
(
TypedDict
):
class
Phi4MMImagePixelInputs
(
TensorSchema
):
type
:
Literal
[
"pixel_values"
]
data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
"""
Shape:
Dimensions:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
- bn: Batch size * number of images
- p: Number of patches (1 + num_patches)
Note that `num_patches` may be different per batch and image,
- c: Number of channels (3)
in which case the data is passed as a list instead of a batched tensor.
- h: Height of each image patch
- w: Width of each image patch
- nc: Number of crops
- H_mask: Height of attention mask
- W_mask: Width of attention mask
"""
"""
image_sizes
:
torch
.
Tensor
type
:
Literal
[
"pixel_values"
]
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
data
:
Annotated
[
"""
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
TensorShape
(
"bn"
,
"p"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"p"
}
),
# may be different per batch and image
]
image_sizes
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
2
),
# (height, width)
]
num_img_tokens
:
list
[
int
]
num_img_tokens
:
Annotated
[
"""Shape: `(batch_size * num_images)`"""
list
[
int
],
TensorShape
(
"bn"
),
]
image_attention_mask
:
torch
.
Tensor
image_attention_mask
:
Annotated
[
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
torch
.
Tensor
,
TensorShape
(
"bn"
,
"nc"
,
32
,
32
),
# H_mask, W_mask
]
class
Phi4MMAudioFeatureInputs
(
TypedDict
):
class
Phi4MMAudioFeatureInputs
(
TensorSchema
):
"""
Dimensions:
- bn: Batch size * number of audios
- t: Time frames (M)
"""
type
:
Literal
[
"audio_features"
]
type
:
Literal
[
"audio_features"
]
data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""Shape: `(batch_size * num_audios, 80, M)"""
data
:
Annotated
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
TensorShape
(
"bn"
,
"t"
,
80
,
dynamic_dims
=
{
"t"
}),
]
class
Phi4MMAudioEmbeddingInputs
(
TypedDict
):
class
Phi4MMAudioEmbeddingInputs
(
TensorSchema
):
"""
Dimensions:
- b: Batch size
- n: Number of audios
- f: Audio feature size
- h: Hidden size (must match language model backbone)
"""
type
:
Literal
[
"audio_embeds"
]
type
:
Literal
[
"audio_embeds"
]
data
:
NestedTensors
data
:
Annotated
[
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
NestedTensors
,
TensorShape
(
"b"
,
"n"
,
"f"
,
"h"
),
]
Phi4MMAudioInputs
=
Union
[
Phi4MMAudioFeatureInputs
,
Phi4MMAudioEmbeddingInputs
]
Phi4MMAudioInputs
=
Union
[
Phi4MMAudioFeatureInputs
,
Phi4MMAudioEmbeddingInputs
]
...
@@ -985,18 +1016,10 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
...
@@ -985,18 +1016,10 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return
None
return
None
if
audio_features
is
not
None
:
if
audio_features
is
not
None
:
if
not
isinstance
(
audio_features
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of audio features. "
f
"Got type:
{
type
(
audio_features
)
}
"
)
return
Phi4MMAudioFeatureInputs
(
type
=
"audio_features"
,
return
Phi4MMAudioFeatureInputs
(
type
=
"audio_features"
,
data
=
flatten_bn
(
audio_features
))
data
=
flatten_bn
(
audio_features
))
if
audio_embeds
is
not
None
:
if
audio_embeds
is
not
None
:
if
not
isinstance
(
audio_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of audio embeds. "
f
"Got type:
{
type
(
audio_embeds
)
}
"
)
return
Phi4MMAudioEmbeddingInputs
(
type
=
"audio_embeds"
,
return
Phi4MMAudioEmbeddingInputs
(
type
=
"audio_embeds"
,
data
=
audio_embeds
)
data
=
audio_embeds
)
...
@@ -1074,7 +1097,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
...
@@ -1074,7 +1097,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif
isinstance
(
image_sizes
,
torch
.
Tensor
):
elif
isinstance
(
image_sizes
,
torch
.
Tensor
):
image_sizes
=
image_sizes
.
flatten
(
0
,
1
)
image_sizes
=
image_sizes
.
flatten
(
0
,
1
)
else
:
else
:
raise
ValueError
(
"Incorrect image_
attention_mask
inputs"
)
raise
ValueError
(
"Incorrect image_
sizes
inputs"
)
if
isinstance
(
num_img_tokens
,
list
):
if
isinstance
(
num_img_tokens
,
list
):
num_img_tokens
=
[
num_img_tokens
=
[
...
@@ -1084,7 +1107,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
...
@@ -1084,7 +1107,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif
isinstance
(
num_img_tokens
,
torch
.
Tensor
):
elif
isinstance
(
num_img_tokens
,
torch
.
Tensor
):
num_img_tokens
=
num_img_tokens
.
flatten
(
0
,
1
).
tolist
()
num_img_tokens
=
num_img_tokens
.
flatten
(
0
,
1
).
tolist
()
else
:
else
:
raise
ValueError
(
"Incorrect
image_attention_mask
inputs"
)
raise
ValueError
(
"Incorrect
num_img_tokens
inputs"
)
return
Phi4MMImagePixelInputs
(
return
Phi4MMImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment