Unverified Commit 466166dc authored by NekoMimiUnagi's avatar NekoMimiUnagi Committed by GitHub
Browse files

[Frontend] Add optional token-level progress bar to `LLM.beam_search` (#19301)


Signed-off-by: default avatarRuosen Li <rxl190028@utdallas.edu>
Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Signed-off-by: default avatarUbuntu <ubuntu@ip-172-31-71-179.ec2.internal>
Co-authored-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent 1d0ae26c
...@@ -552,6 +552,7 @@ class LLM: ...@@ -552,6 +552,7 @@ class LLM:
prompts: list[Union[TokensPrompt, TextPrompt]], prompts: list[Union[TokensPrompt, TextPrompt]],
params: BeamSearchParams, params: BeamSearchParams,
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
use_tqdm: bool = False,
) -> list[BeamSearchOutput]: ) -> list[BeamSearchOutput]:
""" """
Generate sequences using beam search. Generate sequences using beam search.
...@@ -561,6 +562,7 @@ class LLM: ...@@ -561,6 +562,7 @@ class LLM:
of token IDs. of token IDs.
params: The beam search parameters. params: The beam search parameters.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
use_tqdm: Whether to use tqdm to display the progress bar.
""" """
# TODO: how does beam search work together with length penalty, # TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.? # frequency, penalty, and stopping criteria, etc.?
...@@ -623,7 +625,18 @@ class LLM: ...@@ -623,7 +625,18 @@ class LLM:
**mm_kwargs, **mm_kwargs,
), ) ), )
for _ in range(max_tokens): token_iter = range(max_tokens)
if use_tqdm:
token_iter = tqdm(token_iter,
desc="Beam search",
unit="token",
unit_scale=False)
logger.warning(
"The progress bar shows the upper bound on token steps and "
"may finish early due to stopping conditions. It does not "
"reflect instance-level progress.")
for _ in token_iter:
all_beams: list[BeamSearchSequence] = list( all_beams: list[BeamSearchSequence] = list(
sum((instance.beams for instance in instances), [])) sum((instance.beams for instance in instances), []))
pos = [0] + list( pos = [0] + list(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment