Unverified Commit 9efc3bdc authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] Fix `_compute_slot_mappings_kernel` for chunked prefill (#36580)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 156e3355
......@@ -138,10 +138,8 @@ class BlockTables:
num_tokens_padded: int,
) -> torch.Tensor:
num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
idx_mapping,
query_start_loc,
......@@ -213,7 +211,6 @@ def _gather_block_tables_kernel(
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
idx_mapping, # [num_reqs]
query_start_loc, # [num_reqs + 1]
......@@ -236,7 +233,11 @@ def _compute_slot_mappings_kernel(
if batch_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
# Start from actual token count (not padded) to cover the gap
# between actual tokens and padded tokens that can contain stale
# valid slot IDs from previous chunks during chunked prefill.
actual_num_tokens = tl.load(query_start_loc + batch_idx)
for i in range(actual_num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment