Unverified Commit 9a7ced4e authored by Yuwei An's avatar Yuwei An Committed by GitHub
Browse files

[Feature] LMCache Connector Integration (#9741)


Signed-off-by: default avatarOasis-Git <ayw.sirius19@gmail.com>
Signed-off-by: default avatarYuhanLiu11 <yliu738@wisc.edu>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent cb3918a0
...@@ -656,6 +656,21 @@ class Scheduler( ...@@ -656,6 +656,21 @@ class Scheduler(
page_size=self.page_size, page_size=self.page_size,
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
) )
elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
LMCRadixCache,
)
self.tree_cache = LMCRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
model_config=self.model_config,
tp_size=self.tp_size,
rank=self.tp_rank,
tp_group=self.tp_group,
)
else: else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
...@@ -1411,9 +1426,11 @@ class Scheduler( ...@@ -1411,9 +1426,11 @@ class Scheduler(
_, _, available_size, evictable_size = self._get_token_info() _, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size() protected_size = self.tree_cache.protected_size()
memory_leak = (available_size + evictable_size) != ( memory_leak = (available_size + evictable_size) != (
# self.max_total_num_tokens
# if not self.enable_hierarchical_cache
# else self.max_total_num_tokens - protected_size
self.max_total_num_tokens self.max_total_num_tokens
if not self.enable_hierarchical_cache - protected_size
else self.max_total_num_tokens - protected_size
) )
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
......
...@@ -369,7 +369,6 @@ class MHATokenToKVPool(KVCache): ...@@ -369,7 +369,6 @@ class MHATokenToKVPool(KVCache):
# same applies to get_value_buffer and get_kv_buffer # same applies to get_value_buffer and get_kv_buffer
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer) self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self._get_key_buffer(layer_id) return self._get_key_buffer(layer_id)
def _get_value_buffer(self, layer_id: int): def _get_value_buffer(self, layer_id: int):
......
# LMCache Connector for SGLang
This document describes how to use LMCache as KV Cache Management Backend for SGLang engine.
For more details about LMCache, please refer to: https://lmcache.ai
## Install LMCache
### Method 1: with pip
```bash
pip install lmcache
```
### Method 2: from source
Clone LMCache project:
```bash
git clone https://github.com/LMCache/LMCache
```
Install:
```bash
cd LMCache
pip install -e . --no-build-isolation
```
## Use LMCache
Firstly, setup LMCache config. An example config is set at `example_config.yaml`. For more settings please refer to https://docs.lmcache.ai/api_reference/configurations.html.
Secondly, setup SGLang serving engine with lmcache:
```bash
export LMCACHE_USE_EXPERIMENTAL=True
export LMCACHE_CONFIG_FILE=example_config.yaml
python -m sglang.launch_server \
--model-path MODEL \
--enable-lmcache
```
# Basic configurations
chunk_size: 256
# CPU offloading configurations
local_cpu: true
use_layerwise: true
max_local_cpu_size: 10 # number of CPU backend GB
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING, List, Optional
import torch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
try:
from lmcache.integration.sglang.sglang_adapter import (
LMCacheLayerwiseConnector,
LoadMetadata,
StoreMetadata,
)
except ImportError as e:
raise RuntimeError(
"LMCache is not installed. Please install it by running `pip install lmcache`"
) from e
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import Req
logger = logging.getLogger(__name__)
class LayerTransferCounter:
"""Minimal adapter that lets the memory pool notify LMCache per-layer.
The KV pool calls `wait_until(layer_id)` after finishing a layer, which we
translate into a `load_kv_layerwise(layer_id)` call on the LMCache connector
within the provided CUDA stream.
"""
def __init__(
self,
num_layers: int,
load_stream: torch.cuda.Stream,
lmc_connector: LMCacheLayerwiseConnector,
printable: bool = False,
):
self.num_layers = num_layers
self.load_stream = load_stream
self.lmc_connector = lmc_connector
def wait_until(self, layer_id: int):
# Ensure ordering of the async loads wrt compute stream(s).
self.load_stream.synchronize()
with self.load_stream:
self.lmc_connector.load_kv_layerwise(layer_id)
class LMCRadixCache(RadixCache):
"""RadixCache + LMCache IO.
This subclass adds:
- LMCache connector setup (device/host buffers, TP rank/size)
- Two CUDA streams for async load/store
- Layer-wise transfer executor wiring to the KV cache
- Overridden `match_prefix` to fetch missing prefix chunks from LMCache
- Extended cache_finalization paths to store back into LMCache
- Eviction barrier that respects any in-flight host->device stores
"""
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
disable: bool = False,
enable_kv_cache_events: bool = False,
model_config: Optional["ModelConfig"] = None,
tp_size: int = 1,
rank: int = 0,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
super().__init__(
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
page_size=page_size,
disable=disable,
enable_kv_cache_events=enable_kv_cache_events,
)
kvcache = self.token_to_kv_pool_allocator.get_kvcache()
self.lmcache_connector = LMCacheLayerwiseConnector(
sgl_config=model_config,
tp_size=tp_size,
rank=rank,
# NOTE: The original implementation accessed private buffers via
# `_kvcache.k_buffer` / `.v_buffer`. We prefer public accessors when
# available; fall back to private fields if needed.
k_pool=getattr(
kvcache,
"k_buffer",
getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
),
v_pool=getattr(
kvcache,
"v_buffer",
getattr(self.token_to_kv_pool_allocator._kvcache, "v_buffer"),
),
tp_group=tp_group,
)
self.load_stream = torch.cuda.Stream()
self.store_stream = torch.cuda.Stream()
self.layer_done_executor = LayerTransferCounter(
num_layers=(
model_config.num_hidden_layers if model_config is not None else 0
),
load_stream=self.load_stream,
lmc_connector=self.lmcache_connector,
)
kvcache.register_layer_transfer_counter(self.layer_done_executor)
self._in_flight_nodes: list[TreeNode] = []
self._node_lock = threading.Lock()
def reset(self): # type: ignore[override]
super().reset()
if hasattr(self, "_in_flight_nodes"):
with self._node_lock:
self._in_flight_nodes.clear()
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override]
"""Match cached prefix; if there's a tail miss, prefetch from LMCache.
Reuses the base matching logic to obtain (value, last_node). If there
remains a *page-aligned* uncached suffix and there is room (or after
eviction), we allocate token slots and trigger an async LMCache load
into those slots, then materialize a new child node for the retrieved
chunk.
"""
if self.disable or not key:
return super().match_prefix(key, **kwargs)
if self.page_size != 1:
aligned_len = len(key) // self.page_size * self.page_size
key = key[:aligned_len]
base_res = super().match_prefix(key, **kwargs)
value: torch.Tensor = base_res.device_indices
last_node: TreeNode = base_res.last_device_node
if value.numel() == len(key):
return base_res
uncached_len = len(key) - value.numel()
if uncached_len == 0:
return base_res
chunk_size = self.lmcache_connector.chunk_size()
prefix_pad = value.numel() % chunk_size
if self.token_to_kv_pool_allocator.available_size() < uncached_len:
self.evict(uncached_len)
token_slots = self.token_to_kv_pool_allocator.alloc(uncached_len)
if token_slots is None:
return base_res
slot_mapping = torch.cat(
[
torch.full((value.numel(),), -1, dtype=torch.int64, device=self.device),
token_slots.detach().clone().to(torch.int64).to(self.device),
]
)
with torch.cuda.stream(self.load_stream):
num_retrieved = self.lmcache_connector.start_load_kv(
LoadMetadata(
token_ids=key, # full page-aligned key
slot_mapping=slot_mapping,
offset=value.numel() - prefix_pad, # LMCache offset convention
)
)
logger.debug("num_retrieved_tokens: %s", num_retrieved)
if num_retrieved > 0:
self.token_to_kv_pool_allocator.free(
token_slots[(num_retrieved - prefix_pad) :]
)
else:
self.token_to_kv_pool_allocator.free(token_slots)
if num_retrieved > 0:
fetched = num_retrieved - prefix_pad
new_node = TreeNode()
start = value.numel()
end = start + fetched
new_node.key = key[start:end]
new_node.value = token_slots[:fetched]
new_node.parent = last_node
last_node.children[self.get_child_key_fn(new_node.key)] = new_node
last_node = new_node
value = torch.cat([value, token_slots[:fetched]])
self.evictable_size_ += fetched
self._record_store_event(new_node.parent)
self._record_store_event(new_node)
return MatchResult(
device_indices=value,
last_device_node=last_node,
last_host_node=last_node,
)
return base_res
def cache_finished_req(self, req: "Req") -> None: # type: ignore[override]
"""On request completion, insert device KV into radix and store to LMCache."""
super().cache_finished_req(req)
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
_, new_last_node, _, _ = self.match_prefix(token_ids)
assert new_last_node is not None
self.inc_lock_ref(new_last_node)
store_md = StoreMetadata(
last_node=new_last_node,
token_ids=token_ids,
kv_indices=kv_indices,
offset=0,
)
with torch.cuda.stream(self.store_stream):
self.lmcache_connector.store_kv(store_md)
with self._node_lock:
self._in_flight_nodes.append(new_last_node)
def evict(self, num_tokens: int) -> None: # type: ignore[override]
"""Before base eviction, wait for any outstanding stores and release locks."""
if self.disable:
return
self.store_stream.synchronize()
with self._node_lock:
for node in self._in_flight_nodes:
self.dec_lock_ref(node)
self._in_flight_nodes.clear()
super().evict(num_tokens)
def pretty_print(self): # type: ignore[override]
super().pretty_print()
try:
logger.debug(
"evictable=%d protected=%d", self.evictable_size_, self.protected_size_
)
except Exception: # pragma: no cover
pass
if __name__ == "__main__":
cache = LMCRadixCache(
req_to_token_pool=None,
token_to_kv_pool_allocator=None,
page_size=1,
disable=False,
enable_kv_cache_events=False,
model_config=None,
tp_size=1,
rank=0,
tp_group=None,
)
cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64))
cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64))
cache.pretty_print()
try:
from lmcache.integration.sglang.sglang_adapter import (
LMCacheLayerwiseConnector,
LoadMetadata,
StoreMetadata,
)
except ImportError:
raise RuntimeError(
"LMCache is not installed. Please install it by running `pip install lmcache` in the root directory of LMCache"
)
import os
import torch
from sglang.srt.configs.model_config import ModelConfig
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
os.environ["LMCACHE_CONFIG_FILE"] = "example_config.yaml"
def test_load_store_metadata():
model_config = ModelConfig(
model_path="Qwen/Qwen3-4B",
)
# Generate Dummy KV Cache
head_num = model_config.num_key_value_heads
head_dim = model_config.head_dim
layer_num = model_config.num_hidden_layers
buffer_size = 256
input_id_len = 16
k_buffer = [
torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
for _ in range(layer_num)
]
v_buffer = [
torch.randn(buffer_size, head_num, head_dim, dtype=torch.bfloat16).cuda()
for _ in range(layer_num)
]
connector = LMCacheLayerwiseConnector(model_config, 1, 0, k_buffer, v_buffer)
fake_token_ids = torch.randint(0, model_config.vocab_size, (input_id_len,)).tolist()
fake_kv_indices = torch.randint(0, buffer_size, (input_id_len,))
offset = 0
store_metadata = StoreMetadata(
last_node=None,
token_ids=fake_token_ids,
kv_indices=fake_kv_indices,
offset=offset,
)
load_metadata = LoadMetadata(
token_ids=fake_token_ids,
slot_mapping=fake_kv_indices,
offset=offset,
)
current_stream = torch.cuda.current_stream()
retrieve_token_num = connector.start_load_kv(load_metadata)
assert retrieve_token_num == 0
connector.store_kv(store_metadata)
current_stream.synchronize()
# check retrieve
gt_key_buffer = [
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
for _ in range(layer_num)
]
gt_value_buffer = [
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
for _ in range(layer_num)
]
for i in range(layer_num):
gt_key_buffer[i] = k_buffer[i][fake_kv_indices]
gt_value_buffer[i] = v_buffer[i][fake_kv_indices]
# clear the k_buffer and v_buffer
for _ in range(layer_num):
k_buffer[i].zero_()
v_buffer[i].zero_()
retrieve_token_num = connector.start_load_kv(load_metadata)
assert retrieve_token_num == input_id_len
for i in range(layer_num):
current_stream.synchronize()
connector.load_kv_layerwise(i)
current_stream.synchronize()
test_key_buffer = [
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
for _ in range(layer_num)
]
test_value_buffer = [
torch.zeros(input_id_len, head_num, head_dim, dtype=torch.bfloat16).cuda()
for _ in range(layer_num)
]
for i in range(layer_num):
test_key_buffer[i] = k_buffer[i][fake_kv_indices]
test_value_buffer[i] = v_buffer[i][fake_kv_indices]
for i in range(layer_num):
assert torch.allclose(test_key_buffer[i], gt_key_buffer[i])
assert torch.allclose(test_value_buffer[i], gt_value_buffer[i])
print("================================================")
print("TEST_LOAD_STORE_METADATA PASSED!")
print("================================================")
connector.close()
if __name__ == "__main__":
test_load_store_metadata()
...@@ -303,6 +303,8 @@ class ServerArgs: ...@@ -303,6 +303,8 @@ class ServerArgs:
hicache_storage_backend: Optional[str] = None hicache_storage_backend: Optional[str] = None
hicache_storage_prefetch_policy: str = "best_effort" hicache_storage_prefetch_policy: str = "best_effort"
hicache_storage_backend_extra_config: Optional[str] = None hicache_storage_backend_extra_config: Optional[str] = None
# LMCache
enable_lmcache: bool = False
# Double Sparsity # Double Sparsity
enable_double_sparsity: bool = False enable_double_sparsity: bool = False
...@@ -1735,6 +1737,12 @@ class ServerArgs: ...@@ -1735,6 +1737,12 @@ class ServerArgs:
default=ServerArgs.hicache_storage_backend_extra_config, default=ServerArgs.hicache_storage_backend_extra_config,
help="A dictionary in JSON string format containing extra configuration for the storage backend.", help="A dictionary in JSON string format containing extra configuration for the storage backend.",
) )
# LMCache
parser.add_argument(
"--enable-lmcache",
action="store_true",
help="Using LMCache as an alternative hierarchical cache solution",
)
# Double Sparsity # Double Sparsity
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment