"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "3f6a74723fca58d0fa55ffa3bcb6abb1fd4f2e4b"
Unverified Commit 142ac080 authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[Frontend] Optimize beam search performance by limiting concurrency (#23599)


Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
parent 32102644
...@@ -96,7 +96,6 @@ def run_vllm( ...@@ -96,7 +96,6 @@ def run_vllm(
end = time.perf_counter() end = time.perf_counter()
else: else:
assert lora_requests is None, "BeamSearch API does not support LoRA" assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests. # output_len should be the same for all requests.
output_len = requests[0].expected_output_len output_len = requests[0].expected_output_len
for request in requests: for request in requests:
......
...@@ -1022,15 +1022,17 @@ class VllmRunner: ...@@ -1022,15 +1022,17 @@ class VllmRunner:
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
concurrency_limit: Optional[int] = None,
) -> list[tuple[list[list[int]], list[str]]]: ) -> list[tuple[list[list[int]], list[str]]]:
inputs = self.get_inputs(prompts, inputs = self.get_inputs(prompts,
images=images, images=images,
videos=videos, videos=videos,
audios=audios) audios=audios)
outputs = self.llm.beam_search( outputs = self.llm.beam_search(inputs,
inputs, BeamSearchParams(beam_width=beam_width,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) max_tokens=max_tokens),
concurrency_limit=concurrency_limit)
returned_outputs = [] returned_outputs = []
for output in outputs: for output in outputs:
token_ids = [x.tokens for x in output.sequences] token_ids = [x.tokens for x in output.sequences]
......
...@@ -67,6 +67,59 @@ def test_beam_search_single_input( ...@@ -67,6 +67,59 @@ def test_beam_search_single_input(
f"vLLM: {vllm_output_ids}") f"vLLM: {vllm_output_ids}")
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_with_concurrency_limit(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
# example_prompts[1]&[3]&[7] fails due to unknown reason even without
# concurency limit. skip them for now.
example_prompts = (example_prompts[:8])
concurrency_limit = 2
assert len(example_prompts) > concurrency_limit
with vllm_runner(model, dtype=dtype) as vllm_model:
outputs_with_limit = vllm_model.generate_beam_search(
example_prompts,
beam_width,
max_tokens,
concurrency_limit=concurrency_limit)
outputs_without_limit = []
for i in range(0, len(example_prompts), concurrency_limit):
outputs_without_limit.extend(
vllm_model.generate_beam_search(
example_prompts[i:i + concurrency_limit], beam_width,
max_tokens))
correct = True
for i in range(len(example_prompts)):
output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i]
output_ids_without_limit, output_texts_without_limit = (
outputs_without_limit[i])
for j, (text_with_limit, text_without_limit) in enumerate(
zip(output_texts_with_limit, output_texts_without_limit)):
print(f">>>{j}-th with limit output:")
print(text_with_limit)
print(f">>>{j}-th without limit output:")
print(text_without_limit)
assert len(output_ids_with_limit) == len(output_ids_without_limit)
for j in range(len(output_ids_with_limit)):
if output_ids_with_limit[j] != output_ids_without_limit[j]:
print(f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n"
f"-limit: {output_ids_without_limit}")
correct = False
assert correct
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS) @pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS) @pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
......
...@@ -523,6 +523,7 @@ class LLM: ...@@ -523,6 +523,7 @@ class LLM:
params: BeamSearchParams, params: BeamSearchParams,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
use_tqdm: bool = False, use_tqdm: bool = False,
concurrency_limit: Optional[int] = None,
) -> list[BeamSearchOutput]: ) -> list[BeamSearchOutput]:
""" """
Generate sequences using beam search. Generate sequences using beam search.
...@@ -533,6 +534,8 @@ class LLM: ...@@ -533,6 +534,8 @@ class LLM:
params: The beam search parameters. params: The beam search parameters.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
concurrency_limit: The maximum number of concurrent requests.
If None, the number of concurrent requests is unlimited.
""" """
# TODO: how does beam search work together with length penalty, # TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.? # frequency, penalty, and stopping criteria, etc.?
...@@ -551,6 +554,15 @@ class LLM: ...@@ -551,6 +554,15 @@ class LLM:
length_penalty, length_penalty,
) )
if use_tqdm and concurrency_limit is not None:
logger.warning(
"Progress bar is not supported when using concurrency_limit. "
"Disabling progress bar.")
use_tqdm = False
if concurrency_limit is None:
concurrency_limit = len(prompts)
def create_tokens_prompt_from_beam( def create_tokens_prompt_from_beam(
beam: BeamSearchSequence) -> TokensPrompt: beam: BeamSearchSequence) -> TokensPrompt:
token_prompt_kwargs: TokensPrompt = { token_prompt_kwargs: TokensPrompt = {
...@@ -595,73 +607,79 @@ class LLM: ...@@ -595,73 +607,79 @@ class LLM:
**mm_kwargs, **mm_kwargs,
), ) ), )
token_iter = range(max_tokens) for prompt_start in range(0, len(prompts), concurrency_limit):
if use_tqdm: instances_batch = instances[prompt_start:prompt_start +
token_iter = tqdm(token_iter, concurrency_limit]
desc="Beam search",
unit="token", token_iter = range(max_tokens)
unit_scale=False) if use_tqdm:
logger.warning( token_iter = tqdm(token_iter,
"The progress bar shows the upper bound on token steps and " desc="Beam search",
"may finish early due to stopping conditions. It does not " unit="token",
"reflect instance-level progress.") unit_scale=False)
logger.warning(
for _ in token_iter: "The progress bar shows the upper bound on token steps and "
all_beams: list[BeamSearchSequence] = list( "may finish early due to stopping conditions. It does not "
sum((instance.beams for instance in instances), [])) "reflect instance-level progress.")
pos = [0] + list( for _ in token_iter:
itertools.accumulate( all_beams: list[BeamSearchSequence] = list(
len(instance.beams) for instance in instances)) sum((instance.beams for instance in instances_batch), []))
instance_start_and_end: list[tuple[int, int]] = list( pos = [0] + list(
zip(pos[:-1], pos[1:])) itertools.accumulate(
len(instance.beams) for instance in instances_batch))
if len(all_beams) == 0: instance_start_and_end: list[tuple[int, int]] = list(
break zip(pos[:-1], pos[1:]))
# create the corresponding batch entries for prompt & optional lora if len(all_beams) == 0:
prompts_batch, lora_req_batch = zip( break
*[(create_tokens_prompt_from_beam(beam), beam.lora_request)
for beam in all_beams]) # create corresponding batch entries for prompt & optional lora
prompts_batch, lora_req_batch = zip(
# only runs for one step *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
# we don't need to use tqdm here for beam in all_beams])
output = self.generate(prompts_batch,
sampling_params=beam_search_params, # only runs for one step
use_tqdm=False, # we don't need to use tqdm here
lora_request=lora_req_batch) output = self.generate(prompts_batch,
sampling_params=beam_search_params,
for (start, end), instance in zip(instance_start_and_end, use_tqdm=False,
instances): lora_request=lora_req_batch)
instance_new_beams = []
for i in range(start, end): for (start, end), instance in zip(instance_start_and_end,
current_beam = all_beams[i] instances_batch):
result = output[i] instance_new_beams = []
for i in range(start, end):
if result.outputs[0].logprobs is not None: current_beam = all_beams[i]
# if `result.outputs[0].logprobs` is None, it means result = output[i]
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams. if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0] # if `result.outputs[0].logprobs` is None, it means
for token_id, logprob_obj in logprobs.items(): # the sequence is completed because of the
new_beam = BeamSearchSequence( # max-model-len or abortion. we don't need to add
tokens=current_beam.tokens + [token_id], # it to the new beams.
logprobs=current_beam.logprobs + [logprobs], logprobs = result.outputs[0].logprobs[0]
lora_request=current_beam.lora_request, for token_id, logprob_obj in logprobs.items():
cum_logprob=current_beam.cum_logprob + new_beam = BeamSearchSequence(
logprob_obj.logprob, tokens=current_beam.tokens + [token_id],
multi_modal_data=current_beam.multi_modal_data, logprobs=current_beam.logprobs +
mm_processor_kwargs=current_beam. [logprobs],
mm_processor_kwargs) lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob +
if token_id == tokenizer.eos_token_id and \ logprob_obj.logprob,
not ignore_eos: multi_modal_data=current_beam.
instance.completed.append(new_beam) multi_modal_data,
else: mm_processor_kwargs=current_beam.
instance_new_beams.append(new_beam) mm_processor_kwargs)
sorted_beams = sorted(instance_new_beams,
key=sort_beams_key, if token_id == tokenizer.eos_token_id and \
reverse=True) not ignore_eos:
instance.beams = sorted_beams[:beam_width] instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=sort_beams_key,
reverse=True)
instance.beams = sorted_beams[:beam_width]
outputs = [] outputs = []
for instance in instances: for instance in instances:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment