Unverified Commit a6ca736c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify stream_output (#2398)

parent f62055b5
......@@ -39,10 +39,12 @@ class LogitsProcessorOutput:
# The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor = None
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
input_top_logprobs: List = None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
output_top_logprobs: List = None
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
input_top_logprobs_val: List = None
input_top_logprobs_idx: List = None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
output_top_logprobs_val: List = None
output_top_logprobs_idx: List = None
@dataclasses.dataclass
......@@ -125,12 +127,15 @@ class LogitsProcessor(nn.Module):
indices = ret.indices.tolist()
if logits_metadata.forward_mode.is_decode():
output_top_logprobs = []
output_top_logprobs_val = []
output_top_logprobs_idx = []
for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
return None, output_top_logprobs
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return None, None, output_top_logprobs_val, output_top_logprobs_idx
else:
input_top_logprobs, output_top_logprobs = [], []
input_top_logprobs_val, input_top_logprobs_idx = [], []
output_top_logprobs_val, output_top_logprobs_idx = [], []
pt = 0
for k, pruned_len in zip(
......@@ -138,27 +143,36 @@ class LogitsProcessor(nn.Module):
logits_metadata.extend_logprob_pruned_lens_cpu,
):
if pruned_len <= 0:
input_top_logprobs.append([])
output_top_logprobs.append([])
input_top_logprobs_val.append([])
input_top_logprobs_idx.append([])
output_top_logprobs_val.append([])
output_top_logprobs_idx.append([])
continue
input_top_logprobs.append(
[
list(zip(values[pt + j][:k], indices[pt + j][:k]))
for j in range(pruned_len - 1)
]
input_top_logprobs_val.append(
[values[pt + j][:k] for j in range(pruned_len - 1)]
)
input_top_logprobs_idx.append(
[indices[pt + j][:k] for j in range(pruned_len - 1)]
)
output_top_logprobs.append(
output_top_logprobs_val.append(
list(
zip(
values[pt + pruned_len - 1][:k],
indices[pt + pruned_len - 1][:k],
)
)
output_top_logprobs_idx.append(
list(
indices[pt + pruned_len - 1][:k],
)
)
pt += pruned_len
return input_top_logprobs, output_top_logprobs
return (
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
)
def forward(
self,
......@@ -193,29 +207,22 @@ class LogitsProcessor(nn.Module):
if not logits_metadata.return_logprob:
return LogitsProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
else:
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
if logits_metadata.forward_mode.is_decode():
if logits_metadata.return_top_logprob:
output_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata
)[1]
output_top_logprobs_val, output_top_logprobs_idx = (
self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
)
else:
output_top_logprobs = None
output_top_logprobs_val = output_top_logprobs_idx = None
return LogitsProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=output_top_logprobs,
output_top_logprobs_val=output_top_logprobs_val,
output_top_logprobs_idx=output_top_logprobs_idx,
)
else:
# Slice the requested tokens to compute logprob
......@@ -246,11 +253,16 @@ class LogitsProcessor(nn.Module):
# Get the logprob of top-k tokens
if logits_metadata.return_top_logprob:
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata
)
(
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
) = self.get_top_logprobs(all_logprobs, logits_metadata)
else:
input_top_logprobs = output_top_logprobs = None
input_top_logprobs_val = input_top_logprobs_idx = (
output_top_logprobs_val
) = output_top_logprobs_idx = None
# Compute the normalized logprobs for the requested tokens.
# Note that we pad a zero at the end for easy batching.
......@@ -273,8 +285,10 @@ class LogitsProcessor(nn.Module):
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_top_logprobs=output_top_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
output_top_logprobs_val=output_top_logprobs_val,
output_top_logprobs_idx=output_top_logprobs_idx,
)
def _get_logits(
......
......@@ -17,7 +17,7 @@ import dataclasses
import logging
import signal
from collections import OrderedDict
from typing import List, Union
from typing import Dict, List, Union
import psutil
import setproctitle
......@@ -76,17 +76,25 @@ class DetokenizerManager:
self.decode_status = LimitedCapacityDict()
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
if no_stop_trim:
def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
):
if no_stop_trim or not finished_reason:
return output
matched = finished_reason.get("matched", None)
if not matched:
return output
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
pos = output.find(finished_reason.matched)
# TODO(lmzheng): handle the case where multiple stop strs are hit
# Trim stop str.
if isinstance(matched, str) and isinstance(output, str):
pos = output.find(matched)
return output[:pos] if pos != -1 else output
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
output, list
):
# Trim stop token.
if isinstance(matched, int) and isinstance(output, list):
assert len(output) > 0
return output[:-1]
return output
......@@ -125,9 +133,9 @@ class DetokenizerManager:
s.decode_ids = recv_obj.decode_ids[i]
read_ids.append(
self.trim_eos(
self.trim_matched_stop(
s.decode_ids[s.surr_offset :],
recv_obj.finished_reason[i],
recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i],
)
)
......@@ -150,7 +158,7 @@ class DetokenizerManager:
for i in range(bs):
s = self.decode_status[recv_obj.rids[i]]
new_text = read_texts[i][len(surr_texts[i]) :]
if recv_obj.finished_reason[i] is None:
if recv_obj.finished_reasons[i] is None:
# Streaming chunk: update the decode status
if len(new_text) > 0 and not new_text.endswith("�"):
s.decoded_text = s.decoded_text + new_text
......@@ -161,9 +169,9 @@ class DetokenizerManager:
new_text = find_printable_text(new_text)
output_strs.append(
self.trim_eos(
self.trim_matched_stop(
s.decoded_text + new_text,
recv_obj.finished_reason[i],
recv_obj.finished_reasons[i],
recv_obj.no_stop_trim[i],
)
)
......@@ -171,9 +179,20 @@ class DetokenizerManager:
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
prompt_tokens=recv_obj.prompt_tokens,
completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens,
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
input_top_logprobs_val=recv_obj.input_top_logprobs_val,
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
)
)
......
......@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput:
class BatchTokenIDOut:
# The request id
rids: List[str]
# The finish reason
finished_reasons: List[BaseFinishReason]
# For incremental decoding
# The version id to sync decode status with in detokenizer_manager
vids: List[int]
decoded_texts: List[str]
......@@ -315,35 +318,61 @@ class BatchTokenIDOut:
read_offsets: List[int]
# Only used when `--skip-tokenizer-init`
output_ids: Optional[List[int]]
# Detokenization configs
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
no_stop_trim: List[bool]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
# Logprobs
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
output_token_logprobs_val: List[float]
output_token_logprobs_idx: List[int]
input_top_logprobs_val: List[List]
input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
normalized_prompt_logprob: List[float]
@dataclass
class BatchStrOut:
# The request id
rids: List[str]
# The finish reason
finished_reasons: List[dict]
# The output decoded strings
output_strs: List[str]
# The meta info
meta_info: List[Dict]
# The finish reason
finished_reason: List[BaseFinishReason]
# Token counts
prompt_tokens: List[int]
completion_tokens: List[int]
cached_tokens: List[int]
# Logprobs
input_token_logprobs_val: List[float]
input_token_logprobs_idx: List[int]
output_token_logprobs_val: List[float]
output_token_logprobs_idx: List[int]
input_top_logprobs_val: List[List]
input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
normalized_prompt_logprob: List[float]
@dataclass
class BatchEmbeddingOut:
# The request id
rids: List[str]
# The finish reason
finished_reasons: List[BaseFinishReason]
# The output embedding
embeddings: List[List[float]]
# The meta info
meta_info: List[Dict]
# The finish reason
finished_reason: List[BaseFinishReason]
# Token counts
prompt_tokens: List[int]
@dataclass
......
......@@ -200,6 +200,9 @@ class Req:
origin_input_text: str,
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
return_logprob: bool = False,
top_logprobs_num: int = 0,
stream: bool = False,
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
......@@ -217,10 +220,11 @@ class Req:
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.session_id = session_id
self.input_embeds = input_embeds
# Sampling info
self.sampling_params = sampling_params
self.lora_path = lora_path
self.input_embeds = input_embeds
# Memory pool info
self.req_pool_idx = None
......@@ -228,8 +232,8 @@ class Req:
# Check finish
self.tokenizer = None
self.finished_reason = None
self.stream = False
self.to_abort = False
self.stream = stream
# For incremental decoding
# ----- | --------- read_ids -------|
......@@ -241,13 +245,9 @@ class Req:
# 2: read_offset
# 3: last token
self.vid = 0 # version id to sync decode status with in detokenizer_manager
self.decoded_text = ""
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0
self.decoded_text = ""
# For multimodal inputs
self.image_inputs: Optional[ImageInputs] = None
......@@ -256,22 +256,34 @@ class Req:
self.prefix_indices = []
self.extend_input_len = 0
self.last_node = None
# Chunked prefill
self.is_being_chunked = 0
# For retraction
self.is_retracted = False
# Logprobs (arguments)
self.return_logprob = False
self.return_logprob = return_logprob
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.top_logprobs_num = top_logprobs_num
# Logprobs (return value)
self.normalized_prompt_logprob = None
self.input_token_logprobs = None
self.input_top_logprobs = None
self.output_token_logprobs = []
self.output_top_logprobs = []
self.input_token_logprobs_val = None
self.input_token_logprobs_idx = None
self.input_top_logprobs_val = None
self.input_top_logprobs_idx = None
if return_logprob:
self.output_token_logprobs_val = []
self.output_token_logprobs_idx = []
self.output_top_logprobs_val = []
self.output_top_logprobs_idx = []
else:
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
self.output_top_logprobs_val
) = self.output_top_logprobs_idx = None
# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
......@@ -295,8 +307,8 @@ class Req:
else:
self.image_inputs.merge(image_inputs)
# whether request reached finished condition
def finished(self) -> bool:
# Whether request reached finished condition
return self.finished_reason is not None
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
......@@ -454,8 +466,10 @@ class Req:
k = k + 1
else:
break
self.output_token_logprobs = self.output_token_logprobs[:k]
self.output_top_logprobs = self.output_top_logprobs[:k]
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k
......
......@@ -515,6 +515,9 @@ class Scheduler:
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
return_logprob=recv_req.return_logprob,
top_logprobs_num=recv_req.top_logprobs_num,
stream=recv_req.stream,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
)
......@@ -558,9 +561,6 @@ class Scheduler:
return
# Copy more attributes
req.return_logprob = recv_req.return_logprob
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len == -1:
......@@ -982,7 +982,6 @@ class Scheduler:
continue
if req.is_being_chunked <= 0:
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
......@@ -1035,7 +1034,7 @@ class Scheduler:
# being chunked reqs' prefill is not finished
req.is_being_chunked -= 1
self.stream_output(batch.reqs, skip_stream_req)
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids, bid = result
......@@ -1065,7 +1064,6 @@ class Scheduler:
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
......@@ -1073,11 +1071,15 @@ class Scheduler:
self.tree_cache.cache_finished_req(req)
if req.return_logprob:
req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
req.output_token_logprobs_val.append(next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_id)
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
req.output_top_logprobs_val.append(
logits_output.output_top_logprobs_val[i]
)
req.output_top_logprobs_idx.append(
logits_output.output_top_logprobs_idx[i]
)
if req.grammar is not None:
req.grammar.accept_token(next_token_id)
......@@ -1088,7 +1090,7 @@ class Scheduler:
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs)
self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool.free_group_end()
......@@ -1108,9 +1110,8 @@ class Scheduler:
output: LogitsProcessorOutput,
):
"""Attach logprobs to the return values."""
req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_ids[i])
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
......@@ -1118,38 +1119,36 @@ class Scheduler:
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.input_token_logprobs is None:
input_token_logprobs = output.input_token_logprobs[
if req.input_token_logprobs_val is None:
input_token_logprobs_val = output.input_token_logprobs[
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
]
input_token_ids = req.fill_ids[
input_token_logprobs_idx = req.fill_ids[
len(req.fill_ids)
- num_input_logprobs
+ 1 : len(req.fill_ids)
- req.last_update_decode_tokens
]
# Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors.
input_token_ids = [
input_token_logprobs_idx = [
x if x < self.model_config.vocab_size - 1 else 0
for x in input_token_ids
for x in input_token_logprobs_idx
]
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
if (
req.logprob_start_len == 0
): # The first token does not have logprob, pad it.
req.input_token_logprobs = [
(None, req.fill_ids[0])
] + req.input_token_logprobs
input_token_logprobs_val = [None] + input_token_logprobs_val
input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx
req.input_token_logprobs_val = input_token_logprobs_val
req.input_token_logprobs_idx = input_token_logprobs_idx
if req.last_update_decode_tokens != 0:
# Some decode tokens are re-computed in an extend batch
req.output_token_logprobs.extend(
list(
zip(
req.output_token_logprobs_val.extend(
output.input_token_logprobs[
pt
+ num_input_logprobs
......@@ -1158,132 +1157,156 @@ class Scheduler:
+ num_input_logprobs
- 1
],
)
req.output_token_logprobs_idx.extend(
req.fill_ids[
len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids)
],
)
)
]
)
if req.top_logprobs_num > 0:
if req.input_top_logprobs is None:
req.input_top_logprobs = output.input_top_logprobs[i]
if req.input_top_logprobs_val is None:
req.input_top_logprobs_val = output.input_top_logprobs_val[i]
req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
if req.logprob_start_len == 0:
req.input_top_logprobs = [None] + req.input_top_logprobs
req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
if req.last_update_decode_tokens != 0:
req.output_top_logprobs.extend(
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
req.output_top_logprobs_val.extend(
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_idx.extend(
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs.append(output.output_top_logprobs[i])
req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
return num_input_logprobs
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
def stream_output(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
):
"""Stream the output to detokenizer."""
output_rids = []
output_meta_info: List[dict] = []
output_finished_reason: List[BaseFinishReason] = []
rids = []
finished_reasons: List[BaseFinishReason] = []
if self.is_generation:
output_vids = []
vids = []
decoded_texts = []
output_read_ids = []
output_read_offsets = []
decode_ids_list = []
read_offsets = []
output_ids = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_no_stop_trim = []
else: # embedding or reward model
output_embeddings = []
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
skip_special_tokens = []
spaces_between_special_tokens = []
no_stop_trim = []
prompt_tokens = []
completion_tokens = []
cached_tokens = []
if return_logprob:
input_token_logprobs_val = []
input_token_logprobs_idx = []
output_token_logprobs_val = []
output_token_logprobs_idx = []
input_top_logprobs_val = []
input_top_logprobs_idx = []
output_top_logprobs_val = []
output_top_logprobs_idx = []
normalized_prompt_logprob = []
else:
input_token_logprobs_val = input_token_logprobs_idx = (
output_token_logprobs_val
) = output_token_logprobs_idx = input_top_logprobs_val = (
input_top_logprobs_idx
) = output_top_logprobs_val = output_top_logprobs_idx = (
normalized_prompt_logprob
) = None
for req in reqs:
if req is skip_req:
continue
# TODO(lianmin): revisit this for overlap + retract + stream
if req.finished() or (
req.stream and (is_stream_iter or len(req.output_ids) == 1)
if (
req.finished()
# If stream, follow the given stream_interval
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
or (not req.stream and len(req.output_ids) % 50 == 0)
):
output_rids.append(req.rid)
output_finished_reason.append(req.finished_reason)
if self.is_generation:
output_vids.append(req.vid)
rids.append(req.rid)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
)
vids.append(req.vid)
decoded_texts.append(req.decoded_text)
read_ids, read_offset = req.init_incremental_detokenize()
output_read_ids.append(read_ids)
output_read_offsets.append(read_offset)
decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids)
read_offsets.append(read_offset)
if self.skip_tokenizer_init:
output_ids.append(req.output_ids)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
output_spaces_between_special_tokens.append(
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"cached_tokens": req.cached_tokens,
"finish_reason": (
req.finished_reason.to_json()
if req.finished_reason is not None
else None
),
}
if req.return_logprob:
(
meta_info["input_token_logprobs"],
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"],
meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
else: # embedding or reward model
output_embeddings.append(req.embedding)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
}
output_meta_info.append(meta_info)
no_stop_trim.append(req.sampling_params.no_stop_trim)
prompt_tokens.append(len(req.origin_input_ids))
completion_tokens.append(len(req.output_ids))
cached_tokens.append(req.cached_tokens)
if return_logprob:
input_token_logprobs_val.append(req.input_token_logprobs_val)
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
output_token_logprobs_val.append(req.output_token_logprobs_val)
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
input_top_logprobs_val.append(req.input_top_logprobs_val)
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
# Send to detokenizer
if output_rids:
if self.is_generation:
if rids:
self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut(
output_rids,
output_vids,
rids,
finished_reasons,
vids,
decoded_texts,
output_read_ids,
output_read_offsets,
decode_ids_list,
read_offsets,
output_ids,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished_reason,
output_no_stop_trim,
skip_special_tokens,
spaces_between_special_tokens,
no_stop_trim,
prompt_tokens,
completion_tokens,
cached_tokens,
input_token_logprobs_val,
input_token_logprobs_idx,
output_token_logprobs_val,
output_token_logprobs_idx,
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
normalized_prompt_logprob,
)
)
else: # embedding or reward model
embeddings = []
prompt_tokens = []
for req in reqs:
assert req.finished()
rids.append(req.rid)
finished_reasons.append(req.finished_reason.to_json())
embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids))
self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut(
output_rids,
output_embeddings,
output_meta_info,
output_finished_reason,
)
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
)
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
......
......@@ -22,7 +22,7 @@ import signal
import sys
import time
import uuid
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import fastapi
import uvloop
......@@ -76,6 +76,7 @@ class ReqState:
out_list: List
finished: bool
event: asyncio.Event
obj: Any
# For metrics
created_time: float
......@@ -283,7 +284,7 @@ class TokenizerManager:
):
"""Wait for the response of one request."""
event = asyncio.Event()
state = ReqState([], False, event, created_time=created_time)
state = ReqState([], False, event, obj, created_time=created_time)
self.rid_to_state[obj.rid] = state
while True:
......@@ -295,14 +296,6 @@ class TokenizerManager:
raise ValueError(f"Abort request {obj.rid}")
continue
if isinstance(obj, GenerateReqInput):
out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs,
)
else: # isinstance(obj, (EmbeddingReqInput,))
out = state.out_list[-1]
state.out_list = []
......@@ -315,7 +308,13 @@ class TokenizerManager:
break
state.event.clear()
if obj.stream:
yield out
else:
if request is not None and await request.is_disconnected():
self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}")
async def _handle_batch_request(
self,
......@@ -609,29 +608,55 @@ class TokenizerManager:
if state is None:
continue
recv_obj.meta_info[i]["id"] = rid
meta_info = {
"id": rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
}
if getattr(state.obj, "return_logprob", False):
self.convert_logprob_style(
meta_info,
state.obj.top_logprobs_num,
state.obj.return_text_in_logprobs,
recv_obj,
i,
)
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
"meta_info": {
**meta_info,
"completion_tokens": recv_obj.completion_tokens[i],
"cached_tokens": recv_obj.cached_tokens[i],
},
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
"meta_info": {
**meta_info,
"completion_tokens": recv_obj.completion_tokens[i],
"cached_tokens": recv_obj.cached_tokens[i],
},
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
"meta_info": meta_info,
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.finished = recv_obj.finished_reasons[i] is not None
state.event.set()
if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
completion_tokens = (
recv_obj.completion_tokens[i]
if recv_obj.completion_tokens
else 0
)
if state.first_token_time is None:
state.first_token_time = time.time()
......@@ -647,7 +672,7 @@ class TokenizerManager:
if state.finished:
self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"]
recv_obj.prompt_tokens[i]
)
self.metrics_collector.inc_generation_tokens(
completion_tokens
......@@ -696,57 +721,73 @@ class TokenizerManager:
def convert_logprob_style(
self,
ret: dict,
return_logprob: bool,
meta_info: dict,
top_logprobs_num: int,
return_text_in_logprobs: bool,
recv_obj: BatchStrOut,
recv_obj_index: int,
):
if return_logprob:
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
recv_obj.input_token_logprobs_val[recv_obj_index],
recv_obj.input_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
recv_obj.output_token_logprobs_val[recv_obj_index],
recv_obj.output_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
recv_obj_index
]
if top_logprobs_num > 0:
ret["meta_info"]["input_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["input_top_logprobs"],
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.input_top_logprobs_val[recv_obj_index],
recv_obj.input_top_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.output_top_logprobs_val[recv_obj_index],
recv_obj.output_top_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
ret["meta_info"]["output_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
)
)
return ret
def detokenize_logprob_tokens(
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
):
# TODO(lianmin): This should run on DetokenizerManager
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
assert self.tokenizer is not None
token_ids = [tid for _, tid in token_logprobs]
token_texts = self.tokenizer.batch_decode(token_ids)
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
(logprob, token_id, None)
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
]
else:
assert self.tokenizer is not None
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
def detokenize_top_logprobs_tokens(
self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
):
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
# We should batch all top-k tokens in all positions.
for i, token_top_logprobs in enumerate(top_logprobs):
if token_top_logprobs:
top_logprobs[i] = self.detokenize_logprob_tokens(
token_top_logprobs, decode_to_text
ret = []
for i in range(len(token_logprobs_val)):
if token_logprobs_val[i]:
ret.append(
self.detokenize_logprob_tokens(
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
)
return top_logprobs
)
else:
ret.append(None)
return ret
class SignalHandler:
......
......@@ -400,9 +400,14 @@ class CudaGraphRunner:
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
(
logits_output.output_top_logprobs_val,
logits_output.output_top_logprobs_idx,
) = LogitsProcessor.get_top_logprobs(
next_token_logprobs, logits_metadata
)[1]
)[
2:4
]
else:
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
......
......@@ -720,13 +720,13 @@ def run_and_check_memory_leak(
# Clean up everything
kill_process_tree(process.pid)
kill_process_tree(process.pid)
stdout.close()
stderr.close()
if os.path.exists(STDOUT_FILENAME):
os.remove(STDOUT_FILENAME)
if os.path.exists(STDERR_FILENAME):
os.remove(STDERR_FILENAME)
kill_process_tree(process.pid)
t.join()
# Assert success
......@@ -734,7 +734,7 @@ def run_and_check_memory_leak(
has_leak = False
has_abort = False
for line in output_lines:
if "The server is fired" in line:
if "Uvicorn running" in line:
has_new_server = True
if "leak" in line:
has_leak = True
......
......@@ -95,15 +95,6 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
self.assertIsInstance(js_obj["name"], str)
self.assertIsInstance(js_obj["population"], int)
# Make sure jump forward is triggered
# NOTE: The overlap scheduler does not support jump forward so we only do this test
# when --disable-overlap-schedule is set.
if self.check_jump_forward:
self.assertGreater(
ret["meta_info"]["completion_tokens"],
ret["meta_info"]["completion_tokens_wo_jump_forward"],
)
def test_json_generate(self):
self.run_decode(json_schema=self.json_schema)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment