Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
997c8811
Unverified
Commit
997c8811
authored
Mar 26, 2025
by
Cyrus Leung
Committed by
GitHub
Mar 26, 2025
Browse files
[Model] Support multi-image for Molmo (#15438)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
e42389f9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
35 deletions
+39
-35
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+1
-1
tests/models/decoder_only/vision_language/test_models.py
tests/models/decoder_only/vision_language/test_models.py
+1
-1
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+28
-29
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+9
-4
No files found.
docs/source/models/supported_models.md
View file @
997c8811
...
@@ -853,7 +853,7 @@ See [this page](#generative-models) for more information on how to use generativ
...
@@ -853,7 +853,7 @@ See [this page](#generative-models) for more information on how to use generativ
*
*
-
*
`MolmoForCausalLM`
-
*
`MolmoForCausalLM`
*
Molmo
*
Molmo
*
T + I
*
T + I
<sup>
+
</sup>
*
`allenai/Molmo-7B-D-0924`
,
`allenai/Molmo-7B-O-0924`
, etc.
*
`allenai/Molmo-7B-D-0924`
,
`allenai/Molmo-7B-O-0924`
, etc.
*
✅︎
*
✅︎
*
✅︎
*
✅︎
...
...
tests/models/decoder_only/vision_language/test_models.py
View file @
997c8811
...
@@ -431,7 +431,7 @@ VLM_TEST_SETTINGS = {
...
@@ -431,7 +431,7 @@ VLM_TEST_SETTINGS = {
),
),
"molmo"
:
VLMTestInfo
(
"molmo"
:
VLMTestInfo
(
models
=
[
"allenai/Molmo-7B-D-0924"
],
models
=
[
"allenai/Molmo-7B-D-0924"
],
test_type
=
(
VLMTestType
.
IMAGE
),
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
prompt_formatter
=
identity
,
prompt_formatter
=
identity
,
max_model_len
=
4096
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
...
...
vllm/model_executor/models/molmo.py
View file @
997c8811
...
@@ -57,7 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
...
@@ -57,7 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
select_patch_features
from
.vision
import
scatter_patch_features
,
select_patch_features
# TODO: hard-coded for now. Consider making it configurable.
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS
=
[
-
2
,
-
9
]
VIT_LAYERS
=
[
-
2
,
-
9
]
...
@@ -71,13 +71,13 @@ POOLING_SIZE = 2
...
@@ -71,13 +71,13 @@ POOLING_SIZE = 2
class
MolmoImageInputs
(
TypedDict
):
class
MolmoImageInputs
(
TypedDict
):
images
:
Union
[
torch
.
Tensor
,
L
ist
[
torch
.
Tensor
]]
images
:
Union
[
torch
.
Tensor
,
l
ist
[
torch
.
Tensor
]]
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
image_masks
:
Optional
[
Union
[
torch
.
Tensor
,
L
ist
[
torch
.
Tensor
]]]
image_masks
:
Optional
[
Union
[
torch
.
Tensor
,
l
ist
[
torch
.
Tensor
]]]
"""Shape: `(batch_size, num_crops, num_patch)`"""
"""Shape: `(batch_size, num_crops, num_patch)`"""
feat_is_patch
:
Union
[
torch
.
Tensor
,
L
ist
[
torch
.
Tensor
]]
feat_is_patch
:
Union
[
torch
.
Tensor
,
l
ist
[
torch
.
Tensor
]]
"""
"""
A boolean mask indicating which image features correspond
A boolean mask indicating which image features correspond
to patch tokens.
to patch tokens.
...
@@ -85,7 +85,7 @@ class MolmoImageInputs(TypedDict):
...
@@ -85,7 +85,7 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size, num_crops, num_patch)`
Shape: `(batch_size, num_crops, num_patch)`
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
L
ist
[
torch
.
Tensor
]]
embed_is_patch
:
Union
[
torch
.
Tensor
,
l
ist
[
torch
.
Tensor
]]
"""
"""
A boolean mask indicating which image embeddings correspond
A boolean mask indicating which image embeddings correspond
to patch tokens.
to patch tokens.
...
@@ -93,7 +93,7 @@ class MolmoImageInputs(TypedDict):
...
@@ -93,7 +93,7 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size, num_embeds)`
Shape: `(batch_size, num_embeds)`
"""
"""
num_crops
:
torch
.
Tensor
num_crops
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""Shape: `(batch_size, num_images)`"""
"""Shape: `(batch_size, num_images)`"""
...
@@ -1144,13 +1144,7 @@ class MolmoProcessorWrapper:
...
@@ -1144,13 +1144,7 @@ class MolmoProcessorWrapper:
image_input_idx
=
outputs
.
pop
(
"image_input_idx"
,
None
)
image_input_idx
=
outputs
.
pop
(
"image_input_idx"
,
None
)
if
image_input_idx
is
not
None
:
if
image_input_idx
is
not
None
:
input_is_patch
=
input_ids
==
self
.
image_patch_id
feat_is_patch
=
image_input_idx
>=
0
image_input_idx_flat
:
torch
.
Tensor
=
image_input_idx
.
view
(
-
1
)
image_valid_flat
=
image_input_idx_flat
>=
0
feat_is_patch_flat
=
image_valid_flat
.
clone
()
feat_is_patch_flat
[
image_valid_flat
]
=
(
input_is_patch
[
image_input_idx_flat
[
image_valid_flat
]])
feat_is_patch
=
feat_is_patch_flat
.
view
(
*
image_input_idx
.
shape
)
input_is_embed
=
torch
.
isin
(
input_is_embed
=
torch
.
isin
(
input_ids
,
input_ids
,
...
@@ -1165,6 +1159,17 @@ class MolmoProcessorWrapper:
...
@@ -1165,6 +1159,17 @@ class MolmoProcessorWrapper:
embed_is_patch
=
embed_ids
==
self
.
image_patch_id
embed_is_patch
=
embed_ids
==
self
.
image_patch_id
assert
embed_is_patch
.
sum
()
==
feat_is_patch
.
sum
()
assert
embed_is_patch
.
sum
()
==
feat_is_patch
.
sum
()
# image_tokens = extra_joint + joint
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
embed_start
=
torch
.
nonzero
(
embed_ids
==
self
.
im_start_id
)[::
2
,
0
]
embed_end
=
torch
.
nonzero
(
embed_ids
==
self
.
im_end_id
)[
1
::
2
,
0
]
assert
len
(
embed_start
)
==
len
(
embed_end
)
==
len
(
images
)
embed_is_patch
=
[
embed_is_patch
[
start
:
end
+
1
]
for
start
,
end
in
zip
(
embed_start
,
embed_end
)
]
tilings
=
[
tilings
=
[
self
.
select_tiling
(
self
.
select_tiling
(
image_width
=
image
.
size
[
0
],
image_width
=
image
.
size
[
0
],
...
@@ -1180,7 +1185,7 @@ class MolmoProcessorWrapper:
...
@@ -1180,7 +1185,7 @@ class MolmoProcessorWrapper:
outputs
[
"num_crops"
]
=
num_crops
outputs
[
"num_crops"
]
=
num_crops
outputs
[
"img_patch_id"
]
=
self
.
image_patch_id
outputs
[
"img_patch_id"
]
=
self
.
image_patch_id
return
BatchFeature
(
outputs
,
tensor_type
=
return_tensors
)
return
BatchFeature
(
outputs
)
class
MolmoProcessingInfo
(
BaseProcessingInfo
):
class
MolmoProcessingInfo
(
BaseProcessingInfo
):
...
@@ -1190,9 +1195,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
...
@@ -1190,9 +1195,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
return
MolmoProcessorWrapper
(
processor
)
return
MolmoProcessorWrapper
(
processor
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
# TODO: Investigate different `embed_is_patch` between cache/no-cache
return
{
"image"
:
None
}
# in multi-image case
return
{
"image"
:
1
}
def
get_mm_max_tokens_per_item
(
def
get_mm_max_tokens_per_item
(
self
,
self
,
...
@@ -1325,7 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
...
@@ -1325,7 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
"image"
,
num_crops
),
"image"
,
num_crops
),
feat_is_patch
=
MultiModalFieldConfig
.
flat_from_sizes
(
feat_is_patch
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_crops
),
"image"
,
num_crops
),
embed_is_patch
=
MultiModalFieldConfig
.
shar
ed
(
"image"
,
num_images
),
embed_is_patch
=
MultiModalFieldConfig
.
batch
ed
(
"image"
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
img_patch_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
img_patch_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
)
)
...
@@ -1499,7 +1502,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1499,7 +1502,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def
_process_image_input
(
def
_process_image_input
(
self
,
self
,
image_input
:
MolmoImageInputs
,
image_input
:
MolmoImageInputs
,
)
->
Union
[
torch
.
Tensor
,
L
ist
[
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
l
ist
[
torch
.
Tensor
]]:
if
isinstance
(
image_input
[
"images"
],
list
):
if
isinstance
(
image_input
[
"images"
],
list
):
# Call the vision backbone on the whole batch at once
# Call the vision backbone on the whole batch at once
images_flat
=
flatten_bn
(
image_input
[
"images"
],
concat
=
True
)
images_flat
=
flatten_bn
(
image_input
[
"images"
],
concat
=
True
)
...
@@ -1530,7 +1533,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1530,7 +1533,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feat_is_patch
:
torch
.
Tensor
,
# Shape: (num_crop, num_patch)
feat_is_patch
:
torch
.
Tensor
,
# Shape: (num_crop, num_patch)
num_crops
:
torch
.
Tensor
,
# Shape: (num_images,)
num_crops
:
torch
.
Tensor
,
# Shape: (num_images,)
embed_is_patch
:
torch
.
Tensor
,
# Shape: (num_embeds,)
embed_is_patch
:
torch
.
Tensor
,
# Shape: (num_embeds,)
)
->
list
[
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
...
]:
"""
"""
Scatter the patch features into a contiguous tensor that corresponds
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
to the embedding tokens defined by the multimodal processor.
...
@@ -1565,16 +1568,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1565,16 +1568,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feats_per_image
=
features
.
split
(
num_crops_per_image
)
feats_per_image
=
features
.
split
(
num_crops_per_image
)
f_is_patch_per_image
=
feat_is_patch
.
split
(
num_crops_per_image
)
f_is_patch_per_image
=
feat_is_patch
.
split
(
num_crops_per_image
)
_
,
_
,
embed_dim
=
features
.
shape
features
=
torch
.
cat
([
(
num_embeds
,
)
=
embed_is_patch
.
shape
feats
[
f_is_patch
]
for
feats
,
f_is_patch
in
zip
(
feats_per_image
,
f_is_patch_per_image
)
embeds_in_batch
=
list
[
torch
.
Tensor
]()
])
for
feats
,
f_is_patch
in
zip
(
feats_per_image
,
f_is_patch_per_image
):
embeds
=
feats
.
new_full
((
num_embeds
,
embed_dim
),
torch
.
nan
)
embeds
[
embed_is_patch
]
=
feats
[
f_is_patch
]
embeds_in_batch
.
append
(
embeds
)
return
embed
s
_i
n_b
atch
return
scatter_patch_features
(
features
,
embed_i
s_p
atch
)
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
...
...
vllm/model_executor/models/vision.py
View file @
997c8811
...
@@ -155,7 +155,7 @@ def resolve_visual_encoder_outputs(
...
@@ -155,7 +155,7 @@ def resolve_visual_encoder_outputs(
def
scatter_patch_features
(
def
scatter_patch_features
(
features
:
torch
.
Tensor
,
features
:
torch
.
Tensor
,
embed_is_patch
:
torch
.
Tensor
,
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
)
->
tuple
[
torch
.
Tensor
,
...]:
)
->
tuple
[
torch
.
Tensor
,
...]:
"""
"""
Scatter the patch features into a contiguous tensor that corresponds
Scatter the patch features into a contiguous tensor that corresponds
...
@@ -194,14 +194,19 @@ def scatter_patch_features(
...
@@ -194,14 +194,19 @@ def scatter_patch_features(
The resulting embedding tensor is:
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
[ nan p1 p2 nan p3 p4 nan nan ]
"""
"""
num_images
,
num_embeds
=
embed_is_patch
.
shape
num_embeds_per_image
=
[
num_embeds_per_image
=
[
num_embeds
]
*
num_images
e_is_patch
.
numel
()
for
e_is_patch
in
embed_is_patch
]
if
isinstance
(
embed_is_patch
,
torch
.
Tensor
):
embed_is_patch_flat
=
embed_is_patch
.
view
(
-
1
)
else
:
embed_is_patch_flat
=
torch
.
cat
(
embed_is_patch
)
embeds_flat
=
features
.
new_full
(
embeds_flat
=
features
.
new_full
(
(
sum
(
num_embeds_per_image
),
features
.
shape
[
-
1
]),
(
sum
(
num_embeds_per_image
),
features
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
fill_value
=
torch
.
nan
,
)
)
embeds_flat
[
embed_is_patch
.
view
(
-
1
)
]
=
features
.
flatten
(
0
,
-
2
)
embeds_flat
[
embed_is_patch
_flat
]
=
features
.
flatten
(
0
,
-
2
)
return
embeds_flat
.
split
(
num_embeds_per_image
)
return
embeds_flat
.
split
(
num_embeds_per_image
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment