"vllm/vscode:/vscode.git/clone" did not exist on "c0d00f5be6d3ed390534dd909c82b639baf2d359"
Unverified Commit 483463f7 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[MRV2] Extensible CG dispatch rework (#35959)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 4e571ce6
...@@ -97,6 +97,9 @@ class CUDAGraphMode(enum.Enum): ...@@ -97,6 +97,9 @@ class CUDAGraphMode(enum.Enum):
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
def __bool__(self) -> bool:
return self != CUDAGraphMode.NONE
@config @config
class PassConfig: class PassConfig:
......
...@@ -104,19 +104,24 @@ class BlockTables: ...@@ -104,19 +104,24 @@ class BlockTables:
self.num_blocks.copy_to_uva() self.num_blocks.copy_to_uva()
def gather_block_tables( def gather_block_tables(
self, idx_mapping: torch.Tensor self,
idx_mapping: torch.Tensor,
num_reqs_padded: int,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0] num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)]( # Launch kernel with num_reqs_padded to fuse zeroing of padded rows.
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs_padded)](
idx_mapping, idx_mapping,
self.block_table_ptrs, self.block_table_ptrs,
self.input_block_table_ptrs, self.input_block_table_ptrs,
self.block_table_strides, self.block_table_strides,
self.num_blocks.gpu, self.num_blocks.gpu,
self.num_blocks.gpu.stride(0), self.num_blocks.gpu.stride(0),
num_reqs,
self.input_block_tables[0].shape[1], # max_num_blocks
BLOCK_SIZE=1024, # type: ignore BLOCK_SIZE=1024, # type: ignore
) )
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) return tuple(bt[:num_reqs_padded] for bt in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]: def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
# NOTE(woosuk): The output may be used for CUDA graph capture. # NOTE(woosuk): The output may be used for CUDA graph capture.
...@@ -130,6 +135,7 @@ class BlockTables: ...@@ -130,6 +135,7 @@ class BlockTables:
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor, query_start_loc: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
num_tokens_padded: int,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = idx_mapping.shape[0] num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0] num_tokens = positions.shape[0]
...@@ -151,7 +157,7 @@ class BlockTables: ...@@ -151,7 +157,7 @@ class BlockTables:
PAD_ID=PAD_SLOT_ID, PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore TRITON_BLOCK_SIZE=1024, # type: ignore
) )
return self.slot_mappings[:, :num_tokens] return self.slot_mappings[:, :num_tokens_padded]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor: def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries. # Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
...@@ -173,21 +179,31 @@ def _gather_block_tables_kernel( ...@@ -173,21 +179,31 @@ def _gather_block_tables_kernel(
block_table_strides, # [num_kv_cache_groups] block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride, num_blocks_stride,
num_reqs, # actual number of requests (for padding)
max_num_blocks, # stride for zeroing padded rows
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
# kv cache group id # kv cache group id
group_id = tl.program_id(0) group_id = tl.program_id(0)
batch_idx = tl.program_id(1) batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
stride = tl.load(block_table_strides + group_id)
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
if batch_idx >= num_reqs:
# Zero out padded rows.
for i in tl.range(0, max_num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(dst_row_ptr + offset, 0, mask=offset < max_num_blocks)
return
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx) num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32) src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE): for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE) offset = i + tl.arange(0, BLOCK_SIZE)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass
from typing import Any from typing import Any
import torch import torch
...@@ -11,78 +13,262 @@ from vllm.config import VllmConfig ...@@ -11,78 +13,262 @@ from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader from vllm.model_executor.offloader.base import get_offloader
from vllm.utils.math_utils import cdiv from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__)
@dataclass(frozen=True)
class BatchExecutionDescriptor:
"""Describes the shape of the batch and CG mode to run; this is used to make shape
matches between the capture and runtime."""
cg_mode: CUDAGraphMode
num_tokens: int
num_reqs: int | None # None means no request padding is needed (PIECEWISE graphs)
uniform_token_count: int | None = None
def _is_compatible(
desc: BatchExecutionDescriptor,
num_reqs: int,
num_tokens: int,
uniform_token_count: int | None,
) -> bool:
# desc.uniform_token_count=None (PIECEWISE) can handle any uniform_token_count
# desc.num_reqs=None means no request padding needed (PIECEWISE)
return (
(
desc.uniform_token_count is None
or desc.uniform_token_count == uniform_token_count
)
and (desc.num_reqs is None or desc.num_reqs >= num_reqs)
and desc.num_tokens >= num_tokens
)
def get_uniform_token_count(
num_reqs: int,
num_tokens: int,
max_query_len: int,
) -> int | None:
"""
Return the uniform token count if batch is uniform, else None.
A batch is uniform if all requests have the same number of tokens.
"""
if (max_query_len == num_tokens // num_reqs) and (
num_tokens == max_query_len * num_reqs
):
return max_query_len
return None
class CudaGraphManager: class CudaGraphManager:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
use_aux_hidden_state_outputs: bool,
device: torch.device, device: torch.device,
cudagraph_mode: CUDAGraphMode,
decode_query_len: int,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
self.device = device self.device = device
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len self.compilation_config = vllm_config.compilation_config
self.max_num_reqs = self.scheduler_config.max_num_seqs assert self.compilation_config is not None
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.cudagraph_mode = cudagraph_mode
self.decode_query_len = decode_query_len
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.uniform_decode_query_len = 1 self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
spec_config = vllm_config.speculative_config self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
if spec_config is not None:
self.uniform_decode_query_len += spec_config.num_speculative_tokens
self.compilation_config = vllm_config.compilation_config self._graphs_captured = False
assert self.compilation_config is not None self._candidates: list[list[BatchExecutionDescriptor]] = []
self.cudagraph_mode = self.compilation_config.cudagraph_mode self._capture_descs: dict[CUDAGraphMode, list[BatchExecutionDescriptor]] = {}
self._init_candidates()
def _init_candidates(self) -> None:
"""Build priority-ordered candidate lists for each token count."""
capture_sizes = self.compilation_config.cudagraph_capture_sizes
if not (self.cudagraph_mode and capture_sizes):
return
use_uniform_decode_cudagraph = ( capture_sizes = sorted(capture_sizes)
self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL max_decode_tokens = self.max_num_reqs * self.decode_query_len
and self.cudagraph_mode.separate_routine() decode_mode = self.cudagraph_mode.decode_mode()
mixed_mode = self.cudagraph_mode.mixed_mode()
separate_decode_routine = self.cudagraph_mode.separate_routine()
descs_by_token_count = defaultdict(list)
descs_by_mode = defaultdict(list)
for num_tokens in capture_sizes:
# Capture uniform decode specfifc graphs if required
# (i.e. separate decode routine)
if (
separate_decode_routine
and decode_mode
and self.decode_query_len <= num_tokens <= max_decode_tokens
):
desc = BatchExecutionDescriptor(
cg_mode=decode_mode,
num_tokens=num_tokens,
num_reqs=num_tokens // self.decode_query_len,
uniform_token_count=self.decode_query_len,
)
descs_by_mode[decode_mode].append(desc)
descs_by_token_count[num_tokens].append(desc)
if mixed_mode:
# for PIECEWISE graphs there is no limit on requests when replaying
# i.e. no request padding is needed
# so we leave it as None
num_reqs = (
min(num_tokens, self.max_num_reqs)
if mixed_mode == CUDAGraphMode.FULL
else None
) )
self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes( desc = BatchExecutionDescriptor(
self.compilation_config.cudagraph_capture_sizes, cg_mode=mixed_mode,
self.max_num_reqs, num_tokens=num_tokens,
self.max_num_tokens, num_reqs=num_reqs,
self.cudagraph_mode,
self.uniform_decode_query_len,
use_uniform_decode_cudagraph,
) )
descs_by_mode[mixed_mode].append(desc)
descs_by_token_count[num_tokens].append(desc)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {} if not descs_by_token_count:
self.pool = None return
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle() sorted_padded = sorted(descs_by_token_count.keys())
self.hidden_states: torch.Tensor | None = None self._candidates = [[] for _ in range(sorted_padded[-1] + 1)]
self.aux_hidden_states: list[torch.Tensor] = []
current_range_start = 0
for cg_size in sorted_padded:
for i in range(current_range_start, cg_size + 1):
self._candidates[i] = descs_by_token_count[cg_size]
current_range_start = cg_size + 1
for mode, descs in descs_by_mode.items():
descs.sort(key=lambda d: d.num_tokens, reverse=True)
self._capture_descs[mode] = descs
def needs_capture(self) -> bool: def needs_capture(self) -> bool:
return len(self.cudagraph_sizes) > 0 return len(self._capture_descs) > 0
def get_cudagraph_size( @torch.inference_mode()
self, num_tokens: int, uniform_decode: bool = False def capture(
) -> int | None: self,
if uniform_decode and self.uniform_decode_cudagraph_sizes: create_forward_fn: Callable[
return self.uniform_decode_cudagraph_sizes.get(num_tokens) [BatchExecutionDescriptor], Callable[[CUDAGraphMode], None]
return self.cudagraph_sizes.get(num_tokens) ],
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
"""Capture CUDA graphs.
Args:
create_forward_fn: Factory that prepares inputs (OUTSIDE graph) and
returns a function that runs forward with a given CUDAGraphMode.
"""
with graph_capture(device=self.device):
# Capture in order: PIECEWISE first, then FULL. PIECEWISE has larger
# activations so FULL activations should fit in already allocated
# buffers in the graph pool.
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
if mode not in self._capture_descs:
continue
descs = self._capture_descs[mode]
if is_global_first_rank():
descs = tqdm(descs, desc=f"{progress_bar_desc} ({mode.name})")
for desc in descs:
# Prepare inputs and get forward function
forward_fn = create_forward_fn(desc)
# Warmup
forward_fn(CUDAGraphMode.NONE)
def capture_graph( # Capture
logger.debug(
"CG Capture: mode=%s, batch_desc=%s", desc.cg_mode.name, desc
)
if desc.cg_mode == CUDAGraphMode.PIECEWISE:
forward_fn(CUDAGraphMode.PIECEWISE)
else:
assert desc not in self.graphs, (
f"Graph already captured for {desc}"
)
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with torch.cuda.graph(graph, self.pool):
forward_fn(CUDAGraphMode.NONE)
# Join offloader's copy stream after forward to avoid
# unjoined stream error. The last layer's start_prefetch
# forks copy_stream, but wait_prefetch only happens in
# the next forward pass.
get_offloader().join_after_forward()
self.graphs[desc] = graph
self._graphs_captured = True
def dispatch(
self, self,
num_reqs: int,
num_tokens: int, num_tokens: int,
capture_cg_mode: CUDAGraphMode, uniform_token_count: int | None,
) -> BatchExecutionDescriptor:
"""Find matching cudagraph descriptor from priority-ordered candidates."""
if self._graphs_captured and 0 < num_tokens < len(self._candidates):
for desc in self._candidates[num_tokens]:
if _is_compatible(desc, num_reqs, num_tokens, uniform_token_count):
return desc
return BatchExecutionDescriptor(
cg_mode=CUDAGraphMode.NONE, num_tokens=num_tokens, num_reqs=num_reqs
)
def run_fullgraph(self, desc: BatchExecutionDescriptor):
"""Replay a captured FULL cudagraph."""
assert desc.cg_mode == CUDAGraphMode.FULL, (
f"Expected FULL mode, got {desc.cg_mode}"
)
assert desc in self.graphs, f"No cudagraph for {desc}"
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[desc].replay()
class ModelCudaGraphManager(CudaGraphManager):
"""CudaGraphManager with model-specific capture and hidden state management."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
cudagraph_mode: CUDAGraphMode,
decode_query_len: int,
):
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = []
self.use_aux_hidden_state_outputs = False
def capture(
self,
model: nn.Module, model: nn.Module,
model_state: ModelState, model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
...@@ -90,33 +276,22 @@ class CudaGraphManager: ...@@ -90,33 +276,22 @@ class CudaGraphManager:
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
has_lora: bool = False, has_lora: bool = False,
uniform_decode: bool = False, use_aux_hidden_state_outputs: bool = False,
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None: ) -> None:
# select and check capture function """Capture CUDA graphs for model forward pass."""
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
# prepare inputs
if uniform_decode:
num_reqs = min(
cdiv(num_tokens, self.uniform_decode_query_len),
self.max_num_reqs,
)
else:
num_reqs = min(num_tokens, self.max_num_reqs)
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
# NOTE: Values returned by `prepare_dummy_inputs` will override the
# default values above.
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
def create_forward_fn(
desc: BatchExecutionDescriptor,
) -> Callable[[CUDAGraphMode], None]:
num_tokens = desc.num_tokens
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
num_tokens_across_dp = (
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
if self.dp_size > 1
else None
)
attn_metadata, slot_mappings = prepare_inputs_to_capture( attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs, num_reqs,
num_tokens, num_tokens,
...@@ -126,262 +301,56 @@ class CudaGraphManager: ...@@ -126,262 +301,56 @@ class CudaGraphManager:
attn_groups, attn_groups,
kv_cache_config, kv_cache_config,
) )
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up. def forward_fn(cg_mode: CUDAGraphMode) -> None:
batch_descriptor = (
BatchDescriptor(num_tokens=num_tokens)
if cg_mode == CUDAGraphMode.PIECEWISE
else None
)
with set_forward_context( with set_forward_context(
attn_metadata, attn_metadata if cg_mode != CUDAGraphMode.PIECEWISE else None,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens, num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=cg_mode,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
): ):
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
}
model_output = model(**model_inputs) model_output = model(**model_inputs)
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
hidden_states = model_output hidden_states = model_output
aux_hidden_states = None aux_hidden_states = []
# Allocate output buffers if not already done.
if self.hidden_states is None: if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states) self.hidden_states = torch.empty_like(hidden_states)
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states: if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states] self.aux_hidden_states = [
torch.empty_like(x) for x in aux_hidden_states
capture_fn( ]
num_tokens=num_tokens,
num_reqs=num_reqs,
model=model,
model_inputs=model_inputs,
num_tokens_across_dp=num_tokens_across_dp,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
has_lora=has_lora,
)
def _capture_full_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
model_inputs: dict[str, torch.Tensor | None],
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
assert attn_metadata is not None
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with (
set_forward_context(
attn_metadata=attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
),
torch.cuda.graph(graph, self.pool),
):
model_output = model(**model_inputs)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Copy outputs to the output buffers.
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states self.hidden_states[:num_tokens] = hidden_states
if self.use_aux_hidden_state_outputs: for i, aux in enumerate(aux_hidden_states):
for i, aux_hidden in enumerate(aux_hidden_states): self.aux_hidden_states[i][:num_tokens] = aux
self.aux_hidden_states[i][:num_tokens] = aux_hidden
self.graphs[num_tokens] = graph
def _capture_piecewise_graph( return forward_fn
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
model_inputs: dict[str, torch.Tensor | None],
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
# create batch descriptor for piecewise cudagraph dispatch key
batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture. super().capture(create_forward_fn, progress_bar_desc)
with set_forward_context(
attn_metadata=None, # piecewise no need attn_metadata
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings,
):
model(**model_inputs)
@torch.inference_mode()
def capture(
self,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
) -> None:
common_kwargs = dict(
device=self.device,
capture_fn=self.capture_graph,
model=model,
model_state=model_state,
input_buffers=input_buffers,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
has_lora=has_lora,
)
# Phase 1: Capture for mixed prefill-decode batches if needed.
mixed_mode = self.cudagraph_mode.mixed_mode()
if mixed_mode != CUDAGraphMode.NONE:
capture_graphs(
cudagraph_sizes=self.cudagraph_sizes,
capture_cudagraph_mode=mixed_mode,
desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
uniform_decode=False,
**common_kwargs,
)
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
# This is only needed if we use a separate routine for decode batches
# and the decode_mode is FULL.
if self.uniform_decode_cudagraph_sizes:
capture_graphs(
cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
capture_cudagraph_mode=CUDAGraphMode.FULL,
desc="Capturing CUDA graphs (decode, FULL)",
uniform_decode=True,
**common_kwargs,
)
def get_cudagraph_runtime_mode(
self, num_reqs: int, num_tokens: int, max_query_len: int
) -> tuple[CUDAGraphMode, int | None]:
is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_tokens == max_query_len * num_reqs
)
cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
elif is_uniform_decode:
cudagraph_mode = self.cudagraph_mode.decode_mode()
else:
cudagraph_mode = self.cudagraph_mode.mixed_mode()
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def run_fullgraph( def run_fullgraph(
self, num_tokens: int self, desc: BatchExecutionDescriptor
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens" """Replay a captured FULL cudagraph and return hidden states."""
# Sync offloader before replay - needed when transitioning from super().run_fullgraph(desc)
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[num_tokens].replay()
assert self.hidden_states is not None assert self.hidden_states is not None
hidden_states = self.hidden_states[:num_tokens] hidden_states = self.hidden_states[: desc.num_tokens]
if not self.use_aux_hidden_state_outputs: if not self.use_aux_hidden_state_outputs:
return hidden_states return hidden_states
return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states] return hidden_states, [x[: desc.num_tokens] for x in self.aux_hidden_states]
def get_cudagraph_sizes(
capture_sizes: list[int] | None,
max_num_reqs: int,
max_num_tokens: int,
cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int = 1,
uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
# Support both FULL and PIECEWISE cudagraph modes
if cudagraph_mode == CUDAGraphMode.NONE:
return {}, {}
if not capture_sizes:
return {}, {}
capture_sizes = sorted(capture_sizes)
if not capture_sizes:
return {}, {}
cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1):
for x in capture_sizes:
if i <= x:
cudagraph_sizes[i] = x
break
uniform_decode_cudagraph_sizes: dict[int, int] = {}
if uniform_decode_cudagraph:
max_num_tokens = max_num_reqs * uniform_decode_query_len
uniform_decode_cudagraph_sizes = {
k: v
for k, v in cudagraph_sizes.items()
if v <= max_num_tokens and v >= uniform_decode_query_len
}
return cudagraph_sizes, uniform_decode_cudagraph_sizes
def capture_graphs(
cudagraph_sizes: dict[int, int],
device: torch.device,
capture_fn: Callable,
capture_cudagraph_mode: CUDAGraphMode,
desc: str = "Capturing CUDA graphs",
**capture_kwargs,
) -> None:
# Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
with graph_capture(device=device):
for size in sizes_to_capture:
capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
def prepare_inputs_to_capture( def prepare_inputs_to_capture(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.v1.worker.gpu.cudagraph_utils import (
BatchExecutionDescriptor,
CudaGraphManager,
)
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None: def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
...@@ -12,66 +19,63 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N ...@@ -12,66 +19,63 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu") return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
def get_batch_metadata_across_dp( def sync_cudagraph_and_dp_padding(
cudagraph_manager: CudaGraphManager,
desired_batch_desc: BatchExecutionDescriptor,
num_tokens: int, num_tokens: int,
cudagraph_size: int, num_reqs: int,
cudagraph_runtime_mode: int, uniform_token_count: int | None,
dp_size: int, dp_size: int,
dp_rank: int, dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
assert dp_size > 1 """
# Use CPU group to avoid CPU-GPU synchronization. Coordinates the batch descriptor and DP padding across all ranks.
Returns (synced_batch_desc, num_tokens_across_dp).
"""
assert dp_size > 1, "DP size must be greater than 1"
group = get_dp_group().cpu_group group = get_dp_group().cpu_group
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu") tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size tensor[1][dp_rank] = desired_batch_desc.cg_mode.value
tensor[2][dp_rank] = cudagraph_runtime_mode tensor[2][dp_rank] = uniform_token_count or 0 # (0 means None)
dist.all_reduce(tensor, group=group) dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1], tensor[2]
num_tokens_across_dp = tensor[0]
cg_mode_across_dp = tensor[1]
uniform_token_counts_across_dp = tensor[2]
def get_cudagraph_and_dp_padding( if torch.all(num_tokens_across_dp == 0).item():
num_tokens: int, synced_desc = BatchExecutionDescriptor(
cudagraph_size: int | None, cg_mode=CUDAGraphMode.NONE, num_tokens=0, num_reqs=0
cudagraph_runtime_mode: int, )
dp_size: int, return synced_desc, None
dp_rank: int,
) -> tuple[int, torch.Tensor | None, int]:
if dp_size == 1:
if cudagraph_size is not None:
return cudagraph_size, None, cudagraph_runtime_mode
else:
return num_tokens, None, cudagraph_runtime_mode
# Convert None to -1 for sync (indicates no cudagraph available) synced_cg_mode = CUDAGraphMode(int(cg_mode_across_dp.min().item()))
if num_tokens == 0:
cudagraph_size = 0
elif cudagraph_size is None:
cudagraph_size = -1
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = ( # If any rank wants to run eager, all ranks run eager
get_batch_metadata_across_dp( if synced_cg_mode == CUDAGraphMode.NONE:
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank return BatchExecutionDescriptor(
) cg_mode=CUDAGraphMode.NONE,
num_tokens=num_tokens,
num_reqs=num_reqs,
), num_tokens_across_dp
synced_num_tokens = int(num_tokens_across_dp.max().item())
synced_uniform_token_count = uniform_token_counts_across_dp[0]
# If ranks disagree on the uniform token count, or its 0 (means None) set to None
if synced_uniform_token_count == 0 or not torch.all(
uniform_token_counts_across_dp == synced_uniform_token_count
):
synced_uniform_token_count = None
# Dispatch for the final synced values, use num_reqs instead of synced_num_reqs
# so we don't perform request padding for PIECEWISE graphs
synced_desc = cudagraph_manager.dispatch(
num_reqs, synced_num_tokens, synced_uniform_token_count
) )
if torch.all(num_tokens_across_dp == 0).item():
# All ranks have zero tokens to run.
return 0, None, 0
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum. # Update num_tokens_across_dp to reflect padded size.
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item()) num_tokens_across_dp[:] = synced_desc.num_tokens
# Check if all ranks have valid cudagraph_size.
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
if synced_cudagraph_mode != 0 and all_have_cudagraph: return synced_desc, num_tokens_across_dp
# All ranks use cudagraph. Pad to max cudagraph_size.
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
num_tokens_across_dp[:] = max_cudagraph_size
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
else:
# Fall back to eager mode (no cudagraph).
# Either some rank doesn't have cudagraph size or mode is NONE.
synced_cudagraph_mode = 0
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode
...@@ -37,6 +37,7 @@ class InputBatch: ...@@ -37,6 +37,7 @@ class InputBatch:
# batch_idx -> req_id # batch_idx -> req_id
req_ids: list[str] req_ids: list[str]
num_reqs: int num_reqs: int
num_reqs_after_padding: int
# batch_idx -> req_state_idx # batch_idx -> req_state_idx
idx_mapping: torch.Tensor idx_mapping: torch.Tensor
...@@ -123,6 +124,7 @@ class InputBatch: ...@@ -123,6 +124,7 @@ class InputBatch:
return cls( return cls(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs,
num_reqs_after_padding=num_reqs,
idx_mapping=idx_mapping, idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np, idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping, expanded_idx_mapping=expanded_idx_mapping,
...@@ -330,7 +332,8 @@ def combine_sampled_and_draft_tokens( ...@@ -330,7 +332,8 @@ def combine_sampled_and_draft_tokens(
cu_num_logits: torch.Tensor, cu_num_logits: torch.Tensor,
num_logits: int, num_logits: int,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = seq_lens.shape[0] # use idx_mapping.shape[0] for actual request count
num_reqs = idx_mapping.shape[0]
num_speculative_steps = draft_tokens.shape[-1] num_speculative_steps = draft_tokens.shape[-1]
logits_indices = torch.empty( logits_indices = torch.empty(
......
...@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader import get_model_loader ...@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
...@@ -57,8 +56,12 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -57,8 +56,12 @@ from vllm.v1.worker.gpu.attn_utils import (
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.cudagraph_utils import (
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding BatchExecutionDescriptor,
ModelCudaGraphManager,
get_uniform_token_count,
)
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import ( from vllm.v1.worker.gpu.input_batch import (
InputBatch, InputBatch,
InputBuffers, InputBuffers,
...@@ -137,6 +140,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -137,6 +140,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.is_first_pp_rank = True self.is_first_pp_rank = True
self.is_last_pp_rank = True self.is_last_pp_rank = True
# Data parallelism.
self.dp_size = self.parallel_config.data_parallel_size
self.dp_rank = self.parallel_config.data_parallel_rank
# Decode context parallelism. # Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1 self.use_dcp = self.dcp_size > 1
...@@ -193,10 +200,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -193,10 +200,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs) self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
# CUDA graphs. # CUDA graphs.
self.cudagraph_manager = CudaGraphManager( self.decode_query_len = self.num_speculative_steps + 1
self.cudagraph_manager = ModelCudaGraphManager(
self.vllm_config, self.vllm_config,
self.use_aux_hidden_state_outputs,
self.device, self.device,
self.compilation_config.cudagraph_mode,
decode_query_len=self.decode_query_len,
) )
# Structured outputs worker. # Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker( self.structured_outputs_worker = StructuredOutputsWorker(
...@@ -331,17 +340,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -331,17 +340,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
**kwargs, **kwargs,
) -> tuple[torch.Tensor | None, torch.Tensor | None]: ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
# Create a dummy scheduler output. # Create a dummy scheduler output.
if uniform_decode:
# Align tokens to uniform_decode_query_len for cudagraph
# compatibility across DP ranks.
query_len = self.cudagraph_manager.uniform_decode_query_len
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs)
num_tokens = num_reqs * query_len
num_tokens_per_request = [query_len] * num_reqs
else:
num_reqs = min(num_tokens, self.max_num_reqs) num_reqs = min(num_tokens, self.max_num_reqs)
if uniform_decode:
# HACK(lucas): for now since the worker is shared between MRV1 and MRV2,
# and for spec-decode with MTP we want to make sure the dummy runs use
# 1+num_speculative_tokens we use max here, this will likely be eventually
# changed in the worker: https://github.com/vllm-project/vllm/pull/35243
num_tokens = max(num_tokens, self.decode_query_len)
num_reqs = num_tokens // self.decode_query_len
assert num_tokens % self.decode_query_len == 0
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs num_tokens_per_request[-1] += num_tokens % num_reqs
assert sum(num_tokens_per_request) == num_tokens assert sum(num_tokens_per_request) == num_tokens
num_scheduled_tokens = { num_scheduled_tokens = {
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request) f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
...@@ -498,13 +508,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -498,13 +508,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with self.maybe_setup_dummy_loras(self.lora_config): with self.maybe_setup_dummy_loras(self.lora_config):
self.cudagraph_manager.capture( self.cudagraph_manager.capture(
model=self.model, self.model,
model_state=self.model_state, self.model_state,
input_buffers=self.input_buffers, self.input_buffers,
block_tables=self.block_tables, self.block_tables,
attn_groups=self.attn_groups, self.attn_groups,
kv_cache_config=self.kv_cache_config, self.kv_cache_config,
has_lora=self.lora_config is not None, has_lora=self.lora_config is not None,
use_aux_hidden_state_outputs=self.use_aux_hidden_state_outputs,
) )
if self.speculator is not None: if self.speculator is not None:
self.speculator.capture_model() self.speculator.capture_model()
...@@ -592,9 +603,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -592,9 +603,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
def prepare_inputs( def prepare_inputs(
self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int self, scheduler_output: SchedulerOutput, batch_desc: BatchExecutionDescriptor
) -> InputBatch: ) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens num_tokens = scheduler_output.total_num_scheduled_tokens
num_tokens_after_padding = batch_desc.num_tokens
assert num_tokens > 0 assert num_tokens > 0
num_tokens_per_req = scheduler_output.num_scheduled_tokens num_tokens_per_req = scheduler_output.num_scheduled_tokens
num_reqs = len(num_tokens_per_req) num_reqs = len(num_tokens_per_req)
...@@ -644,6 +656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -644,6 +656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
# Get query_start_loc. # Get query_start_loc.
# num_reqs_padded is None for PIECEWISE graphs (no request padding needed)
num_reqs_padded = batch_desc.num_reqs or num_reqs
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32) query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0 query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1]) np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
...@@ -651,8 +665,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -651,8 +665,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing. # Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens query_start_loc_np[num_reqs + 1 :] = num_tokens
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc) async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs + 1] query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs_padded + 1]
# Get prefill tokens if any. # Get prefill tokens if any.
if self.req_states.any_prefills(idx_mapping_np): if self.req_states.any_prefills(idx_mapping_np):
...@@ -674,7 +688,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -674,7 +688,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_buffers.positions, self.input_buffers.positions,
self.input_buffers.seq_lens, self.input_buffers.seq_lens,
) )
seq_lens = self.input_buffers.seq_lens[:num_reqs] seq_lens = self.input_buffers.seq_lens[:num_reqs_padded]
dcp_local_seq_lens = None dcp_local_seq_lens = None
if self.use_dcp: if self.use_dcp:
...@@ -687,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -687,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank, self.dcp_rank,
self.cp_interleave, self.cp_interleave,
) )
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs] dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs_padded]
# Some input token ids are directly read from the last sampled tokens # Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from. # and draft tokens. Also, get the logits indices to sample tokens from.
...@@ -706,6 +720,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -706,6 +720,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return InputBatch( return InputBatch(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs,
num_reqs_after_padding=num_reqs_padded,
idx_mapping=idx_mapping, idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np, idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping, expanded_idx_mapping=expanded_idx_mapping,
...@@ -729,13 +744,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -729,13 +744,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def prepare_attn( def prepare_attn(
self, input_batch: InputBatch self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] # Block tables: num_kv_cache_groups x [num_reqs_padded, max_num_blocks].
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping) block_tables = self.block_tables.gather_block_tables(
# Compute slot mappings: [num_kv_cache_groups, num_tokens] input_batch.idx_mapping,
num_reqs_padded=input_batch.num_reqs_after_padding,
)
# Slot mappings: [num_kv_cache_groups, num_tokens_padded].
# Kernel pads beyond num_tokens with PAD_SLOT_ID.
slot_mappings = self.block_tables.compute_slot_mappings( slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping, input_batch.idx_mapping,
input_batch.query_start_loc, input_batch.query_start_loc,
input_batch.positions, input_batch.positions,
num_tokens_padded=input_batch.num_tokens_after_padding,
) )
return block_tables, slot_mappings return block_tables, slot_mappings
...@@ -851,27 +871,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -851,27 +871,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
empty_output = self.kv_connector.no_forward(scheduler_output) empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output return empty_output
# Get local cudagraph mode and size. # Get batch descriptor and sync across DP ranks.
local_cudagraph_mode, local_cudagraph_size = ( num_reqs = len(scheduler_output.num_scheduled_tokens)
self.cudagraph_manager.get_cudagraph_runtime_mode( num_toks = scheduler_output.total_num_scheduled_tokens
num_reqs=len(scheduler_output.num_scheduled_tokens), max_query_len = max(scheduler_output.num_scheduled_tokens.values())
num_tokens=scheduler_output.total_num_scheduled_tokens, uniform_tok_count = get_uniform_token_count(num_reqs, num_toks, max_query_len)
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
)
)
# DP sync: num_tokens + cudagraph_size + cudagraph_mode batch_desc = self.cudagraph_manager.dispatch(
num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = ( num_reqs, num_toks, uniform_tok_count
get_cudagraph_and_dp_padding(
scheduler_output.total_num_scheduled_tokens,
local_cudagraph_size,
local_cudagraph_mode.value,
self.parallel_config.data_parallel_size,
self.parallel_config.data_parallel_rank,
) )
num_tokens_across_dp = None
if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager,
batch_desc,
num_toks,
num_reqs,
uniform_tok_count,
self.dp_size,
self.dp_rank,
) )
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
if num_tokens_after_padding == 0: if batch_desc.num_tokens == 0:
# All DP ranks have zero tokens to run. # All DP ranks have zero tokens to run.
empty_output = self.kv_connector.no_forward(scheduler_output) empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output return empty_output
...@@ -879,9 +901,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -879,9 +901,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not dummy_run: if not dummy_run:
# Common case. # Common case.
# Prepare all the inputs and copy to the input buffers. # Prepare all the inputs and copy to the input buffers.
input_batch = self.prepare_inputs( input_batch = self.prepare_inputs(scheduler_output, batch_desc)
scheduler_output, num_tokens_after_padding
)
block_tables, slot_mappings = self.prepare_attn(input_batch) block_tables, slot_mappings = self.prepare_attn(input_batch)
if self.lora_config: if self.lora_config:
...@@ -894,9 +914,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -894,9 +914,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self._set_active_loras(*lora_inputs) self._set_active_loras(*lora_inputs)
else: else:
# No actual tokens to run. A dummy run for DP or memory profiling. # No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
input_batch = InputBatch.make_dummy( input_batch = InputBatch.make_dummy(
num_reqs, num_tokens_after_padding, self.input_buffers batch_desc.num_reqs or num_reqs,
batch_desc.num_tokens,
self.input_buffers,
) )
if not skip_attn_for_dummy_run: if not skip_attn_for_dummy_run:
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch) block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
...@@ -948,14 +969,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -948,14 +969,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
model_inputs["intermediate_tensors"] = intermediate_tensors model_inputs["intermediate_tensors"] = intermediate_tensors
# Run model. # Run model.
if cudagraph_runtime_mode == CUDAGraphMode.FULL: if batch_desc.cg_mode == CUDAGraphMode.FULL:
# Use explicit cudagraph replay for FULL mode. # Use explicit cudagraph replay for FULL mode.
# NOTE(woosuk): Here, we don't need to pass the input tensors, # NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers. # because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
model_output = self.cudagraph_manager.run_fullgraph( model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
input_batch.num_tokens_after_padding
)
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
...@@ -972,7 +991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -972,7 +991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding, num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=batch_desc.cg_mode,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings_by_layer, slot_mapping=slot_mappings_by_layer,
......
...@@ -142,12 +142,15 @@ class DefaultModelState(ModelState): ...@@ -142,12 +142,15 @@ class DefaultModelState(ModelState):
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> dict[str, Any]: ) -> dict[str, Any]:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item() max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata( attn_metadata = build_attn_metadata(
attn_groups=attn_groups, attn_groups=attn_groups,
num_reqs=input_batch.num_reqs, num_reqs=num_reqs,
num_tokens=input_batch.num_tokens, num_tokens=num_tokens,
query_start_loc_gpu=input_batch.query_start_loc, query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len, max_query_len=max_query_len,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from typing import Any
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.model_executor.offloader.base import get_offloader
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import ( from vllm.v1.worker.gpu.cudagraph_utils import (
capture_graphs, BatchExecutionDescriptor,
get_cudagraph_sizes, CudaGraphManager,
prepare_inputs_to_capture, prepare_inputs_to_capture,
) )
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup from vllm.v1.worker.utils import AttentionGroup
class EagleCudaGraphManager: class EagleCudaGraphManager(CudaGraphManager):
def __init__(self, vllm_config: VllmConfig, device: torch.device): """CudaGraphManager for Eagle speculative decoding (FULL mode only)."""
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len def __init__(
self.max_num_reqs = self.scheduler_config.max_num_seqs self,
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens vllm_config: VllmConfig,
self.dp_size = vllm_config.parallel_config.data_parallel_size device: torch.device,
self.compilation_config = vllm_config.compilation_config cudagraph_mode: CUDAGraphMode,
assert self.compilation_config is not None draft_tokens: torch.Tensor,
):
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode. assert not cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE), (
self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode() "EagleCudaGraphManager does not support PIECEWISE mode yet"
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
_, self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
uniform_decode_query_len=1,
uniform_decode_cudagraph=True,
) )
# Eagle always uses uniform decode with query_len=1
self.graphs: dict[int, torch.cuda.CUDAGraph] = {} super().__init__(vllm_config, device, cudagraph_mode, decode_query_len=1)
self.pool = None self.draft_tokens = draft_tokens
if self.cudagraph_mode != CUDAGraphMode.NONE:
# Use a dedicated pool for Eagle to avoid memory overlap with the main
# model's cudagraph. The base class uses a shared global pool, but Eagle's
# internal allocations (e.g., gumbel_sample temporaries) can conflict with
# the main model's allocations when sharing the same pool.
if cudagraph_mode:
self.pool = torch.cuda.graph_pool_handle() self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None: def capture(
return self.cudagraph_sizes.get(num_tokens)
def get_cudagraph_runtime_mode(
self, num_tokens: int
) -> tuple[CUDAGraphMode, int | None]:
cudagraph_size = self.get_cudagraph_size(num_tokens)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.cudagraph_mode
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def capture_graph(
self, self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
generate_fn: Callable, generate_fn: Callable,
model_state: ModelState, model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
block_tables: BlockTables, block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None: ) -> None:
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( """Capture CUDA graphs for Eagle speculative decoding (FULL mode only)."""
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
def create_forward_fn(
desc: BatchExecutionDescriptor,
) -> Callable[[CUDAGraphMode], None]:
num_tokens = desc.num_tokens
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
num_tokens_across_dp = (
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
if self.dp_size > 1
else None
) )
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata, slot_mappings = prepare_inputs_to_capture( attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs, num_reqs,
num_tokens, num_tokens,
...@@ -104,111 +73,19 @@ class EagleCudaGraphManager: ...@@ -104,111 +73,19 @@ class EagleCudaGraphManager:
attn_groups, attn_groups,
kv_cache_config, kv_cache_config,
) )
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Capture the graph.
capture_fn(
num_reqs=num_reqs,
num_tokens=num_tokens,
generate_fn=generate_fn,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
)
def _capture_full_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with torch.cuda.graph(graph, self.pool):
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
self.graphs[num_tokens] = graph
def _capture_piecewise_graph( return lambda cg_mode: generate_fn(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
generate_fn(
num_reqs, num_reqs,
num_tokens, num_tokens,
attn_metadata, attn_metadata,
slot_mappings, slot_mappings,
num_tokens_across_dp, num_tokens_across_dp,
CUDAGraphMode.PIECEWISE, cg_mode,
) )
@torch.inference_mode() super().capture(create_forward_fn, progress_bar_desc)
def capture(
self,
generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> None:
if self.cudagraph_mode == CUDAGraphMode.NONE:
return
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn,
model_state=model_state,
input_buffers=input_buffers,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
)
def run_fullgraph(self, num_tokens: int) -> None: def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor:
assert num_tokens in self.graphs """Replay a captured FULL cudagraph and return draft tokens."""
# Sync offloader before replay - needed when transitioning from super().run_fullgraph(desc)
# eager/piecewise to full cudagraph (e.g., prefill → decode). return self.draft_tokens
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[num_tokens].replay()
...@@ -16,7 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -16,7 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer, build_slot_mappings_by_layer,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
...@@ -75,7 +75,16 @@ class EagleSpeculator: ...@@ -75,7 +75,16 @@ class EagleSpeculator:
device=device, device=device,
) )
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device) # currently we don't support PIECEWISE for Eagle.
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
else:
cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_manager = EagleCudaGraphManager(
vllm_config, device, cudagraph_mode, self.draft_tokens
)
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
self.model = load_eagle_model(target_model, self.vllm_config) self.model = load_eagle_model(target_model, self.vllm_config)
...@@ -171,7 +180,7 @@ class EagleSpeculator: ...@@ -171,7 +180,7 @@ class EagleSpeculator:
) )
if attn_metadata is not None: if attn_metadata is not None:
self.block_tables.compute_slot_mappings( self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos idx_mapping, query_start_loc, pos, num_tokens_padded
) )
def capture_model(self) -> None: def capture_model(self) -> None:
...@@ -185,6 +194,7 @@ class EagleSpeculator: ...@@ -185,6 +194,7 @@ class EagleSpeculator:
self.block_tables, self.block_tables,
self.attn_groups, self.attn_groups,
self.kv_cache_config, self.kv_cache_config,
progress_bar_desc="Capturing eagle CUDA graphs",
) )
@torch.inference_mode() @torch.inference_mode()
...@@ -251,6 +261,7 @@ class EagleSpeculator: ...@@ -251,6 +261,7 @@ class EagleSpeculator:
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs num_reqs = input_batch.num_reqs
num_reqs_padded = input_batch.num_reqs_after_padding
# NOTE(woosuk): For draft sampling, we only consider the temperature # NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p, # and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance. # for simplicity and performance.
...@@ -292,48 +303,52 @@ class EagleSpeculator: ...@@ -292,48 +303,52 @@ class EagleSpeculator:
self.max_num_reqs, self.max_num_reqs,
) )
if not (dummy_run and skip_attn_for_dummy_run): # Get batch descriptor and sync across DP ranks.
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] # Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_mode, cudagraph_size = ( batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1)
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs) num_tokens_across_dp = None
)
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = ( if self.dp_size > 1:
get_cudagraph_and_dp_padding( batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager,
batch_desc,
num_reqs, num_reqs,
cudagraph_size, num_reqs,
cudagraph_mode.value, 1, # uniform_token_count
self.dp_size, self.dp_size,
self.dp_rank, self.dp_rank,
) )
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos, batch_desc.num_tokens
) )
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
if cudagraph_mode == CUDAGraphMode.FULL: if batch_desc.cg_mode == CUDAGraphMode.FULL:
# Run full CUDA graph. return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs]
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
return self.draft_tokens[:num_reqs]
# Run eager or piecewise CUDA graph. # Run eager or piecewise CUDA graph.
attn_metadata_updated = None attn_metadata_updated = None
slot_mappings_updated = None slot_mappings_updated = None
if not (dummy_run and skip_attn_for_dummy_run): if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc_cpu = torch.arange( query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu" num_reqs_padded + 1, dtype=torch.int32, device="cpu"
) )
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] block_tables = [
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
]
# FIXME(woosuk): This is UNSAFE!! # FIXME(woosuk): This is UNSAFE!!
attn_metadata_updated = build_attn_metadata( attn_metadata_updated = build_attn_metadata(
attn_groups=self.attn_groups, attn_groups=self.attn_groups,
num_reqs=num_reqs, num_reqs=num_reqs_padded,
num_tokens=num_reqs, num_tokens=num_reqs_padded,
query_start_loc_gpu=query_start_loc, query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1, max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs], seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
max_seq_len=self.max_model_len, max_seq_len=self.max_model_len,
block_tables=block_tables, block_tables=block_tables,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
...@@ -345,11 +360,11 @@ class EagleSpeculator: ...@@ -345,11 +360,11 @@ class EagleSpeculator:
self.generate_draft( self.generate_draft(
num_reqs, num_reqs,
num_tokens_padded, batch_desc.num_tokens,
attn_metadata_updated, attn_metadata_updated,
slot_mappings_updated, slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode, cudagraph_runtime_mode=batch_desc.cg_mode,
) )
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment