Unverified Commit a99801e0 authored by YiXR's avatar YiXR Committed by GitHub
Browse files

[Performance][PD Disaggregation] optimize TokenToKVPoolAllocator by sorting free pages (#8133)


Signed-off-by: default avatarXingrui Yi <yixingrui@linux.alibaba.com>
Co-authored-by: default avatarXingrui Yi <yixingrui@linux.alibaba.com>
parent 4c605235
...@@ -51,6 +51,7 @@ class BaseTokenToKVPoolAllocator(abc.ABC): ...@@ -51,6 +51,7 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
self._kvcache = kvcache self._kvcache = kvcache
self.free_pages = None self.free_pages = None
self.release_pages = None
self.is_not_in_free_group = True self.is_not_in_free_group = True
self.free_group = [] self.free_group = []
...@@ -58,16 +59,16 @@ class BaseTokenToKVPoolAllocator(abc.ABC): ...@@ -58,16 +59,16 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
return "" return ""
def available_size(self): def available_size(self):
return len(self.free_pages) * self.page_size return (len(self.free_pages) + len(self.release_pages)) * self.page_size
def get_kvcache(self): def get_kvcache(self):
return self._kvcache return self._kvcache
def restore_state(self, free_pages): def restore_state(self, state):
self.free_pages = free_pages self.free_pages, self.release_pages = state
def backup_state(self): def backup_state(self):
return self.free_pages return (self.free_pages, self.release_pages)
def free_group_begin(self): def free_group_begin(self):
self.is_not_in_free_group = False self.is_not_in_free_group = False
...@@ -78,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC): ...@@ -78,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
if self.free_group: if self.free_group:
self.free(torch.cat(self.free_group)) self.free(torch.cat(self.free_group))
def merge_and_sort_free(self):
if len(self.release_pages) > 0:
self.free_pages = torch.cat((self.free_pages, self.release_pages))
self.free_pages, _ = torch.sort(self.free_pages)
self.release_pages = torch.empty(
(0,), dtype=self.release_pages.dtype, device=self.device
)
def get_cpu_copy(self, *args, **kwargs): def get_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the get_cpu_copy after paged allocator is implemented # FIXME: reuse the get_cpu_copy after paged allocator is implemented
raise NotImplementedError() raise NotImplementedError()
...@@ -119,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -119,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
) )
self.is_not_in_free_group = True self.is_not_in_free_group = True
self.free_group = [] self.free_group = []
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
def available_size(self): def available_size(self):
# To avoid minor "len(free_pages) * 1" overhead # To avoid minor "len(free_pages) * 1" overhead
return len(self.free_pages) return len(self.free_pages) + len(self.release_pages)
def alloc(self, need_size: int): def alloc(self, need_size: int):
if need_size > len(self.free_pages):
self.merge_and_sort_free()
if need_size > len(self.free_pages): if need_size > len(self.free_pages):
return None return None
...@@ -137,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -137,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
return return
if self.is_not_in_free_group: if self.is_not_in_free_group:
self.free_pages = torch.cat((self.free_pages, free_index)) self.release_pages = torch.cat((self.release_pages, free_index))
else: else:
self.free_group.append(free_index) self.free_group.append(free_index)
...@@ -421,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -421,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
), "The allocation size should be page-aligned" ), "The allocation size should be page-aligned"
num_pages = need_size // self.page_size num_pages = need_size // self.page_size
if num_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_pages > len(self.free_pages): if num_pages > len(self.free_pages):
return None return None
...@@ -446,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -446,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size (last_loc + 1) % self.page_size == prefix_lens % self.page_size
) )
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
bs = len(prefix_lens) bs = len(prefix_lens)
out_indices = torch.empty( out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device (extend_num_tokens,), dtype=torch.int64, device=self.device
...@@ -483,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -483,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size (last_loc + 2) % self.page_size == seq_lens % self.page_size
) )
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (seq_lens - 1 + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
bs = len(seq_lens) bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)]( alloc_decode_kernel[(bs,)](
...@@ -511,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -511,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.is_not_in_free_group: if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size) free_page_indices = torch.unique(free_index // self.page_size)
self.free_pages = torch.cat((free_page_indices, self.free_pages)) self.release_pages = torch.cat((free_page_indices, self.release_pages))
else: else:
self.free_group.append(free_index) self.free_group.append(free_index)
...@@ -525,6 +561,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -525,6 +561,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
) )
self.is_not_in_free_group = True self.is_not_in_free_group = True
self.free_group = [] self.free_group = []
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices) return self._kvcache.get_cpu_copy(indices)
...@@ -633,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -633,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size (last_loc + 1) % self.page_size == prefix_lens % self.page_size
) )
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
bs = len(prefix_lens) bs = len(prefix_lens)
out_indices = torch.empty( out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device (extend_num_tokens,), dtype=torch.int32, device=self.device
...@@ -668,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -668,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size (last_loc + 2) % self.page_size == seq_lens % self.page_size
) )
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (seq_lens - 1 + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
bs = len(seq_lens) bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device) out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
...@@ -692,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -692,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def clear(self): def clear(self):
super().clear() super().clear()
self.free_pages = self.free_pages.to(torch.int32) self.free_pages = self.free_pages.to(torch.int32)
self.release_pages = self.release_pages.to(torch.int32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment