"vscode:/vscode.git/clone" did not exist on "9da448508925db4e8509f13bc50a9e0d18302c5c"
Unverified Commit 0b2aa8a7 authored by Zhang Junda's avatar Zhang Junda Committed by GitHub
Browse files

Intoduce cpu tensor as metadata to avoid blocking gpu kernel launch (#10720)


Co-authored-by: default avatarhnyls2002 <lsyincs@gmail.com>
parent 609f65ba
...@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin: ...@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
req_pool_indices, dtype=torch.int64, device=self.device req_pool_indices, dtype=torch.int64, device=self.device
) )
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
self.orig_seq_lens = torch.tensor( self.orig_seq_lens = torch.tensor(
seq_lens, dtype=torch.int32, device=self.device seq_lens, dtype=torch.int32, device=self.device
) )
......
...@@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_type_ids: torch.Tensor = None # shape: [b], int64 token_type_ids: torch.Tensor = None # shape: [b], int64
req_pool_indices: torch.Tensor = None # shape: [b], int64 req_pool_indices: torch.Tensor = None # shape: [b], int64
seq_lens: torch.Tensor = None # shape: [b], int64 seq_lens: torch.Tensor = None # shape: [b], int64
seq_lens_cpu: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache # The output locations of the KV cache
out_cache_loc: torch.Tensor = None # shape: [b], int64 out_cache_loc: torch.Tensor = None # shape: [b], int64
output_ids: torch.Tensor = None # shape: [b], int64 output_ids: torch.Tensor = None # shape: [b], int64
...@@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def alloc_paged_token_slots_extend( def alloc_paged_token_slots_extend(
self, self,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
extend_num_tokens: int, extend_num_tokens: int,
backup_state: bool = False, backup_state: bool = False,
...@@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Over estimate the number of tokens: assume each request needs a new page. # Over estimate the number of tokens: assume each request needs a new page.
num_tokens = ( num_tokens = (
extend_num_tokens extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size + len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
) )
self._evict_tree_cache_if_needed(num_tokens) self._evict_tree_cache_if_needed(num_tokens)
...@@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
state = self.token_to_kv_pool_allocator.backup_state() state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend( out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens, seq_lens, last_loc, extend_num_tokens prefix_lens,
prefix_lens_cpu,
seq_lens,
seq_lens_cpu,
last_loc,
extend_num_tokens,
) )
if out_cache_loc is None: if out_cache_loc is None:
error_msg = ( error_msg = (
...@@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def alloc_paged_token_slots_decode( def alloc_paged_token_slots_decode(
self, self,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
backup_state: bool = False, backup_state: bool = False,
): ):
...@@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if backup_state: if backup_state:
state = self.token_to_kv_pool_allocator.backup_state() state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc) out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
seq_lens, seq_lens_cpu, last_loc
)
if out_cache_loc is None: if out_cache_loc is None:
error_msg = ( error_msg = (
f"Decode out of memory. Try to lower your batch size.\n" f"Decode out of memory. Try to lower your batch size.\n"
...@@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
if not decoder_out_cache_loc: if not decoder_out_cache_loc:
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to( self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
...@@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
prefix_lens_tensor = torch.tensor( prefix_lens_tensor = torch.tensor(
prefix_lens, dtype=torch.int64, device=self.device prefix_lens, dtype=torch.int64, device=self.device
) )
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
token_type_ids_tensor = None token_type_ids_tensor = None
if len(token_type_ids) > 0: if len(token_type_ids) > 0:
...@@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor, prefix_lens_tensor,
) )
out_cache_loc = self.alloc_paged_token_slots_extend( out_cache_loc = self.alloc_paged_token_slots_extend(
prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens prefix_lens_tensor,
prefix_lens_cpu_tensor,
seq_lens_tensor,
seq_lens_cpu_tensor,
last_loc,
extend_num_tokens,
) )
# Set fields # Set fields
self.input_ids = input_ids_tensor self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu_tensor
self.orig_seq_lens = orig_seq_lens_tensor self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.input_embeds = ( self.input_embeds = (
...@@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
retracted_reqs = [] retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True first_iter = True
while first_iter or ( while first_iter or (
not self.check_decode_mem(selected_indices=sorted_indices) not self.check_decode_mem(selected_indices=sorted_indices)
...@@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx] req = self.reqs[idx]
seq_lens_cpu = self.seq_lens.cpu().numpy() seq_lens_cpu = self.seq_lens_cpu.numpy()
if server_args.disaggregation_mode == "decode": if server_args.disaggregation_mode == "decode":
req.offload_kv_cache( req.offload_kv_cache(
...@@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.forward_mode = ForwardMode.IDLE self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device) self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
...@@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if self.enable_overlap: if self.enable_overlap:
# Do not use in-place operations in the overlap mode # Do not use in-place operations in the overlap mode
self.seq_lens = self.seq_lens + 1 self.seq_lens = self.seq_lens + 1
self.seq_lens_cpu = self.seq_lens_cpu + 1
self.orig_seq_lens = self.orig_seq_lens + 1 self.orig_seq_lens = self.orig_seq_lens + 1
else: else:
# A faster in-place version # A faster in-place version
self.seq_lens.add_(1) self.seq_lens.add_(1)
self.seq_lens_cpu.add_(1)
self.orig_seq_lens.add_(1) self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs self.seq_lens_sum += bs
...@@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_pool_indices, self.seq_lens - 2 self.req_pool_indices, self.seq_lens - 2
] ]
self.out_cache_loc = self.alloc_paged_token_slots_decode( self.out_cache_loc = self.alloc_paged_token_slots_decode(
self.seq_lens, last_loc self.seq_lens, self.seq_lens_cpu, last_loc
) )
self.req_to_token_pool.write( self.req_to_token_pool.write(
...@@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices] self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device] self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device] self.seq_lens = self.seq_lens[keep_indices_device]
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item() self.seq_lens_sum = self.seq_lens.sum().item()
...@@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[self.req_pool_indices, other.req_pool_indices] [self.req_pool_indices, other.req_pool_indices]
) )
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens]) self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
self.out_cache_loc = None self.out_cache_loc = None
self.seq_lens_sum += other.seq_lens_sum self.seq_lens_sum += other.seq_lens_sum
...@@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.sampling_info.grammars = None self.sampling_info.grammars = None
seq_lens_cpu = ( seq_lens_cpu = (
seq_lens_cpu_cache seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
if seq_lens_cpu_cache is not None
else self.seq_lens.cpu()
) )
global bid global bid
......
...@@ -27,7 +27,7 @@ import triton ...@@ -27,7 +27,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.mem_cache.memory_pool import SWAKVPool from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.utils import get_bool_env_var, next_power_of_2 from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache from sglang.srt.mem_cache.memory_pool import KVCache
...@@ -294,7 +294,6 @@ def alloc_extend_kernel( ...@@ -294,7 +294,6 @@ def alloc_extend_kernel(
last_loc_ptr, last_loc_ptr,
free_page_ptr, free_page_ptr,
out_indices, out_indices,
ret_values,
bs_upper: tl.constexpr, bs_upper: tl.constexpr,
page_size: tl.constexpr, page_size: tl.constexpr,
max_num_extend_tokens: tl.constexpr, max_num_extend_tokens: tl.constexpr,
...@@ -323,13 +322,6 @@ def alloc_extend_kernel( ...@@ -323,13 +322,6 @@ def alloc_extend_kernel(
sum_num_new_pages = tl.sum(num_new_pages) sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
tl.int64
)
tl.store(ret_values, merged_value)
# Part 1: fill the old partial page # Part 1: fill the old partial page
last_loc = tl.load(last_loc_ptr + pid) last_loc = tl.load(last_loc_ptr + pid)
num_part1 = ( num_part1 = (
...@@ -381,7 +373,6 @@ def alloc_decode_kernel( ...@@ -381,7 +373,6 @@ def alloc_decode_kernel(
last_loc_ptr, last_loc_ptr,
free_page_ptr, free_page_ptr,
out_indices, out_indices,
ret_values,
bs_upper: tl.constexpr, bs_upper: tl.constexpr,
page_size: tl.constexpr, page_size: tl.constexpr,
): ):
...@@ -404,10 +395,6 @@ def alloc_decode_kernel( ...@@ -404,10 +395,6 @@ def alloc_decode_kernel(
sum_num_new_pages = tl.sum(num_new_pages) sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
tl.store(ret_values, sum_num_new_pages)
if num_page_start_loc_self == 0: if num_page_start_loc_self == 0:
last_loc = tl.load(last_loc_ptr + pid) last_loc = tl.load(last_loc_ptr + pid)
tl.store(out_indices + pid, last_loc + 1) tl.store(out_indices + pid, last_loc + 1)
...@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
super().__init__(size, page_size, dtype, device, kvcache, need_sort) super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
self.seen_max_num_extend_tokens_next_power_of_2 = 1 self.seen_max_num_extend_tokens_next_power_of_2 = 1
self.clear() self.clear()
...@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def alloc_extend( def alloc_extend(
self, self,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
extend_num_tokens: int, extend_num_tokens: int,
): ):
...@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
last_loc, last_loc,
self.free_pages, self.free_pages,
out_indices, out_indices,
self.ret_values,
next_power_of_2(bs), next_power_of_2(bs),
self.page_size, self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2, self.seen_max_num_extend_tokens_next_power_of_2,
...@@ -506,8 +493,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -506,8 +493,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
merged_value = self.ret_values.item() num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size)
num_new_pages = merged_value >> 32
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None
...@@ -517,6 +503,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -517,6 +503,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def alloc_decode( def alloc_decode(
self, self,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
): ):
if self.debug_mode: if self.debug_mode:
...@@ -534,7 +521,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -534,7 +521,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
last_loc, last_loc,
self.free_pages, self.free_pages,
out_indices, out_indices,
self.ret_values,
next_power_of_2(bs), next_power_of_2(bs),
self.page_size, self.page_size,
) )
...@@ -542,7 +528,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -542,7 +528,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.item() num_new_pages = get_num_new_pages(
seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True
)
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
return None return None
......
...@@ -69,7 +69,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -69,7 +69,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def alloc_extend( def alloc_extend(
self, self,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
extend_num_tokens: int, extend_num_tokens: int,
): ):
...@@ -80,8 +82,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -80,8 +82,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
num_new_pages = ( num_new_pages = (
( (
(seq_lens + self.page_size - 1) // self.page_size (seq_lens_cpu + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size - (prefix_lens_cpu + self.page_size - 1) // self.page_size
) )
.sum() .sum()
.item() .item()
...@@ -115,6 +117,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -115,6 +117,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def alloc_decode( def alloc_decode(
self, self,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor, last_loc: torch.Tensor,
): ):
if self.debug_mode: if self.debug_mode:
...@@ -123,7 +126,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -123,7 +126,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
) )
need_new_pages = (seq_lens % self.page_size == 1).int() need_new_pages = (seq_lens % self.page_size == 1).int()
num_new_pages = need_new_pages.sum().item() need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int()
num_new_pages = need_new_pages_cpu.sum().item()
if num_new_pages > len(self.free_pages): if num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
......
...@@ -104,14 +104,21 @@ class EagleVerifyInput(SpecInput): ...@@ -104,14 +104,21 @@ class EagleVerifyInput(SpecInput):
end_offset = batch.seq_lens + self.draft_token_num end_offset = batch.seq_lens + self.draft_token_num
else: else:
prefix_lens = batch.seq_lens prefix_lens = batch.seq_lens
prefix_lens_cpu = batch.seq_lens_cpu
end_offset = prefix_lens + self.draft_token_num end_offset = prefix_lens + self.draft_token_num
end_offset_cpu = prefix_lens_cpu + self.draft_token_num
last_loc = get_last_loc( last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
batch.req_pool_indices, batch.req_pool_indices,
prefix_lens, prefix_lens,
) )
batch.out_cache_loc = batch.alloc_paged_token_slots_extend( batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
prefix_lens, end_offset, last_loc, len(batch.input_ids) prefix_lens,
prefix_lens_cpu,
end_offset,
end_offset_cpu,
last_loc,
len(batch.input_ids),
) )
self.last_loc = last_loc self.last_loc = last_loc
...@@ -380,6 +387,8 @@ class EagleVerifyInput(SpecInput): ...@@ -380,6 +387,8 @@ class EagleVerifyInput(SpecInput):
verified_id = predict[accept_index] verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False evict_mask[accept_index] = False
accept_length_cpu = accept_length.cpu()
accept_length_list = accept_length_cpu.tolist()
if page_size == 1: if page_size == 1:
# TODO: boolean array index leads to a device sync. Remove it. # TODO: boolean array index leads to a device sync. Remove it.
...@@ -456,13 +465,15 @@ class EagleVerifyInput(SpecInput): ...@@ -456,13 +465,15 @@ class EagleVerifyInput(SpecInput):
else: else:
batch.out_cache_loc = tgt_cache_loc batch.out_cache_loc = tgt_cache_loc
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
draft_input = EagleDraftInput( draft_input = EagleDraftInput(
hidden_states=batch.spec_info.hidden_states[accept_index], hidden_states=batch.spec_info.hidden_states[accept_index],
verified_id=verified_id, verified_id=verified_id,
accept_length=accept_length, accept_length=accept_length,
accept_length_cpu=accept_length.tolist(), accept_length_cpu=accept_length_list,
seq_lens_for_draft_extend=batch.seq_lens, seq_lens_for_draft_extend=batch.seq_lens,
seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu,
req_pool_indices_for_draft_extend=batch.req_pool_indices, req_pool_indices_for_draft_extend=batch.req_pool_indices,
) )
...@@ -485,15 +496,15 @@ class EagleVerifyInput(SpecInput): ...@@ -485,15 +496,15 @@ class EagleVerifyInput(SpecInput):
next_power_of_2(bs), next_power_of_2(bs),
) )
batch.seq_lens.add_(accept_length + 1) batch.seq_lens.add_(accept_length + 1)
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
accept_length_cpu = accept_length.tolist()
if len(unfinished_accept_index) > 0: if len(unfinished_accept_index) > 0:
unfinished_accept_index = torch.cat(unfinished_accept_index) unfinished_accept_index = torch.cat(unfinished_accept_index)
unfinished_index_device = torch.tensor( unfinished_index_device = torch.tensor(
unfinished_index, dtype=torch.int64, device=predict.device unfinished_index, dtype=torch.int64, device=predict.device
) )
draft_input_accept_length_cpu = [ draft_input_accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index accept_length_list[i] for i in unfinished_index
] ]
if page_size == 1 or self.topk == 1: if page_size == 1 or self.topk == 1:
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index] batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
...@@ -508,6 +519,7 @@ class EagleVerifyInput(SpecInput): ...@@ -508,6 +519,7 @@ class EagleVerifyInput(SpecInput):
unfinished_index_device, unfinished_index_device,
batch.seq_lens, batch.seq_lens,
) )
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
filter_finished_cache_loc_kernel[(bs,)]( filter_finished_cache_loc_kernel[(bs,)](
batch.out_cache_loc, batch.out_cache_loc,
tgt_cache_loc, tgt_cache_loc,
...@@ -525,6 +537,7 @@ class EagleVerifyInput(SpecInput): ...@@ -525,6 +537,7 @@ class EagleVerifyInput(SpecInput):
accept_length_cpu=draft_input_accept_length_cpu, accept_length_cpu=draft_input_accept_length_cpu,
accept_length=accept_length[unfinished_index_device], accept_length=accept_length[unfinished_index_device],
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device], seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu[unfinished_index],
req_pool_indices_for_draft_extend=batch.req_pool_indices[ req_pool_indices_for_draft_extend=batch.req_pool_indices[
unfinished_index_device unfinished_index_device
], ],
...@@ -542,7 +555,7 @@ class EagleVerifyInput(SpecInput): ...@@ -542,7 +555,7 @@ class EagleVerifyInput(SpecInput):
draft_input=draft_input, draft_input=draft_input,
logits_output=logits_output, logits_output=logits_output,
verified_id=verified_id, verified_id=verified_id,
accept_length_per_req_cpu=accept_length_cpu, accept_length_per_req_cpu=accept_length_list,
accepted_indices=accept_index, accepted_indices=accept_index,
) )
...@@ -575,6 +588,7 @@ class EagleDraftInput(SpecInput): ...@@ -575,6 +588,7 @@ class EagleDraftInput(SpecInput):
# Inputs for draft extend # Inputs for draft extend
# shape: (b,) # shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None seq_lens_for_draft_extend: torch.Tensor = None
seq_lens_for_draft_extend_cpu: torch.Tensor = None
req_pool_indices_for_draft_extend: torch.Tensor = None req_pool_indices_for_draft_extend: torch.Tensor = None
def __post_init__(self): def __post_init__(self):
...@@ -631,6 +645,7 @@ class EagleDraftInput(SpecInput): ...@@ -631,6 +645,7 @@ class EagleDraftInput(SpecInput):
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu] batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens) batch.extend_num_tokens = sum(batch.extend_lens)
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.seq_lens_cpu = batch.spec_info.seq_lens_for_draft_extend_cpu
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
batch.return_logprob = False batch.return_logprob = False
batch.return_hidden_states = False batch.return_hidden_states = False
......
...@@ -543,6 +543,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -543,6 +543,8 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens, batch.seq_lens,
self.speculative_num_steps, self.speculative_num_steps,
) )
prefix_lens_cpu = batch.seq_lens_cpu
seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps
extend_num_tokens = num_seqs * self.speculative_num_steps extend_num_tokens = num_seqs * self.speculative_num_steps
else: else:
# In this case, the last partial page needs to be duplicated. # In this case, the last partial page needs to be duplicated.
...@@ -578,14 +580,23 @@ class EAGLEWorker(TpModelWorker): ...@@ -578,14 +580,23 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.page_size, self.page_size,
) )
prefix_lens_cpu = batch.seq_lens_cpu
# TODO(lmzheng): remove this device sync last_page_lens = prefix_lens_cpu % self.page_size
extend_num_tokens = torch.sum(self.extend_lens).item() num_new_pages_per_topk = (
last_page_lens + self.speculative_num_steps + self.page_size - 1
) // self.page_size
seq_lens_cpu = (
prefix_lens_cpu // self.page_size * self.page_size
+ num_new_pages_per_topk * (self.page_size * self.topk)
)
extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
out_cache_loc, token_to_kv_pool_state_backup = ( out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend( batch.alloc_paged_token_slots_extend(
prefix_lens, prefix_lens,
prefix_lens_cpu,
seq_lens, seq_lens,
seq_lens_cpu,
last_loc, last_loc,
extend_num_tokens, extend_num_tokens,
backup_state=True, backup_state=True,
...@@ -1003,6 +1014,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -1003,6 +1014,7 @@ class EAGLEWorker(TpModelWorker):
assert isinstance(batch.spec_info, EagleDraftInput) assert isinstance(batch.spec_info, EagleDraftInput)
# Backup fields that will be modified in-place # Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone() seq_lens_backup = batch.seq_lens.clone()
seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
req_pool_indices_backup = batch.req_pool_indices req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob return_logprob_backup = batch.return_logprob
...@@ -1081,6 +1093,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -1081,6 +1093,7 @@ class EAGLEWorker(TpModelWorker):
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
) )
batch.seq_lens = seq_lens_backup batch.seq_lens = seq_lens_backup
batch.seq_lens_cpu = seq_lens_cpu_backup
batch.req_pool_indices = req_pool_indices_backup batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.accept_length = accept_length_backup batch.spec_info.accept_length = accept_length_backup
batch.return_logprob = return_logprob_backup batch.return_logprob = return_logprob_backup
......
...@@ -77,6 +77,7 @@ class NgramVerifyInput(SpecInput): ...@@ -77,6 +77,7 @@ class NgramVerifyInput(SpecInput):
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
end_offset = batch.seq_lens + self.draft_token_num end_offset = batch.seq_lens + self.draft_token_num
else: else:
# TODO(lsyin): add prefix lens cpu here to support page size > 1
prefix_lens = batch.seq_lens prefix_lens = batch.seq_lens
end_offset = prefix_lens + self.draft_token_num end_offset = prefix_lens + self.draft_token_num
last_loc = get_last_loc( last_loc = get_last_loc(
...@@ -405,10 +406,13 @@ class NgramVerifyInput(SpecInput): ...@@ -405,10 +406,13 @@ class NgramVerifyInput(SpecInput):
self._fill_requests(batch, logits_output) self._fill_requests(batch, logits_output)
self._free_cache(batch, page_size) self._free_cache(batch, page_size)
accept_length_cpu = self.accept_length.cpu()
num_accepted_tokens = accept_length_cpu.sum().item()
batch.seq_lens.add_(self.accept_length + 1) batch.seq_lens.add_(self.accept_length + 1)
batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.seq_lens_cpu.add_(accept_length_cpu + 1)
return logits_output, self.verified_id, self.accept_length.sum().item() return logits_output, self.verified_id, num_accepted_tokens
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
pass pass
......
...@@ -3250,6 +3250,30 @@ def get_extend_input_len_swa_limit( ...@@ -3250,6 +3250,30 @@ def get_extend_input_len_swa_limit(
return page_size + 2 * max(sliding_window_size, chunked_prefill_size) return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
def get_num_new_pages(
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
page_size: int,
decode: bool = False,
) -> torch.Tensor:
"""
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch.
"""
cpu_device = torch.device("cpu")
assert prefix_lens.device == cpu_device
assert seq_lens.device == cpu_device
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (prefix_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
extend_lens = seq_lens - prefix_lens
sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
if decode:
return sum_num_new_pages.item()
merged_value = (sum_num_new_pages) << 32 | torch.sum(extend_lens).to(torch.int64)
return merged_value.item() >> 32
class CachedKernel: class CachedKernel:
""" """
Wrapper that allows kernel[grid](...) syntax with caching based on a key function. Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment