Unverified Commit 80dc76e1 authored by ykwd's avatar ykwd Committed by GitHub
Browse files

[Fix] HiCache Bugfix & Mooncake Error Handling Enhance (#8901)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 9b08d975
...@@ -27,7 +27,7 @@ if TYPE_CHECKING: ...@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -240,28 +240,38 @@ class HiCacheController: ...@@ -240,28 +240,38 @@ class HiCacheController:
self.io_backend = io_backend self.io_backend = io_backend
self.enable_storage = False self.enable_storage = False
self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
# todo: move backend initialization to storage backend module # todo: move backend initialization to storage backend module
if storage_backend is not None: if storage_backend is not None:
self.storage_backend_type = storage_backend self.storage_backend_type = storage_backend
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str from sglang.srt.mem_cache.hicache_storage import get_hash_str
if storage_backend == "file":
self.storage_backend = HiCacheFile(is_mla=self.is_mla)
self.get_hash_str = get_hash_str self.get_hash_str = get_hash_str
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
# In MLA backend, only one rank needs to backup the KV cache
self.backup_skip = (
is_mla_backend
# todo: for load balancing, decide which rank to backup the KV cache by hash value
and get_tensor_model_parallel_rank() != 0
# todo: support other storage backends
and self.storage_backend_type in ["file", "mooncake"]
)
if storage_backend == "file":
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
self.storage_backend = HiCacheFile(is_mla_backend=is_mla_backend)
elif storage_backend == "nixl": elif storage_backend == "nixl":
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
self.storage_backend = HiCacheNixl() self.storage_backend = HiCacheNixl()
self.get_hash_str = get_hash_str
elif storage_backend == "mooncake": elif storage_backend == "mooncake":
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import ( from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
MooncakeStore, MooncakeStore,
get_hash_str_mooncake,
) )
self.storage_backend = MooncakeStore(is_mla=self.is_mla) self.storage_backend = MooncakeStore(is_mla_backend=is_mla_backend)
self.get_hash_str = get_hash_str_mooncake
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first" assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs": elif storage_backend == "hf3fs":
...@@ -281,7 +291,6 @@ class HiCacheController: ...@@ -281,7 +291,6 @@ class HiCacheController:
self.storage_backend = HiCacheHF3FS.from_env_config( self.storage_backend = HiCacheHF3FS.from_env_config(
bytes_per_page, dtype bytes_per_page, dtype
) )
self.get_hash_str = get_hash_str
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}" f"Unsupported storage backend: {storage_backend}"
...@@ -400,15 +409,6 @@ class HiCacheController: ...@@ -400,15 +409,6 @@ class HiCacheController:
self.prefetch_thread.start() self.prefetch_thread.start()
self.backup_thread.start() self.backup_thread.start()
@property
def backup_skip(self):
return (
self.is_mla
and get_tensor_model_parallel_rank() != 0
# todo: only support file and mooncake
and self.storage_backend_type in ["file", "mooncake"]
)
def write( def write(
self, self,
device_indices: torch.Tensor, device_indices: torch.Tensor,
...@@ -570,57 +570,91 @@ class HiCacheController: ...@@ -570,57 +570,91 @@ class HiCacheController:
operation.mark_done() operation.mark_done()
return operation.completed_tokens, operation.hash_value return operation.completed_tokens, operation.hash_value
def zerocopy_page_transfer(self, operation, batch_size=8): # zero copy
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
hashes, dsts = self.mem_pool_host.get_buffer_with_hash( hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
operation.hash_value, operation.host_indices hash_values, host_indices
) )
for i in range(0, len(hashes), batch_size): page_data = self.storage_backend.batch_get(hashes, dsts)
page_hashes = hashes[i : i + batch_size] if page_data:
page_dsts = dsts[i : i + batch_size] operation.increment(self.page_size * len(hashes))
page_data = self.storage_backend.batch_get(page_hashes, page_dsts) else:
if page_data is None:
logger.warning( logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}." f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
) )
break
completed_tokens = operation.completed_tokens
if operation.increment(self.page_size * len(page_hashes)):
for i in range(len(page_hashes)):
completed_tokens += self.page_size
else:
break
def generic_page_transfer(self, operation, batch_size=8): # zero copy
for i in range(0, len(operation.hash_value), batch_size): def _mooncake_page_get(self, operation, hash_values, host_indices):
page_hashes = operation.hash_value[i : i + batch_size] key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
)
get_result = self.storage_backend.batch_get(
key_strs,
target_location=buffer_ptrs,
target_sizes=buffer_sizes,
)
if get_result != len(hash_values):
logger.warning(
f"Prefetch operation {operation.request_id} failed or partially failed."
)
if get_result != 0:
operation.increment(get_result * self.page_size)
# non-zero copy
def _generic_page_get(self, operation, hash_values, host_indices):
# todo: zero copy # todo: zero copy
dummy_page_dst = [ dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
self.mem_pool_host.get_dummy_flat_data_page() hash_values
for _ in range(len(page_hashes)) )
] page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
if page_data is None: if page_data is None:
return
for i in range(len(hash_values)):
if page_data[i] is None:
logger.warning( logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}." f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
) )
break break
completed_tokens = operation.completed_tokens
if operation.increment(self.page_size * len(page_hashes)):
for i in range(len(page_hashes)):
self.mem_pool_host.set_from_flat_data_page( self.mem_pool_host.set_from_flat_data_page(
operation.host_indices[completed_tokens], host_indices[operation.completed_tokens],
page_data[i], page_data[i],
) )
completed_tokens += self.page_size if not operation.increment(self.page_size):
break # Operation terminated by controller
def _page_transfer(self, operation):
# Select the get function and batch size
if self.is_mooncake_backend():
get_func = self._mooncake_page_get
batch_size = 128
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
get_func = self._3fs_zero_copy_page_get
elif self.mem_pool_host.layout == "layer_first":
get_func = self._generic_page_get
batch_size = 128
else: else:
break get_func = self._generic_page_get
batch_size = 8
def mooncake_page_transfer(self, operation): # Transfer batch by batch
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta( for i in range(0, len(operation.hash_value), batch_size):
operation.hash_value, operation.host_indices batch_hashes = operation.hash_value[i : i + batch_size]
) batch_host_indices = operation.host_indices[
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes) i * self.page_size : (i + len(batch_hashes)) * self.page_size
operation.increment(len(operation.hash_value) * self.page_size) ]
prev_completed_tokens = operation.completed_tokens
# Get one batch token, and update the completed_tokens if succeed
get_func(operation, batch_hashes, batch_host_indices)
# Check termination
if (
operation.completed_tokens
!= prev_completed_tokens + len(batch_hashes) * self.page_size
):
break # Some operations fail or operation terminated by controller
# release pre-allocated memory
self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :])
def is_mooncake_backend(self): def is_mooncake_backend(self):
return self.storage_backend_type == "mooncake" return self.storage_backend_type == "mooncake"
...@@ -632,15 +666,7 @@ class HiCacheController: ...@@ -632,15 +666,7 @@ class HiCacheController:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.prefetch_buffer.get(block=True, timeout=1) operation = self.prefetch_buffer.get(block=True, timeout=1)
if self.is_mooncake_backend(): self._page_transfer(operation)
self.mooncake_page_transfer(operation)
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
self.zerocopy_page_transfer(operation, batch_size=128)
elif self.mem_pool_host.layout == "layer_first":
self.generic_page_transfer(operation, batch_size=128)
else:
self.generic_page_transfer(operation)
if self.tp_world_size > 1: if self.tp_world_size > 1:
# to ensure all TP workers release the host memory at the same time # to ensure all TP workers release the host memory at the same time
...@@ -662,6 +688,27 @@ class HiCacheController: ...@@ -662,6 +688,27 @@ class HiCacheController:
# todo: more sophisticated rate limiting based on storage backend performance # todo: more sophisticated rate limiting based on storage backend performance
return True return True
def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]:
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
storage_query_count = 0
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_query_count : storage_query_count + self.page_size
],
last_hash,
)
hash_value.append(last_hash)
storage_query_count += self.page_size
remaining_tokens -= self.page_size
# deferring to batch exists
hit_page_num = self.storage_backend.batch_exists(hash_value)
return hash_value[:hit_page_num], hit_page_num * self.page_size
def prefetch_thread_func(self): def prefetch_thread_func(self):
""" """
Manage prefetching operations from storage backend to host memory. Manage prefetching operations from storage backend to host memory.
...@@ -675,37 +722,11 @@ class HiCacheController: ...@@ -675,37 +722,11 @@ class HiCacheController:
if operation is None: if operation is None:
continue continue
storage_hit_count = 0
if ( if (
operation.host_indices is not None operation.host_indices is not None
) and self.prefetch_rate_limit_check(): ) and self.prefetch_rate_limit_check():
last_hash = operation.last_hash hash_value, storage_hit_count = self._generic_storage_hit_query(
tokens_to_fetch = operation.token_ids operation
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
# todo, more unified interface
if not self.is_mooncake_backend():
if not self.storage_backend.exists(last_hash):
break
hash_value.append(last_hash)
storage_hit_count += self.page_size
remaining_tokens -= self.page_size
if self.is_mooncake_backend():
# deferring to batch exists for mooncake store
exist_result = self.storage_backend.exists(hash_value)
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0)
* self.page_size
) )
if self.tp_world_size > 1: if self.tp_world_size > 1:
...@@ -755,59 +776,64 @@ class HiCacheController: ...@@ -755,59 +776,64 @@ class HiCacheController:
self.backup_queue.put(operation) self.backup_queue.put(operation)
return operation.id return operation.id
def zerocopy_page_backup(self, operation, batch_size=8): # non-zero copy
def _generic_page_set(self, hash_values, host_indices) -> bool:
data = [
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
for i in range(len(hash_values))
]
return self.storage_backend.batch_set(hash_values, data)
# zero copy
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
)
success = self.storage_backend.batch_set(
key_strs,
target_location=buffer_ptrs,
target_sizes=buffer_sizes,
)
return success
# zero copy
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
hashes, dsts = self.mem_pool_host.get_buffer_with_hash( hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
operation.hash_value, operation.host_indices hash_values, host_indices
) )
for i in range(0, len(hashes), batch_size): return self.storage_backend.batch_set(hashes, dsts)
page_hashes = hashes[i : i + batch_size]
page_data = dsts[i : i + batch_size]
success = self.storage_backend.batch_set(page_hashes, page_data)
if not success:
logger.warning(f"Failed to write page {page_hashes} to storage.")
break
operation.completed_tokens += self.page_size * len(page_hashes)
def generic_page_backup(self, operation, batch_size=8): # Backup batch by batch
def _page_backup(self, operation):
# Select the set function and batch size
if self.is_mooncake_backend():
backup_set_func = self._mooncake_page_set
batch_size = 128
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
backup_set_func = self._3fs_zero_copy_page_set
elif self.mem_pool_host.layout == "layer_first":
backup_set_func = self._generic_page_set
batch_size = 128
else:
backup_set_func = self._generic_page_set
batch_size = 8
# Backup batch by batch
for i in range(0, len(operation.hash_value), batch_size): for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size] batch_hashes = operation.hash_value[i : i + batch_size]
page_data = [ batch_host_indices = operation.host_indices[
self.mem_pool_host.get_flat_data_page( i * self.page_size : (i + len(batch_hashes)) * self.page_size
operation.host_indices[j * self.page_size]
)
for j in range(i, i + len(page_hashes))
] ]
success = self.storage_backend.batch_set(page_hashes, page_data) # Set one batch token, and record if success.
# todo: allow partial success
success = backup_set_func(batch_hashes, batch_host_indices)
if not success: if not success:
logger.warning(f"Failed to write page {page_hashes} to storage.") logger.warning(
break f"Write page to storage: {len(batch_hashes)} pages failed."
operation.completed_tokens += self.page_size * len(page_hashes)
def mooncake_page_backup(self, operation):
if len(operation.hash_value):
exist_hashvalues = self.storage_backend.exists(operation.hash_value)
indices = operation.host_indices.tolist()
non_exist_keys = []
non_exist_indices = []
for i in range(len(operation.hash_value)):
if not exist_hashvalues[operation.hash_value[i]]:
non_exist_keys.append(operation.hash_value[i])
non_exist_indices.extend(
indices[i * self.page_size : (i + 1) * self.page_size]
)
if len(non_exist_keys) > 0:
key_strs, buffer_ptrs, buffer_sizes = (
self.mem_pool_host.get_buffer_meta(
non_exist_keys, non_exist_indices
)
)
# TODO: check the return value of batch set to see how many tokens are set successfully
self.storage_backend.batch_set(
key_strs,
target_location=buffer_ptrs,
target_sizes=buffer_sizes,
) )
operation.completed_tokens += len(operation.hash_value) * self.page_size break
operation.completed_tokens += self.page_size * len(batch_hashes)
def backup_thread_func(self): def backup_thread_func(self):
""" """
...@@ -820,15 +846,7 @@ class HiCacheController: ...@@ -820,15 +846,7 @@ class HiCacheController:
continue continue
if not self.backup_skip: if not self.backup_skip:
if self.is_mooncake_backend(): self._page_backup(operation)
self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs":
if self.mem_pool_host.layout == "page_first":
self.zerocopy_page_backup(operation, batch_size=128)
elif self.mem_pool_host.layout == "layer_first":
self.generic_page_backup(operation, batch_size=128)
else:
self.generic_page_backup(operation)
min_completed_tokens = operation.completed_tokens min_completed_tokens = operation.completed_tokens
else: else:
min_completed_tokens = len(operation.token_ids) min_completed_tokens = len(operation.token_ids)
......
...@@ -60,7 +60,7 @@ class HiCacheStorage(ABC): ...@@ -60,7 +60,7 @@ class HiCacheStorage(ABC):
keys: List[str], keys: List[str],
target_locations: Optional[Any] = None, target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None, target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]: ) -> List[torch.Tensor | None] | int:
""" """
Retrieve values for multiple keys. Retrieve values for multiple keys.
Returns a list of tensors or None for each key. Returns a list of tensors or None for each key.
...@@ -96,17 +96,28 @@ class HiCacheStorage(ABC): ...@@ -96,17 +96,28 @@ class HiCacheStorage(ABC):
pass pass
@abstractmethod @abstractmethod
def exists(self, key: str) -> bool | dict: def exists(self, key: str) -> bool:
""" """
Check if the key exists in the storage. Check if the key exists in the storage.
Returns True if the key exists, False otherwise. Returns True if the key exists, False otherwise.
""" """
pass pass
def batch_exists(self, keys: List[str]) -> int:
"""
Check if the keys exist in the storage.
return the number of consecutive existing keys from the start.
Can be overridden by subclasses for more efficient implementation.
"""
for i in range(len(keys)):
if not self.exists(keys[i]):
return i
return len(keys)
class HiCacheFile(HiCacheStorage): class HiCacheFile(HiCacheStorage):
def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False): def __init__(self, file_path: str = "/tmp/hicache", is_mla_backend: bool = False):
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
if is_dp_attention_enabled(): if is_dp_attention_enabled():
tp_rank = get_attention_tp_rank() tp_rank = get_attention_tp_rank()
...@@ -115,7 +126,9 @@ class HiCacheFile(HiCacheStorage): ...@@ -115,7 +126,9 @@ class HiCacheFile(HiCacheStorage):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else "" self.tp_suffix = (
f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla_backend else ""
)
if not os.path.exists(self.file_path) and tp_rank == 0: if not os.path.exists(self.file_path) and tp_rank == 0:
os.makedirs(self.file_path) os.makedirs(self.file_path)
logger.info(f"Created HiCacheFile storage directory at {self.file_path}") logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
......
...@@ -465,6 +465,7 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -465,6 +465,7 @@ class MHATokenToKVPoolHost(HostKVCache):
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
def get_buffer_meta(self, keys, indices): def get_buffer_meta(self, keys, indices):
local_rank = get_tensor_model_parallel_rank()
ptr_list = [] ptr_list = []
key_list = [] key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr() kv_buffer_data_ptr = self.kv_buffer.data_ptr()
...@@ -488,8 +489,8 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -488,8 +489,8 @@ class MHATokenToKVPoolHost(HostKVCache):
ptr_list.append(k_ptr) ptr_list.append(k_ptr)
ptr_list.append(v_ptr) ptr_list.append(v_ptr)
key_ = keys[index // self.page_size] key_ = keys[index // self.page_size]
key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k") key_list.append(f"{key_}_{local_rank}_k")
key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v") key_list.append(f"{key_}_{local_rank}_v")
element_size = ( element_size = (
self.layer_num self.layer_num
* self.dtype.itemsize * self.dtype.itemsize
...@@ -704,6 +705,7 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -704,6 +705,7 @@ class MLATokenToKVPoolHost(HostKVCache):
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
def get_buffer_meta(self, keys, indices): def get_buffer_meta(self, keys, indices):
local_rank = get_tensor_model_parallel_rank()
ptr_list = [] ptr_list = []
key_list = [] key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr() kv_buffer_data_ptr = self.kv_buffer.data_ptr()
...@@ -717,7 +719,7 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -717,7 +719,7 @@ class MLATokenToKVPoolHost(HostKVCache):
) )
ptr_list.append(k_ptr) ptr_list.append(k_ptr)
key_ = keys[index // self.page_size] key_ = keys[index // self.page_size]
key_list.append(f"{key_}_k") key_list.append(f"{key_}_{local_rank}_k")
element_size = ( element_size = (
self.layer_num self.layer_num
* self.dtype.itemsize * self.dtype.itemsize
......
...@@ -55,12 +55,11 @@ Launch Mooncake meta server: ...@@ -55,12 +55,11 @@ Launch Mooncake meta server:
python -m mooncake.http_metadata_server python -m mooncake.http_metadata_server
``` ```
Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables: Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables. Note that, for optimal performance, the Mooncake backend currently supports only the `page_first` layout.
```bash ```bash
MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \ MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \
MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \ MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \
MOONCAKE_LOCAL_BUFFER_SIZE=134217728 \
MOONCAKE_PROTOCOL="rdma" \ MOONCAKE_PROTOCOL="rdma" \
MOONCAKE_DEVICE="erdma_0,erdma_1" \ MOONCAKE_DEVICE="erdma_0,erdma_1" \
MOONCAKE_MASTER=127.0.0.1:50051 \ MOONCAKE_MASTER=127.0.0.1:50051 \
......
...@@ -13,21 +13,11 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank ...@@ -13,21 +13,11 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
prefix_str = ""
if prior_hash:
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
current_token_ids_bytes = np.array(token_ids).tobytes()
current_hash_object = hashlib.sha256(current_token_ids_bytes)
current_hash_hex = current_hash_object.hexdigest()
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
@dataclass @dataclass
class MooncakeStoreConfig: class MooncakeStoreConfig:
local_hostname: str local_hostname: str
...@@ -54,9 +44,8 @@ class MooncakeStoreConfig: ...@@ -54,9 +44,8 @@ class MooncakeStoreConfig:
global_segment_size=config.get( global_segment_size=config.get(
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
), ),
local_buffer_size=config.get( # Zero copy interface does not need local buffer
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
),
protocol=config.get("protocol", "tcp"), protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", "auto"), device_name=config.get("device_name", "auto"),
master_server_address=config.get("master_server_address"), master_server_address=config.get("master_server_address"),
...@@ -79,9 +68,8 @@ class MooncakeStoreConfig: ...@@ -79,9 +68,8 @@ class MooncakeStoreConfig:
global_segment_size=int( global_segment_size=int(
os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE) os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
), ),
local_buffer_size=int( # Zero copy interface does not need local buffer
os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE) local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
),
protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"), protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
device_name=os.getenv("MOONCAKE_DEVICE", "auto"), device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
master_server_address=os.getenv("MOONCAKE_MASTER"), master_server_address=os.getenv("MOONCAKE_MASTER"),
...@@ -96,7 +84,15 @@ class MooncakeStoreConfig: ...@@ -96,7 +84,15 @@ class MooncakeStoreConfig:
class MooncakeStore(HiCacheStorage): class MooncakeStore(HiCacheStorage):
def __init__(self, is_mla: bool = False): def __init__(self, is_mla_backend: bool = False):
"""
Initialize MooncakeStore.
Args:
is_mla_backend: If the backend is MLA
"""
self.is_mla_backend = is_mla_backend
try: try:
from mooncake.store import MooncakeDistributedStore from mooncake.store import MooncakeDistributedStore
except ImportError as e: except ImportError as e:
...@@ -126,7 +122,6 @@ class MooncakeStore(HiCacheStorage): ...@@ -126,7 +122,6 @@ class MooncakeStore(HiCacheStorage):
logger.info("Connect to Mooncake store successfully.") logger.info("Connect to Mooncake store successfully.")
self.warmup() self.warmup()
logger.info("Mooncake store warmup successfully.") logger.info("Mooncake store warmup successfully.")
self.is_mla = is_mla
except ValueError as e: except ValueError as e:
logger.error("Configuration loading failed: %s", e) logger.error("Configuration loading failed: %s", e)
...@@ -135,14 +130,14 @@ class MooncakeStore(HiCacheStorage): ...@@ -135,14 +130,14 @@ class MooncakeStore(HiCacheStorage):
logger.error("An error occurred while loading the configuration: %s", exc) logger.error("An error occurred while loading the configuration: %s", exc)
raise raise
self.local_rank = get_tensor_model_parallel_rank()
def warmup(self): def warmup(self):
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
# 10 MB warmup_value = bytes(4 * 1024) # 4 KB
warmup_value = bytes(10 * 1024 * 1024) assert self.store.put(warmup_key, warmup_value) == 0
self.store.put(warmup_key, warmup_value)
assert self.store.is_exist(warmup_key) == 1 assert self.store.is_exist(warmup_key) == 1
self.store.get(warmup_key) assert self.store.get(warmup_key) == warmup_value
self.store.remove(warmup_key)
def register_buffer(self, buffer: torch.Tensor) -> None: def register_buffer(self, buffer: torch.Tensor) -> None:
try: try:
...@@ -162,78 +157,95 @@ class MooncakeStore(HiCacheStorage): ...@@ -162,78 +157,95 @@ class MooncakeStore(HiCacheStorage):
target_location: Optional[List[int]] = None, target_location: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None, target_sizes: Optional[List[int]] = None,
) -> bool: ) -> bool:
assert len(key) == len(target_location) == len(target_sizes) return self.batch_set([key], [value], [target_location], [target_sizes])
if len(key) == 0:
return
for i in range(len(key)):
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
return
self._put_batch_zero_copy_impl(key, target_location, target_sizes)
def batch_set( def batch_set(
self, self,
keys: List[str], keys: List[str],
value: Optional[Any] = None,
target_location: Optional[List[int]] = None, target_location: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None, target_sizes: Optional[List[int]] = None,
) -> bool: ) -> bool:
assert len(keys) == len(target_location) == len(target_sizes) assert len(keys) == len(target_location) == len(target_sizes)
if len(keys) == 0: if len(keys) == 0:
return return False
for i in range(len(keys)): for i in range(len(keys)):
if keys[i] is None or target_location[i] is None or target_sizes[i] is None: if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
return return False
self._put_batch_zero_copy_impl(keys, target_location, target_sizes) exist_result = self._batch_exist(keys)
set_keys = []
set_target_locations = []
set_target_sizes = []
set_indices = []
for i in range(len(keys)):
if exist_result[i] != 1:
set_keys.append(keys[i])
set_target_locations.append(target_location[i])
set_target_sizes.append(target_sizes[i])
set_indices.append(i)
# Only set non-existing keys to storage
put_result = self._put_batch_zero_copy_impl(
set_keys, set_target_locations, set_target_sizes
)
for i in range(len(set_indices)):
if put_result[i] == 0:
exist_result[set_indices[i]] = 1
success_count = 0
for i in range(len(keys)):
if exist_result[i] == 0:
break
success_count += 1
# TODO: return the number of consecutive successful operations from the start.
return success_count == len(keys)
def get( def get(
self, self,
key, key,
target_location: Optional[Any] = None, target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None, target_sizes: Optional[Any] = None,
) -> torch.Tensor | None: ) -> bool:
assert len(key) == len(target_location) == len(target_sizes) return self.batch_get([key], [target_location], [target_sizes]) == 1
if len(key) == 0:
return
for i in range(len(key)):
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
return
return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
def batch_get( def batch_get(
self, self,
keys: List[str], keys: List[str],
target_location: Optional[Any] = None, target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None, target_sizes: Optional[Any] = None,
) -> torch.Tensor | None: ) -> int:
assert len(keys) == len(target_location) == len(target_sizes) assert len(keys) == len(target_location) == len(target_sizes)
if len(keys) == 0: if len(keys) == 0:
return return 0
get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
if self.is_mla_backend:
key_multiplier = 1
else:
key_multiplier = 2
for i in range(len(keys)): for i in range(len(keys)):
if keys[i] is None or target_location[i] is None or target_sizes[i] is None: if get_result[i] < 0:
return return i // key_multiplier
return len(keys) // key_multiplier
return self._get_batch_zero_copy_impl(keys, target_location, target_sizes) def exists(self, key) -> bool:
return self.batch_exists([key]) > 0
def exists(self, keys) -> bool | dict: def batch_exists(self, keys) -> int:
_keys = [] if self.is_mla_backend:
local_rank = get_tensor_model_parallel_rank() query_keys = [f"{key}_k" for key in keys]
key_multiplier = 1
else:
query_keys = []
for key in keys: for key in keys:
if key is None: query_keys.append(f"{key}_{self.local_rank}_k")
return None query_keys.append(f"{key}_{self.local_rank}_v")
key_multiplier = 2
if self.is_mla: exist_result = self._batch_exist(query_keys)
_keys.append(f"{key}_k") for i in range(len(query_keys)):
else: if exist_result[i] != 1:
_keys.append(f"{key}_{local_rank}_k") return i // key_multiplier
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))} return len(query_keys) // key_multiplier
return result
def delete(self, key) -> None: def delete(self, key) -> None:
raise (NotImplementedError) raise (NotImplementedError)
...@@ -248,18 +260,13 @@ class MooncakeStore(HiCacheStorage): ...@@ -248,18 +260,13 @@ class MooncakeStore(HiCacheStorage):
def _put_batch_zero_copy_impl( def _put_batch_zero_copy_impl(
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int] self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
) -> None: ) -> List[int]:
try: return self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
except TypeError as err:
logger.error("Failed to put value to Mooncake Store: %s", err)
raise TypeError("Mooncake Store Put Type Error.") from err
def _get_batch_zero_copy_impl( def _get_batch_zero_copy_impl(
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int] self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
) -> None: ) -> List[int]:
try: return self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
except TypeError as err: def _batch_exist(self, key_strs: List[str]) -> List[int]:
logger.error("Failed to get value from Mooncake Store: %s", err) return self.store.batch_is_exist(key_strs)
raise TypeError("Mooncake Store Get Type Error.") from err
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment