Unverified Commit 0418b9d4 authored by YiXR's avatar YiXR Committed by GitHub
Browse files

[Optimization] Update estimated_num_new_pages logic in TokenToKVPoolAllocator (#8794)


Signed-off-by: default avatarXingrui Yi <yixingrui@linux.alibaba.com>
Co-authored-by: default avatarXingrui Yi <yixingrui@linux.alibaba.com>
parent e322a94d
...@@ -43,12 +43,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC): ...@@ -43,12 +43,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str,
kvcache: KVCache, kvcache: KVCache,
need_sort: bool,
): ):
self.size = size self.size = size
self.page_size = page_size self.page_size = page_size
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self._kvcache = kvcache self._kvcache = kvcache
self.need_sort = need_sort
self.free_pages = None self.free_pages = None
self.release_pages = None self.release_pages = None
...@@ -79,6 +81,9 @@ class BaseTokenToKVPoolAllocator(abc.ABC): ...@@ -79,6 +81,9 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
if self.free_group: if self.free_group:
self.free(torch.cat(self.free_group)) self.free(torch.cat(self.free_group))
def estimated_num_new_pages(self, bs, extend_num_tokens):
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
def merge_and_sort_free(self): def merge_and_sort_free(self):
if len(self.release_pages) > 0: if len(self.release_pages) > 0:
self.free_pages = torch.cat((self.free_pages, self.release_pages)) self.free_pages = torch.cat((self.free_pages, self.release_pages))
...@@ -117,8 +122,15 @@ class BaseTokenToKVPoolAllocator(abc.ABC): ...@@ -117,8 +122,15 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""An allocator managing the indices to kv cache data.""" """An allocator managing the indices to kv cache data."""
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache): def __init__(
super().__init__(size, 1, dtype, device, kvcache) self,
size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, 1, dtype, device, kvcache, need_sort)
self.clear() self.clear()
def clear(self): def clear(self):
...@@ -135,7 +147,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -135,7 +147,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
return len(self.free_pages) + len(self.release_pages) return len(self.free_pages) + len(self.release_pages)
def alloc(self, need_size: int): def alloc(self, need_size: int):
if need_size > len(self.free_pages): if self.need_sort and need_size > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
if need_size > len(self.free_pages): if need_size > len(self.free_pages):
return None return None
...@@ -149,7 +161,10 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -149,7 +161,10 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
return return
if self.is_not_in_free_group: if self.is_not_in_free_group:
self.release_pages = torch.cat((self.release_pages, free_index)) if self.need_sort:
self.release_pages = torch.cat((self.release_pages, free_index))
else:
self.free_pages = torch.cat((self.free_pages, free_index))
else: else:
self.free_group.append(free_index) self.free_group.append(free_index)
...@@ -170,8 +185,9 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -170,8 +185,9 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str,
kvcache: SWAKVPool, kvcache: SWAKVPool,
need_sort: bool,
): ):
super().__init__(size, 1, dtype, device, kvcache) super().__init__(size, 1, dtype, device, kvcache, need_sort)
assert isinstance(kvcache, SWAKVPool) assert isinstance(kvcache, SWAKVPool)
self._size_full = size self._size_full = size
self._size_swa = size_swa self._size_swa = size_swa
...@@ -180,12 +196,14 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -180,12 +196,14 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
dtype, dtype,
device, device,
kvcache.full_kv_pool, kvcache.full_kv_pool,
need_sort,
) )
self.swa_attn_allocator = TokenToKVPoolAllocator( self.swa_attn_allocator = TokenToKVPoolAllocator(
size_swa, size_swa,
dtype, dtype,
device, device,
kvcache.swa_kv_pool, kvcache.swa_kv_pool,
need_sort,
) )
self.full_to_swa_index_mapping = torch.empty( self.full_to_swa_index_mapping = torch.empty(
size + size_swa + 1, size + size_swa + 1,
...@@ -418,8 +436,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -418,8 +436,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str,
kvcache: KVCache, kvcache: KVCache,
need_sort: bool,
): ):
super().__init__(size, page_size, dtype, device, kvcache) super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device) self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
...@@ -433,7 +452,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -433,7 +452,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
), "The allocation size should be page-aligned" ), "The allocation size should be page-aligned"
num_pages = need_size // self.page_size num_pages = need_size // self.page_size
if num_pages > len(self.free_pages): if self.need_sort and num_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
if num_pages > len(self.free_pages): if num_pages > len(self.free_pages):
return None return None
...@@ -460,18 +479,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -460,18 +479,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size (last_loc + 1) % self.page_size == prefix_lens % self.page_size
) )
estimated_num_new_pages = ( bs = len(prefix_lens)
( if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
(seq_lens + self.page_size - 1) // self.page_size self.free_pages
- (prefix_lens + self.page_size - 1) // self.page_size ):
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
bs = len(prefix_lens)
out_indices = torch.empty( out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device (extend_num_tokens,), dtype=torch.int64, device=self.device
) )
...@@ -508,18 +521,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -508,18 +521,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size (last_loc + 2) % self.page_size == seq_lens % self.page_size
) )
estimated_num_new_pages = ( bs = len(seq_lens)
( if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
(seq_lens + self.page_size - 1) // self.page_size self.free_pages
- (seq_lens - 1 + self.page_size - 1) // self.page_size ):
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)]( alloc_decode_kernel[(bs,)](
seq_lens, seq_lens,
...@@ -547,7 +554,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -547,7 +554,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.is_not_in_free_group: if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size) free_page_indices = torch.unique(free_index // self.page_size)
self.release_pages = torch.cat((free_page_indices, self.release_pages)) if self.need_sort:
self.release_pages = torch.cat((free_page_indices, self.release_pages))
else:
self.free_pages = torch.cat((free_page_indices, self.free_pages))
else: else:
self.free_group.append(free_index) self.free_group.append(free_index)
...@@ -654,8 +664,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -654,8 +664,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str,
kvcache: KVCache, kvcache: KVCache,
need_sort: bool,
): ):
super().__init__(size, page_size, dtype, device, kvcache) super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device) self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
def alloc_extend( def alloc_extend(
...@@ -670,18 +681,12 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -670,18 +681,12 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size (last_loc + 1) % self.page_size == prefix_lens % self.page_size
) )
estimated_num_new_pages = ( bs = len(prefix_lens)
( if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
(seq_lens + self.page_size - 1) // self.page_size self.free_pages
- (prefix_lens + self.page_size - 1) // self.page_size ):
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
bs = len(prefix_lens)
out_indices = torch.empty( out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device (extend_num_tokens,), dtype=torch.int32, device=self.device
) )
...@@ -716,18 +721,12 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -716,18 +721,12 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size (last_loc + 2) % self.page_size == seq_lens % self.page_size
) )
estimated_num_new_pages = ( bs = len(seq_lens)
( if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
(seq_lens + self.page_size - 1) // self.page_size self.free_pages
- (seq_lens - 1 + self.page_size - 1) // self.page_size ):
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device) out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
self.ret_values = alloc_decode_kernel_ascend( self.ret_values = alloc_decode_kernel_ascend(
......
...@@ -1300,6 +1300,8 @@ class ModelRunner: ...@@ -1300,6 +1300,8 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
) )
else: else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
...@@ -1307,6 +1309,8 @@ class ModelRunner: ...@@ -1307,6 +1309,8 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
) )
else: else:
if _is_npu: if _is_npu:
...@@ -1316,6 +1320,8 @@ class ModelRunner: ...@@ -1316,6 +1320,8 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
) )
else: else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
...@@ -1324,6 +1330,8 @@ class ModelRunner: ...@@ -1324,6 +1330,8 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
) )
else: else:
assert self.is_draft_worker assert self.is_draft_worker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment