Unverified Commit 2aaf22c4 authored by Makcum888e's avatar Makcum888e Committed by GitHub
Browse files

Optimization for AscendPagedTokenToKVPoolAllocator (#8293)


Co-authored-by: default avatarronnie_zheng <zl19940307@163.com>
Co-authored-by: default avatarVDV1985 <vladdv85@mail.ru>
parent 29a610b4
...@@ -632,27 +632,6 @@ def alloc_extend_kernel_ascend( ...@@ -632,27 +632,6 @@ def alloc_extend_kernel_ascend(
out_indices[end_pos[i] - num3 : end_pos[i]] = ( out_indices[end_pos[i] - num3 : end_pos[i]] = (
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3] free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
).view(-1) ).view(-1)
return num_new_pages
def alloc_decode_kernel_ascend(
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
):
num_new_pages = (seq_lens + page_size - 1) // page_size - (
seq_lens - 1 + page_size - 1
) // page_size
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
for i in range(len(seq_lens)):
if num_new_pages[i]:
out_indices[i] = free_pages[start_new_pages[i]] * page_size
else:
out_indices[i] = last_loc[i] + 1
return num_new_pages
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
...@@ -667,7 +646,6 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -667,7 +646,6 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
need_sort: bool, need_sort: bool,
): ):
super().__init__(size, page_size, dtype, device, kvcache, need_sort) super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
def alloc_extend( def alloc_extend(
self, self,
...@@ -681,17 +659,25 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -681,17 +659,25 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size (last_loc + 1) % self.page_size == prefix_lens % self.page_size
) )
bs = len(prefix_lens) estimated_num_new_pages = (
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len( (
self.free_pages (seq_lens + self.page_size - 1) // self.page_size
): - (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
if estimated_num_new_pages > len(self.free_pages):
return None
out_indices = torch.empty( out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device (extend_num_tokens,), dtype=torch.int32, device=self.device
) )
self.ret_values = alloc_extend_kernel_ascend( alloc_extend_kernel_ascend(
prefix_lens, prefix_lens,
seq_lens, seq_lens,
last_loc, last_loc,
...@@ -704,11 +690,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -704,11 +690,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.sum() self.free_pages = self.free_pages[estimated_num_new_pages:]
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices return out_indices
def alloc_decode( def alloc_decode(
...@@ -721,33 +703,26 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -721,33 +703,26 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size (last_loc + 2) % self.page_size == seq_lens % self.page_size
) )
bs = len(seq_lens) need_new_pages = (seq_lens % self.page_size == 1).int()
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len( num_new_pages = need_new_pages.sum().item()
self.free_pages
): if num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device) if num_new_pages > len(self.free_pages):
return None
self.ret_values = alloc_decode_kernel_ascend( end_new_pages = torch.cumsum(need_new_pages, 0)
seq_lens, start_new_pages = end_new_pages - need_new_pages
last_loc, if num_new_pages == 0:
self.free_pages, out_indices = last_loc + 1
out_indices, else:
self.page_size, out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
) start_new_pages
] * self.page_size * need_new_pages
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.sum()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:] self.free_pages = self.free_pages[num_new_pages:]
return out_indices return out_indices.int()
def clear(self):
super().clear()
self.free_pages = self.free_pages.to(torch.int32)
self.release_pages = self.release_pages.to(torch.int32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment