Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ead21102
Unverified
Commit
ead21102
authored
Jun 19, 2025
by
Alex Brooks
Committed by
GitHub
Jun 19, 2025
Browse files
[Core][Bugfix] Fix Online MM Beam Search (#19688)
Signed-off-by:
Alex-Brooks
<
Alex.Brooks@ibm.com
>
parent
01220ce8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
12 deletions
+45
-12
tests/entrypoints/openai/test_vision.py
tests/entrypoints/openai/test_vision.py
+27
-4
vllm/engine/protocol.py
vllm/engine/protocol.py
+11
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+7
-6
No files found.
tests/entrypoints/openai/test_vision.py
View file @
ead21102
...
...
@@ -25,6 +25,25 @@ TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
,
]
EXPECTED_MM_BEAM_SEARCH_RES
=
[
[
"The image shows a wooden boardwalk leading through a"
,
"The image shows a wooden boardwalk extending into a"
,
],
[
"The image shows two parrots perched on"
,
"The image shows two birds perched on a cur"
,
],
[
"The image shows a Venn diagram with three over"
,
"This image shows a Venn diagram with three over"
,
],
[
"This image displays a gradient of colors ranging from"
,
"This image displays a gradient of colors transitioning from"
,
],
]
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
...
...
@@ -270,10 +289,13 @@ async def test_single_chat_session_image_base64encoded(
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"image_
url"
,
TEST_IMAGE_URLS
)
@
pytest
.
mark
.
parametrize
(
"image_
idx"
,
list
(
range
(
len
(
TEST_IMAGE_URLS
)
)))
async
def
test_single_chat_session_image_base64encoded_beamsearch
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
image_
url
:
str
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
image_
idx
:
int
,
base64_encoded_image
:
dict
[
str
,
str
]):
# NOTE: This test also validates that we pass MM data through beam search
image_url
=
TEST_IMAGE_URLS
[
image_idx
]
expected_res
=
EXPECTED_MM_BEAM_SEARCH_RES
[
image_idx
]
messages
=
[{
"role"
:
...
...
@@ -297,10 +319,11 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
messages
=
messages
,
n
=
2
,
max_completion_tokens
=
10
,
temperature
=
0.0
,
extra_body
=
dict
(
use_beam_search
=
True
))
assert
len
(
chat_completion
.
choices
)
==
2
assert
chat_completion
.
choices
[
0
]
.
message
.
content
!
=
chat_completion
.
choices
[
1
].
message
.
content
for
actual
,
expected_str
in
zip
(
chat_completion
.
choices
,
expected_res
):
assert
actual
.
message
.
content
=
=
expected_str
@
pytest
.
mark
.
asyncio
...
...
vllm/engine/protocol.py
View file @
ead21102
...
...
@@ -88,9 +88,18 @@ class EngineClient(ABC):
if
processed_inputs
[
"type"
]
==
"embeds"
:
raise
NotImplementedError
prompt_token_ids
=
processed_inputs
[
"prompt_token_ids"
]
# This is a workaround to fix multimodal beam search; this is a
# bandaid fix for 2 small problems:
# 1. Multi_modal_data on the processed_inputs currently resolves to
# `None`.
# 2. preprocessing above expands the multimodal placeholders. However,
# this happens again in generation, so the double expansion causes
# a mismatch.
# TODO - would be ideal to handle this more gracefully.
prompt_token_ids
=
prompt
.
get
(
"prompt_token_ids"
)
multi_modal_data
=
prompt
.
get
(
"multi_modal_data"
)
prompt_text
=
processed_inputs
.
get
(
"prompt"
)
multi_modal_data
=
processed_inputs
.
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
processed_inputs
.
get
(
"mm_processor_kwargs"
)
tokenized_length
=
len
(
prompt_token_ids
)
...
...
vllm/entrypoints/llm.py
View file @
ead21102
...
...
@@ -15,7 +15,8 @@ from tqdm.auto import tqdm
from
typing_extensions
import
TypeVar
,
deprecated
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
BeamSearchSequence
,
get_beam_search_score
)
BeamSearchSequence
,
create_sort_beams_key_function
)
from
vllm.config
import
(
CompilationConfig
,
ModelDType
,
TokenizerMode
,
is_init_field
)
from
vllm.engine.arg_utils
import
(
EngineArgs
,
HfOverrides
,
PoolerConfig
,
...
...
@@ -575,10 +576,11 @@ class LLM:
lora_requests
=
self
.
_get_beam_search_lora_requests
(
lora_request
,
prompts
)
def
sort_beams_key
(
x
:
BeamSearchSequence
)
->
float
:
return
get_beam_search_score
(
x
.
tokens
,
x
.
cum_logprob
,
tokenizer
.
eos_token_id
,
length_penalty
)
tokenizer
=
self
.
get_tokenizer
()
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
,
)
def
create_tokens_prompt_from_beam
(
beam
:
BeamSearchSequence
)
->
TokensPrompt
:
...
...
@@ -593,7 +595,6 @@ class LLM:
"mm_processor_kwargs"
]
=
beam
.
mm_processor_kwargs
return
TokensPrompt
(
**
token_prompt_kwargs
)
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment