Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9d197280
Unverified
Commit
9d197280
authored
Jul 26, 2025
by
Benji Beck
Committed by
GitHub
Jul 26, 2025
Browse files
Migrate AriaImagePixelInputs to TensorSchema for shape validation (#21620)
Signed-off-by:
Benji Beck
<
benjibeck@meta.com
>
parent
e98def43
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
29 deletions
+21
-29
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+21
-29
No files found.
vllm/model_executor/models/aria.py
View file @
9d197280
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Optional
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
# yapf: disable
# yapf: disable
from
.idefics2_vision_model
import
Idefics2VisionConfig
from
.idefics2_vision_model
import
Idefics2VisionConfig
...
@@ -42,15 +43,26 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
...
@@ -42,15 +43,26 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
class
AriaImagePixelInputs
(
TypedDict
):
class
AriaImagePixelInputs
(
TensorSchema
):
pixel_values
:
torch
.
Tensor
pixel_mask
:
Optional
[
torch
.
Tensor
]
"""
"""
Shape:
Dimensions:
pixel_values: `(batch_size * num_images, num_channels, height, width)`
- b: Batch size
pixel_mask: `(batch_size * num_images, height, width)`
- n: Number of images
- c: Number of channels
- h: Height of each image
- w: Width of each image
"""
"""
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
3
,
"h"
,
"w"
),
]
pixel_mask
:
Annotated
[
Optional
[
torch
.
Tensor
],
TensorShape
(
"bn"
,
"h"
,
"w"
),
]
class
AriaVisionTransformer
(
Idefics3VisionTransformer
,
SupportsQuant
):
class
AriaVisionTransformer
(
Idefics3VisionTransformer
,
SupportsQuant
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
...
@@ -540,12 +552,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -540,12 +552,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
vocab_size
,
logit_scale
)
self
.
vocab_size
,
logit_scale
)
def
_validate_image_sizes
(
self
,
images
:
list
[
torch
.
Tensor
])
->
list
[
torch
.
Tensor
]:
if
not
all
(
img
.
shape
==
images
[
0
].
shape
for
img
in
images
):
raise
ValueError
(
"All images must be the same size"
)
return
images
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
AriaImagePixelInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
AriaImagePixelInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
@@ -554,23 +560,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -554,23 +560,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
if
pixel_values
is
None
:
if
pixel_values
is
None
:
return
None
return
None
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
pixel_values
=
self
.
_validate_image_sizes
(
pixel_values
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
if
pixel_mask
is
not
None
:
if
not
isinstance
(
pixel_mask
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel mask. "
f
"Got type:
{
type
(
pixel_mask
)
}
"
)
pixel_mask
=
flatten_bn
(
pixel_mask
,
concat
=
True
)
return
AriaImagePixelInputs
(
return
AriaImagePixelInputs
(
pixel_values
=
pixel_values
,
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
),
pixel_mask
=
pixel_mask
,
pixel_mask
=
flatten_bn
(
pixel_mask
,
concat
=
True
)
,
)
)
def
_create_patch_attention_mask
(
def
_create_patch_attention_mask
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment