Unverified Commit 9c902b19 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Decode Incrementally (#517)

parent 111991fe
import sglang as sgl
character_regex = (
r"""\{\n"""
+ r""" "姓名": "[^"]{1,32}",\n"""
+ r""" "学院": "(格兰芬多|赫奇帕奇|拉文克劳|斯莱特林)",\n"""
+ r""" "血型": "(纯血|混血|麻瓜)",\n"""
+ r""" "职业": "(学生|教师|傲罗|魔法部|食死徒|凤凰社成员)",\n"""
+ r""" "魔杖": \{\n"""
+ r""" "材质": "[^"]{1,32}",\n"""
+ r""" "杖芯": "[^"]{1,32}",\n"""
+ r""" "长度": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "存活": "(存活|死亡)",\n"""
+ r""" "守护神": "[^"]{1,32}",\n"""
+ r""" "博格特": "[^"]{1,32}"\n"""
+ r"""\}"""
)
@sgl.function
def character_gen(s, name):
s += name + " 是一名哈利波特系列小说中的角色。请填写以下关于这个角色的信息。"
s += """\
这是一个例子
{
"姓名": "哈利波特",
"学院": "格兰芬多",
"血型": "混血",
"职业": "学生",
"魔杖": {
"材质": "冬青木",
"杖芯": "凤凰尾羽",
"长度": 11.0
},
"存活": "存活",
"守护神": "麋鹿",
"博格特": "摄魂怪"
}
"""
s += f"现在请你填写{name}的信息:\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
def main():
backend = sgl.RuntimeEndpoint("http://localhost:30000")
sgl.set_default_backend(backend)
ret = character_gen.run(name="赫敏格兰杰", temperature=0)
print(ret.text())
if __name__ == "__main__":
main()
......@@ -3,8 +3,8 @@ from typing import Dict, Optional, Union
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel
......@@ -28,11 +28,12 @@ except ImportError:
__all__ = [
"RegexFSM",
"RegexGuide",
"FSMInfo",
"make_deterministic_fsm",
"build_regex_from_object",
"TransformerTokenizer",
"disk_cache",
"disable_cache",
"make_byte_level_fsm",
]
"""Cache for the compressed finite state machine."""
from sglang.srt.constrained import RegexFSM, TransformerTokenizer
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache
......@@ -26,4 +26,4 @@ class FSMCache(BaseCache):
)
def init_value(self, regex):
return RegexFSM(regex, self.outlines_tokenizer)
return RegexGuide(regex, self.outlines_tokenizer)
......@@ -2,20 +2,41 @@
Faster constrained decoding.
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
"""
import interegular
from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
import interegular
import dataclasses
from collections import defaultdict
import outlines.caching
from sglang.srt.constrained import (
FSMInfo,
disk_cache,
make_deterministic_fsm,
make_byte_level_fsm,
)
from sglang.srt.constrained.base_cache import BaseCache
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@dataclasses.dataclass
class JumpEdge:
symbol: str = None
symbol_next_state: int = None
byte: int = None
byte_next_state: int = None
class JumpForwardMap:
def __init__(self, regex_string):
@disk_cache()
def _init_state_to_jump_forward(regex_string):
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
byte_fsm = make_byte_level_fsm(
regex_pattern.to_fsm().reduce(), keep_utf8=True
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
fsm_info: FSMInfo = regex_fsm.fsm_info
......@@ -25,40 +46,91 @@ class JumpForwardMap:
id_to_symbol.setdefault(id_, []).append(symbol)
transitions = fsm_info.transitions
dirty_states = set()
outgoings_ct = defaultdict(int)
state_to_jump_forward = {}
for (state, id_), next_state in transitions.items():
if state in dirty_states:
if id_ == fsm_info.alphabet_anything_value:
continue
if state in state_to_jump_forward:
dirty_states.add(state)
del state_to_jump_forward[state]
continue
if len(id_to_symbol[id_]) > 1:
dirty_states.add(state)
symbols = id_to_symbol[id_]
for c in symbols:
if len(c) > 1:
# Skip byte level transitions
continue
outgoings_ct[state] += 1
if outgoings_ct[state] > 1:
if state in state_to_jump_forward:
del state_to_jump_forward[state]
break
state_to_jump_forward[state] = JumpEdge(
symbol=c,
symbol_next_state=next_state,
)
# Process the byte level jump forward
outgoings_ct = defaultdict(int)
for (state, id_), next_state in transitions.items():
if id_ == fsm_info.alphabet_anything_value:
continue
state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
symbols = id_to_symbol[id_]
for c in symbols:
byte_ = None
if len(c) == 1 and ord(c) < 0x80:
# ASCII character
byte_ = ord(c)
elif len(c) == 2:
byte_ = int(symbols[0], 16)
if byte_ is not None:
outgoings_ct[state] += 1
if outgoings_ct[state] > 1:
if state in state_to_jump_forward:
del state_to_jump_forward[state]
break
e = state_to_jump_forward.get(state, JumpEdge())
e.byte = byte_
e.byte_next_state = next_state
state_to_jump_forward[state] = e
return state_to_jump_forward
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
def valid_states(self):
return self.state_to_jump_forward.keys()
def jump_forward_symbol(self, state):
jump_forward_str = ""
next_state = state
while state in self.state_to_jump_forward:
e = self.state_to_jump_forward[state]
if e.symbol is None:
break
jump_forward_str += e.symbol
next_state = e.symbol_next_state
state = next_state
return jump_forward_str, next_state
def jump_forward(self, state):
def jump_forward_byte(self, state):
if state not in self.state_to_jump_forward:
return None
jump_forward_str = ""
jump_forward_bytes = []
next_state = None
while state in self.state_to_jump_forward:
symbol, next_state = self.state_to_jump_forward[state]
jump_forward_str += symbol
e = self.state_to_jump_forward[state]
assert e.byte is not None and e.byte_next_state is not None
jump_forward_bytes.append((e.byte, e.byte_next_state))
next_state = e.byte_next_state
state = next_state
return jump_forward_str, next_state
return jump_forward_bytes
def is_jump_forward_symbol_state(self, state):
return (
state in self.state_to_jump_forward
and self.state_to_jump_forward[state].symbol is not None
)
class JumpForwardCache(BaseCache):
......@@ -69,12 +141,21 @@ class JumpForwardCache(BaseCache):
return JumpForwardMap(regex)
def test_main():
regex_string = r"The google's DNS sever address is " + IP_REGEX
def test_main(regex_string):
jump_forward_map = JumpForwardMap(regex_string)
for state in jump_forward_map.valid_states():
print(state, f'"{jump_forward_map.jump_forward(state)}"')
for state, e in jump_forward_map.state_to_jump_forward.items():
if e.symbol is not None:
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
print(f"{state} -> {next_state}", jump_forward_str)
bytes_ = jump_forward_map.jump_forward_byte(state)
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
if __name__ == "__main__":
test_main()
import outlines
outlines.caching.clear_cache()
test_main(r"The google's DNS sever address is " + IP_REGEX)
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
......@@ -3,12 +3,17 @@
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import List
import warnings
import numpy as np
import torch
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.constrained import RegexGuide
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
class ForwardMode(IntEnum):
......@@ -64,12 +69,15 @@ class Req:
def __init__(self, rid, origin_input_text, origin_input_ids):
self.rid = rid
self.origin_input_text = origin_input_text
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
self.origin_input_ids = origin_input_ids
self.origin_input_ids_unpadded = origin_input_ids # before image padding
self.prev_output_str = ""
self.prev_output_ids = []
self.output_ids = []
self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
self.output_ids = [] # Each decode stage's output ids
self.input_ids = None # input_ids = origin_input_ids + output_ids
# For incremental decode
self.decoded_text = ""
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
......@@ -109,20 +117,54 @@ class Req:
self.last_update_decode_tokens = 0
# Constrained decoding
self.regex_fsm = None
self.regex_fsm_state = 0
self.jump_forward_map = None
self.regex_fsm: RegexGuide = None
self.regex_fsm_state: int = 0
self.jump_forward_map: JumpForwardMap = None
# whether request reached finished condition
def finished(self) -> bool:
return self.finished_reason is not None
def partial_decode(self, ids):
first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
first_token = (
first_token.decode() if isinstance(first_token, bytes) else first_token
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def init_detokenize_incrementally(self):
first_iter = self.surr_offset is None or self.read_offset is None
if first_iter:
self.read_offset = len(self.origin_input_ids_unpadded)
self.surr_offset = max(
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
)
all_ids = self.origin_input_ids_unpadded + self.output_ids
surr_ids = all_ids[self.surr_offset : self.read_offset]
read_ids = all_ids[self.surr_offset :]
return surr_ids, read_ids, len(all_ids)
def detokenize_incrementally(self, inplace: bool = True):
surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
surr_text = self.tokenizer.decode(
surr_ids,
skip_special_tokens=self.sampling_params.skip_special_tokens,
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
)
return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
new_text = self.tokenizer.decode(
read_ids,
skip_special_tokens=self.sampling_params.skip_special_tokens,
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
)
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
new_text = new_text[len(surr_text) :]
if inplace:
self.decoded_text += new_text
self.surr_offset = self.read_offset
self.read_offset = num_all_tokens
return True, new_text
return False, ""
def max_new_tokens(self):
return self.sampling_params.max_new_tokens
......@@ -131,18 +173,17 @@ class Req:
if self.finished():
return
if (
len(self.prev_output_ids) + len(self.output_ids)
>= self.sampling_params.max_new_tokens
):
self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids))
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
self.finished_reason = FINISH_LENGTH(len(self.output_ids))
return
if (
self.output_ids[-1] == self.tokenizer.eos_token_id
and not self.sampling_params.ignore_eos
):
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id)
self.finished_reason = FINISH_MATCHED_TOKEN(
matched=self.tokenizer.eos_token_id
)
return
if len(self.sampling_params.stop_strs) > 0:
......@@ -151,61 +192,59 @@ class Req:
)
for stop_str in self.sampling_params.stop_strs:
# FIXME: (minor) try incremental match in prev_output_str
if stop_str in tail_str or stop_str in self.prev_output_str:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
# FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
cur_output_str = self.partial_decode(self.output_ids)
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
self.origin_input_ids_unpadded
)
all_text = (
self.origin_input_text
+ self.prev_output_str
+ cur_output_str
+ jump_forward_str
)
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
all_ids = self.tokenizer.encode(all_text)
prompt_tokens = len(self.origin_input_ids_unpadded)
self.origin_input_ids = all_ids[:prompt_tokens]
self.origin_input_ids_unpadded = self.origin_input_ids
# NOTE: the output ids may not strictly correspond to the output text
old_prev_output_ids = self.prev_output_ids
self.prev_output_ids = all_ids[prompt_tokens:]
self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
self.output_ids = []
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion
warnings.warn(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return False
old_output_ids = self.output_ids
self.output_ids = all_ids[prompt_tokens:]
self.decoded_text = self.decoded_text + jump_forward_str
self.surr_offset = prompt_tokens
self.read_offset = len(all_ids)
# NOTE: A trick to reduce the surrouding tokens decoding overhead
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
surr_text_ = self.tokenizer.decode(
all_ids[self.read_offset - i : self.read_offset]
)
if not surr_text_.endswith("�"):
self.surr_offset = self.read_offset - i
break
self.regex_fsm_state = next_state
if self.return_logprob:
# For fast-forward part's logprobs
k = 0
for i, old_id in enumerate(old_prev_output_ids):
if old_id == self.prev_output_ids[i]:
for i, old_id in enumerate(old_output_ids):
if old_id == self.output_ids[i]:
k = k + 1
else:
break
self.decode_token_logprobs = self.decode_token_logprobs[:k]
self.decode_top_logprobs = self.decode_top_logprobs[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.prev_output_ids) - k
self.last_update_decode_tokens = len(self.output_ids) - k
# print("=" * 100)
# print(f"Catch jump forward:\n{jump_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
# print("*" * 100)
return True
def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
......@@ -381,7 +420,10 @@ class Batch:
sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve the priority of retraction
sorted_indices.sort(
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
key=lambda i: (
len(self.reqs[i].output_ids),
-len(self.reqs[i].origin_input_ids),
),
reverse=True,
)
......@@ -403,14 +445,9 @@ class Batch:
# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
cur_output_str = req.partial_decode(req.output_ids)
req.prev_output_str = req.prev_output_str + cur_output_str
req.prev_output_ids.extend(req.output_ids)
req.prefix_indices = None
req.last_node = None
req.extend_input_len = 0
req.output_ids = []
# For incremental logprobs
req.last_update_decode_tokens = 0
......@@ -428,18 +465,53 @@ class Batch:
for i, req in enumerate(self.reqs):
if req.jump_forward_map is not None:
res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
if res is not None:
jump_forward_str, next_state = res
if len(jump_forward_str) <= 1:
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
req.regex_fsm_state
)
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = req.regex_fsm_state
while (
len(jump_forward_bytes)
and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
# Current ids, for cache and revert
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
cur_output_ids = req.output_ids
req.output_ids.extend(suffix_ids)
decode_res, new_text = req.detokenize_incrementally(inplace=False)
if not decode_res:
req.output_ids = cur_output_ids
continue
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist()
jump_forward_str, next_state = (
req.jump_forward_map.jump_forward_symbol(cur_state)
)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize(
jump_forward_str, next_state
):
req.output_ids = cur_output_ids
continue
# insert the old request into tree_cache
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist()
self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
)
......@@ -447,9 +519,6 @@ class Batch:
# unlock the last node
self.tree_cache.dec_lock_ref(req.last_node)
# jump-forward
req.jump_forward_and_retokenize(jump_forward_str, next_state)
# re-applying image padding
if req.pixel_values is not None:
(
......@@ -583,7 +652,7 @@ class Batch:
if req.regex_fsm is not None:
allowed_mask.zero_()
allowed_mask[
req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
] = 1
logits[i].masked_fill_(~allowed_mask, float("-inf"))
......@@ -602,7 +671,7 @@ class Batch:
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
for i, req in enumerate(self.reqs):
if req.regex_fsm is not None:
req.regex_fsm_state = req.regex_fsm.next_state(
req.regex_fsm_state = req.regex_fsm.get_next_state(
req.regex_fsm_state, batch_next_token_ids_cpu[i]
)
......
......@@ -21,7 +21,13 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
from sglang.srt.managers.controller.infer_batch import (
BaseFinishReason,
Batch,
FINISH_ABORT,
ForwardMode,
Req,
)
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
......@@ -98,8 +104,11 @@ class ModelTpServer:
else server_args.max_prefill_tokens
),
)
self.max_running_requests = (self.max_total_num_tokens // 2
if server_args.max_running_requests is None else server_args.max_running_requests)
self.max_running_requests = (
self.max_total_num_tokens // 2
if server_args.max_running_requests is None
else server_args.max_running_requests
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
)
......@@ -314,10 +323,7 @@ class ModelTpServer:
# Compute matched prefix length
for req in self.forward_queue:
assert (
len(req.output_ids) == 0
), "The output ids should be empty when prefilling"
req.input_ids = req.origin_input_ids + req.prev_output_ids
req.input_ids = req.origin_input_ids + req.output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len]
......@@ -464,7 +470,7 @@ class ModelTpServer:
pt = 0
for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1
req.output_ids = [next_token_ids[i]]
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req.return_logprob:
......@@ -524,7 +530,7 @@ class ModelTpServer:
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False,
......@@ -596,8 +602,9 @@ class ModelTpServer:
def handle_finished_requests(self, batch: Batch):
output_rids = []
prev_output_strs = []
output_tokens = []
decoded_texts = []
surr_output_ids = []
read_output_ids = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_meta_info = []
......@@ -620,8 +627,10 @@ class ModelTpServer:
)
):
output_rids.append(req.rid)
prev_output_strs.append(req.prev_output_str)
output_tokens.append(req.output_ids)
decoded_texts.append(req.decoded_text)
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
surr_output_ids.append(surr_ids)
read_output_ids.append(read_ids)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
......@@ -631,7 +640,7 @@ class ModelTpServer:
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
"completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finished_reason),
}
......@@ -657,8 +666,9 @@ class ModelTpServer:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
prev_output_strs,
output_tokens,
decoded_texts,
surr_output_ids,
read_output_ids,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
......@@ -673,7 +683,7 @@ class ModelTpServer:
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
)
......@@ -790,4 +800,4 @@ class ModelTpClient:
return _func
self.step = async_wrap("step")
\ No newline at end of file
self.step = async_wrap("step")
......@@ -39,30 +39,24 @@ class DetokenizerManager:
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
assert isinstance(recv_obj, BatchTokenIDOut)
output_tokens = recv_obj.output_tokens
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs = self.tokenizer.batch_decode(
output_tokens,
surr_texts = self.tokenizer.batch_decode(
recv_obj.surr_output_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
read_texts = self.tokenizer.batch_decode(
recv_obj.read_output_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
0
],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)):
if len(output_tokens[i]) > 0:
first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0])
)
if not isinstance(first_token, str):
first_token = first_token.decode("utf-8", errors="ignore")
if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i]
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
output_strs = []
for i in range(len(recv_obj.rids)):
new_text = read_texts[i][len(surr_texts[i]) :]
output_strs.append(recv_obj.decoded_texts[i] + new_text)
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
......
......@@ -111,13 +111,15 @@ class TokenizedGenerateReqInput:
@dataclass
class BatchTokenIDOut:
rids: List[str]
prev_output_strs: List[str]
output_tokens: List[List[int]]
decoded_texts: List[str]
surr_output_ids: List[List[int]]
read_output_ids: List[List[int]]
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
@dataclass
class BatchStrOut:
rids: List[str]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment