Unverified Commit 7d004799 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Clean up ascend allocator (#11152)

parent 083629c2
......@@ -493,7 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size)
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
prefix_lens=prefix_lens_cpu,
)
if num_new_pages > len(self.free_pages):
return None
......@@ -529,7 +533,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = get_num_new_pages(
seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True
seq_lens=seq_lens_cpu,
page_size=self.page_size,
decode=True,
)
if num_new_pages > len(self.free_pages):
return None
......
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.utils import get_num_new_pages
def alloc_extend_kernel_ascend(
......@@ -80,13 +76,10 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
num_new_pages = (
(
(seq_lens_cpu + self.page_size - 1) // self.page_size
- (prefix_lens_cpu + self.page_size - 1) // self.page_size
)
.sum()
.item()
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
prefix_lens=prefix_lens_cpu,
)
if self.need_sort and num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
......@@ -125,9 +118,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
need_new_pages = (seq_lens % self.page_size == 1).int()
need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int()
num_new_pages = need_new_pages_cpu.sum().item()
num_new_pages = get_num_new_pages(
seq_lens=seq_lens_cpu,
page_size=self.page_size,
decode=True,
)
if num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
......@@ -135,6 +130,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
if num_new_pages > len(self.free_pages):
return None
need_new_pages = (seq_lens % self.page_size == 1).int()
end_new_pages = torch.cumsum(need_new_pages, 0)
start_new_pages = end_new_pages - need_new_pages
if num_new_pages == 0:
......
......@@ -3251,17 +3251,24 @@ def get_extend_input_len_swa_limit(
def get_num_new_pages(
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
page_size: int,
prefix_lens: Optional[torch.Tensor] = None,
decode: bool = False,
) -> torch.Tensor:
"""
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch.
Get the number of new pages for the given prefix and sequence lengths.
We use cpu tensors to avoid blocking kernel launch.
"""
cpu_device = torch.device("cpu")
assert prefix_lens.device == cpu_device
assert seq_lens.device == cpu_device
if prefix_lens is None or decode:
# NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
assert decode
return (seq_lens % page_size == 1).int().sum().item()
assert prefix_lens.device == cpu_device
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (prefix_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment