Unverified Commit ccf90ba7 authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[Model Runner V2] Add full cuda graph support for eagle prefill (#37588)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent 6adacfcb
...@@ -464,7 +464,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -464,7 +464,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
hidden_states = self.execute_model_state.hidden_states hidden_states = self.execute_model_state.hidden_states
aux_hidden_states = self.execute_model_state.aux_hidden_states aux_hidden_states = self.execute_model_state.aux_hidden_states
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
self.execute_model_state = None self.execute_model_state = None
# dummy run the eagle speculator's propose to ensure DP/EP sync. # dummy run the eagle speculator's propose to ensure DP/EP sync.
...@@ -496,7 +495,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -496,7 +495,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_prefill_tokens=self.req_states.next_prefill_tokens, next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu, temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu, seeds=self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True, dummy_run=True,
skip_attn_for_dummy_run=skip_attn, skip_attn_for_dummy_run=skip_attn,
mm_inputs=mm_inputs, mm_inputs=mm_inputs,
...@@ -1110,7 +1108,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1110,7 +1108,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states=hidden_states, hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states, aux_hidden_states=aux_hidden_states,
kv_connector_output=kv_connector_output, kv_connector_output=kv_connector_output,
num_tokens_across_dp=num_tokens_across_dp,
) )
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
...@@ -1135,7 +1132,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1135,7 +1132,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states = self.execute_model_state.hidden_states hidden_states = self.execute_model_state.hidden_states
aux_hidden_states = self.execute_model_state.aux_hidden_states aux_hidden_states = self.execute_model_state.aux_hidden_states
kv_connector_output = self.execute_model_state.kv_connector_output kv_connector_output = self.execute_model_state.kv_connector_output
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
self.execute_model_state = None self.execute_model_state = None
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
...@@ -1228,7 +1224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1228,7 +1224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens, self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu, self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
mm_inputs=mm_inputs, mm_inputs=mm_inputs,
) )
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
...@@ -1336,4 +1331,3 @@ class ExecuteModelState(NamedTuple): ...@@ -1336,4 +1331,3 @@ class ExecuteModelState(NamedTuple):
hidden_states: torch.Tensor | None hidden_states: torch.Tensor | None
aux_hidden_states: list[torch.Tensor] | None aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None kv_connector_output: KVConnectorOutput | None
num_tokens_across_dp: torch.Tensor | None
...@@ -19,21 +19,16 @@ from vllm.v1.worker.utils import AttentionGroup ...@@ -19,21 +19,16 @@ from vllm.v1.worker.utils import AttentionGroup
class EagleCudaGraphManager(CudaGraphManager): class EagleCudaGraphManager(CudaGraphManager):
"""CudaGraphManager for Eagle speculative decoding (FULL mode only).""" """CudaGraphManager for Eagle speculative decoding."""
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
cudagraph_mode: CUDAGraphMode, cudagraph_mode: CUDAGraphMode,
draft_tokens: torch.Tensor, decode_query_len: int,
): ):
assert not cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE), ( super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
"EagleCudaGraphManager does not support PIECEWISE mode yet"
)
# Eagle always uses uniform decode with query_len=1
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len=1)
self.draft_tokens = draft_tokens
# Use a dedicated pool for Eagle to avoid memory overlap with the main # Use a dedicated pool for Eagle to avoid memory overlap with the main
# model's cudagraph. The base class uses a shared global pool, but Eagle's # model's cudagraph. The base class uses a shared global pool, but Eagle's
...@@ -44,7 +39,7 @@ class EagleCudaGraphManager(CudaGraphManager): ...@@ -44,7 +39,7 @@ class EagleCudaGraphManager(CudaGraphManager):
def capture( def capture(
self, self,
generate_fn: Callable, forward_fn: Callable,
model_state: ModelState, model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
block_tables: BlockTables, block_tables: BlockTables,
...@@ -52,7 +47,7 @@ class EagleCudaGraphManager(CudaGraphManager): ...@@ -52,7 +47,7 @@ class EagleCudaGraphManager(CudaGraphManager):
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
progress_bar_desc: str = "Capturing CUDA graphs", progress_bar_desc: str = "Capturing CUDA graphs",
) -> None: ) -> None:
"""Capture CUDA graphs for Eagle speculative decoding (FULL mode only).""" """Capture CUDA graphs for Eagle."""
def create_forward_fn( def create_forward_fn(
desc: BatchExecutionDescriptor, desc: BatchExecutionDescriptor,
...@@ -74,7 +69,7 @@ class EagleCudaGraphManager(CudaGraphManager): ...@@ -74,7 +69,7 @@ class EagleCudaGraphManager(CudaGraphManager):
kv_cache_config, kv_cache_config,
) )
return lambda cg_mode: generate_fn( return lambda cg_mode: forward_fn(
num_reqs, num_reqs,
num_tokens, num_tokens,
attn_metadata, attn_metadata,
...@@ -84,8 +79,3 @@ class EagleCudaGraphManager(CudaGraphManager): ...@@ -84,8 +79,3 @@ class EagleCudaGraphManager(CudaGraphManager):
) )
super().capture(create_forward_fn, progress_bar_desc) super().capture(create_forward_fn, progress_bar_desc)
def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor:
"""Replay a captured FULL cudagraph and return draft tokens."""
super().run_fullgraph(desc)
return self.draft_tokens
...@@ -19,11 +19,16 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -19,11 +19,16 @@ from vllm.v1.worker.gpu.attn_utils import (
init_attn_backend, init_attn_backend,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
get_uniform_token_count,
)
from vllm.v1.worker.gpu.dp_utils import dispatch_cg_and_sync_dp from vllm.v1.worker.gpu.dp_utils import dispatch_cg_and_sync_dp
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import (
EagleCudaGraphManager,
)
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -76,6 +81,9 @@ class EagleSpeculator: ...@@ -76,6 +81,9 @@ class EagleSpeculator:
dtype=torch.int64, dtype=torch.int64,
device=device, device=device,
) )
self.last_token_indices = torch.zeros(
self.max_num_reqs, dtype=torch.int64, device=device
)
self.supports_mm_inputs = MULTIMODAL_REGISTRY.supports_multimodal_inputs( self.supports_mm_inputs = MULTIMODAL_REGISTRY.supports_multimodal_inputs(
self.draft_model_config self.draft_model_config
...@@ -95,20 +103,30 @@ class EagleSpeculator: ...@@ -95,20 +103,30 @@ class EagleSpeculator:
device=device, device=device,
) )
self.cudagraph_manager: EagleCudaGraphManager | None = None self.prefill_cudagraph_manager: EagleCudaGraphManager | None = None
self.decode_cudagraph_manager: EagleCudaGraphManager | None = None
def init_cudagraph_manager(self, cudagraph_mode: CUDAGraphMode) -> None: def init_cudagraph_manager(self, cudagraph_mode: CUDAGraphMode) -> None:
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL: cudagraph_mode = self.vllm_config.compilation_config.cudagraph_mode
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY # Initialize cudagraph manager for draft prefill (draft position 0).
else: self.prefill_cudagraph_manager = EagleCudaGraphManager(
cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_manager = EagleCudaGraphManager(
self.vllm_config, self.vllm_config,
self.device, self.device,
cudagraph_mode, cudagraph_mode,
self.draft_tokens, self.num_speculative_steps + 1,
) )
# Initialize cudagraph manager for draft generation (draft positions > 0).
self.decode_cudagraph_manager = EagleCudaGraphManager(
self.vllm_config,
self.device,
# Only use FULL graph mode, if available, because draft decodes
# only consist of a single token.
cudagraph_mode.decode_mode(),
decode_query_len=1,
)
# Share a single pool between prefill and decode since they never
# execute concurrently.
self.decode_cudagraph_manager.pool = self.prefill_cudagraph_manager.pool
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
target_attn_layer_names = get_layers_from_vllm_config( target_attn_layer_names = get_layers_from_vllm_config(
...@@ -189,6 +207,47 @@ class EagleSpeculator: ...@@ -189,6 +207,47 @@ class EagleSpeculator:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
return last_hidden_states, hidden_states return last_hidden_states, hidden_states
def prefill(
self,
num_reqs: int,
num_tokens: int,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
) -> None:
last_token_indices = self.last_token_indices[:num_reqs]
pos = self.input_buffers.positions[last_token_indices]
idx_mapping = self.idx_mapping[:num_reqs]
last_hidden_states, hidden_states = self.run_model(
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
mm_inputs=mm_inputs,
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
self.draft_tokens[:num_reqs, 0] = gumbel_sample(
logits,
idx_mapping,
self.temperature,
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=self.draft_logits[:, 0]
if self.draft_logits is not None
else None,
)
self.hidden_states[:num_reqs] = hidden_states[last_token_indices]
self.input_buffers.positions[:num_reqs] = pos
def generate_draft( def generate_draft(
self, self,
num_reqs: int, num_reqs: int,
...@@ -281,19 +340,46 @@ class EagleSpeculator: ...@@ -281,19 +340,46 @@ class EagleSpeculator:
return attn_metadata return attn_metadata
def capture_model(self) -> None: def capture_model(self) -> None:
assert self.cudagraph_manager is not None logger.info("Capturing model for Eagle speculator...")
# Reset indices to zeros to prevent stale values from prior
# dummy runs to cause out-of-bounds indexing during capture.
self.last_token_indices.zero_()
# Capture the prefill routine (model forward + compute_logits +
# gumbel_sample).
# For FULL graphs, the entire routine is recorded as one graph.
# For PIECEWISE, only the model's compiled regions are captured
# and the rest (compute_logits, gumbel_sample) runs eagerly.
assert self.prefill_cudagraph_manager is not None
self.prefill_cudagraph_manager.capture(
self.prefill,
self.model_state,
self.input_buffers,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
progress_bar_desc="Capturing eagle prefill CUDA graphs",
)
if self.num_speculative_steps == 1: if self.num_speculative_steps == 1:
return return
logger.info("Capturing model for Eagle speculator...") # Capture the decode draft generation loop (model forward +
self.cudagraph_manager.capture( # compute_logits + gumbel_sample + update_eagle_inputs, for
# each step).
# For FULL graphs, the entire multi-step loop is recorded as
# one graph. For PIECEWISE, only the model's compiled regions
# are captured, and the rest (compute_logits, gumbel_sample,
# update_eagle_inputs) runs eagerly.
assert self.decode_cudagraph_manager is not None
self.decode_cudagraph_manager.capture(
self.generate_draft, self.generate_draft,
self.model_state, self.model_state,
self.input_buffers, self.input_buffers,
self.block_tables, self.block_tables,
self.attn_groups, self.attn_groups,
self.kv_cache_config, self.kv_cache_config,
progress_bar_desc="Capturing eagle CUDA graphs", progress_bar_desc="Capturing eagle decode CUDA graphs",
) )
@torch.inference_mode() @torch.inference_mode()
...@@ -324,6 +410,10 @@ class EagleSpeculator: ...@@ -324,6 +410,10 @@ class EagleSpeculator:
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
is_profile: bool = False, is_profile: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = input_batch.num_tokens_after_padding
num_reqs = input_batch.num_reqs
max_query_len = input_batch.num_scheduled_tokens.max()
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and # number of rejected tokens, we maintain the size of eagle's input_ids and
# hidden_states the same as the target model's. This means, we pad each # hidden_states the same as the target model's. This means, we pad each
...@@ -337,82 +427,88 @@ class EagleSpeculator: ...@@ -337,82 +427,88 @@ class EagleSpeculator:
) )
else: else:
hidden_states = last_hidden_states hidden_states = last_hidden_states
num_tokens = input_batch.num_tokens_after_padding self.hidden_states[:num_tokens].copy_(hidden_states)
self.hidden_states[:num_tokens] = hidden_states
# Copy temperature, seeds, and idx mapping to the pre-allocated buffers.
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
self.temperature.copy_(temperature)
self.seeds.copy_(seeds)
self.idx_mapping[:num_reqs].copy_(input_batch.idx_mapping)
# Get the input ids and last token indices for the speculator. # Get the input ids and last token indices for the speculator.
last_token_indices = prepare_eagle_inputs( prepare_eagle_inputs(
self.input_buffers, self.input_buffers,
input_batch, input_batch,
self.last_token_indices,
num_sampled, num_sampled,
num_rejected, num_rejected,
last_sampled, last_sampled,
next_prefill_tokens, next_prefill_tokens,
self.max_num_reqs,
) )
# Prefill: Run the eagle speculator with eager mode. # When all requests are decoding (no true prefills), each has
# TODO(woosuk): Support CUDA graph for prefill. # num_speculative_steps + 1 tokens, enabling FULL graph replay.
last_hidden_states, hidden_states = self.run_model( # Mixed or prefill-only batches fall back to PIECEWISE.
prefill_batch_desc, num_tokens_across_dp = dispatch_cg_and_sync_dp(
self.prefill_cudagraph_manager,
num_reqs,
num_tokens, num_tokens,
attn_metadata, get_uniform_token_count(num_reqs, num_tokens, max_query_len),
slot_mappings, dp_size=self.dp_size,
num_tokens_across_dp=num_tokens_across_dp, dp_rank=self.dp_rank,
mm_inputs=mm_inputs, need_eager=is_profile,
) )
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs if prefill_batch_desc.cg_mode == CUDAGraphMode.FULL:
# NOTE(woosuk): For draft sampling, we only consider the temperature # It is necessary to rebuild the attention metadata when
# and ignore the other sampling parameters such as top_k and top_p, # replaying the FULL graph so that any attention metadata
# for simplicity and performance. # builder state is updated.
# While this may slightly degrade the acceptance rate, it does not self._build_draft_attn_metadata(
# affect the output distribution after rejection sampling. num_reqs=num_reqs,
idx_mapping = self.idx_mapping[:num_reqs] num_reqs_padded=prefill_batch_desc.num_reqs or num_reqs,
idx_mapping.copy_(input_batch.idx_mapping) num_tokens_padded=prefill_batch_desc.num_tokens,
self.temperature.copy_(temperature) max_query_len=self.num_speculative_steps + 1,
self.seeds.copy_(seeds) )
# Replay the full graph for draft prefill.
# Gather the values and copy them to the pre-allocated buffers. assert self.prefill_cudagraph_manager is not None
pos = self.input_buffers.positions[:num_reqs] self.prefill_cudagraph_manager.run_fullgraph(prefill_batch_desc)
torch.gather(input_batch.positions, 0, last_token_indices, out=pos) else:
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise # The target model's attention metadata and slot mappings
# used for draft and target sampling. # can directly be used for draft prefill, because of the
draft_tokens = gumbel_sample( # identical batch shape and KV cache layout.
logits, self.prefill(
idx_mapping, num_reqs,
self.temperature, prefill_batch_desc.num_tokens,
self.seeds, attn_metadata,
pos + 1, slot_mappings,
apply_temperature=True, num_tokens_across_dp=num_tokens_across_dp,
processed_logits_out=self.draft_logits[:, 0] cudagraph_runtime_mode=prefill_batch_desc.cg_mode,
if self.draft_logits is not None mm_inputs=mm_inputs,
else None, )
)
if self.num_speculative_steps == 1: if self.num_speculative_steps == 1:
# Early exit. # Early exit.
return draft_tokens.view(-1, 1) return self.draft_tokens[:num_reqs, :1]
# Save the draft tokens for the first step.
self.draft_tokens[:num_reqs, 0] = draft_tokens
# Prepare the inputs for the decode steps. # Prepare the inputs for the decode steps.
prepare_eagle_decode( prepare_eagle_decode(
draft_tokens, self.draft_tokens[:num_reqs, 0],
hidden_states,
last_token_indices,
input_batch.seq_lens, input_batch.seq_lens,
num_rejected, num_rejected,
self.input_buffers, self.input_buffers,
self.hidden_states,
self.max_model_len, self.max_model_len,
self.max_num_reqs, self.max_num_reqs,
) )
# Each request produces exactly 1 token per draft decode step, # Each request produces exactly 1 token per draft generation step,
# enabling FULL cudagraph. # enabling FULL graph replay.
decode_batch_desc, num_tokens_across_dp = dispatch_cg_and_sync_dp( decode_batch_desc, num_tokens_across_dp = dispatch_cg_and_sync_dp(
self.cudagraph_manager, self.decode_cudagraph_manager,
num_reqs, num_reqs,
num_reqs, num_reqs,
uniform_token_count=1, uniform_token_count=1,
...@@ -426,12 +522,12 @@ class EagleSpeculator: ...@@ -426,12 +522,12 @@ class EagleSpeculator:
if not (dummy_run and skip_attn_for_dummy_run): if not (dummy_run and skip_attn_for_dummy_run):
# Build attention metadata and slot mappings for the draft # Build attention metadata and slot mappings for the draft
# decode steps. It is necessary to rebuild the attention # decode steps. It is necessary to rebuild the attention
# metadata even when replaying the FULL cudagraph so that # metadata even when replaying the FULL graph so that any
# any attention metadata builder state is updated. # attention metadata builder state is updated.
slot_mappings = self.block_tables.compute_slot_mappings( slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, self.idx_mapping[:num_reqs],
self.input_buffers.query_start_loc[: num_reqs + 1], self.input_buffers.query_start_loc[: num_reqs + 1],
pos, self.input_buffers.positions[:num_reqs],
decode_batch_desc.num_tokens, decode_batch_desc.num_tokens,
) )
slot_mappings_updated = build_slot_mappings_by_layer( slot_mappings_updated = build_slot_mappings_by_layer(
...@@ -445,8 +541,9 @@ class EagleSpeculator: ...@@ -445,8 +541,9 @@ class EagleSpeculator:
) )
if decode_batch_desc.cg_mode == CUDAGraphMode.FULL: if decode_batch_desc.cg_mode == CUDAGraphMode.FULL:
assert self.cudagraph_manager is not None # Replay the full graph for draft generation.
self.cudagraph_manager.run_fullgraph(decode_batch_desc) assert self.decode_cudagraph_manager is not None
self.decode_cudagraph_manager.run_fullgraph(decode_batch_desc)
else: else:
self.generate_draft( self.generate_draft(
num_reqs, num_reqs,
...@@ -464,6 +561,8 @@ def _prepare_eagle_inputs_kernel( ...@@ -464,6 +561,8 @@ def _prepare_eagle_inputs_kernel(
last_token_indices_ptr, last_token_indices_ptr,
eagle_input_ids_ptr, eagle_input_ids_ptr,
eagle_positions_ptr, eagle_positions_ptr,
eagle_query_start_loc_ptr,
eagle_seq_lens_ptr,
target_input_ids_ptr, target_input_ids_ptr,
target_positions_ptr, target_positions_ptr,
idx_mapping_ptr, idx_mapping_ptr,
...@@ -472,20 +571,24 @@ def _prepare_eagle_inputs_kernel( ...@@ -472,20 +571,24 @@ def _prepare_eagle_inputs_kernel(
num_sampled_ptr, num_sampled_ptr,
num_rejected_ptr, num_rejected_ptr,
query_start_loc_ptr, query_start_loc_ptr,
seq_lens_ptr,
max_num_reqs,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
batch_idx = tl.program_id(0) req_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) num_reqs = tl.num_programs(0)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
query_start = tl.load(query_start_loc_ptr + batch_idx) query_start = tl.load(query_start_loc_ptr + req_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1) query_end = tl.load(query_start_loc_ptr + req_idx + 1)
query_len = query_end - query_start query_len = query_end - query_start
seq_len = tl.load(seq_lens_ptr + req_idx)
# Get the true query length and next token after accounting for rejected tokens. # Get the true query length and next token after accounting for rejected tokens.
num_rejected = tl.load(num_rejected_ptr + batch_idx) num_rejected = tl.load(num_rejected_ptr + req_idx)
query_len -= num_rejected query_len -= num_rejected
num_sampled = tl.load(num_sampled_ptr + batch_idx) num_sampled = tl.load(num_sampled_ptr + req_idx)
if num_sampled > 0: if num_sampled > 0:
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32) next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
else: else:
...@@ -501,7 +604,7 @@ def _prepare_eagle_inputs_kernel( ...@@ -501,7 +604,7 @@ def _prepare_eagle_inputs_kernel(
tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask) tl.store(eagle_input_ids_ptr + query_start + block - 1, input_ids, mask=mask)
last_token_index = query_start + query_len - 1 last_token_index = query_start + query_len - 1
tl.store(last_token_indices_ptr + batch_idx, last_token_index) tl.store(last_token_indices_ptr + req_idx, last_token_index)
tl.store(eagle_input_ids_ptr + last_token_index, next_token) tl.store(eagle_input_ids_ptr + last_token_index, next_token)
# Copy positions. # Copy positions.
...@@ -511,11 +614,29 @@ def _prepare_eagle_inputs_kernel( ...@@ -511,11 +614,29 @@ def _prepare_eagle_inputs_kernel(
target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask) target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask)
tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask) tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask)
# Copy query start locations.
tl.store(eagle_query_start_loc_ptr + req_idx, query_start)
# Copy sequence lengths.
tl.store(eagle_seq_lens_ptr + req_idx, seq_len)
if req_idx == (num_reqs - 1):
# Pad query_start_loc for CUDA graphs.
for i in range(num_reqs, max_num_reqs + 1, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs + 1
tl.store(eagle_query_start_loc_ptr + block, query_end, mask=mask)
# Pad seq_lens for CUDA graphs.
for i in range(num_reqs, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(eagle_seq_lens_ptr + block, 0, mask=mask)
def prepare_eagle_inputs( def prepare_eagle_inputs(
input_buffers: InputBuffers, input_buffers: InputBuffers,
input_batch: InputBatch, input_batch: InputBatch,
# [num_reqs] # [num_reqs]
last_token_indices: torch.Tensor,
# [num_reqs]
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
...@@ -523,17 +644,15 @@ def prepare_eagle_inputs( ...@@ -523,17 +644,15 @@ def prepare_eagle_inputs(
last_sampled: torch.Tensor, last_sampled: torch.Tensor,
# [max_num_reqs] # [max_num_reqs]
next_prefill_tokens: torch.Tensor, next_prefill_tokens: torch.Tensor,
max_num_reqs,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = input_batch.num_reqs num_reqs = input_batch.num_reqs
last_token_indices = torch.empty(
num_reqs,
dtype=torch.int64,
device=num_sampled.device,
)
_prepare_eagle_inputs_kernel[(num_reqs,)]( _prepare_eagle_inputs_kernel[(num_reqs,)](
last_token_indices, last_token_indices,
input_buffers.input_ids, input_buffers.input_ids,
input_buffers.positions, input_buffers.positions,
input_buffers.query_start_loc,
input_buffers.seq_lens,
input_batch.input_ids, input_batch.input_ids,
input_batch.positions, input_batch.positions,
input_batch.idx_mapping, input_batch.idx_mapping,
...@@ -542,6 +661,8 @@ def prepare_eagle_inputs( ...@@ -542,6 +661,8 @@ def prepare_eagle_inputs(
num_sampled, num_sampled,
num_rejected, num_rejected,
input_batch.query_start_loc, input_batch.query_start_loc,
input_batch.seq_lens,
max_num_reqs,
BLOCK_SIZE=1024, BLOCK_SIZE=1024,
) )
return last_token_indices return last_token_indices
...@@ -550,18 +671,13 @@ def prepare_eagle_inputs( ...@@ -550,18 +671,13 @@ def prepare_eagle_inputs(
@triton.jit @triton.jit
def _prepare_eagle_docode_kernel( def _prepare_eagle_docode_kernel(
draft_tokens_ptr, draft_tokens_ptr,
output_hidden_states_ptr, draft_tokens_stride,
output_hidden_states_stride,
last_token_indices_ptr,
target_seq_lens_ptr, target_seq_lens_ptr,
num_rejected_ptr, num_rejected_ptr,
input_ids_ptr, input_ids_ptr,
positions_ptr, positions_ptr,
input_hidden_states_ptr,
input_hidden_states_stride,
query_start_loc_ptr, query_start_loc_ptr,
seq_lens_ptr, seq_lens_ptr,
hidden_size,
max_model_len, max_model_len,
max_num_reqs, max_num_reqs,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
...@@ -584,24 +700,9 @@ def _prepare_eagle_docode_kernel( ...@@ -584,24 +700,9 @@ def _prepare_eagle_docode_kernel(
return return
# draft token -> input id. # draft token -> input id.
draft_token = tl.load(draft_tokens_ptr + req_idx) draft_token = tl.load(draft_tokens_ptr + req_idx * draft_tokens_stride)
tl.store(input_ids_ptr + req_idx, draft_token) tl.store(input_ids_ptr + req_idx, draft_token)
# output hidden states -> input hidden states.
src_idx = tl.load(last_token_indices_ptr + req_idx)
for i in range(0, hidden_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < hidden_size
output_hidden_states = tl.load(
output_hidden_states_ptr + src_idx * output_hidden_states_stride + block,
mask=mask,
)
tl.store(
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
output_hidden_states,
mask=mask,
)
# Compute position and seq_lens. # Compute position and seq_lens.
# NOTE(woosuk): To prevent out-of-range access, we clamp these values # NOTE(woosuk): To prevent out-of-range access, we clamp these values
# if they reach the max model length. # if they reach the max model length.
...@@ -618,31 +719,22 @@ def _prepare_eagle_docode_kernel( ...@@ -618,31 +719,22 @@ def _prepare_eagle_docode_kernel(
def prepare_eagle_decode( def prepare_eagle_decode(
draft_tokens: torch.Tensor, draft_tokens: torch.Tensor,
output_hidden_states: torch.Tensor,
last_token_indices: torch.Tensor,
target_seq_lens: torch.Tensor, target_seq_lens: torch.Tensor,
num_rejected: torch.Tensor, num_rejected: torch.Tensor,
input_buffers: InputBuffers, input_buffers: InputBuffers,
input_hidden_states: torch.Tensor,
max_model_len: int, max_model_len: int,
max_num_reqs: int, max_num_reqs: int,
): ):
num_reqs = draft_tokens.shape[0] num_reqs = draft_tokens.shape[0]
hidden_size = output_hidden_states.shape[-1]
_prepare_eagle_docode_kernel[(num_reqs + 1,)]( _prepare_eagle_docode_kernel[(num_reqs + 1,)](
draft_tokens, draft_tokens,
output_hidden_states, draft_tokens.stride(0),
output_hidden_states.stride(0),
last_token_indices,
target_seq_lens, target_seq_lens,
num_rejected, num_rejected,
input_buffers.input_ids, input_buffers.input_ids,
input_buffers.positions, input_buffers.positions,
input_hidden_states,
input_hidden_states.stride(0),
input_buffers.query_start_loc, input_buffers.query_start_loc,
input_buffers.seq_lens, input_buffers.seq_lens,
hidden_size,
max_model_len, max_model_len,
max_num_reqs, max_num_reqs,
BLOCK_SIZE=1024, BLOCK_SIZE=1024,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment