Unverified Commit 39191c85 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Cache optimizations (#418)

parent 562b8857
......@@ -25,5 +25,8 @@ class GlobalConfig:
# adjust_cache: Adjust the position embedding of KV cache.
self.concate_and_append_mode = "no_adjust"
# Request dependency time due to network delay
self.request_dependency_time = 0.03
global_config = GlobalConfig()
"""
Backend configurations, may vary with different serving platforms.
"""
from dataclasses import dataclass
@dataclass
class BackendConfig:
extend_dependency_time: float = 0.03
GLOBAL_BACKEND_CONFIG = BackendConfig()
......@@ -335,20 +335,20 @@ class Batch:
req = self.reqs[idx]
retracted_reqs.append(req)
self.tree_cache.dec_ref_counter(req.last_node)
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices)
self.tree_cache.dec_lock_ref(req.last_node)
req.prefix_indices = None
req.last_node = None
req.extend_input_len = 0
req.output_ids = []
req.regex_fsm_state = 0
# TODO: apply more fine-grained retraction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][: seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices)
self.filter_batch(sorted_indices)
return retracted_reqs
......@@ -367,20 +367,18 @@ class Batch:
if len(jump_forward_str) <= 1:
continue
# insert the old request into tree_cache
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist()
req_pool_idx = req_pool_indices_cpu[i]
indices = self.req_to_token_pool.req_to_token[
req_pool_idx, : len(token_ids_in_memory)
]
prefix_len = self.tree_cache.insert(
token_ids_in_memory, indices.clone()
# insert the old request into tree_cache
self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
)
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node)
# unlock the last node
self.tree_cache.dec_lock_ref(req.last_node)
# jump-forward
req.jump_forward_and_retokenize(jump_forward_str, next_state)
......
......@@ -5,7 +5,7 @@ import uvloop
import zmq
import zmq.asyncio
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
from sglang import global_config
from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback
......@@ -30,7 +30,7 @@ class RouterManager:
self.recv_reqs = []
# Init some configs
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
self.request_dependency_time = global_config.request_dependency_time
async def loop_for_forward(self):
while True:
......@@ -46,9 +46,9 @@ class RouterManager:
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:
if self.extend_dependency_time > 0:
if self.request_dependency_time > 0:
slept = True
await asyncio.sleep(self.extend_dependency_time)
await asyncio.sleep(self.request_dependency_time)
if not slept:
await asyncio.sleep(0.0006)
......
......@@ -117,7 +117,11 @@ class ModelRpcServer:
logger.info(f"server_args: {server_args.print_mode_args()}")
# Init cache
self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = Scheduler(
self.schedule_heuristic,
......@@ -203,6 +207,8 @@ class ModelRpcServer:
# Run new fill batch
self.forward_fill_batch(new_batch)
self.cache_filled_batch(new_batch)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
......@@ -349,20 +355,19 @@ class ModelRpcServer:
and req.extend_input_len + new_batch_input_tokens
< self.max_prefill_num_token
):
delta = self.tree_cache.inc_ref_counter(req.last_node)
delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta
if not (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
):
# Undo the insertion
delta = self.tree_cache.dec_ref_counter(req.last_node)
# Undo locking
delta = self.tree_cache.dec_lock_ref(req.last_node)
available_size += delta
break
else:
# Add this request to the running batch
self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.max_new_tokens()
......@@ -477,6 +482,18 @@ class ModelRpcServer:
self.handle_finished_requests(batch)
def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False,
old_last_node=req.last_node,
)
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
def forward_decode_batch(self, batch: Batch):
# check if decode out of memory
if not batch.check_decode_mem():
......@@ -636,17 +653,13 @@ class ModelRpcServer:
req_pool_indices_cpu = batch.req_pool_indices.tolist()
for i in finished_indices:
req = batch.reqs[i]
req_pool_idx = req_pool_indices_cpu[i]
token_ids = tuple(req.input_ids + req.output_ids)
seq_len = len(token_ids) - 1
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
prefix_len = self.tree_cache.insert(
token_ids[:seq_len], indices.clone()
self.tree_cache.cache_req(
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
)
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node)
self.tree_cache.dec_lock_ref(req.last_node)
# Update batch tensors
if unfinished_indices:
......
......@@ -11,7 +11,7 @@ class TreeNode:
self.parent = None
self.key = None
self.value = None
self.ref_counter = 0
self.lock_ref = 0
self.last_access_time = time.time()
def __lt__(self, other: "TreeNode"):
......@@ -28,7 +28,9 @@ def _key_match(key0, key1):
class RadixCache:
def __init__(self, disable: bool = False):
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.disable = disable
self.reset()
......@@ -38,7 +40,7 @@ class RadixCache:
self.root_node = TreeNode()
self.root_node.key = []
self.root_node.value = []
self.root_node.ref_counter = 1
self.root_node.lock_ref = 1
self.evictable_size_ = 0
def match_prefix(self, key):
......@@ -50,6 +52,8 @@ class RadixCache:
self._match_prefix_helper(self.root_node, key, value, last_node)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int64)
return value, last_node[0]
def insert(self, key, value=None):
......@@ -60,6 +64,34 @@ class RadixCache:
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_req(
self,
token_ids,
last_uncached_pos,
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
new_prefix_len = self.insert(token_ids, indices.clone())
# Radix Cache takes one ref in memory pool
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
if del_in_memory_pool:
self.req_to_token_pool.free(req_pool_idx)
else:
cached_indices, new_last_node = self.match_prefix(token_ids)
assert len(cached_indices) == len(token_ids)
self.req_to_token_pool.req_to_token[
req_pool_idx, last_uncached_pos : len(cached_indices)
] = cached_indices[last_uncached_pos:]
self.dec_lock_ref(old_last_node)
self.inc_lock_ref(new_last_node)
return cached_indices, new_last_node
def pretty_print(self):
self._print_helper(self.root_node, 0)
print(f"#tokens: {self.total_size()}")
......@@ -80,7 +112,7 @@ class RadixCache:
if x == self.root_node:
break
if x.ref_counter > 0:
if x.lock_ref > 0:
continue
num_evicted += evict_callback(x.value)
......@@ -89,23 +121,23 @@ class RadixCache:
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
def inc_ref_counter(self, node):
def inc_lock_ref(self, node: TreeNode):
delta = 0
while node != self.root_node:
if node.ref_counter == 0:
if node.lock_ref == 0:
self.evictable_size_ -= len(node.value)
delta -= len(node.value)
node.ref_counter += 1
node.lock_ref += 1
node = node.parent
return delta
def dec_ref_counter(self, node):
def dec_lock_ref(self, node: TreeNode):
delta = 0
while node != self.root_node:
if node.ref_counter == 1:
if node.lock_ref == 1:
self.evictable_size_ += len(node.value)
delta += len(node.value)
node.ref_counter -= 1
node.lock_ref -= 1
node = node.parent
return delta
......@@ -131,12 +163,12 @@ class RadixCache:
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
def _split_node(self, key, child, split_len):
def _split_node(self, key, child: TreeNode, split_len):
# new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len:][0]: child}
new_node.parent = child.parent
new_node.ref_counter = child.ref_counter
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len]
child.parent = new_node
......@@ -176,11 +208,9 @@ class RadixCache:
self.evictable_size_ += len(value)
return 0
def _print_helper(self, node, indent):
def _print_helper(self, node: TreeNode, indent):
for _, child in node.children.items():
print(
" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
)
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
self._print_helper(child, indent=indent + 2)
def _delete_leaf(self, node):
......@@ -211,7 +241,7 @@ class RadixCache:
if __name__ == "__main__":
tree = RadixCache()
tree = RadixCache(None, None, False)
tree.insert("Hello")
tree.insert("Hello")
......
......@@ -27,44 +27,33 @@ class Scheduler:
return forward_queue
elif self.schedule_heuristic == "fcfs":
return forward_queue
elif self.schedule_heuristic == "weight":
elif self.schedule_heuristic == "dfs-weight":
last_node_to_reqs = defaultdict(list)
for req in forward_queue:
last_node_to_reqs[req.last_node].append(req)
for node in last_node_to_reqs:
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
node_to_weight = defaultdict(int)
self._calc_weight_recursive(
self.tree_cache.root_node, last_node_to_reqs, node_to_weight
)
for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node])
self.calc_weight(self.tree_cache.root_node, node_to_weight)
tmp_queue = []
self._get_weight_priority_recursive(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
q = []
self.get_dfs_priority(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
)
assert len(tmp_queue) == len(forward_queue)
return tmp_queue
assert len(q) == len(forward_queue)
return q
else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
node_to_weight[cur_node] = 1
if cur_node in last_node_to_reqs:
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
def calc_weight(self, cur_node, node_to_weight):
for child in cur_node.children.values():
self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
self.calc_weight(child, node_to_weight)
node_to_weight[cur_node] += node_to_weight[child]
def _get_weight_priority_recursive(
self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
):
visit_list = [child for child in cur_node.children.values()]
visit_list.sort(key=lambda x: -node_to_wight[x])
# for node in visit_list:
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
for child in visit_list:
self._get_weight_priority_recursive(
child, node_to_wight, last_node_to_reqs, tmp_queue
)
tmp_queue.extend(last_node_to_reqs[cur_node])
def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x])
for child in childs:
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
q.extend(last_node_to_reqs[cur_node])
......@@ -149,7 +149,8 @@ class ServerArgs:
"--schedule-heuristic",
type=str,
default=ServerArgs.schedule_heuristic,
help="Schudule mode: [lpm, weight, random, fcfs]",
choices=["lpm", "random", "fcfs", "dfs-weight"],
help="Scheduling Heuristic.",
)
parser.add_argument(
"--schedule-conservativeness",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment