Unverified Commit 483463f7 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[MRV2] Extensible CG dispatch rework (#35959)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 4e571ce6
......@@ -97,6 +97,9 @@ class CUDAGraphMode(enum.Enum):
def __str__(self) -> str:
return self.name
def __bool__(self) -> bool:
return self != CUDAGraphMode.NONE
@config
class PassConfig:
......
......@@ -104,19 +104,24 @@ class BlockTables:
self.num_blocks.copy_to_uva()
def gather_block_tables(
self, idx_mapping: torch.Tensor
self,
idx_mapping: torch.Tensor,
num_reqs_padded: int,
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
# Launch kernel with num_reqs_padded to fuse zeroing of padded rows.
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs_padded)](
idx_mapping,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks.gpu,
self.num_blocks.gpu.stride(0),
num_reqs,
self.input_block_tables[0].shape[1], # max_num_blocks
BLOCK_SIZE=1024, # type: ignore
)
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
return tuple(bt[:num_reqs_padded] for bt in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
# NOTE(woosuk): The output may be used for CUDA graph capture.
......@@ -130,6 +135,7 @@ class BlockTables:
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
num_tokens_padded: int,
) -> torch.Tensor:
num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0]
......@@ -151,7 +157,7 @@ class BlockTables:
PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
return self.slot_mappings[:, :num_tokens_padded]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
......@@ -173,21 +179,31 @@ def _gather_block_tables_kernel(
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
num_reqs, # actual number of requests (for padding)
max_num_blocks, # stride for zeroing padded rows
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
stride = tl.load(block_table_strides + group_id)
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
if batch_idx >= num_reqs:
# Zero out padded rows.
for i in tl.range(0, max_num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(dst_row_ptr + offset, 0, mask=offset < max_num_blocks)
return
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import torch
import torch.distributed as dist
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import get_dp_group
from vllm.v1.worker.gpu.cudagraph_utils import (
BatchExecutionDescriptor,
CudaGraphManager,
)
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
......@@ -12,66 +19,63 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
def get_batch_metadata_across_dp(
def sync_cudagraph_and_dp_padding(
cudagraph_manager: CudaGraphManager,
desired_batch_desc: BatchExecutionDescriptor,
num_tokens: int,
cudagraph_size: int,
cudagraph_runtime_mode: int,
num_reqs: int,
uniform_token_count: int | None,
dp_size: int,
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization.
) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
"""
Coordinates the batch descriptor and DP padding across all ranks.
Returns (synced_batch_desc, num_tokens_across_dp).
"""
assert dp_size > 1, "DP size must be greater than 1"
group = get_dp_group().cpu_group
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size
tensor[2][dp_rank] = cudagraph_runtime_mode
tensor[1][dp_rank] = desired_batch_desc.cg_mode.value
tensor[2][dp_rank] = uniform_token_count or 0 # (0 means None)
dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1], tensor[2]
num_tokens_across_dp = tensor[0]
cg_mode_across_dp = tensor[1]
uniform_token_counts_across_dp = tensor[2]
def get_cudagraph_and_dp_padding(
num_tokens: int,
cudagraph_size: int | None,
cudagraph_runtime_mode: int,
dp_size: int,
dp_rank: int,
) -> tuple[int, torch.Tensor | None, int]:
if dp_size == 1:
if cudagraph_size is not None:
return cudagraph_size, None, cudagraph_runtime_mode
else:
return num_tokens, None, cudagraph_runtime_mode
if torch.all(num_tokens_across_dp == 0).item():
synced_desc = BatchExecutionDescriptor(
cg_mode=CUDAGraphMode.NONE, num_tokens=0, num_reqs=0
)
return synced_desc, None
# Convert None to -1 for sync (indicates no cudagraph available)
if num_tokens == 0:
cudagraph_size = 0
elif cudagraph_size is None:
cudagraph_size = -1
synced_cg_mode = CUDAGraphMode(int(cg_mode_across_dp.min().item()))
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
get_batch_metadata_across_dp(
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
)
# If any rank wants to run eager, all ranks run eager
if synced_cg_mode == CUDAGraphMode.NONE:
return BatchExecutionDescriptor(
cg_mode=CUDAGraphMode.NONE,
num_tokens=num_tokens,
num_reqs=num_reqs,
), num_tokens_across_dp
synced_num_tokens = int(num_tokens_across_dp.max().item())
synced_uniform_token_count = uniform_token_counts_across_dp[0]
# If ranks disagree on the uniform token count, or its 0 (means None) set to None
if synced_uniform_token_count == 0 or not torch.all(
uniform_token_counts_across_dp == synced_uniform_token_count
):
synced_uniform_token_count = None
# Dispatch for the final synced values, use num_reqs instead of synced_num_reqs
# so we don't perform request padding for PIECEWISE graphs
synced_desc = cudagraph_manager.dispatch(
num_reqs, synced_num_tokens, synced_uniform_token_count
)
if torch.all(num_tokens_across_dp == 0).item():
# All ranks have zero tokens to run.
return 0, None, 0
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
# Check if all ranks have valid cudagraph_size.
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
# Update num_tokens_across_dp to reflect padded size.
num_tokens_across_dp[:] = synced_desc.num_tokens
if synced_cudagraph_mode != 0 and all_have_cudagraph:
# All ranks use cudagraph. Pad to max cudagraph_size.
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
num_tokens_across_dp[:] = max_cudagraph_size
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
else:
# Fall back to eager mode (no cudagraph).
# Either some rank doesn't have cudagraph size or mode is NONE.
synced_cudagraph_mode = 0
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode
return synced_desc, num_tokens_across_dp
......@@ -37,6 +37,7 @@ class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
num_reqs_after_padding: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
......@@ -123,6 +124,7 @@ class InputBatch:
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
num_reqs_after_padding=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
......@@ -330,7 +332,8 @@ def combine_sampled_and_draft_tokens(
cu_num_logits: torch.Tensor,
num_logits: int,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
# use idx_mapping.shape[0] for actual request count
num_reqs = idx_mapping.shape[0]
num_speculative_steps = draft_tokens.shape[-1]
logits_indices = torch.empty(
......
......@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
......@@ -57,8 +56,12 @@ from vllm.v1.worker.gpu.attn_utils import (
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.cudagraph_utils import (
BatchExecutionDescriptor,
ModelCudaGraphManager,
get_uniform_token_count,
)
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
......@@ -137,6 +140,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Data parallelism.
self.dp_size = self.parallel_config.data_parallel_size
self.dp_rank = self.parallel_config.data_parallel_rank
# Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1
......@@ -193,10 +200,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
self.decode_query_len = self.num_speculative_steps + 1
self.cudagraph_manager = ModelCudaGraphManager(
self.vllm_config,
self.use_aux_hidden_state_outputs,
self.device,
self.compilation_config.cudagraph_mode,
decode_query_len=self.decode_query_len,
)
# Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker(
......@@ -331,17 +340,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
**kwargs,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
# Create a dummy scheduler output.
num_reqs = min(num_tokens, self.max_num_reqs)
if uniform_decode:
# Align tokens to uniform_decode_query_len for cudagraph
# compatibility across DP ranks.
query_len = self.cudagraph_manager.uniform_decode_query_len
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs)
num_tokens = num_reqs * query_len
num_tokens_per_request = [query_len] * num_reqs
else:
num_reqs = min(num_tokens, self.max_num_reqs)
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs
# HACK(lucas): for now since the worker is shared between MRV1 and MRV2,
# and for spec-decode with MTP we want to make sure the dummy runs use
# 1+num_speculative_tokens we use max here, this will likely be eventually
# changed in the worker: https://github.com/vllm-project/vllm/pull/35243
num_tokens = max(num_tokens, self.decode_query_len)
num_reqs = num_tokens // self.decode_query_len
assert num_tokens % self.decode_query_len == 0
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs
assert sum(num_tokens_per_request) == num_tokens
num_scheduled_tokens = {
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
......@@ -498,13 +508,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with self.maybe_setup_dummy_loras(self.lora_config):
self.cudagraph_manager.capture(
model=self.model,
model_state=self.model_state,
input_buffers=self.input_buffers,
block_tables=self.block_tables,
attn_groups=self.attn_groups,
kv_cache_config=self.kv_cache_config,
self.model,
self.model_state,
self.input_buffers,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
has_lora=self.lora_config is not None,
use_aux_hidden_state_outputs=self.use_aux_hidden_state_outputs,
)
if self.speculator is not None:
self.speculator.capture_model()
......@@ -592,9 +603,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
def prepare_inputs(
self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
self, scheduler_output: SchedulerOutput, batch_desc: BatchExecutionDescriptor
) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens
num_tokens_after_padding = batch_desc.num_tokens
assert num_tokens > 0
num_tokens_per_req = scheduler_output.num_scheduled_tokens
num_reqs = len(num_tokens_per_req)
......@@ -644,6 +656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
# Get query_start_loc.
# num_reqs_padded is None for PIECEWISE graphs (no request padding needed)
num_reqs_padded = batch_desc.num_reqs or num_reqs
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
......@@ -651,8 +665,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs_padded + 1]
# Get prefill tokens if any.
if self.req_states.any_prefills(idx_mapping_np):
......@@ -674,7 +688,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_buffers.positions,
self.input_buffers.seq_lens,
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
seq_lens = self.input_buffers.seq_lens[:num_reqs_padded]
dcp_local_seq_lens = None
if self.use_dcp:
......@@ -687,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank,
self.cp_interleave,
)
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs_padded]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
......@@ -706,6 +720,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
num_reqs_after_padding=num_reqs_padded,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
......@@ -729,13 +744,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def prepare_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
# Block tables: num_kv_cache_groups x [num_reqs_padded, max_num_blocks].
block_tables = self.block_tables.gather_block_tables(
input_batch.idx_mapping,
num_reqs_padded=input_batch.num_reqs_after_padding,
)
# Slot mappings: [num_kv_cache_groups, num_tokens_padded].
# Kernel pads beyond num_tokens with PAD_SLOT_ID.
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions,
num_tokens_padded=input_batch.num_tokens_after_padding,
)
return block_tables, slot_mappings
......@@ -851,27 +871,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output
# Get local cudagraph mode and size.
local_cudagraph_mode, local_cudagraph_size = (
self.cudagraph_manager.get_cudagraph_runtime_mode(
num_reqs=len(scheduler_output.num_scheduled_tokens),
num_tokens=scheduler_output.total_num_scheduled_tokens,
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
)
# Get batch descriptor and sync across DP ranks.
num_reqs = len(scheduler_output.num_scheduled_tokens)
num_toks = scheduler_output.total_num_scheduled_tokens
max_query_len = max(scheduler_output.num_scheduled_tokens.values())
uniform_tok_count = get_uniform_token_count(num_reqs, num_toks, max_query_len)
batch_desc = self.cudagraph_manager.dispatch(
num_reqs, num_toks, uniform_tok_count
)
num_tokens_across_dp = None
# DP sync: num_tokens + cudagraph_size + cudagraph_mode
num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding(
scheduler_output.total_num_scheduled_tokens,
local_cudagraph_size,
local_cudagraph_mode.value,
self.parallel_config.data_parallel_size,
self.parallel_config.data_parallel_rank,
if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager,
batch_desc,
num_toks,
num_reqs,
uniform_tok_count,
self.dp_size,
self.dp_rank,
)
)
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
if num_tokens_after_padding == 0:
if batch_desc.num_tokens == 0:
# All DP ranks have zero tokens to run.
empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output
......@@ -879,9 +901,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not dummy_run:
# Common case.
# Prepare all the inputs and copy to the input buffers.
input_batch = self.prepare_inputs(
scheduler_output, num_tokens_after_padding
)
input_batch = self.prepare_inputs(scheduler_output, batch_desc)
block_tables, slot_mappings = self.prepare_attn(input_batch)
if self.lora_config:
......@@ -894,9 +914,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self._set_active_loras(*lora_inputs)
else:
# No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs, num_tokens_after_padding, self.input_buffers
batch_desc.num_reqs or num_reqs,
batch_desc.num_tokens,
self.input_buffers,
)
if not skip_attn_for_dummy_run:
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
......@@ -948,14 +969,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
model_inputs["intermediate_tensors"] = intermediate_tensors
# Run model.
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
if batch_desc.cg_mode == CUDAGraphMode.FULL:
# Use explicit cudagraph replay for FULL mode.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output)
model_output = self.cudagraph_manager.run_fullgraph(
input_batch.num_tokens_after_padding
)
model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
......@@ -972,7 +991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=cudagraph_runtime_mode,
cudagraph_runtime_mode=batch_desc.cg_mode,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings_by_layer,
......
......@@ -142,12 +142,15 @@ class DefaultModelState(ModelState):
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.model_executor.offloader.base import get_offloader
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
capture_graphs,
get_cudagraph_sizes,
BatchExecutionDescriptor,
CudaGraphManager,
prepare_inputs_to_capture,
)
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
class EagleCudaGraphManager:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
class EagleCudaGraphManager(CudaGraphManager):
"""CudaGraphManager for Eagle speculative decoding (FULL mode only)."""
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode()
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
_, self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
uniform_decode_query_len=1,
uniform_decode_cudagraph=True,
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
cudagraph_mode: CUDAGraphMode,
draft_tokens: torch.Tensor,
):
assert not cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE), (
"EagleCudaGraphManager does not support PIECEWISE mode yet"
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
# Eagle always uses uniform decode with query_len=1
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len=1)
self.draft_tokens = draft_tokens
# Use a dedicated pool for Eagle to avoid memory overlap with the main
# model's cudagraph. The base class uses a shared global pool, but Eagle's
# internal allocations (e.g., gumbel_sample temporaries) can conflict with
# the main model's allocations when sharing the same pool.
if cudagraph_mode:
self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)
def get_cudagraph_runtime_mode(
self, num_tokens: int
) -> tuple[CUDAGraphMode, int | None]:
cudagraph_size = self.get_cudagraph_size(num_tokens)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.cudagraph_mode
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def capture_graph(
def capture(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Capture the graph.
capture_fn(
num_reqs=num_reqs,
num_tokens=num_tokens,
generate_fn=generate_fn,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
)
def _capture_full_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
"""Capture CUDA graphs for Eagle speculative decoding (FULL mode only)."""
def create_forward_fn(
desc: BatchExecutionDescriptor,
) -> Callable[[CUDAGraphMode], None]:
num_tokens = desc.num_tokens
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
num_tokens_across_dp = (
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
if self.dp_size > 1
else None
)
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
kv_cache_config,
)
with torch.cuda.graph(graph, self.pool):
generate_fn(
return lambda cg_mode: generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
cg_mode,
)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.PIECEWISE,
)
@torch.inference_mode()
def capture(
self,
generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> None:
if self.cudagraph_mode == CUDAGraphMode.NONE:
return
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn,
model_state=model_state,
input_buffers=input_buffers,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
)
super().capture(create_forward_fn, progress_bar_desc)
def run_fullgraph(self, num_tokens: int) -> None:
assert num_tokens in self.graphs
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[num_tokens].replay()
def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor:
"""Replay a captured FULL cudagraph and return draft tokens."""
super().run_fullgraph(desc)
return self.draft_tokens
......@@ -16,7 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
......@@ -75,7 +75,16 @@ class EagleSpeculator:
device=device,
)
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
# currently we don't support PIECEWISE for Eagle.
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
else:
cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_manager = EagleCudaGraphManager(
vllm_config, device, cudagraph_mode, self.draft_tokens
)
def load_model(self, target_model: nn.Module) -> None:
self.model = load_eagle_model(target_model, self.vllm_config)
......@@ -171,7 +180,7 @@ class EagleSpeculator:
)
if attn_metadata is not None:
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
idx_mapping, query_start_loc, pos, num_tokens_padded
)
def capture_model(self) -> None:
......@@ -185,6 +194,7 @@ class EagleSpeculator:
self.block_tables,
self.attn_groups,
self.kv_cache_config,
progress_bar_desc="Capturing eagle CUDA graphs",
)
@torch.inference_mode()
......@@ -251,6 +261,7 @@ class EagleSpeculator:
logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs
num_reqs_padded = input_batch.num_reqs_after_padding
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
......@@ -292,48 +303,52 @@ class EagleSpeculator:
self.max_num_reqs,
)
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
# Get batch descriptor and sync across DP ranks.
# Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode
cudagraph_mode, cudagraph_size = (
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
)
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding(
batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1)
num_tokens_across_dp = None
if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager,
batch_desc,
num_reqs,
cudagraph_size,
cudagraph_mode.value,
num_reqs,
1, # uniform_token_count
self.dp_size,
self.dp_rank,
)
)
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
if cudagraph_mode == CUDAGraphMode.FULL:
# Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
return self.draft_tokens[:num_reqs]
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos, batch_desc.num_tokens
)
if batch_desc.cg_mode == CUDAGraphMode.FULL:
return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs]
# Run eager or piecewise CUDA graph.
attn_metadata_updated = None
slot_mappings_updated = None
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
num_reqs_padded + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
block_tables = [
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata_updated = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
num_reqs=num_reqs_padded,
num_tokens=num_reqs_padded,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
......@@ -345,11 +360,11 @@ class EagleSpeculator:
self.generate_draft(
num_reqs,
num_tokens_padded,
batch_desc.num_tokens,
attn_metadata_updated,
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode,
cudagraph_runtime_mode=batch_desc.cg_mode,
)
return self.draft_tokens[:num_reqs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment