Unverified Commit 93088b69 authored by ykwd's avatar ykwd Committed by GitHub
Browse files

[Hicache] Mooncake API Fix & Test, and Improved Readme (#9951)


Co-authored-by: default avatarTeng Ma <sima.mt@alibaba-inc.com>
parent 453511ac
...@@ -659,7 +659,7 @@ class HiCacheController: ...@@ -659,7 +659,7 @@ class HiCacheController:
) )
get_result = self.storage_backend.batch_get( get_result = self.storage_backend.batch_get(
key_strs, key_strs,
target_location=buffer_ptrs, target_locations=buffer_ptrs,
target_sizes=buffer_sizes, target_sizes=buffer_sizes,
) )
if get_result != len(hash_values): if get_result != len(hash_values):
...@@ -843,7 +843,7 @@ class HiCacheController: ...@@ -843,7 +843,7 @@ class HiCacheController:
) )
success = self.storage_backend.batch_set( success = self.storage_backend.batch_set(
key_strs, key_strs,
target_location=buffer_ptrs, target_locations=buffer_ptrs,
target_sizes=buffer_sizes, target_sizes=buffer_sizes,
) )
return success return success
......
# Mooncake as L3 KV Cache # Mooncake as L3 KV Cache
This document describes how to use Mooncake as the L3 KV cache for SGLang. This document describes how to use Mooncake as the L3 KV cache for SGLang.
For more details about Mooncake, please refer to: https://kvcache-ai.github.io/
## About Mooncake
Mooncake aims to enhance the inference efficiency of large language models (LLMs), especially in slow object storage environments, by constructing a multi-level caching pool on high-speed interconnected DRAM/SSD resources. Compared to traditional caching systems, Mooncake utilizes (GPUDirect) RDMA technology to transfer data directly in a zero-copy manner, while maximizing the use of multi-NIC resources on a single machine.
For more details about Mooncake, please refer to [Mooncake project](https://github.com/kvcache-ai/Mooncake) and [Mooncake documents](https://kvcache-ai.github.io/Mooncake/).
## Install Mooncake ## Install Mooncake
...@@ -41,30 +46,108 @@ Install Mooncake: ...@@ -41,30 +46,108 @@ Install Mooncake:
sudo make install sudo make install
``` ```
## Use Mooncake ## Deploy Mooncake
**Mooncake** is a distributed system that efficiently aggregates memory resources across multiple servers. It can also be deployed on a single server for simpler setups.
When integrated with **SGLang**, the system conceptually consists of four key components: `the master service`, `metadata service`, `store service`, and the `SGLang server`. Among them, the `master service` and `metadata service` are responsible for object and metadata maintenance. The `store service` manages a contiguous memory segment that contributes to the distributed KV cache, making its memory accessible to both local and remote `SGLang servers`. Data transfer occurs directly between the `store service` and `SGLang servers`, bypassing the `master service`.
### Single Server Deployment
**Launch Mooncake `metadata service`:**
```bash
python -m mooncake.http_metadata_server
```
Launch Mooncake master server: **Launch Mooncake `master service`:**
```bash ```bash
mooncake_master mooncake_master
``` ```
Launch Mooncake meta server: **Launch Mooncake `store service`:**
First, create and save a configuration file in JSON format. For example:
```json
{
"local_hostname": "localhost",
"metadata_server": "http://localhost:8080/metadata",
"master_server_address": "localhost:50051",
"protocol": "rdma",
"device_name": "mlx5_0,mlx5_1",
"global_segment_size": 2684354560,
"local_buffer_size": 0
}
```
Parameter Explanation:
* `local_hostname`: The hostname of the `store service`.
* `metadata_server`: The network address of the `metadata service`. The default port is 8080.
* `master_server_address`: The network address of the `master service`. The default port is 50051.
* `protocol`: The protocol used by the Mooncake. Supported values are `"rdma"` or `"tcp"`. For optimal performance, `"rdma"` is recommended.
* `device_name`: The RDMA devices used by Mooncake. This parameter is required only when the protocol is set to `"rdma"`. Available devices can be listed using the `ibv_devices` command.
* `global_segment_size`: The amount of memory (in bytes) contributed to the global memory pool. A larger value allows Mooncake to cache more KV tensors.
* `local_buffer_size`: Local buffer is used to do request operations such as `Get` or `Put`. In this case, it is set to 0 because the instance functions solely as a storage server, contributing memory to the global pool without issuing any request operations.
Then start the `store service`:
```bash ```bash
python -m mooncake.http_metadata_server python -m mooncake.mooncake_store_service --config=[config_path]
``` ```
Start the SGLang server with Mooncake enabled. Mooncake configuration can be provided via environment variables. Note that, for optimal performance, the Mooncake backend currently supports only the `page_first` layout. Note: To get started quickly, if `MOONCAKE_GLOBAL_SEGMENT_SIZE` is set to a non-zero value when starting the `SGLang server`, launching the `store service` can be skipped. In this case, the `SGLang server` also fulfills the role of the `store service`.
**Start the `SGLang server` with Mooncake enabled:**
Mooncake configuration can be provided via environment variables. Note that, for optimal performance, the Mooncake backend currently supports only the `page_first` layout (which optimizes memory access patterns for KV cache operations).
```bash ```bash
MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \ MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \
MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \
MOONCAKE_PROTOCOL="rdma" \
MOONCAKE_DEVICE="erdma_0,erdma_1" \
MOONCAKE_MASTER=127.0.0.1:50051 \ MOONCAKE_MASTER=127.0.0.1:50051 \
MOONCAKE_PROTOCOL="rdma" \
MOONCAKE_DEVICE="mlx5_0,mlx5_1" \
MOONCAKE_GLOBAL_SEGMENT_SIZE=4294967296 \
python -m sglang.launch_server \ python -m sglang.launch_server \
--enable-hierarchical-cache \ --enable-hierarchical-cache \
--hicache-storage-backend mooncake\ --hicache-storage-backend mooncake\
--model-path [model_path] --model-path [model_path]
``` ```
Parameter Explanation:
* `MOONCAKE_TE_META_DATA_SERVER`: The network address of the `metadata service`. The default port is 8080.
* `MOONCAKE_MASTER`: The network address of the `master service`. The default port is 50051.
* `MOONCAKE_PROTOCOL`: The protocol used by Mooncake. Supported values are `"rdma"` or `"tcp"`. For optimal performance, `"rdma"` is recommended.
* `MOONCAKE_DEVICE`: The RDMA devices used by Mooncake. This parameter is required only when the protocol is set to `"rdma"`. Available devices can be listed using the `ibv_devices` command.
* `MOONCAKE_GLOBAL_SEGMENT_SIZE`: The amount of memory (in bytes) contributed to the global memory pool. If at least one `store service` is launched, then this value could be set to `0`. In this case, the `SGLang server` will not contribute any memory to the system. Note that KV tensors cached in the contributed memory will be lost once this process terminates; however, this will not cause any system errors.
**Important: Understanding Global Segment Size**
`global_segment_size` for `store service` and `MOONCAKE_GLOBAL_SEGMENT_SIZE` for `SGLang service`: This parameter specifies the amount of memory each instance contributes to the distributed memory pool. The total memory available for KV cache storage across the cluster is the sum of the memory contributed by all instances.
Adjust this value according to system’s available memory and expected cache requirements.
### Distributed Deployment
Distributed deployment of Mooncake is straightforward. Similar to the single-node setup, start one `metadata service` and one `master service` for this cluster. Then start a `store service` on each server.
Mooncake also supports high availability mode. This mode enhances fault tolerance by running the `master service` as a cluster of multiple master nodes coordinated through an `etcd` cluster. The master nodes use `etcd` to elect a leader, which is responsible for handling client requests. For more details about how to deploy in this mode, please refer to our [documents](https://kvcache-ai.github.io/Mooncake/) .
## Test Mooncake Store
This test is intended for developers to quickly verify that the MooncakeStore class interfaces are functioning correctly.
First, start the `metadata service` and `master service`. Then run the `test_mooncake_store.py`. 16MB global segments size is enough to run this test.
```bash
MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata" \
MOONCAKE_MASTER=127.0.0.1:50051 \
MOONCAKE_PROTOCOL="rdma" \
MOONCAKE_DEVICE="mlx5_0,mlx5_1" \
MOONCAKE_GLOBAL_SEGMENT_SIZE=16777216 \
python3 [path of test_mooncake_store.py]
```
If all tests pass, the message "✅ All tests passed" will be printed at the end.
import hashlib
import json import json
import logging import logging
import os import os
...@@ -6,10 +5,8 @@ import uuid ...@@ -6,10 +5,8 @@ import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Optional from typing import Any, List, Optional
import numpy as np
import torch import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
...@@ -154,21 +151,36 @@ class MooncakeStore(HiCacheStorage): ...@@ -154,21 +151,36 @@ class MooncakeStore(HiCacheStorage):
target_location: Optional[List[int]] = None, target_location: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None, target_sizes: Optional[List[int]] = None,
) -> bool: ) -> bool:
return self.batch_set([key], [value], [target_location], [target_sizes]) # Only support zero copy set for now
assert target_location is not None and target_sizes is not None
exist_result = self._batch_exist([key])
if exist_result[0] == 1:
return True
put_result = self._put_batch_zero_copy_impl(
[key], [target_location], [target_sizes]
)
return put_result[0] == 0
def batch_set( def batch_set(
self, self,
keys: List[str], keys: List[str],
values: Optional[List[torch.Tensor]] = None, values: Optional[List[torch.Tensor]] = None,
target_location: Optional[List[int]] = None, target_locations: Optional[List[int]] = None,
target_sizes: Optional[List[int]] = None, target_sizes: Optional[List[int]] = None,
) -> bool: ) -> bool:
assert len(keys) == len(target_location) == len(target_sizes) # Only support zero copy set for now
assert target_locations is not None and target_sizes is not None
assert len(keys) == len(target_locations) == len(target_sizes)
if len(keys) == 0: if len(keys) == 0:
return False return False
for i in range(len(keys)): for i in range(len(keys)):
if keys[i] is None or target_location[i] is None or target_sizes[i] is None: if (
keys[i] is None
or target_locations[i] is None
or target_sizes[i] is None
):
return False return False
exist_result = self._batch_exist(keys) exist_result = self._batch_exist(keys)
...@@ -179,7 +191,7 @@ class MooncakeStore(HiCacheStorage): ...@@ -179,7 +191,7 @@ class MooncakeStore(HiCacheStorage):
for i in range(len(keys)): for i in range(len(keys)):
if exist_result[i] != 1: if exist_result[i] != 1:
set_keys.append(keys[i]) set_keys.append(keys[i])
set_target_locations.append(target_location[i]) set_target_locations.append(target_locations[i])
set_target_sizes.append(target_sizes[i]) set_target_sizes.append(target_sizes[i])
set_indices.append(i) set_indices.append(i)
# Only set non-existing keys to storage # Only set non-existing keys to storage
...@@ -204,18 +216,24 @@ class MooncakeStore(HiCacheStorage): ...@@ -204,18 +216,24 @@ class MooncakeStore(HiCacheStorage):
target_location: Optional[Any] = None, target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None, target_sizes: Optional[Any] = None,
) -> bool: ) -> bool:
return self.batch_get([key], [target_location], [target_sizes]) == 1 assert target_location is not None and target_sizes is not None
get_result = self._get_batch_zero_copy_impl(
[key], [target_location], [target_sizes]
)
return get_result[0] >= 0
def batch_get( def batch_get(
self, self,
keys: List[str], keys: List[str],
target_location: Optional[Any] = None, target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None, target_sizes: Optional[Any] = None,
) -> int: ) -> int:
assert len(keys) == len(target_location) == len(target_sizes) assert len(keys) == len(target_locations) == len(target_sizes)
if len(keys) == 0: if len(keys) == 0:
return 0 return 0
get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes) get_result = self._get_batch_zero_copy_impl(
keys, target_locations, target_sizes
)
if self.is_mla_backend: if self.is_mla_backend:
key_multiplier = 1 key_multiplier = 1
else: else:
...@@ -226,7 +244,8 @@ class MooncakeStore(HiCacheStorage): ...@@ -226,7 +244,8 @@ class MooncakeStore(HiCacheStorage):
return len(keys) // key_multiplier return len(keys) // key_multiplier
def exists(self, key) -> bool: def exists(self, key) -> bool:
return self.batch_exists([key]) > 0 exist_result = self._batch_exist([key])
return exist_result[0] == 1
def batch_exists(self, keys) -> int: def batch_exists(self, keys) -> int:
if self.is_mla_backend: if self.is_mla_backend:
......
import logging
import uuid
import torch
from mooncake_store import MooncakeStore
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def generate_batch_query_keys(kv_num: int, config: HiCacheStorageConfig):
keys = []
for _ in range(kv_num):
key = "test_" + str(uuid.uuid4())
keys.append(key)
set_keys = []
for key in keys:
if config.is_mla_model:
set_keys.append(key + "_k")
else:
set_keys.append(key + f"_{config.tp_rank}_k")
set_keys.append(key + f"_{config.tp_rank}_v")
get_keys = set_keys
exist_keys = keys
return set_keys, get_keys, exist_keys
def test_single_operation():
"""Test the set API with a single key-value pair."""
print("=" * 100)
print("Testing single operation")
buffer_size = 1024 * 1024 * 16 # 16MB
value_elements = 1024
store = MooncakeStore()
buffer = torch.randn(buffer_size, dtype=torch.float32)
store.register_buffer(buffer)
value_size = value_elements * buffer.element_size()
key = str(uuid.uuid4())
set_slice = buffer[:value_elements]
get_slice = buffer[value_elements : 2 * value_elements]
set_location = set_slice.data_ptr()
get_location = get_slice.data_ptr()
# Test set operation
result = store.set(key, target_location=set_location, target_sizes=value_size)
assert result is True, f"❌set operation failed for key: {key}"
# Test exists operation
assert store.exists(key), f"❌key {key} should exist after set operation"
# Test get operation
result = store.get(key, target_location=get_location, target_sizes=value_size)
assert result is True, f"❌get operation failed for key: {key}"
# Compare the data using proper tensor indices
assert torch.allclose(
set_slice, get_slice, atol=1e-6
), f"❌get operation failed for key: {key}"
logger.info(f"✅ Single operation passed")
def test_batch_operation(config: HiCacheStorageConfig):
"""Test the batch set/get APIs with multiple key-value pairs."""
print("=" * 100)
print(f"Testing batch operation with config: {config}")
buffer_size = 1024 * 1024 * 16 # 16MB
value_elements = 256
kv_num = 13
store = MooncakeStore(config)
buffer = torch.randn(buffer_size, dtype=torch.float32)
store.register_buffer(buffer)
value_size = value_elements * buffer.element_size()
set_keys, get_keys, exist_keys = generate_batch_query_keys(kv_num, config)
set_slices = [
buffer[i * value_elements : (i + 1) * value_elements]
for i in range(len(set_keys))
]
set_locations = [set_slice.data_ptr() for set_slice in set_slices]
target_sizes = [value_size for _ in range(len(set_keys))]
# Test batch set operation
result = store.batch_set(
set_keys, target_locations=set_locations, target_sizes=target_sizes
)
assert result is True, f"❌batch set operation failed"
# Test batch exists operation
assert store.batch_exists(
exist_keys
), f"❌keys should exist after batch set operation"
# Test batch get operation
get_slices = [
buffer[
(len(set_keys) + i)
* value_elements : (len(set_keys) + i + 1)
* value_elements
]
for i in range(len(get_keys))
]
get_locations = [get_slice.data_ptr() for get_slice in get_slices]
result = store.batch_get(
get_keys, target_locations=get_locations, target_sizes=target_sizes
)
assert result == kv_num, f"❌batch get operation failed"
for i in range(len(get_keys)):
assert torch.allclose(
set_slices[i], get_slices[i], atol=1e-6
), f"❌batch get operation failed for key: {get_keys[i]}"
logger.info(f"✅ Batch operation passed")
if __name__ == "__main__":
test_single_operation()
test_batch_operation(
HiCacheStorageConfig(
is_mla_model=False,
tp_rank=0,
tp_size=1,
model_name=None,
is_page_first_layout=True,
)
)
test_batch_operation(
HiCacheStorageConfig(
is_mla_model=True,
tp_rank=0,
tp_size=1,
model_name=None,
is_page_first_layout=True,
)
)
test_batch_operation(
HiCacheStorageConfig(
is_mla_model=False,
tp_rank=1,
tp_size=4,
model_name=None,
is_page_first_layout=True,
)
)
test_batch_operation(
HiCacheStorageConfig(
is_mla_model=True,
tp_rank=3,
tp_size=8,
model_name=None,
is_page_first_layout=True,
)
)
logger.info(f"✅ All tests passed")
import torch
from mooncake_store import MooncakeStore
def test_init_and_warmup():
store = MooncakeStore()
assert store.store is not None
def test_register_buffer():
store = MooncakeStore()
tensor = torch.zeros(1024, dtype=torch.float32)
store.register_buffer(tensor)
def test_set_and_get():
store = MooncakeStore()
key = ["test_key_" + str(i) for i in range(2)]
tensor = torch.arange(256, dtype=torch.float32).cuda()
ptrs = [tensor.data_ptr(), tensor.data_ptr()]
sizes = [tensor.numel() * tensor.element_size()] * 2
store.set(key, target_location=ptrs, target_sizes=sizes)
store.get(key, target_location=ptrs, target_sizes=sizes)
def test_exists():
store = MooncakeStore()
keys = ["test_key_0", "non_existent_key"]
result = store.exists(keys)
assert isinstance(result, dict)
assert "test_key_0" in result
if __name__ == "__main__":
test_init_and_warmup()
test_register_buffer()
test_set_and_get()
test_exists()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment