Unverified Commit 9376ac36 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Memory pool fix for upstream change about eagle (#4170)

parent 94a2b9d3
...@@ -22,7 +22,10 @@ from typing import List, Optional ...@@ -22,7 +22,10 @@ from typing import List, Optional
import torch import torch
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost,
TokenToKVPoolAllocator,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -127,12 +130,12 @@ class HiCacheController: ...@@ -127,12 +130,12 @@ class HiCacheController:
def __init__( def __init__(
self, self,
mem_pool_device: MHATokenToKVPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: MHATokenToKVPoolHost, mem_pool_host: MHATokenToKVPoolHost,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
): ):
self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = mem_pool_device self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
self.mem_pool_host = mem_pool_host self.mem_pool_host = mem_pool_host
self.write_policy = write_policy self.write_policy = write_policy
...@@ -216,7 +219,7 @@ class HiCacheController: ...@@ -216,7 +219,7 @@ class HiCacheController:
""" """
Load KV caches from host memory to device memory. Load KV caches from host memory to device memory.
""" """
device_indices = self.mem_pool_device.alloc(len(host_indices)) device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
if device_indices is None: if device_indices is None:
return None return None
self.mem_pool_host.protect_load(host_indices) self.mem_pool_host.protect_load(host_indices)
...@@ -417,7 +420,7 @@ class HiCacheController: ...@@ -417,7 +420,7 @@ class HiCacheController:
self, device_indices: torch.Tensor, host_indices: torch.Tensor self, device_indices: torch.Tensor, host_indices: torch.Tensor
) -> int: ) -> int:
if self.mem_pool_host.is_synced(host_indices): if self.mem_pool_host.is_synced(host_indices):
self.mem_pool_device.free(device_indices) self.mem_pool_device_allocator.free(device_indices)
self.mem_pool_host.update_backup(host_indices) self.mem_pool_host.update_backup(host_indices)
return len(device_indices) return len(device_indices)
else: else:
......
...@@ -7,9 +7,9 @@ import torch ...@@ -7,9 +7,9 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MHATokenToKVPoolHost, MHATokenToKVPoolHost,
ReqToTokenPool, ReqToTokenPool,
TokenToKVPoolAllocator,
) )
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
...@@ -21,11 +21,13 @@ class HiRadixCache(RadixCache): ...@@ -21,11 +21,13 @@ class HiRadixCache(RadixCache):
def __init__( def __init__(
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: MHATokenToKVPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
): ):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool) self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache()
)
self.cache_controller = HiCacheController( self.cache_controller = HiCacheController(
token_to_kv_pool, self.token_to_kv_pool_host token_to_kv_pool_allocator, self.token_to_kv_pool_host
) )
# record the nodes with ongoing write through # record the nodes with ongoing write through
...@@ -35,7 +37,7 @@ class HiRadixCache(RadixCache): ...@@ -35,7 +37,7 @@ class HiRadixCache(RadixCache):
# todo: dynamically adjust the threshold # todo: dynamically adjust the threshold
self.write_through_threshold = 1 self.write_through_threshold = 1
self.load_back_threshold = 10 self.load_back_threshold = 10
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False) super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
def reset(self): def reset(self):
TreeNode.counter = 0 TreeNode.counter = 0
...@@ -160,7 +162,7 @@ class HiRadixCache(RadixCache): ...@@ -160,7 +162,7 @@ class HiRadixCache(RadixCache):
def _evict_write_through_selective(self, node: TreeNode): def _evict_write_through_selective(self, node: TreeNode):
# evict a node not initiated write to host # evict a node not initiated write to host
self.cache_controller.mem_pool_device.free(node.value) self.cache_controller.mem_pool_device_allocator.free(node.value)
num_evicted = len(node.value) num_evicted = len(node.value)
self._delete_leaf(node) self._delete_leaf(node)
return num_evicted return num_evicted
...@@ -270,28 +272,27 @@ class HiRadixCache(RadixCache): ...@@ -270,28 +272,27 @@ class HiRadixCache(RadixCache):
return last_node, prefix_indices return last_node, prefix_indices
def _match_prefix_helper( def _match_prefix_helper(self, node: TreeNode, key: List):
self, node: TreeNode, key: List, value, last_node: TreeNode
):
node.last_access_time = time.time() node.last_access_time = time.time()
if len(key) == 0: value = []
return while len(key) > 0 and key[0] in node.children.keys():
if key[0] in node.children.keys():
child = node.children[key[0]] child = node.children[key[0]]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key) prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key): if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len) new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node) self.inc_hit_count(new_node)
if not new_node.evicted: if not new_node.evicted:
value.append(new_node.value) value.append(new_node.value)
last_node[0] = new_node node = new_node
break
else: else:
self.inc_hit_count(child) self.inc_hit_count(child)
if not child.evicted: if not child.evicted:
value.append(child.value) value.append(child.value)
last_node[0] = child node = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node) key = key[prefix_len:]
return value, node
def _split_node(self, key, child: TreeNode, split_len: int): def _split_node(self, key, child: TreeNode, split_len: int):
# child node split into new_node -> child # child node split into new_node -> child
......
...@@ -470,7 +470,7 @@ class MHATokenToKVPoolHost: ...@@ -470,7 +470,7 @@ class MHATokenToKVPoolHost:
def __init__( def __init__(
self, self,
device_pool: MHATokenToKVPool, device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 2.0, host_to_device_ratio: float = 3.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu", device: str = "cpu",
): ):
......
...@@ -24,14 +24,10 @@ import requests ...@@ -24,14 +24,10 @@ import requests
from IPython.display import HTML, display from IPython.display import HTML, display
from tqdm import tqdm from tqdm import tqdm
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
def get_exception_traceback(): def get_exception_traceback():
etype, value, tb = sys.exc_info() etype, value, tb = sys.exc_info()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment