Unverified Commit aee30630 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Add a pointer to the real KV cache pool (#4113)

parent 286e6540
......@@ -20,9 +20,8 @@ Memory pool.
SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations.
TokenToKVPoolAllocator maps a token location to its KV cache data.
KVCache actually holds the physical kv cache. Allocation indices are allocated
by TokenToKVPoolAllocator
TokenToKVPoolAllocator manages the indices to kv cache data.
KVCache actually holds the physical kv cache.
"""
import abc
......@@ -92,14 +91,40 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size))
class KVCache(abc.ABC):
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
) -> None:
raise NotImplementedError()
class TokenToKVPoolAllocator:
"""A memory pool that maps a token location to its kv cache data."""
"""An allocator managing the indices to kv cache data."""
def __init__(
self,
size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
......@@ -110,9 +135,14 @@ class TokenToKVPoolAllocator:
self.free_group = []
self.clear()
self._kvcache = kvcache
def available_size(self):
return len(self.free_slots)
def get_kvcache(self):
return self._kvcache
def alloc(self, need_size: int):
if need_size > len(self.free_slots):
return None
......@@ -147,31 +177,6 @@ class TokenToKVPoolAllocator:
self.free_group = []
class KVCache(abc.ABC):
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
) -> None:
raise NotImplementedError()
class MHATokenToKVPool(KVCache):
def __init__(
......
......@@ -710,15 +710,6 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker
if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
)
else:
assert self.is_draft_worker
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
......@@ -753,6 +744,17 @@ class ModelRunner:
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else:
assert self.is_draft_worker
logger.info(
f"Memory pool end. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment