Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3339cba3
Unverified
Commit
3339cba3
authored
Jul 26, 2025
by
Benji Beck
Committed by
GitHub
Jul 26, 2025
Browse files
Migrate FuyuImagePatchInputs to TensorSchema (#21662)
Signed-off-by:
Benji Beck
<
benjibeck@meta.com
>
parent
0b8caf90
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
45 deletions
+57
-45
tests/standalone_tests/test_tensor_schema.py
tests/standalone_tests/test_tensor_schema.py
+22
-0
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+19
-36
vllm/utils/tensor_schema.py
vllm/utils/tensor_schema.py
+16
-9
No files found.
tests/standalone_tests/test_tensor_schema.py
View file @
3339cba3
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.models.fuyu
import
FuyuImagePatchInputs
from
vllm.model_executor.models.phi3v
import
Phi3VImagePixelInputs
from
vllm.model_executor.models.phi3v
import
Phi3VImagePixelInputs
...
@@ -124,3 +125,24 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
...
@@ -124,3 +125,24 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
"w"
:
336
"w"
:
336
},
},
)
)
def
test_tensor_schema_with_list_of_symbolic_dim
():
flat_data
=
torch
.
stack
([
torch
.
randn
(
768
)
for
_
in
range
(
3
)])
# (bn=3, fn)
patches_per_image
=
[
64
,
64
,
64
]
# len = bn = 3
FuyuImagePatchInputs
(
flat_data
=
flat_data
,
patches_per_image
=
patches_per_image
,
)
def
test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length
():
flat_data
=
torch
.
stack
([
torch
.
randn
(
768
)
for
_
in
range
(
4
)])
# (bn=4, fn)
patches_per_image
=
[
64
,
64
,
64
]
# len = 3 ≠ bn
with
pytest
.
raises
(
ValueError
,
match
=
"expected 'bn'=4, got 3"
):
FuyuImagePatchInputs
(
flat_data
=
flat_data
,
patches_per_image
=
patches_per_image
,
)
\ No newline at end of file
vllm/model_executor/models/fuyu.py
View file @
3339cba3
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
""" PyTorch Fuyu model."""
""" PyTorch Fuyu model."""
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Literal
,
Optional
,
TypedDict
from
typing
import
Annotated
,
Literal
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
maybe_prefix
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
maybe_prefix
,
...
@@ -50,18 +51,24 @@ _IMAGE_TOKEN_ID = 71011
...
@@ -50,18 +51,24 @@ _IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID
=
71019
_NEWLINE_TOKEN_ID
=
71019
class
FuyuImagePatchInputs
(
TypedDict
):
class
FuyuImagePatchInputs
(
TensorSchema
):
type
:
Literal
[
"image_patches"
]
flat_data
:
torch
.
Tensor
"""
"""
Shape:
Dimensions:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
- bn: Batch size * number of images
- fn: Num channels * patch_size_x * patch_size_y
"""
"""
patches_per_image
:
list
[
int
]
type
:
Literal
[
"image_patches"
]
=
"image_patches"
flat_data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"fn"
),
]
patches_per_image
:
Annotated
[
list
[
int
],
TensorShape
(
"bn"
)]
"""
"""
The number of total patches for each image in the batch.
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
flattened just like `flat_data`.
"""
"""
...
@@ -297,42 +304,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -297,42 +304,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
patch_size
num_channels
=
self
.
config
.
num_channels
expected_dims
=
num_channels
*
h
*
w
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
d
.
size
(
-
1
)
if
actual_dims
!=
expected_dims
:
expected_expr
=
str
(
expected_dims
)
raise
ValueError
(
"The expected shape of pixel values per image per batch "
f
"per patch is
{
expected_expr
}
. "
f
"You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
.
to
(
self
.
vision_embed_tokens
.
weight
.
dtype
)
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
FuyuImagePatchInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
FuyuImagePatchInputs
]:
image_patches
=
kwargs
.
pop
(
"image_patches"
,
None
)
image_patches
=
kwargs
.
pop
(
"image_patches"
,
None
)
if
image_patches
is
not
None
:
if
image_patches
is
not
None
:
if
not
isinstance
(
image_patches
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image patches. "
f
"Got type:
{
type
(
image_patches
)
}
"
)
image_patches_flat
=
flatten_bn
(
image_patches
)
image_patches_flat
=
flatten_bn
(
image_patches
)
flat_data
=
flatten_bn
(
image_patches
,
concat
=
True
).
data
.
to
(
self
.
vision_embed_tokens
.
weight
.
dtype
)
return
FuyuImagePatchInputs
(
return
FuyuImagePatchInputs
(
type
=
"image_patches"
,
type
=
"image_patches"
,
flat_data
=
self
.
_validate_pixel_values
(
flat_data
=
flat_data
,
flatten_bn
(
image_patches_flat
,
concat
=
True
)),
patches_per_image
=
[
x
.
size
(
0
)
for
x
in
image_patches_flat
],
patches_per_image
=
[
x
.
size
(
0
)
for
x
in
image_patches_flat
],
resolve_bindings
=
{
"fn"
:
self
.
image_feature_size
},
)
)
return
None
return
None
...
...
vllm/utils/tensor_schema.py
View file @
3339cba3
...
@@ -86,9 +86,6 @@ class TensorSchema:
...
@@ -86,9 +86,6 @@ class TensorSchema:
expected_shape
:
tuple
[
Union
[
int
,
str
],
...],
expected_shape
:
tuple
[
Union
[
int
,
str
],
...],
dynamic_dims
:
set
[
str
,
...])
->
tuple
[
int
,
...]:
dynamic_dims
:
set
[
str
,
...])
->
tuple
[
int
,
...]:
"""Validate a list/tuple of tensors and return the actual shape."""
"""Validate a list/tuple of tensors and return the actual shape."""
if
not
value
:
raise
ValueError
(
f
"
{
field_name
}
is an empty list"
)
# Ensure all tensors in the list have the same
# Ensure all tensors in the list have the same
# shape, besides dynamic dimensions
# shape, besides dynamic dimensions
first
=
value
[
0
]
first
=
value
[
0
]
...
@@ -117,6 +114,7 @@ class TensorSchema:
...
@@ -117,6 +114,7 @@ class TensorSchema:
int
],
int
],
dynamic_dims
:
set
[
str
,
...])
->
None
:
dynamic_dims
:
set
[
str
,
...])
->
None
:
"""Validate that the actual tensor shape matches the expected shape."""
"""Validate that the actual tensor shape matches the expected shape."""
if
len
(
actual_shape
)
!=
len
(
expected_shape
):
if
len
(
actual_shape
)
!=
len
(
expected_shape
):
raise
ValueError
(
f
"
{
field_name
}
has rank
{
len
(
actual_shape
)
}
"
raise
ValueError
(
f
"
{
field_name
}
has rank
{
len
(
actual_shape
)
}
"
f
"but expected
{
len
(
expected_shape
)
}
"
)
f
"but expected
{
len
(
expected_shape
)
}
"
)
...
@@ -160,12 +158,11 @@ class TensorSchema:
...
@@ -160,12 +158,11 @@ class TensorSchema:
# Skip validation when Union contains None
# Skip validation when Union contains None
if
type
(
None
)
in
args
:
if
type
(
None
)
in
args
:
continue
continue
#
If not optional
, raise error
#
Otherwise field is required
, raise error
raise
ValueError
(
f
"Required field '
{
field_name
}
' is missing"
)
raise
ValueError
(
f
"Required field '
{
field_name
}
' is missing"
)
# Field exists, proceed with validation
# Field exists, proceed with validation
value
=
getattr
(
self
,
field_name
)
value
=
getattr
(
self
,
field_name
)
if
get_origin
(
field_type
)
is
not
None
:
if
get_origin
(
field_type
)
is
not
None
:
args
=
get_args
(
field_type
)
args
=
get_args
(
field_type
)
...
@@ -173,13 +170,23 @@ class TensorSchema:
...
@@ -173,13 +170,23 @@ class TensorSchema:
if
isinstance
(
arg
,
TensorShape
):
if
isinstance
(
arg
,
TensorShape
):
expected_shape
=
arg
.
resolve
(
**
self
.
_resolve_bindings
)
expected_shape
=
arg
.
resolve
(
**
self
.
_resolve_bindings
)
if
isinstance
(
value
,
(
list
,
tuple
)):
if
isinstance
(
value
,
(
list
,
tuple
)):
actual_shape
=
self
.
_validate_nested_tensors
(
# list/tuple of Tensors → shape = (len(value), ...)
value
,
field_name
,
expected_shape
,
if
value
and
isinstance
(
value
[
0
],
torch
.
Tensor
):
arg
.
dynamic_dims
)
actual_shape
=
self
.
_validate_nested_tensors
(
value
,
field_name
,
expected_shape
,
arg
.
dynamic_dims
)
elif
value
:
# list/tuple of scalars → shape = (len(value),)
actual_shape
=
(
len
(
value
),
)
else
:
raise
ValueError
(
f
"
{
field_name
}
is an empty list"
)
# Tensor → shape = tensor.shape
elif
isinstance
(
value
,
torch
.
Tensor
):
elif
isinstance
(
value
,
torch
.
Tensor
):
actual_shape
=
value
.
shape
actual_shape
=
value
.
shape
# Otherwise, it's an unsupported type
else
:
else
:
type_names
=
[]
type_names
=
[]
for
arg
in
args
:
for
arg
in
args
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment