Unverified Commit e1da249c authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Minor refactor for `compute_slot_mappings` (#32794)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 9b693d02
...@@ -116,24 +116,26 @@ class BlockTables: ...@@ -116,24 +116,26 @@ class BlockTables:
def compute_slot_mappings( def compute_slot_mappings(
self, self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor, query_start_loc: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = query_start_loc.shape[0] - 1 num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0] num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)]( _compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens, num_tokens,
self.max_num_batched_tokens, self.max_num_batched_tokens,
idx_mapping,
query_start_loc, query_start_loc,
positions, positions,
self.input_block_table_ptrs, self.block_table_ptrs,
self.block_table_strides, self.block_table_strides,
self.block_sizes_tensor, self.block_sizes_tensor,
self.slot_mappings, self.slot_mappings,
self.slot_mappings.stride(0), self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID, PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024, # type: ignore TRITON_BLOCK_SIZE=1024, # type: ignore
) )
return self.slot_mappings[:, :num_tokens] return self.slot_mappings[:, :num_tokens]
...@@ -176,42 +178,44 @@ def _gather_block_tables_kernel( ...@@ -176,42 +178,44 @@ def _gather_block_tables_kernel(
def _compute_slot_mappings_kernel( def _compute_slot_mappings_kernel(
num_tokens, num_tokens,
max_num_tokens, max_num_tokens,
cu_num_tokens, # [num_reqs + 1] idx_mapping, # [num_reqs]
query_start_loc, # [num_reqs + 1]
pos, # [num_tokens] pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups] block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups] block_table_strides, # [num_kv_cache_groups]
page_sizes, # [num_kv_cache_groups] block_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens] slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride, slot_mappings_stride,
PAD_ID: tl.constexpr, PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr, TRITON_BLOCK_SIZE: tl.constexpr,
): ):
# kv cache group id # kv cache group id
group_id = tl.program_id(0) group_id = tl.program_id(0)
req_idx = tl.program_id(1) batch_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if req_idx == tl.num_programs(1) - 1: if batch_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs. # Pad remaining slots to -1. This is needed for CUDA graphs.
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE): for i in range(num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE) offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens) tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
return return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id) block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id) block_size = tl.load(block_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx) req_state_idx = tl.load(idx_mapping + batch_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1) start_idx = tl.load(query_start_loc + batch_idx)
for i in range(start_idx, end_idx, BLOCK_SIZE): end_idx = tl.load(query_start_loc + batch_idx + 1)
offset = i + tl.arange(0, BLOCK_SIZE) for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0) positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size block_indices = positions // block_size
block_numbers = tl.load( block_numbers = tl.load(
block_table_ptr + req_idx * block_table_stride + block_indices block_table_ptr + req_state_idx * block_table_stride + block_indices
) )
slot_ids = block_numbers * page_size + positions % page_size slot_ids = block_numbers * block_size + positions % block_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
......
...@@ -607,7 +607,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -607,7 +607,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute slot mappings: [num_kv_cache_groups, num_tokens] # Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings( slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc, self.input_buffers.positions[:num_tokens] idx_mapping,
query_start_loc,
self.input_buffers.positions[:num_tokens],
) )
# Layer name -> attention metadata. # Layer name -> attention metadata.
......
...@@ -138,6 +138,7 @@ class EagleSpeculator: ...@@ -138,6 +138,7 @@ class EagleSpeculator:
) -> None: ) -> None:
pos = self.input_buffers.positions[:num_reqs] pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
idx_mapping = self.idx_mapping[:num_reqs]
for step in range(1, self.num_speculative_steps): for step in range(1, self.num_speculative_steps):
# Run the eagle model. # Run the eagle model.
last_hidden_states, hidden_states = self.run_model( last_hidden_states, hidden_states = self.run_model(
...@@ -149,7 +150,7 @@ class EagleSpeculator: ...@@ -149,7 +150,7 @@ class EagleSpeculator:
# used for draft and target sampling. # used for draft and target sampling.
draft_tokens = gumbel_sample( draft_tokens = gumbel_sample(
logits, logits,
self.idx_mapping[:num_reqs], idx_mapping,
self.temperature, self.temperature,
self.seeds, self.seeds,
pos + 1, pos + 1,
...@@ -166,7 +167,9 @@ class EagleSpeculator: ...@@ -166,7 +167,9 @@ class EagleSpeculator:
self.hidden_states, self.hidden_states,
self.max_model_len, self.max_model_len,
) )
self.block_tables.compute_slot_mappings(query_start_loc, pos) self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
def capture_model(self) -> None: def capture_model(self) -> None:
if self.num_speculative_steps == 1: if self.num_speculative_steps == 1:
...@@ -279,7 +282,9 @@ class EagleSpeculator: ...@@ -279,7 +282,9 @@ class EagleSpeculator:
self.max_num_reqs, self.max_num_reqs,
) )
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(query_start_loc, pos) slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None: if cudagraph_size is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment