"tests/vscode:/vscode.git/clone" did not exist on "fe3d29ac4ded1535a75ca3a545a5bd8a9643514a"
Unverified Commit a99801e0 authored by YiXR's avatar YiXR Committed by GitHub
Browse files

[Performance][PD Disaggregation] optimize TokenToKVPoolAllocator by sorting free pages (#8133)


Signed-off-by: default avatarXingrui Yi <yixingrui@linux.alibaba.com>
Co-authored-by: default avatarXingrui Yi <yixingrui@linux.alibaba.com>
parent 4c605235
......@@ -51,6 +51,7 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
self._kvcache = kvcache
self.free_pages = None
self.release_pages = None
self.is_not_in_free_group = True
self.free_group = []
......@@ -58,16 +59,16 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
return ""
def available_size(self):
return len(self.free_pages) * self.page_size
return (len(self.free_pages) + len(self.release_pages)) * self.page_size
def get_kvcache(self):
return self._kvcache
def restore_state(self, free_pages):
self.free_pages = free_pages
def restore_state(self, state):
self.free_pages, self.release_pages = state
def backup_state(self):
return self.free_pages
return (self.free_pages, self.release_pages)
def free_group_begin(self):
self.is_not_in_free_group = False
......@@ -78,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
if self.free_group:
self.free(torch.cat(self.free_group))
def merge_and_sort_free(self):
if len(self.release_pages) > 0:
self.free_pages = torch.cat((self.free_pages, self.release_pages))
self.free_pages, _ = torch.sort(self.free_pages)
self.release_pages = torch.empty(
(0,), dtype=self.release_pages.dtype, device=self.device
)
def get_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
raise NotImplementedError()
......@@ -119,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
)
self.is_not_in_free_group = True
self.free_group = []
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
def available_size(self):
# To avoid minor "len(free_pages) * 1" overhead
return len(self.free_pages)
return len(self.free_pages) + len(self.release_pages)
def alloc(self, need_size: int):
if need_size > len(self.free_pages):
self.merge_and_sort_free()
if need_size > len(self.free_pages):
return None
......@@ -137,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
return
if self.is_not_in_free_group:
self.free_pages = torch.cat((self.free_pages, free_index))
self.release_pages = torch.cat((self.release_pages, free_index))
else:
self.free_group.append(free_index)
......@@ -421,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
), "The allocation size should be page-aligned"
num_pages = need_size // self.page_size
if num_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_pages > len(self.free_pages):
return None
......@@ -446,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
......@@ -483,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (seq_lens - 1 + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
......@@ -511,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size)
self.free_pages = torch.cat((free_page_indices, self.free_pages))
self.release_pages = torch.cat((free_page_indices, self.release_pages))
else:
self.free_group.append(free_index)
......@@ -525,6 +561,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
)
self.is_not_in_free_group = True
self.free_group = []
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
......@@ -633,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
......@@ -668,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (seq_lens - 1 + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
......@@ -692,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def clear(self):
super().clear()
self.free_pages = self.free_pages.to(torch.int32)
self.release_pages = self.release_pages.to(torch.int32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment