Unverified Commit d9049592 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

Support l3 cache (mooncake store) for hiradix cache (#7211)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: default avatarAniZpZ <zhuangsen.zp@antgroup.com>
Co-authored-by: default avatarzuoyuan <zhangzuo21@mails.tsinghua.edu.cn>
Co-authored-by: default avatar@wangyueneng.wyn <wangyueneng.wyn@antgroup.com>
Co-authored-by: default avatarJinYan Su <jinyansu792@gmail.com>
parent 26c8a310
...@@ -26,6 +26,10 @@ if TYPE_CHECKING: ...@@ -26,6 +26,10 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
MooncakeStore,
get_hash_str_mooncake,
)
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -125,7 +129,7 @@ class TransferBuffer: ...@@ -125,7 +129,7 @@ class TransferBuffer:
""" """
def __init__( def __init__(
self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1000 self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024
) -> None: ) -> None:
self.stop_event = stop_event self.stop_event = stop_event
self.buffers = Queue(maxsize=buffer_count) self.buffers = Queue(maxsize=buffer_count)
...@@ -260,6 +264,11 @@ class HiCacheController: ...@@ -260,6 +264,11 @@ class HiCacheController:
if storage_backend == "file": if storage_backend == "file":
self.storage_backend = HiCacheFile() self.storage_backend = HiCacheFile()
self.get_hash_str = get_hash_str
elif storage_backend == "mooncake":
self.storage_backend = MooncakeStore()
self.get_hash_str = get_hash_str_mooncake
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
elif storage_backend == "hf3fs": elif storage_backend == "hf3fs":
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
...@@ -271,6 +280,7 @@ class HiCacheController: ...@@ -271,6 +280,7 @@ class HiCacheController:
self.storage_backend = HiCacheHF3FS.from_env_config( self.storage_backend = HiCacheHF3FS.from_env_config(
rank, bytes_per_page, dtype rank, bytes_per_page, dtype
) )
self.get_hash_str = get_hash_str
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}" f"Unsupported storage backend: {storage_backend}"
...@@ -532,6 +542,37 @@ class HiCacheController: ...@@ -532,6 +542,37 @@ class HiCacheController:
operation.mark_done() operation.mark_done()
return operation.completed_tokens, operation.hash_value return operation.completed_tokens, operation.hash_value
def generic_page_transfer(self, operation, batch_size=8):
for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size]
page_data = self.storage_backend.batch_get(page_hashes)
if page_data is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
)
break
completed_tokens = operation.completed_tokens
if operation.increment(self.page_size * len(page_hashes)):
for i in range(len(page_hashes)):
self.mem_pool_host.set_from_flat_data_page(
operation.host_indices[completed_tokens],
page_data[i],
)
completed_tokens += self.page_size
else:
# operation terminated by controller, release pre-allocated memory
self.mem_pool_host.free(
operation.host_indices[operation.completed_tokens :]
)
break
def mooncake_page_transfer(self, operation):
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
operation.hash_value, operation.host_indices
)
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
operation.increment(len(operation.hash_value) * self.page_size)
def prefetch_io_aux_func(self): def prefetch_io_aux_func(self):
""" """
Auxiliary function conducting IO operations for prefetching. Auxiliary function conducting IO operations for prefetching.
...@@ -539,26 +580,10 @@ class HiCacheController: ...@@ -539,26 +580,10 @@ class HiCacheController:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.prefetch_buffer.get(block=True, timeout=1) operation = self.prefetch_buffer.get(block=True, timeout=1)
page_datas = self.storage_backend.batch_get(operation.hash_value) if isinstance(self.storage_backend, MooncakeStore):
for h, page_data in zip(operation.hash_value, page_datas): self.mooncake_page_transfer(operation)
if page_data is None: else:
logger.warning( self.generic_page_transfer(operation)
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
)
break
if operation.increment(self.page_size):
self.mem_pool_host.set_from_flat_data_page(
operation.host_indices[
operation.completed_tokens - self.page_size
],
page_data,
)
else:
# operation terminated by controller, release pre-allocated memory
self.mem_pool_host.free(
operation.host_indices[operation.completed_tokens :]
)
break
except Empty: except Empty:
continue continue
...@@ -582,18 +607,27 @@ class HiCacheController: ...@@ -582,18 +607,27 @@ class HiCacheController:
remaining_tokens = len(tokens_to_fetch) remaining_tokens = len(tokens_to_fetch)
hash_value = [] hash_value = []
while remaining_tokens >= self.page_size: while remaining_tokens >= self.page_size:
last_hash = get_hash_str( last_hash = self.get_hash_str(
tokens_to_fetch[ tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size storage_hit_count : storage_hit_count + self.page_size
], ],
last_hash, last_hash,
) )
if self.storage_backend.exists(last_hash):
storage_hit_count += self.page_size # todo, more unified interface
hash_value.append(last_hash) if not isinstance(self.storage_backend, MooncakeStore):
remaining_tokens -= self.page_size if not self.storage_backend.exists(last_hash):
else: break
break hash_value.append(last_hash)
storage_hit_count += self.page_size
remaining_tokens -= self.page_size
if isinstance(self.storage_backend, MooncakeStore):
# deferring to batch exists for mooncake store
exist_result = self.storage_backend.exists(hash_value)
storage_hit_count = (
sum(1 for v in exist_result.values() if v != 0) * self.page_size
)
if self.tp_world_size > 1: if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor( storage_hit_count_tensor = torch.tensor(
...@@ -641,6 +675,47 @@ class HiCacheController: ...@@ -641,6 +675,47 @@ class HiCacheController:
self.backup_queue.put(operation) self.backup_queue.put(operation)
return operation.id return operation.id
def generic_page_backup(self, operation, batch_size=8):
for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size]
page_data = [
self.mem_pool_host.get_flat_data_pages(
operation.host_indices[j * self.page_size]
)
for j in range(i, i + len(page_hashes))
]
success = self.storage_backend.batch_set(page_hashes, page_data)
if not success:
logger.warning(f"Failed to write page {page_hashes} to storage.")
break
operation.completed_tokens += self.page_size * len(page_hashes)
def mooncake_page_backup(self, operation):
if len(operation.hash_value):
exist_hashvalues = self.storage_backend.exists(operation.hash_value)
indices = operation.host_indices.tolist()
non_exist_keys = []
non_exist_indices = []
for i in range(len(operation.hash_value)):
if not exist_hashvalues[operation.hash_value[i]]:
non_exist_keys.append(operation.hash_value[i])
non_exist_indices.extend(
indices[i * self.page_size : (i + 1) * self.page_size]
)
if len(non_exist_keys) > 0:
key_strs, buffer_ptrs, buffer_sizes = (
self.mem_pool_host.get_buffer_meta(
non_exist_keys, non_exist_indices
)
)
# TODO: check the return value of batch set to see how many tokens are set successfully
self.storage_backend.batch_set(
key_strs,
target_location=buffer_ptrs,
target_sizes=buffer_sizes,
)
operation.completed_tokens += len(operation.hash_value) * self.page_size
def backup_thread_func(self): def backup_thread_func(self):
""" """
Manage backup operations from host memory to storage backend. Manage backup operations from host memory to storage backend.
...@@ -654,23 +729,25 @@ class HiCacheController: ...@@ -654,23 +729,25 @@ class HiCacheController:
last_hash = operation.last_hash last_hash = operation.last_hash
tokens_to_backup = operation.token_ids tokens_to_backup = operation.token_ids
last_hashes, data_pages = [], [] backup_hit_count = 0
for i in range(0, len(tokens_to_backup), self.page_size): remaining_tokens = len(tokens_to_backup)
last_hash = get_hash_str( hash_value = []
tokens_to_backup[i : i + self.page_size], last_hash while remaining_tokens >= self.page_size:
) last_hash = self.get_hash_str(
data_page = self.mem_pool_host.get_flat_data_page( tokens_to_backup[
operation.host_indices[i] backup_hit_count : backup_hit_count + self.page_size
],
last_hash,
) )
last_hashes.append(last_hash) backup_hit_count += self.page_size
data_pages.append(data_page) hash_value.append(last_hash)
remaining_tokens -= self.page_size
operation.hash_value = hash_value
success = self.storage_backend.batch_set(last_hashes, data_pages) if isinstance(self.storage_backend, MooncakeStore):
if not success: self.mooncake_page_backup(operation)
logger.warning(f"Failed to write page {last_hashes} to storage.")
else: else:
operation.completed_tokens += len(tokens_to_backup) self.generic_page_backup(operation)
operation.hash_value.extend(last_hashes)
min_completed_tokens = operation.completed_tokens min_completed_tokens = operation.completed_tokens
if self.tp_world_size > 1: if self.tp_world_size > 1:
......
...@@ -2,7 +2,7 @@ import hashlib ...@@ -2,7 +2,7 @@ import hashlib
import logging import logging
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import Any, List, Optional
import torch import torch
...@@ -39,7 +39,10 @@ class HiCacheStorage(ABC): ...@@ -39,7 +39,10 @@ class HiCacheStorage(ABC):
@abstractmethod @abstractmethod
def get( def get(
self, key: str, target_location: Optional[torch.Tensor] = None self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
""" """
Retrieve the value associated with the given key. Retrieve the value associated with the given key.
...@@ -49,7 +52,10 @@ class HiCacheStorage(ABC): ...@@ -49,7 +52,10 @@ class HiCacheStorage(ABC):
@abstractmethod @abstractmethod
def batch_get( def batch_get(
self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None self,
keys: List[str],
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]: ) -> List[torch.Tensor | None]:
""" """
Retrieve values for multiple keys. Retrieve values for multiple keys.
...@@ -58,7 +64,13 @@ class HiCacheStorage(ABC): ...@@ -58,7 +64,13 @@ class HiCacheStorage(ABC):
pass pass
@abstractmethod @abstractmethod
def set(self, key, value) -> bool: def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
""" """
Store the value associated with the given key. Store the value associated with the given key.
Returns True if the operation was successful, False otherwise. Returns True if the operation was successful, False otherwise.
...@@ -66,7 +78,13 @@ class HiCacheStorage(ABC): ...@@ -66,7 +78,13 @@ class HiCacheStorage(ABC):
pass pass
@abstractmethod @abstractmethod
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
""" """
Store multiple key-value pairs. Store multiple key-value pairs.
Returns True if all operations were successful, False otherwise. Returns True if all operations were successful, False otherwise.
...@@ -74,7 +92,7 @@ class HiCacheStorage(ABC): ...@@ -74,7 +92,7 @@ class HiCacheStorage(ABC):
pass pass
@abstractmethod @abstractmethod
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool | dict:
""" """
Check if the key exists in the storage. Check if the key exists in the storage.
Returns True if the key exists, False otherwise. Returns True if the key exists, False otherwise.
...@@ -97,7 +115,10 @@ class HiCacheFile(HiCacheStorage): ...@@ -97,7 +115,10 @@ class HiCacheFile(HiCacheStorage):
return key + self.tp_suffix return key + self.tp_suffix
def get( def get(
self, key: str, target_location: Optional[torch.Tensor] = None self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
key = self._get_suffixed_key(key) key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin") tensor_path = os.path.join(self.file_path, f"{key}.bin")
...@@ -115,7 +136,8 @@ class HiCacheFile(HiCacheStorage): ...@@ -115,7 +136,8 @@ class HiCacheFile(HiCacheStorage):
def batch_get( def batch_get(
self, self,
keys: List[str], keys: List[str],
target_locations: Optional[List[torch.Tensor]] = None, target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]: ) -> List[torch.Tensor | None]:
return [ return [
self.get(key, target_location) self.get(key, target_location)
...@@ -124,7 +146,13 @@ class HiCacheFile(HiCacheStorage): ...@@ -124,7 +146,13 @@ class HiCacheFile(HiCacheStorage):
) )
] ]
def set(self, key: str, value: torch.Tensor) -> bool: def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
key = self._get_suffixed_key(key) key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin") tensor_path = os.path.join(self.file_path, f"{key}.bin")
if self.exists(key): if self.exists(key):
...@@ -137,7 +165,13 @@ class HiCacheFile(HiCacheStorage): ...@@ -137,7 +165,13 @@ class HiCacheFile(HiCacheStorage):
logger.error(f"Failed to save tensor {key}: {e}") logger.error(f"Failed to save tensor {key}: {e}")
return False return False
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool: def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
for key, value in zip(keys, values): for key, value in zip(keys, values):
if not self.set(key, value): if not self.set(key, value):
return False return False
......
...@@ -594,6 +594,10 @@ class HiRadixCache(RadixCache): ...@@ -594,6 +594,10 @@ class HiRadixCache(RadixCache):
if child.backuped: if child.backuped:
new_node.host_value = child.host_value[:split_len] new_node.host_value = child.host_value[:split_len]
child.host_value = child.host_value[split_len:] child.host_value = child.host_value[split_len:]
if child.hash_value:
new_node.hash_value = child.hash_value[: split_len // self.page_size]
child.hash_value = child.hash_value[split_len // self.page_size :]
child.parent = new_node child.parent = new_node
child.key = child.key[split_len:] child.key = child.key[split_len:]
new_node.parent.children[self.get_child_key_fn(key)] = new_node new_node.parent.children[self.get_child_key_fn(key)] = new_node
......
...@@ -265,6 +265,43 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -265,6 +265,43 @@ class MHATokenToKVPoolHost(HostKVCache):
self.head_dim, self.head_dim,
) )
def get_buffer_meta(self, keys, indices):
ptr_list = []
key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
v_offset = (
self.layer_num
* self.size
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
for index in range(0, len(indices), self.page_size):
for layer_id in range(self.layer_num):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.head_num
* self.head_dim
* self.dtype.itemsize
+ layer_id
* self.size
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
v_ptr = k_ptr + v_offset
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
key_ = keys[index // self.page_size]
key_list.append(f"{key_}_{layer_id}_k")
key_list.append(f"{key_}_{layer_id}_v")
element_size = (
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
)
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
@property @property
def k_buffer(self): def k_buffer(self):
return self.kv_buffer[0] return self.kv_buffer[0]
...@@ -325,3 +362,30 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -325,3 +362,30 @@ class MLATokenToKVPoolHost(HostKVCache):
1, 1,
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
) )
def get_buffer_meta(self, keys, indices):
ptr_list = []
key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
for index in range(0, len(indices), self.page_size):
for layer_id in range(self.layer_num):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
+ layer_id
* self.size
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
)
ptr_list.append(k_ptr)
key_ = keys[index // self.page_size]
key_list.append(f"{key_}_{layer_id}_k")
element_size = (
self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
)
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
# Mooncake as L3 KV Cache
This document describes how to use Mooncake as the L3 KV cache for SGLang.
For more details about Mooncake, please refer to: https://kvcache-ai.github.io/
## Install Mooncake
### Method 1: with pip
```bash
pip install mooncake-transfer-engine
```
### Method 2: from source
Clone Mooncake project:
```bash
git clone https://github.com/kvcache-ai/Mooncake --recursive
```
Install dependencies:
```bash
cd Mooncake
bash dependencies.sh
```
Build the project. For additional build options, please refer to [the official guide](https://kvcache-ai.github.io/Mooncake/getting_started/build.html).
```bash
mkdir build
cd build
cmake ..
make -j
```
Install Mooncake:
```bash
sudo make install
```
## Use Mooncake
Launch Mooncake master server:
```bash
mooncake_master
```
Launch Mooncake meta server:
```bash
python -m mooncake.http_metadata_server
```
Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables:
```bash
MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \
MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \
MOONCAKE_LOCAL_BUFFER_SIZE=134217728 \
MOONCAKE_PROTOCOL="rdma" \
MOONCAKE_DEVICE="erdma_0,erdma_1" \
MOONCAKE_MASTER=127.0.0.1:50051 \
python -m sglang.launch_server \
--enable-hierarchical-cache \
--hicache-storage-backend mooncake\
--model-path [model_path]
```
import hashlib
import json
import logging
import os
import uuid
from dataclasses import dataclass
from typing import Any, List, Optional
import numpy as np
import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
logger = logging.getLogger(__name__)
def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
local_rank = get_tensor_model_parallel_rank()
prefix_str = ""
if prefix_block_key:
if len(prefix_block_key):
prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
current_token_ids_bytes = np.array(current_page_ids).tobytes()
current_hash_object = hashlib.sha256(current_token_ids_bytes)
current_hash_hex = current_hash_object.hexdigest()
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
@staticmethod
def from_file() -> "MooncakeStoreConfig":
"""Load the config from a JSON file."""
file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
with open(file_path) as fin:
config = json.load(fin)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=config.get(
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
),
local_buffer_size=config.get(
"local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", "auto"),
master_server_address=config.get("master_server_address"),
)
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
"""Load config from a file specified in the environment variable.
export MOONCAKE_MASTER=10.13.3.232:50051
export MOONCAKE_PROTOCOL="rdma"
export MOONCAKE_DEVICE="auto"
export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE"
"""
# other required environment variables...
if not os.getenv("MOONCAKE_MASTER"):
raise ValueError("The environment variable 'MOONCAKE_MASTER' is not set.")
return MooncakeStoreConfig(
local_hostname=os.getenv("LOCAL_HOSTNAME", "localhost"),
metadata_server=os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"),
global_segment_size=int(
os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
),
local_buffer_size=int(
os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
),
protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
master_server_address=os.getenv("MOONCAKE_MASTER"),
)
def __post_init__(self):
if self.device_name == "auto":
os.environ["MC_MS_AUTO_DISC"] = "1"
os.environ["MC_MS_FILTERS"] = (
"mlx5_bond_0, mlx5_bond_1, mlx5_bond_2, mlx5_bond_3"
)
class MooncakeStore(HiCacheStorage):
def __init__(self):
try:
from mooncake.store import MooncakeDistributedStore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://kvcache-ai.github.io/Mooncake/getting_started/build.html"
"to run SGLang with MooncakeConnector."
) from e
try:
self.store = MooncakeDistributedStore()
self.config = MooncakeStoreConfig.load_from_env()
logger.info("Mooncake Configuration loaded from env successfully.")
ret_code = self.store.setup(
self.config.local_hostname,
self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
)
if ret_code:
logger.error(f"failed to setup mooncake store, error code: {ret_code}")
logger.info("Connect to Mooncake store successfully.")
self.warmup()
logger.info("Mooncake store warmup successfully.")
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error("An error occurred while loading the configuration: %s", exc)
raise
def warmup(self):
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
# 10 MB
warmup_value = bytes(10 * 1024 * 1024)
self.store.put(warmup_key, warmup_value)
assert self.store.is_exist(warmup_key) == 1
self.store.get(warmup_key)
self.store.remove(warmup_key)
def register_buffer(self, buffer: torch.Tensor) -> None:
try:
buffer_ptr = buffer.data_ptr()
buffer_size = buffer.numel() * buffer.element_size()
ret_code = self.store.register_buffer(buffer_ptr, buffer_size)
if ret_code:
logger.error(f"failed to register buffer, error code: {ret_code}")
except TypeError as err:
logger.error("Failed to register buffer to Mooncake Store: %s", err)
raise TypeError("Mooncake Store Register Buffer Error.") from err
def set(
self,
key,
value: Optional[Any] = None,
target_location: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None,
) -> bool:
assert len(key) == len(target_location) == len(target_sizes)
if len(key) == 0:
return
for i in range(len(key)):
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
return
self._put_batch_zero_copy_impl(key, target_location, target_sizes)
def batch_set(
self,
keys: List[str],
value: Optional[Any] = None,
target_location: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None,
) -> bool:
assert len(keys) == len(target_location) == len(target_sizes)
if len(keys) == 0:
return
for i in range(len(keys)):
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
return
self._put_batch_zero_copy_impl(keys, target_location, target_sizes)
def get(
self,
key,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
assert len(key) == len(target_location) == len(target_sizes)
if len(key) == 0:
return
for i in range(len(key)):
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
return
return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
def batch_get(
self,
keys: List[str],
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
assert len(keys) == len(target_location) == len(target_sizes)
if len(keys) == 0:
return
for i in range(len(keys)):
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
return
return self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
def exists(self, keys) -> bool | dict:
_keys = []
local_rank = torch.cuda.current_device()
for key in keys:
if key is None:
return None
# Since mooncake store is stored in layer by layer,
# only the first layer is checked here.
_keys.append(f"{key}_{local_rank}_k")
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
return result
def delete(self, key) -> None:
raise (NotImplementedError)
def close(self):
# MooncakeDistributedStore will automatically call the destructor, so
# it is unnecessary to close it manually.
pass
def clear(self) -> None:
raise (NotImplementedError)
def _put_batch_zero_copy_impl(
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
) -> None:
try:
self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
except TypeError as err:
logger.error("Failed to put value to Mooncake Store: %s", err)
raise TypeError("Mooncake Store Put Type Error.") from err
def _get_batch_zero_copy_impl(
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
) -> None:
try:
self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
except TypeError as err:
logger.error("Failed to get value from Mooncake Store: %s", err)
raise TypeError("Mooncake Store Get Type Error.") from err
import torch
from mooncake_store import MooncakeStore
def test_init_and_warmup():
store = MooncakeStore()
assert store.store is not None
def test_register_buffer():
store = MooncakeStore()
tensor = torch.zeros(1024, dtype=torch.float32)
store.register_buffer(tensor)
def test_set_and_get():
store = MooncakeStore()
key = ["test_key_" + str(i) for i in range(2)]
tensor = torch.arange(256, dtype=torch.float32).cuda()
ptrs = [tensor.data_ptr(), tensor.data_ptr()]
sizes = [tensor.numel() * tensor.element_size()] * 2
store.set(key, target_location=ptrs, target_sizes=sizes)
store.get(key, target_location=ptrs, target_sizes=sizes)
def test_exists():
store = MooncakeStore()
keys = ["test_key_0", "non_existent_key"]
result = store.exists(keys)
assert isinstance(result, dict)
assert "test_key_0" in result
if __name__ == "__main__":
test_init_and_warmup()
test_register_buffer()
test_set_and_get()
test_exists()
...@@ -1476,7 +1476,7 @@ class ServerArgs: ...@@ -1476,7 +1476,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--hicache-storage-backend", "--hicache-storage-backend",
type=str, type=str,
choices=["file", "hf3fs"], # todo, mooncake choices=["file", "mooncake", "hf3fs"],
default=ServerArgs.hicache_storage_backend, default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache.", help="The storage backend for hierarchical KV cache.",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment