Unverified Commit 70645f4d authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

upstream hicache fixes (#5570)

parent 188f0955
......@@ -571,6 +571,14 @@ class Req:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
elif enable_hierarchical_cache:
# in case last_node is evicted during scheduling, we need to update the prefix_indices
while self.last_node.evicted:
self.prefix_indices = self.prefix_indices[
: -len(self.last_node.host_value)
]
self.last_node = self.last_node.parent
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):
......
......@@ -489,6 +489,8 @@ class Scheduler(
tp_cache_group=self.tp_cpu_group,
page_size=self.page_size,
hicache_ratio=server_args.hicache_ratio,
hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy,
)
else:
self.tree_cache = RadixCache(
......
......@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
tp_cache_group: torch.distributed.ProcessGroup,
page_size: int,
hicache_ratio: float,
hicache_size: int,
hicache_write_policy: str,
):
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
self.kv_cache, hicache_ratio, page_size
self.kv_cache, hicache_ratio, hicache_size, page_size
)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache, hicache_ratio, page_size
self.kv_cache, hicache_ratio, hicache_size, page_size
)
else:
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
......@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
self.token_to_kv_pool_host,
page_size,
load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy,
)
# record the nodes with ongoing write through
......@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
# record the node segments with ongoing load back
self.ongoing_load_back = {}
# todo: dynamically adjust the threshold
self.write_through_threshold = 1
self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 3
)
self.load_back_threshold = 10
super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
......@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
height += 1
return height
def write_backup(self, node: TreeNode):
def write_backup(self, node: TreeNode, write_back=False):
host_indices = self.cache_controller.write(
device_indices=node.value,
node_id=node.id,
......@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
if host_indices is not None:
node.host_value = host_indices
self.ongoing_write_through[node.id] = node
self.inc_lock_ref(node)
if not write_back:
# no need to lock nodes if write back
self.inc_lock_ref(node)
else:
return 0
return len(host_indices)
def inc_hit_count(self, node: TreeNode):
if self.cache_controller.write_policy != "write_through_selective":
if node.backuped or self.cache_controller.write_policy == "write_back":
return
node.hit_count += 1
if node.host_value is None and node.hit_count > self.write_through_threshold:
if node.hit_count >= self.write_through_threshold:
self.write_backup(node)
node.hit_count = 0
def writing_check(self):
def writing_check(self, write_back=False):
if write_back:
# blocking till all write back complete
while len(self.ongoing_write_through) > 0:
ack_id = self.cache_controller.ack_write_queue.get()
del self.ongoing_write_through[ack_id]
return
queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
)
......@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
heapq.heapify(leaves)
num_evicted = 0
pending_nodes = []
write_back_nodes = []
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x.lock_ref > 0:
continue
if x.host_value is None:
if not x.backuped:
if self.cache_controller.write_policy == "write_back":
num_evicted += self.write_backup(x)
pending_nodes.append(x)
elif self.cache_controller.write_policy == "write_through_selective":
num_evicted += self._evict_write_through_selective(x)
# write to host if the node is not backuped
num_evicted += self.write_backup(x, write_back=True)
write_back_nodes.append(x)
else:
assert (
self.cache_controller.write_policy != "write_through"
), "write_through should be inclusive"
raise NotImplementedError
num_evicted += self._evict_regular(x)
else:
num_evicted += self._evict_write_through(x)
num_evicted += self._evict_backuped(x)
for child in x.parent.children.values():
if child in pending_nodes:
if child in write_back_nodes:
continue
if not child.evicted:
break
......@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
heapq.heappush(leaves, x.parent)
if self.cache_controller.write_policy == "write_back":
# blocking till all write back complete
while len(self.ongoing_write_through) > 0:
self.writing_check()
time.sleep(0.1)
for node in pending_nodes:
assert node.host_value is not None
self._evict_write_through(node)
self.writing_check(write_back=True)
for node in write_back_nodes:
assert node.backuped
self._evict_backuped(node)
def _evict_write_through(self, node: TreeNode):
def _evict_backuped(self, node: TreeNode):
# evict a node already written to host
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
assert num_evicted > 0
......@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
node.value = None
return num_evicted
def _evict_write_through_selective(self, node: TreeNode):
def _evict_regular(self, node: TreeNode):
# evict a node not initiated write to host
self.cache_controller.mem_pool_device_allocator.free(node.value)
num_evicted = len(node.value)
......@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not new_node.evicted:
value.append(new_node.value)
node = new_node
break
else:
self.inc_hit_count(child)
if not child.evicted:
value.append(child.value)
node = child
......@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
else:
new_node.value = child.value[:split_len]
child.value = child.value[split_len:]
if child.host_value is not None:
if child.backuped:
new_node.host_value = child.host_value[:split_len]
child.host_value = child.host_value[split_len:]
child.parent = new_node
......@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
node.children[child_key] = new_node
self.evictable_size_ += len(value)
if self.cache_controller.write_policy == "write_through":
self.write_backup(new_node)
if self.cache_controller.write_policy != "write_back":
self.inc_hit_count(new_node)
return total_prefix_length
def _collect_leaves_device(self):
......
......@@ -624,26 +624,27 @@ class HostKVCache(abc.ABC):
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float,
host_size: int,
pin_memory: bool,
device: str,
page_size: int,
):
assert (
host_to_device_ratio >= 1
), "The host memory should be larger than the device memory with the current protocol"
# todo, other ways of configuring the size
self.device_pool = device_pool
self.host_to_device_ratio = host_to_device_ratio
self.dtype = device_pool.store_dtype
self.pin_memory = pin_memory
self.device = device
self.page_size = page_size
self.size = int(device_pool.size * host_to_device_ratio)
self.size_per_token = self.get_size_per_token()
if host_size > 0:
self.size = int(host_size * 1e9 // self.size_per_token)
else:
self.size = int(device_pool.size * host_to_device_ratio)
# Align the host memory pool size to the page size
self.size = self.size - (self.size % self.page_size)
self.dtype = device_pool.store_dtype
self.size_per_token = self.get_size_per_token()
assert (
self.size > device_pool.size
), "The host memory should be larger than the device memory with the current protocol"
# Verify there is enough available host memory.
host_mem = psutil.virtual_memory()
......@@ -795,12 +796,13 @@ class MHATokenToKVPoolHost(HostKVCache):
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float,
host_size: int,
page_size: int,
pin_memory: bool = True,
device: str = "cpu",
):
super().__init__(
device_pool, host_to_device_ratio, pin_memory, device, page_size
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
)
def get_size_per_token(self):
......@@ -869,12 +871,13 @@ class MLATokenToKVPoolHost(HostKVCache):
self,
device_pool: MLATokenToKVPool,
host_to_device_ratio: float,
host_size: int,
page_size: int,
pin_memory: bool = True,
device: str = "cpu",
):
super().__init__(
device_pool, host_to_device_ratio, pin_memory, device, page_size
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
)
def get_size_per_token(self):
......
......@@ -180,6 +180,8 @@ class ServerArgs:
tool_call_parser: Optional[str] = None
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
hicache_size: int = 0
hicache_write_policy: str = "write_through_selective"
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
......@@ -1116,10 +1118,22 @@ class ServerArgs:
parser.add_argument(
"--hicache-ratio",
type=float,
required=False,
default=ServerArgs.hicache_ratio,
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
)
parser.add_argument(
"--hicache-size",
type=int,
default=ServerArgs.hicache_size,
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
)
parser.add_argument(
"--hicache-write-policy",
type=str,
choices=["write_back", "write_through", "write_through_selective"],
default=ServerArgs.hicache_write_policy,
help="The write policy of hierarchical cache.",
)
parser.add_argument(
"--enable-deepep-moe",
action="store_true",
......
......@@ -23,6 +23,10 @@ class TestHiCache(CustomTestCase):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-hierarchical-cache",
"--mem-fraction-static",
0.7,
"--hicache-size",
100,
],
)
......
......@@ -24,6 +24,8 @@ class TestHierarchicalMLA(CustomTestCase):
other_args=[
"--trust-remote-code",
"--enable-hierarchical-cache",
"--hicache-ratio",
2,
],
)
......
......@@ -24,7 +24,9 @@ class TestHiCachePage(CustomTestCase):
other_args=[
"--enable-hierarchical-cache",
"--page-size",
"32",
32,
"--hicache-write-policy",
"write-back",
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment