"csrc/vscode:/vscode.git/clone" did not exist on "eb5741ad422f04d0bac60c9b6c07183e0431ce8c"
Unverified Commit ef7865b4 authored by Zhong Qishuai's avatar Zhong Qishuai Committed by GitHub
Browse files

[Frontend] re-enable multi-modality input in the new beam search implementation (#9427)

Signed-off-by: Qishuai Ferdinandzhong@gmail.com
parent eae3d481
...@@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, ...@@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
model_name: str,
image_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
...@@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded( ...@@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded(
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_single_chat_session_image_base64encoded_beamsearch(
client: openai.AsyncOpenAI, model_name: str, image_url: str,
base64_encoded_image: Dict[str, str]):
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url":
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
n=2,
max_tokens=10,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from vllm.sequence import Logprob from vllm.sequence import Logprob
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
@dataclass @dataclass
class BeamSearchSequence: class BeamSearchSequence:
...@@ -16,6 +19,10 @@ class BeamSearchSequence: ...@@ -16,6 +19,10 @@ class BeamSearchSequence:
logprobs: List[Dict[int, Logprob]] logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0 cum_logprob: float = 0.0
text: Optional[str] = None text: Optional[str] = None
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
multi_modal_data: Optional["MultiModalDataDict"] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
@dataclass @dataclass
......
...@@ -6,6 +6,7 @@ from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function ...@@ -6,6 +6,7 @@ from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -59,7 +60,8 @@ class EngineClient(ABC): ...@@ -59,7 +60,8 @@ class EngineClient(ABC):
async def beam_search( async def beam_search(
self, self,
prompt: Union[str, List[int]], prompt: Union[PromptType, List[int]],
model_config: ModelConfig,
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
...@@ -69,32 +71,40 @@ class EngineClient(ABC): ...@@ -69,32 +71,40 @@ class EngineClient(ABC):
ignore_eos = params.ignore_eos ignore_eos = params.ignore_eos
temperature = params.temperature temperature = params.temperature
length_penalty = params.length_penalty length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output
tokenizer = await self.get_tokenizer(lora_request=None) tokenizer = await self.get_tokenizer()
if isinstance(prompt, str): input_preprocessor = InputPreprocessor(model_config, tokenizer)
tokenized_prompt = tokenizer.encode(prompt)
prompt_text = prompt (prompt_text, prompt_token_ids, multi_modal_data,
else: mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
tokenized_prompt = prompt prompt,
prompt_text = None request_id=request_id,
tokenized_length = len(tokenized_prompt) )
tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function( sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty) tokenizer.eos_token_id, length_penalty)
beam_search_params = SamplingParams(logprobs=2 * beam_width, beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1, max_tokens=1,
temperature=temperature) temperature=temperature,
)
all_beams = [ all_beams = [
BeamSearchSequence(tokens=tokenized_prompt, BeamSearchSequence(tokens=prompt_token_ids,
cum_logprob=0,
logprobs=[], logprobs=[],
cum_logprob=0) multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
] ]
completed = [] completed = []
for _ in range(max_tokens): for _ in range(max_tokens):
prompts_batch = [ prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens) TokensPrompt(prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs)
for beam in all_beams for beam in all_beams
] ]
...@@ -120,17 +130,31 @@ class EngineClient(ABC): ...@@ -120,17 +130,31 @@ class EngineClient(ABC):
if result.outputs[0].logprobs is not None: if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0] logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \ if token_id == tokenizer.eos_token_id and \
not ignore_eos: not ignore_eos:
completed.append(new_beam) completed.append(
BeamSearchSequence(
tokens=current_beam.tokens +
[token_id] if include_stop_str_in_output
else current_beam.tokens,
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
finish_reason="stop",
stop_reason=tokenizer.eos_token_id))
else: else:
new_beams.append(new_beam) new_beams.append(
BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs +
[logprobs],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob,
multi_modal_data=current_beam.
multi_modal_data,
mm_processor_kwargs=current_beam.
mm_processor_kwargs))
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width] all_beams = sorted_beams[:beam_width]
...@@ -151,16 +175,18 @@ class EngineClient(ABC): ...@@ -151,16 +175,18 @@ class EngineClient(ABC):
request_id=request_id, request_id=request_id,
prompt=prompt_text, prompt=prompt_text,
outputs=[ outputs=[
CompletionOutput( CompletionOutput(text=beam.text,
text=beam.text,
cumulative_logprob=beam.cum_logprob, cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens[tokenized_length:], token_ids=beam.tokens[tokenized_length:],
index=i, index=i,
logprobs=beam.logprobs, logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams) finish_reason=beam.finish_reason if
beam.finish_reason is not None else "length",
stop_reason=beam.stop_reason)
for (i, beam) in enumerate(best_beams)
], ],
finished=True, finished=True,
prompt_token_ids=tokenized_prompt, prompt_token_ids=prompt_token_ids,
prompt_logprobs=None) prompt_logprobs=None)
yield beam_search_output yield beam_search_output
......
...@@ -308,7 +308,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -308,7 +308,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
) include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
...@@ -606,7 +606,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -606,7 +606,7 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
) include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
......
...@@ -236,9 +236,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -236,9 +236,10 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
result_generator = self.engine_client.beam_search( result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'], prompt=engine_inputs,
request_id, model_config=self.model_config,
sampling_params, request_id=request_id,
params=sampling_params,
) )
else: else:
result_generator = self.engine_client.generate( result_generator = self.engine_client.generate(
......
...@@ -150,9 +150,13 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -150,9 +150,13 @@ class OpenAIServingCompletion(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search( generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"], prompt={
request_id_item, "prompt_token_ids":
sampling_params, prompt_inputs["prompt_token_ids"]
},
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
) )
else: else:
generator = self.engine_client.generate( generator = self.engine_client.generate(
......
...@@ -500,3 +500,4 @@ class BeamSearchParams( ...@@ -500,3 +500,4 @@ class BeamSearchParams(
ignore_eos: bool = False ignore_eos: bool = False
temperature: float = 0.0 temperature: float = 0.0
length_penalty: float = 1.0 length_penalty: float = 1.0
include_stop_str_in_output: bool = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment