Unverified Commit 55349e36 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

support mooncake store dp attention (#9684)

parent e1f7cf57
...@@ -636,6 +636,7 @@ class HiCacheController: ...@@ -636,6 +636,7 @@ class HiCacheController:
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta( key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values, hash_values,
host_indices, host_indices,
self.storage_config.tp_rank,
) )
get_result = self.storage_backend.batch_get( get_result = self.storage_backend.batch_get(
key_strs, key_strs,
...@@ -838,6 +839,7 @@ class HiCacheController: ...@@ -838,6 +839,7 @@ class HiCacheController:
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta( key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values, hash_values,
host_indices, host_indices,
self.storage_config.tp_rank,
) )
success = self.storage_backend.batch_set( success = self.storage_backend.batch_set(
key_strs, key_strs,
......
...@@ -7,7 +7,6 @@ from functools import wraps ...@@ -7,7 +7,6 @@ from functools import wraps
import psutil import psutil
import torch import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.utils import is_npu from sglang.srt.utils import is_npu
...@@ -464,8 +463,7 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -464,8 +463,7 @@ class MHATokenToKVPoolHost(HostKVCache):
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
def get_buffer_meta(self, keys, indices): def get_buffer_meta(self, keys, indices, local_rank):
local_rank = get_tensor_model_parallel_rank()
ptr_list = [] ptr_list = []
key_list = [] key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr() kv_buffer_data_ptr = self.kv_buffer.data_ptr()
...@@ -704,7 +702,7 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -704,7 +702,7 @@ class MLATokenToKVPoolHost(HostKVCache):
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
def get_buffer_meta(self, keys, indices): def get_buffer_meta(self, keys, indices, local_rank):
ptr_list = [] ptr_list = []
key_list = [] key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr() kv_buffer_data_ptr = self.kv_buffer.data_ptr()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment