Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d8937de4
Unverified
Commit
d8937de4
authored
Jul 27, 2025
by
Benji Beck
Committed by
GitHub
Jul 27, 2025
Browse files
Migrate Gemma3ImagePixelInputs to TensorSchema (#21676)
Signed-off-by:
Benji Beck
<
benjibeck@meta.com
>
parent
e626d286
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
25 deletions
+21
-25
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+21
-25
No files found.
vllm/model_executor/models/gemma3_mm.py
View file @
d8937de4
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Any
,
Literal
,
Optional
,
TypedDict
from
typing
import
Annotated
,
Any
,
Literal
,
Optional
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
)
SupportsMultiModal
,
SupportsPP
)
...
@@ -42,18 +43,21 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
...
@@ -42,18 +43,21 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
Gemma3ImagePixelInputs
(
TypedDict
):
class
Gemma3ImagePixelInputs
(
TensorSchema
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
"""
"""
Shape: `(num_patches_total, num_channels, height, width)`
Dimensions:
- p: Number of patches total (over each image over each prompt in the
`num_patches_total` is the total number of patches
batch)
over each image over each prompt in the batch.
- c: Number of channels (3)
- h: Height of each patch
- w: Width of each patch
- bn: Batch size * number of images
"""
"""
type
:
Literal
[
"pixel_values"
]
=
"pixel_values"
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"p"
,
3
,
"h"
,
"w"
)]
num_patches
:
torch
.
Tensor
num_patches
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
)]
"""Shape: `(batch_size * num_images)`"""
Gemma3ImageInputs
=
Gemma3ImagePixelInputs
Gemma3ImageInputs
=
Gemma3ImagePixelInputs
...
@@ -523,15 +527,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -523,15 +527,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def
dtype
(
self
):
def
dtype
(
self
):
return
next
(
self
.
parameters
()).
dtype
return
next
(
self
.
parameters
()).
dtype
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_size
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
image_size
,
image_size
)
if
data
.
shape
[
1
:]
!=
expected_dims
:
raise
ValueError
(
"The expected shape of pixel values per image per batch is "
f
"
{
expected_dims
}
. You supplied
{
tuple
(
data
.
shape
)
}
."
)
return
data
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3ImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3ImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
@@ -549,14 +544,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -549,14 +544,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise
ValueError
(
"Incorrect type of num_crops. "
raise
ValueError
(
"Incorrect type of num_crops. "
f
"Got type:
{
type
(
num_crops
)
}
"
)
f
"Got type:
{
type
(
num_crops
)
}
"
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
image_size
=
self
.
config
.
vision_config
.
image_size
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
return
Gemma3ImagePixelInputs
(
return
Gemma3ImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
),
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
num_patches
=
flatten_bn
(
num_crops
,
concat
=
True
)
+
1
,
num_patches
=
num_crops
+
1
,
resolve_bindings
=
{
)
"h"
:
image_size
,
"w"
:
image_size
})
def
_image_pixels_to_features
(
def
_image_pixels_to_features
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment