Unverified Commit 3bfe55a0 authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[Model Runner V2] Disable piecewise cudagraph mode fallback for eagle draft decodes (#39773)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent b569620f
...@@ -115,13 +115,21 @@ class EagleSpeculator: ...@@ -115,13 +115,21 @@ class EagleSpeculator:
cudagraph_mode, cudagraph_mode,
self.num_speculative_steps + 1, self.num_speculative_steps + 1,
) )
# Initialize cudagraph manager for draft generation (draft positions > 0).
# PIECEWISE cudagraphs are not supported for eagle draft decodes.
# PIECEWISE pads num_tokens to the next capture size without padding
# num_reqs, which can cause attention backends to read past the
# valid per-request metadata (e.g. FlashInfer's kv_indptr buffer).
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
else:
cudagraph_mode = CUDAGraphMode.NONE
# Initialize cudagraph manager for draft decodes (draft positions > 0).
self.decode_cudagraph_manager = EagleCudaGraphManager( self.decode_cudagraph_manager = EagleCudaGraphManager(
self.vllm_config, self.vllm_config,
self.device, self.device,
# Only use FULL graph mode, if available, because draft decodes cudagraph_mode,
# only consist of a single token.
cudagraph_mode.decode_mode(),
decode_query_len=1, decode_query_len=1,
) )
# Share a single pool between prefill and decode since they never # Share a single pool between prefill and decode since they never
...@@ -366,11 +374,8 @@ class EagleSpeculator: ...@@ -366,11 +374,8 @@ class EagleSpeculator:
# Capture the decode draft generation loop (model forward + # Capture the decode draft generation loop (model forward +
# compute_logits + gumbel_sample + update_eagle_inputs, for # compute_logits + gumbel_sample + update_eagle_inputs, for
# each step). # each step). For FULL graphs, the entire multi-step loop is
# For FULL graphs, the entire multi-step loop is recorded as # recorded as one graph.
# one graph. For PIECEWISE, only the model's compiled regions
# are captured, and the rest (compute_logits, gumbel_sample,
# update_eagle_inputs) runs eagerly.
assert self.decode_cudagraph_manager is not None assert self.decode_cudagraph_manager is not None
self.decode_cudagraph_manager.capture( self.decode_cudagraph_manager.capture(
self.generate_draft, self.generate_draft,
...@@ -629,6 +634,11 @@ def _prepare_eagle_inputs_kernel( ...@@ -629,6 +634,11 @@ def _prepare_eagle_inputs_kernel(
block = i + tl.arange(0, BLOCK_SIZE) block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs mask = block < max_num_reqs
tl.store(eagle_seq_lens_ptr + block, 0, mask=mask) tl.store(eagle_seq_lens_ptr + block, 0, mask=mask)
# Pad last_token_indices for CUDA graphs.
for i in range(num_reqs, max_num_reqs, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < max_num_reqs
tl.store(last_token_indices_ptr + block, 0, mask=mask)
def prepare_eagle_inputs( def prepare_eagle_inputs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment