Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c9d3ecf0
Unverified
Commit
c9d3ecf0
authored
Feb 13, 2025
by
Cyrus Leung
Committed by
GitHub
Feb 13, 2025
Browse files
[VLM] Merged multi-modal processor for Molmo (#12966)
parent
fdcf64d3
Changes
9
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
750 additions
and
498 deletions
+750
-498
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+1
-1
tests/models/decoder_only/language/test_models.py
tests/models/decoder_only/language/test_models.py
+1
-1
tests/models/decoder_only/vision_language/test_models.py
tests/models/decoder_only/vision_language/test_models.py
+2
-3
tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
...els/decoder_only/vision_language/vlm_utils/model_utils.py
+21
-77
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+2
-0
tests/models/registry.py
tests/models/registry.py
+1
-0
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+681
-342
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+40
-40
vllm/utils.py
vllm/utils.py
+1
-34
No files found.
docs/source/models/supported_models.md
View file @
c9d3ecf0
...
@@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ
...
@@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ
-
*
`MolmoForCausalLM`
-
*
`MolmoForCausalLM`
*
Molmo
*
Molmo
*
T + I
*
T + I
*
`allenai/Molmo-7B-D-0924`
,
`allenai/Molmo-7
2
B-0924`
, etc.
*
`allenai/Molmo-7B-D-0924`
,
`allenai/Molmo-7B
-O
-0924`
, etc.
*
✅︎
*
✅︎
*
✅︎
*
✅︎
*
✅︎
*
✅︎
...
...
tests/models/decoder_only/language/test_models.py
View file @
c9d3ecf0
...
@@ -27,7 +27,7 @@ from ...utils import check_logprobs_close
...
@@ -27,7 +27,7 @@ from ...utils import check_logprobs_close
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
),
),
pytest
.
param
(
pytest
.
param
(
"THUDM/chatglm3-6b"
,
#
C
hat
GLM
(text-only)
"THUDM/chatglm3-6b"
,
#
c
hat
glm
(text-only)
),
),
pytest
.
param
(
pytest
.
param
(
"meta-llama/Llama-3.2-1B-Instruct"
,
# llama
"meta-llama/Llama-3.2-1B-Instruct"
,
# llama
...
...
tests/models/decoder_only/vision_language/test_models.py
View file @
c9d3ecf0
...
@@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = {
...
@@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = {
"molmo"
:
VLMTestInfo
(
"molmo"
:
VLMTestInfo
(
models
=
[
"allenai/Molmo-7B-D-0924"
],
models
=
[
"allenai/Molmo-7B-D-0924"
],
test_type
=
(
VLMTestType
.
IMAGE
),
test_type
=
(
VLMTestType
.
IMAGE
),
prompt_formatter
=
lambda
img_prompt
:
"User: "
+
img_prompt
+
" Assistant:"
,
# noqa: E501
prompt_formatter
=
identity
,
max_model_len
=
4096
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
image_size_factors
=
[(),(
1.0
,
1.0
,
1.0
)],
patch_hf_runner
=
model_utils
.
molmo_patch_hf_runner
,
patch_hf_runner
=
model_utils
.
mlomo_patch_hf_runner
,
postprocess_inputs
=
model_utils
.
molmo_post_processor
,
postprocess_inputs
=
model_utils
.
molmo_post_processor
,
),
),
# Tests for phi3v currently live in another file because of a bug in
# Tests for phi3v currently live in another file because of a bug in
...
...
tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
View file @
c9d3ecf0
...
@@ -6,7 +6,7 @@ typically specific to a small subset of models.
...
@@ -6,7 +6,7 @@ typically specific to a small subset of models.
import
re
import
re
import
types
import
types
from
pathlib
import
PosixPath
from
pathlib
import
PosixPath
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
PIL.Image
import
Image
from
PIL.Image
import
Image
...
@@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs
...
@@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs
from
vllm.transformers_utils.tokenizer
import
patch_padding_side
from
vllm.transformers_utils.tokenizer
import
patch_padding_side
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
.....conftest
import
(
HfRunner
,
ImageAsset
,
PromptAudioInput
,
from
.....conftest
import
HfRunner
,
ImageAsset
,
_ImageAssets
PromptImageInput
,
PromptVideoInput
,
_ImageAssets
)
from
....utils
import
TokensTextLogprobs
from
.types
import
RunnerOutput
from
.types
import
RunnerOutput
...
@@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
...
@@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return
hf_model
return
hf_model
def
_generate_greedy_logprobs_limit
(
def
molmo_patch_hf_runner
(
hf_model
:
HfRunner
)
->
HfRunner
:
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
TokensTextLogprobs
]:
all_inputs
=
self
.
get_inputs
(
prompts
,
images
=
images
,
videos
=
videos
,
audios
=
audios
)
# Process in batches for inference.
if
len
(
all_inputs
):
input_ids_lst
=
[]
images_lst
=
[]
images_input_idx_lst
=
[]
imges_masks_lst
=
[]
for
inputs
in
all_inputs
:
input_ids_lst
.
append
(
inputs
[
"input_ids"
])
images_lst
.
append
(
inputs
[
"images"
])
images_input_idx_lst
.
append
(
inputs
[
"image_input_idx"
])
imges_masks_lst
.
append
(
inputs
[
"image_masks"
])
batch_inputs
=
{}
batch_inputs
[
'input_ids'
]
=
torch
.
cat
(
input_ids_lst
,
dim
=
0
)
batch_inputs
[
'images'
]
=
torch
.
cat
(
images_lst
,
dim
=
0
)
batch_inputs
[
'image_input_idx'
]
=
torch
.
cat
(
images_input_idx_lst
,
dim
=
0
)
batch_inputs
[
'image_masks'
]
=
torch
.
cat
(
imges_masks_lst
,
dim
=
0
)
outputs
=
self
.
model
.
generate_from_batch
(
batch
=
self
.
wrap_device
(
batch_inputs
,
device
=
self
.
model
.
device
.
type
),
generation_config
=
GenerationConfig
(
max_new_tokens
=
max_tokens
,
stop_strings
=
"<|endoftext|>"
,
do_sample
=
False
,
),
tokenizer
=
self
.
tokenizer
,
output_hidden_states
=
True
,
return_dict_in_generate
=
True
,
)
all_logprobs
:
List
[
List
[
Dict
[
int
,
float
]]]
=
[]
all_output_ids
:
List
[
List
[
int
]]
=
[]
all_output_strs
:
List
[
str
]
=
[]
for
index
in
range
(
len
(
all_inputs
)):
(
seq_logprobs_lst
,
output_len
,
)
=
self
.
_hidden_states_to_logprobs
(
outputs
.
hidden_states
,
num_logprobs
)
all_logprobs
.
append
(
seq_logprobs_lst
)
seq_ids
=
outputs
.
sequences
[
index
]
output_ids
=
seq_ids
[
-
output_len
:]
all_output_ids
.
append
(
output_ids
.
tolist
())
all_output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
))
outputs
=
zip
(
all_output_ids
,
all_output_strs
,
all_logprobs
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
####### Molmo-specific HuggingFace runner patchers
def
mlomo_patch_hf_runner
(
hf_model
:
HfRunner
)
->
HfRunner
:
"""Patches and returns an instance of the HfRunner to use for Molmo."""
"""Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor
=
hf_model
.
processor
hf_processor
=
hf_model
.
processor
...
@@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
...
@@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model
.
processor
=
_processor
hf_model
.
processor
=
_processor
setattr
(
# noqa: B010
def
_generate
(
self
,
max_new_tokens
=
None
,
do_sample
=
None
,
**
kwargs
):
hf_model
,
batch
=
{
"generate_greedy_logprobs_limit"
,
k
:
kwargs
.
pop
(
k
)
types
.
MethodType
(
_generate_greedy_logprobs_limit
,
hf_model
),
for
k
in
(
"input_ids"
,
"images"
,
"image_input_idx"
,
"image_masks"
)
if
k
in
kwargs
}
return
self
.
generate_from_batch
(
batch
,
generation_config
=
GenerationConfig
(
max_new_tokens
=
max_new_tokens
,
stop_strings
=
"<|endoftext|>"
,
do_sample
=
do_sample
,
),
**
kwargs
,
)
)
hf_model
.
model
.
generate
=
types
.
MethodType
(
_generate
,
hf_model
.
model
)
return
hf_model
return
hf_model
tests/models/multimodal/processing/test_common.py
View file @
c9d3ecf0
...
@@ -168,6 +168,8 @@ def _test_processing_correctness(
...
@@ -168,6 +168,8 @@ def _test_processing_correctness(
"mistral-community/pixtral-12b"
,
"mistral-community/pixtral-12b"
,
"openbmb/MiniCPM-o-2_6"
,
"openbmb/MiniCPM-o-2_6"
,
"openbmb/MiniCPM-V-2_6"
,
"openbmb/MiniCPM-V-2_6"
,
"allenai/Molmo-7B-D-0924"
,
"allenai/Molmo-7B-O-0924"
,
"nvidia/NVLM-D-72B"
,
"nvidia/NVLM-D-72B"
,
"Qwen/Qwen-VL-Chat"
,
"Qwen/Qwen-VL-Chat"
,
"Qwen/Qwen2-VL-2B-Instruct"
,
"Qwen/Qwen2-VL-2B-Instruct"
,
...
...
tests/models/registry.py
View file @
c9d3ecf0
...
@@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-V-2_6"
,
"MiniCPMV"
:
_HfExamplesInfo
(
"openbmb/MiniCPM-V-2_6"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
"MolmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Molmo-7B-D-0924"
,
extras
=
{
"olmo"
:
"allenai/Molmo-7B-O-0924"
},
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"NVLM_D"
:
_HfExamplesInfo
(
"nvidia/NVLM-D-72B"
,
"NVLM_D"
:
_HfExamplesInfo
(
"nvidia/NVLM-D-72B"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
...
...
vllm/model_executor/models/molmo.py
View file @
c9d3ecf0
This diff is collapsed.
Click to expand it.
vllm/multimodal/inputs.py
View file @
c9d3ecf0
vllm/utils.py
View file @
c9d3ecf0
...
@@ -33,8 +33,7 @@ from dataclasses import dataclass, field
...
@@ -33,8 +33,7 @@ from dataclasses import dataclass, field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Literal
,
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Literal
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
)
overload
)
from
uuid
import
uuid4
from
uuid
import
uuid4
import
cloudpickle
import
cloudpickle
...
@@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
...
@@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
"""A nested JSON structure where the leaves need not be JSON-serializable."""
"""A nested JSON structure where the leaves need not be JSON-serializable."""
@
overload
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
Dict
[
str
,
JSONTree
[
T
]],
)
->
Dict
[
str
,
JSONTree
[
U
]]:
...
@
overload
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
List
[
JSONTree
[
T
]],
)
->
List
[
JSONTree
[
U
]]:
...
@
overload
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
Tuple
[
JSONTree
[
T
],
...],
)
->
Tuple
[
JSONTree
[
U
],
...]:
...
@
overload
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
JSONTree
[
T
],
)
->
JSONTree
[
U
]:
...
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
JSONTree
[
T
])
->
JSONTree
[
U
]:
def
json_map_leaves
(
func
:
Callable
[[
T
],
U
],
value
:
JSONTree
[
T
])
->
JSONTree
[
U
]:
if
isinstance
(
value
,
dict
):
if
isinstance
(
value
,
dict
):
return
{
k
:
json_map_leaves
(
func
,
v
)
for
k
,
v
in
value
.
items
()}
return
{
k
:
json_map_leaves
(
func
,
v
)
for
k
,
v
in
value
.
items
()}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment