Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
71b9cde0
Unverified
Commit
71b9cde0
authored
Apr 11, 2025
by
Travis Johnson
Committed by
GitHub
Apr 11, 2025
Browse files
[Bugfix] handle alignment of encoder_seq_lens in mllama.py (#14784)
Signed-off-by:
Travis Johnson
<
tsjohnso@us.ibm.com
>
parent
5285589f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
82 additions
and
22 deletions
+82
-22
tests/models/encoder_decoder/vision_language/test_mllama.py
tests/models/encoder_decoder/vision_language/test_mllama.py
+50
-9
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+32
-13
No files found.
tests/models/encoder_decoder/vision_language/test_mllama.py
View file @
71b9cde0
...
@@ -209,14 +209,15 @@ def _run_test(
...
@@ -209,14 +209,15 @@ def _run_test(
# will hurt multiprocessing backend with fork method (the default method).
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
with
vllm_runner
(
model
,
dtype
=
dtype
,
dtype
=
dtype
,
max_model_len
=
8
192
,
max_model_len
=
192
12
,
# 3 max size images
max_num_seqs
=
3
,
max_num_seqs
=
3
,
tensor_parallel_size
=
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
distributed_executor_backend
=
distributed_executor_backend
,
limit_mm_per_prompt
=
{
"image"
:
_LIMIT_IMAGE_PER_PROMPT
limit_mm_per_prompt
=
{
"image"
:
})
as
vllm_model
:
_LIMIT_IMAGE_PER_PROMPT
})
as
vllm_model
:
vllm_outputs_per_image
=
[
vllm_outputs_per_image
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
max_tokens
,
...
@@ -507,7 +508,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
...
@@ -507,7 +508,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
model
,
model
,
dtype
=
dtype
,
dtype
=
dtype
,
max_model_len
=
8192
,
max_model_len
=
8192
,
max_num_seqs
=
2
,
max_num_seqs
=
4
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
limit_mm_per_prompt
=
{
"image"
:
limit_mm_per_prompt
=
{
"image"
:
_LIMIT_IMAGE_PER_PROMPT
})
as
vllm_model
:
_LIMIT_IMAGE_PER_PROMPT
})
as
vllm_model
:
...
@@ -552,6 +553,23 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
...
@@ -552,6 +553,23 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
num_logprobs
,
num_logprobs
,
images
=
images
)
images
=
images
)
# Mixed batch with text and images with different numbers of tiles
prompts
=
[
"<|begin_of_text|>Hello!"
,
"<|begin_of_text|>Some text before.<|image|>What is in the image?"
,
# noqa: E501
"<|begin_of_text|>Some text before.<|image|>What is in the image?"
,
# noqa: E501
]
images
=
[
None
,
[
stop_sign
],
# smaller image must be 2nd for the repro
[
stop_sign
.
resize
((
448
,
448
))],
]
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
,
images
=
images
)
class
DummyModel
:
class
DummyModel
:
image_token_id
=
MLLAMA_IMAGE_TOKEN_ID
image_token_id
=
MLLAMA_IMAGE_TOKEN_ID
...
@@ -674,3 +692,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
...
@@ -674,3 +692,26 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
f
"full_text_row_masked_out_mask[
{
idx
}
] must be "
\
f
"full_text_row_masked_out_mask[
{
idx
}
] must be "
\
f
"'
{
must_be_masked
}
' "
f
"'
{
must_be_masked
}
' "
idx
+=
1
idx
+=
1
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"encoder_seq_lens, num_tiles, expected"
,
[
([
6404
],
[[
4
]],
[
6404
]),
([
0
,
6404
],
[[
4
]],
[
6404
]),
([
0
,
1601
,
8005
],
[[
1
],
[
4
,
1
]],
[
1601
,
8005
]),
([
0
,
19212
,
0
,
3202
],
[[
4
,
4
,
4
],
[
2
]],
[
19212
,
3202
]),
])
def
test_parse_and_validate_encoder_lens
(
encoder_seq_lens
,
num_tiles
,
expected
)
->
None
:
dummy
=
DummyModel
()
num_tokens_per_tile
=
1601
actual_encoder_seq_lens
=
MllamaForConditionalGeneration
\
.
_get_and_validate_encoder_lens
(
dummy
,
encoder_seq_lens
,
num_tiles
,
num_tokens_per_tile
,
)
assert
actual_encoder_seq_lens
==
expected
,
\
f
"Expected
{
expected
}
but got
{
actual_encoder_seq_lens
}
"
vllm/model_executor/models/mllama.py
View file @
71b9cde0
...
@@ -1301,6 +1301,31 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1301,6 +1301,31 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_get_and_validate_encoder_lens
(
self
,
encoder_seq_lens
:
List
[
int
],
num_tiles
:
List
[
List
[
int
]],
num_tokens_per_tile
:
int
,
)
->
List
[
int
]:
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
actual_encoder_seq_lens
=
[
sum
(
num_tile
)
*
num_tokens_per_tile
for
num_tile
in
num_tiles
]
# remove 0 encoder len entries for text-only requests for these
# assertions
attn_metadata_lens
=
[
x
for
x
in
encoder_seq_lens
if
x
>
0
]
assert
len
(
actual_encoder_seq_lens
)
==
len
(
attn_metadata_lens
)
for
actual_len
,
last_group_len
in
zip
(
actual_encoder_seq_lens
,
attn_metadata_lens
):
assert
actual_len
>=
last_group_len
return
actual_encoder_seq_lens
def
flat_encoder_result
(
self
,
cross_attention_states
:
torch
.
Tensor
,
def
flat_encoder_result
(
self
,
cross_attention_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
actual_encoder_seq_lens
:
List
[
int
]):
actual_encoder_seq_lens
:
List
[
int
]):
...
@@ -1428,20 +1453,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1428,20 +1453,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
else
:
else
:
skip_cross_attention
=
False
skip_cross_attention
=
False
# Get the actual number of encoder tokens for each sample.
num_tiles
=
[
t
.
tolist
()
for
t
in
kwargs
.
pop
(
"num_tiles"
)]
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles_tensor
=
kwargs
.
pop
(
"num_tiles"
)
num_tiles
=
[
t
.
tolist
()
for
t
in
num_tiles_tensor
]
num_tokens_per_tile
=
calc_token_per_chunk
(
self
.
image_size
)
num_tokens_per_tile
=
calc_token_per_chunk
(
self
.
image_size
)
actual_encoder_seq_lens
=
[
sum
(
num_tile
)
*
num_tokens_per_tile
for
num_tile
in
num_tiles
actual_encoder_seq_lens
=
self
.
_get_and_validate_encoder_lens
(
]
attn_metadata
.
encoder_seq_lens
,
for
actual_len
,
last_group_len
in
zip
(
num_tiles
,
actual_encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens
):
num_tokens_per_tile
,
assert
actual_len
>=
last_group_len
)
cross_attention_states
=
self
.
get_cross_attention_states
(
cross_attention_states
=
self
.
get_cross_attention_states
(
image_inputs
,
attn_metadata
,
actual_encoder_seq_lens
)
image_inputs
,
attn_metadata
,
actual_encoder_seq_lens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment