Unverified Commit 48d6bea1 authored by Hanming Lu's avatar Hanming Lu Committed by GitHub
Browse files

[GDN/SWA] mamba and swa radix cache edge case fix (#12111)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent 1689c0e3
......@@ -2386,6 +2386,12 @@ class Scheduler(
- self.tree_cache.swa_evictable_size()
)
num_tokens = max(num_tokens_full, num_tokens_swa)
elif self.is_hybrid_gdn:
num_tokens = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.full_evictable_size()
)
else:
num_tokens = (
self.max_total_num_tokens
......
......@@ -20,11 +20,11 @@ The radix tree data structure for managing the hybrid (full and Mamba) KV cache.
"""
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from numpy import float64
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
......@@ -46,6 +46,7 @@ logger = logging.getLogger(__name__)
class TreeNode:
counter = 0
last_access_time_counter_float = float64(1.0)
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode)
......@@ -61,7 +62,7 @@ class TreeNode:
self.full_lock_ref = 0
self.mamba_lock_ref = 0
# last access time is only used for sanity check. LRU is maintained by the lru list.
self.last_access_time = time.monotonic()
self.last_access_time = get_last_access_time()
self.hit_count = 0
# store the host indices of KV cache
......@@ -90,6 +91,12 @@ class TreeNode:
return self.last_access_time < other.last_access_time
def get_last_access_time() -> float64:
ret = TreeNode.last_access_time_counter_float
TreeNode.last_access_time_counter_float += 1.0
return ret
class LRUList:
def __init__(self, mamba: bool = False):
self.mamba = mamba
......@@ -382,8 +389,6 @@ class MambaRadixCache(BasePrefixCache):
# copy mamba state to req local space if cow is true
if cow_mamba and last_node.mamba_value is not None:
assert req.req_pool_idx is None # req_pool_idx is uninitialed
# for reqs without mamba cache
if req.mamba_pool_idx is None:
dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
......@@ -421,7 +426,7 @@ class MambaRadixCache(BasePrefixCache):
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
return self._insert_helper(self.root_node, key, value, mamba_value)
def cache_finished_req(self, req: Req) -> None:
def cache_finished_req(self, req: Req, is_insert=True) -> None:
"""Cache request when it finishes."""
if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[
......@@ -449,15 +454,20 @@ class MambaRadixCache(BasePrefixCache):
.clone()
)
if is_insert:
new_prefix_len, mamba_exist = self.insert(
RadixKey(token_ids[:page_aligned_len], req.extra_key),
page_aligned_kv_indices,
mamba_value,
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
else:
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : page_aligned_len]
)
mamba_exist = True
self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist)
self.dec_lock_ref(req.last_node)
......@@ -767,15 +777,18 @@ class MambaRadixCache(BasePrefixCache):
# update time for matched nodes, and make nodes closer to root to be least recently used
# this allows mamba to evict nodes closer to root first
self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
self.mamba_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
node_update = best_last_node
self.full_lru_list.reset_node_and_parents_mru(node_update, self.root_node)
self.mamba_lru_list.reset_node_and_parents_mru(node_update, self.root_node)
# This last_access_time is for sanity check, can be deleted after validation in production
cur_time = time.monotonic()
while node:
node.last_access_time = cur_time
cur_time -= 0.0001
node = node.parent
cur_time = get_last_access_time()
while node_update:
node_update.last_access_time = cur_time
cur_time -= (
0.00001 # assuming less than 100000 nodes in a branch of the tree
)
node_update = node_update.parent
return value[:best_value_len], best_last_node
......@@ -791,7 +804,7 @@ class MambaRadixCache(BasePrefixCache):
new_node.value = child.value[:split_len]
# child time should be later than parent's time for mamba tombstone
child.last_access_time = time.monotonic()
child.last_access_time = get_last_access_time()
self.full_lru_list.remove_node(child)
if child.mamba_value is not None:
......@@ -819,7 +832,7 @@ class MambaRadixCache(BasePrefixCache):
# Update the last access time from root to leaf, so that
# mamba will tombstone the node closer to root first
assert mamba_value is not None, "Mamba value should not be None here."
node.last_access_time = time.monotonic()
node.last_access_time = get_last_access_time()
if node != self.root_node:
self.full_lru_list.reset_node_mru(node)
if node.mamba_value is not None:
......@@ -832,7 +845,7 @@ class MambaRadixCache(BasePrefixCache):
total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
node.last_access_time = get_last_access_time()
self.full_lru_list.reset_node_mru(node)
if node.mamba_value is not None:
self.mamba_lru_list.reset_node_mru(node)
......@@ -856,17 +869,21 @@ class MambaRadixCache(BasePrefixCache):
new_node.value = value
new_node.mamba_value = mamba_value
self.full_lru_list.insert_mru(new_node)
self.full_evictable_size_ += len(value)
self.mamba_evictable_size_ += len(mamba_value)
self.mamba_lru_list.insert_mru(new_node)
node.children[child_key] = new_node
self.full_evictable_size_ += len(value)
self.mamba_evictable_size_ += len(mamba_value)
elif node.mamba_value is None: # add for mamba tombstone
node.mamba_value = mamba_value
self.mamba_evictable_size_ += len(mamba_value)
self.full_lru_list.reset_node_mru(node)
self.mamba_lru_list.insert_mru(node)
else:
self.mamba_evictable_size_ += len(mamba_value)
node.last_access_time = get_last_access_time()
else: # mamba value already exists
mamba_value_exist = True
self.full_lru_list.reset_node_mru(node)
self.mamba_lru_list.reset_node_mru(node)
node.last_access_time = get_last_access_time()
return total_prefix_length, mamba_value_exist
......
......@@ -20,12 +20,12 @@ The radix tree data structure for managing the hybrid (full and SWA) KV cache.
"""
import heapq
import time
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from numpy import float64
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
......@@ -50,6 +50,7 @@ class TreeNode:
counter = 0
swa_uuid_counter = 1
last_access_time_counter_float = float64(1.0)
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode)
......@@ -64,7 +65,7 @@ class TreeNode:
self.full_lock_ref = 0
self.swa_lock_ref = 0
# last access time is only used for sanity check. LRU is maintained by the lru list.
self.last_access_time = time.monotonic()
self.last_access_time = get_last_access_time()
self.hit_count = 0
# store the host indices of KV cache
......@@ -99,6 +100,12 @@ def gen_swa_uuid() -> int:
return TreeNode.swa_uuid_counter
def get_last_access_time() -> float64:
ret = TreeNode.last_access_time_counter_float
TreeNode.last_access_time_counter_float += 1.0
return ret
class LRUList:
def __init__(self, swa: bool = False):
self.swa = swa
......@@ -841,15 +848,18 @@ class SWARadixCache(BasePrefixCache):
# update time for matched nodes, and make nodes closer to root to be least recently used
# this allows swa to evict nodes closer to root first
self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
self.swa_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
node_update = best_last_node
self.full_lru_list.reset_node_and_parents_mru(node_update, self.root_node)
self.swa_lru_list.reset_node_and_parents_mru(node_update, self.root_node)
# This last_access_time is for sanity check, can be deleted after validation in production
cur_time = time.monotonic()
while node:
node.last_access_time = cur_time
cur_time -= 0.0001
node = node.parent
cur_time = get_last_access_time()
while node_update:
node_update.last_access_time = cur_time
cur_time -= (
0.00001 # assuming less than 100000 nodes in a branch of the tree
)
node_update = node_update.parent
return value[:best_value_len], best_last_node
......@@ -867,7 +877,7 @@ class SWARadixCache(BasePrefixCache):
new_node.swa_uuid = child.swa_uuid
child.swa_uuid = None
# child time should be later than parent's time for swa tombstone
child.last_access_time = time.monotonic()
child.last_access_time = get_last_access_time()
# remove the child from the lru lists because it is being split
self.full_lru_list.remove_node(child)
......@@ -892,7 +902,7 @@ class SWARadixCache(BasePrefixCache):
) -> int:
# Update the last access time from root to leaf, so that
# swa will tombstone the node closer to root first
node.last_access_time = time.monotonic()
node.last_access_time = get_last_access_time()
if node != self.root_node:
self.full_lru_list.reset_node_mru(node)
if not node.swa_tombstone:
......@@ -905,7 +915,7 @@ class SWARadixCache(BasePrefixCache):
total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
node.last_access_time = get_last_access_time()
self.full_lru_list.reset_node_mru(node)
if not node.swa_tombstone:
self.swa_lru_list.reset_node_mru(node)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment