Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e6c9053f
Unverified
Commit
e6c9053f
authored
Mar 27, 2025
by
Cyrus Leung
Committed by
GitHub
Mar 27, 2025
Browse files
[Misc] Clean up `scatter_patch_features` (#15559)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
43ed4143
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
84 additions
and
138 deletions
+84
-138
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+6
-11
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+9
-12
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+10
-12
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+32
-73
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+7
-11
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+20
-19
No files found.
vllm/model_executor/models/gemma3_mm.py
View file @
e6c9053f
...
@@ -30,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -30,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
flatten_2d_lists
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
)
SupportsMultiModal
,
SupportsPP
)
...
@@ -60,7 +59,7 @@ class Gemma3ImagePixelInputs(TypedDict):
...
@@ -60,7 +59,7 @@ class Gemma3ImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
A boolean mask indicating which image embeddings correspond
to patch tokens.
to patch tokens.
Shape: `(batch_size
,
num_images, num_embeds)`
Shape: `(batch_size
*
num_images, num_embeds)`
"""
"""
...
@@ -593,6 +592,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -593,6 +592,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
Gemma3ImagePixelInputs
(
return
Gemma3ImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
...
@@ -635,14 +635,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -635,14 +635,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
image_features
=
self
.
_process_image_input
(
image_input
)
image_features
=
self
.
_process_image_input
(
image_input
)
if
kwargs
.
get
(
"v0_path"
,
False
):
return
scatter_patch_features
(
return
image_features
image_features
,
image_input
[
"embed_is_patch"
],
return
flatten_2d_lists
(
)
scatter_patch_features
(
*
args
)
for
args
in
zip
(
image_features
,
image_input
[
"embed_is_patch"
],
))
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -671,7 +667,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -671,7 +667,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
# condition is for v0 compatibility.
elif
inputs_embeds
is
None
:
elif
inputs_embeds
is
None
:
kwargs
.
update
({
"v0_path"
:
True
})
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
...
...
vllm/model_executor/models/internvl.py
View file @
e6c9053f
...
@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
flatten_2d_lists
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
...
@@ -66,13 +65,13 @@ class InternVLImagePixelInputs(TypedDict):
...
@@ -66,13 +65,13 @@ class InternVLImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
A boolean mask indicating which image embeddings correspond
to patch tokens.
to patch tokens.
Shape: `(batch_size
,
num_images, num_embeds)`
Shape: `(batch_size
*
num_images, num_embeds)`
"""
"""
class
InternVLImageEmbeddingInputs
(
TypedDict
):
class
InternVLImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
data
:
Nested
Tensor
s
data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
...
@@ -867,6 +866,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -867,6 +866,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat
=
flatten_bn
(
pixel_values_flat
,
concat
=
True
)
pixel_values_flat
=
flatten_bn
(
pixel_values_flat
,
concat
=
True
)
image_num_patches
=
flatten_bn
(
image_num_patches
,
concat
=
True
)
image_num_patches
=
flatten_bn
(
image_num_patches
,
concat
=
True
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
InternVLImagePixelInputs
(
return
InternVLImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
...
@@ -881,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -881,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def
_process_image_input
(
def
_process_image_input
(
self
,
self
,
image_input
:
InternVLImageInputs
,
image_input
:
InternVLImageInputs
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
],
tuple
[
torch
.
Tensor
,
...]]:
if
image_input
[
"type"
]
==
"image_embeds"
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
return
image_input
[
"data"
]
...
@@ -921,15 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -921,15 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_features
=
self
.
_process_image_input
(
image_input
)
image_features
=
self
.
_process_image_input
(
image_input
)
if
(
kwargs
.
get
(
"v0_path"
,
False
)
if
image_input
[
"type"
]
!=
"pixel_values"
:
or
image_input
[
"type"
]
!=
"pixel_values"
):
return
image_features
return
image_features
return
flatten_2d_lists
(
return
scatter_patch_features
(
scatter_patch_features
(
*
args
)
for
args
in
zip
(
image_features
,
image_features
,
image_input
[
"embed_is_patch"
],
image_input
[
"embed_is_patch"
],
)
))
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -964,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -964,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
# condition is for v0 compatibility.
elif
inputs_embeds
is
None
:
elif
inputs_embeds
is
None
:
kwargs
.
update
({
"v0_path"
:
True
})
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
vision_embeddings
)
...
...
vllm/model_executor/models/llava.py
View file @
e6c9053f
...
@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement
,
PromptUpdate
)
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
flatten_2d_lists
from
.clip
import
CLIPVisionModel
from
.clip
import
CLIPVisionModel
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
...
@@ -73,7 +72,7 @@ class PixtralHFImagePixelInputs(TypedDict):
...
@@ -73,7 +72,7 @@ class PixtralHFImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
A boolean mask indicating which image embeddings correspond
to patch tokens.
to patch tokens.
Shape: `(batch_size
,
num_images, num_embeds)`
Shape: `(batch_size
*
num_images, num_embeds)`
"""
"""
...
@@ -618,6 +617,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -618,6 +617,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
PixtralHFImagePixelInputs
(
return
PixtralHFImagePixelInputs
(
type
=
"pixel_values_pixtral"
,
type
=
"pixel_values_pixtral"
,
pixel_values
=
flatten_bn
(
pixel_values
),
pixel_values
=
flatten_bn
(
pixel_values
),
...
@@ -713,18 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -713,18 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
vision_embedding
s
=
self
.
_process_image_input
(
image_input
)
image_feature
s
=
self
.
_process_image_input
(
image_input
)
if
(
kwargs
.
get
(
"v0_path"
,
False
)
if
image_input
[
"type"
]
!=
"pixel_values_pixtral"
:
or
image_input
[
"type"
]
!=
"pixel_values_pixtral"
):
# The path is used for pixtral (V0 only) and llava (V0/V1)
# The path is used for pixtral (V0 only) and llava (V0/V1)
return
vision_embedding
s
return
image_feature
s
return
flatten_2d_lists
(
return
scatter_patch_features
(
scatter_patch_features
(
*
args
)
for
args
in
zip
(
image_features
,
vision_embeddings
,
image_input
[
"embed_is_patch"
],
image_input
[
"embed_is_patch"
],
)
))
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -790,7 +789,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -790,7 +789,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
# condition is for v0 compatibility.
elif
inputs_embeds
is
None
:
elif
inputs_embeds
is
None
:
kwargs
.
update
({
"v0_path"
:
True
})
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
vision_embeddings
)
...
...
vllm/model_executor/models/molmo.py
View file @
e6c9053f
...
@@ -49,7 +49,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -49,7 +49,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion
,
PromptUpdate
)
PromptInsertion
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
flatten_2d_lists
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
)
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
)
...
@@ -72,17 +71,17 @@ POOLING_SIZE = 2
...
@@ -72,17 +71,17 @@ POOLING_SIZE = 2
class
MolmoImageInputs
(
TypedDict
):
class
MolmoImageInputs
(
TypedDict
):
images
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
images
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
"""Shape: `(batch_size
* num_images
, num_crops, num_patch, patch_dim)`"""
image_masks
:
Optional
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]]
image_masks
:
Optional
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]]
"""Shape: `(batch_size, num_crops, num_patch)`"""
"""Shape: `(batch_size
* num_images
, num_crops, num_patch)`"""
feat_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
feat_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
"""
A boolean mask indicating which image features correspond
A boolean mask indicating which image features correspond
to patch tokens.
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
Shape: `(batch_size
* num_images
, num_crops, num_patch)`
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
...
@@ -90,7 +89,7 @@ class MolmoImageInputs(TypedDict):
...
@@ -90,7 +89,7 @@ class MolmoImageInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
A boolean mask indicating which image embeddings correspond
to patch tokens.
to patch tokens.
Shape: `(batch_size, num_embeds)`
Shape: `(batch_size
* num_images
, num_embeds)`
"""
"""
num_crops
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
num_crops
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
...
@@ -696,9 +695,10 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
...
@@ -696,9 +695,10 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
return
image_features
return
image_features
def
forward
(
def
forward
(
self
,
images
:
torch
.
Tensor
,
image_masks
:
torch
.
Tensor
self
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
images
:
torch
.
Tensor
,
image_masks
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
batch_size
,
num_image
=
images
.
shape
[:
2
]
batch_size
,
num_image
=
images
.
shape
[:
2
]
images
=
images
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
images
=
images
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
...
@@ -1491,6 +1491,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1491,6 +1491,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f
"Got type:
{
type
(
img_patch_id
)
}
"
)
f
"Got type:
{
type
(
img_patch_id
)
}
"
)
self
.
img_patch_id
=
img_patch_id
.
flatten
().
unique
().
item
()
self
.
img_patch_id
=
img_patch_id
.
flatten
().
unique
().
item
()
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
MolmoImageInputs
(
return
MolmoImageInputs
(
images
=
images
,
images
=
images
,
image_masks
=
image_masks
,
image_masks
=
image_masks
,
...
@@ -1502,13 +1504,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1502,13 +1504,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def
_process_image_input
(
def
_process_image_input
(
self
,
self
,
image_input
:
MolmoImageInputs
,
image_input
:
MolmoImageInputs
,
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
)
->
list
[
torch
.
Tensor
]:
if
isinstance
(
image_input
[
"images"
],
list
):
images
=
image_input
[
"images"
]
image_masks
=
image_input
[
"image_masks"
]
feat_is_patch
=
image_input
[
"feat_is_patch"
]
num_crops
=
image_input
[
"num_crops"
]
if
isinstance
(
images
,
list
):
# Call the vision backbone on the whole batch at once
# Call the vision backbone on the whole batch at once
images_flat
=
flatten_bn
(
image_input
[
"images"
],
concat
=
True
)
images_flat
=
flatten_bn
(
images
,
concat
=
True
)
image_masks_flat
=
(
None
if
(
image_masks
:
=
image_masks_flat
=
(
None
if
image_masks
is
None
else
flatten_bn
(
image_input
[
"image_masks"
])
is
None
image_masks
,
concat
=
True
))
else
flatten_bn
(
image_masks
,
concat
=
True
))
image_features_flat
=
self
.
vision_backbone
(
image_features_flat
=
self
.
vision_backbone
(
images
=
images_flat
.
unsqueeze
(
0
),
images
=
images_flat
.
unsqueeze
(
0
),
...
@@ -1517,63 +1523,19 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1517,63 +1523,19 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
).
squeeze
(
0
)
).
squeeze
(
0
)
# Reconstruct the batch dimension
# Reconstruct the batch dimension
image_features
=
image_features_flat
.
split
(
num_crops_per_image
=
[
nc
.
sum
().
item
()
for
nc
in
num_crops
]
image_input
[
"num_crops"
].
sum
(
-
1
).
tolist
()
)
image_features
=
image_features_flat
.
split
(
num_crops_per_image
)
else
:
else
:
image_features
=
self
.
vision_backbone
(
image_features
=
self
.
vision_backbone
(
images
=
image
_input
[
"images"
]
,
images
=
image
s
,
image_masks
=
image_
input
[
"image_
masks
"
]
,
image_masks
=
image_masks
,
)
)
return
image_features
# Only the features corresponding to patch tokens are relevant
return
[
def
_get_mm_embeds
(
self
,
features
:
torch
.
Tensor
,
# Shape: (num_crop, num_patch, d)
feat_is_patch
:
torch
.
Tensor
,
# Shape: (num_crop, num_patch)
num_crops
:
torch
.
Tensor
,
# Shape: (num_images,)
embed_is_patch
:
torch
.
Tensor
,
# Shape: (num_embeds,)
)
->
tuple
[
torch
.
Tensor
,
...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.
Example:
A simplified example for one item in the batch:
.. code-block::
Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 0 p3 p4 0 ]
feat_is_patch (from HF processor):
[ True True False True True False ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_crops_per_image
=
num_crops
.
tolist
()
feats_per_image
=
features
.
split
(
num_crops_per_image
)
f_is_patch_per_image
=
feat_is_patch
.
split
(
num_crops_per_image
)
features
=
torch
.
cat
([
feats
[
f_is_patch
]
feats
[
f_is_patch
]
for
feats
,
f_is_patch
in
zip
(
feats_per_image
,
f_is_patch_per_image
)
for
feats
,
f_is_patch
in
zip
(
image_features
,
feat_is_patch
)
])
]
return
scatter_patch_features
(
features
,
embed_is_patch
)
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
...
@@ -1583,13 +1545,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1583,13 +1545,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_features
=
self
.
_process_image_input
(
image_input
)
image_features
=
self
.
_process_image_input
(
image_input
)
return
flatten_2d_lists
(
return
scatter_patch_features
(
self
.
_get_mm_embeds
(
*
args
)
for
args
in
zip
(
image_features
,
image_features
,
image_input
[
"embed_is_patch"
],
image_input
[
"feat_is_patch"
],
)
image_input
[
"num_crops"
],
image_input
[
"embed_is_patch"
],
))
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
...
vllm/model_executor/models/pixtral.py
View file @
e6c9053f
...
@@ -42,7 +42,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
...
@@ -42,7 +42,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
(
MistralTokenizer
,
from
vllm.transformers_utils.tokenizer
import
(
MistralTokenizer
,
cached_tokenizer_from_config
)
cached_tokenizer_from_config
)
from
vllm.utils
import
flatten_2d_lists
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
from
.utils
import
(
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
...
@@ -74,7 +73,7 @@ class PixtralImagePixelInputs(TypedDict):
...
@@ -74,7 +73,7 @@ class PixtralImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
A boolean mask indicating which image embeddings correspond
to patch tokens.
to patch tokens.
Shape: `(batch_size
,
num_images, num_embeds)`
Shape: `(batch_size
*
num_images, num_embeds)`
"""
"""
...
@@ -387,6 +386,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -387,6 +386,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
ValueError
(
"Incorrect type of embed_is_patch. "
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
PixtralImagePixelInputs
(
return
PixtralImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
images
=
flatten_bn
(
images
),
images
=
flatten_bn
(
images
),
...
@@ -428,14 +429,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -428,14 +429,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
image_features
=
self
.
_process_image_input
(
image_input
)
image_features
=
self
.
_process_image_input
(
image_input
)
if
kwargs
.
get
(
"v0_path"
,
False
):
return
scatter_patch_features
(
return
image_features
image_features
,
image_input
[
"embed_is_patch"
],
return
flatten_2d_lists
(
)
scatter_patch_features
(
*
args
)
for
args
in
zip
(
image_features
,
image_input
[
"embed_is_patch"
],
))
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -467,7 +464,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -467,7 +464,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
# condition is for v0 compatibility.
elif
inputs_embeds
is
None
:
elif
inputs_embeds
is
None
:
kwargs
.
update
({
"v0_path"
:
True
})
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
vision_embeddings
)
...
...
vllm/model_executor/models/vision.py
View file @
e6c9053f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Sequence
from
typing
import
Final
,
Generic
,
Optional
,
Protocol
,
TypeVar
,
Union
,
cast
from
typing
import
Final
,
Generic
,
Optional
,
Protocol
,
TypeVar
,
Union
,
cast
import
torch
import
torch
...
@@ -154,8 +155,8 @@ def resolve_visual_encoder_outputs(
...
@@ -154,8 +155,8 @@ def resolve_visual_encoder_outputs(
def
scatter_patch_features
(
def
scatter_patch_features
(
features
:
torch
.
Tensor
,
patches
:
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]]
,
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
embed_is_patch
:
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]],
)
->
tuple
[
torch
.
Tensor
,
...]:
)
->
tuple
[
torch
.
Tensor
,
...]:
"""
"""
Scatter the patch features into a contiguous tensor that corresponds
Scatter the patch features into a contiguous tensor that corresponds
...
@@ -165,8 +166,8 @@ def scatter_patch_features(
...
@@ -165,8 +166,8 @@ def scatter_patch_features(
can be filtered out by :func`select_patch_features`.
can be filtered out by :func`select_patch_features`.
Args:
Args:
featur
es: The patch features
, concatenated across
each image.
patch
es: The patch features
for
each image.
Shape: `(num_
patch
, feature_depth)`
Shape: `(num_
images, <patch_dims>
, feature_depth)`
embed_is_patch: A boolean mask indicating which image embeddings
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
Shape: `(num_images, num_embeds)`
...
@@ -194,21 +195,21 @@ def scatter_patch_features(
...
@@ -194,21 +195,21 @@ def scatter_patch_features(
The resulting embedding tensor is:
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
[ nan p1 p2 nan p3 p4 nan nan ]
"""
"""
num_
embeds_p
er_image
=
[
if
len
(
patches
)
!=
len
(
embed
_i
s_p
atch
):
e_is_patch
.
numel
()
for
e_is_patch
in
embed_is_patch
raise
ValueError
(
f
"Inconsistent num_images:
{
len
(
patches
)
=
}
vs. "
]
f
"
{
len
(
embed_is_patch
)
=
}
"
)
if
isinstance
(
embed_is_patch
,
torch
.
Tensor
):
embed_
is_
patch
_flat
=
embed_is_patch
.
view
(
-
1
)
def
get_
embed_
one
(
patch
es_one
:
torch
.
Tensor
,
e_is_patch
:
torch
.
Tensor
):
e
lse
:
e
mbed_one
=
patches_one
.
new_full
(
embed
_is_patch
_flat
=
torch
.
cat
(
embed_is_patch
)
(
e
_is_patch
.
shape
[
0
],
patches_one
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
embeds_flat
=
features
.
new_full
(
)
(
sum
(
num_embeds_per_image
),
features
.
shape
[
-
1
]),
embed_one
[
e_is_patch
]
=
patches_one
.
flatten
(
0
,
-
2
)
fill_value
=
torch
.
nan
,
return
embed_one
)
embeds_flat
[
embed_is_patch_flat
]
=
features
.
flatten
(
0
,
-
2
)
return
tuple
(
get_embed_one
(
patches_one
,
e_is_patch
)
return
embeds_flat
.
split
(
num_
embeds_p
er_image
)
for
patches_one
,
e_is_patch
in
zip
(
patches
,
embed
_i
s_p
atch
)
)
def
select_patch_features
(
def
select_patch_features
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment