Unverified Commit a9ef49c1 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Detokenize incrementally when streaming (#653)

parent 21ba3a88
...@@ -136,7 +136,33 @@ class RadixAttention(nn.Module): ...@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
return self.decode_forward(q, k, v, input_metadata) return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
key_buffer[input_metadata.out_cache_loc] = cache_k _store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
value_buffer[input_metadata.out_cache_loc] = cache_v
try:
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
def _store_kv_cache(
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
cache_loc: torch.Tensor,
) -> None:
kv_cache[cache_loc, 0] = k
kv_cache[cache_loc, 1] = v
@_store_kv_cache.register_fake
def _(k, v, kv_cache, cache_loc):
pass
except:
def _store_kv_cache(
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
cache_loc: torch.Tensor,
) -> None:
kv_cache[cache_loc, 0] = k
kv_cache[cache_loc, 1] = v
...@@ -82,6 +82,14 @@ class Req: ...@@ -82,6 +82,14 @@ class Req:
self.input_ids = None # input_ids = origin_input_ids + output_ids self.input_ids = None # input_ids = origin_input_ids + output_ids
# For incremental decoding # For incremental decoding
# ----- | --------- read_ids -------|
# ----- | surr_ids |
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
# ----- ^ ----------- ^ ----------- ^
# ----- 1 ----------- 2 ----------- 3
# 1: surr_offset
# 2: read_offset
# 3: last token
self.decoded_text = "" self.decoded_text = ""
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None self.read_offset = None
...@@ -132,7 +140,7 @@ class Req: ...@@ -132,7 +140,7 @@ class Req:
return self.finished_reason is not None return self.finished_reason is not None
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def init_detokenize_incrementally(self): def init_incremental_detokenize(self):
first_iter = self.surr_offset is None or self.read_offset is None first_iter = self.surr_offset is None or self.read_offset is None
if first_iter: if first_iter:
...@@ -142,13 +150,11 @@ class Req: ...@@ -142,13 +150,11 @@ class Req:
) )
all_ids = self.origin_input_ids_unpadded + self.output_ids all_ids = self.origin_input_ids_unpadded + self.output_ids
surr_ids = all_ids[self.surr_offset : self.read_offset] return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
read_ids = all_ids[self.surr_offset :]
return surr_ids, read_ids, len(all_ids) def get_next_inc_detokenization(self):
read_ids, read_offset = self.init_incremental_detokenize()
def detokenize_incrementally(self, inplace: bool = True): surr_ids = read_ids[:read_offset]
surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
surr_text = self.tokenizer.decode( surr_text = self.tokenizer.decode(
surr_ids, surr_ids,
...@@ -162,13 +168,7 @@ class Req: ...@@ -162,13 +168,7 @@ class Req:
) )
if len(new_text) > len(surr_text) and not new_text.endswith("�"): if len(new_text) > len(surr_text) and not new_text.endswith("�"):
new_text = new_text[len(surr_text) :] return True, new_text[len(surr_text) :]
if inplace:
self.decoded_text += new_text
self.surr_offset = self.read_offset
self.read_offset = num_all_tokens
return True, new_text
return False, "" return False, ""
...@@ -501,7 +501,7 @@ class Batch: ...@@ -501,7 +501,7 @@ class Batch:
cur_output_ids = req.output_ids cur_output_ids = req.output_ids
req.output_ids.extend(suffix_ids) req.output_ids.extend(suffix_ids)
decode_res, new_text = req.detokenize_incrementally(inplace=False) decode_res, new_text = req.get_next_inc_detokenization()
if not decode_res: if not decode_res:
req.output_ids = cur_output_ids req.output_ids = cur_output_ids
continue continue
......
...@@ -590,8 +590,8 @@ class ModelTpServer: ...@@ -590,8 +590,8 @@ class ModelTpServer:
def handle_finished_requests(self, batch: Batch): def handle_finished_requests(self, batch: Batch):
output_rids = [] output_rids = []
decoded_texts = [] decoded_texts = []
surr_output_ids = [] output_read_ids = []
read_output_ids = [] output_read_offsets = []
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
output_meta_info = [] output_meta_info = []
...@@ -615,9 +615,9 @@ class ModelTpServer: ...@@ -615,9 +615,9 @@ class ModelTpServer:
): ):
output_rids.append(req.rid) output_rids.append(req.rid)
decoded_texts.append(req.decoded_text) decoded_texts.append(req.decoded_text)
surr_ids, read_ids, _ = req.init_detokenize_incrementally() read_ids, read_offset = req.init_incremental_detokenize()
surr_output_ids.append(surr_ids) output_read_ids.append(read_ids)
read_output_ids.append(read_ids) output_read_offsets.append(read_offset)
output_skip_special_tokens.append( output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens req.sampling_params.skip_special_tokens
) )
...@@ -654,8 +654,8 @@ class ModelTpServer: ...@@ -654,8 +654,8 @@ class ModelTpServer:
BatchTokenIDOut( BatchTokenIDOut(
output_rids, output_rids,
decoded_texts, decoded_texts,
surr_output_ids, output_read_ids,
read_output_ids, output_read_offsets,
output_skip_special_tokens, output_skip_special_tokens,
output_spaces_between_special_tokens, output_spaces_between_special_tokens,
output_meta_info, output_meta_info,
......
"""DetokenizerManager is a process that detokenizes the token ids.""" """DetokenizerManager is a process that detokenizes the token ids."""
import asyncio import asyncio
import dataclasses
import inspect import inspect
from typing import List
import uvloop import uvloop
import zmq import zmq
...@@ -16,6 +18,14 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_ ...@@ -16,6 +18,14 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass
class DecodeStatus:
decoded_text: str
decode_ids: List[int]
surr_offset: int
read_offset: int
class DetokenizerManager: class DetokenizerManager:
def __init__( def __init__(
self, self,
...@@ -35,19 +45,42 @@ class DetokenizerManager: ...@@ -35,19 +45,42 @@ class DetokenizerManager:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
) )
self.decode_status = {}
async def handle_loop(self): async def handle_loop(self):
while True: while True:
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
assert isinstance(recv_obj, BatchTokenIDOut) assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids)
# FIXME: incremental detokenize is not compatible with jump forward
# Initialize decode status
read_ids, surr_ids = [], []
for i in range(bs):
rid = recv_obj.rids[i]
if rid not in self.decode_status:
s = DecodeStatus(
decoded_text=recv_obj.decoded_texts[i],
decode_ids=recv_obj.decode_ids[i],
surr_offset=0,
read_offset=recv_obj.read_offsets[i],
)
self.decode_status[rid] = s
else:
s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i]
read_ids.append(s.decode_ids[s.surr_offset :])
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
surr_texts = self.tokenizer.batch_decode( surr_texts = self.tokenizer.batch_decode(
recv_obj.surr_output_ids, surr_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0], skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
) )
read_texts = self.tokenizer.batch_decode( read_texts = self.tokenizer.batch_decode(
recv_obj.read_output_ids, read_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0], skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
) )
...@@ -55,11 +88,20 @@ class DetokenizerManager: ...@@ -55,11 +88,20 @@ class DetokenizerManager:
# Trim stop str # Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit # TODO(lmzheng): handle the case where multiple stop strs are hit
output_strs = [] output_strs = []
for i in range(len(recv_obj.rids)): for i in range(bs):
s = self.decode_status[recv_obj.rids[i]]
new_text = read_texts[i][len(surr_texts[i]) :] new_text = read_texts[i][len(surr_texts[i]) :]
if recv_obj.finished_reason[i] is None: if recv_obj.finished_reason[i] is None:
new_text = find_printable_text(new_text) # Streaming chunk: update the decode status
output_strs.append(recv_obj.decoded_texts[i] + new_text) if len(new_text) > 0 and not new_text.endswith("�"):
s.decoded_text = s.decoded_text + new_text
s.surr_offset = s.read_offset
s.read_offset = len(s.decode_ids)
new_text = ""
else:
new_text = find_printable_text(new_text)
output_strs.append(s.decoded_text + new_text)
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched) pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
......
...@@ -111,8 +111,8 @@ class TokenizedGenerateReqInput: ...@@ -111,8 +111,8 @@ class TokenizedGenerateReqInput:
class BatchTokenIDOut: class BatchTokenIDOut:
rids: List[str] rids: List[str]
decoded_texts: List[str] decoded_texts: List[str]
surr_output_ids: List[List[int]] decode_ids: List[int]
read_output_ids: List[List[int]] read_offsets: List[int]
skip_special_tokens: List[bool] skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool] spaces_between_special_tokens: List[bool]
meta_info: List[Dict] meta_info: List[Dict]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment