Commit 8f1dc2a4 authored by MaYuhang's avatar MaYuhang
Browse files

issue/265 perf(llm): replace O(n²) full-sequence detokenize with incremental decode

parent 5fb56f97
...@@ -269,11 +269,15 @@ class LLMEngine: ...@@ -269,11 +269,15 @@ class LLMEngine:
req.is_prefill = False req.is_prefill = False
req.generated_token_ids.append(token_id) req.generated_token_ids.append(token_id)
decoded_text = self.detokenize(req.generated_token_ids) pending_tokens = req.generated_token_ids[req._pending_token_offset :]
req.generated_text = decoded_text delta = self.tokenizer.decode(pending_tokens)
holds_back_incomplete_utf8 = bool(decoded_text) and decoded_text.endswith( holds_back = bool(delta) and delta.endswith("\ufffd")
"\ufffd"
) last_committed_text = req.generated_text
if not holds_back:
req.generated_text = last_committed_text + delta
req._pending_token_offset = len(req.generated_token_ids)
is_finished = self._check_request_finished(req, token_id) is_finished = self._check_request_finished(req, token_id)
...@@ -281,25 +285,28 @@ class LLMEngine: ...@@ -281,25 +285,28 @@ class LLMEngine:
# For offline generation (no output queue), keep the fast incremental path. # For offline generation (no output queue), keep the fast incremental path.
if req._output_queue is None: if req._output_queue is None:
if is_finished: if is_finished:
if holds_back_incomplete_utf8:
req.generated_text = decoded_text[:-1]
req.mark_finished(req.finish_reason) req.mark_finished(req.finish_reason)
else: else:
if (holds_back_incomplete_utf8 and not is_finished) or ( if holds_back and not is_finished:
is_finished
and req.finish_reason
in (FinishReason.LENGTH, FinishReason.STOP_STRING)
):
token_text = "" token_text = ""
else: else:
last_len = getattr(req, "_stream_last_yielded_length", 0) if is_finished and req.finish_reason in (
token_text = decoded_text[last_len:] FinishReason.EOS_TOKEN,
if token_text: FinishReason.LENGTH,
req._stream_last_yielded_length = len(decoded_text) FinishReason.STOP_STRING,
):
token_text = ""
else:
token_text = req.generated_text[
req._stream_last_yielded_length :
]
if token_text:
req._stream_last_yielded_length = len(req.generated_text)
if is_finished:
req.mark_finished(req.finish_reason)
if is_finished:
req.mark_finished(req.finish_reason)
output = TokenOutput( output = TokenOutput(
request_id=req.request_id, request_id=req.request_id,
token_id=token_id, token_id=token_id,
......
...@@ -152,6 +152,7 @@ class InferenceRequest: ...@@ -152,6 +152,7 @@ class InferenceRequest:
# Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer) # Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer)
# Used by the engine to compute "delta" text chunks from a full decode. # Used by the engine to compute "delta" text chunks from a full decode.
self._stream_last_yielded_length: int = 0 self._stream_last_yielded_length: int = 0
self._pending_token_offset: int = 0
@property @property
def output_queue(self) -> janus.Queue: def output_queue(self) -> janus.Queue:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment