Unverified Commit fce17048 authored by yi wang's avatar yi wang Committed by GitHub
Browse files

integrate AIBrix KVcache (#10376)

parent 3d40794f
...@@ -41,7 +41,7 @@ repos: ...@@ -41,7 +41,7 @@ repos:
hooks: hooks:
- id: codespell - id: codespell
additional_dependencies: ['tomli'] additional_dependencies: ['tomli']
args: ['--toml', 'python/pyproject.toml', '-L', 'cann,thi,makro,wil,rouge'] args: ['--toml', 'python/pyproject.toml', '-L', 'cann,thi,makro,wil,rouge,PRIS']
exclude: | exclude: |
(?x)^( (?x)^(
test/srt/test_reasoning_parser\.py| test/srt/test_reasoning_parser\.py|
......
...@@ -289,6 +289,14 @@ class HiCacheController: ...@@ -289,6 +289,14 @@ class HiCacheController:
) )
self.storage_backend = MooncakeStore(self.storage_config) self.storage_backend = MooncakeStore(self.storage_config)
elif storage_backend == "aibrix":
from sglang.srt.mem_cache.storage.aibrix_kvcache.aibrix_kvcache_storage import (
AibrixKVCacheStorage,
)
self.storage_backend = AibrixKVCacheStorage(
self.storage_config, self.mem_pool_host
)
elif storage_backend == "hf3fs": elif storage_backend == "hf3fs":
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import ( from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
HiCacheHF3FS, HiCacheHF3FS,
......
# AIBrix KVCache as L3 KV Cache
This document provides brief instructions for setting up a AIBrixKVCache storage backend + AIBrixKVCache + SGLang runtime environment from scratch, describing how to utilize AIBrixKVCache as the L3 KV cache for SGLang.
The process consists of three main steps:
## Step1:Install AIbrix KVCache
Refer to the [AIBrix KVCache documentation](https://github.com/vllm-project/aibrix/blob/main/python/aibrix_kvcache/README.md) to install AIBrix KVCache.
## Step2: Deploy AIBrix Distributed KVCache Storage
AIBrix KVCache currently supports multiple distributed KVCache backends, including ByteDance's open-source Infinistore and the not-yet-open source PrisKV incubated by ByteDance's PrisDB & IAAS & DMI team.
For the Infinistore installation process, please refer to [this link](https://github.com/bytedance/InfiniStore).
PrisKV for AIBrix KVCache is currently in the open-source preparation stage, and no public documentation is available yet.
## Step3: Deploy Model Serving
For information on configuring a distributed KVCache backend for AIBrixKVCache, please refer to [this link](https://aibrix.readthedocs.io/latest/designs/aibrix-kvcache-offloading-framework.html)
Using PrisKV as an example, the startup command is as follows:
```bash
export AIBRIX_KV_CACHE_OL_L1_CACHE_ENABLED="0"
export AIBRIX_KV_CACHE_OL_L2_CACHE_BACKEND="PRIS"
export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR="127.0.0.1"
export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT="6379"
export AIBRIX_KV_CACHE_OL_PRIS_PASSWORD="kvcache-redis"
MODEL_LENGTH=32768&&NCCL_MIN_NCHANNELS=24&&NCCL_IB_QPS_PER_CONNECTION=8&&NCCL_DEBUG=INFO \
python3 -m sglang.launch_server \
--model-path /code/models/Qwen3-32B \
--host 0.0.0.0 --port 8080 \
--enable-hierarchical-cache \
--hicache-storage-backend aibrix \
--page-size 16 \
--hicache-write-policy write_back \
--enable-metrics --hicache-ratio=2
```
import logging
from typing import Any, List, Optional
import torch
from aibrix_kvcache import (
BaseKVCacheManager,
BlockHashes,
KVCacheBlockLayout,
KVCacheBlockSpec,
KVCacheConfig,
KVCacheTensorSpec,
ModelSpec,
)
from aibrix_kvcache.common.absl_logging import log_every_n_seconds
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
logger = logging.getLogger(__name__)
class AibrixKVCacheStorage(HiCacheStorage):
def __init__(self, storage_config: HiCacheStorageConfig, mem_pool: HostKVCache):
if storage_config is not None:
self.is_mla_backend = storage_config.is_mla_model
self.local_rank = storage_config.tp_rank
else:
self.is_mla_backend = False
self.local_rank = 0
kv_cache = mem_pool.device_pool
self.page_size = mem_pool.page_size
self.kv_cache_dtype = kv_cache.dtype
self.layer_num = kv_cache.layer_num
self.kv_head_ids = [
self.local_rank * kv_cache.head_num + i for i in range(kv_cache.head_num)
]
if not self.is_mla_backend:
self.layer_ids = range(
kv_cache.start_layer, kv_cache.end_layer
) # for pipeline parallel
self.block_spec = KVCacheBlockSpec(
block_ntokens=self.page_size,
block_dtype=self.kv_cache_dtype,
block_layout=KVCacheBlockLayout(KVCacheBlockLayout.NCLD),
tensor_spec=KVCacheTensorSpec(
heads=self.kv_head_ids,
layers=self.layer_ids,
head_size=kv_cache.head_dim,
),
)
logger.info(self.block_spec)
config = KVCacheConfig(
block_spec=self.block_spec, model_spec=ModelSpec(102400)
)
self.kv_cache_manager = BaseKVCacheManager(config)
else:
raise NotImplementedError(
"MLA is not supported by AibrixKVCacheStorage yet."
)
def _aibrix_kvcache_metrics_report(self):
self.kv_cache_manager.metrics.summary()
self.kv_cache_manager.metrics.reset()
def batch_get(
self,
keys: List[str],
target_locations: List[torch.Tensor],
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]:
block_hash = BlockHashes(keys, self.page_size)
status = self.kv_cache_manager.acquire(None, block_hash)
log_every_n_seconds(
logger, logging.INFO, self._aibrix_kvcache_metrics_report(), 1
)
if status.is_ok():
num_fetched_tokens, handle = status.value
kv_blocks = handle.to_tensors()
assert len(kv_blocks) == len(target_locations)
for i in range(len(kv_blocks)):
assert (
target_locations[i].nbytes == kv_blocks[i].nbytes
), f"{target_locations[i].nbytes}, {kv_blocks[i].nbytes}"
target_locations[i].copy_(kv_blocks[i].flatten())
handle.release()
return target_locations
return [None] * len(keys)
def get(
self,
key: str,
target_location: Optional[Any] = None,
target_size: Optional[Any] = None,
) -> torch.Tensor | None:
return self.batch_get([key], [target_location], [target_size])[0]
def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
block_hash = BlockHashes(keys, self.page_size)
status = self.kv_cache_manager.allocate_for(None, block_hash)
if not status.is_ok():
logger.warning(
f"aibrix_kvcache set allocate failed, error_code {status.error_code}"
)
return False
handle = status.value
tensors = handle.to_tensors()
if len(tensors) != len(values):
logger.warning("aibrix_kvcache set allocate not enough")
return False
for i in range(len(tensors)):
assert (
tensors[i].nbytes == values[i].nbytes
), f"{tensors[i].nbytes}, {values[i].nbytes}"
tensors[i].reshape(values[i].shape).copy_(values[i]).reshape(
tensors[i].shape
)
status = self.kv_cache_manager.put(None, block_hash, handle)
if not status.is_ok():
logger.info(
f"AIBrix KVCache Storage set failed, error_code {status.error_code}"
)
return False
completed = status.value
return completed == len(keys) * self.page_size
def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_size: Optional[Any] = None,
) -> bool:
return self.batch_set([key], [value], [target_location], [target_size])
def batch_exists(self, keys: List[str]) -> int:
block_hash = BlockHashes(keys, self.page_size)
status = self.kv_cache_manager.exists(None, block_hash)
if status.is_ok():
return status.value // self.page_size
return 0
def exists(self, key: str) -> bool | dict:
return self.batch_exists([key]) > 0
import logging
import os
import torch
import torch.distributed
from aibrix_kvcache import (
BaseKVCacheManager,
GroupAwareKVCacheManager,
KVCacheBlockLayout,
KVCacheBlockSpec,
KVCacheConfig,
KVCacheMetrics,
KVCacheTensorSpec,
ModelSpec,
TokenListView,
)
from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if
from aibrix_kvcache_storage import AibrixKVCacheStorage
from torch.distributed import Backend, ProcessGroup
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def setup():
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "63886"
class AIBrixKVCacheStorageTest:
def test_with_page_size(self):
config = HiCacheStorageConfig(
tp_rank=0,
tp_size=1,
is_mla_model=False,
is_page_first_layout=True,
model_name="test",
)
for page_size in range(1, 3):
logger.info(f"page_size: {page_size}")
batch_size = 2
head_num = 1
layer_num = 64
head_dim = 128
kv_cache = MHATokenToKVPool(
1024,
page_size,
torch.float16,
head_num,
head_dim,
layer_num,
"cpu",
False,
0,
layer_num,
)
mem_pool = MHATokenToKVPoolHost(kv_cache, 2, 0, page_size, "layer_first")
query_length = batch_size * 2
partial = batch_size
self.aibrix_kvcache = AibrixKVCacheStorage(config, mem_pool)
target_shape = (2, layer_num, page_size, head_num, head_dim)
rand_tensor = [
torch.rand(target_shape, dtype=torch.float16)
for _ in range(query_length)
]
keys = ["hash" + str(i) for i in range(query_length)]
partial_keys = keys[batch_size:query_length]
assert self.aibrix_kvcache.batch_exists(keys) == 0
assert self.aibrix_kvcache.batch_set(keys, rand_tensor)
get_tensor = [
torch.rand(target_shape, dtype=torch.float16).flatten()
for _ in range(query_length)
]
self.aibrix_kvcache.batch_get(keys, get_tensor)
for i in range(query_length):
assert torch.equal(get_tensor[i], rand_tensor[i].flatten())
ret = self.aibrix_kvcache.batch_exists(keys)
assert self.aibrix_kvcache.batch_exists(keys) == query_length
assert self.aibrix_kvcache.batch_exists(partial_keys) == partial
partial_get_tensor = [
torch.rand(target_shape, dtype=torch.float16).flatten()
for _ in range(partial)
]
self.aibrix_kvcache.batch_get(partial_keys, partial_get_tensor)
for i in range(partial):
assert torch.equal(
partial_get_tensor[i], rand_tensor[i + partial].flatten()
)
log_every_n_seconds(
logger,
logging.INFO,
self.aibrix_kvcache.kv_cache_manager.metrics.summary(),
1,
)
if __name__ == "__main__":
setup()
test = AIBrixKVCacheStorageTest()
test.test_with_page_size()
...@@ -2154,7 +2154,7 @@ class ServerArgs: ...@@ -2154,7 +2154,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--hicache-storage-backend", "--hicache-storage-backend",
type=str, type=str,
choices=["file", "mooncake", "hf3fs", "nixl"], choices=["file", "mooncake", "hf3fs", "nixl", "aibrix"],
default=ServerArgs.hicache_storage_backend, default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache.", help="The storage backend for hierarchical KV cache.",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment