Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3339cba3
Unverified
Commit
3339cba3
authored
Jul 26, 2025
by
Benji Beck
Committed by
GitHub
Jul 26, 2025
Browse files
Migrate FuyuImagePatchInputs to TensorSchema (#21662)
Signed-off-by:
Benji Beck
<
benjibeck@meta.com
>
parent
0b8caf90
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
45 deletions
+57
-45
tests/standalone_tests/test_tensor_schema.py
tests/standalone_tests/test_tensor_schema.py
+22
-0
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+19
-36
vllm/utils/tensor_schema.py
vllm/utils/tensor_schema.py
+16
-9
No files found.
tests/standalone_tests/test_tensor_schema.py
View file @
3339cba3
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.models.fuyu
import
FuyuImagePatchInputs
from
vllm.model_executor.models.phi3v
import
Phi3VImagePixelInputs
from
vllm.model_executor.models.phi3v
import
Phi3VImagePixelInputs
...
@@ -124,3 +125,24 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
...
@@ -124,3 +125,24 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
"w"
:
336
"w"
:
336
},
},
)
)
def
test_tensor_schema_with_list_of_symbolic_dim
():
flat_data
=
torch
.
stack
([
torch
.
randn
(
768
)
for
_
in
range
(
3
)])
# (bn=3, fn)
patches_per_image
=
[
64
,
64
,
64
]
# len = bn = 3
FuyuImagePatchInputs
(
flat_data
=
flat_data
,
patches_per_image
=
patches_per_image
,
)
def
test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length
():
flat_data
=
torch
.
stack
([
torch
.
randn
(
768
)
for
_
in
range
(
4
)])
# (bn=4, fn)
patches_per_image
=
[
64
,
64
,
64
]
# len = 3 ≠ bn
with
pytest
.
raises
(
ValueError
,
match
=
"expected 'bn'=4, got 3"
):
FuyuImagePatchInputs
(
flat_data
=
flat_data
,
patches_per_image
=
patches_per_image
,
)
\ No newline at end of file
vllm/model_executor/models/fuyu.py
View file @
3339cba3
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
""" PyTorch Fuyu model."""
""" PyTorch Fuyu model."""
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Literal
,
Optional
,
TypedDict
from
typing
import
Annotated
,
Literal
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
maybe_prefix
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
flatten_bn
,
maybe_prefix
,
...
@@ -50,15 +51,21 @@ _IMAGE_TOKEN_ID = 71011
...
@@ -50,15 +51,21 @@ _IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID
=
71019
_NEWLINE_TOKEN_ID
=
71019
class
FuyuImagePatchInputs
(
TypedDict
):
class
FuyuImagePatchInputs
(
TensorSchema
):
type
:
Literal
[
"image_patches"
]
flat_data
:
torch
.
Tensor
"""
"""
Shape:
Dimensions:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
- bn: Batch size * number of images
- fn: Num channels * patch_size_x * patch_size_y
"""
"""
patches_per_image
:
list
[
int
]
type
:
Literal
[
"image_patches"
]
=
"image_patches"
flat_data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"fn"
),
]
patches_per_image
:
Annotated
[
list
[
int
],
TensorShape
(
"bn"
)]
"""
"""
The number of total patches for each image in the batch.
The number of total patches for each image in the batch.
...
@@ -297,42 +304,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -297,42 +304,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
patch_size
num_channels
=
self
.
config
.
num_channels
expected_dims
=
num_channels
*
h
*
w
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
d
.
size
(
-
1
)
if
actual_dims
!=
expected_dims
:
expected_expr
=
str
(
expected_dims
)
raise
ValueError
(
"The expected shape of pixel values per image per batch "
f
"per patch is
{
expected_expr
}
. "
f
"You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
.
to
(
self
.
vision_embed_tokens
.
weight
.
dtype
)
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
FuyuImagePatchInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
FuyuImagePatchInputs
]:
image_patches
=
kwargs
.
pop
(
"image_patches"
,
None
)
image_patches
=
kwargs
.
pop
(
"image_patches"
,
None
)
if
image_patches
is
not
None
:
if
image_patches
is
not
None
:
if
not
isinstance
(
image_patches
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image patches. "
f
"Got type:
{
type
(
image_patches
)
}
"
)
image_patches_flat
=
flatten_bn
(
image_patches
)
image_patches_flat
=
flatten_bn
(
image_patches
)
flat_data
=
flatten_bn
(
image_patches
,
concat
=
True
).
data
.
to
(
self
.
vision_embed_tokens
.
weight
.
dtype
)
return
FuyuImagePatchInputs
(
return
FuyuImagePatchInputs
(
type
=
"image_patches"
,
type
=
"image_patches"
,
flat_data
=
self
.
_validate_pixel_values
(
flat_data
=
flat_data
,
flatten_bn
(
image_patches_flat
,
concat
=
True
)),
patches_per_image
=
[
x
.
size
(
0
)
for
x
in
image_patches_flat
],
patches_per_image
=
[
x
.
size
(
0
)
for
x
in
image_patches_flat
],
resolve_bindings
=
{
"fn"
:
self
.
image_feature_size
},
)
)
return
None
return
None
...
...
vllm/utils/tensor_schema.py
View file @
3339cba3
...
@@ -86,9 +86,6 @@ class TensorSchema:
...
@@ -86,9 +86,6 @@ class TensorSchema:
expected_shape
:
tuple
[
Union
[
int
,
str
],
...],
expected_shape
:
tuple
[
Union
[
int
,
str
],
...],
dynamic_dims
:
set
[
str
,
...])
->
tuple
[
int
,
...]:
dynamic_dims
:
set
[
str
,
...])
->
tuple
[
int
,
...]:
"""Validate a list/tuple of tensors and return the actual shape."""
"""Validate a list/tuple of tensors and return the actual shape."""
if
not
value
:
raise
ValueError
(
f
"
{
field_name
}
is an empty list"
)
# Ensure all tensors in the list have the same
# Ensure all tensors in the list have the same
# shape, besides dynamic dimensions
# shape, besides dynamic dimensions
first
=
value
[
0
]
first
=
value
[
0
]
...
@@ -117,6 +114,7 @@ class TensorSchema:
...
@@ -117,6 +114,7 @@ class TensorSchema:
int
],
int
],
dynamic_dims
:
set
[
str
,
...])
->
None
:
dynamic_dims
:
set
[
str
,
...])
->
None
:
"""Validate that the actual tensor shape matches the expected shape."""
"""Validate that the actual tensor shape matches the expected shape."""
if
len
(
actual_shape
)
!=
len
(
expected_shape
):
if
len
(
actual_shape
)
!=
len
(
expected_shape
):
raise
ValueError
(
f
"
{
field_name
}
has rank
{
len
(
actual_shape
)
}
"
raise
ValueError
(
f
"
{
field_name
}
has rank
{
len
(
actual_shape
)
}
"
f
"but expected
{
len
(
expected_shape
)
}
"
)
f
"but expected
{
len
(
expected_shape
)
}
"
)
...
@@ -160,12 +158,11 @@ class TensorSchema:
...
@@ -160,12 +158,11 @@ class TensorSchema:
# Skip validation when Union contains None
# Skip validation when Union contains None
if
type
(
None
)
in
args
:
if
type
(
None
)
in
args
:
continue
continue
#
If not optional
, raise error
#
Otherwise field is required
, raise error
raise
ValueError
(
f
"Required field '
{
field_name
}
' is missing"
)
raise
ValueError
(
f
"Required field '
{
field_name
}
' is missing"
)
# Field exists, proceed with validation
# Field exists, proceed with validation
value
=
getattr
(
self
,
field_name
)
value
=
getattr
(
self
,
field_name
)
if
get_origin
(
field_type
)
is
not
None
:
if
get_origin
(
field_type
)
is
not
None
:
args
=
get_args
(
field_type
)
args
=
get_args
(
field_type
)
...
@@ -173,13 +170,23 @@ class TensorSchema:
...
@@ -173,13 +170,23 @@ class TensorSchema:
if
isinstance
(
arg
,
TensorShape
):
if
isinstance
(
arg
,
TensorShape
):
expected_shape
=
arg
.
resolve
(
**
self
.
_resolve_bindings
)
expected_shape
=
arg
.
resolve
(
**
self
.
_resolve_bindings
)
if
isinstance
(
value
,
(
list
,
tuple
)):
if
isinstance
(
value
,
(
list
,
tuple
)):
# list/tuple of Tensors → shape = (len(value), ...)
if
value
and
isinstance
(
value
[
0
],
torch
.
Tensor
):
actual_shape
=
self
.
_validate_nested_tensors
(
actual_shape
=
self
.
_validate_nested_tensors
(
value
,
field_name
,
expected_shape
,
value
,
field_name
,
expected_shape
,
arg
.
dynamic_dims
)
arg
.
dynamic_dims
)
elif
value
:
# list/tuple of scalars → shape = (len(value),)
actual_shape
=
(
len
(
value
),
)
else
:
raise
ValueError
(
f
"
{
field_name
}
is an empty list"
)
# Tensor → shape = tensor.shape
elif
isinstance
(
value
,
torch
.
Tensor
):
elif
isinstance
(
value
,
torch
.
Tensor
):
actual_shape
=
value
.
shape
actual_shape
=
value
.
shape
# Otherwise, it's an unsupported type
else
:
else
:
type_names
=
[]
type_names
=
[]
for
arg
in
args
:
for
arg
in
args
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment