Unverified Commit 7d004799 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Clean up ascend allocator (#11152)

parent 083629c2
...@@ -493,7 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -493,7 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size) num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
prefix_lens=prefix_lens_cpu,
)
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None
...@@ -529,7 +533,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -529,7 +533,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = get_num_new_pages( num_new_pages = get_num_new_pages(
seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True seq_lens=seq_lens_cpu,
page_size=self.page_size,
decode=True,
) )
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None
......
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING
import torch import torch
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
from sglang.srt.utils import get_num_new_pages
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
def alloc_extend_kernel_ascend( def alloc_extend_kernel_ascend(
...@@ -80,13 +76,10 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -80,13 +76,10 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size (last_loc + 1) % self.page_size == prefix_lens % self.page_size
) )
num_new_pages = ( num_new_pages = get_num_new_pages(
( seq_lens=seq_lens_cpu,
(seq_lens_cpu + self.page_size - 1) // self.page_size page_size=self.page_size,
- (prefix_lens_cpu + self.page_size - 1) // self.page_size prefix_lens=prefix_lens_cpu,
)
.sum()
.item()
) )
if self.need_sort and num_new_pages > len(self.free_pages): if self.need_sort and num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
...@@ -125,9 +118,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -125,9 +118,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size (last_loc + 2) % self.page_size == seq_lens % self.page_size
) )
need_new_pages = (seq_lens % self.page_size == 1).int() num_new_pages = get_num_new_pages(
need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int() seq_lens=seq_lens_cpu,
num_new_pages = need_new_pages_cpu.sum().item() page_size=self.page_size,
decode=True,
)
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
...@@ -135,6 +130,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -135,6 +130,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None
need_new_pages = (seq_lens % self.page_size == 1).int()
end_new_pages = torch.cumsum(need_new_pages, 0) end_new_pages = torch.cumsum(need_new_pages, 0)
start_new_pages = end_new_pages - need_new_pages start_new_pages = end_new_pages - need_new_pages
if num_new_pages == 0: if num_new_pages == 0:
......
...@@ -3251,17 +3251,24 @@ def get_extend_input_len_swa_limit( ...@@ -3251,17 +3251,24 @@ def get_extend_input_len_swa_limit(
def get_num_new_pages( def get_num_new_pages(
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
page_size: int, page_size: int,
prefix_lens: Optional[torch.Tensor] = None,
decode: bool = False, decode: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch. Get the number of new pages for the given prefix and sequence lengths.
We use cpu tensors to avoid blocking kernel launch.
""" """
cpu_device = torch.device("cpu") cpu_device = torch.device("cpu")
assert prefix_lens.device == cpu_device
assert seq_lens.device == cpu_device assert seq_lens.device == cpu_device
if prefix_lens is None or decode:
# NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
assert decode
return (seq_lens % page_size == 1).int().sum().item()
assert prefix_lens.device == cpu_device
num_pages_after = (seq_lens + page_size - 1) // page_size num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (prefix_lens + page_size - 1) // page_size num_pages_before = (prefix_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before num_new_pages = num_pages_after - num_pages_before
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment