Unverified Commit 9376ac36 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Memory pool fix for upstream change about eagle (#4170)

parent 94a2b9d3
......@@ -22,7 +22,10 @@ from typing import List, Optional
import torch
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost,
TokenToKVPoolAllocator,
)
logger = logging.getLogger(__name__)
......@@ -127,12 +130,12 @@ class HiCacheController:
def __init__(
self,
mem_pool_device: MHATokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: MHATokenToKVPoolHost,
write_policy: str = "write_through_selective",
):
self.mem_pool_device = mem_pool_device
self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
self.mem_pool_host = mem_pool_host
self.write_policy = write_policy
......@@ -216,7 +219,7 @@ class HiCacheController:
"""
Load KV caches from host memory to device memory.
"""
device_indices = self.mem_pool_device.alloc(len(host_indices))
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
if device_indices is None:
return None
self.mem_pool_host.protect_load(host_indices)
......@@ -417,7 +420,7 @@ class HiCacheController:
self, device_indices: torch.Tensor, host_indices: torch.Tensor
) -> int:
if self.mem_pool_host.is_synced(host_indices):
self.mem_pool_device.free(device_indices)
self.mem_pool_device_allocator.free(device_indices)
self.mem_pool_host.update_backup(host_indices)
return len(device_indices)
else:
......
......@@ -7,9 +7,9 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MHATokenToKVPoolHost,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
......@@ -21,11 +21,13 @@ class HiRadixCache(RadixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: MHATokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache()
)
self.cache_controller = HiCacheController(
token_to_kv_pool, self.token_to_kv_pool_host
token_to_kv_pool_allocator, self.token_to_kv_pool_host
)
# record the nodes with ongoing write through
......@@ -35,7 +37,7 @@ class HiRadixCache(RadixCache):
# todo: dynamically adjust the threshold
self.write_through_threshold = 1
self.load_back_threshold = 10
super().__init__(req_to_token_pool, token_to_kv_pool, disable=False)
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, disable=False)
def reset(self):
TreeNode.counter = 0
......@@ -160,7 +162,7 @@ class HiRadixCache(RadixCache):
def _evict_write_through_selective(self, node: TreeNode):
# evict a node not initiated write to host
self.cache_controller.mem_pool_device.free(node.value)
self.cache_controller.mem_pool_device_allocator.free(node.value)
num_evicted = len(node.value)
self._delete_leaf(node)
return num_evicted
......@@ -270,28 +272,27 @@ class HiRadixCache(RadixCache):
return last_node, prefix_indices
def _match_prefix_helper(
self, node: TreeNode, key: List, value, last_node: TreeNode
):
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
if len(key) == 0:
return
if key[0] in node.children.keys():
value = []
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not new_node.evicted:
value.append(new_node.value)
last_node[0] = new_node
node = new_node
break
else:
self.inc_hit_count(child)
if not child.evicted:
value.append(child.value)
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
node = child
key = key[prefix_len:]
return value, node
def _split_node(self, key, child: TreeNode, split_len: int):
# child node split into new_node -> child
......
......@@ -470,7 +470,7 @@ class MHATokenToKVPoolHost:
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 2.0,
host_to_device_ratio: float = 3.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu",
):
......
......@@ -24,14 +24,10 @@ import requests
from IPython.display import HTML, display
from tqdm import tqdm
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
from sglang.srt.utils import kill_process_tree
logger = logging.getLogger(__name__)
# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
def get_exception_traceback():
etype, value, tb = sys.exc_info()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment