"vscode:/vscode.git/clone" did not exist on "9c8b2c2a8a0bbfa0c43ed7b220642dc3e802633c"
Unverified Commit ccf90ba7 authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[Model Runner V2] Add full cuda graph support for eagle prefill (#37588)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent 6adacfcb
......@@ -464,7 +464,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
hidden_states = self.execute_model_state.hidden_states
aux_hidden_states = self.execute_model_state.aux_hidden_states
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
self.execute_model_state = None
# dummy run the eagle speculator's propose to ensure DP/EP sync.
......@@ -496,7 +495,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
mm_inputs=mm_inputs,
......@@ -1110,7 +1108,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
kv_connector_output=kv_connector_output,
num_tokens_across_dp=num_tokens_across_dp,
)
if not self.is_last_pp_rank:
......@@ -1135,7 +1132,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states = self.execute_model_state.hidden_states
aux_hidden_states = self.execute_model_state.aux_hidden_states
kv_connector_output = self.execute_model_state.kv_connector_output
num_tokens_across_dp = self.execute_model_state.num_tokens_across_dp
self.execute_model_state = None
if not self.is_last_pp_rank:
......@@ -1228,7 +1224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
num_tokens_across_dp=num_tokens_across_dp,
mm_inputs=mm_inputs,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
......@@ -1336,4 +1331,3 @@ class ExecuteModelState(NamedTuple):
hidden_states: torch.Tensor | None
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
num_tokens_across_dp: torch.Tensor | None
......@@ -19,21 +19,16 @@ from vllm.v1.worker.utils import AttentionGroup
class EagleCudaGraphManager(CudaGraphManager):
"""CudaGraphManager for Eagle speculative decoding (FULL mode only)."""
"""CudaGraphManager for Eagle speculative decoding."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
cudagraph_mode: CUDAGraphMode,
draft_tokens: torch.Tensor,
decode_query_len: int,
):
assert not cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE), (
"EagleCudaGraphManager does not support PIECEWISE mode yet"
)
# Eagle always uses uniform decode with query_len=1
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len=1)
self.draft_tokens = draft_tokens
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
# Use a dedicated pool for Eagle to avoid memory overlap with the main
# model's cudagraph. The base class uses a shared global pool, but Eagle's
......@@ -44,7 +39,7 @@ class EagleCudaGraphManager(CudaGraphManager):
def capture(
self,
generate_fn: Callable,
forward_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
......@@ -52,7 +47,7 @@ class EagleCudaGraphManager(CudaGraphManager):
kv_cache_config: KVCacheConfig,
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
"""Capture CUDA graphs for Eagle speculative decoding (FULL mode only)."""
"""Capture CUDA graphs for Eagle."""
def create_forward_fn(
desc: BatchExecutionDescriptor,
......@@ -74,7 +69,7 @@ class EagleCudaGraphManager(CudaGraphManager):
kv_cache_config,
)
return lambda cg_mode: generate_fn(
return lambda cg_mode: forward_fn(
num_reqs,
num_tokens,
attn_metadata,
......@@ -84,8 +79,3 @@ class EagleCudaGraphManager(CudaGraphManager):
)
super().capture(create_forward_fn, progress_bar_desc)
def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor:
"""Replay a captured FULL cudagraph and return draft tokens."""
super().run_fullgraph(desc)
return self.draft_tokens
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment