Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ea57a56
Unverified
Commit
3ea57a56
authored
Jul 27, 2025
by
Benji Beck
Committed by
GitHub
Jul 27, 2025
Browse files
Migrate Idefics3ImagePixelInputs and Idefics3ImageEmbeddingInputs to … (#21683)
Signed-off-by:
Benji Beck
<
benjibeck@meta.com
>
parent
75856bc2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
41 deletions
+28
-41
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+28
-41
No files found.
vllm/model_executor/models/idefics3.py
View file @
3ea57a56
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Literal
,
Optional
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -45,6 +45,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -45,6 +45,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
# yapf: disable
# yapf: disable
from
.idefics2_vision_model
import
(
from
.idefics2_vision_model
import
(
...
@@ -56,26 +57,30 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
...
@@ -56,26 +57,30 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
class
Idefics3ImagePixelInputs
(
TypedDict
):
class
Idefics3ImagePixelInputs
(
TensorSchema
):
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
"""
"""
Shape: `(batch_size * num_images * num_patches,
Dimensions:
num_channels, height, width)`
- bn: Batch size * number of images
- bnp: Batch size * number of images * number of patches
- c: Number of channels (3)
- h: Height
- w: Width
"""
"""
type
:
Literal
[
"pixel_values"
]
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bnp"
,
3
,
"h"
,
"w"
)]
pixel_attention_mask
:
torch
.
Tensor
pixel_attention_mask
:
torch
.
Tensor
num_patches
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
)]
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
class
Idefics3ImageEmbeddingInputs
(
TypedDict
):
class
Idefics3ImageEmbeddingInputs
(
TensorSchema
):
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
"""
"""
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
Dimensions:
`hidden_size` must match the hidden size of language model backbone.
- bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
"""
type
:
Literal
[
"image_embeds"
]
data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"f"
,
"h"
)]
ImageInputs
=
Union
[
Idefics3ImagePixelInputs
,
Idefics3ImageEmbeddingInputs
]
ImageInputs
=
Union
[
Idefics3ImagePixelInputs
,
Idefics3ImageEmbeddingInputs
]
...
@@ -614,25 +619,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -614,25 +619,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
lm_head
.
weight
=
self
.
model
.
text_model
.
wte
.
weight
self
.
lm_head
.
weight
=
self
.
model
.
text_model
.
wte
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
.
vocab_size
)
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
tuple
(
d
.
shape
)
if
actual_dims
!=
expected_dims
:
expected_expr
=
str
(
expected_dims
)
raise
ValueError
(
"The expected shape of pixel values per image per batch "
f
" per patch is
{
expected_expr
}
. "
f
"You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
ImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
ImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
@@ -666,16 +652,17 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -666,16 +652,17 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
ValueError
(
"Incorrect type of num_patches. "
raise
ValueError
(
"Incorrect type of num_patches. "
f
"Got type:
{
type
(
num_patches
)
}
"
)
f
"Got type:
{
type
(
num_patches
)
}
"
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
expected_h
=
expected_w
=
self
.
config
.
vision_config
.
image_size
pixel_attention_mask
=
flatten_bn
(
pixel_attention_mask
,
concat
=
True
)
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
)
return
Idefics3ImagePixelInputs
(
return
Idefics3ImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
),
pixel_attention_mask
=
pixel_attention_mask
,
pixel_attention_mask
=
flatten_bn
(
pixel_attention_mask
,
num_patches
=
num_patches
,
concat
=
True
),
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
),
resolve_bindings
=
{
"h"
:
expected_h
,
"w"
:
expected_w
},
)
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment