"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f0c6d9784b6b5ec01e3c3a3795d22680567429aa"
Unverified Commit 9c902b19 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Decode Incrementally (#517)

parent 111991fe
import sglang as sgl
character_regex = (
r"""\{\n"""
+ r""" "姓名": "[^"]{1,32}",\n"""
+ r""" "学院": "(格兰芬多|赫奇帕奇|拉文克劳|斯莱特林)",\n"""
+ r""" "血型": "(纯血|混血|麻瓜)",\n"""
+ r""" "职业": "(学生|教师|傲罗|魔法部|食死徒|凤凰社成员)",\n"""
+ r""" "魔杖": \{\n"""
+ r""" "材质": "[^"]{1,32}",\n"""
+ r""" "杖芯": "[^"]{1,32}",\n"""
+ r""" "长度": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "存活": "(存活|死亡)",\n"""
+ r""" "守护神": "[^"]{1,32}",\n"""
+ r""" "博格特": "[^"]{1,32}"\n"""
+ r"""\}"""
)
@sgl.function
def character_gen(s, name):
s += name + " 是一名哈利波特系列小说中的角色。请填写以下关于这个角色的信息。"
s += """\
这是一个例子
{
"姓名": "哈利波特",
"学院": "格兰芬多",
"血型": "混血",
"职业": "学生",
"魔杖": {
"材质": "冬青木",
"杖芯": "凤凰尾羽",
"长度": 11.0
},
"存活": "存活",
"守护神": "麋鹿",
"博格特": "摄魂怪"
}
"""
s += f"现在请你填写{name}的信息:\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
def main():
backend = sgl.RuntimeEndpoint("http://localhost:30000")
sgl.set_default_backend(backend)
ret = character_gen.run(name="赫敏格兰杰", temperature=0)
print(ret.text())
if __name__ == "__main__":
main()
...@@ -3,8 +3,8 @@ from typing import Dict, Optional, Union ...@@ -3,8 +3,8 @@ from typing import Dict, Optional, Union
from outlines.caching import cache as disk_cache from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache from outlines.caching import disable_cache
from outlines.fsm.fsm import RegexFSM from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm
from outlines.models.transformers import TransformerTokenizer from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel from pydantic import BaseModel
...@@ -28,11 +28,12 @@ except ImportError: ...@@ -28,11 +28,12 @@ except ImportError:
__all__ = [ __all__ = [
"RegexFSM", "RegexGuide",
"FSMInfo", "FSMInfo",
"make_deterministic_fsm", "make_deterministic_fsm",
"build_regex_from_object", "build_regex_from_object",
"TransformerTokenizer", "TransformerTokenizer",
"disk_cache", "disk_cache",
"disable_cache", "disable_cache",
"make_byte_level_fsm",
] ]
"""Cache for the compressed finite state machine.""" """Cache for the compressed finite state machine."""
from sglang.srt.constrained import RegexFSM, TransformerTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.base_cache import BaseCache
...@@ -26,4 +26,4 @@ class FSMCache(BaseCache): ...@@ -26,4 +26,4 @@ class FSMCache(BaseCache):
) )
def init_value(self, regex): def init_value(self, regex):
return RegexFSM(regex, self.outlines_tokenizer) return RegexGuide(regex, self.outlines_tokenizer)
...@@ -2,20 +2,41 @@ ...@@ -2,20 +2,41 @@
Faster constrained decoding. Faster constrained decoding.
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
""" """
import interegular
from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm import interegular
import dataclasses
from collections import defaultdict
import outlines.caching
from sglang.srt.constrained import (
FSMInfo,
disk_cache,
make_deterministic_fsm,
make_byte_level_fsm,
)
from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.base_cache import BaseCache
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@dataclasses.dataclass
class JumpEdge:
symbol: str = None
symbol_next_state: int = None
byte: int = None
byte_next_state: int = None
class JumpForwardMap: class JumpForwardMap:
def __init__(self, regex_string): def __init__(self, regex_string):
@disk_cache() @disk_cache()
def _init_state_to_jump_forward(regex_string): def _init_state_to_jump_forward(regex_string):
regex_pattern = interegular.parse_pattern(regex_string) regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
byte_fsm = make_byte_level_fsm(
regex_pattern.to_fsm().reduce(), keep_utf8=True
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
fsm_info: FSMInfo = regex_fsm.fsm_info fsm_info: FSMInfo = regex_fsm.fsm_info
...@@ -25,40 +46,91 @@ class JumpForwardMap: ...@@ -25,40 +46,91 @@ class JumpForwardMap:
id_to_symbol.setdefault(id_, []).append(symbol) id_to_symbol.setdefault(id_, []).append(symbol)
transitions = fsm_info.transitions transitions = fsm_info.transitions
dirty_states = set() outgoings_ct = defaultdict(int)
state_to_jump_forward = {} state_to_jump_forward = {}
for (state, id_), next_state in transitions.items(): for (state, id_), next_state in transitions.items():
if state in dirty_states: if id_ == fsm_info.alphabet_anything_value:
continue continue
if state in state_to_jump_forward: symbols = id_to_symbol[id_]
dirty_states.add(state) for c in symbols:
del state_to_jump_forward[state] if len(c) > 1:
continue # Skip byte level transitions
if len(id_to_symbol[id_]) > 1: continue
dirty_states.add(state)
outgoings_ct[state] += 1
if outgoings_ct[state] > 1:
if state in state_to_jump_forward:
del state_to_jump_forward[state]
break
state_to_jump_forward[state] = JumpEdge(
symbol=c,
symbol_next_state=next_state,
)
# Process the byte level jump forward
outgoings_ct = defaultdict(int)
for (state, id_), next_state in transitions.items():
if id_ == fsm_info.alphabet_anything_value:
continue continue
symbols = id_to_symbol[id_]
state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state) for c in symbols:
byte_ = None
if len(c) == 1 and ord(c) < 0x80:
# ASCII character
byte_ = ord(c)
elif len(c) == 2:
byte_ = int(symbols[0], 16)
if byte_ is not None:
outgoings_ct[state] += 1
if outgoings_ct[state] > 1:
if state in state_to_jump_forward:
del state_to_jump_forward[state]
break
e = state_to_jump_forward.get(state, JumpEdge())
e.byte = byte_
e.byte_next_state = next_state
state_to_jump_forward[state] = e
return state_to_jump_forward return state_to_jump_forward
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string) self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
def valid_states(self): def jump_forward_symbol(self, state):
return self.state_to_jump_forward.keys() jump_forward_str = ""
next_state = state
while state in self.state_to_jump_forward:
e = self.state_to_jump_forward[state]
if e.symbol is None:
break
jump_forward_str += e.symbol
next_state = e.symbol_next_state
state = next_state
return jump_forward_str, next_state
def jump_forward(self, state): def jump_forward_byte(self, state):
if state not in self.state_to_jump_forward: if state not in self.state_to_jump_forward:
return None return None
jump_forward_str = "" jump_forward_bytes = []
next_state = None next_state = None
while state in self.state_to_jump_forward: while state in self.state_to_jump_forward:
symbol, next_state = self.state_to_jump_forward[state] e = self.state_to_jump_forward[state]
jump_forward_str += symbol assert e.byte is not None and e.byte_next_state is not None
jump_forward_bytes.append((e.byte, e.byte_next_state))
next_state = e.byte_next_state
state = next_state state = next_state
return jump_forward_str, next_state
return jump_forward_bytes
def is_jump_forward_symbol_state(self, state):
return (
state in self.state_to_jump_forward
and self.state_to_jump_forward[state].symbol is not None
)
class JumpForwardCache(BaseCache): class JumpForwardCache(BaseCache):
...@@ -69,12 +141,21 @@ class JumpForwardCache(BaseCache): ...@@ -69,12 +141,21 @@ class JumpForwardCache(BaseCache):
return JumpForwardMap(regex) return JumpForwardMap(regex)
def test_main(): def test_main(regex_string):
regex_string = r"The google's DNS sever address is " + IP_REGEX
jump_forward_map = JumpForwardMap(regex_string) jump_forward_map = JumpForwardMap(regex_string)
for state in jump_forward_map.valid_states(): for state, e in jump_forward_map.state_to_jump_forward.items():
print(state, f'"{jump_forward_map.jump_forward(state)}"') if e.symbol is not None:
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
print(f"{state} -> {next_state}", jump_forward_str)
bytes_ = jump_forward_map.jump_forward_byte(state)
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
if __name__ == "__main__": if __name__ == "__main__":
test_main() import outlines
outlines.caching.clear_cache()
test_main(r"The google's DNS sever address is " + IP_REGEX)
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
...@@ -3,12 +3,17 @@ ...@@ -3,12 +3,17 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List from typing import List
import warnings
import numpy as np import numpy as np
import torch import torch
from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.constrained import RegexGuide
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
...@@ -64,12 +69,15 @@ class Req: ...@@ -64,12 +69,15 @@ class Req:
def __init__(self, rid, origin_input_text, origin_input_ids): def __init__(self, rid, origin_input_text, origin_input_ids):
self.rid = rid self.rid = rid
self.origin_input_text = origin_input_text self.origin_input_text = origin_input_text
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
self.origin_input_ids = origin_input_ids self.origin_input_ids = origin_input_ids
self.origin_input_ids_unpadded = origin_input_ids # before image padding self.output_ids = [] # Each decode stage's output ids
self.prev_output_str = "" self.input_ids = None # input_ids = origin_input_ids + output_ids
self.prev_output_ids = []
self.output_ids = [] # For incremental decode
self.input_ids = None # input_ids = origin_input_ids + prev_output_ids self.decoded_text = ""
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None
# The number of decoded tokens for token usage report. Note that # The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens. # this does not include the jump forward tokens.
...@@ -109,20 +117,54 @@ class Req: ...@@ -109,20 +117,54 @@ class Req:
self.last_update_decode_tokens = 0 self.last_update_decode_tokens = 0
# Constrained decoding # Constrained decoding
self.regex_fsm = None self.regex_fsm: RegexGuide = None
self.regex_fsm_state = 0 self.regex_fsm_state: int = 0
self.jump_forward_map = None self.jump_forward_map: JumpForwardMap = None
# whether request reached finished condition # whether request reached finished condition
def finished(self) -> bool: def finished(self) -> bool:
return self.finished_reason is not None return self.finished_reason is not None
def partial_decode(self, ids): # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
first_token = self.tokenizer.convert_ids_to_tokens(ids[0]) def init_detokenize_incrementally(self):
first_token = ( first_iter = self.surr_offset is None or self.read_offset is None
first_token.decode() if isinstance(first_token, bytes) else first_token
if first_iter:
self.read_offset = len(self.origin_input_ids_unpadded)
self.surr_offset = max(
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
)
all_ids = self.origin_input_ids_unpadded + self.output_ids
surr_ids = all_ids[self.surr_offset : self.read_offset]
read_ids = all_ids[self.surr_offset :]
return surr_ids, read_ids, len(all_ids)
def detokenize_incrementally(self, inplace: bool = True):
surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
surr_text = self.tokenizer.decode(
surr_ids,
skip_special_tokens=self.sampling_params.skip_special_tokens,
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
) )
return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids) new_text = self.tokenizer.decode(
read_ids,
skip_special_tokens=self.sampling_params.skip_special_tokens,
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
)
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
new_text = new_text[len(surr_text) :]
if inplace:
self.decoded_text += new_text
self.surr_offset = self.read_offset
self.read_offset = num_all_tokens
return True, new_text
return False, ""
def max_new_tokens(self): def max_new_tokens(self):
return self.sampling_params.max_new_tokens return self.sampling_params.max_new_tokens
...@@ -131,18 +173,17 @@ class Req: ...@@ -131,18 +173,17 @@ class Req:
if self.finished(): if self.finished():
return return
if ( if len(self.output_ids) >= self.sampling_params.max_new_tokens:
len(self.prev_output_ids) + len(self.output_ids) self.finished_reason = FINISH_LENGTH(len(self.output_ids))
>= self.sampling_params.max_new_tokens
):
self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids))
return return
if ( if (
self.output_ids[-1] == self.tokenizer.eos_token_id self.output_ids[-1] == self.tokenizer.eos_token_id
and not self.sampling_params.ignore_eos and not self.sampling_params.ignore_eos
): ):
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id) self.finished_reason = FINISH_MATCHED_TOKEN(
matched=self.tokenizer.eos_token_id
)
return return
if len(self.sampling_params.stop_strs) > 0: if len(self.sampling_params.stop_strs) > 0:
...@@ -151,61 +192,59 @@ class Req: ...@@ -151,61 +192,59 @@ class Req:
) )
for stop_str in self.sampling_params.stop_strs: for stop_str in self.sampling_params.stop_strs:
# FIXME: (minor) try incremental match in prev_output_str if stop_str in tail_str or stop_str in self.decoded_text:
if stop_str in tail_str or stop_str in self.prev_output_str:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return return
def jump_forward_and_retokenize(self, jump_forward_str, next_state): def jump_forward_and_retokenize(self, jump_forward_str, next_state):
# FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
cur_output_str = self.partial_decode(self.output_ids)
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
if self.origin_input_text is None: if self.origin_input_text is None:
# Recovering text can only use unpadded ids # Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode( self.origin_input_text = self.tokenizer.decode(
self.origin_input_ids_unpadded self.origin_input_ids_unpadded
) )
all_text = ( all_text = self.origin_input_text + self.decoded_text + jump_forward_str
self.origin_input_text
+ self.prev_output_str
+ cur_output_str
+ jump_forward_str
)
all_ids = self.tokenizer.encode(all_text) all_ids = self.tokenizer.encode(all_text)
prompt_tokens = len(self.origin_input_ids_unpadded) prompt_tokens = len(self.origin_input_ids_unpadded)
self.origin_input_ids = all_ids[:prompt_tokens]
self.origin_input_ids_unpadded = self.origin_input_ids if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# NOTE: the output ids may not strictly correspond to the output text # TODO(lsyin): fix token fusion
old_prev_output_ids = self.prev_output_ids warnings.warn(
self.prev_output_ids = all_ids[prompt_tokens:] "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str )
self.output_ids = [] return False
old_output_ids = self.output_ids
self.output_ids = all_ids[prompt_tokens:]
self.decoded_text = self.decoded_text + jump_forward_str
self.surr_offset = prompt_tokens
self.read_offset = len(all_ids)
# NOTE: A trick to reduce the surrouding tokens decoding overhead
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
surr_text_ = self.tokenizer.decode(
all_ids[self.read_offset - i : self.read_offset]
)
if not surr_text_.endswith("�"):
self.surr_offset = self.read_offset - i
break
self.regex_fsm_state = next_state self.regex_fsm_state = next_state
if self.return_logprob: if self.return_logprob:
# For fast-forward part's logprobs # For fast-forward part's logprobs
k = 0 k = 0
for i, old_id in enumerate(old_prev_output_ids): for i, old_id in enumerate(old_output_ids):
if old_id == self.prev_output_ids[i]: if old_id == self.output_ids[i]:
k = k + 1 k = k + 1
else: else:
break break
self.decode_token_logprobs = self.decode_token_logprobs[:k] self.decode_token_logprobs = self.decode_token_logprobs[:k]
self.decode_top_logprobs = self.decode_top_logprobs[:k] self.decode_top_logprobs = self.decode_top_logprobs[:k]
self.logprob_start_len = prompt_tokens + k self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.prev_output_ids) - k self.last_update_decode_tokens = len(self.output_ids) - k
# print("=" * 100) return True
# print(f"Catch jump forward:\n{jump_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
# print("*" * 100)
def __repr__(self): def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
...@@ -381,7 +420,10 @@ class Batch: ...@@ -381,7 +420,10 @@ class Batch:
sorted_indices = [i for i in range(len(self.reqs))] sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve the priority of retraction # TODO(lsyin): improve the priority of retraction
sorted_indices.sort( sorted_indices.sort(
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)), key=lambda i: (
len(self.reqs[i].output_ids),
-len(self.reqs[i].origin_input_ids),
),
reverse=True, reverse=True,
) )
...@@ -403,14 +445,9 @@ class Batch: ...@@ -403,14 +445,9 @@ class Batch:
# release the last node # release the last node
self.tree_cache.dec_lock_ref(req.last_node) self.tree_cache.dec_lock_ref(req.last_node)
cur_output_str = req.partial_decode(req.output_ids)
req.prev_output_str = req.prev_output_str + cur_output_str
req.prev_output_ids.extend(req.output_ids)
req.prefix_indices = None req.prefix_indices = None
req.last_node = None req.last_node = None
req.extend_input_len = 0 req.extend_input_len = 0
req.output_ids = []
# For incremental logprobs # For incremental logprobs
req.last_update_decode_tokens = 0 req.last_update_decode_tokens = 0
...@@ -428,18 +465,53 @@ class Batch: ...@@ -428,18 +465,53 @@ class Batch:
for i, req in enumerate(self.reqs): for i, req in enumerate(self.reqs):
if req.jump_forward_map is not None: if req.jump_forward_map is not None:
res = req.jump_forward_map.jump_forward(req.regex_fsm_state) jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
if res is not None: req.regex_fsm_state
jump_forward_str, next_state = res )
if len(jump_forward_str) <= 1: if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = req.regex_fsm_state
while (
len(jump_forward_bytes)
and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
# Current ids, for cache and revert
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
cur_output_ids = req.output_ids
req.output_ids.extend(suffix_ids)
decode_res, new_text = req.detokenize_incrementally(inplace=False)
if not decode_res:
req.output_ids = cur_output_ids
continue continue
if req_pool_indices_cpu is None: jump_forward_str, next_state = (
req_pool_indices_cpu = self.req_pool_indices.tolist() req.jump_forward_map.jump_forward_symbol(cur_state)
)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize(
jump_forward_str, next_state
):
req.output_ids = cur_output_ids
continue
# insert the old request into tree_cache # insert the old request into tree_cache
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist()
self.tree_cache.cache_req( self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1], token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices), last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i], req_pool_idx=req_pool_indices_cpu[i],
) )
...@@ -447,9 +519,6 @@ class Batch: ...@@ -447,9 +519,6 @@ class Batch:
# unlock the last node # unlock the last node
self.tree_cache.dec_lock_ref(req.last_node) self.tree_cache.dec_lock_ref(req.last_node)
# jump-forward
req.jump_forward_and_retokenize(jump_forward_str, next_state)
# re-applying image padding # re-applying image padding
if req.pixel_values is not None: if req.pixel_values is not None:
( (
...@@ -583,7 +652,7 @@ class Batch: ...@@ -583,7 +652,7 @@ class Batch:
if req.regex_fsm is not None: if req.regex_fsm is not None:
allowed_mask.zero_() allowed_mask.zero_()
allowed_mask[ allowed_mask[
req.regex_fsm.allowed_token_ids(req.regex_fsm_state) req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
] = 1 ] = 1
logits[i].masked_fill_(~allowed_mask, float("-inf")) logits[i].masked_fill_(~allowed_mask, float("-inf"))
...@@ -602,7 +671,7 @@ class Batch: ...@@ -602,7 +671,7 @@ class Batch:
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
for i, req in enumerate(self.reqs): for i, req in enumerate(self.reqs):
if req.regex_fsm is not None: if req.regex_fsm is not None:
req.regex_fsm_state = req.regex_fsm.next_state( req.regex_fsm_state = req.regex_fsm.get_next_state(
req.regex_fsm_state, batch_next_token_ids_cpu[i] req.regex_fsm_state, batch_next_token_ids_cpu[i]
) )
......
...@@ -21,7 +21,13 @@ from sglang.srt.managers.io_struct import ( ...@@ -21,7 +21,13 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req from sglang.srt.managers.controller.infer_batch import (
BaseFinishReason,
Batch,
FINISH_ABORT,
ForwardMode,
Req,
)
from sglang.srt.managers.controller.model_runner import ModelRunner from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
...@@ -98,8 +104,11 @@ class ModelTpServer: ...@@ -98,8 +104,11 @@ class ModelTpServer:
else server_args.max_prefill_tokens else server_args.max_prefill_tokens
), ),
) )
self.max_running_requests = (self.max_total_num_tokens // 2 self.max_running_requests = (
if server_args.max_running_requests is None else server_args.max_running_requests) self.max_total_num_tokens // 2
if server_args.max_running_requests is None
else server_args.max_running_requests
)
self.int_token_logit_bias = torch.tensor( self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
) )
...@@ -314,10 +323,7 @@ class ModelTpServer: ...@@ -314,10 +323,7 @@ class ModelTpServer:
# Compute matched prefix length # Compute matched prefix length
for req in self.forward_queue: for req in self.forward_queue:
assert ( req.input_ids = req.origin_input_ids + req.output_ids
len(req.output_ids) == 0
), "The output ids should be empty when prefilling"
req.input_ids = req.origin_input_ids + req.prev_output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob: if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len] prefix_indices = prefix_indices[: req.logprob_start_len]
...@@ -464,7 +470,7 @@ class ModelTpServer: ...@@ -464,7 +470,7 @@ class ModelTpServer:
pt = 0 pt = 0
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids = [next_token_ids[i]] req.output_ids.append(next_token_ids[i])
req.check_finished() req.check_finished()
if req.return_logprob: if req.return_logprob:
...@@ -524,7 +530,7 @@ class ModelTpServer: ...@@ -524,7 +530,7 @@ class ModelTpServer:
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req( new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1], token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices), last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i], req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False, del_in_memory_pool=False,
...@@ -596,8 +602,9 @@ class ModelTpServer: ...@@ -596,8 +602,9 @@ class ModelTpServer:
def handle_finished_requests(self, batch: Batch): def handle_finished_requests(self, batch: Batch):
output_rids = [] output_rids = []
prev_output_strs = [] decoded_texts = []
output_tokens = [] surr_output_ids = []
read_output_ids = []
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
output_meta_info = [] output_meta_info = []
...@@ -620,8 +627,10 @@ class ModelTpServer: ...@@ -620,8 +627,10 @@ class ModelTpServer:
) )
): ):
output_rids.append(req.rid) output_rids.append(req.rid)
prev_output_strs.append(req.prev_output_str) decoded_texts.append(req.decoded_text)
output_tokens.append(req.output_ids) surr_ids, read_ids, _ = req.init_detokenize_incrementally()
surr_output_ids.append(surr_ids)
read_output_ids.append(read_ids)
output_skip_special_tokens.append( output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens req.sampling_params.skip_special_tokens
) )
...@@ -631,7 +640,7 @@ class ModelTpServer: ...@@ -631,7 +640,7 @@ class ModelTpServer:
meta_info = { meta_info = {
"prompt_tokens": len(req.origin_input_ids), "prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids), "completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": str(req.finished_reason), "finish_reason": str(req.finished_reason),
} }
...@@ -657,8 +666,9 @@ class ModelTpServer: ...@@ -657,8 +666,9 @@ class ModelTpServer:
self.out_pyobjs.append( self.out_pyobjs.append(
BatchTokenIDOut( BatchTokenIDOut(
output_rids, output_rids,
prev_output_strs, decoded_texts,
output_tokens, surr_output_ids,
read_output_ids,
output_skip_special_tokens, output_skip_special_tokens,
output_spaces_between_special_tokens, output_spaces_between_special_tokens,
output_meta_info, output_meta_info,
...@@ -673,7 +683,7 @@ class ModelTpServer: ...@@ -673,7 +683,7 @@ class ModelTpServer:
for i in finished_indices: for i in finished_indices:
req = batch.reqs[i] req = batch.reqs[i]
self.tree_cache.cache_req( self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1], token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices), last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i], req_pool_idx=req_pool_indices_cpu[i],
) )
...@@ -790,4 +800,4 @@ class ModelTpClient: ...@@ -790,4 +800,4 @@ class ModelTpClient:
return _func return _func
self.step = async_wrap("step") self.step = async_wrap("step")
\ No newline at end of file
...@@ -39,30 +39,24 @@ class DetokenizerManager: ...@@ -39,30 +39,24 @@ class DetokenizerManager:
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
assert isinstance(recv_obj, BatchTokenIDOut) assert isinstance(recv_obj, BatchTokenIDOut)
output_tokens = recv_obj.output_tokens
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs = self.tokenizer.batch_decode( surr_texts = self.tokenizer.batch_decode(
output_tokens, recv_obj.surr_output_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
read_texts = self.tokenizer.batch_decode(
recv_obj.read_output_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0], skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
0
],
) )
# Trim stop str # Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit # TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)): output_strs = []
if len(output_tokens[i]) > 0: for i in range(len(recv_obj.rids)):
first_token = self.tokenizer.convert_ids_to_tokens( new_text = read_texts[i][len(surr_texts[i]) :]
int(output_tokens[i][0]) output_strs.append(recv_obj.decoded_texts[i] + new_text)
)
if not isinstance(first_token, str):
first_token = first_token.decode("utf-8", errors="ignore")
if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i]
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched) pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
......
...@@ -111,13 +111,15 @@ class TokenizedGenerateReqInput: ...@@ -111,13 +111,15 @@ class TokenizedGenerateReqInput:
@dataclass @dataclass
class BatchTokenIDOut: class BatchTokenIDOut:
rids: List[str] rids: List[str]
prev_output_strs: List[str] decoded_texts: List[str]
output_tokens: List[List[int]] surr_output_ids: List[List[int]]
read_output_ids: List[List[int]]
skip_special_tokens: List[bool] skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool] spaces_between_special_tokens: List[bool]
meta_info: List[Dict] meta_info: List[Dict]
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
@dataclass @dataclass
class BatchStrOut: class BatchStrOut:
rids: List[str] rids: List[str]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment