Unverified Commit d4041a5e authored by pansicheng's avatar pansicheng Committed by GitHub
Browse files
parent 2f555c4c
...@@ -289,8 +289,6 @@ class HiCacheController: ...@@ -289,8 +289,6 @@ class HiCacheController:
) )
self.storage_backend = MooncakeStore(self.storage_config) self.storage_backend = MooncakeStore(self.storage_config)
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs": elif storage_backend == "hf3fs":
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import ( from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
HiCacheHF3FS, HiCacheHF3FS,
...@@ -313,6 +311,8 @@ class HiCacheController: ...@@ -313,6 +311,8 @@ class HiCacheController:
f"Unsupported storage backend: {storage_backend}" f"Unsupported storage backend: {storage_backend}"
) )
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
self.enable_storage = True self.enable_storage = True
# todo: threshold policy for prefetching # todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size) self.prefetch_threshold = max(prefetch_threshold, self.page_size)
...@@ -335,18 +335,10 @@ class HiCacheController: ...@@ -335,18 +335,10 @@ class HiCacheController:
# Select the get and set functions # Select the get and set functions
self.page_get_func = self._generic_page_get self.page_get_func = self._generic_page_get
self.page_set_func = self._generic_page_set self.page_set_func = self._generic_page_set
self.batch_exists_func = self.storage_backend.batch_exists
self.is_3fs_zerocopy = ( if self.storage_backend_type in ["hf3fs", "mooncake"]:
self.storage_backend_type == "hf3fs" self.page_get_func = self._page_get_zero_copy
and self.mem_pool_host.layout == "page_first" self.page_set_func = self._page_set_zero_copy
)
if self.storage_backend_type == "mooncake":
self.page_get_func = self._mooncake_page_get
self.page_set_func = self._mooncake_page_set
elif self.is_3fs_zerocopy:
self.page_get_func = self._3fs_zero_copy_page_get
self.page_set_func = self._3fs_zero_copy_page_set
self.batch_exists_func = self._3fs_zero_copy_batch_exists
self.device = self.mem_pool_device.device self.device = self.mem_pool_device.device
self.layer_num = self.mem_pool_device.layer_num self.layer_num = self.mem_pool_device.layer_num
...@@ -630,42 +622,19 @@ class HiCacheController: ...@@ -630,42 +622,19 @@ class HiCacheController:
for chunk in chunks: for chunk in chunks:
self.host_mem_release_queue.put(chunk) self.host_mem_release_queue.put(chunk)
def _3fs_zero_copy_batch_exists(self, batch_hashes): def _page_get_zero_copy(self, operation, hash_values, host_indices):
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes) results = self.storage_backend.batch_get_v1(hash_values, host_indices)
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor inc = 0
return hit_page_num for i in range(len(hash_values)):
if not results[i]:
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices): logger.warning(
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash( f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
hash_values, host_indices )
) break
page_data = self.storage_backend.batch_get(hashes, dsts) inc += self.page_size
if page_data: operation.increment(inc)
inc = self.page_size * len(hashes) // factor
operation.increment(inc)
else:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
)
def _mooncake_page_get(self, operation, hash_values, host_indices):
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
self.storage_config.tp_rank,
)
get_result = self.storage_backend.batch_get(
key_strs,
target_locations=buffer_ptrs,
target_sizes=buffer_sizes,
)
if get_result != len(hash_values):
logger.warning(
f"Prefetch operation {operation.request_id} failed or partially failed."
)
if get_result != 0:
operation.increment(get_result * self.page_size)
# todo: deprecate
def _generic_page_get(self, operation, hash_values, host_indices): def _generic_page_get(self, operation, hash_values, host_indices):
dummy_page_dst = [ dummy_page_dst = [
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
...@@ -755,7 +724,7 @@ class HiCacheController: ...@@ -755,7 +724,7 @@ class HiCacheController:
batch_tokens[i : i + self.page_size], last_hash batch_tokens[i : i + self.page_size], last_hash
) )
batch_hashes.append(last_hash) batch_hashes.append(last_hash)
hit_page_num = self.batch_exists_func(batch_hashes) hit_page_num = self.storage_backend.batch_exists(batch_hashes)
hash_value.extend(batch_hashes[:hit_page_num]) hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes): if hit_page_num < len(batch_hashes):
...@@ -824,34 +793,16 @@ class HiCacheController: ...@@ -824,34 +793,16 @@ class HiCacheController:
self.backup_queue.put(operation) self.backup_queue.put(operation)
return operation.id return operation.id
# non-zero copy # todo: deprecate
def _generic_page_set(self, hash_values, host_indices) -> bool: def _generic_page_set(self, hash_values, host_indices) -> bool:
data = [ data = [
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size]) self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
for i in range(len(hash_values)) for i in range(len(hash_values))
] ]
return self.storage_backend.batch_set(hash_values, data) return self.storage_backend.batch_set(hash_values, data)
# zero copy def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
def _mooncake_page_set(self, hash_values, host_indices) -> bool: return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
self.storage_config.tp_rank,
)
success = self.storage_backend.batch_set(
key_strs,
target_locations=buffer_ptrs,
target_sizes=buffer_sizes,
)
return success
# zero copy
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices
)
return self.storage_backend.batch_set(hashes, dsts)
# Backup batch by batch # Backup batch by batch
def _page_backup(self, operation): def _page_backup(self, operation):
......
...@@ -7,6 +7,8 @@ from typing import Any, List, Optional ...@@ -7,6 +7,8 @@ from typing import Any, List, Optional
import torch import torch
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -32,15 +34,46 @@ class HiCacheStorageConfig: ...@@ -32,15 +34,46 @@ class HiCacheStorageConfig:
extra_config: Optional[dict] = None extra_config: Optional[dict] = None
@dataclass
class HiCacheStorageExtraInfo:
extra_info: Optional[dict] = None
class HiCacheStorage(ABC): class HiCacheStorage(ABC):
""" """
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache. HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
It abstracts the underlying storage mechanism, allowing different implementations to be used. It abstracts the underlying storage mechanism, allowing different implementations to be used.
""" """
# todo, potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool # todo, the page size of storage backend does not have to be the same as the same as host memory pool
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
self.mem_pool_host = mem_pool_host
def batch_get_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
def batch_set_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
@abstractmethod @abstractmethod
def get( def get(
self, self,
...@@ -54,6 +87,7 @@ class HiCacheStorage(ABC): ...@@ -54,6 +87,7 @@ class HiCacheStorage(ABC):
""" """
pass pass
# TODO: Deprecate
@abstractmethod @abstractmethod
def batch_get( def batch_get(
self, self,
...@@ -81,6 +115,7 @@ class HiCacheStorage(ABC): ...@@ -81,6 +115,7 @@ class HiCacheStorage(ABC):
""" """
pass pass
# TODO: Deprecate
@abstractmethod @abstractmethod
def batch_set( def batch_set(
self, self,
...@@ -103,6 +138,7 @@ class HiCacheStorage(ABC): ...@@ -103,6 +138,7 @@ class HiCacheStorage(ABC):
""" """
pass pass
# TODO: Use a finer-grained return type (e.g., List[bool])
def batch_exists(self, keys: List[str]) -> int: def batch_exists(self, keys: List[str]) -> int:
""" """
Check if the keys exist in the storage. Check if the keys exist in the storage.
...@@ -114,6 +150,9 @@ class HiCacheStorage(ABC): ...@@ -114,6 +150,9 @@ class HiCacheStorage(ABC):
return i return i
return len(keys) return len(keys)
def clear(self) -> None:
pass
def get_stats(self): def get_stats(self):
return None return None
......
...@@ -140,7 +140,7 @@ class HostKVCache(abc.ABC): ...@@ -140,7 +140,7 @@ class HostKVCache(abc.ABC):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def get_flat_data_page(self, index) -> torch.Tensor: def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
""" """
Get a flat data page from the host memory pool. Get a flat data page from the host memory pool.
""" """
...@@ -461,16 +461,19 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -461,16 +461,19 @@ class MHATokenToKVPoolHost(HostKVCache):
else: else:
raise ValueError(f"Unsupported IO backend: {io_backend}") raise ValueError(f"Unsupported IO backend: {io_backend}")
def get_flat_data_page(self, index) -> torch.Tensor: def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
if self.layout == "layer_first": if self.layout == "layer_first":
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten() data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
elif self.layout == "page_first": elif self.layout == "page_first":
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten() data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
elif self.layout == "page_first_direct": elif self.layout == "page_first_direct":
real_index = index // self.page_size real_index = index // self.page_size
return self.kv_buffer[:, real_index : real_index + 1, :, :, :, :].flatten() data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
if flat:
data_page = data_page.flatten()
return data_page
def get_dummy_flat_data_page(self) -> torch.Tensor: def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros( return torch.zeros(
...@@ -507,9 +510,12 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -507,9 +510,12 @@ class MHATokenToKVPoolHost(HostKVCache):
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
def get_buffer_meta(self, keys, indices, local_rank): def get_page_buffer_meta(self, indices):
""" "
meta data for zero copy
"""
assert len(indices) % self.page_size == 0
ptr_list = [] ptr_list = []
key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr() kv_buffer_data_ptr = self.kv_buffer.data_ptr()
indices = indices.tolist() indices = indices.tolist()
v_offset = ( v_offset = (
...@@ -519,48 +525,52 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -519,48 +525,52 @@ class MHATokenToKVPoolHost(HostKVCache):
* self.head_dim * self.head_dim
* self.dtype.itemsize * self.dtype.itemsize
) )
for index in range(0, len(indices), self.page_size): if self.layout == "layer_first":
k_ptr = ( for index in range(0, len(indices), self.page_size):
kv_buffer_data_ptr for layer_id in range(self.layer_num):
+ indices[index] k_ptr = (
* self.layer_num kv_buffer_data_ptr
+ indices[index]
* self.head_num
* self.head_dim
* self.dtype.itemsize
+ layer_id
* self.size
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
v_ptr = k_ptr + v_offset
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
element_size = (
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
)
element_size_list = [element_size] * len(ptr_list)
elif self.layout in ["page_first", "page_first_direct"]:
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
v_ptr = k_ptr + v_offset
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
* self.head_num * self.head_num
* self.head_dim * self.head_dim
* self.dtype.itemsize
) )
v_ptr = k_ptr + v_offset element_size_list = [element_size] * len(ptr_list)
ptr_list.append(k_ptr) else:
ptr_list.append(v_ptr) raise ValueError(f"Unsupported layout: {self.layout}")
key_ = keys[index // self.page_size] return ptr_list, element_size_list
key_list.append(f"{key_}_{local_rank}_k")
key_list.append(f"{key_}_{local_rank}_v")
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
* self.head_num
* self.head_dim
)
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices=None):
assert self.layout == "page_first"
assert indices is None or (len(keys) == (len(indices) // self.page_size))
key_list = []
buf_list = []
for i in range(len(keys)):
key = keys[i]
key_list.append(f"{key}-k")
key_list.append(f"{key}-v")
if indices is not None:
index = indices[i * self.page_size]
buf_list.append(self.k_buffer[index : index + self.page_size])
buf_list.append(self.v_buffer[index : index + self.page_size])
return key_list, buf_list, 2
class MLATokenToKVPoolHost(HostKVCache): class MLATokenToKVPoolHost(HostKVCache):
...@@ -736,16 +746,19 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -736,16 +746,19 @@ class MLATokenToKVPoolHost(HostKVCache):
else: else:
raise ValueError(f"Unsupported IO backend: {io_backend}") raise ValueError(f"Unsupported IO backend: {io_backend}")
def get_flat_data_page(self, index) -> torch.Tensor: def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
if self.layout == "layer_first": if self.layout == "layer_first":
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten() data_page = self.kv_buffer[:, index : index + self.page_size, :, :]
elif self.layout == "page_first": elif self.layout == "page_first":
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten() data_page = self.kv_buffer[index : index + self.page_size, :, :, :]
elif self.layout == "page_first_direct": elif self.layout == "page_first_direct":
real_index = index // self.page_size real_index = index // self.page_size
return self.kv_buffer[real_index : real_index + 1, :, :, :, :].flatten() data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :]
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
if flat:
data_page = data_page.flatten()
return data_page
def get_dummy_flat_data_page(self) -> torch.Tensor: def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros( return torch.zeros(
...@@ -787,40 +800,51 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -787,40 +800,51 @@ class MLATokenToKVPoolHost(HostKVCache):
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
def get_buffer_meta(self, keys, indices, local_rank): def get_page_buffer_meta(self, indices):
""" "
meta data for zero copy
"""
assert len(indices) % self.page_size == 0
ptr_list = [] ptr_list = []
key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr() kv_buffer_data_ptr = self.kv_buffer.data_ptr()
indices = indices.tolist() indices = indices.tolist()
for index in range(0, len(indices), self.page_size): if self.layout == "layer_first":
k_ptr = ( for index in range(0, len(indices), self.page_size):
kv_buffer_data_ptr for layer_id in range(self.layer_num):
+ indices[index] k_ptr = (
* self.layer_num kv_buffer_data_ptr
+ indices[index]
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
+ layer_id
* self.size
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
)
ptr_list.append(k_ptr)
element_size = (
self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim) * (self.kv_lora_rank + self.qk_rope_head_dim)
)
element_size_list = [element_size] * len(ptr_list)
elif self.layout in ["page_first", "page_first_direct"]:
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
)
ptr_list.append(k_ptr)
element_size = (
self.layer_num
* self.dtype.itemsize * self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
) )
ptr_list.append(k_ptr) element_size_list = [element_size] * len(ptr_list)
key_ = keys[index // self.page_size] else:
key_list.append(f"{key_}_k") raise ValueError(f"Unsupported layout: {self.layout}")
element_size = ( return ptr_list, element_size_list
self.layer_num
* self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
)
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices=None):
assert self.layout == "page_first"
assert indices is None or (len(keys) == (len(indices) // self.page_size))
buf_list = []
if indices is not None:
for i in range(len(keys)):
index = indices[i * self.page_size]
buf_list.append(self.kv_buffer[index : index + self.page_size])
return keys, buf_list, 1
...@@ -12,7 +12,12 @@ from typing import Any, List, Optional, Tuple ...@@ -12,7 +12,12 @@ from typing import Any, List, Optional, Tuple
import torch import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig from sglang.srt.mem_cache.hicache_storage import (
HiCacheStorage,
HiCacheStorageConfig,
HiCacheStorageExtraInfo,
)
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
from sglang.srt.metrics.collector import StorageMetrics from sglang.srt.metrics.collector import StorageMetrics
...@@ -178,11 +183,14 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -178,11 +183,14 @@ class HiCacheHF3FS(HiCacheStorage):
self.skip_backup = True self.skip_backup = True
self.rank = 0 self.rank = 0
self.is_zero_copy = False
logger.info( logger.info(
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: " f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
f"file_path={self.file_path}, " f"file_path={self.file_path}, "
f"file_size={self.file_size / (2 ** 30):.2f} GB, " f"file_size={self.file_size / (2 ** 30):.2f} GB, "
f"num_pages={self.num_pages}" f"num_pages={self.num_pages}, "
f"is_mla_model={self.is_mla_model}"
) )
self.ac = AtomicCounter(self.numjobs) self.ac = AtomicCounter(self.numjobs)
...@@ -323,25 +331,12 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -323,25 +331,12 @@ class HiCacheHF3FS(HiCacheStorage):
use_mock_client=use_mock_client, use_mock_client=use_mock_client,
) )
def get(
self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
return self.batch_get(
[key],
[target_location] if target_location is not None else None,
[target_sizes] if target_sizes is not None else None,
)[0]
@synchronized() @synchronized()
def batch_get( def _batch_get(
self, self,
keys: List[str], keys: List[str],
target_locations: Optional[Any] = None, values: List[torch.Tensor],
target_sizes: Optional[Any] = None, ) -> List[bool]:
) -> List[torch.Tensor | None]:
page_indices = self.metadata_client.get_page_indices(self.rank, keys) page_indices = self.metadata_client.get_page_indices(self.rank, keys)
batch_indices, file_offsets = [], [] batch_indices, file_offsets = [], []
...@@ -350,15 +345,9 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -350,15 +345,9 @@ class HiCacheHF3FS(HiCacheStorage):
batch_indices.append(i) batch_indices.append(i)
file_offsets.append(page_index * self.bytes_per_page) file_offsets.append(page_index * self.bytes_per_page)
if target_locations is not None: for target_location in values:
for target_location in target_locations: assert target_location.is_contiguous()
assert target_location.is_contiguous() file_results = values
file_results = target_locations
else:
file_results = [
torch.empty(self.numel, dtype=self.dtype)
for _ in range(len(batch_indices))
]
start_time = time.perf_counter() start_time = time.perf_counter()
...@@ -379,12 +368,10 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -379,12 +368,10 @@ class HiCacheHF3FS(HiCacheStorage):
ionum / (end_time - start_time) * self.gb_per_page ionum / (end_time - start_time) * self.gb_per_page
) )
results = [None] * len(keys) results = [False] * len(keys)
for batch_index, file_result, read_result in zip( for batch_index, read_result in zip(batch_indices, read_results):
batch_indices, file_results, read_results
):
if read_result == self.bytes_per_page: if read_result == self.bytes_per_page:
results[batch_index] = file_result results[batch_index] = True
else: else:
logger.error( logger.error(
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed" f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
...@@ -392,28 +379,12 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -392,28 +379,12 @@ class HiCacheHF3FS(HiCacheStorage):
return results return results
def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
return self.batch_set(
[key],
[value] if value is not None else None,
[target_location] if target_location is not None else None,
[target_sizes] if target_sizes is not None else None,
)
@synchronized() @synchronized()
def batch_set( def _batch_set(
self, self,
keys: List[str], keys: List[str],
values: Optional[Any] = None, values: Optional[Any] = None,
target_locations: Optional[Any] = None, ) -> List[bool]:
target_sizes: Optional[Any] = None,
) -> bool:
# In MLA backend, only one rank needs to backup the KV cache # In MLA backend, only one rank needs to backup the KV cache
if self.skip_backup: if self.skip_backup:
return True return True
...@@ -474,7 +445,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -474,7 +445,7 @@ class HiCacheHF3FS(HiCacheStorage):
self.rank, written_keys_to_confirm, pages_to_release self.rank, written_keys_to_confirm, pages_to_release
) )
return all(results) return results
def delete(self, key: str) -> None: def delete(self, key: str) -> None:
self.metadata_client.delete_keys(self.rank, [key]) self.metadata_client.delete_keys(self.rank, [key])
...@@ -484,21 +455,25 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -484,21 +455,25 @@ class HiCacheHF3FS(HiCacheStorage):
return result[0] if result else False return result[0] if result else False
def batch_exists(self, keys: List[str]) -> int: def batch_exists(self, keys: List[str]) -> int:
factor = 1
if self.is_zero_copy and not self.is_mla_model:
keys = self._get_mha_zero_copy_keys(keys)
factor = 2
results = self.metadata_client.exists(self.rank, keys) results = self.metadata_client.exists(self.rank, keys)
for i in range(len(keys)):
if not results[i]:
return i
return len(keys) i = 0
while i < len(keys) and results[i]:
i += 1
return i // factor
def clear(self) -> bool: def clear(self) -> None:
try: try:
self.metadata_client.clear(self.rank) self.metadata_client.clear(self.rank)
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}") logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
return True
except Exception as e: except Exception as e:
logger.error(f"Failed to clear HiCacheHF3FS: {e}") logger.error(f"Failed to clear HiCacheHF3FS: {e}")
return False
def close(self) -> None: def close(self) -> None:
try: try:
...@@ -521,3 +496,139 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -521,3 +496,139 @@ class HiCacheHF3FS(HiCacheStorage):
self.prefetch_bandwidth.clear() self.prefetch_bandwidth.clear()
self.backup_bandwidth.clear() self.backup_bandwidth.clear()
return storage_metrics return storage_metrics
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
super().register_mem_pool_host(mem_pool_host)
self.is_zero_copy = self.mem_pool_host.layout == "page_first"
logger.info(f"{self.is_zero_copy=}")
def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
_keys = []
for k in keys:
_keys.append(f"{k}-k")
_keys.append(f"{k}-v")
return _keys
def _get_mha_zero_copy_values(
self, values: List[torch.Tensor]
) -> List[torch.Tensor]:
_values = []
for value in values:
_values.append(value[0])
_values.append(value[1])
return _values
def _batch_get_preprocess(self, keys, host_indices):
page_num = len(host_indices) // self.mem_pool_host.page_size
# host_indices to kv_buffer
flat = not self.is_zero_copy
values = (
[
self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat)
for i in range(page_num)
]
if self.is_zero_copy
else [
self.mem_pool_host.get_dummy_flat_data_page() for _ in range(page_num)
]
)
if self.is_zero_copy and not self.is_mla_model:
keys = self._get_mha_zero_copy_keys(keys)
values = self._get_mha_zero_copy_values(values)
return keys, values
def _batch_get_postprocess(self, host_indices, values, results):
page_num = len(host_indices) // self.mem_pool_host.page_size
if self.is_zero_copy:
if not self.is_mla_model:
results = [
(results[2 * i] and results[2 * i + 1]) for i in range(page_num)
]
results = results[:page_num]
return results
for i in range(page_num):
if not results[i]:
break
self.mem_pool_host.set_from_flat_data_page(
host_indices[i * self.mem_pool_host.page_size], values[i]
)
return results
def batch_get_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
keys, values = self._batch_get_preprocess(keys, host_indices)
results = self._batch_get(keys, values)
return self._batch_get_postprocess(host_indices, values, results)
def _batch_set_preprocess(self, keys, host_indices):
page_num = len(host_indices) // self.mem_pool_host.page_size
# host_indices to kv_buffer
flat = not self.is_zero_copy
values = [
self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat)
for i in range(page_num)
]
if self.is_zero_copy and not self.is_mla_model:
keys = self._get_mha_zero_copy_keys(keys)
values = self._get_mha_zero_copy_values(values)
return keys, values
def batch_set_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
len_keys = len(keys)
keys, values = self._batch_set_preprocess(keys, host_indices)
results = self._batch_set(keys, values)
return results
# Deprecated
def get(
self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
pass
# Deprecated
def batch_get(
self,
keys: List[str],
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None] | int:
pass
# Deprecated
def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
pass
# Deprecated
def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
pass
...@@ -7,7 +7,12 @@ from typing import Any, List, Optional ...@@ -7,7 +7,12 @@ from typing import Any, List, Optional
import torch import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig from sglang.srt.mem_cache.hicache_storage import (
HiCacheStorage,
HiCacheStorageConfig,
HiCacheStorageExtraInfo,
)
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
...@@ -183,7 +188,12 @@ class MooncakeStore(HiCacheStorage): ...@@ -183,7 +188,12 @@ class MooncakeStore(HiCacheStorage):
assert self.store.is_exist(warmup_key) == 1 assert self.store.is_exist(warmup_key) == 1
assert self.store.get(warmup_key) == warmup_value assert self.store.get(warmup_key) == warmup_value
def register_buffer(self, buffer: torch.Tensor) -> None: def register_mem_pool_host(self, mem_pool_host: HostKVCache):
super().register_mem_pool_host(mem_pool_host)
assert (
self.mem_pool_host.layout == "page_first"
), "mooncake store storage backend only support page first layout"
buffer = self.mem_pool_host.kv_buffer
try: try:
buffer_ptr = buffer.data_ptr() buffer_ptr = buffer.data_ptr()
buffer_size = buffer.numel() * buffer.element_size() buffer_size = buffer.numel() * buffer.element_size()
...@@ -194,6 +204,97 @@ class MooncakeStore(HiCacheStorage): ...@@ -194,6 +204,97 @@ class MooncakeStore(HiCacheStorage):
logger.error("Failed to register buffer to Mooncake Store: %s", err) logger.error("Failed to register buffer to Mooncake Store: %s", err)
raise TypeError("Mooncake Store Register Buffer Error.") from err raise TypeError("Mooncake Store Register Buffer Error.") from err
def _get_mha_buffer_meta(self, keys, indices):
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
key_list = []
for key_ in keys:
key_list.append(f"{key_}_{self.local_rank}_k")
key_list.append(f"{key_}_{self.local_rank}_v")
assert len(key_list) == len(ptr_list)
return key_list, ptr_list, element_size_list
def _get_mla_buffer_meta(self, keys, indices):
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
key_list = []
for key_ in keys:
key_list.append(f"{key_}_k")
assert len(key_list) == len(ptr_list)
return key_list, ptr_list, element_size_list
def _batch_preprocess(self, keys, host_indices):
assert len(keys) > 0
assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
if self.is_mla_backend:
return self._get_mla_buffer_meta(keys, host_indices)
else:
return self._get_mha_buffer_meta(keys, host_indices)
def _batch_postprocess(self, results: List[int], is_set_operate=False):
"""
refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
for batch_get_into, results is Vector of integers,
where each element is the number of bytes read on success, or a negative value on error
for batch_put_from, results is Vector of integers,
where each element is 0 on success, or a negative value on error
"""
if self.is_mla_backend:
return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
else:
kv_pairs = zip(results[::2], results[1::2])
return [
(
(k_res == 0 and v_res == 0)
if is_set_operate
else (k_res > 0 and v_res > 0)
)
for k_res, v_res in kv_pairs
]
def batch_get_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
get_results = self._get_batch_zero_copy_impl(
key_strs, buffer_ptrs, buffer_sizes
)
return self._batch_postprocess(get_results, is_set_operate=False)
def batch_set_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
exist_result = self._batch_exist(key_strs)
set_keys = []
set_buffer_ptrs = []
set_buffer_sizes = []
set_indices = []
set_results = [-1] * len(keys)
for i in range(len(keys)):
if exist_result[i] != 1:
set_keys.append(keys[i])
set_buffer_ptrs.append(buffer_ptrs[i])
set_buffer_sizes.append(buffer_sizes[i])
set_indices.append(i)
else:
set_results[i] = 0
# Only set non-existing keys to storage
if len(set_keys) > 0:
put_results = self._put_batch_zero_copy_impl(
key_strs, buffer_ptrs, buffer_sizes
)
for i in range(len(set_indices)):
set_results[set_indices[i]] = put_results[i]
return self._batch_postprocess(set_results, is_set_operate=True)
def set( def set(
self, self,
key, key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment