"tests/vscode:/vscode.git/clone" did not exist on "853c371fc33e7c99aa2ab9f6e2cd7cbd1cadcf99"
Unverified Commit 9efc3bdc authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] Fix `_compute_slot_mappings_kernel` for chunked prefill (#36580)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 156e3355
...@@ -138,10 +138,8 @@ class BlockTables: ...@@ -138,10 +138,8 @@ class BlockTables:
num_tokens_padded: int, num_tokens_padded: int,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = idx_mapping.shape[0] num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)]( _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens, self.max_num_batched_tokens,
idx_mapping, idx_mapping,
query_start_loc, query_start_loc,
...@@ -213,7 +211,6 @@ def _gather_block_tables_kernel( ...@@ -213,7 +211,6 @@ def _gather_block_tables_kernel(
@triton.jit @triton.jit
def _compute_slot_mappings_kernel( def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens, max_num_tokens,
idx_mapping, # [num_reqs] idx_mapping, # [num_reqs]
query_start_loc, # [num_reqs + 1] query_start_loc, # [num_reqs + 1]
...@@ -236,7 +233,11 @@ def _compute_slot_mappings_kernel( ...@@ -236,7 +233,11 @@ def _compute_slot_mappings_kernel(
if batch_idx == tl.num_programs(1) - 1: if batch_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs. # Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE): # Start from actual token count (not padded) to cover the gap
# between actual tokens and padded tokens that can contain stale
# valid slot IDs from previous chunks during chunked prefill.
actual_num_tokens = tl.load(query_start_loc + batch_idx)
for i in range(actual_num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE) offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens) tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment