Unverified Commit 0418b9d4 authored by YiXR's avatar YiXR Committed by GitHub
Browse files

[Optimization] Update estimated_num_new_pages logic in TokenToKVPoolAllocator (#8794)


Signed-off-by: default avatarXingrui Yi <yixingrui@linux.alibaba.com>
Co-authored-by: default avatarXingrui Yi <yixingrui@linux.alibaba.com>
parent e322a94d
......@@ -43,12 +43,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
self._kvcache = kvcache
self.need_sort = need_sort
self.free_pages = None
self.release_pages = None
......@@ -79,6 +81,9 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
if self.free_group:
self.free(torch.cat(self.free_group))
def estimated_num_new_pages(self, bs, extend_num_tokens):
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
def merge_and_sort_free(self):
if len(self.release_pages) > 0:
self.free_pages = torch.cat((self.free_pages, self.release_pages))
......@@ -117,8 +122,15 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""An allocator managing the indices to kv cache data."""
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
super().__init__(size, 1, dtype, device, kvcache)
def __init__(
self,
size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, 1, dtype, device, kvcache, need_sort)
self.clear()
def clear(self):
......@@ -135,7 +147,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
return len(self.free_pages) + len(self.release_pages)
def alloc(self, need_size: int):
if need_size > len(self.free_pages):
if self.need_sort and need_size > len(self.free_pages):
self.merge_and_sort_free()
if need_size > len(self.free_pages):
return None
......@@ -149,7 +161,10 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
return
if self.is_not_in_free_group:
self.release_pages = torch.cat((self.release_pages, free_index))
if self.need_sort:
self.release_pages = torch.cat((self.release_pages, free_index))
else:
self.free_pages = torch.cat((self.free_pages, free_index))
else:
self.free_group.append(free_index)
......@@ -170,8 +185,9 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
dtype: torch.dtype,
device: str,
kvcache: SWAKVPool,
need_sort: bool,
):
super().__init__(size, 1, dtype, device, kvcache)
super().__init__(size, 1, dtype, device, kvcache, need_sort)
assert isinstance(kvcache, SWAKVPool)
self._size_full = size
self._size_swa = size_swa
......@@ -180,12 +196,14 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
dtype,
device,
kvcache.full_kv_pool,
need_sort,
)
self.swa_attn_allocator = TokenToKVPoolAllocator(
size_swa,
dtype,
device,
kvcache.swa_kv_pool,
need_sort,
)
self.full_to_swa_index_mapping = torch.empty(
size + size_swa + 1,
......@@ -418,8 +436,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache)
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
......@@ -433,7 +452,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
), "The allocation size should be page-aligned"
num_pages = need_size // self.page_size
if num_pages > len(self.free_pages):
if self.need_sort and num_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_pages > len(self.free_pages):
return None
......@@ -460,18 +479,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
bs = len(prefix_lens)
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
self.free_pages
):
self.merge_and_sort_free()
bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
......@@ -508,18 +521,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (seq_lens - 1 + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
bs = len(seq_lens)
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
self.free_pages
):
self.merge_and_sort_free()
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
seq_lens,
......@@ -547,7 +554,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size)
self.release_pages = torch.cat((free_page_indices, self.release_pages))
if self.need_sort:
self.release_pages = torch.cat((free_page_indices, self.release_pages))
else:
self.free_pages = torch.cat((free_page_indices, self.free_pages))
else:
self.free_group.append(free_index)
......@@ -654,8 +664,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache)
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
def alloc_extend(
......@@ -670,18 +681,12 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
bs = len(prefix_lens)
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
self.free_pages
):
self.merge_and_sort_free()
bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
)
......@@ -716,18 +721,12 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (seq_lens - 1 + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
bs = len(seq_lens)
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
self.free_pages
):
self.merge_and_sort_free()
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
self.ret_values = alloc_decode_kernel_ascend(
......
......@@ -1300,6 +1300,8 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
......@@ -1307,6 +1309,8 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
)
else:
if _is_npu:
......@@ -1316,6 +1320,8 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
)
else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
......@@ -1324,6 +1330,8 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
)
else:
assert self.is_draft_worker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment