Unverified Commit 70645f4d authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

upstream hicache fixes (#5570)

parent 188f0955
...@@ -571,6 +571,14 @@ class Req: ...@@ -571,6 +571,14 @@ class Req:
self.prefix_indices, self.last_node = tree_cache.match_prefix( self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids() rid=self.rid, key=self.adjust_max_prefix_ids()
) )
elif enable_hierarchical_cache:
# in case last_node is evicted during scheduling, we need to update the prefix_indices
while self.last_node.evicted:
self.prefix_indices = self.prefix_indices[
: -len(self.last_node.host_value)
]
self.last_node = self.last_node.parent
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self): def adjust_max_prefix_ids(self):
......
...@@ -489,6 +489,8 @@ class Scheduler( ...@@ -489,6 +489,8 @@ class Scheduler(
tp_cache_group=self.tp_cpu_group, tp_cache_group=self.tp_cpu_group,
page_size=self.page_size, page_size=self.page_size,
hicache_ratio=server_args.hicache_ratio, hicache_ratio=server_args.hicache_ratio,
hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy,
) )
else: else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
......
...@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache): ...@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
tp_cache_group: torch.distributed.ProcessGroup, tp_cache_group: torch.distributed.ProcessGroup,
page_size: int, page_size: int,
hicache_ratio: float, hicache_ratio: float,
hicache_size: int,
hicache_write_policy: str,
): ):
self.kv_cache = token_to_kv_pool_allocator.get_kvcache() self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool): if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost( self.token_to_kv_pool_host = MHATokenToKVPoolHost(
self.kv_cache, hicache_ratio, page_size self.kv_cache, hicache_ratio, hicache_size, page_size
) )
elif isinstance(self.kv_cache, MLATokenToKVPool): elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost( self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache, hicache_ratio, page_size self.kv_cache, hicache_ratio, hicache_size, page_size
) )
else: else:
raise ValueError(f"HiRadixCache only supports MHA and MLA yet") raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
...@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache): ...@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
self.token_to_kv_pool_host, self.token_to_kv_pool_host,
page_size, page_size,
load_cache_event=self.load_cache_event, load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy,
) )
# record the nodes with ongoing write through # record the nodes with ongoing write through
...@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache): ...@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
# record the node segments with ongoing load back # record the node segments with ongoing load back
self.ongoing_load_back = {} self.ongoing_load_back = {}
# todo: dynamically adjust the threshold # todo: dynamically adjust the threshold
self.write_through_threshold = 1 self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 3
)
self.load_back_threshold = 10 self.load_back_threshold = 10
super().__init__( super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
...@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache): ...@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
height += 1 height += 1
return height return height
def write_backup(self, node: TreeNode): def write_backup(self, node: TreeNode, write_back=False):
host_indices = self.cache_controller.write( host_indices = self.cache_controller.write(
device_indices=node.value, device_indices=node.value,
node_id=node.id, node_id=node.id,
...@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache): ...@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
if host_indices is not None: if host_indices is not None:
node.host_value = host_indices node.host_value = host_indices
self.ongoing_write_through[node.id] = node self.ongoing_write_through[node.id] = node
self.inc_lock_ref(node) if not write_back:
# no need to lock nodes if write back
self.inc_lock_ref(node)
else: else:
return 0 return 0
return len(host_indices) return len(host_indices)
def inc_hit_count(self, node: TreeNode): def inc_hit_count(self, node: TreeNode):
if self.cache_controller.write_policy != "write_through_selective": if node.backuped or self.cache_controller.write_policy == "write_back":
return return
node.hit_count += 1 node.hit_count += 1
if node.host_value is None and node.hit_count > self.write_through_threshold: if node.hit_count >= self.write_through_threshold:
self.write_backup(node) self.write_backup(node)
node.hit_count = 0 node.hit_count = 0
def writing_check(self): def writing_check(self, write_back=False):
if write_back:
# blocking till all write back complete
while len(self.ongoing_write_through) > 0:
ack_id = self.cache_controller.ack_write_queue.get()
del self.ongoing_write_through[ack_id]
return
queue_size = torch.tensor( queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
) )
...@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache): ...@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
heapq.heapify(leaves) heapq.heapify(leaves)
num_evicted = 0 num_evicted = 0
pending_nodes = [] write_back_nodes = []
while num_evicted < num_tokens and len(leaves): while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves) x = heapq.heappop(leaves)
if x.lock_ref > 0: if x.lock_ref > 0:
continue continue
if x.host_value is None: if not x.backuped:
if self.cache_controller.write_policy == "write_back": if self.cache_controller.write_policy == "write_back":
num_evicted += self.write_backup(x) # write to host if the node is not backuped
pending_nodes.append(x) num_evicted += self.write_backup(x, write_back=True)
elif self.cache_controller.write_policy == "write_through_selective": write_back_nodes.append(x)
num_evicted += self._evict_write_through_selective(x)
else: else:
assert ( num_evicted += self._evict_regular(x)
self.cache_controller.write_policy != "write_through"
), "write_through should be inclusive"
raise NotImplementedError
else: else:
num_evicted += self._evict_write_through(x) num_evicted += self._evict_backuped(x)
for child in x.parent.children.values(): for child in x.parent.children.values():
if child in pending_nodes: if child in write_back_nodes:
continue continue
if not child.evicted: if not child.evicted:
break break
...@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache): ...@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
heapq.heappush(leaves, x.parent) heapq.heappush(leaves, x.parent)
if self.cache_controller.write_policy == "write_back": if self.cache_controller.write_policy == "write_back":
# blocking till all write back complete self.writing_check(write_back=True)
while len(self.ongoing_write_through) > 0: for node in write_back_nodes:
self.writing_check() assert node.backuped
time.sleep(0.1) self._evict_backuped(node)
for node in pending_nodes:
assert node.host_value is not None
self._evict_write_through(node)
def _evict_write_through(self, node: TreeNode): def _evict_backuped(self, node: TreeNode):
# evict a node already written to host # evict a node already written to host
num_evicted = self.cache_controller.evict_device(node.value, node.host_value) num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
assert num_evicted > 0 assert num_evicted > 0
...@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache): ...@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
node.value = None node.value = None
return num_evicted return num_evicted
def _evict_write_through_selective(self, node: TreeNode): def _evict_regular(self, node: TreeNode):
# evict a node not initiated write to host # evict a node not initiated write to host
self.cache_controller.mem_pool_device_allocator.free(node.value) self.cache_controller.mem_pool_device_allocator.free(node.value)
num_evicted = len(node.value) num_evicted = len(node.value)
...@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache): ...@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
prefix_len = self.key_match_fn(child.key, key) prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key): if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len) new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not new_node.evicted: if not new_node.evicted:
value.append(new_node.value) value.append(new_node.value)
node = new_node node = new_node
break break
else: else:
self.inc_hit_count(child)
if not child.evicted: if not child.evicted:
value.append(child.value) value.append(child.value)
node = child node = child
...@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache): ...@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
else: else:
new_node.value = child.value[:split_len] new_node.value = child.value[:split_len]
child.value = child.value[split_len:] child.value = child.value[split_len:]
if child.host_value is not None: if child.backuped:
new_node.host_value = child.host_value[:split_len] new_node.host_value = child.host_value[:split_len]
child.host_value = child.host_value[split_len:] child.host_value = child.host_value[split_len:]
child.parent = new_node child.parent = new_node
...@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache): ...@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
node.children[child_key] = new_node node.children[child_key] = new_node
self.evictable_size_ += len(value) self.evictable_size_ += len(value)
if self.cache_controller.write_policy == "write_through": if self.cache_controller.write_policy != "write_back":
self.write_backup(new_node) self.inc_hit_count(new_node)
return total_prefix_length return total_prefix_length
def _collect_leaves_device(self): def _collect_leaves_device(self):
......
...@@ -624,26 +624,27 @@ class HostKVCache(abc.ABC): ...@@ -624,26 +624,27 @@ class HostKVCache(abc.ABC):
self, self,
device_pool: MHATokenToKVPool, device_pool: MHATokenToKVPool,
host_to_device_ratio: float, host_to_device_ratio: float,
host_size: int,
pin_memory: bool, pin_memory: bool,
device: str, device: str,
page_size: int, page_size: int,
): ):
assert (
host_to_device_ratio >= 1
), "The host memory should be larger than the device memory with the current protocol"
# todo, other ways of configuring the size
self.device_pool = device_pool self.device_pool = device_pool
self.host_to_device_ratio = host_to_device_ratio self.dtype = device_pool.store_dtype
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.device = device self.device = device
self.page_size = page_size self.page_size = page_size
self.size_per_token = self.get_size_per_token()
self.size = int(device_pool.size * host_to_device_ratio) if host_size > 0:
self.size = int(host_size * 1e9 // self.size_per_token)
else:
self.size = int(device_pool.size * host_to_device_ratio)
# Align the host memory pool size to the page size # Align the host memory pool size to the page size
self.size = self.size - (self.size % self.page_size) self.size = self.size - (self.size % self.page_size)
self.dtype = device_pool.store_dtype
self.size_per_token = self.get_size_per_token() assert (
self.size > device_pool.size
), "The host memory should be larger than the device memory with the current protocol"
# Verify there is enough available host memory. # Verify there is enough available host memory.
host_mem = psutil.virtual_memory() host_mem = psutil.virtual_memory()
...@@ -795,12 +796,13 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -795,12 +796,13 @@ class MHATokenToKVPoolHost(HostKVCache):
self, self,
device_pool: MHATokenToKVPool, device_pool: MHATokenToKVPool,
host_to_device_ratio: float, host_to_device_ratio: float,
host_size: int,
page_size: int, page_size: int,
pin_memory: bool = True, pin_memory: bool = True,
device: str = "cpu", device: str = "cpu",
): ):
super().__init__( super().__init__(
device_pool, host_to_device_ratio, pin_memory, device, page_size device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
) )
def get_size_per_token(self): def get_size_per_token(self):
...@@ -869,12 +871,13 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -869,12 +871,13 @@ class MLATokenToKVPoolHost(HostKVCache):
self, self,
device_pool: MLATokenToKVPool, device_pool: MLATokenToKVPool,
host_to_device_ratio: float, host_to_device_ratio: float,
host_size: int,
page_size: int, page_size: int,
pin_memory: bool = True, pin_memory: bool = True,
device: str = "cpu", device: str = "cpu",
): ):
super().__init__( super().__init__(
device_pool, host_to_device_ratio, pin_memory, device, page_size device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
) )
def get_size_per_token(self): def get_size_per_token(self):
......
...@@ -180,6 +180,8 @@ class ServerArgs: ...@@ -180,6 +180,8 @@ class ServerArgs:
tool_call_parser: Optional[str] = None tool_call_parser: Optional[str] = None
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0 hicache_ratio: float = 2.0
hicache_size: int = 0
hicache_write_policy: str = "write_through_selective"
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None moe_dense_tp_size: Optional[int] = None
...@@ -1116,10 +1118,22 @@ class ServerArgs: ...@@ -1116,10 +1118,22 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--hicache-ratio", "--hicache-ratio",
type=float, type=float,
required=False,
default=ServerArgs.hicache_ratio, default=ServerArgs.hicache_ratio,
help="The ratio of the size of host KV cache memory pool to the size of device pool.", help="The ratio of the size of host KV cache memory pool to the size of device pool.",
) )
parser.add_argument(
"--hicache-size",
type=int,
default=ServerArgs.hicache_size,
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
)
parser.add_argument(
"--hicache-write-policy",
type=str,
choices=["write_back", "write_through", "write_through_selective"],
default=ServerArgs.hicache_write_policy,
help="The write policy of hierarchical cache.",
)
parser.add_argument( parser.add_argument(
"--enable-deepep-moe", "--enable-deepep-moe",
action="store_true", action="store_true",
......
...@@ -23,6 +23,10 @@ class TestHiCache(CustomTestCase): ...@@ -23,6 +23,10 @@ class TestHiCache(CustomTestCase):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--enable-hierarchical-cache", "--enable-hierarchical-cache",
"--mem-fraction-static",
0.7,
"--hicache-size",
100,
], ],
) )
......
...@@ -24,6 +24,8 @@ class TestHierarchicalMLA(CustomTestCase): ...@@ -24,6 +24,8 @@ class TestHierarchicalMLA(CustomTestCase):
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--enable-hierarchical-cache", "--enable-hierarchical-cache",
"--hicache-ratio",
2,
], ],
) )
......
...@@ -24,7 +24,9 @@ class TestHiCachePage(CustomTestCase): ...@@ -24,7 +24,9 @@ class TestHiCachePage(CustomTestCase):
other_args=[ other_args=[
"--enable-hierarchical-cache", "--enable-hierarchical-cache",
"--page-size", "--page-size",
"32", 32,
"--hicache-write-policy",
"write-back",
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment