Unverified Commit 56f45edd authored by rookie's avatar rookie Committed by GitHub
Browse files

[Frontend] Optimize beam search loop by sorting and then splicing (#19347)


Signed-off-by: default avatarzhangguozhu <zhangguozhu@360.cn>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarzhangguozhu <zhangguozhu@360.cn>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent 82b05b15
...@@ -10,6 +10,7 @@ from concurrent.futures import ThreadPoolExecutor ...@@ -10,6 +10,7 @@ from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import numpy as np
import torch import torch
from fastapi import Request from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
...@@ -389,8 +390,9 @@ class OpenAIServing: ...@@ -389,8 +390,9 @@ class OpenAIServing:
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
logprobs_num = 2 * beam_width
beam_search_params = SamplingParams( beam_search_params = SamplingParams(
logprobs=2 * beam_width, logprobs=logprobs_num,
max_tokens=1, max_tokens=1,
temperature=temperature, temperature=temperature,
) )
...@@ -443,40 +445,75 @@ class OpenAIServing: ...@@ -443,40 +445,75 @@ class OpenAIServing:
output = [x[0] for x in await asyncio.gather(*tasks)] output = [x[0] for x in await asyncio.gather(*tasks)]
new_beams = [] new_beams = []
for i, current_beam in enumerate(all_beams): # Store all new tokens generated by beam
result = output[i] all_beams_token_id = []
# Store the cumulative probability of all tokens
# generated by beam search
all_beams_logprob = []
# Iterate through all beam inference results
for i, result in enumerate(output):
current_beam = all_beams[i]
if result.outputs[0].logprobs is not None: if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0] logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items(): all_beams_token_id.extend(list(logprobs.keys()))
if token_id == eos_token_id and not ignore_eos: all_beams_logprob.extend(
[
current_beam.cum_logprob + obj.logprob
for obj in logprobs.values()
]
)
# Handle the token for the end of sentence (EOS)
all_beams_token_id = np.array(all_beams_token_id)
all_beams_logprob = np.array(all_beams_logprob)
if not ignore_eos:
# Get the index position of eos token in all generated results
eos_idx = np.where(all_beams_token_id == eos_token_id)[0]
for idx in eos_idx:
current_beam = all_beams[idx // logprobs_num]
result = output[idx // logprobs_num]
assert result.outputs[0].logprobs is not None
logprobs_entry = result.outputs[0].logprobs[0]
completed.append( completed.append(
BeamSearchSequence( BeamSearchSequence(
tokens=current_beam.tokens + [token_id] tokens=current_beam.tokens + [eos_token_id]
if include_stop_str_in_output if include_stop_str_in_output
else current_beam.tokens, else current_beam.tokens,
logprobs=current_beam.logprobs + [logprobs], logprobs=current_beam.logprobs + [logprobs_entry],
cum_logprob=current_beam.cum_logprob cum_logprob=float(all_beams_logprob[idx]),
+ logprob_obj.logprob,
finish_reason="stop", finish_reason="stop",
stop_reason=eos_token_id, stop_reason=eos_token_id,
) )
) )
else: # After processing, set the log probability of the eos condition
# to negative infinity.
all_beams_logprob[eos_idx] = -np.inf
# Processing non-EOS tokens
# Get indices of the top beam_width probabilities
topn_idx = np.argpartition(np.negative(all_beams_logprob), beam_width)[
:beam_width
]
for idx in topn_idx:
current_beam = all_beams[idx // logprobs_num]
result = output[idx // logprobs_num]
token_id = int(all_beams_token_id[idx])
assert result.outputs[0].logprobs is not None
logprobs_entry = result.outputs[0].logprobs[0]
new_beams.append( new_beams.append(
BeamSearchSequence( BeamSearchSequence(
tokens=current_beam.tokens + [token_id], tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs], logprobs=current_beam.logprobs + [logprobs_entry],
lora_request=current_beam.lora_request, lora_request=current_beam.lora_request,
cum_logprob=current_beam.cum_logprob cum_logprob=float(all_beams_logprob[idx]),
+ logprob_obj.logprob,
multi_modal_data=current_beam.multi_modal_data, multi_modal_data=current_beam.multi_modal_data,
mm_processor_kwargs=current_beam.mm_processor_kwargs, mm_processor_kwargs=current_beam.mm_processor_kwargs,
) )
) )
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) all_beams = new_beams
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams) completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment