Unverified Commit d4041a5e authored by pansicheng's avatar pansicheng Committed by GitHub
Browse files
parent 2f555c4c
......@@ -289,8 +289,6 @@ class HiCacheController:
)
self.storage_backend = MooncakeStore(self.storage_config)
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs":
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
HiCacheHF3FS,
......@@ -313,6 +311,8 @@ class HiCacheController:
f"Unsupported storage backend: {storage_backend}"
)
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
......@@ -335,18 +335,10 @@ class HiCacheController:
# Select the get and set functions
self.page_get_func = self._generic_page_get
self.page_set_func = self._generic_page_set
self.batch_exists_func = self.storage_backend.batch_exists
self.is_3fs_zerocopy = (
self.storage_backend_type == "hf3fs"
and self.mem_pool_host.layout == "page_first"
)
if self.storage_backend_type == "mooncake":
self.page_get_func = self._mooncake_page_get
self.page_set_func = self._mooncake_page_set
elif self.is_3fs_zerocopy:
self.page_get_func = self._3fs_zero_copy_page_get
self.page_set_func = self._3fs_zero_copy_page_set
self.batch_exists_func = self._3fs_zero_copy_batch_exists
if self.storage_backend_type in ["hf3fs", "mooncake"]:
self.page_get_func = self._page_get_zero_copy
self.page_set_func = self._page_set_zero_copy
self.device = self.mem_pool_device.device
self.layer_num = self.mem_pool_device.layer_num
......@@ -630,42 +622,19 @@ class HiCacheController:
for chunk in chunks:
self.host_mem_release_queue.put(chunk)
def _3fs_zero_copy_batch_exists(self, batch_hashes):
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
return hit_page_num
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices
)
page_data = self.storage_backend.batch_get(hashes, dsts)
if page_data:
inc = self.page_size * len(hashes) // factor
operation.increment(inc)
else:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
)
def _mooncake_page_get(self, operation, hash_values, host_indices):
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
self.storage_config.tp_rank,
)
get_result = self.storage_backend.batch_get(
key_strs,
target_locations=buffer_ptrs,
target_sizes=buffer_sizes,
)
if get_result != len(hash_values):
logger.warning(
f"Prefetch operation {operation.request_id} failed or partially failed."
)
if get_result != 0:
operation.increment(get_result * self.page_size)
def _page_get_zero_copy(self, operation, hash_values, host_indices):
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
inc = 0
for i in range(len(hash_values)):
if not results[i]:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
)
break
inc += self.page_size
operation.increment(inc)
# todo: deprecate
def _generic_page_get(self, operation, hash_values, host_indices):
dummy_page_dst = [
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
......@@ -755,7 +724,7 @@ class HiCacheController:
batch_tokens[i : i + self.page_size], last_hash
)
batch_hashes.append(last_hash)
hit_page_num = self.batch_exists_func(batch_hashes)
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes):
......@@ -824,34 +793,16 @@ class HiCacheController:
self.backup_queue.put(operation)
return operation.id
# non-zero copy
# todo: deprecate
def _generic_page_set(self, hash_values, host_indices) -> bool:
data = [
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
for i in range(len(hash_values))
]
return self.storage_backend.batch_set(hash_values, data)
# zero copy
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
self.storage_config.tp_rank,
)
success = self.storage_backend.batch_set(
key_strs,
target_locations=buffer_ptrs,
target_sizes=buffer_sizes,
)
return success
# zero copy
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices
)
return self.storage_backend.batch_set(hashes, dsts)
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
# Backup batch by batch
def _page_backup(self, operation):
......
......@@ -7,6 +7,8 @@ from typing import Any, List, Optional
import torch
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
logger = logging.getLogger(__name__)
......@@ -32,15 +34,46 @@ class HiCacheStorageConfig:
extra_config: Optional[dict] = None
@dataclass
class HiCacheStorageExtraInfo:
extra_info: Optional[dict] = None
class HiCacheStorage(ABC):
"""
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
It abstracts the underlying storage mechanism, allowing different implementations to be used.
"""
# todo, potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
self.mem_pool_host = mem_pool_host
def batch_get_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
def batch_set_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
@abstractmethod
def get(
self,
......@@ -54,6 +87,7 @@ class HiCacheStorage(ABC):
"""
pass
# TODO: Deprecate
@abstractmethod
def batch_get(
self,
......@@ -81,6 +115,7 @@ class HiCacheStorage(ABC):
"""
pass
# TODO: Deprecate
@abstractmethod
def batch_set(
self,
......@@ -103,6 +138,7 @@ class HiCacheStorage(ABC):
"""
pass
# TODO: Use a finer-grained return type (e.g., List[bool])
def batch_exists(self, keys: List[str]) -> int:
"""
Check if the keys exist in the storage.
......@@ -114,6 +150,9 @@ class HiCacheStorage(ABC):
return i
return len(keys)
def clear(self) -> None:
pass
def get_stats(self):
return None
......
......@@ -140,7 +140,7 @@ class HostKVCache(abc.ABC):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data_page(self, index) -> torch.Tensor:
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
"""
Get a flat data page from the host memory pool.
"""
......@@ -461,16 +461,19 @@ class MHATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
def get_flat_data_page(self, index) -> torch.Tensor:
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
if self.layout == "layer_first":
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
elif self.layout == "page_first":
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
elif self.layout == "page_first_direct":
real_index = index // self.page_size
return self.kv_buffer[:, real_index : real_index + 1, :, :, :, :].flatten()
data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
else:
raise ValueError(f"Unsupported layout: {self.layout}")
if flat:
data_page = data_page.flatten()
return data_page
def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros(
......@@ -507,9 +510,12 @@ class MHATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
def get_buffer_meta(self, keys, indices, local_rank):
def get_page_buffer_meta(self, indices):
""" "
meta data for zero copy
"""
assert len(indices) % self.page_size == 0
ptr_list = []
key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
indices = indices.tolist()
v_offset = (
......@@ -519,48 +525,52 @@ class MHATokenToKVPoolHost(HostKVCache):
* self.head_dim
* self.dtype.itemsize
)
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
if self.layout == "layer_first":
for index in range(0, len(indices), self.page_size):
for layer_id in range(self.layer_num):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.head_num
* self.head_dim
* self.dtype.itemsize
+ layer_id
* self.size
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
v_ptr = k_ptr + v_offset
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
element_size = (
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
)
element_size_list = [element_size] * len(ptr_list)
elif self.layout in ["page_first", "page_first_direct"]:
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
v_ptr = k_ptr + v_offset
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
v_ptr = k_ptr + v_offset
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
key_ = keys[index // self.page_size]
key_list.append(f"{key_}_{local_rank}_k")
key_list.append(f"{key_}_{local_rank}_v")
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
* self.head_num
* self.head_dim
)
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices=None):
assert self.layout == "page_first"
assert indices is None or (len(keys) == (len(indices) // self.page_size))
key_list = []
buf_list = []
for i in range(len(keys)):
key = keys[i]
key_list.append(f"{key}-k")
key_list.append(f"{key}-v")
if indices is not None:
index = indices[i * self.page_size]
buf_list.append(self.k_buffer[index : index + self.page_size])
buf_list.append(self.v_buffer[index : index + self.page_size])
return key_list, buf_list, 2
element_size_list = [element_size] * len(ptr_list)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
return ptr_list, element_size_list
class MLATokenToKVPoolHost(HostKVCache):
......@@ -736,16 +746,19 @@ class MLATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
def get_flat_data_page(self, index) -> torch.Tensor:
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
if self.layout == "layer_first":
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
data_page = self.kv_buffer[:, index : index + self.page_size, :, :]
elif self.layout == "page_first":
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
data_page = self.kv_buffer[index : index + self.page_size, :, :, :]
elif self.layout == "page_first_direct":
real_index = index // self.page_size
return self.kv_buffer[real_index : real_index + 1, :, :, :, :].flatten()
data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :]
else:
raise ValueError(f"Unsupported layout: {self.layout}")
if flat:
data_page = data_page.flatten()
return data_page
def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros(
......@@ -787,40 +800,51 @@ class MLATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
def get_buffer_meta(self, keys, indices, local_rank):
def get_page_buffer_meta(self, indices):
""" "
meta data for zero copy
"""
assert len(indices) % self.page_size == 0
ptr_list = []
key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
indices = indices.tolist()
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
if self.layout == "layer_first":
for index in range(0, len(indices), self.page_size):
for layer_id in range(self.layer_num):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
+ layer_id
* self.size
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
)
ptr_list.append(k_ptr)
element_size = (
self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
)
element_size_list = [element_size] * len(ptr_list)
elif self.layout in ["page_first", "page_first_direct"]:
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
)
ptr_list.append(k_ptr)
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
)
ptr_list.append(k_ptr)
key_ = keys[index // self.page_size]
key_list.append(f"{key_}_k")
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
)
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices=None):
assert self.layout == "page_first"
assert indices is None or (len(keys) == (len(indices) // self.page_size))
buf_list = []
if indices is not None:
for i in range(len(keys)):
index = indices[i * self.page_size]
buf_list.append(self.kv_buffer[index : index + self.page_size])
return keys, buf_list, 1
element_size_list = [element_size] * len(ptr_list)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
return ptr_list, element_size_list
......@@ -12,7 +12,12 @@ from typing import Any, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.hicache_storage import (
HiCacheStorage,
HiCacheStorageConfig,
HiCacheStorageExtraInfo,
)
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
from sglang.srt.metrics.collector import StorageMetrics
......@@ -178,11 +183,14 @@ class HiCacheHF3FS(HiCacheStorage):
self.skip_backup = True
self.rank = 0
self.is_zero_copy = False
logger.info(
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
f"file_path={self.file_path}, "
f"file_size={self.file_size / (2 ** 30):.2f} GB, "
f"num_pages={self.num_pages}"
f"num_pages={self.num_pages}, "
f"is_mla_model={self.is_mla_model}"
)
self.ac = AtomicCounter(self.numjobs)
......@@ -323,25 +331,12 @@ class HiCacheHF3FS(HiCacheStorage):
use_mock_client=use_mock_client,
)
def get(
self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
return self.batch_get(
[key],
[target_location] if target_location is not None else None,
[target_sizes] if target_sizes is not None else None,
)[0]
@synchronized()
def batch_get(
def _batch_get(
self,
keys: List[str],
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]:
values: List[torch.Tensor],
) -> List[bool]:
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
batch_indices, file_offsets = [], []
......@@ -350,15 +345,9 @@ class HiCacheHF3FS(HiCacheStorage):
batch_indices.append(i)
file_offsets.append(page_index * self.bytes_per_page)
if target_locations is not None:
for target_location in target_locations:
assert target_location.is_contiguous()
file_results = target_locations
else:
file_results = [
torch.empty(self.numel, dtype=self.dtype)
for _ in range(len(batch_indices))
]
for target_location in values:
assert target_location.is_contiguous()
file_results = values
start_time = time.perf_counter()
......@@ -379,12 +368,10 @@ class HiCacheHF3FS(HiCacheStorage):
ionum / (end_time - start_time) * self.gb_per_page
)
results = [None] * len(keys)
for batch_index, file_result, read_result in zip(
batch_indices, file_results, read_results
):
results = [False] * len(keys)
for batch_index, read_result in zip(batch_indices, read_results):
if read_result == self.bytes_per_page:
results[batch_index] = file_result
results[batch_index] = True
else:
logger.error(
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
......@@ -392,28 +379,12 @@ class HiCacheHF3FS(HiCacheStorage):
return results
def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
return self.batch_set(
[key],
[value] if value is not None else None,
[target_location] if target_location is not None else None,
[target_sizes] if target_sizes is not None else None,
)
@synchronized()
def batch_set(
def _batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
) -> List[bool]:
# In MLA backend, only one rank needs to backup the KV cache
if self.skip_backup:
return True
......@@ -474,7 +445,7 @@ class HiCacheHF3FS(HiCacheStorage):
self.rank, written_keys_to_confirm, pages_to_release
)
return all(results)
return results
def delete(self, key: str) -> None:
self.metadata_client.delete_keys(self.rank, [key])
......@@ -484,21 +455,25 @@ class HiCacheHF3FS(HiCacheStorage):
return result[0] if result else False
def batch_exists(self, keys: List[str]) -> int:
factor = 1
if self.is_zero_copy and not self.is_mla_model:
keys = self._get_mha_zero_copy_keys(keys)
factor = 2
results = self.metadata_client.exists(self.rank, keys)
for i in range(len(keys)):
if not results[i]:
return i
return len(keys)
i = 0
while i < len(keys) and results[i]:
i += 1
return i // factor
def clear(self) -> bool:
def clear(self) -> None:
try:
self.metadata_client.clear(self.rank)
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
return True
except Exception as e:
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
return False
def close(self) -> None:
try:
......@@ -521,3 +496,139 @@ class HiCacheHF3FS(HiCacheStorage):
self.prefetch_bandwidth.clear()
self.backup_bandwidth.clear()
return storage_metrics
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
super().register_mem_pool_host(mem_pool_host)
self.is_zero_copy = self.mem_pool_host.layout == "page_first"
logger.info(f"{self.is_zero_copy=}")
def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
_keys = []
for k in keys:
_keys.append(f"{k}-k")
_keys.append(f"{k}-v")
return _keys
def _get_mha_zero_copy_values(
self, values: List[torch.Tensor]
) -> List[torch.Tensor]:
_values = []
for value in values:
_values.append(value[0])
_values.append(value[1])
return _values
def _batch_get_preprocess(self, keys, host_indices):
page_num = len(host_indices) // self.mem_pool_host.page_size
# host_indices to kv_buffer
flat = not self.is_zero_copy
values = (
[
self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat)
for i in range(page_num)
]
if self.is_zero_copy
else [
self.mem_pool_host.get_dummy_flat_data_page() for _ in range(page_num)
]
)
if self.is_zero_copy and not self.is_mla_model:
keys = self._get_mha_zero_copy_keys(keys)
values = self._get_mha_zero_copy_values(values)
return keys, values
def _batch_get_postprocess(self, host_indices, values, results):
page_num = len(host_indices) // self.mem_pool_host.page_size
if self.is_zero_copy:
if not self.is_mla_model:
results = [
(results[2 * i] and results[2 * i + 1]) for i in range(page_num)
]
results = results[:page_num]
return results
for i in range(page_num):
if not results[i]:
break
self.mem_pool_host.set_from_flat_data_page(
host_indices[i * self.mem_pool_host.page_size], values[i]
)
return results
def batch_get_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
keys, values = self._batch_get_preprocess(keys, host_indices)
results = self._batch_get(keys, values)
return self._batch_get_postprocess(host_indices, values, results)
def _batch_set_preprocess(self, keys, host_indices):
page_num = len(host_indices) // self.mem_pool_host.page_size
# host_indices to kv_buffer
flat = not self.is_zero_copy
values = [
self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat)
for i in range(page_num)
]
if self.is_zero_copy and not self.is_mla_model:
keys = self._get_mha_zero_copy_keys(keys)
values = self._get_mha_zero_copy_values(values)
return keys, values
def batch_set_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
len_keys = len(keys)
keys, values = self._batch_set_preprocess(keys, host_indices)
results = self._batch_set(keys, values)
return results
# Deprecated
def get(
self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
pass
# Deprecated
def batch_get(
self,
keys: List[str],
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None] | int:
pass
# Deprecated
def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
pass
# Deprecated
def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
pass
......@@ -7,7 +7,12 @@ from typing import Any, List, Optional
import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.hicache_storage import (
HiCacheStorage,
HiCacheStorageConfig,
HiCacheStorageExtraInfo,
)
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
......@@ -183,7 +188,12 @@ class MooncakeStore(HiCacheStorage):
assert self.store.is_exist(warmup_key) == 1
assert self.store.get(warmup_key) == warmup_value
def register_buffer(self, buffer: torch.Tensor) -> None:
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
super().register_mem_pool_host(mem_pool_host)
assert (
self.mem_pool_host.layout == "page_first"
), "mooncake store storage backend only support page first layout"
buffer = self.mem_pool_host.kv_buffer
try:
buffer_ptr = buffer.data_ptr()
buffer_size = buffer.numel() * buffer.element_size()
......@@ -194,6 +204,97 @@ class MooncakeStore(HiCacheStorage):
logger.error("Failed to register buffer to Mooncake Store: %s", err)
raise TypeError("Mooncake Store Register Buffer Error.") from err
def _get_mha_buffer_meta(self, keys, indices):
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
key_list = []
for key_ in keys:
key_list.append(f"{key_}_{self.local_rank}_k")
key_list.append(f"{key_}_{self.local_rank}_v")
assert len(key_list) == len(ptr_list)
return key_list, ptr_list, element_size_list
def _get_mla_buffer_meta(self, keys, indices):
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
key_list = []
for key_ in keys:
key_list.append(f"{key_}_k")
assert len(key_list) == len(ptr_list)
return key_list, ptr_list, element_size_list
def _batch_preprocess(self, keys, host_indices):
assert len(keys) > 0
assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
if self.is_mla_backend:
return self._get_mla_buffer_meta(keys, host_indices)
else:
return self._get_mha_buffer_meta(keys, host_indices)
def _batch_postprocess(self, results: List[int], is_set_operate=False):
"""
refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
for batch_get_into, results is Vector of integers,
where each element is the number of bytes read on success, or a negative value on error
for batch_put_from, results is Vector of integers,
where each element is 0 on success, or a negative value on error
"""
if self.is_mla_backend:
return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
else:
kv_pairs = zip(results[::2], results[1::2])
return [
(
(k_res == 0 and v_res == 0)
if is_set_operate
else (k_res > 0 and v_res > 0)
)
for k_res, v_res in kv_pairs
]
def batch_get_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
get_results = self._get_batch_zero_copy_impl(
key_strs, buffer_ptrs, buffer_sizes
)
return self._batch_postprocess(get_results, is_set_operate=False)
def batch_set_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
exist_result = self._batch_exist(key_strs)
set_keys = []
set_buffer_ptrs = []
set_buffer_sizes = []
set_indices = []
set_results = [-1] * len(keys)
for i in range(len(keys)):
if exist_result[i] != 1:
set_keys.append(keys[i])
set_buffer_ptrs.append(buffer_ptrs[i])
set_buffer_sizes.append(buffer_sizes[i])
set_indices.append(i)
else:
set_results[i] = 0
# Only set non-existing keys to storage
if len(set_keys) > 0:
put_results = self._put_batch_zero_copy_impl(
key_strs, buffer_ptrs, buffer_sizes
)
for i in range(len(set_indices)):
set_results[set_indices[i]] = put_results[i]
return self._batch_postprocess(set_results, is_set_operate=True)
def set(
self,
key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment