"docs/vscode:/vscode.git/clone" did not exist on "e23705e5577387872dd55ebf6db81bd59df928f1"
Unverified Commit a6ca736c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify stream_output (#2398)

parent f62055b5
...@@ -39,10 +39,12 @@ class LogitsProcessorOutput: ...@@ -39,10 +39,12 @@ class LogitsProcessorOutput:
# The logprobs of input tokens. shape: [#token, vocab_size] # The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor = None input_token_logprobs: torch.Tensor = None
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
input_top_logprobs: List = None input_top_logprobs_val: List = None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id) input_top_logprobs_idx: List = None
output_top_logprobs: List = None # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
output_top_logprobs_val: List = None
output_top_logprobs_idx: List = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -125,12 +127,15 @@ class LogitsProcessor(nn.Module): ...@@ -125,12 +127,15 @@ class LogitsProcessor(nn.Module):
indices = ret.indices.tolist() indices = ret.indices.tolist()
if logits_metadata.forward_mode.is_decode(): if logits_metadata.forward_mode.is_decode():
output_top_logprobs = [] output_top_logprobs_val = []
output_top_logprobs_idx = []
for i, k in enumerate(logits_metadata.top_logprobs_nums): for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k]))) output_top_logprobs_val.append(values[i][:k])
return None, output_top_logprobs output_top_logprobs_idx.append(indices[i][:k])
return None, None, output_top_logprobs_val, output_top_logprobs_idx
else: else:
input_top_logprobs, output_top_logprobs = [], [] input_top_logprobs_val, input_top_logprobs_idx = [], []
output_top_logprobs_val, output_top_logprobs_idx = [], []
pt = 0 pt = 0
for k, pruned_len in zip( for k, pruned_len in zip(
...@@ -138,27 +143,36 @@ class LogitsProcessor(nn.Module): ...@@ -138,27 +143,36 @@ class LogitsProcessor(nn.Module):
logits_metadata.extend_logprob_pruned_lens_cpu, logits_metadata.extend_logprob_pruned_lens_cpu,
): ):
if pruned_len <= 0: if pruned_len <= 0:
input_top_logprobs.append([]) input_top_logprobs_val.append([])
output_top_logprobs.append([]) input_top_logprobs_idx.append([])
output_top_logprobs_val.append([])
output_top_logprobs_idx.append([])
continue continue
input_top_logprobs.append( input_top_logprobs_val.append(
[ [values[pt + j][:k] for j in range(pruned_len - 1)]
list(zip(values[pt + j][:k], indices[pt + j][:k]))
for j in range(pruned_len - 1)
]
) )
output_top_logprobs.append( input_top_logprobs_idx.append(
[indices[pt + j][:k] for j in range(pruned_len - 1)]
)
output_top_logprobs_val.append(
list(
values[pt + pruned_len - 1][:k],
)
)
output_top_logprobs_idx.append(
list( list(
zip( indices[pt + pruned_len - 1][:k],
values[pt + pruned_len - 1][:k],
indices[pt + pruned_len - 1][:k],
)
) )
) )
pt += pruned_len pt += pruned_len
return input_top_logprobs, output_top_logprobs return (
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
)
def forward( def forward(
self, self,
...@@ -193,29 +207,22 @@ class LogitsProcessor(nn.Module): ...@@ -193,29 +207,22 @@ class LogitsProcessor(nn.Module):
if not logits_metadata.return_logprob: if not logits_metadata.return_logprob:
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
) )
else: else:
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
if logits_metadata.forward_mode.is_decode(): if logits_metadata.forward_mode.is_decode():
if logits_metadata.return_top_logprob: if logits_metadata.return_top_logprob:
output_top_logprobs = self.get_top_logprobs( output_top_logprobs_val, output_top_logprobs_idx = (
last_logprobs, logits_metadata self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
)[1] )
else: else:
output_top_logprobs = None output_top_logprobs_val = output_top_logprobs_idx = None
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=last_logits,
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None, output_top_logprobs_val=output_top_logprobs_val,
input_token_logprobs=None, output_top_logprobs_idx=output_top_logprobs_idx,
input_top_logprobs=None,
output_top_logprobs=output_top_logprobs,
) )
else: else:
# Slice the requested tokens to compute logprob # Slice the requested tokens to compute logprob
...@@ -246,11 +253,16 @@ class LogitsProcessor(nn.Module): ...@@ -246,11 +253,16 @@ class LogitsProcessor(nn.Module):
# Get the logprob of top-k tokens # Get the logprob of top-k tokens
if logits_metadata.return_top_logprob: if logits_metadata.return_top_logprob:
input_top_logprobs, output_top_logprobs = self.get_top_logprobs( (
all_logprobs, logits_metadata input_top_logprobs_val,
) input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
) = self.get_top_logprobs(all_logprobs, logits_metadata)
else: else:
input_top_logprobs = output_top_logprobs = None input_top_logprobs_val = input_top_logprobs_idx = (
output_top_logprobs_val
) = output_top_logprobs_idx = None
# Compute the normalized logprobs for the requested tokens. # Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching. # Note that we pad a zero at the end for easy batching.
...@@ -273,8 +285,10 @@ class LogitsProcessor(nn.Module): ...@@ -273,8 +285,10 @@ class LogitsProcessor(nn.Module):
next_token_logprobs=last_logprobs, next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs, input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs, input_top_logprobs_val=input_top_logprobs_val,
output_top_logprobs=output_top_logprobs, input_top_logprobs_idx=input_top_logprobs_idx,
output_top_logprobs_val=output_top_logprobs_val,
output_top_logprobs_idx=output_top_logprobs_idx,
) )
def _get_logits( def _get_logits(
......
...@@ -17,7 +17,7 @@ import dataclasses ...@@ -17,7 +17,7 @@ import dataclasses
import logging import logging
import signal import signal
from collections import OrderedDict from collections import OrderedDict
from typing import List, Union from typing import Dict, List, Union
import psutil import psutil
import setproctitle import setproctitle
...@@ -76,17 +76,25 @@ class DetokenizerManager: ...@@ -76,17 +76,25 @@ class DetokenizerManager:
self.decode_status = LimitedCapacityDict() self.decode_status = LimitedCapacityDict()
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim): def trim_matched_stop(
if no_stop_trim: self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
):
if no_stop_trim or not finished_reason:
return output
matched = finished_reason.get("matched", None)
if not matched:
return output return output
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit # TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
pos = output.find(finished_reason.matched) # Trim stop str.
if isinstance(matched, str) and isinstance(output, str):
pos = output.find(matched)
return output[:pos] if pos != -1 else output return output[:pos] if pos != -1 else output
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
output, list # Trim stop token.
): if isinstance(matched, int) and isinstance(output, list):
assert len(output) > 0 assert len(output) > 0
return output[:-1] return output[:-1]
return output return output
...@@ -125,9 +133,9 @@ class DetokenizerManager: ...@@ -125,9 +133,9 @@ class DetokenizerManager:
s.decode_ids = recv_obj.decode_ids[i] s.decode_ids = recv_obj.decode_ids[i]
read_ids.append( read_ids.append(
self.trim_eos( self.trim_matched_stop(
s.decode_ids[s.surr_offset :], s.decode_ids[s.surr_offset :],
recv_obj.finished_reason[i], recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i], recv_obj.no_stop_trim[i],
) )
) )
...@@ -150,7 +158,7 @@ class DetokenizerManager: ...@@ -150,7 +158,7 @@ class DetokenizerManager:
for i in range(bs): for i in range(bs):
s = self.decode_status[recv_obj.rids[i]] s = self.decode_status[recv_obj.rids[i]]
new_text = read_texts[i][len(surr_texts[i]) :] new_text = read_texts[i][len(surr_texts[i]) :]
if recv_obj.finished_reason[i] is None: if recv_obj.finished_reasons[i] is None:
# Streaming chunk: update the decode status # Streaming chunk: update the decode status
if len(new_text) > 0 and not new_text.endswith("�"): if len(new_text) > 0 and not new_text.endswith("�"):
s.decoded_text = s.decoded_text + new_text s.decoded_text = s.decoded_text + new_text
...@@ -161,9 +169,9 @@ class DetokenizerManager: ...@@ -161,9 +169,9 @@ class DetokenizerManager:
new_text = find_printable_text(new_text) new_text = find_printable_text(new_text)
output_strs.append( output_strs.append(
self.trim_eos( self.trim_matched_stop(
s.decoded_text + new_text, s.decoded_text + new_text,
recv_obj.finished_reason[i], recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i], recv_obj.no_stop_trim[i],
) )
) )
...@@ -171,9 +179,20 @@ class DetokenizerManager: ...@@ -171,9 +179,20 @@ class DetokenizerManager:
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
BatchStrOut( BatchStrOut(
rids=recv_obj.rids, rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs, output_strs=output_strs,
meta_info=recv_obj.meta_info, prompt_tokens=recv_obj.prompt_tokens,
finished_reason=recv_obj.finished_reason, completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
input_top_logprobs_val=recv_obj.input_top_logprobs_val,
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
) )
) )
......
...@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput: ...@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput:
class BatchTokenIDOut: class BatchTokenIDOut:
# The request id # The request id
rids: List[str] rids: List[str]
# The finish reason
finished_reasons: List[BaseFinishReason]
# For incremental decoding
# The version id to sync decode status with in detokenizer_manager # The version id to sync decode status with in detokenizer_manager
vids: List[int] vids: List[int]
decoded_texts: List[str] decoded_texts: List[str]
...@@ -315,35 +318,61 @@ class BatchTokenIDOut: ...@@ -315,35 +318,61 @@ class BatchTokenIDOut:
read_offsets: List[int] read_offsets: List[int]
# Only used when `--skip-tokenizer-init` # Only used when `--skip-tokenizer-init`
output_ids: Optional[List[int]] output_ids: Optional[List[int]]
# Detokenization configs
skip_special_tokens: List[bool] skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool] spaces_between_special_tokens: List[bool]
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
no_stop_trim: List[bool] no_stop_trim: List[bool]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
# Logprobs
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
output_token_logprobs_val: List[float]
output_token_logprobs_idx: List[int]
input_top_logprobs_val: List[List]
input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
normalized_prompt_logprob: List[float]
@dataclass @dataclass
class BatchStrOut: class BatchStrOut:
# The request id # The request id
rids: List[str] rids: List[str]
# The finish reason
finished_reasons: List[dict]
# The output decoded strings # The output decoded strings
output_strs: List[str] output_strs: List[str]
# The meta info
meta_info: List[Dict] # Token counts
# The finish reason prompt_tokens: List[int]
finished_reason: List[BaseFinishReason] completion_tokens: List[int]
cached_tokens: List[int]
# Logprobs
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
output_token_logprobs_val: List[float]
output_token_logprobs_idx: List[int]
input_top_logprobs_val: List[List]
input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
normalized_prompt_logprob: List[float]
@dataclass @dataclass
class BatchEmbeddingOut: class BatchEmbeddingOut:
# The request id # The request id
rids: List[str] rids: List[str]
# The finish reason
finished_reasons: List[BaseFinishReason]
# The output embedding # The output embedding
embeddings: List[List[float]] embeddings: List[List[float]]
# The meta info # Token counts
meta_info: List[Dict] prompt_tokens: List[int]
# The finish reason
finished_reason: List[BaseFinishReason]
@dataclass @dataclass
......
...@@ -200,6 +200,9 @@ class Req: ...@@ -200,6 +200,9 @@ class Req:
origin_input_text: str, origin_input_text: str,
origin_input_ids: Tuple[int], origin_input_ids: Tuple[int],
sampling_params: SamplingParams, sampling_params: SamplingParams,
return_logprob: bool = False,
top_logprobs_num: int = 0,
stream: bool = False,
origin_input_ids_unpadded: Optional[Tuple[int]] = None, origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None, input_embeds: Optional[List[List[float]]] = None,
...@@ -217,10 +220,11 @@ class Req: ...@@ -217,10 +220,11 @@ class Req:
self.output_ids = [] # Each decode stage's output ids self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.session_id = session_id self.session_id = session_id
self.input_embeds = input_embeds
# Sampling info
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.lora_path = lora_path self.lora_path = lora_path
self.input_embeds = input_embeds
# Memory pool info # Memory pool info
self.req_pool_idx = None self.req_pool_idx = None
...@@ -228,8 +232,8 @@ class Req: ...@@ -228,8 +232,8 @@ class Req:
# Check finish # Check finish
self.tokenizer = None self.tokenizer = None
self.finished_reason = None self.finished_reason = None
self.stream = False
self.to_abort = False self.to_abort = False
self.stream = stream
# For incremental decoding # For incremental decoding
# ----- | --------- read_ids -------| # ----- | --------- read_ids -------|
...@@ -241,13 +245,9 @@ class Req: ...@@ -241,13 +245,9 @@ class Req:
# 2: read_offset # 2: read_offset
# 3: last token # 3: last token
self.vid = 0 # version id to sync decode status with in detokenizer_manager self.vid = 0 # version id to sync decode status with in detokenizer_manager
self.decoded_text = ""
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None self.read_offset = None
self.decoded_text = ""
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0
# For multimodal inputs # For multimodal inputs
self.image_inputs: Optional[ImageInputs] = None self.image_inputs: Optional[ImageInputs] = None
...@@ -256,22 +256,34 @@ class Req: ...@@ -256,22 +256,34 @@ class Req:
self.prefix_indices = [] self.prefix_indices = []
self.extend_input_len = 0 self.extend_input_len = 0
self.last_node = None self.last_node = None
# Chunked prefill
self.is_being_chunked = 0 self.is_being_chunked = 0
# For retraction # For retraction
self.is_retracted = False self.is_retracted = False
# Logprobs (arguments) # Logprobs (arguments)
self.return_logprob = False self.return_logprob = return_logprob
self.logprob_start_len = 0 self.logprob_start_len = 0
self.top_logprobs_num = 0 self.top_logprobs_num = top_logprobs_num
# Logprobs (return value) # Logprobs (return value)
self.normalized_prompt_logprob = None self.normalized_prompt_logprob = None
self.input_token_logprobs = None self.input_token_logprobs_val = None
self.input_top_logprobs = None self.input_token_logprobs_idx = None
self.output_token_logprobs = [] self.input_top_logprobs_val = None
self.output_top_logprobs = [] self.input_top_logprobs_idx = None
if return_logprob:
self.output_token_logprobs_val = []
self.output_token_logprobs_idx = []
self.output_top_logprobs_val = []
self.output_top_logprobs_idx = []
else:
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
self.output_top_logprobs_val
) = self.output_top_logprobs_idx = None
# Logprobs (internal values) # Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens # The tokens is prefilled but need to be considered as decode tokens
...@@ -295,8 +307,8 @@ class Req: ...@@ -295,8 +307,8 @@ class Req:
else: else:
self.image_inputs.merge(image_inputs) self.image_inputs.merge(image_inputs)
# whether request reached finished condition
def finished(self) -> bool: def finished(self) -> bool:
# Whether request reached finished condition
return self.finished_reason is not None return self.finished_reason is not None
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
...@@ -454,8 +466,10 @@ class Req: ...@@ -454,8 +466,10 @@ class Req:
k = k + 1 k = k + 1
else: else:
break break
self.output_token_logprobs = self.output_token_logprobs[:k] self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
self.output_top_logprobs = self.output_top_logprobs[:k] self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
self.logprob_start_len = prompt_tokens + k self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k self.last_update_decode_tokens = len(self.output_ids) - k
......
...@@ -515,6 +515,9 @@ class Scheduler: ...@@ -515,6 +515,9 @@ class Scheduler:
recv_req.input_text, recv_req.input_text,
recv_req.input_ids, recv_req.input_ids,
recv_req.sampling_params, recv_req.sampling_params,
return_logprob=recv_req.return_logprob,
top_logprobs_num=recv_req.top_logprobs_num,
stream=recv_req.stream,
lora_path=recv_req.lora_path, lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds, input_embeds=recv_req.input_embeds,
) )
...@@ -558,9 +561,6 @@ class Scheduler: ...@@ -558,9 +561,6 @@ class Scheduler:
return return
# Copy more attributes # Copy more attributes
req.return_logprob = recv_req.return_logprob
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
req.logprob_start_len = recv_req.logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len == -1: if req.logprob_start_len == -1:
...@@ -982,7 +982,6 @@ class Scheduler: ...@@ -982,7 +982,6 @@ class Scheduler:
continue continue
if req.is_being_chunked <= 0: if req.is_being_chunked <= 0:
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
...@@ -1035,7 +1034,7 @@ class Scheduler: ...@@ -1035,7 +1034,7 @@ class Scheduler:
# being chunked reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_being_chunked -= 1 req.is_being_chunked -= 1
self.stream_output(batch.reqs, skip_stream_req) self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids, bid = result logits_output, next_token_ids, bid = result
...@@ -1065,7 +1064,6 @@ class Scheduler: ...@@ -1065,7 +1064,6 @@ class Scheduler:
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue continue
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
...@@ -1073,11 +1071,15 @@ class Scheduler: ...@@ -1073,11 +1071,15 @@ class Scheduler:
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
if req.return_logprob: if req.return_logprob:
req.output_token_logprobs.append( req.output_token_logprobs_val.append(next_token_logprobs[i])
(next_token_logprobs[i], next_token_id) req.output_token_logprobs_idx.append(next_token_id)
)
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) req.output_top_logprobs_val.append(
logits_output.output_top_logprobs_val[i]
)
req.output_top_logprobs_idx.append(
logits_output.output_top_logprobs_idx[i]
)
if req.grammar is not None: if req.grammar is not None:
req.grammar.accept_token(next_token_id) req.grammar.accept_token(next_token_id)
...@@ -1088,7 +1090,7 @@ class Scheduler: ...@@ -1088,7 +1090,7 @@ class Scheduler:
self.current_stream.synchronize() self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs) self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool.free_group_end() self.token_to_kv_pool.free_group_end()
...@@ -1108,9 +1110,8 @@ class Scheduler: ...@@ -1108,9 +1110,8 @@ class Scheduler:
output: LogitsProcessorOutput, output: LogitsProcessorOutput,
): ):
"""Attach logprobs to the return values.""" """Attach logprobs to the return values."""
req.output_token_logprobs.append( req.output_token_logprobs_val.append(output.next_token_logprobs[i])
(output.next_token_logprobs[i], next_token_ids[i]) req.output_token_logprobs_idx.append(next_token_ids[i])
)
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
...@@ -1118,173 +1119,195 @@ class Scheduler: ...@@ -1118,173 +1119,195 @@ class Scheduler:
if req.normalized_prompt_logprob is None: if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.input_token_logprobs is None: if req.input_token_logprobs_val is None:
input_token_logprobs = output.input_token_logprobs[ input_token_logprobs_val = output.input_token_logprobs[
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
] ]
input_token_ids = req.fill_ids[
input_token_logprobs_idx = req.fill_ids[
len(req.fill_ids) len(req.fill_ids)
- num_input_logprobs - num_input_logprobs
+ 1 : len(req.fill_ids) + 1 : len(req.fill_ids)
- req.last_update_decode_tokens - req.last_update_decode_tokens
] ]
# Clip the padded hash values from image tokens. # Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors. # Otherwise, it will lead to detokenization errors.
input_token_ids = [ input_token_logprobs_idx = [
x if x < self.model_config.vocab_size - 1 else 0 x if x < self.model_config.vocab_size - 1 else 0
for x in input_token_ids for x in input_token_logprobs_idx
] ]
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
if ( if (
req.logprob_start_len == 0 req.logprob_start_len == 0
): # The first token does not have logprob, pad it. ): # The first token does not have logprob, pad it.
req.input_token_logprobs = [ input_token_logprobs_val = [None] + input_token_logprobs_val
(None, req.fill_ids[0]) input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx
] + req.input_token_logprobs
req.input_token_logprobs_val = input_token_logprobs_val
req.input_token_logprobs_idx = input_token_logprobs_idx
if req.last_update_decode_tokens != 0: if req.last_update_decode_tokens != 0:
# Some decode tokens are re-computed in an extend batch # Some decode tokens are re-computed in an extend batch
req.output_token_logprobs.extend( req.output_token_logprobs_val.extend(
list( output.input_token_logprobs[
zip( pt
output.input_token_logprobs[ + num_input_logprobs
pt - 1
+ num_input_logprobs - req.last_update_decode_tokens : pt
- 1 + num_input_logprobs
- req.last_update_decode_tokens : pt - 1
+ num_input_logprobs ],
- 1 )
], req.output_token_logprobs_idx.extend(
req.fill_ids[ req.fill_ids[
len(req.fill_ids) len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids) - req.last_update_decode_tokens : len(req.fill_ids)
], ]
)
)
) )
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
if req.input_top_logprobs is None: if req.input_top_logprobs_val is None:
req.input_top_logprobs = output.input_top_logprobs[i] req.input_top_logprobs_val = output.input_top_logprobs_val[i]
req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
if req.logprob_start_len == 0: if req.logprob_start_len == 0:
req.input_top_logprobs = [None] + req.input_top_logprobs req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
if req.last_update_decode_tokens != 0: if req.last_update_decode_tokens != 0:
req.output_top_logprobs.extend( req.output_top_logprobs_val.extend(
output.input_top_logprobs[i][-req.last_update_decode_tokens :] output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
) )
req.output_top_logprobs.append(output.output_top_logprobs[i]) req.output_top_logprobs_idx.extend(
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
return num_input_logprobs return num_input_logprobs
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None): def stream_output(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
):
"""Stream the output to detokenizer.""" """Stream the output to detokenizer."""
output_rids = [] rids = []
output_meta_info: List[dict] = [] finished_reasons: List[BaseFinishReason] = []
output_finished_reason: List[BaseFinishReason] = []
if self.is_generation: if self.is_generation:
output_vids = [] vids = []
decoded_texts = [] decoded_texts = []
output_read_ids = [] decode_ids_list = []
output_read_offsets = [] read_offsets = []
output_ids = [] output_ids = []
output_skip_special_tokens = [] skip_special_tokens = []
output_spaces_between_special_tokens = [] spaces_between_special_tokens = []
output_no_stop_trim = [] no_stop_trim = []
else: # embedding or reward model prompt_tokens = []
output_embeddings = [] completion_tokens = []
cached_tokens = []
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
if return_logprob:
for req in reqs: input_token_logprobs_val = []
if req is skip_req: input_token_logprobs_idx = []
continue output_token_logprobs_val = []
output_token_logprobs_idx = []
input_top_logprobs_val = []
input_top_logprobs_idx = []
output_top_logprobs_val = []
output_top_logprobs_idx = []
normalized_prompt_logprob = []
else:
input_token_logprobs_val = input_token_logprobs_idx = (
output_token_logprobs_val
) = output_token_logprobs_idx = input_top_logprobs_val = (
input_top_logprobs_idx
) = output_top_logprobs_val = output_top_logprobs_idx = (
normalized_prompt_logprob
) = None
for req in reqs:
if req is skip_req:
continue
# TODO(lianmin): revisit this for overlap + retract + stream # TODO(lianmin): revisit this for overlap + retract + stream
if req.finished() or ( if (
req.stream and (is_stream_iter or len(req.output_ids) == 1) req.finished()
): # If stream, follow the given stream_interval
output_rids.append(req.rid) or (req.stream and len(req.output_ids) % self.stream_interval == 0)
output_finished_reason.append(req.finished_reason) # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
if self.is_generation: or (not req.stream and len(req.output_ids) % 50 == 0)
output_vids.append(req.vid) ):
rids.append(req.rid)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
)
vids.append(req.vid)
decoded_texts.append(req.decoded_text) decoded_texts.append(req.decoded_text)
read_ids, read_offset = req.init_incremental_detokenize() decode_ids, read_offset = req.init_incremental_detokenize()
output_read_ids.append(read_ids) decode_ids_list.append(decode_ids)
output_read_offsets.append(read_offset) read_offsets.append(read_offset)
if self.skip_tokenizer_init: if self.skip_tokenizer_init:
output_ids.append(req.output_ids) output_ids.append(req.output_ids)
output_skip_special_tokens.append( skip_special_tokens.append(req.sampling_params.skip_special_tokens)
req.sampling_params.skip_special_tokens spaces_between_special_tokens.append(
)
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens req.sampling_params.spaces_between_special_tokens
) )
output_no_stop_trim.append(req.sampling_params.no_stop_trim) no_stop_trim.append(req.sampling_params.no_stop_trim)
meta_info = { prompt_tokens.append(len(req.origin_input_ids))
"prompt_tokens": len(req.origin_input_ids), completion_tokens.append(len(req.output_ids))
"completion_tokens": len(req.output_ids), cached_tokens.append(req.cached_tokens)
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"cached_tokens": req.cached_tokens, if return_logprob:
"finish_reason": ( input_token_logprobs_val.append(req.input_token_logprobs_val)
req.finished_reason.to_json() input_token_logprobs_idx.append(req.input_token_logprobs_idx)
if req.finished_reason is not None output_token_logprobs_val.append(req.output_token_logprobs_val)
else None output_token_logprobs_idx.append(req.output_token_logprobs_idx)
), input_top_logprobs_val.append(req.input_top_logprobs_val)
} input_top_logprobs_idx.append(req.input_top_logprobs_idx)
if req.return_logprob: output_top_logprobs_val.append(req.output_top_logprobs_val)
( output_top_logprobs_idx.append(req.output_top_logprobs_idx)
meta_info["input_token_logprobs"], normalized_prompt_logprob.append(req.normalized_prompt_logprob)
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"], # Send to detokenizer
meta_info["output_top_logprobs"], if rids:
meta_info["normalized_prompt_logprob"],
) = (
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
else: # embedding or reward model
output_embeddings.append(req.embedding)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
}
output_meta_info.append(meta_info)
# Send to detokenizer
if output_rids:
if self.is_generation:
self.send_to_detokenizer.send_pyobj( self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut( BatchTokenIDOut(
output_rids, rids,
output_vids, finished_reasons,
vids,
decoded_texts, decoded_texts,
output_read_ids, decode_ids_list,
output_read_offsets, read_offsets,
output_ids, output_ids,
output_skip_special_tokens, skip_special_tokens,
output_spaces_between_special_tokens, spaces_between_special_tokens,
output_meta_info, no_stop_trim,
output_finished_reason, prompt_tokens,
output_no_stop_trim, completion_tokens,
) cached_tokens,
) input_token_logprobs_val,
else: # embedding or reward model input_token_logprobs_idx,
self.send_to_detokenizer.send_pyobj( output_token_logprobs_val,
BatchEmbeddingOut( output_token_logprobs_idx,
output_rids, input_top_logprobs_val,
output_embeddings, input_top_logprobs_idx,
output_meta_info, output_top_logprobs_val,
output_finished_reason, output_top_logprobs_idx,
normalized_prompt_logprob,
) )
) )
else: # embedding or reward model
embeddings = []
prompt_tokens = []
for req in reqs:
assert req.finished()
rids.append(req.rid)
finished_reasons.append(req.finished_reason.to_json())
embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids))
self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
)
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches # Check if other DP workers have running batches
......
...@@ -22,7 +22,7 @@ import signal ...@@ -22,7 +22,7 @@ import signal
import sys import sys
import time import time
import uuid import uuid
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import uvloop import uvloop
...@@ -76,6 +76,7 @@ class ReqState: ...@@ -76,6 +76,7 @@ class ReqState:
out_list: List out_list: List
finished: bool finished: bool
event: asyncio.Event event: asyncio.Event
obj: Any
# For metrics # For metrics
created_time: float created_time: float
...@@ -283,7 +284,7 @@ class TokenizerManager: ...@@ -283,7 +284,7 @@ class TokenizerManager:
): ):
"""Wait for the response of one request.""" """Wait for the response of one request."""
event = asyncio.Event() event = asyncio.Event()
state = ReqState([], False, event, created_time=created_time) state = ReqState([], False, event, obj, created_time=created_time)
self.rid_to_state[obj.rid] = state self.rid_to_state[obj.rid] = state
while True: while True:
...@@ -295,15 +296,7 @@ class TokenizerManager: ...@@ -295,15 +296,7 @@ class TokenizerManager:
raise ValueError(f"Abort request {obj.rid}") raise ValueError(f"Abort request {obj.rid}")
continue continue
if isinstance(obj, GenerateReqInput): out = state.out_list[-1]
out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs,
)
else: # isinstance(obj, (EmbeddingReqInput,))
out = state.out_list[-1]
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
...@@ -315,7 +308,13 @@ class TokenizerManager: ...@@ -315,7 +308,13 @@ class TokenizerManager:
break break
state.event.clear() state.event.clear()
yield out
if obj.stream:
yield out
else:
if request is not None and await request.is_disconnected():
self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}")
async def _handle_batch_request( async def _handle_batch_request(
self, self,
...@@ -609,29 +608,55 @@ class TokenizerManager: ...@@ -609,29 +608,55 @@ class TokenizerManager:
if state is None: if state is None:
continue continue
recv_obj.meta_info[i]["id"] = rid meta_info = {
"id": rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
}
if getattr(state.obj, "return_logprob", False):
self.convert_logprob_style(
meta_info,
state.obj.top_logprobs_num,
state.obj.return_text_in_logprobs,
recv_obj,
i,
)
if isinstance(recv_obj, BatchStrOut): if isinstance(recv_obj, BatchStrOut):
out_dict = { out_dict = {
"text": recv_obj.output_strs[i], "text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i], "meta_info": {
**meta_info,
"completion_tokens": recv_obj.completion_tokens[i],
"cached_tokens": recv_obj.cached_tokens[i],
},
} }
elif isinstance(recv_obj, BatchTokenIDOut): elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = { out_dict = {
"token_ids": recv_obj.output_ids[i], "token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i], "meta_info": {
**meta_info,
"completion_tokens": recv_obj.completion_tokens[i],
"cached_tokens": recv_obj.cached_tokens[i],
},
} }
else: else:
assert isinstance(recv_obj, BatchEmbeddingOut) assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = { out_dict = {
"embedding": recv_obj.embeddings[i], "embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i], "meta_info": meta_info,
} }
state.out_list.append(out_dict) state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None state.finished = recv_obj.finished_reasons[i] is not None
state.event.set() state.event.set()
if self.enable_metrics: if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"] completion_tokens = (
recv_obj.completion_tokens[i]
if recv_obj.completion_tokens
else 0
)
if state.first_token_time is None: if state.first_token_time is None:
state.first_token_time = time.time() state.first_token_time = time.time()
...@@ -647,7 +672,7 @@ class TokenizerManager: ...@@ -647,7 +672,7 @@ class TokenizerManager:
if state.finished: if state.finished:
self.metrics_collector.inc_prompt_tokens( self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"] recv_obj.prompt_tokens[i]
) )
self.metrics_collector.inc_generation_tokens( self.metrics_collector.inc_generation_tokens(
completion_tokens completion_tokens
...@@ -696,57 +721,73 @@ class TokenizerManager: ...@@ -696,57 +721,73 @@ class TokenizerManager:
def convert_logprob_style( def convert_logprob_style(
self, self,
ret: dict, meta_info: dict,
return_logprob: bool,
top_logprobs_num: int, top_logprobs_num: int,
return_text_in_logprobs: bool, return_text_in_logprobs: bool,
recv_obj: BatchStrOut,
recv_obj_index: int,
): ):
if return_logprob: meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens( recv_obj.input_token_logprobs_val[recv_obj_index],
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs recv_obj.input_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
recv_obj.output_token_logprobs_val[recv_obj_index],
recv_obj.output_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
recv_obj_index
]
if top_logprobs_num > 0:
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.input_top_logprobs_val[recv_obj_index],
recv_obj.input_top_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
) )
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens( meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs recv_obj.output_top_logprobs_val[recv_obj_index],
recv_obj.output_top_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
) )
if top_logprobs_num > 0:
ret["meta_info"]["input_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["input_top_logprobs"],
return_text_in_logprobs,
)
)
ret["meta_info"]["output_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
)
)
return ret
def detokenize_logprob_tokens( def detokenize_logprob_tokens(
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
): ):
# TODO(lianmin): This should run on DetokenizerManager
if not decode_to_text: if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs] return [
(logprob, token_id, None)
assert self.tokenizer is not None for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
token_ids = [tid for _, tid in token_logprobs] ]
token_texts = self.tokenizer.batch_decode(token_ids) else:
return [ assert self.tokenizer is not None
(logprob, token_id, token_text) token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
for (logprob, token_id), token_text in zip(token_logprobs, token_texts) return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): def detokenize_top_logprobs_tokens(
self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
):
# TODO: The current implementation only batches the detokenization for top-k tokens per single position. # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
# We should batch all top-k tokens in all positions. # We should batch all top-k tokens in all positions.
for i, token_top_logprobs in enumerate(top_logprobs): ret = []
if token_top_logprobs: for i in range(len(token_logprobs_val)):
top_logprobs[i] = self.detokenize_logprob_tokens( if token_logprobs_val[i]:
token_top_logprobs, decode_to_text ret.append(
self.detokenize_logprob_tokens(
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
)
) )
return top_logprobs else:
ret.append(None)
return ret
class SignalHandler: class SignalHandler:
......
...@@ -400,9 +400,14 @@ class CudaGraphRunner: ...@@ -400,9 +400,14 @@ class CudaGraphRunner:
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
top_logprobs_nums=forward_batch.top_logprobs_nums, top_logprobs_nums=forward_batch.top_logprobs_nums,
) )
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( (
logits_output.output_top_logprobs_val,
logits_output.output_top_logprobs_idx,
) = LogitsProcessor.get_top_logprobs(
next_token_logprobs, logits_metadata next_token_logprobs, logits_metadata
)[1] )[
2:4
]
else: else:
logits_output = LogitsProcessorOutput( logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits, next_token_logits=next_token_logits,
......
...@@ -720,13 +720,13 @@ def run_and_check_memory_leak( ...@@ -720,13 +720,13 @@ def run_and_check_memory_leak(
# Clean up everything # Clean up everything
kill_process_tree(process.pid) kill_process_tree(process.pid)
kill_process_tree(process.pid)
stdout.close() stdout.close()
stderr.close() stderr.close()
if os.path.exists(STDOUT_FILENAME): if os.path.exists(STDOUT_FILENAME):
os.remove(STDOUT_FILENAME) os.remove(STDOUT_FILENAME)
if os.path.exists(STDERR_FILENAME): if os.path.exists(STDERR_FILENAME):
os.remove(STDERR_FILENAME) os.remove(STDERR_FILENAME)
kill_process_tree(process.pid)
t.join() t.join()
# Assert success # Assert success
...@@ -734,7 +734,7 @@ def run_and_check_memory_leak( ...@@ -734,7 +734,7 @@ def run_and_check_memory_leak(
has_leak = False has_leak = False
has_abort = False has_abort = False
for line in output_lines: for line in output_lines:
if "The server is fired" in line: if "Uvicorn running" in line:
has_new_server = True has_new_server = True
if "leak" in line: if "leak" in line:
has_leak = True has_leak = True
......
...@@ -95,15 +95,6 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase): ...@@ -95,15 +95,6 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
self.assertIsInstance(js_obj["name"], str) self.assertIsInstance(js_obj["name"], str)
self.assertIsInstance(js_obj["population"], int) self.assertIsInstance(js_obj["population"], int)
# Make sure jump forward is triggered
# NOTE: The overlap scheduler does not support jump forward so we only do this test
# when --disable-overlap-schedule is set.
if self.check_jump_forward:
self.assertGreater(
ret["meta_info"]["completion_tokens"],
ret["meta_info"]["completion_tokens_wo_jump_forward"],
)
def test_json_generate(self): def test_json_generate(self):
self.run_decode(json_schema=self.json_schema) self.run_decode(json_schema=self.json_schema)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment