Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f32a5bc5
Unverified
Commit
f32a5bc5
authored
Aug 28, 2025
by
Benji Beck
Committed by
GitHub
Aug 28, 2025
Browse files
Migrate Llama4ImagePatchInputs to TensorSchema (#22021)
Signed-off-by:
Benji Beck
<
benjibeck@meta.com
>
parent
8805ad9f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
18 deletions
+23
-18
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+23
-18
No files found.
vllm/model_executor/models/mllama4.py
View file @
f32a5bc5
...
...
@@ -19,7 +19,7 @@
import
math
from
collections.abc
import
Iterable
,
Mapping
from
itertools
import
tee
from
typing
import
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Literal
,
Optional
,
Union
import
torch
from
torch
import
nn
...
...
@@ -53,6 +53,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.utils
import
run_dp_sharded_vision_model
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.llama4
import
Llama4ForCausalLM
...
...
@@ -60,14 +61,22 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings
)
class
Llama4ImagePatchInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
flat_data
:
torch
.
Tensor
class
Llama4ImagePatchInputs
(
TensorSchema
):
"""
Shape:
`(batch_size * num_chunks, num_channels, image size, image size)`
Dimensions:
- batch_size: Batch size
- total_num_chunks: Batch size * number of chunks
- num_channels: Number of channels
- image_size: Size of each image
"""
patches_per_image
:
torch
.
Tensor
type
:
Literal
[
"pixel_values"
]
=
"pixel_values"
flat_data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"total_num_chunks"
,
"num_channels"
,
"image_size"
,
"image_size"
)]
patches_per_image
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"batch_size"
)]
"""
The number of total patches for each image in the batch.
...
...
@@ -75,13 +84,11 @@ class Llama4ImagePatchInputs(TypedDict):
flattened just like `flat_data`.
"""
aspect_ratios
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]
]
aspect_ratios
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"batch_size"
,
2
)
]
"""
A list of aspect ratios corresponding to the number of tiles
in each dimension that each image in the batch corresponds to.
Shape:
`(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)`
Each aspect ratio is a pair (ratio_h, ratio_w).
"""
...
...
@@ -623,7 +630,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
for
(
r_h
,
r_w
)
in
aspect_ratios
]
processed_outputs
[
"aspect_ratios"
]
=
aspect_ratios
processed_outputs
[
"aspect_ratios"
]
=
torch
.
tensor
(
aspect_ratios
)
processed_outputs
[
"patches_per_image"
]
=
torch
.
tensor
(
patches_per_image
)
...
...
@@ -770,11 +777,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
# TODO: confirm handling for variable lengths
flat_pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
patches_per_image
=
flatten_bn
(
kwargs
.
pop
(
"patches_per_image"
))
aspect_ratios
=
kwargs
.
pop
(
"aspect_ratios"
,
None
)
if
not
isinstance
(
aspect_ratios
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of aspect_ratios. "
f
"Got type:
{
type
(
aspect_ratios
)
}
"
)
aspect_ratios
=
kwargs
.
pop
(
"aspect_ratios"
)
if
aspect_ratios
.
ndim
==
3
:
aspect_ratios
=
aspect_ratios
.
squeeze
(
1
)
return
Llama4ImagePatchInputs
(
type
=
"pixel_values"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment