Unverified Commit 26f0bedc authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

jump-forward rename (#144)

parent 82fa69b3
...@@ -219,7 +219,7 @@ def main(args): ...@@ -219,7 +219,7 @@ def main(args):
with open(args.result_file, "a") as fout: with open(args.result_file, "a") as fout:
value = { value = {
"task": "json_fast_forward", "task": "json_jump_forward",
"backend": args.backend, "backend": args.backend,
"latency": round(latency, 3), "latency": round(latency, 3),
"num_jsons": args.num_jsons, "num_jsons": args.num_jsons,
......
...@@ -122,7 +122,7 @@ def main(args): ...@@ -122,7 +122,7 @@ def main(args):
with open(args.result_file, "a") as fout: with open(args.result_file, "a") as fout:
value = { value = {
"task": "json_fast_forward", "task": "json_jump_forward",
"backend": args.backend, "backend": args.backend,
"latency": round(latency, 3), "latency": round(latency, 3),
"num_jsons": args.num_jsons, "num_jsons": args.num_jsons,
......
...@@ -6,10 +6,10 @@ from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm ...@@ -6,10 +6,10 @@ from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
class FastForwardMap: class JumpForwardMap:
def __init__(self, regex_string): def __init__(self, regex_string):
@disk_cache() @disk_cache()
def _init_state_to_fast_forward(regex_string): def _init_state_to_jump_forward(regex_string):
regex_pattern = interegular.parse_pattern(regex_string) regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
...@@ -22,54 +22,54 @@ class FastForwardMap: ...@@ -22,54 +22,54 @@ class FastForwardMap:
transitions = fsm_info.transitions transitions = fsm_info.transitions
dirty_states = set() dirty_states = set()
state_to_fast_forward = {} state_to_jump_forward = {}
for (state, id_), next_state in transitions.items(): for (state, id_), next_state in transitions.items():
if state in dirty_states: if state in dirty_states:
continue continue
if state in state_to_fast_forward: if state in state_to_jump_forward:
dirty_states.add(state) dirty_states.add(state)
del state_to_fast_forward[state] del state_to_jump_forward[state]
continue continue
if len(id_to_symbol[id_]) > 1: if len(id_to_symbol[id_]) > 1:
dirty_states.add(state) dirty_states.add(state)
continue continue
state_to_fast_forward[state] = (id_to_symbol[id_][0], next_state) state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
return state_to_fast_forward return state_to_jump_forward
self.state_to_fast_forward = _init_state_to_fast_forward(regex_string) self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
def valid_states(self): def valid_states(self):
return self.state_to_fast_forward.keys() return self.state_to_jump_forward.keys()
def fast_forward(self, state): def jump_forward(self, state):
if state not in self.state_to_fast_forward: if state not in self.state_to_jump_forward:
return None return None
fast_forward_str = "" jump_forward_str = ""
next_state = None next_state = None
while state in self.state_to_fast_forward: while state in self.state_to_jump_forward:
symbol, next_state = self.state_to_fast_forward[state] symbol, next_state = self.state_to_jump_forward[state]
fast_forward_str += symbol jump_forward_str += symbol
state = next_state state = next_state
return fast_forward_str, next_state return jump_forward_str, next_state
class FastForwardCache(BaseCache): class JumpForwardCache(BaseCache):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def init_value(self, regex): def init_value(self, regex):
return FastForwardMap(regex) return JumpForwardMap(regex)
def test_main(): def test_main():
regex_string = r"The google's DNS sever address is " + IP_REGEX regex_string = r"The google's DNS sever address is " + IP_REGEX
fast_forward_map = FastForwardMap(regex_string) jump_forward_map = JumpForwardMap(regex_string)
for state in fast_forward_map.valid_states(): for state in jump_forward_map.valid_states():
print(state, f'"{fast_forward_map.fast_forward(state)}"') print(state, f'"{jump_forward_map.jump_forward(state)}"')
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -61,7 +61,7 @@ class DetokenizerManager: ...@@ -61,7 +61,7 @@ class DetokenizerManager:
output_strs[i] = " " + output_strs[i] output_strs[i] = " " + output_strs[i]
output_strs[i] = ( output_strs[i] = (
recv_obj.output_and_fast_forward_strs[i] + output_strs[i] recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
) )
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
......
...@@ -81,7 +81,7 @@ class TokenizedGenerateReqInput: ...@@ -81,7 +81,7 @@ class TokenizedGenerateReqInput:
class BatchTokenIDOut: class BatchTokenIDOut:
rids: List[str] rids: List[str]
output_tokens: List[List[int]] output_tokens: List[List[int]]
output_and_fast_forward_strs: List[str] output_and_jump_forward_strs: List[str]
hit_stop_str: List[Optional[str]] hit_stop_str: List[Optional[str]]
skip_special_tokens: List[bool] skip_special_tokens: List[bool]
meta_info: List[Dict] meta_info: List[Dict]
......
...@@ -53,13 +53,13 @@ class Req: ...@@ -53,13 +53,13 @@ class Req:
# For constrained decoding # For constrained decoding
self.regex_fsm = None self.regex_fsm = None
self.regex_fsm_state = 0 self.regex_fsm_state = 0
self.fast_forward_map = None self.jump_forward_map = None
self.output_and_fast_forward_str = "" self.output_and_jump_forward_str = ""
def max_new_tokens(self): def max_new_tokens(self):
return self.sampling_params.max_new_tokens return self.sampling_params.max_new_tokens
def fast_forward_and_retokenize(self, fast_forward_str, next_state): def jump_forward_and_retokenize(self, jump_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids) old_output_str = self.tokenizer.decode(self.output_ids)
# FIXME: This logic does not really solve the problem of determining whether # FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space. # there should be a leading space.
...@@ -71,35 +71,35 @@ class Req: ...@@ -71,35 +71,35 @@ class Req:
old_output_str = " " + old_output_str old_output_str = " " + old_output_str
new_input_string = ( new_input_string = (
self.input_text self.input_text
+ self.output_and_fast_forward_str + self.output_and_jump_forward_str
+ old_output_str + old_output_str
+ fast_forward_str + jump_forward_str
) )
new_input_ids = self.tokenizer.encode(new_input_string) new_input_ids = self.tokenizer.encode(new_input_string)
if self.pixel_values is not None: if self.pixel_values is not None:
# NOTE: This is a hack because the old input_ids contains the image padding # NOTE: This is a hack because the old input_ids contains the image padding
fast_forward_tokens_len = len(self.tokenizer.encode(fast_forward_str)) jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
else: else:
fast_forward_tokens_len = ( jump_forward_tokens_len = (
len(new_input_ids) - len(self.input_ids) - len(self.output_ids) len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
) )
# print("=" * 100) # print("=" * 100)
# print(f"Catch fast forward:\n{fast_forward_str}") # print(f"Catch jump forward:\n{jump_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids)) # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids)) # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
self.input_ids = new_input_ids self.input_ids = new_input_ids
self.output_ids = [] self.output_ids = []
self.sampling_params.max_new_tokens = max( self.sampling_params.max_new_tokens = max(
self.sampling_params.max_new_tokens - fast_forward_tokens_len, 0 self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
) )
self.regex_fsm_state = next_state self.regex_fsm_state = next_state
self.output_and_fast_forward_str = ( self.output_and_jump_forward_str = (
self.output_and_fast_forward_str + old_output_str + fast_forward_str self.output_and_jump_forward_str + old_output_str + jump_forward_str
) )
# print(f"Output and fast forward str:\n{self.output_and_fast_forward_str}") # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
# print("*" * 100) # print("*" * 100)
def check_finished(self): def check_finished(self):
...@@ -327,18 +327,18 @@ class Batch: ...@@ -327,18 +327,18 @@ class Batch:
return retracted_reqs return retracted_reqs
def check_for_fast_forward(self): def check_for_jump_forward(self):
fast_forward_reqs = [] jump_forward_reqs = []
filter_indices = [i for i in range(len(self.reqs))] filter_indices = [i for i in range(len(self.reqs))]
req_pool_indices_cpu = None req_pool_indices_cpu = None
for i, req in enumerate(self.reqs): for i, req in enumerate(self.reqs):
if req.fast_forward_map is not None: if req.jump_forward_map is not None:
res = req.fast_forward_map.fast_forward(req.regex_fsm_state) res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
if res is not None: if res is not None:
fast_forward_str, next_state = res jump_forward_str, next_state = res
if len(fast_forward_str) <= 1: if len(jump_forward_str) <= 1:
continue continue
# insert the old request into tree_cache # insert the old request into tree_cache
...@@ -356,16 +356,16 @@ class Batch: ...@@ -356,16 +356,16 @@ class Batch:
self.req_to_token_pool.free(req_pool_idx) self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node) self.tree_cache.dec_ref_counter(req.last_node)
# fast forward # jump-forward
req.fast_forward_and_retokenize(fast_forward_str, next_state) req.jump_forward_and_retokenize(jump_forward_str, next_state)
fast_forward_reqs.append(req) jump_forward_reqs.append(req)
filter_indices.remove(i) filter_indices.remove(i)
if len(filter_indices) < len(self.reqs): if len(filter_indices) < len(self.reqs):
self.filter_batch(filter_indices) self.filter_batch(filter_indices)
return fast_forward_reqs return jump_forward_reqs
def prepare_for_decode(self, input_ids=None): def prepare_for_decode(self, input_ids=None):
if input_ids is None: if input_ids is None:
......
...@@ -11,7 +11,7 @@ import rpyc ...@@ -11,7 +11,7 @@ import rpyc
import torch import torch
from rpyc.utils.classic import obtain from rpyc.utils.classic import obtain
from rpyc.utils.server import ThreadedServer from rpyc.utils.server import ThreadedServer
from sglang.srt.constrained.fast_forward import FastForwardCache from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -49,7 +49,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -49,7 +49,7 @@ class ModelRpcServer(rpyc.Service):
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_heuristic = server_args.schedule_heuristic
self.no_regex_fast_forward = server_args.no_regex_fast_forward self.no_regex_jump_forward = server_args.no_regex_jump_forward
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
...@@ -127,7 +127,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -127,7 +127,7 @@ class ModelRpcServer(rpyc.Service):
"trust_remote_code": server_args.trust_remote_code, "trust_remote_code": server_args.trust_remote_code,
}, },
) )
self.fast_forward_cache = FastForwardCache() self.jump_forward_cache = JumpForwardCache()
# Init new token estimation # Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0) self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
...@@ -254,8 +254,8 @@ class ModelRpcServer(rpyc.Service): ...@@ -254,8 +254,8 @@ class ModelRpcServer(rpyc.Service):
# Init regex fsm # Init regex fsm
if req.sampling_params.regex is not None: if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.no_regex_fast_forward: if not self.no_regex_jump_forward:
req.fast_forward_map = self.fast_forward_cache.query( req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex req.sampling_params.regex
) )
...@@ -369,8 +369,8 @@ class ModelRpcServer(rpyc.Service): ...@@ -369,8 +369,8 @@ class ModelRpcServer(rpyc.Service):
logger.debug( logger.debug(
f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
f"ff_cache_hit_rate: {100.0 * self.fast_forward_cache.get_cache_hit_rate():.2f}%. " f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
f"ff_cache_avg_init_time: {self.fast_forward_cache.get_avg_init_time():.2f}s. " f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
) )
new_batch = Batch.init_new( new_batch = Batch.init_new(
...@@ -437,12 +437,12 @@ class ModelRpcServer(rpyc.Service): ...@@ -437,12 +437,12 @@ class ModelRpcServer(rpyc.Service):
self.min_new_token_ratio, self.min_new_token_ratio,
) )
if not self.no_regex_fast_forward: if not self.no_regex_jump_forward:
# check for fast forward # check for jump-forward
fast_forward_reqs = batch.check_for_fast_forward() jump_forward_reqs = batch.check_for_jump_forward()
# check for image fast forward # check for image jump-forward
for req in fast_forward_reqs: for req in jump_forward_reqs:
if req.pixel_values is not None: if req.pixel_values is not None:
( (
req.input_ids, req.input_ids,
...@@ -454,7 +454,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -454,7 +454,7 @@ class ModelRpcServer(rpyc.Service):
req.image_size, req.image_size,
) )
self.forward_queue.extend(fast_forward_reqs) self.forward_queue.extend(jump_forward_reqs)
if batch.is_empty(): if batch.is_empty():
return return
...@@ -478,7 +478,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -478,7 +478,7 @@ class ModelRpcServer(rpyc.Service):
def handle_finished_requests(self, batch: Batch): def handle_finished_requests(self, batch: Batch):
output_rids = [] output_rids = []
output_tokens = [] output_tokens = []
output_and_fast_forward_strs = [] output_and_jump_forward_strs = []
output_hit_stop_str = [] output_hit_stop_str = []
output_skip_special_tokens = [] output_skip_special_tokens = []
output_meta_info = [] output_meta_info = []
...@@ -502,7 +502,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -502,7 +502,7 @@ class ModelRpcServer(rpyc.Service):
): ):
output_rids.append(req.rid) output_rids.append(req.rid)
output_tokens.append(req.output_ids) output_tokens.append(req.output_ids)
output_and_fast_forward_strs.append(req.output_and_fast_forward_str) output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
output_hit_stop_str.append(req.hit_stop_str) output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append( output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens req.sampling_params.skip_special_tokens
...@@ -523,7 +523,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -523,7 +523,7 @@ class ModelRpcServer(rpyc.Service):
BatchTokenIDOut( BatchTokenIDOut(
output_rids, output_rids,
output_tokens, output_tokens,
output_and_fast_forward_strs, output_and_jump_forward_strs,
output_hit_stop_str, output_hit_stop_str,
output_skip_special_tokens, output_skip_special_tokens,
output_meta_info, output_meta_info,
......
...@@ -25,7 +25,7 @@ class ServerArgs: ...@@ -25,7 +25,7 @@ class ServerArgs:
disable_log_stats: bool = False disable_log_stats: bool = False
log_stats_interval: int = 10 log_stats_interval: int = 10
log_level: str = "info" log_level: str = "info"
no_regex_fast_forward: bool = False no_regex_jump_forward: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
...@@ -172,9 +172,9 @@ class ServerArgs: ...@@ -172,9 +172,9 @@ class ServerArgs:
help="Log stats interval in second.", help="Log stats interval in second.",
) )
parser.add_argument( parser.add_argument(
"--no-regex-fast-forward", "--no-regex-jump-forward",
action="store_true", action="store_true",
help="Disable regex fast forward", help="Disable regex jump-forward",
) )
@classmethod @classmethod
......
...@@ -12,7 +12,7 @@ import sglang as sgl ...@@ -12,7 +12,7 @@ import sglang as sgl
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
ip_fast_forward = ( ip_jump_forward = (
r"The google's DNS sever address is " r"The google's DNS sever address is "
+ IP_REGEX + IP_REGEX
+ r" and " + r" and "
...@@ -32,11 +32,11 @@ def regex_gen(s): ...@@ -32,11 +32,11 @@ def regex_gen(s):
"answer", "answer",
max_tokens=128, max_tokens=128,
temperature=0, temperature=0,
regex=ip_fast_forward, regex=ip_jump_forward,
) )
# fmt: on # fmt: on
json_fast_forward = ( json_jump_forward = (
r"""The information about Hogwarts is in the following JSON format\.\n""" r"""The information about Hogwarts is in the following JSON format\.\n"""
+ r"""\n\{\n""" + r"""\n\{\n"""
+ r""" "name": "[\w\d\s]*",\n""" + r""" "name": "[\w\d\s]*",\n"""
...@@ -54,7 +54,7 @@ def json_gen(s): ...@@ -54,7 +54,7 @@ def json_gen(s):
"json", "json",
max_tokens=128, max_tokens=128,
temperature=0, temperature=0,
regex=json_fast_forward, regex=json_jump_forward,
) )
# fmt: on # fmt: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment