Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
32c9eff2
Unverified
Commit
32c9eff2
authored
Jan 06, 2025
by
Jee Jee Li
Committed by
GitHub
Jan 06, 2025
Browse files
[Bugfix][V1] Fix molmo text-only inputs (#11676)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
4ca5d40a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
42 deletions
+123
-42
tests/models/decoder_only/vision_language/test_models.py
tests/models/decoder_only/vision_language/test_models.py
+10
-0
tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
...els/decoder_only/vision_language/vlm_utils/model_utils.py
+96
-3
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+17
-39
No files found.
tests/models/decoder_only/vision_language/test_models.py
View file @
32c9eff2
...
...
@@ -341,6 +341,16 @@ VLM_TEST_SETTINGS = {
),
hf_output_post_proc
=
model_utils
.
minicpmv_trunc_hf_output
,
),
"molmo"
:
VLMTestInfo
(
models
=
[
"allenai/Molmo-7B-D-0924"
],
test_type
=
(
VLMTestType
.
IMAGE
),
prompt_formatter
=
lambda
img_prompt
:
"User: "
+
img_prompt
+
" Assistant:"
,
# noqa: E501
max_model_len
=
4096
,
max_num_seqs
=
2
,
image_size_factors
=
[(),(
1.0
,
1.0
,
1.0
)],
patch_hf_runner
=
model_utils
.
mlomo_patch_hf_runner
,
postprocess_inputs
=
model_utils
.
molmo_post_processor
,
),
# Tests for phi3v currently live in another file because of a bug in
# transformers. Once this issue is fixed, we can enable them here instead.
# https://github.com/huggingface/transformers/issues/34307
...
...
tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
View file @
32c9eff2
...
...
@@ -5,17 +5,20 @@ typically specific to a small subset of models.
import
re
import
types
from
pathlib
import
PosixPath
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
PIL.Image
import
Image
from
transformers
import
AutoConfig
,
AutoTokenizer
,
BatchEncoding
from
transformers
import
(
AutoConfig
,
AutoTokenizer
,
BatchEncoding
,
GenerationConfig
)
from
vllm.sequence
import
SampleLogprobs
from
vllm.transformers_utils.tokenizer
import
patch_padding_side
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
.....conftest
import
HfRunner
,
ImageAsset
,
_ImageAssets
from
.....conftest
import
(
HfRunner
,
ImageAsset
,
PromptAudioInput
,
PromptImageInput
,
PromptVideoInput
,
_ImageAssets
)
from
....utils
import
TokensTextLogprobs
from
.types
import
RunnerOutput
...
...
@@ -222,6 +225,11 @@ def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
return
{
"model_inputs"
:
hf_inputs
}
def
molmo_post_processor
(
hf_inputs
:
BatchEncoding
,
dtype
:
str
):
hf_inputs
=
cast_dtype_post_processor
(
"images"
)(
hf_inputs
,
dtype
)
return
{
k
:
v
.
unsqueeze
(
0
)
for
k
,
v
in
hf_inputs
.
items
()}
####### Prompt path encoders for models that need models on disk
def
qwen_prompt_path_encoder
(
tmp_path
:
PosixPath
,
prompt
:
str
,
assets
:
Union
[
List
[
ImageAsset
],
...
...
@@ -451,3 +459,88 @@ def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model
.
model
.
generate
=
types
.
MethodType
(
_generate
,
hf_model
.
model
)
return
hf_model
def
_generate_greedy_logprobs_limit
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
TokensTextLogprobs
]:
all_inputs
=
self
.
get_inputs
(
prompts
,
images
=
images
,
videos
=
videos
,
audios
=
audios
)
# Process in batches for inference.
if
len
(
all_inputs
):
input_ids_lst
=
[]
images_lst
=
[]
images_input_idx_lst
=
[]
imges_masks_lst
=
[]
for
inputs
in
all_inputs
:
input_ids_lst
.
append
(
inputs
[
"input_ids"
])
images_lst
.
append
(
inputs
[
"images"
])
images_input_idx_lst
.
append
(
inputs
[
"image_input_idx"
])
imges_masks_lst
.
append
(
inputs
[
"image_masks"
])
batch_inputs
=
{}
batch_inputs
[
'input_ids'
]
=
torch
.
cat
(
input_ids_lst
,
dim
=
0
)
batch_inputs
[
'images'
]
=
torch
.
cat
(
images_lst
,
dim
=
0
)
batch_inputs
[
'image_input_idx'
]
=
torch
.
cat
(
images_input_idx_lst
,
dim
=
0
)
batch_inputs
[
'image_masks'
]
=
torch
.
cat
(
imges_masks_lst
,
dim
=
0
)
outputs
=
self
.
model
.
generate_from_batch
(
batch
=
self
.
wrap_device
(
batch_inputs
,
device
=
self
.
model
.
device
.
type
),
generation_config
=
GenerationConfig
(
max_new_tokens
=
max_tokens
,
stop_strings
=
"<|endoftext|>"
,
do_sample
=
False
,
),
tokenizer
=
self
.
tokenizer
,
output_hidden_states
=
True
,
return_dict_in_generate
=
True
,
)
all_logprobs
:
List
[
List
[
Dict
[
int
,
float
]]]
=
[]
all_output_ids
:
List
[
List
[
int
]]
=
[]
all_output_strs
:
List
[
str
]
=
[]
for
index
in
range
(
len
(
all_inputs
)):
(
seq_logprobs_lst
,
output_len
,
)
=
self
.
_hidden_states_to_logprobs
(
outputs
.
hidden_states
,
num_logprobs
)
all_logprobs
.
append
(
seq_logprobs_lst
)
seq_ids
=
outputs
.
sequences
[
index
]
output_ids
=
seq_ids
[
-
output_len
:]
all_output_ids
.
append
(
output_ids
.
tolist
())
all_output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
))
outputs
=
zip
(
all_output_ids
,
all_output_strs
,
all_logprobs
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
####### Molmo-specific HuggingFace runner patchers
def
mlomo_patch_hf_runner
(
hf_model
:
HfRunner
)
->
HfRunner
:
"""Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor
=
hf_model
.
processor
def
_processor
(
*
args
,
**
kwargs
):
return
hf_processor
.
process
(
*
args
,
**
kwargs
)
hf_model
.
processor
=
_processor
setattr
(
# noqa: B010
hf_model
,
"generate_greedy_logprobs_limit"
,
types
.
MethodType
(
_generate_greedy_logprobs_limit
,
hf_model
),
)
return
hf_model
vllm/model_executor/models/molmo.py
View file @
32c9eff2
...
...
@@ -1081,45 +1081,25 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
else
:
out
=
processor
.
process
(
None
,
image
,
tokens
=
inputs
[
"prompt_token_ids"
])
# If there is no image, return directly.
if
image
is
None
:
new_prompt_token_ids
=
out
[
"input_ids"
].
tolist
()
prompt
=
inputs
.
get
(
"prompt"
)
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
new_prompt_token_ids
)
return
token_inputs
(
prompt_token_ids
=
new_prompt_token_ids
,
prompt
=
prompt
,
)
image_processor
=
processor
.
image_processor
max_total_crops
=
1
+
image_processor
.
max_crops
if
image
is
not
None
:
images
,
image_input_idx
,
image_masks
=
pad_images
(
max_total_crops
,
out
[
"images"
],
out
[
"image_input_idx"
],
out
.
get
(
"image_masks"
),
)
else
:
base_image_input_size
=
image_processor
.
base_image_input_size
image_patch_size
=
image_processor
.
image_patch_size
image_num_patch
=
(
base_image_input_size
[
0
]
//
image_patch_size
,
base_image_input_size
[
1
]
//
image_patch_size
,
)
n_pixels
=
image_patch_size
*
image_patch_size
*
3
n_patches
=
image_num_patch
[
0
]
*
image_num_patch
[
1
]
image_length_w
=
image_processor
.
image_token_length_w
image_length_h
=
image_processor
.
image_token_length_h
tokens_per_image
=
image_length_w
*
image_length_h
images
=
torch
.
full
(
(
max_total_crops
,
n_patches
,
n_pixels
),
-
1
,
dtype
=
torch
.
float32
,
)
image_input_idx
=
torch
.
full
(
(
max_total_crops
,
tokens_per_image
),
-
1
,
dtype
=
torch
.
int32
,
)
if
image_processor
.
image_padding_mask
:
image_masks
=
torch
.
full
(
(
max_total_crops
,
n_patches
),
-
1
,
dtype
=
torch
.
float32
,
)
image_data
=
dict
(
images
=
images
,
image_input_idx
=
image_input_idx
,
...
...
@@ -1143,11 +1123,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
offset
=
i
size
+=
1
image_data
[
"image_start_end"
]
=
(
offset
,
offset
+
size
)
prompt
=
inputs
.
get
(
"prompt"
)
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
new_prompt_token_ids
)
return
token_inputs
(
prompt_token_ids
=
new_prompt_token_ids
,
prompt
=
prompt
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment