Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef7865b4
Unverified
Commit
ef7865b4
authored
Oct 29, 2024
by
Zhong Qishuai
Committed by
GitHub
Oct 29, 2024
Browse files
[Frontend] re-enable multi-modality input in the new beam search implementation (#9427)
Signed-off-by: Qishuai Ferdinandzhong@gmail.com
parent
eae3d481
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
150 additions
and
40 deletions
+150
-40
tests/entrypoints/openai/test_vision.py
tests/entrypoints/openai/test_vision.py
+71
-0
vllm/beam_search.py
vllm/beam_search.py
+8
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+57
-31
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+2
-2
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+4
-3
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+7
-3
vllm/sampling_params.py
vllm/sampling_params.py
+1
-0
No files found.
tests/entrypoints/openai/test_vision.py
View file @
ef7865b4
...
@@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
...
@@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
0
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
0
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
async
def
test_single_chat_session_image_beamsearch
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
image_url
:
str
):
messages
=
[{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
],
}]
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
n
=
2
,
max_tokens
=
10
,
logprobs
=
True
,
top_logprobs
=
5
,
extra_body
=
dict
(
use_beam_search
=
True
))
assert
len
(
chat_completion
.
choices
)
==
2
assert
chat_completion
.
choices
[
0
].
message
.
content
!=
chat_completion
.
choices
[
1
].
message
.
content
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
...
@@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded(
...
@@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded(
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
0
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
0
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
async
def
test_single_chat_session_image_base64encoded_beamsearch
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
image_url
:
str
,
base64_encoded_image
:
Dict
[
str
,
str
]):
messages
=
[{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
f
"data:image/jpeg;base64,
{
base64_encoded_image
[
image_url
]
}
"
}
},
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
],
}]
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
messages
,
n
=
2
,
max_tokens
=
10
,
extra_body
=
dict
(
use_beam_search
=
True
))
assert
len
(
chat_completion
.
choices
)
==
2
assert
chat_completion
.
choices
[
0
].
message
.
content
!=
chat_completion
.
choices
[
1
].
message
.
content
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
...
...
vllm/beam_search.py
View file @
ef7865b4
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
if
TYPE_CHECKING
:
from
vllm.multimodal
import
MultiModalDataDict
@
dataclass
@
dataclass
class
BeamSearchSequence
:
class
BeamSearchSequence
:
...
@@ -16,6 +19,10 @@ class BeamSearchSequence:
...
@@ -16,6 +19,10 @@ class BeamSearchSequence:
logprobs
:
List
[
Dict
[
int
,
Logprob
]]
logprobs
:
List
[
Dict
[
int
,
Logprob
]]
cum_logprob
:
float
=
0.0
cum_logprob
:
float
=
0.0
text
:
Optional
[
str
]
=
None
text
:
Optional
[
str
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
@
dataclass
@
dataclass
...
...
vllm/engine/protocol.py
View file @
ef7865b4
...
@@ -6,6 +6,7 @@ from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
...
@@ -6,6 +6,7 @@ from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
@@ -59,7 +60,8 @@ class EngineClient(ABC):
...
@@ -59,7 +60,8 @@ class EngineClient(ABC):
async
def
beam_search
(
async
def
beam_search
(
self
,
self
,
prompt
:
Union
[
str
,
List
[
int
]],
prompt
:
Union
[
PromptType
,
List
[
int
]],
model_config
:
ModelConfig
,
request_id
:
str
,
request_id
:
str
,
params
:
BeamSearchParams
,
params
:
BeamSearchParams
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@@ -69,32 +71,40 @@ class EngineClient(ABC):
...
@@ -69,32 +71,40 @@ class EngineClient(ABC):
ignore_eos
=
params
.
ignore_eos
ignore_eos
=
params
.
ignore_eos
temperature
=
params
.
temperature
temperature
=
params
.
temperature
length_penalty
=
params
.
length_penalty
length_penalty
=
params
.
length_penalty
include_stop_str_in_output
=
params
.
include_stop_str_in_output
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
=
None
)
tokenizer
=
await
self
.
get_tokenizer
()
if
isinstance
(
prompt
,
str
):
input_preprocessor
=
InputPreprocessor
(
model_config
,
tokenizer
)
tokenized_prompt
=
tokenizer
.
encode
(
prompt
)
prompt_text
=
prompt
(
prompt_text
,
prompt_token_ids
,
multi_modal_data
,
else
:
mm_processor_kwargs
)
=
input_preprocessor
.
_extract_prompt_components
(
tokenized_prompt
=
prompt
prompt
,
prompt_text
=
None
request_id
=
request_id
,
tokenized_length
=
len
(
tokenized_prompt
)
)
tokenized_length
=
len
(
prompt_token_ids
)
sort_beams_key
=
create_sort_beams_key_function
(
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
tokenizer
.
eos_token_id
,
length_penalty
)
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
beam_search_params
=
SamplingParams
(
max_tokens
=
1
,
logprobs
=
2
*
beam_width
,
temperature
=
temperature
)
max_tokens
=
1
,
temperature
=
temperature
,
)
all_beams
=
[
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenized_prompt
,
BeamSearchSequence
(
tokens
=
prompt_token_ids
,
cum_logprob
=
0
,
logprobs
=
[],
logprobs
=
[],
cum_logprob
=
0
)
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
)
]
]
completed
=
[]
completed
=
[]
for
_
in
range
(
max_tokens
):
for
_
in
range
(
max_tokens
):
prompts_batch
=
[
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
,
multi_modal_data
=
beam
.
multi_modal_data
,
mm_processor_kwargs
=
beam
.
mm_processor_kwargs
)
for
beam
in
all_beams
for
beam
in
all_beams
]
]
...
@@ -120,17 +130,31 @@ class EngineClient(ABC):
...
@@ -120,17 +130,31 @@ class EngineClient(ABC):
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
not
ignore_eos
:
completed
.
append
(
new_beam
)
completed
.
append
(
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
]
if
include_stop_str_in_output
else
current_beam
.
tokens
,
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
,
finish_reason
=
"stop"
,
stop_reason
=
tokenizer
.
eos_token_id
))
else
:
else
:
new_beams
.
append
(
new_beam
)
new_beams
.
append
(
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
,
multi_modal_data
=
current_beam
.
multi_modal_data
,
mm_processor_kwargs
=
current_beam
.
mm_processor_kwargs
))
sorted_beams
=
sorted
(
new_beams
,
key
=
sort_beams_key
,
reverse
=
True
)
sorted_beams
=
sorted
(
new_beams
,
key
=
sort_beams_key
,
reverse
=
True
)
all_beams
=
sorted_beams
[:
beam_width
]
all_beams
=
sorted_beams
[:
beam_width
]
...
@@ -151,16 +175,18 @@ class EngineClient(ABC):
...
@@ -151,16 +175,18 @@ class EngineClient(ABC):
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
prompt_text
,
prompt
=
prompt_text
,
outputs
=
[
outputs
=
[
CompletionOutput
(
CompletionOutput
(
text
=
beam
.
text
,
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
[
tokenized_length
:],
token_ids
=
beam
.
tokens
[
tokenized_length
:],
index
=
i
,
index
=
i
,
logprobs
=
beam
.
logprobs
,
logprobs
=
beam
.
logprobs
,
finish_reason
=
beam
.
finish_reason
if
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
beam
.
finish_reason
is
not
None
else
"length"
,
stop_reason
=
beam
.
stop_reason
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
],
finished
=
True
,
finished
=
True
,
prompt_token_ids
=
tokenized_prompt
,
prompt_token_ids
=
prompt_token_ids
,
prompt_logprobs
=
None
)
prompt_logprobs
=
None
)
yield
beam_search_output
yield
beam_search_output
...
...
vllm/entrypoints/openai/protocol.py
View file @
ef7865b4
...
@@ -308,7 +308,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -308,7 +308,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
temperature
=
temperature
,
temperature
=
temperature
,
length_penalty
=
self
.
length_penalty
,
length_penalty
=
self
.
length_penalty
,
)
include_stop_str_in_output
=
self
.
include_stop_str_in_output
)
def
to_sampling_params
(
self
,
default_max_tokens
:
int
)
->
SamplingParams
:
def
to_sampling_params
(
self
,
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
max_tokens
=
self
.
max_tokens
...
@@ -606,7 +606,7 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -606,7 +606,7 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos
=
self
.
ignore_eos
,
ignore_eos
=
self
.
ignore_eos
,
temperature
=
temperature
,
temperature
=
temperature
,
length_penalty
=
self
.
length_penalty
,
length_penalty
=
self
.
length_penalty
,
)
include_stop_str_in_output
=
self
.
include_stop_str_in_output
)
def
to_sampling_params
(
self
,
default_max_tokens
:
int
)
->
SamplingParams
:
def
to_sampling_params
(
self
,
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
max_tokens
=
self
.
max_tokens
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
ef7865b4
...
@@ -236,9 +236,10 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -236,9 +236,10 @@ class OpenAIServingChat(OpenAIServing):
if
isinstance
(
sampling_params
,
BeamSearchParams
):
if
isinstance
(
sampling_params
,
BeamSearchParams
):
result_generator
=
self
.
engine_client
.
beam_search
(
result_generator
=
self
.
engine_client
.
beam_search
(
engine_inputs
[
'prompt_token_ids'
],
prompt
=
engine_inputs
,
request_id
,
model_config
=
self
.
model_config
,
sampling_params
,
request_id
=
request_id
,
params
=
sampling_params
,
)
)
else
:
else
:
result_generator
=
self
.
engine_client
.
generate
(
result_generator
=
self
.
engine_client
.
generate
(
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
ef7865b4
...
@@ -150,9 +150,13 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -150,9 +150,13 @@ class OpenAIServingCompletion(OpenAIServing):
if
isinstance
(
sampling_params
,
BeamSearchParams
):
if
isinstance
(
sampling_params
,
BeamSearchParams
):
generator
=
self
.
engine_client
.
beam_search
(
generator
=
self
.
engine_client
.
beam_search
(
prompt_inputs
[
"prompt_token_ids"
],
prompt
=
{
request_id_item
,
"prompt_token_ids"
:
sampling_params
,
prompt_inputs
[
"prompt_token_ids"
]
},
model_config
=
self
.
model_config
,
request_id
=
request_id
,
params
=
sampling_params
,
)
)
else
:
else
:
generator
=
self
.
engine_client
.
generate
(
generator
=
self
.
engine_client
.
generate
(
...
...
vllm/sampling_params.py
View file @
ef7865b4
...
@@ -500,3 +500,4 @@ class BeamSearchParams(
...
@@ -500,3 +500,4 @@ class BeamSearchParams(
ignore_eos
:
bool
=
False
ignore_eos
:
bool
=
False
temperature
:
float
=
0.0
temperature
:
float
=
0.0
length_penalty
:
float
=
1.0
length_penalty
:
float
=
1.0
include_stop_str_in_output
:
bool
=
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment