Unverified Commit 0edda320 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

Support page first layout zero copy for mooncake store (#8651)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 924827c3
...@@ -260,6 +260,7 @@ class HiCacheController: ...@@ -260,6 +260,7 @@ class HiCacheController:
self.storage_backend = MooncakeStore() self.storage_backend = MooncakeStore()
self.get_hash_str = get_hash_str_mooncake self.get_hash_str = get_hash_str_mooncake
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs": elif storage_backend == "hf3fs":
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import ( from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
......
...@@ -472,27 +472,26 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -472,27 +472,26 @@ class MHATokenToKVPoolHost(HostKVCache):
* self.dtype.itemsize * self.dtype.itemsize
) )
for index in range(0, len(indices), self.page_size): for index in range(0, len(indices), self.page_size):
for layer_id in range(self.layer_num): k_ptr = (
k_ptr = ( kv_buffer_data_ptr
kv_buffer_data_ptr + indices[index]
+ indices[index] * self.layer_num
* self.head_num * self.head_num
* self.head_dim * self.head_dim
* self.dtype.itemsize * self.dtype.itemsize
+ layer_id )
* self.size v_ptr = k_ptr + v_offset
* self.head_num ptr_list.append(k_ptr)
* self.head_dim ptr_list.append(v_ptr)
* self.dtype.itemsize key_ = keys[index // self.page_size]
) key_list.append(f"{key_}_k")
v_ptr = k_ptr + v_offset key_list.append(f"{key_}_v")
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
key_ = keys[index // self.page_size]
key_list.append(f"{key_}_{layer_id}_k")
key_list.append(f"{key_}_{layer_id}_v")
element_size = ( element_size = (
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim self.layer_num
* self.dtype.itemsize
* self.page_size
* self.head_num
* self.head_dim
) )
element_size_list = [element_size] * len(key_list) element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list return key_list, ptr_list, element_size_list
...@@ -687,22 +686,19 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -687,22 +686,19 @@ class MLATokenToKVPoolHost(HostKVCache):
key_list = [] key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr() kv_buffer_data_ptr = self.kv_buffer.data_ptr()
for index in range(0, len(indices), self.page_size): for index in range(0, len(indices), self.page_size):
for layer_id in range(self.layer_num): k_ptr = (
k_ptr = ( kv_buffer_data_ptr
kv_buffer_data_ptr + indices[index]
+ indices[index] * self.layer_num
* (self.kv_lora_rank + self.qk_rope_head_dim) * (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize * self.dtype.itemsize
+ layer_id )
* self.size ptr_list.append(k_ptr)
* (self.kv_lora_rank + self.qk_rope_head_dim) key_ = keys[index // self.page_size]
* self.dtype.itemsize key_list.append(f"{key_}_k")
)
ptr_list.append(k_ptr)
key_ = keys[index // self.page_size]
key_list.append(f"{key_}_{layer_id}_k")
element_size = ( element_size = (
self.dtype.itemsize self.layer_num
* self.dtype.itemsize
* self.page_size * self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim) * (self.kv_lora_rank + self.qk_rope_head_dim)
) )
......
...@@ -223,13 +223,11 @@ class MooncakeStore(HiCacheStorage): ...@@ -223,13 +223,11 @@ class MooncakeStore(HiCacheStorage):
def exists(self, keys) -> bool | dict: def exists(self, keys) -> bool | dict:
_keys = [] _keys = []
local_rank = torch.cuda.current_device()
for key in keys: for key in keys:
if key is None: if key is None:
return None return None
# Since mooncake store is stored in layer by layer,
# only the first layer is checked here. _keys.append(f"{key}_k")
_keys.append(f"{key}_{local_rank}_k")
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))} result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
return result return result
......
...@@ -575,6 +575,11 @@ class ServerArgs: ...@@ -575,6 +575,11 @@ class ServerArgs:
"Pipeline parallelism is incompatible with overlap schedule." "Pipeline parallelism is incompatible with overlap schedule."
) )
if self.hicache_storage_backend == "mooncake":
# to use mooncake storage backend, the following conditions must be met:
self.hicache_io_backend = "kernel"
self.hicache_mem_layout = "page_first"
# Speculative Decoding # Speculative Decoding
if self.speculative_algorithm == "NEXTN": if self.speculative_algorithm == "NEXTN":
# NEXTN shares the same implementation of EAGLE # NEXTN shares the same implementation of EAGLE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment