Unverified Commit 1ffc8a73 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473)

parent 944dd8ed
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Dict, List, Optional
from vllm.sequence import Logprob
@dataclass @dataclass
...@@ -11,6 +13,7 @@ class BeamSearchSequence: ...@@ -11,6 +13,7 @@ class BeamSearchSequence:
""" """
# The tokens includes the prompt. # The tokens includes the prompt.
tokens: List[int] tokens: List[int]
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0 cum_logprob: float = 0.0
text: Optional[str] = None text: Optional[str] = None
...@@ -28,7 +31,7 @@ class BeamSearchInstance: ...@@ -28,7 +31,7 @@ class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]): def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [ self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens) BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
] ]
self.completed: List[BeamSearchSequence] = [] self.completed: List[BeamSearchSequence] = []
......
...@@ -59,7 +59,7 @@ class EngineClient(ABC): ...@@ -59,7 +59,7 @@ class EngineClient(ABC):
async def beam_search( async def beam_search(
self, self,
prompt: Union[PromptType, List[int]], prompt: Union[str, List[int]],
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
...@@ -71,9 +71,13 @@ class EngineClient(ABC): ...@@ -71,9 +71,13 @@ class EngineClient(ABC):
length_penalty = params.length_penalty length_penalty = params.length_penalty
tokenizer = await self.get_tokenizer(lora_request=None) tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance( if isinstance(prompt, str):
prompt, list) else tokenizer.encode(prompt) tokenized_prompt = tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt) prompt_text = prompt
else:
tokenized_prompt = prompt
prompt_text = None
tokenized_length = len(tokenized_prompt)
sort_beams_key = create_sort_beams_key_function( sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty) tokenizer.eos_token_id, length_penalty)
...@@ -81,7 +85,11 @@ class EngineClient(ABC): ...@@ -81,7 +85,11 @@ class EngineClient(ABC):
beam_search_params = SamplingParams(logprobs=2 * beam_width, beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1, max_tokens=1,
temperature=temperature) temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] all_beams = [
BeamSearchSequence(tokens=tokenized_prompt,
logprobs=[],
cum_logprob=0)
]
completed = [] completed = []
for _ in range(max_tokens): for _ in range(max_tokens):
...@@ -114,6 +122,7 @@ class EngineClient(ABC): ...@@ -114,6 +122,7 @@ class EngineClient(ABC):
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence( new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id], tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob + cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob) logprob_obj.logprob)
...@@ -131,22 +140,22 @@ class EngineClient(ABC): ...@@ -131,22 +140,22 @@ class EngineClient(ABC):
best_beams = sorted_completed[:beam_width] best_beams = sorted_completed[:beam_width]
for beam in best_beams: for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) beam.text = tokenizer.decode(beam.tokens[tokenized_length:])
beam_search_output = RequestOutput( beam_search_output = RequestOutput(
request_id=request_id, request_id=request_id,
prompt=prompt, prompt=prompt_text,
outputs=[ outputs=[
CompletionOutput( CompletionOutput(
text=beam.text, text=beam.text,
cumulative_logprob=beam.cum_logprob, cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens, token_ids=beam.tokens[tokenized_length:],
index=i, index=i,
logprobs=beam.cum_logprob, logprobs=beam.logprobs,
) for (i, beam) in enumerate(best_beams) ) for (i, beam) in enumerate(best_beams)
], ],
finished=True, finished=True,
prompt_token_ids=tokenizedPrompt, prompt_token_ids=tokenized_prompt,
prompt_logprobs=None) prompt_logprobs=None)
yield beam_search_output yield beam_search_output
......
...@@ -433,6 +433,7 @@ class LLM: ...@@ -433,6 +433,7 @@ class LLM:
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence( new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id], tokens=current_beam.tokens + [token_id],
logprobs=current_beam.logprobs + [logprobs],
cum_logprob=current_beam.cum_logprob + cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob) logprob_obj.logprob)
......
...@@ -4,7 +4,6 @@ from typing import List, Optional ...@@ -4,7 +4,6 @@ from typing import List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import Union
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
...@@ -93,7 +92,7 @@ class RequestOutput: ...@@ -93,7 +92,7 @@ class RequestOutput:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
prompt: Optional[PromptType], prompt: Optional[str],
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment