Unverified Commit 0b2aa8a7 authored by Zhang Junda's avatar Zhang Junda Committed by GitHub
Browse files

Intoduce cpu tensor as metadata to avoid blocking gpu kernel launch (#10720)


Co-authored-by: default avatarhnyls2002 <lsyincs@gmail.com>
parent 609f65ba
...@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
req_pool_indices, dtype=torch.int64, device=self.device req_pool_indices, dtype=torch.int64, device=self.device
) )
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
self.orig_seq_lens = torch.tensor( self.orig_seq_lens = torch.tensor(
seq_lens, dtype=torch.int32, device=self.device seq_lens, dtype=torch.int32, device=self.device
) )
......
...@@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_type_ids: torch.Tensor = None # shape: [b], int64 token_type_ids: torch.Tensor = None # shape: [b], int64
req_pool_indices: torch.Tensor = None # shape: [b], int64 req_pool_indices: torch.Tensor = None # shape: [b], int64
seq_lens: torch.Tensor = None # shape: [b], int64 seq_lens: torch.Tensor = None # shape: [b], int64
seq_lens_cpu: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache # The output locations of the KV cache
out_cache_loc: torch.Tensor = None # shape: [b], int64 out_cache_loc: torch.Tensor = None # shape: [b], int64
output_ids: torch.Tensor = None # shape: [b], int64 output_ids: torch.Tensor = None # shape: [b], int64
...@@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def alloc_paged_token_slots_extend( def alloc_paged_token_slots_extend(
self, self,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
extend_num_tokens: int, extend_num_tokens: int,
backup_state: bool = False, backup_state: bool = False,
...@@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Over estimate the number of tokens: assume each request needs a new page. # Over estimate the number of tokens: assume each request needs a new page.
num_tokens = ( num_tokens = (
extend_num_tokens extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size + len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
) )
self._evict_tree_cache_if_needed(num_tokens) self._evict_tree_cache_if_needed(num_tokens)
...@@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
state = self.token_to_kv_pool_allocator.backup_state() state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend( out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens, seq_lens, last_loc, extend_num_tokens prefix_lens,
prefix_lens_cpu,
seq_lens,
seq_lens_cpu,
last_loc,
extend_num_tokens,
) )
if out_cache_loc is None: if out_cache_loc is None:
error_msg = ( error_msg = (
...@@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def alloc_paged_token_slots_decode( def alloc_paged_token_slots_decode(
self, self,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
backup_state: bool = False, backup_state: bool = False,
): ):
...@@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if backup_state: if backup_state:
state = self.token_to_kv_pool_allocator.backup_state() state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc) out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
seq_lens, seq_lens_cpu, last_loc
)
if out_cache_loc is None: if out_cache_loc is None:
error_msg = ( error_msg = (
f"Decode out of memory. Try to lower your batch size.\n" f"Decode out of memory. Try to lower your batch size.\n"
...@@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
if not decoder_out_cache_loc: if not decoder_out_cache_loc:
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to( self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
...@@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
prefix_lens_tensor = torch.tensor( prefix_lens_tensor = torch.tensor(
prefix_lens, dtype=torch.int64, device=self.device prefix_lens, dtype=torch.int64, device=self.device
) )
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
token_type_ids_tensor = None token_type_ids_tensor = None
if len(token_type_ids) > 0: if len(token_type_ids) > 0:
...@@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor, prefix_lens_tensor,
) )
out_cache_loc = self.alloc_paged_token_slots_extend( out_cache_loc = self.alloc_paged_token_slots_extend(
prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens prefix_lens_tensor,
prefix_lens_cpu_tensor,
seq_lens_tensor,
seq_lens_cpu_tensor,
last_loc,
extend_num_tokens,
) )
# Set fields # Set fields
self.input_ids = input_ids_tensor self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu_tensor
self.orig_seq_lens = orig_seq_lens_tensor self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.input_embeds = ( self.input_embeds = (
...@@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
retracted_reqs = [] retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True first_iter = True
while first_iter or ( while first_iter or (
not self.check_decode_mem(selected_indices=sorted_indices) not self.check_decode_mem(selected_indices=sorted_indices)
...@@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx] req = self.reqs[idx]
seq_lens_cpu = self.seq_lens.cpu().numpy() seq_lens_cpu = self.seq_lens_cpu.numpy()
if server_args.disaggregation_mode == "decode": if server_args.disaggregation_mode == "decode":
req.offload_kv_cache( req.offload_kv_cache(
...@@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.forward_mode = ForwardMode.IDLE self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device) self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
...@@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if self.enable_overlap: if self.enable_overlap:
# Do not use in-place operations in the overlap mode # Do not use in-place operations in the overlap mode
self.seq_lens = self.seq_lens + 1 self.seq_lens = self.seq_lens + 1
self.seq_lens_cpu = self.seq_lens_cpu + 1
self.orig_seq_lens = self.orig_seq_lens + 1 self.orig_seq_lens = self.orig_seq_lens + 1
else: else:
# A faster in-place version # A faster in-place version
self.seq_lens.add_(1) self.seq_lens.add_(1)
self.seq_lens_cpu.add_(1)
self.orig_seq_lens.add_(1) self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs self.seq_lens_sum += bs
...@@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_pool_indices, self.seq_lens - 2 self.req_pool_indices, self.seq_lens - 2
] ]
self.out_cache_loc = self.alloc_paged_token_slots_decode( self.out_cache_loc = self.alloc_paged_token_slots_decode(
self.seq_lens, last_loc self.seq_lens, self.seq_lens_cpu, last_loc
) )
self.req_to_token_pool.write( self.req_to_token_pool.write(
...@@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices] self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device] self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device] self.seq_lens = self.seq_lens[keep_indices_device]
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item() self.seq_lens_sum = self.seq_lens.sum().item()
...@@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[self.req_pool_indices, other.req_pool_indices] [self.req_pool_indices, other.req_pool_indices]
) )
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens]) self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
self.out_cache_loc = None self.out_cache_loc = None
self.seq_lens_sum += other.seq_lens_sum self.seq_lens_sum += other.seq_lens_sum
...@@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.sampling_info.grammars = None self.sampling_info.grammars = None
seq_lens_cpu = ( seq_lens_cpu = (
seq_lens_cpu_cache seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
if seq_lens_cpu_cache is not None
else self.seq_lens.cpu()
) )
global bid global bid
......
...@@ -27,7 +27,7 @@ import triton ...@@ -27,7 +27,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.mem_cache.memory_pool import SWAKVPool from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.utils import get_bool_env_var, next_power_of_2 from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache from sglang.srt.mem_cache.memory_pool import KVCache
...@@ -294,7 +294,6 @@ def alloc_extend_kernel( ...@@ -294,7 +294,6 @@ def alloc_extend_kernel(
last_loc_ptr, last_loc_ptr,
free_page_ptr, free_page_ptr,
out_indices, out_indices,
ret_values,
bs_upper: tl.constexpr, bs_upper: tl.constexpr,
page_size: tl.constexpr, page_size: tl.constexpr,
max_num_extend_tokens: tl.constexpr, max_num_extend_tokens: tl.constexpr,
...@@ -323,13 +322,6 @@ def alloc_extend_kernel( ...@@ -323,13 +322,6 @@ def alloc_extend_kernel(
sum_num_new_pages = tl.sum(num_new_pages) sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
tl.int64
)
tl.store(ret_values, merged_value)
# Part 1: fill the old partial page # Part 1: fill the old partial page
last_loc = tl.load(last_loc_ptr + pid) last_loc = tl.load(last_loc_ptr + pid)
num_part1 = ( num_part1 = (
...@@ -381,7 +373,6 @@ def alloc_decode_kernel( ...@@ -381,7 +373,6 @@ def alloc_decode_kernel(
last_loc_ptr, last_loc_ptr,
free_page_ptr, free_page_ptr,
out_indices, out_indices,
ret_values,
bs_upper: tl.constexpr, bs_upper: tl.constexpr,
page_size: tl.constexpr, page_size: tl.constexpr,
): ):
...@@ -404,10 +395,6 @@ def alloc_decode_kernel( ...@@ -404,10 +395,6 @@ def alloc_decode_kernel(
sum_num_new_pages = tl.sum(num_new_pages) sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
tl.store(ret_values, sum_num_new_pages)
if num_page_start_loc_self == 0: if num_page_start_loc_self == 0:
last_loc = tl.load(last_loc_ptr + pid) last_loc = tl.load(last_loc_ptr + pid)
tl.store(out_indices + pid, last_loc + 1) tl.store(out_indices + pid, last_loc + 1)
...@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
super().__init__(size, page_size, dtype, device, kvcache, need_sort) super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
self.seen_max_num_extend_tokens_next_power_of_2 = 1 self.seen_max_num_extend_tokens_next_power_of_2 = 1
self.clear() self.clear()
...@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def alloc_extend( def alloc_extend(
self, self,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
extend_num_tokens: int, extend_num_tokens: int,
): ):
...@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
last_loc, last_loc,
self.free_pages, self.free_pages,
out_indices, out_indices,
self.ret_values,
next_power_of_2(bs), next_power_of_2(bs),
self.page_size, self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2, self.seen_max_num_extend_tokens_next_power_of_2,
...@@ -506,8 +493,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -506,8 +493,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
merged_value = self.ret_values.item() num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size)
num_new_pages = merged_value >> 32
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None
...@@ -517,6 +503,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -517,6 +503,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def alloc_decode( def alloc_decode(
self, self,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
): ):
if self.debug_mode: if self.debug_mode:
...@@ -534,7 +521,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -534,7 +521,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
last_loc, last_loc,
self.free_pages, self.free_pages,
out_indices, out_indices,
self.ret_values,
next_power_of_2(bs), next_power_of_2(bs),
self.page_size, self.page_size,
) )
...@@ -542,7 +528,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -542,7 +528,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.item() num_new_pages = get_num_new_pages(
seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True
)
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None
......
...@@ -69,7 +69,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -69,7 +69,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def alloc_extend( def alloc_extend(
self, self,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
extend_num_tokens: int, extend_num_tokens: int,
): ):
...@@ -80,8 +82,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -80,8 +82,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
num_new_pages = ( num_new_pages = (
( (
(seq_lens + self.page_size - 1) // self.page_size (seq_lens_cpu + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size - (prefix_lens_cpu + self.page_size - 1) // self.page_size
) )
.sum() .sum()
.item() .item()
...@@ -115,6 +117,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -115,6 +117,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def alloc_decode( def alloc_decode(
self, self,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
): ):
if self.debug_mode: if self.debug_mode:
...@@ -123,7 +126,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -123,7 +126,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
) )
need_new_pages = (seq_lens % self.page_size == 1).int() need_new_pages = (seq_lens % self.page_size == 1).int()
num_new_pages = need_new_pages.sum().item() need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int()
num_new_pages = need_new_pages_cpu.sum().item()
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
......
...@@ -104,14 +104,21 @@ class EagleVerifyInput(SpecInput): ...@@ -104,14 +104,21 @@ class EagleVerifyInput(SpecInput):
end_offset = batch.seq_lens + self.draft_token_num end_offset = batch.seq_lens + self.draft_token_num
else: else:
prefix_lens = batch.seq_lens prefix_lens = batch.seq_lens
prefix_lens_cpu = batch.seq_lens_cpu
end_offset = prefix_lens + self.draft_token_num end_offset = prefix_lens + self.draft_token_num
end_offset_cpu = prefix_lens_cpu + self.draft_token_num
last_loc = get_last_loc( last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
batch.req_pool_indices, batch.req_pool_indices,
prefix_lens, prefix_lens,
) )
batch.out_cache_loc = batch.alloc_paged_token_slots_extend( batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
prefix_lens, end_offset, last_loc, len(batch.input_ids) prefix_lens,
prefix_lens_cpu,
end_offset,
end_offset_cpu,
last_loc,
len(batch.input_ids),
) )
self.last_loc = last_loc self.last_loc = last_loc
...@@ -380,6 +387,8 @@ class EagleVerifyInput(SpecInput): ...@@ -380,6 +387,8 @@ class EagleVerifyInput(SpecInput):
verified_id = predict[accept_index] verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False evict_mask[accept_index] = False
accept_length_cpu = accept_length.cpu()
accept_length_list = accept_length_cpu.tolist()
if page_size == 1: if page_size == 1:
# TODO: boolean array index leads to a device sync. Remove it. # TODO: boolean array index leads to a device sync. Remove it.
...@@ -456,13 +465,15 @@ class EagleVerifyInput(SpecInput): ...@@ -456,13 +465,15 @@ class EagleVerifyInput(SpecInput):
else: else:
batch.out_cache_loc = tgt_cache_loc batch.out_cache_loc = tgt_cache_loc
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
draft_input = EagleDraftInput( draft_input = EagleDraftInput(
hidden_states=batch.spec_info.hidden_states[accept_index], hidden_states=batch.spec_info.hidden_states[accept_index],
verified_id=verified_id, verified_id=verified_id,
accept_length=accept_length, accept_length=accept_length,
accept_length_cpu=accept_length.tolist(), accept_length_cpu=accept_length_list,
seq_lens_for_draft_extend=batch.seq_lens, seq_lens_for_draft_extend=batch.seq_lens,
seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu,
req_pool_indices_for_draft_extend=batch.req_pool_indices, req_pool_indices_for_draft_extend=batch.req_pool_indices,
) )
...@@ -485,15 +496,15 @@ class EagleVerifyInput(SpecInput): ...@@ -485,15 +496,15 @@ class EagleVerifyInput(SpecInput):
next_power_of_2(bs), next_power_of_2(bs),
) )
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
accept_length_cpu = accept_length.tolist()
if len(unfinished_accept_index) > 0: if len(unfinished_accept_index) > 0:
unfinished_accept_index = torch.cat(unfinished_accept_index) unfinished_accept_index = torch.cat(unfinished_accept_index)
unfinished_index_device = torch.tensor( unfinished_index_device = torch.tensor(
unfinished_index, dtype=torch.int64, device=predict.device unfinished_index, dtype=torch.int64, device=predict.device
) )
draft_input_accept_length_cpu = [ draft_input_accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index accept_length_list[i] for i in unfinished_index
] ]
if page_size == 1 or self.topk == 1: if page_size == 1 or self.topk == 1:
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index] batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
...@@ -508,6 +519,7 @@ class EagleVerifyInput(SpecInput): ...@@ -508,6 +519,7 @@ class EagleVerifyInput(SpecInput):
unfinished_index_device, unfinished_index_device,
batch.seq_lens, batch.seq_lens,
) )
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
filter_finished_cache_loc_kernel[(bs,)]( filter_finished_cache_loc_kernel[(bs,)](
batch.out_cache_loc, batch.out_cache_loc,
tgt_cache_loc, tgt_cache_loc,
...@@ -525,6 +537,7 @@ class EagleVerifyInput(SpecInput): ...@@ -525,6 +537,7 @@ class EagleVerifyInput(SpecInput):
accept_length_cpu=draft_input_accept_length_cpu, accept_length_cpu=draft_input_accept_length_cpu,
accept_length=accept_length[unfinished_index_device], accept_length=accept_length[unfinished_index_device],
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device], seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu[unfinished_index],
req_pool_indices_for_draft_extend=batch.req_pool_indices[ req_pool_indices_for_draft_extend=batch.req_pool_indices[
unfinished_index_device unfinished_index_device
], ],
...@@ -542,7 +555,7 @@ class EagleVerifyInput(SpecInput): ...@@ -542,7 +555,7 @@ class EagleVerifyInput(SpecInput):
draft_input=draft_input, draft_input=draft_input,
logits_output=logits_output, logits_output=logits_output,
verified_id=verified_id, verified_id=verified_id,
accept_length_per_req_cpu=accept_length_cpu, accept_length_per_req_cpu=accept_length_list,
accepted_indices=accept_index, accepted_indices=accept_index,
) )
...@@ -575,6 +588,7 @@ class EagleDraftInput(SpecInput): ...@@ -575,6 +588,7 @@ class EagleDraftInput(SpecInput):
# Inputs for draft extend # Inputs for draft extend
# shape: (b,) # shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None seq_lens_for_draft_extend: torch.Tensor = None
seq_lens_for_draft_extend_cpu: torch.Tensor = None
req_pool_indices_for_draft_extend: torch.Tensor = None req_pool_indices_for_draft_extend: torch.Tensor = None
def __post_init__(self): def __post_init__(self):
...@@ -631,6 +645,7 @@ class EagleDraftInput(SpecInput): ...@@ -631,6 +645,7 @@ class EagleDraftInput(SpecInput):
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu] batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens) batch.extend_num_tokens = sum(batch.extend_lens)
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.seq_lens_cpu = batch.spec_info.seq_lens_for_draft_extend_cpu
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
batch.return_logprob = False batch.return_logprob = False
batch.return_hidden_states = False batch.return_hidden_states = False
......
...@@ -543,6 +543,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -543,6 +543,8 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens, batch.seq_lens,
self.speculative_num_steps, self.speculative_num_steps,
) )
prefix_lens_cpu = batch.seq_lens_cpu
seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps
extend_num_tokens = num_seqs * self.speculative_num_steps extend_num_tokens = num_seqs * self.speculative_num_steps
else: else:
# In this case, the last partial page needs to be duplicated. # In this case, the last partial page needs to be duplicated.
...@@ -578,14 +580,23 @@ class EAGLEWorker(TpModelWorker): ...@@ -578,14 +580,23 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.page_size, self.page_size,
) )
prefix_lens_cpu = batch.seq_lens_cpu
# TODO(lmzheng): remove this device sync last_page_lens = prefix_lens_cpu % self.page_size
extend_num_tokens = torch.sum(self.extend_lens).item() num_new_pages_per_topk = (
last_page_lens + self.speculative_num_steps + self.page_size - 1
) // self.page_size
seq_lens_cpu = (
prefix_lens_cpu // self.page_size * self.page_size
+ num_new_pages_per_topk * (self.page_size * self.topk)
)
extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
out_cache_loc, token_to_kv_pool_state_backup = ( out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend( batch.alloc_paged_token_slots_extend(
prefix_lens, prefix_lens,
prefix_lens_cpu,
seq_lens, seq_lens,
seq_lens_cpu,
last_loc, last_loc,
extend_num_tokens, extend_num_tokens,
backup_state=True, backup_state=True,
...@@ -1003,6 +1014,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -1003,6 +1014,7 @@ class EAGLEWorker(TpModelWorker):
assert isinstance(batch.spec_info, EagleDraftInput) assert isinstance(batch.spec_info, EagleDraftInput)
# Backup fields that will be modified in-place # Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone() seq_lens_backup = batch.seq_lens.clone()
seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
req_pool_indices_backup = batch.req_pool_indices req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob return_logprob_backup = batch.return_logprob
...@@ -1081,6 +1093,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -1081,6 +1093,7 @@ class EAGLEWorker(TpModelWorker):
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
) )
batch.seq_lens = seq_lens_backup batch.seq_lens = seq_lens_backup
batch.seq_lens_cpu = seq_lens_cpu_backup
batch.req_pool_indices = req_pool_indices_backup batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.accept_length = accept_length_backup batch.spec_info.accept_length = accept_length_backup
batch.return_logprob = return_logprob_backup batch.return_logprob = return_logprob_backup
......
...@@ -77,6 +77,7 @@ class NgramVerifyInput(SpecInput): ...@@ -77,6 +77,7 @@ class NgramVerifyInput(SpecInput):
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
end_offset = batch.seq_lens + self.draft_token_num end_offset = batch.seq_lens + self.draft_token_num
else: else:
# TODO(lsyin): add prefix lens cpu here to support page size > 1
prefix_lens = batch.seq_lens prefix_lens = batch.seq_lens
end_offset = prefix_lens + self.draft_token_num end_offset = prefix_lens + self.draft_token_num
last_loc = get_last_loc( last_loc = get_last_loc(
...@@ -405,10 +406,13 @@ class NgramVerifyInput(SpecInput): ...@@ -405,10 +406,13 @@ class NgramVerifyInput(SpecInput):
self._fill_requests(batch, logits_output) self._fill_requests(batch, logits_output)
self._free_cache(batch, page_size) self._free_cache(batch, page_size)
accept_length_cpu = self.accept_length.cpu()
num_accepted_tokens = accept_length_cpu.sum().item()
batch.seq_lens.add_(self.accept_length + 1) batch.seq_lens.add_(self.accept_length + 1)
batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.seq_lens_cpu.add_(accept_length_cpu + 1)
return logits_output, self.verified_id, self.accept_length.sum().item() return logits_output, self.verified_id, num_accepted_tokens
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
pass pass
......
...@@ -3250,6 +3250,30 @@ def get_extend_input_len_swa_limit( ...@@ -3250,6 +3250,30 @@ def get_extend_input_len_swa_limit(
return page_size + 2 * max(sliding_window_size, chunked_prefill_size) return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
def get_num_new_pages(
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
page_size: int,
decode: bool = False,
) -> torch.Tensor:
"""
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch.
"""
cpu_device = torch.device("cpu")
assert prefix_lens.device == cpu_device
assert seq_lens.device == cpu_device
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (prefix_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
extend_lens = seq_lens - prefix_lens
sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
if decode:
return sum_num_new_pages.item()
merged_value = (sum_num_new_pages) << 32 | torch.sum(extend_lens).to(torch.int64)
return merged_value.item() >> 32
class CachedKernel: class CachedKernel:
""" """
Wrapper that allows kernel[grid](...) syntax with caching based on a key function. Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment