Unverified Commit 39191c85 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Cache optimizations (#418)

parent 562b8857
...@@ -25,5 +25,8 @@ class GlobalConfig: ...@@ -25,5 +25,8 @@ class GlobalConfig:
# adjust_cache: Adjust the position embedding of KV cache. # adjust_cache: Adjust the position embedding of KV cache.
self.concate_and_append_mode = "no_adjust" self.concate_and_append_mode = "no_adjust"
# Request dependency time due to network delay
self.request_dependency_time = 0.03
global_config = GlobalConfig() global_config = GlobalConfig()
"""
Backend configurations, may vary with different serving platforms.
"""
from dataclasses import dataclass
@dataclass
class BackendConfig:
extend_dependency_time: float = 0.03
GLOBAL_BACKEND_CONFIG = BackendConfig()
...@@ -335,20 +335,20 @@ class Batch: ...@@ -335,20 +335,20 @@ class Batch:
req = self.reqs[idx] req = self.reqs[idx]
retracted_reqs.append(req) retracted_reqs.append(req)
self.tree_cache.dec_ref_counter(req.last_node) # TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices)
self.tree_cache.dec_lock_ref(req.last_node)
req.prefix_indices = None req.prefix_indices = None
req.last_node = None req.last_node = None
req.extend_input_len = 0 req.extend_input_len = 0
req.output_ids = [] req.output_ids = []
req.regex_fsm_state = 0 req.regex_fsm_state = 0
# TODO: apply more fine-grained retraction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][: seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices)
self.filter_batch(sorted_indices) self.filter_batch(sorted_indices)
return retracted_reqs return retracted_reqs
...@@ -367,20 +367,18 @@ class Batch: ...@@ -367,20 +367,18 @@ class Batch:
if len(jump_forward_str) <= 1: if len(jump_forward_str) <= 1:
continue continue
# insert the old request into tree_cache
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
if req_pool_indices_cpu is None: if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist() req_pool_indices_cpu = self.req_pool_indices.tolist()
req_pool_idx = req_pool_indices_cpu[i]
indices = self.req_to_token_pool.req_to_token[ # insert the old request into tree_cache
req_pool_idx, : len(token_ids_in_memory) self.tree_cache.cache_req(
] token_ids=tuple(req.input_ids + req.output_ids)[:-1],
prefix_len = self.tree_cache.insert( last_uncached_pos=len(req.prefix_indices),
token_ids_in_memory, indices.clone() req_pool_idx=req_pool_indices_cpu[i],
) )
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx) # unlock the last node
self.tree_cache.dec_ref_counter(req.last_node) self.tree_cache.dec_lock_ref(req.last_node)
# jump-forward # jump-forward
req.jump_forward_and_retokenize(jump_forward_str, next_state) req.jump_forward_and_retokenize(jump_forward_str, next_state)
......
...@@ -5,7 +5,7 @@ import uvloop ...@@ -5,7 +5,7 @@ import uvloop
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG from sglang import global_config
from sglang.srt.managers.router.model_rpc import ModelRpcClient from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback from sglang.srt.utils import get_exception_traceback
...@@ -30,7 +30,7 @@ class RouterManager: ...@@ -30,7 +30,7 @@ class RouterManager:
self.recv_reqs = [] self.recv_reqs = []
# Init some configs # Init some configs
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time self.request_dependency_time = global_config.request_dependency_time
async def loop_for_forward(self): async def loop_for_forward(self):
while True: while True:
...@@ -46,9 +46,9 @@ class RouterManager: ...@@ -46,9 +46,9 @@ class RouterManager:
if len(out_pyobjs) != 0: if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs]) has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished: if has_finished:
if self.extend_dependency_time > 0: if self.request_dependency_time > 0:
slept = True slept = True
await asyncio.sleep(self.extend_dependency_time) await asyncio.sleep(self.request_dependency_time)
if not slept: if not slept:
await asyncio.sleep(0.0006) await asyncio.sleep(0.0006)
......
...@@ -117,7 +117,11 @@ class ModelRpcServer: ...@@ -117,7 +117,11 @@ class ModelRpcServer:
logger.info(f"server_args: {server_args.print_mode_args()}") logger.info(f"server_args: {server_args.print_mode_args()}")
# Init cache # Init cache
self.tree_cache = RadixCache(disable=server_args.disable_radix_cache) self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler( self.scheduler = Scheduler(
self.schedule_heuristic, self.schedule_heuristic,
...@@ -203,6 +207,8 @@ class ModelRpcServer: ...@@ -203,6 +207,8 @@ class ModelRpcServer:
# Run new fill batch # Run new fill batch
self.forward_fill_batch(new_batch) self.forward_fill_batch(new_batch)
self.cache_filled_batch(new_batch)
if not new_batch.is_empty(): if not new_batch.is_empty():
if self.running_batch is None: if self.running_batch is None:
self.running_batch = new_batch self.running_batch = new_batch
...@@ -349,20 +355,19 @@ class ModelRpcServer: ...@@ -349,20 +355,19 @@ class ModelRpcServer:
and req.extend_input_len + new_batch_input_tokens and req.extend_input_len + new_batch_input_tokens
< self.max_prefill_num_token < self.max_prefill_num_token
): ):
delta = self.tree_cache.inc_ref_counter(req.last_node) delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta available_size += delta
if not ( if not (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size < available_size
): ):
# Undo the insertion # Undo locking
delta = self.tree_cache.dec_ref_counter(req.last_node) delta = self.tree_cache.dec_lock_ref(req.last_node)
available_size += delta available_size += delta
break break
else: else:
# Add this request to the running batch # Add this request to the running batch
self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req) can_run_list.append(req)
new_batch_total_tokens += ( new_batch_total_tokens += (
req.extend_input_len + req.max_new_tokens() req.extend_input_len + req.max_new_tokens()
...@@ -477,6 +482,18 @@ class ModelRpcServer: ...@@ -477,6 +482,18 @@ class ModelRpcServer:
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False,
old_last_node=req.last_node,
)
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
def forward_decode_batch(self, batch: Batch): def forward_decode_batch(self, batch: Batch):
# check if decode out of memory # check if decode out of memory
if not batch.check_decode_mem(): if not batch.check_decode_mem():
...@@ -636,17 +653,13 @@ class ModelRpcServer: ...@@ -636,17 +653,13 @@ class ModelRpcServer:
req_pool_indices_cpu = batch.req_pool_indices.tolist() req_pool_indices_cpu = batch.req_pool_indices.tolist()
for i in finished_indices: for i in finished_indices:
req = batch.reqs[i] req = batch.reqs[i]
req_pool_idx = req_pool_indices_cpu[i] self.tree_cache.cache_req(
token_ids = tuple(req.input_ids + req.output_ids) token_ids=tuple(req.input_ids + req.output_ids)[:-1],
seq_len = len(token_ids) - 1 last_uncached_pos=len(req.prefix_indices),
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len] req_pool_idx=req_pool_indices_cpu[i],
prefix_len = self.tree_cache.insert(
token_ids[:seq_len], indices.clone()
) )
self.token_to_kv_pool.dec_refs(indices[:prefix_len]) self.tree_cache.dec_lock_ref(req.last_node)
self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node)
# Update batch tensors # Update batch tensors
if unfinished_indices: if unfinished_indices:
......
...@@ -11,7 +11,7 @@ class TreeNode: ...@@ -11,7 +11,7 @@ class TreeNode:
self.parent = None self.parent = None
self.key = None self.key = None
self.value = None self.value = None
self.ref_counter = 0 self.lock_ref = 0
self.last_access_time = time.time() self.last_access_time = time.time()
def __lt__(self, other: "TreeNode"): def __lt__(self, other: "TreeNode"):
...@@ -28,7 +28,9 @@ def _key_match(key0, key1): ...@@ -28,7 +28,9 @@ def _key_match(key0, key1):
class RadixCache: class RadixCache:
def __init__(self, disable: bool = False): def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.disable = disable self.disable = disable
self.reset() self.reset()
...@@ -38,7 +40,7 @@ class RadixCache: ...@@ -38,7 +40,7 @@ class RadixCache:
self.root_node = TreeNode() self.root_node = TreeNode()
self.root_node.key = [] self.root_node.key = []
self.root_node.value = [] self.root_node.value = []
self.root_node.ref_counter = 1 self.root_node.lock_ref = 1
self.evictable_size_ = 0 self.evictable_size_ = 0
def match_prefix(self, key): def match_prefix(self, key):
...@@ -50,6 +52,8 @@ class RadixCache: ...@@ -50,6 +52,8 @@ class RadixCache:
self._match_prefix_helper(self.root_node, key, value, last_node) self._match_prefix_helper(self.root_node, key, value, last_node)
if value: if value:
value = torch.concat(value) value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int64)
return value, last_node[0] return value, last_node[0]
def insert(self, key, value=None): def insert(self, key, value=None):
...@@ -60,6 +64,34 @@ class RadixCache: ...@@ -60,6 +64,34 @@ class RadixCache:
value = [x for x in key] value = [x for x in key]
return self._insert_helper(self.root_node, key, value) return self._insert_helper(self.root_node, key, value)
def cache_req(
self,
token_ids,
last_uncached_pos,
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
new_prefix_len = self.insert(token_ids, indices.clone())
# Radix Cache takes one ref in memory pool
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
if del_in_memory_pool:
self.req_to_token_pool.free(req_pool_idx)
else:
cached_indices, new_last_node = self.match_prefix(token_ids)
assert len(cached_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req_pool_idx, last_uncached_pos : len(cached_indices)
] = cached_indices[last_uncached_pos:]
self.dec_lock_ref(old_last_node)
self.inc_lock_ref(new_last_node)
return cached_indices, new_last_node
def pretty_print(self): def pretty_print(self):
self._print_helper(self.root_node, 0) self._print_helper(self.root_node, 0)
print(f"#tokens: {self.total_size()}") print(f"#tokens: {self.total_size()}")
...@@ -80,7 +112,7 @@ class RadixCache: ...@@ -80,7 +112,7 @@ class RadixCache:
if x == self.root_node: if x == self.root_node:
break break
if x.ref_counter > 0: if x.lock_ref > 0:
continue continue
num_evicted += evict_callback(x.value) num_evicted += evict_callback(x.value)
...@@ -89,23 +121,23 @@ class RadixCache: ...@@ -89,23 +121,23 @@ class RadixCache:
if len(x.parent.children) == 0: if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent) heapq.heappush(leaves, x.parent)
def inc_ref_counter(self, node): def inc_lock_ref(self, node: TreeNode):
delta = 0 delta = 0
while node != self.root_node: while node != self.root_node:
if node.ref_counter == 0: if node.lock_ref == 0:
self.evictable_size_ -= len(node.value) self.evictable_size_ -= len(node.value)
delta -= len(node.value) delta -= len(node.value)
node.ref_counter += 1 node.lock_ref += 1
node = node.parent node = node.parent
return delta return delta
def dec_ref_counter(self, node): def dec_lock_ref(self, node: TreeNode):
delta = 0 delta = 0
while node != self.root_node: while node != self.root_node:
if node.ref_counter == 1: if node.lock_ref == 1:
self.evictable_size_ += len(node.value) self.evictable_size_ += len(node.value)
delta += len(node.value) delta += len(node.value)
node.ref_counter -= 1 node.lock_ref -= 1
node = node.parent node = node.parent
return delta return delta
...@@ -131,12 +163,12 @@ class RadixCache: ...@@ -131,12 +163,12 @@ class RadixCache:
last_node[0] = child last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node) self._match_prefix_helper(child, key[prefix_len:], value, last_node)
def _split_node(self, key, child, split_len): def _split_node(self, key, child: TreeNode, split_len):
# new_node -> child # new_node -> child
new_node = TreeNode() new_node = TreeNode()
new_node.children = {key[split_len:][0]: child} new_node.children = {key[split_len:][0]: child}
new_node.parent = child.parent new_node.parent = child.parent
new_node.ref_counter = child.ref_counter new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len] new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len] new_node.value = child.value[:split_len]
child.parent = new_node child.parent = new_node
...@@ -176,11 +208,9 @@ class RadixCache: ...@@ -176,11 +208,9 @@ class RadixCache:
self.evictable_size_ += len(value) self.evictable_size_ += len(value)
return 0 return 0
def _print_helper(self, node, indent): def _print_helper(self, node: TreeNode, indent):
for _, child in node.children.items(): for _, child in node.children.items():
print( print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
)
self._print_helper(child, indent=indent + 2) self._print_helper(child, indent=indent + 2)
def _delete_leaf(self, node): def _delete_leaf(self, node):
...@@ -211,7 +241,7 @@ class RadixCache: ...@@ -211,7 +241,7 @@ class RadixCache:
if __name__ == "__main__": if __name__ == "__main__":
tree = RadixCache() tree = RadixCache(None, None, False)
tree.insert("Hello") tree.insert("Hello")
tree.insert("Hello") tree.insert("Hello")
......
...@@ -27,44 +27,33 @@ class Scheduler: ...@@ -27,44 +27,33 @@ class Scheduler:
return forward_queue return forward_queue
elif self.schedule_heuristic == "fcfs": elif self.schedule_heuristic == "fcfs":
return forward_queue return forward_queue
elif self.schedule_heuristic == "weight": elif self.schedule_heuristic == "dfs-weight":
last_node_to_reqs = defaultdict(list) last_node_to_reqs = defaultdict(list)
for req in forward_queue: for req in forward_queue:
last_node_to_reqs[req.last_node].append(req) last_node_to_reqs[req.last_node].append(req)
for node in last_node_to_reqs:
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
node_to_weight = defaultdict(int) node_to_weight = defaultdict(int)
self._calc_weight_recursive( for node in last_node_to_reqs:
self.tree_cache.root_node, last_node_to_reqs, node_to_weight node_to_weight[node] = len(last_node_to_reqs[node])
) self.calc_weight(self.tree_cache.root_node, node_to_weight)
tmp_queue = [] q = []
self._get_weight_priority_recursive( self.get_dfs_priority(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
) )
assert len(tmp_queue) == len(forward_queue) assert len(q) == len(forward_queue)
return tmp_queue return q
else: else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight): def calc_weight(self, cur_node, node_to_weight):
node_to_weight[cur_node] = 1
if cur_node in last_node_to_reqs:
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
for child in cur_node.children.values(): for child in cur_node.children.values():
self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight) self.calc_weight(child, node_to_weight)
node_to_weight[cur_node] += node_to_weight[child] node_to_weight[cur_node] += node_to_weight[child]
def _get_weight_priority_recursive( def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue childs = [child for child in cur_node.children.values()]
): childs.sort(key=lambda x: -node_to_priority[x])
visit_list = [child for child in cur_node.children.values()] for child in childs:
visit_list.sort(key=lambda x: -node_to_wight[x]) self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
# for node in visit_list: q.extend(last_node_to_reqs[cur_node])
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
for child in visit_list:
self._get_weight_priority_recursive(
child, node_to_wight, last_node_to_reqs, tmp_queue
)
tmp_queue.extend(last_node_to_reqs[cur_node])
...@@ -149,7 +149,8 @@ class ServerArgs: ...@@ -149,7 +149,8 @@ class ServerArgs:
"--schedule-heuristic", "--schedule-heuristic",
type=str, type=str,
default=ServerArgs.schedule_heuristic, default=ServerArgs.schedule_heuristic,
help="Schudule mode: [lpm, weight, random, fcfs]", choices=["lpm", "random", "fcfs", "dfs-weight"],
help="Scheduling Heuristic.",
) )
parser.add_argument( parser.add_argument(
"--schedule-conservativeness", "--schedule-conservativeness",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment