Unverified Commit ca1b1e72 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Refactor prefill token preparation (#29712)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 762a4a6c
...@@ -78,7 +78,7 @@ class CudaGraphManager: ...@@ -78,7 +78,7 @@ class CudaGraphManager:
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> None: ) -> None:
num_reqs = min(num_tokens, self.max_num_reqs) num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids.gpu[:num_tokens] input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens] positions = input_buffers.positions[:num_tokens]
attn_metadata = prepare_inputs_to_capture( attn_metadata = prepare_inputs_to_capture(
num_reqs, num_reqs,
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import numba
import numpy as np import numpy as np
import torch import torch
...@@ -30,15 +29,12 @@ class InputBuffers: ...@@ -30,15 +29,12 @@ class InputBuffers:
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32) self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32) self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device) self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32) self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
# Spec decoding.
self.next_prefill_tokens = self._make_buffer(max_num_reqs, dtype=torch.int32)
# Structured outputs. # Structured outputs.
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32) self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.grammar_bitmask = self._make_buffer( self.grammar_bitmask = self._make_buffer(
...@@ -120,7 +116,7 @@ class InputBatch: ...@@ -120,7 +116,7 @@ class InputBatch:
input_buffers.seq_lens[num_reqs:] = 0 input_buffers.seq_lens[num_reqs:] = 0
seq_lens = input_buffers.seq_lens[:num_reqs] seq_lens = input_buffers.seq_lens[:num_reqs]
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens) input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens] positions = input_buffers.positions[:num_tokens]
# attn_metadata = defaultdict(lambda: None) # attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1 logits_indices = query_start_loc[1:] - 1
...@@ -146,41 +142,63 @@ class InputBatch: ...@@ -146,41 +142,63 @@ class InputBatch:
) )
@numba.njit(cache=True) @triton.jit
def _prepare_prefill_inputs( def _prepare_prefill_inputs_kernel(
idx_mapping: np.ndarray, # [B] input_ids_ptr,
query_lens: np.ndarray, # [B] next_prefill_tokens_ptr,
query_start_loc: np.ndarray, # [B + 1] idx_mapping_ptr,
prefill_token_ids: np.ndarray, # [N, max_model_len] query_start_loc_ptr,
num_computed_prefill_tokens: np.ndarray, # [N] prefill_token_ids_ptr,
input_ids: np.ndarray, # [num_input_tokens] prefill_token_ids_stride,
) -> None: prefill_lens_ptr,
num_reqs = idx_mapping.shape[0] num_computed_tokens_ptr,
query_starts = query_start_loc[:num_reqs] BLOCK_SIZE: tl.constexpr,
query_ends = query_start_loc[1 : num_reqs + 1] ):
starts = num_computed_prefill_tokens[idx_mapping] batch_idx = tl.program_id(0)
ends = starts + query_lens req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
for i in range(num_reqs): prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[ num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
idx_mapping[i], starts[i] : ends[i] if num_computed >= prefill_len:
] # Not prefill.
return
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
prefill_ptr = prefill_token_ids_ptr + req_state_idx * prefill_token_ids_stride
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
tokens = tl.load(prefill_ptr + num_computed + block, mask=mask)
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
next_pos = num_computed + query_len
if next_pos < prefill_len:
next_token = tl.load(prefill_ptr + next_pos)
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
def prepare_prefill_inputs( def prepare_prefill_inputs(
idx_mapping: np.ndarray, input_ids: torch.Tensor,
num_scheduled_tokens: np.ndarray, next_prefill_tokens: torch.Tensor,
query_start_loc: np.ndarray, idx_mapping: torch.Tensor,
prefill_token_ids: np.ndarray, query_start_loc: torch.Tensor,
num_computed_prefill_tokens: np.ndarray, prefill_token_ids: torch.Tensor,
input_ids: np.ndarray, prefill_len: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None: ) -> None:
_prepare_prefill_inputs( num_reqs = idx_mapping.shape[0]
_prepare_prefill_inputs_kernel[(num_reqs,)](
input_ids,
next_prefill_tokens,
idx_mapping, idx_mapping,
num_scheduled_tokens,
query_start_loc, query_start_loc,
prefill_token_ids, prefill_token_ids,
num_computed_prefill_tokens, prefill_token_ids.stride(0),
input_ids, prefill_len,
num_computed_tokens,
BLOCK_SIZE=1024,
) )
......
...@@ -104,11 +104,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -104,11 +104,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.use_async_scheduling: if self.use_async_scheduling:
self.input_prep_event = torch.cuda.Event() self.input_prep_event = torch.cuda.Event()
self.structured_outputs_event = torch.cuda.Event() self.structured_outputs_event = torch.cuda.Event()
self.spec_decode_event = torch.cuda.Event()
else: else:
self.input_prep_event = None self.input_prep_event = None
self.structured_outputs_event = None self.structured_outputs_event = None
self.spec_decode_event = None
if self.speculative_config is not None: if self.speculative_config is not None:
self.do_spec_decode = True self.do_spec_decode = True
...@@ -412,9 +410,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -412,9 +410,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_new_blocks[i].append(x + len(block_ids)) cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids) new_block_ids[i].extend(block_ids)
overwrite.append(True) overwrite.append(True)
# Update the GPU tensors for request states.
if scheduler_output.scheduled_new_reqs:
self.req_states.prefill_len.copy_to_gpu()
# Add new blocks for the existing requests. # Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs cached_reqs = scheduler_output.scheduled_cached_reqs
...@@ -507,16 +502,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -507,16 +502,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1] query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1] query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
# Copy prefill tokens from CPU to GPU. # Get prefill tokens.
prepare_prefill_inputs( prepare_prefill_inputs(
idx_mapping_np, self.input_buffers.input_ids,
num_scheduled_tokens, self.req_states.next_prefill_tokens,
query_start_loc_np, idx_mapping,
self.req_states.prefill_token_ids.np, query_start_loc_gpu,
self.req_states.num_computed_prefill_tokens, self.req_states.prefill_token_ids.gpu,
self.input_buffers.input_ids.np, self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens,
) )
self.input_buffers.input_ids.copy_to_gpu(num_tokens)
# Prepare positions and seq_lens. # Prepare positions and seq_lens.
prepare_pos_seq_lens( prepare_pos_seq_lens(
...@@ -531,7 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -531,7 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Some input token ids are directly read from the last sampled tokens # Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from. # and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices = combine_sampled_and_draft_tokens( logits_indices = combine_sampled_and_draft_tokens(
self.input_buffers.input_ids.gpu, self.input_buffers.input_ids,
idx_mapping, idx_mapping,
self.req_states.last_sampled_tokens, self.req_states.last_sampled_tokens,
query_start_loc_gpu, query_start_loc_gpu,
...@@ -572,7 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -572,7 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
) )
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding] input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding] positions = self.input_buffers.positions[:num_tokens_after_padding]
return InputBatch( return InputBatch(
req_ids=req_ids, req_ids=req_ids,
...@@ -782,20 +777,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -782,20 +777,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = input_batch.num_reqs
idx_mapping_np = input_batch.idx_mapping_np
with async_barrier(self.spec_decode_event):
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
self.req_states.prefill_token_ids.np[
idx_mapping_np,
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
]
)
next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
num_reqs
)
assert self.speculator is not None assert self.speculator is not None
last_sampled_tokens = self.req_states.last_sampled_tokens[
input_batch.idx_mapping
]
next_prefill_tokens = self.req_states.next_prefill_tokens[
input_batch.idx_mapping
]
draft_tokens = self.speculator.propose( draft_tokens = self.speculator.propose(
input_batch, input_batch,
sampling_metadata, sampling_metadata,
...@@ -803,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -803,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
aux_hidden_states, aux_hidden_states,
num_sampled, num_sampled,
num_rejected, num_rejected,
self.req_states.last_sampled_tokens, last_sampled_tokens,
next_prefill_tokens, next_prefill_tokens,
) )
return draft_tokens return draft_tokens
......
...@@ -121,7 +121,7 @@ class EagleSpeculator: ...@@ -121,7 +121,7 @@ class EagleSpeculator:
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
): ):
ret_hidden_states = self.model( ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids.gpu[:num_tokens], input_ids=self.input_buffers.input_ids[:num_tokens],
positions=self.input_buffers.positions[:num_tokens], positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens], hidden_states=self.hidden_states[:num_tokens],
) )
...@@ -194,7 +194,7 @@ class EagleSpeculator: ...@@ -194,7 +194,7 @@ class EagleSpeculator:
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
# [max_num_reqs, 1] # [num_reqs]
last_sampled: torch.Tensor, last_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
next_prefill_tokens: torch.Tensor, next_prefill_tokens: torch.Tensor,
...@@ -316,7 +316,6 @@ def _prepare_eagle_inputs_kernel( ...@@ -316,7 +316,6 @@ def _prepare_eagle_inputs_kernel(
eagle_positions_ptr, eagle_positions_ptr,
target_input_ids_ptr, target_input_ids_ptr,
target_positions_ptr, target_positions_ptr,
idx_mapping_ptr,
last_sampled_ptr, last_sampled_ptr,
next_prefill_tokens_ptr, next_prefill_tokens_ptr,
num_sampled_ptr, num_sampled_ptr,
...@@ -335,8 +334,7 @@ def _prepare_eagle_inputs_kernel( ...@@ -335,8 +334,7 @@ def _prepare_eagle_inputs_kernel(
num_sampled = tl.load(num_sampled_ptr + batch_idx) num_sampled = tl.load(num_sampled_ptr + batch_idx)
if num_sampled > 0: if num_sampled > 0:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32)
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
else: else:
# Chunked prefilling. # Chunked prefilling.
# Get the next prefill token. # Get the next prefill token.
...@@ -368,9 +366,9 @@ def prepare_eagle_inputs( ...@@ -368,9 +366,9 @@ def prepare_eagle_inputs(
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
# [max_num_reqs, 1] # [num_reqs]
last_sampled: torch.Tensor, last_sampled: torch.Tensor,
# [max_num_reqs] # [num_reqs]
next_prefill_tokens: torch.Tensor, next_prefill_tokens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = input_batch.num_reqs num_reqs = input_batch.num_reqs
...@@ -381,11 +379,10 @@ def prepare_eagle_inputs( ...@@ -381,11 +379,10 @@ def prepare_eagle_inputs(
) )
_prepare_eagle_inputs_kernel[(num_reqs,)]( _prepare_eagle_inputs_kernel[(num_reqs,)](
last_token_indices, last_token_indices,
input_buffers.input_ids.gpu, input_buffers.input_ids,
input_buffers.positions, input_buffers.positions,
input_batch.input_ids, input_batch.input_ids,
input_batch.positions, input_batch.positions,
input_batch.idx_mapping,
last_sampled, last_sampled,
next_prefill_tokens, next_prefill_tokens,
num_sampled, num_sampled,
...@@ -485,7 +482,7 @@ def prepare_eagle_decode( ...@@ -485,7 +482,7 @@ def prepare_eagle_decode(
last_token_indices, last_token_indices,
target_seq_lens, target_seq_lens,
num_rejected, num_rejected,
input_buffers.input_ids.gpu, input_buffers.input_ids,
input_buffers.positions, input_buffers.positions,
input_hidden_states, input_hidden_states,
input_hidden_states.stride(0), input_hidden_states.stride(0),
...@@ -553,7 +550,7 @@ def update_eagle_inputs( ...@@ -553,7 +550,7 @@ def update_eagle_inputs(
): ):
num_reqs, hidden_size = output_hidden_states.shape num_reqs, hidden_size = output_hidden_states.shape
_update_eagle_inputs_kernel[(num_reqs,)]( _update_eagle_inputs_kernel[(num_reqs,)](
input_buffers.input_ids.gpu, input_buffers.input_ids,
input_buffers.positions, input_buffers.positions,
hidden_states, hidden_states,
hidden_states.stride(0), hidden_states.stride(0),
......
...@@ -117,8 +117,7 @@ class RequestState: ...@@ -117,8 +117,7 @@ class RequestState:
self.prefill_token_ids = UvaBuffer( self.prefill_token_ids = UvaBuffer(
self.max_num_reqs, self.max_model_len, dtype=torch.int32 self.max_num_reqs, self.max_model_len, dtype=torch.int32
) )
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32) self.prefill_len = UvaBuffer(self.max_num_reqs, dtype=torch.int32)
# Number of computed tokens. # Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.num_computed_tokens = torch.zeros( self.num_computed_tokens = torch.zeros(
...@@ -140,6 +139,9 @@ class RequestState: ...@@ -140,6 +139,9 @@ class RequestState:
dtype=torch.int64, dtype=torch.int64,
device=device, device=device,
) )
self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
# LoRA. # LoRA.
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
...@@ -380,13 +382,13 @@ def _expand_sampling_metadata_kernel( ...@@ -380,13 +382,13 @@ def _expand_sampling_metadata_kernel(
expanded_top_p_ptr, expanded_top_p_ptr,
top_k_ptr, top_k_ptr,
expanded_top_k_ptr, expanded_top_k_ptr,
seeds_ptr,
rep_penalty_ptr, rep_penalty_ptr,
expanded_rep_penalty_ptr, expanded_rep_penalty_ptr,
freq_penalty_ptr, freq_penalty_ptr,
expanded_freq_penalty_ptr, expanded_freq_penalty_ptr,
pres_penalty_ptr, pres_penalty_ptr,
expanded_pres_penalty_ptr, expanded_pres_penalty_ptr,
seeds_ptr,
expanded_seeds_ptr, expanded_seeds_ptr,
cu_num_logits_ptr, cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment