"tests/cpp/test_spmat_csr.cc" did not exist on "870da747eaee96313aa798806a22573ea1e1e8eb"
Unverified Commit 2aaf22c4 authored by Makcum888e's avatar Makcum888e Committed by GitHub
Browse files

Optimization for AscendPagedTokenToKVPoolAllocator (#8293)


Co-authored-by: default avatarronnie_zheng <zl19940307@163.com>
Co-authored-by: default avatarVDV1985 <vladdv85@mail.ru>
parent 29a610b4
......@@ -632,27 +632,6 @@ def alloc_extend_kernel_ascend(
out_indices[end_pos[i] - num3 : end_pos[i]] = (
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
).view(-1)
return num_new_pages
def alloc_decode_kernel_ascend(
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
):
num_new_pages = (seq_lens + page_size - 1) // page_size - (
seq_lens - 1 + page_size - 1
) // page_size
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
for i in range(len(seq_lens)):
if num_new_pages[i]:
out_indices[i] = free_pages[start_new_pages[i]] * page_size
else:
out_indices[i] = last_loc[i] + 1
return num_new_pages
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
......@@ -667,7 +646,6 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
def alloc_extend(
self,
......@@ -681,17 +659,25 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
bs = len(prefix_lens)
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
self.free_pages
):
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if estimated_num_new_pages > len(self.free_pages):
return None
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
)
self.ret_values = alloc_extend_kernel_ascend(
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
......@@ -704,11 +690,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.sum()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
self.free_pages = self.free_pages[estimated_num_new_pages:]
return out_indices
def alloc_decode(
......@@ -721,33 +703,26 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
bs = len(seq_lens)
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
self.free_pages
):
need_new_pages = (seq_lens % self.page_size == 1).int()
num_new_pages = need_new_pages.sum().item()
if num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
if num_new_pages > len(self.free_pages):
return None
self.ret_values = alloc_decode_kernel_ascend(
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
)
end_new_pages = torch.cumsum(need_new_pages, 0)
start_new_pages = end_new_pages - need_new_pages
if num_new_pages == 0:
out_indices = last_loc + 1
else:
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
start_new_pages
] * self.page_size * need_new_pages
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.sum()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def clear(self):
super().clear()
self.free_pages = self.free_pages.to(torch.int32)
self.release_pages = self.release_pages.to(torch.int32)
return out_indices.int()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment