Unverified Commit 466166dc authored by NekoMimiUnagi's avatar NekoMimiUnagi Committed by GitHub
Browse files

[Frontend] Add optional token-level progress bar to `LLM.beam_search` (#19301)


Signed-off-by: default avatarRuosen Li <rxl190028@utdallas.edu>
Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Signed-off-by: default avatarUbuntu <ubuntu@ip-172-31-71-179.ec2.internal>
Co-authored-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent 1d0ae26c
......@@ -552,6 +552,7 @@ class LLM:
prompts: list[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
use_tqdm: bool = False,
) -> list[BeamSearchOutput]:
"""
Generate sequences using beam search.
......@@ -561,6 +562,7 @@ class LLM:
of token IDs.
params: The beam search parameters.
lora_request: LoRA request to use for generation, if any.
use_tqdm: Whether to use tqdm to display the progress bar.
"""
# TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.?
......@@ -623,7 +625,18 @@ class LLM:
**mm_kwargs,
), )
for _ in range(max_tokens):
token_iter = range(max_tokens)
if use_tqdm:
token_iter = tqdm(token_iter,
desc="Beam search",
unit="token",
unit_scale=False)
logger.warning(
"The progress bar shows the upper bound on token steps and "
"may finish early due to stopping conditions. It does not "
"reflect instance-level progress.")
for _ in token_iter:
all_beams: list[BeamSearchSequence] = list(
sum((instance.beams for instance in instances), []))
pos = [0] + list(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment