Unverified Commit 48d6bea1 authored by Hanming Lu's avatar Hanming Lu Committed by GitHub
Browse files

[GDN/SWA] mamba and swa radix cache edge case fix (#12111)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent 1689c0e3
...@@ -2386,6 +2386,12 @@ class Scheduler( ...@@ -2386,6 +2386,12 @@ class Scheduler(
- self.tree_cache.swa_evictable_size() - self.tree_cache.swa_evictable_size()
) )
num_tokens = max(num_tokens_full, num_tokens_swa) num_tokens = max(num_tokens_full, num_tokens_swa)
elif self.is_hybrid_gdn:
num_tokens = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.full_evictable_size()
)
else: else:
num_tokens = ( num_tokens = (
self.max_total_num_tokens self.max_total_num_tokens
......
...@@ -20,11 +20,11 @@ The radix tree data structure for managing the hybrid (full and Mamba) KV cache. ...@@ -20,11 +20,11 @@ The radix tree data structure for managing the hybrid (full and Mamba) KV cache.
""" """
import heapq import heapq
import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
import torch import torch
from numpy import float64
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
...@@ -46,6 +46,7 @@ logger = logging.getLogger(__name__) ...@@ -46,6 +46,7 @@ logger = logging.getLogger(__name__)
class TreeNode: class TreeNode:
counter = 0 counter = 0
last_access_time_counter_float = float64(1.0)
def __init__(self, id: Optional[int] = None): def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode) self.children = defaultdict(TreeNode)
...@@ -61,7 +62,7 @@ class TreeNode: ...@@ -61,7 +62,7 @@ class TreeNode:
self.full_lock_ref = 0 self.full_lock_ref = 0
self.mamba_lock_ref = 0 self.mamba_lock_ref = 0
# last access time is only used for sanity check. LRU is maintained by the lru list. # last access time is only used for sanity check. LRU is maintained by the lru list.
self.last_access_time = time.monotonic() self.last_access_time = get_last_access_time()
self.hit_count = 0 self.hit_count = 0
# store the host indices of KV cache # store the host indices of KV cache
...@@ -90,6 +91,12 @@ class TreeNode: ...@@ -90,6 +91,12 @@ class TreeNode:
return self.last_access_time < other.last_access_time return self.last_access_time < other.last_access_time
def get_last_access_time() -> float64:
ret = TreeNode.last_access_time_counter_float
TreeNode.last_access_time_counter_float += 1.0
return ret
class LRUList: class LRUList:
def __init__(self, mamba: bool = False): def __init__(self, mamba: bool = False):
self.mamba = mamba self.mamba = mamba
...@@ -382,8 +389,6 @@ class MambaRadixCache(BasePrefixCache): ...@@ -382,8 +389,6 @@ class MambaRadixCache(BasePrefixCache):
# copy mamba state to req local space if cow is true # copy mamba state to req local space if cow is true
if cow_mamba and last_node.mamba_value is not None: if cow_mamba and last_node.mamba_value is not None:
assert req.req_pool_idx is None # req_pool_idx is uninitialed
# for reqs without mamba cache # for reqs without mamba cache
if req.mamba_pool_idx is None: if req.mamba_pool_idx is None:
dst_index = self.req_to_token_pool.mamba_pool.alloc(1) dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
...@@ -421,7 +426,7 @@ class MambaRadixCache(BasePrefixCache): ...@@ -421,7 +426,7 @@ class MambaRadixCache(BasePrefixCache):
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
return self._insert_helper(self.root_node, key, value, mamba_value) return self._insert_helper(self.root_node, key, value, mamba_value)
def cache_finished_req(self, req: Req) -> None: def cache_finished_req(self, req: Req, is_insert=True) -> None:
"""Cache request when it finishes.""" """Cache request when it finishes."""
if self.disable: if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
...@@ -449,15 +454,20 @@ class MambaRadixCache(BasePrefixCache): ...@@ -449,15 +454,20 @@ class MambaRadixCache(BasePrefixCache):
.clone() .clone()
) )
new_prefix_len, mamba_exist = self.insert( if is_insert:
RadixKey(token_ids[:page_aligned_len], req.extra_key), new_prefix_len, mamba_exist = self.insert(
page_aligned_kv_indices, RadixKey(token_ids[:page_aligned_len], req.extra_key),
mamba_value, page_aligned_kv_indices,
) mamba_value,
)
self.token_to_kv_pool_allocator.free( self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len] kv_indices[len(req.prefix_indices) : new_prefix_len]
) )
else:
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : page_aligned_len]
)
mamba_exist = True
self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist) self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist)
self.dec_lock_ref(req.last_node) self.dec_lock_ref(req.last_node)
...@@ -767,15 +777,18 @@ class MambaRadixCache(BasePrefixCache): ...@@ -767,15 +777,18 @@ class MambaRadixCache(BasePrefixCache):
# update time for matched nodes, and make nodes closer to root to be least recently used # update time for matched nodes, and make nodes closer to root to be least recently used
# this allows mamba to evict nodes closer to root first # this allows mamba to evict nodes closer to root first
self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) node_update = best_last_node
self.mamba_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) self.full_lru_list.reset_node_and_parents_mru(node_update, self.root_node)
self.mamba_lru_list.reset_node_and_parents_mru(node_update, self.root_node)
# This last_access_time is for sanity check, can be deleted after validation in production # This last_access_time is for sanity check, can be deleted after validation in production
cur_time = time.monotonic() cur_time = get_last_access_time()
while node: while node_update:
node.last_access_time = cur_time node_update.last_access_time = cur_time
cur_time -= 0.0001 cur_time -= (
node = node.parent 0.00001 # assuming less than 100000 nodes in a branch of the tree
)
node_update = node_update.parent
return value[:best_value_len], best_last_node return value[:best_value_len], best_last_node
...@@ -791,7 +804,7 @@ class MambaRadixCache(BasePrefixCache): ...@@ -791,7 +804,7 @@ class MambaRadixCache(BasePrefixCache):
new_node.value = child.value[:split_len] new_node.value = child.value[:split_len]
# child time should be later than parent's time for mamba tombstone # child time should be later than parent's time for mamba tombstone
child.last_access_time = time.monotonic() child.last_access_time = get_last_access_time()
self.full_lru_list.remove_node(child) self.full_lru_list.remove_node(child)
if child.mamba_value is not None: if child.mamba_value is not None:
...@@ -819,7 +832,7 @@ class MambaRadixCache(BasePrefixCache): ...@@ -819,7 +832,7 @@ class MambaRadixCache(BasePrefixCache):
# Update the last access time from root to leaf, so that # Update the last access time from root to leaf, so that
# mamba will tombstone the node closer to root first # mamba will tombstone the node closer to root first
assert mamba_value is not None, "Mamba value should not be None here." assert mamba_value is not None, "Mamba value should not be None here."
node.last_access_time = time.monotonic() node.last_access_time = get_last_access_time()
if node != self.root_node: if node != self.root_node:
self.full_lru_list.reset_node_mru(node) self.full_lru_list.reset_node_mru(node)
if node.mamba_value is not None: if node.mamba_value is not None:
...@@ -832,7 +845,7 @@ class MambaRadixCache(BasePrefixCache): ...@@ -832,7 +845,7 @@ class MambaRadixCache(BasePrefixCache):
total_prefix_length = 0 total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys(): while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key] node = node.children[child_key]
node.last_access_time = time.monotonic() node.last_access_time = get_last_access_time()
self.full_lru_list.reset_node_mru(node) self.full_lru_list.reset_node_mru(node)
if node.mamba_value is not None: if node.mamba_value is not None:
self.mamba_lru_list.reset_node_mru(node) self.mamba_lru_list.reset_node_mru(node)
...@@ -856,17 +869,21 @@ class MambaRadixCache(BasePrefixCache): ...@@ -856,17 +869,21 @@ class MambaRadixCache(BasePrefixCache):
new_node.value = value new_node.value = value
new_node.mamba_value = mamba_value new_node.mamba_value = mamba_value
self.full_lru_list.insert_mru(new_node) self.full_lru_list.insert_mru(new_node)
self.full_evictable_size_ += len(value)
self.mamba_evictable_size_ += len(mamba_value)
self.mamba_lru_list.insert_mru(new_node) self.mamba_lru_list.insert_mru(new_node)
node.children[child_key] = new_node node.children[child_key] = new_node
self.full_evictable_size_ += len(value)
self.mamba_evictable_size_ += len(mamba_value)
elif node.mamba_value is None: # add for mamba tombstone elif node.mamba_value is None: # add for mamba tombstone
node.mamba_value = mamba_value node.mamba_value = mamba_value
self.mamba_evictable_size_ += len(mamba_value) self.full_lru_list.reset_node_mru(node)
self.mamba_lru_list.insert_mru(node) self.mamba_lru_list.insert_mru(node)
else: self.mamba_evictable_size_ += len(mamba_value)
node.last_access_time = get_last_access_time()
else: # mamba value already exists
mamba_value_exist = True mamba_value_exist = True
self.full_lru_list.reset_node_mru(node)
self.mamba_lru_list.reset_node_mru(node) self.mamba_lru_list.reset_node_mru(node)
node.last_access_time = get_last_access_time()
return total_prefix_length, mamba_value_exist return total_prefix_length, mamba_value_exist
......
...@@ -20,12 +20,12 @@ The radix tree data structure for managing the hybrid (full and SWA) KV cache. ...@@ -20,12 +20,12 @@ The radix tree data structure for managing the hybrid (full and SWA) KV cache.
""" """
import heapq import heapq
import time
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
import torch import torch
from numpy import float64
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
...@@ -50,6 +50,7 @@ class TreeNode: ...@@ -50,6 +50,7 @@ class TreeNode:
counter = 0 counter = 0
swa_uuid_counter = 1 swa_uuid_counter = 1
last_access_time_counter_float = float64(1.0)
def __init__(self, id: Optional[int] = None): def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode) self.children = defaultdict(TreeNode)
...@@ -64,7 +65,7 @@ class TreeNode: ...@@ -64,7 +65,7 @@ class TreeNode:
self.full_lock_ref = 0 self.full_lock_ref = 0
self.swa_lock_ref = 0 self.swa_lock_ref = 0
# last access time is only used for sanity check. LRU is maintained by the lru list. # last access time is only used for sanity check. LRU is maintained by the lru list.
self.last_access_time = time.monotonic() self.last_access_time = get_last_access_time()
self.hit_count = 0 self.hit_count = 0
# store the host indices of KV cache # store the host indices of KV cache
...@@ -99,6 +100,12 @@ def gen_swa_uuid() -> int: ...@@ -99,6 +100,12 @@ def gen_swa_uuid() -> int:
return TreeNode.swa_uuid_counter return TreeNode.swa_uuid_counter
def get_last_access_time() -> float64:
ret = TreeNode.last_access_time_counter_float
TreeNode.last_access_time_counter_float += 1.0
return ret
class LRUList: class LRUList:
def __init__(self, swa: bool = False): def __init__(self, swa: bool = False):
self.swa = swa self.swa = swa
...@@ -841,15 +848,18 @@ class SWARadixCache(BasePrefixCache): ...@@ -841,15 +848,18 @@ class SWARadixCache(BasePrefixCache):
# update time for matched nodes, and make nodes closer to root to be least recently used # update time for matched nodes, and make nodes closer to root to be least recently used
# this allows swa to evict nodes closer to root first # this allows swa to evict nodes closer to root first
self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) node_update = best_last_node
self.swa_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) self.full_lru_list.reset_node_and_parents_mru(node_update, self.root_node)
self.swa_lru_list.reset_node_and_parents_mru(node_update, self.root_node)
# This last_access_time is for sanity check, can be deleted after validation in production # This last_access_time is for sanity check, can be deleted after validation in production
cur_time = time.monotonic() cur_time = get_last_access_time()
while node: while node_update:
node.last_access_time = cur_time node_update.last_access_time = cur_time
cur_time -= 0.0001 cur_time -= (
node = node.parent 0.00001 # assuming less than 100000 nodes in a branch of the tree
)
node_update = node_update.parent
return value[:best_value_len], best_last_node return value[:best_value_len], best_last_node
...@@ -867,7 +877,7 @@ class SWARadixCache(BasePrefixCache): ...@@ -867,7 +877,7 @@ class SWARadixCache(BasePrefixCache):
new_node.swa_uuid = child.swa_uuid new_node.swa_uuid = child.swa_uuid
child.swa_uuid = None child.swa_uuid = None
# child time should be later than parent's time for swa tombstone # child time should be later than parent's time for swa tombstone
child.last_access_time = time.monotonic() child.last_access_time = get_last_access_time()
# remove the child from the lru lists because it is being split # remove the child from the lru lists because it is being split
self.full_lru_list.remove_node(child) self.full_lru_list.remove_node(child)
...@@ -892,7 +902,7 @@ class SWARadixCache(BasePrefixCache): ...@@ -892,7 +902,7 @@ class SWARadixCache(BasePrefixCache):
) -> int: ) -> int:
# Update the last access time from root to leaf, so that # Update the last access time from root to leaf, so that
# swa will tombstone the node closer to root first # swa will tombstone the node closer to root first
node.last_access_time = time.monotonic() node.last_access_time = get_last_access_time()
if node != self.root_node: if node != self.root_node:
self.full_lru_list.reset_node_mru(node) self.full_lru_list.reset_node_mru(node)
if not node.swa_tombstone: if not node.swa_tombstone:
...@@ -905,7 +915,7 @@ class SWARadixCache(BasePrefixCache): ...@@ -905,7 +915,7 @@ class SWARadixCache(BasePrefixCache):
total_prefix_length = 0 total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys(): while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key] node = node.children[child_key]
node.last_access_time = time.monotonic() node.last_access_time = get_last_access_time()
self.full_lru_list.reset_node_mru(node) self.full_lru_list.reset_node_mru(node)
if not node.swa_tombstone: if not node.swa_tombstone:
self.swa_lru_list.reset_node_mru(node) self.swa_lru_list.reset_node_mru(node)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment