Unverified Commit aee30630 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Add a pointer to the real KV cache pool (#4113)

parent 286e6540
...@@ -20,9 +20,8 @@ Memory pool. ...@@ -20,9 +20,8 @@ Memory pool.
SGLang has two levels of memory pool. SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations. ReqToTokenPool maps a a request to its token locations.
TokenToKVPoolAllocator maps a token location to its KV cache data. TokenToKVPoolAllocator manages the indices to kv cache data.
KVCache actually holds the physical kv cache. Allocation indices are allocated KVCache actually holds the physical kv cache.
by TokenToKVPoolAllocator
""" """
import abc import abc
...@@ -92,14 +91,40 @@ class ReqToTokenPool: ...@@ -92,14 +91,40 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size)) self.free_slots = list(range(self.size))
class KVCache(abc.ABC):
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
) -> None:
raise NotImplementedError()
class TokenToKVPoolAllocator: class TokenToKVPoolAllocator:
"""A memory pool that maps a token location to its kv cache data.""" """An allocator managing the indices to kv cache data."""
def __init__( def __init__(
self, self,
size: int, size: int,
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str,
kvcache: KVCache,
): ):
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
...@@ -110,9 +135,14 @@ class TokenToKVPoolAllocator: ...@@ -110,9 +135,14 @@ class TokenToKVPoolAllocator:
self.free_group = [] self.free_group = []
self.clear() self.clear()
self._kvcache = kvcache
def available_size(self): def available_size(self):
return len(self.free_slots) return len(self.free_slots)
def get_kvcache(self):
return self._kvcache
def alloc(self, need_size: int): def alloc(self, need_size: int):
if need_size > len(self.free_slots): if need_size > len(self.free_slots):
return None return None
...@@ -147,31 +177,6 @@ class TokenToKVPoolAllocator: ...@@ -147,31 +177,6 @@ class TokenToKVPoolAllocator:
self.free_group = [] self.free_group = []
class KVCache(abc.ABC):
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
) -> None:
raise NotImplementedError()
class MHATokenToKVPool(KVCache): class MHATokenToKVPool(KVCache):
def __init__( def __init__(
......
...@@ -710,15 +710,6 @@ class ModelRunner: ...@@ -710,15 +710,6 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker. # Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker assert self.is_draft_worker
if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
)
else:
assert self.is_draft_worker
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla and not self.server_args.disable_mla
...@@ -753,6 +744,17 @@ class ModelRunner: ...@@ -753,6 +744,17 @@ class ModelRunner:
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
) )
if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else:
assert self.is_draft_worker
logger.info( logger.info(
f"Memory pool end. " f"Memory pool end. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment