Unverified Commit 483463f7 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[MRV2] Extensible CG dispatch rework (#35959)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 4e571ce6
...@@ -97,6 +97,9 @@ class CUDAGraphMode(enum.Enum): ...@@ -97,6 +97,9 @@ class CUDAGraphMode(enum.Enum):
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
def __bool__(self) -> bool:
return self != CUDAGraphMode.NONE
@config @config
class PassConfig: class PassConfig:
......
...@@ -104,19 +104,24 @@ class BlockTables: ...@@ -104,19 +104,24 @@ class BlockTables:
self.num_blocks.copy_to_uva() self.num_blocks.copy_to_uva()
def gather_block_tables( def gather_block_tables(
self, idx_mapping: torch.Tensor self,
idx_mapping: torch.Tensor,
num_reqs_padded: int,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0] num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)]( # Launch kernel with num_reqs_padded to fuse zeroing of padded rows.
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs_padded)](
idx_mapping, idx_mapping,
self.block_table_ptrs, self.block_table_ptrs,
self.input_block_table_ptrs, self.input_block_table_ptrs,
self.block_table_strides, self.block_table_strides,
self.num_blocks.gpu, self.num_blocks.gpu,
self.num_blocks.gpu.stride(0), self.num_blocks.gpu.stride(0),
num_reqs,
self.input_block_tables[0].shape[1], # max_num_blocks
BLOCK_SIZE=1024, # type: ignore BLOCK_SIZE=1024, # type: ignore
) )
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) return tuple(bt[:num_reqs_padded] for bt in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]: def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
# NOTE(woosuk): The output may be used for CUDA graph capture. # NOTE(woosuk): The output may be used for CUDA graph capture.
...@@ -130,6 +135,7 @@ class BlockTables: ...@@ -130,6 +135,7 @@ class BlockTables:
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor, query_start_loc: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
num_tokens_padded: int,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = idx_mapping.shape[0] num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0] num_tokens = positions.shape[0]
...@@ -151,7 +157,7 @@ class BlockTables: ...@@ -151,7 +157,7 @@ class BlockTables:
PAD_ID=PAD_SLOT_ID, PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore TRITON_BLOCK_SIZE=1024, # type: ignore
) )
return self.slot_mappings[:, :num_tokens] return self.slot_mappings[:, :num_tokens_padded]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor: def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries. # Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
...@@ -173,21 +179,31 @@ def _gather_block_tables_kernel( ...@@ -173,21 +179,31 @@ def _gather_block_tables_kernel(
block_table_strides, # [num_kv_cache_groups] block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride, num_blocks_stride,
num_reqs, # actual number of requests (for padding)
max_num_blocks, # stride for zeroing padded rows
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
# kv cache group id # kv cache group id
group_id = tl.program_id(0) group_id = tl.program_id(0)
batch_idx = tl.program_id(1) batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
stride = tl.load(block_table_strides + group_id)
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
if batch_idx >= num_reqs:
# Zero out padded rows.
for i in tl.range(0, max_num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(dst_row_ptr + offset, 0, mask=offset < max_num_blocks)
return
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx) num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32) src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE): for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE) offset = i + tl.arange(0, BLOCK_SIZE)
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.v1.worker.gpu.cudagraph_utils import (
BatchExecutionDescriptor,
CudaGraphManager,
)
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None: def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
...@@ -12,66 +19,63 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N ...@@ -12,66 +19,63 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu") return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
def get_batch_metadata_across_dp( def sync_cudagraph_and_dp_padding(
cudagraph_manager: CudaGraphManager,
desired_batch_desc: BatchExecutionDescriptor,
num_tokens: int, num_tokens: int,
cudagraph_size: int, num_reqs: int,
cudagraph_runtime_mode: int, uniform_token_count: int | None,
dp_size: int, dp_size: int,
dp_rank: int, dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
assert dp_size > 1 """
# Use CPU group to avoid CPU-GPU synchronization. Coordinates the batch descriptor and DP padding across all ranks.
Returns (synced_batch_desc, num_tokens_across_dp).
"""
assert dp_size > 1, "DP size must be greater than 1"
group = get_dp_group().cpu_group group = get_dp_group().cpu_group
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu") tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size tensor[1][dp_rank] = desired_batch_desc.cg_mode.value
tensor[2][dp_rank] = cudagraph_runtime_mode tensor[2][dp_rank] = uniform_token_count or 0 # (0 means None)
dist.all_reduce(tensor, group=group) dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1], tensor[2]
num_tokens_across_dp = tensor[0]
cg_mode_across_dp = tensor[1]
uniform_token_counts_across_dp = tensor[2]
def get_cudagraph_and_dp_padding( if torch.all(num_tokens_across_dp == 0).item():
num_tokens: int, synced_desc = BatchExecutionDescriptor(
cudagraph_size: int | None, cg_mode=CUDAGraphMode.NONE, num_tokens=0, num_reqs=0
cudagraph_runtime_mode: int, )
dp_size: int, return synced_desc, None
dp_rank: int,
) -> tuple[int, torch.Tensor | None, int]:
if dp_size == 1:
if cudagraph_size is not None:
return cudagraph_size, None, cudagraph_runtime_mode
else:
return num_tokens, None, cudagraph_runtime_mode
# Convert None to -1 for sync (indicates no cudagraph available) synced_cg_mode = CUDAGraphMode(int(cg_mode_across_dp.min().item()))
if num_tokens == 0:
cudagraph_size = 0
elif cudagraph_size is None:
cudagraph_size = -1
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = ( # If any rank wants to run eager, all ranks run eager
get_batch_metadata_across_dp( if synced_cg_mode == CUDAGraphMode.NONE:
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank return BatchExecutionDescriptor(
) cg_mode=CUDAGraphMode.NONE,
num_tokens=num_tokens,
num_reqs=num_reqs,
), num_tokens_across_dp
synced_num_tokens = int(num_tokens_across_dp.max().item())
synced_uniform_token_count = uniform_token_counts_across_dp[0]
# If ranks disagree on the uniform token count, or its 0 (means None) set to None
if synced_uniform_token_count == 0 or not torch.all(
uniform_token_counts_across_dp == synced_uniform_token_count
):
synced_uniform_token_count = None
# Dispatch for the final synced values, use num_reqs instead of synced_num_reqs
# so we don't perform request padding for PIECEWISE graphs
synced_desc = cudagraph_manager.dispatch(
num_reqs, synced_num_tokens, synced_uniform_token_count
) )
if torch.all(num_tokens_across_dp == 0).item():
# All ranks have zero tokens to run.
return 0, None, 0
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum. # Update num_tokens_across_dp to reflect padded size.
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item()) num_tokens_across_dp[:] = synced_desc.num_tokens
# Check if all ranks have valid cudagraph_size.
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
if synced_cudagraph_mode != 0 and all_have_cudagraph: return synced_desc, num_tokens_across_dp
# All ranks use cudagraph. Pad to max cudagraph_size.
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
num_tokens_across_dp[:] = max_cudagraph_size
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
else:
# Fall back to eager mode (no cudagraph).
# Either some rank doesn't have cudagraph size or mode is NONE.
synced_cudagraph_mode = 0
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode
...@@ -37,6 +37,7 @@ class InputBatch: ...@@ -37,6 +37,7 @@ class InputBatch:
# batch_idx -> req_id # batch_idx -> req_id
req_ids: list[str] req_ids: list[str]
num_reqs: int num_reqs: int
num_reqs_after_padding: int
# batch_idx -> req_state_idx # batch_idx -> req_state_idx
idx_mapping: torch.Tensor idx_mapping: torch.Tensor
...@@ -123,6 +124,7 @@ class InputBatch: ...@@ -123,6 +124,7 @@ class InputBatch:
return cls( return cls(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs,
num_reqs_after_padding=num_reqs,
idx_mapping=idx_mapping, idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np, idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping, expanded_idx_mapping=expanded_idx_mapping,
...@@ -330,7 +332,8 @@ def combine_sampled_and_draft_tokens( ...@@ -330,7 +332,8 @@ def combine_sampled_and_draft_tokens(
cu_num_logits: torch.Tensor, cu_num_logits: torch.Tensor,
num_logits: int, num_logits: int,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs = seq_lens.shape[0] # use idx_mapping.shape[0] for actual request count
num_reqs = idx_mapping.shape[0]
num_speculative_steps = draft_tokens.shape[-1] num_speculative_steps = draft_tokens.shape[-1]
logits_indices = torch.empty( logits_indices = torch.empty(
......
...@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader import get_model_loader ...@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
...@@ -57,8 +56,12 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -57,8 +56,12 @@ from vllm.v1.worker.gpu.attn_utils import (
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.cudagraph_utils import (
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding BatchExecutionDescriptor,
ModelCudaGraphManager,
get_uniform_token_count,
)
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import ( from vllm.v1.worker.gpu.input_batch import (
InputBatch, InputBatch,
InputBuffers, InputBuffers,
...@@ -137,6 +140,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -137,6 +140,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.is_first_pp_rank = True self.is_first_pp_rank = True
self.is_last_pp_rank = True self.is_last_pp_rank = True
# Data parallelism.
self.dp_size = self.parallel_config.data_parallel_size
self.dp_rank = self.parallel_config.data_parallel_rank
# Decode context parallelism. # Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1 self.use_dcp = self.dcp_size > 1
...@@ -193,10 +200,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -193,10 +200,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs) self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
# CUDA graphs. # CUDA graphs.
self.cudagraph_manager = CudaGraphManager( self.decode_query_len = self.num_speculative_steps + 1
self.cudagraph_manager = ModelCudaGraphManager(
self.vllm_config, self.vllm_config,
self.use_aux_hidden_state_outputs,
self.device, self.device,
self.compilation_config.cudagraph_mode,
decode_query_len=self.decode_query_len,
) )
# Structured outputs worker. # Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker( self.structured_outputs_worker = StructuredOutputsWorker(
...@@ -331,17 +340,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -331,17 +340,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
**kwargs, **kwargs,
) -> tuple[torch.Tensor | None, torch.Tensor | None]: ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
# Create a dummy scheduler output. # Create a dummy scheduler output.
num_reqs = min(num_tokens, self.max_num_reqs)
if uniform_decode: if uniform_decode:
# Align tokens to uniform_decode_query_len for cudagraph # HACK(lucas): for now since the worker is shared between MRV1 and MRV2,
# compatibility across DP ranks. # and for spec-decode with MTP we want to make sure the dummy runs use
query_len = self.cudagraph_manager.uniform_decode_query_len # 1+num_speculative_tokens we use max here, this will likely be eventually
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs) # changed in the worker: https://github.com/vllm-project/vllm/pull/35243
num_tokens = num_reqs * query_len num_tokens = max(num_tokens, self.decode_query_len)
num_tokens_per_request = [query_len] * num_reqs num_reqs = num_tokens // self.decode_query_len
else: assert num_tokens % self.decode_query_len == 0
num_reqs = min(num_tokens, self.max_num_reqs) num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs num_tokens_per_request[-1] += num_tokens % num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs
assert sum(num_tokens_per_request) == num_tokens assert sum(num_tokens_per_request) == num_tokens
num_scheduled_tokens = { num_scheduled_tokens = {
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request) f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
...@@ -498,13 +508,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -498,13 +508,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with self.maybe_setup_dummy_loras(self.lora_config): with self.maybe_setup_dummy_loras(self.lora_config):
self.cudagraph_manager.capture( self.cudagraph_manager.capture(
model=self.model, self.model,
model_state=self.model_state, self.model_state,
input_buffers=self.input_buffers, self.input_buffers,
block_tables=self.block_tables, self.block_tables,
attn_groups=self.attn_groups, self.attn_groups,
kv_cache_config=self.kv_cache_config, self.kv_cache_config,
has_lora=self.lora_config is not None, has_lora=self.lora_config is not None,
use_aux_hidden_state_outputs=self.use_aux_hidden_state_outputs,
) )
if self.speculator is not None: if self.speculator is not None:
self.speculator.capture_model() self.speculator.capture_model()
...@@ -592,9 +603,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -592,9 +603,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
def prepare_inputs( def prepare_inputs(
self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int self, scheduler_output: SchedulerOutput, batch_desc: BatchExecutionDescriptor
) -> InputBatch: ) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens num_tokens = scheduler_output.total_num_scheduled_tokens
num_tokens_after_padding = batch_desc.num_tokens
assert num_tokens > 0 assert num_tokens > 0
num_tokens_per_req = scheduler_output.num_scheduled_tokens num_tokens_per_req = scheduler_output.num_scheduled_tokens
num_reqs = len(num_tokens_per_req) num_reqs = len(num_tokens_per_req)
...@@ -644,6 +656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -644,6 +656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
# Get query_start_loc. # Get query_start_loc.
# num_reqs_padded is None for PIECEWISE graphs (no request padding needed)
num_reqs_padded = batch_desc.num_reqs or num_reqs
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32) query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0 query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1]) np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
...@@ -651,8 +665,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -651,8 +665,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing. # Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens query_start_loc_np[num_reqs + 1 :] = num_tokens
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc) async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs + 1] query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs_padded + 1]
# Get prefill tokens if any. # Get prefill tokens if any.
if self.req_states.any_prefills(idx_mapping_np): if self.req_states.any_prefills(idx_mapping_np):
...@@ -674,7 +688,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -674,7 +688,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_buffers.positions, self.input_buffers.positions,
self.input_buffers.seq_lens, self.input_buffers.seq_lens,
) )
seq_lens = self.input_buffers.seq_lens[:num_reqs] seq_lens = self.input_buffers.seq_lens[:num_reqs_padded]
dcp_local_seq_lens = None dcp_local_seq_lens = None
if self.use_dcp: if self.use_dcp:
...@@ -687,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -687,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank, self.dcp_rank,
self.cp_interleave, self.cp_interleave,
) )
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs] dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs_padded]
# Some input token ids are directly read from the last sampled tokens # Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from. # and draft tokens. Also, get the logits indices to sample tokens from.
...@@ -706,6 +720,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -706,6 +720,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return InputBatch( return InputBatch(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs,
num_reqs_after_padding=num_reqs_padded,
idx_mapping=idx_mapping, idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np, idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping, expanded_idx_mapping=expanded_idx_mapping,
...@@ -729,13 +744,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -729,13 +744,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def prepare_attn( def prepare_attn(
self, input_batch: InputBatch self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] # Block tables: num_kv_cache_groups x [num_reqs_padded, max_num_blocks].
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping) block_tables = self.block_tables.gather_block_tables(
# Compute slot mappings: [num_kv_cache_groups, num_tokens] input_batch.idx_mapping,
num_reqs_padded=input_batch.num_reqs_after_padding,
)
# Slot mappings: [num_kv_cache_groups, num_tokens_padded].
# Kernel pads beyond num_tokens with PAD_SLOT_ID.
slot_mappings = self.block_tables.compute_slot_mappings( slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping, input_batch.idx_mapping,
input_batch.query_start_loc, input_batch.query_start_loc,
input_batch.positions, input_batch.positions,
num_tokens_padded=input_batch.num_tokens_after_padding,
) )
return block_tables, slot_mappings return block_tables, slot_mappings
...@@ -851,27 +871,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -851,27 +871,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
empty_output = self.kv_connector.no_forward(scheduler_output) empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output return empty_output
# Get local cudagraph mode and size. # Get batch descriptor and sync across DP ranks.
local_cudagraph_mode, local_cudagraph_size = ( num_reqs = len(scheduler_output.num_scheduled_tokens)
self.cudagraph_manager.get_cudagraph_runtime_mode( num_toks = scheduler_output.total_num_scheduled_tokens
num_reqs=len(scheduler_output.num_scheduled_tokens), max_query_len = max(scheduler_output.num_scheduled_tokens.values())
num_tokens=scheduler_output.total_num_scheduled_tokens, uniform_tok_count = get_uniform_token_count(num_reqs, num_toks, max_query_len)
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
) batch_desc = self.cudagraph_manager.dispatch(
num_reqs, num_toks, uniform_tok_count
) )
num_tokens_across_dp = None
# DP sync: num_tokens + cudagraph_size + cudagraph_mode if self.dp_size > 1:
num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = ( batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
get_cudagraph_and_dp_padding( self.cudagraph_manager,
scheduler_output.total_num_scheduled_tokens, batch_desc,
local_cudagraph_size, num_toks,
local_cudagraph_mode.value, num_reqs,
self.parallel_config.data_parallel_size, uniform_tok_count,
self.parallel_config.data_parallel_rank, self.dp_size,
self.dp_rank,
) )
)
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode) if batch_desc.num_tokens == 0:
if num_tokens_after_padding == 0:
# All DP ranks have zero tokens to run. # All DP ranks have zero tokens to run.
empty_output = self.kv_connector.no_forward(scheduler_output) empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output return empty_output
...@@ -879,9 +901,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -879,9 +901,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not dummy_run: if not dummy_run:
# Common case. # Common case.
# Prepare all the inputs and copy to the input buffers. # Prepare all the inputs and copy to the input buffers.
input_batch = self.prepare_inputs( input_batch = self.prepare_inputs(scheduler_output, batch_desc)
scheduler_output, num_tokens_after_padding
)
block_tables, slot_mappings = self.prepare_attn(input_batch) block_tables, slot_mappings = self.prepare_attn(input_batch)
if self.lora_config: if self.lora_config:
...@@ -894,9 +914,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -894,9 +914,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self._set_active_loras(*lora_inputs) self._set_active_loras(*lora_inputs)
else: else:
# No actual tokens to run. A dummy run for DP or memory profiling. # No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
input_batch = InputBatch.make_dummy( input_batch = InputBatch.make_dummy(
num_reqs, num_tokens_after_padding, self.input_buffers batch_desc.num_reqs or num_reqs,
batch_desc.num_tokens,
self.input_buffers,
) )
if not skip_attn_for_dummy_run: if not skip_attn_for_dummy_run:
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch) block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
...@@ -948,14 +969,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -948,14 +969,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
model_inputs["intermediate_tensors"] = intermediate_tensors model_inputs["intermediate_tensors"] = intermediate_tensors
# Run model. # Run model.
if cudagraph_runtime_mode == CUDAGraphMode.FULL: if batch_desc.cg_mode == CUDAGraphMode.FULL:
# Use explicit cudagraph replay for FULL mode. # Use explicit cudagraph replay for FULL mode.
# NOTE(woosuk): Here, we don't need to pass the input tensors, # NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers. # because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
model_output = self.cudagraph_manager.run_fullgraph( model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
input_batch.num_tokens_after_padding
)
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
...@@ -972,7 +991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -972,7 +991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding, num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=batch_desc.cg_mode,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings_by_layer, slot_mapping=slot_mappings_by_layer,
......
...@@ -142,12 +142,15 @@ class DefaultModelState(ModelState): ...@@ -142,12 +142,15 @@ class DefaultModelState(ModelState):
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> dict[str, Any]: ) -> dict[str, Any]:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item() max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata( attn_metadata = build_attn_metadata(
attn_groups=attn_groups, attn_groups=attn_groups,
num_reqs=input_batch.num_reqs, num_reqs=num_reqs,
num_tokens=input_batch.num_tokens, num_tokens=num_tokens,
query_start_loc_gpu=input_batch.query_start_loc, query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len, max_query_len=max_query_len,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from typing import Any
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.model_executor.offloader.base import get_offloader
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import ( from vllm.v1.worker.gpu.cudagraph_utils import (
capture_graphs, BatchExecutionDescriptor,
get_cudagraph_sizes, CudaGraphManager,
prepare_inputs_to_capture, prepare_inputs_to_capture,
) )
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup from vllm.v1.worker.utils import AttentionGroup
class EagleCudaGraphManager: class EagleCudaGraphManager(CudaGraphManager):
def __init__(self, vllm_config: VllmConfig, device: torch.device): """CudaGraphManager for Eagle speculative decoding (FULL mode only)."""
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len def __init__(
self.max_num_reqs = self.scheduler_config.max_num_seqs self,
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens vllm_config: VllmConfig,
self.dp_size = vllm_config.parallel_config.data_parallel_size device: torch.device,
self.compilation_config = vllm_config.compilation_config cudagraph_mode: CUDAGraphMode,
assert self.compilation_config is not None draft_tokens: torch.Tensor,
):
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode. assert not cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE), (
self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode() "EagleCudaGraphManager does not support PIECEWISE mode yet"
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
_, self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
uniform_decode_query_len=1,
uniform_decode_cudagraph=True,
) )
# Eagle always uses uniform decode with query_len=1
self.graphs: dict[int, torch.cuda.CUDAGraph] = {} super().__init__(vllm_config, device, cudagraph_mode, decode_query_len=1)
self.pool = None self.draft_tokens = draft_tokens
if self.cudagraph_mode != CUDAGraphMode.NONE:
# Use a dedicated pool for Eagle to avoid memory overlap with the main
# model's cudagraph. The base class uses a shared global pool, but Eagle's
# internal allocations (e.g., gumbel_sample temporaries) can conflict with
# the main model's allocations when sharing the same pool.
if cudagraph_mode:
self.pool = torch.cuda.graph_pool_handle() self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None: def capture(
return self.cudagraph_sizes.get(num_tokens)
def get_cudagraph_runtime_mode(
self, num_tokens: int
) -> tuple[CUDAGraphMode, int | None]:
cudagraph_size = self.get_cudagraph_size(num_tokens)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.cudagraph_mode
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def capture_graph(
self, self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
generate_fn: Callable, generate_fn: Callable,
model_state: ModelState, model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
block_tables: BlockTables, block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None: ) -> None:
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( """Capture CUDA graphs for Eagle speculative decoding (FULL mode only)."""
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
) def create_forward_fn(
if capture_cg_mode == CUDAGraphMode.PIECEWISE: desc: BatchExecutionDescriptor,
capture_fn = self._capture_piecewise_graph ) -> Callable[[CUDAGraphMode], None]:
else: num_tokens = desc.num_tokens
capture_fn = self._capture_full_graph num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
num_tokens_across_dp = (
num_reqs = min(num_tokens, self.max_num_reqs) torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
attn_metadata, slot_mappings = prepare_inputs_to_capture( if self.dp_size > 1
num_reqs, else None
num_tokens, )
model_state, attn_metadata, slot_mappings = prepare_inputs_to_capture(
input_buffers, num_reqs,
block_tables, num_tokens,
attn_groups, model_state,
kv_cache_config, input_buffers,
) block_tables,
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) attn_groups,
kv_cache_config,
# Warm up. )
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Capture the graph.
capture_fn(
num_reqs=num_reqs,
num_tokens=num_tokens,
generate_fn=generate_fn,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
)
def _capture_full_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with torch.cuda.graph(graph, self.pool): return lambda cg_mode: generate_fn(
generate_fn(
num_reqs, num_reqs,
num_tokens, num_tokens,
attn_metadata, attn_metadata,
slot_mappings, slot_mappings,
num_tokens_across_dp, num_tokens_across_dp,
CUDAGraphMode.NONE, cg_mode,
) )
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.PIECEWISE,
)
@torch.inference_mode()
def capture(
self,
generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> None:
if self.cudagraph_mode == CUDAGraphMode.NONE:
return
capture_graphs( super().capture(create_forward_fn, progress_bar_desc)
self.cudagraph_sizes,
self.device,
self.capture_graph,
capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn,
model_state=model_state,
input_buffers=input_buffers,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
)
def run_fullgraph(self, num_tokens: int) -> None: def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor:
assert num_tokens in self.graphs """Replay a captured FULL cudagraph and return draft tokens."""
# Sync offloader before replay - needed when transitioning from super().run_fullgraph(desc)
# eager/piecewise to full cudagraph (e.g., prefill → decode). return self.draft_tokens
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[num_tokens].replay()
...@@ -16,7 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -16,7 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer, build_slot_mappings_by_layer,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
...@@ -75,7 +75,16 @@ class EagleSpeculator: ...@@ -75,7 +75,16 @@ class EagleSpeculator:
device=device, device=device,
) )
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device) # currently we don't support PIECEWISE for Eagle.
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
else:
cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_manager = EagleCudaGraphManager(
vllm_config, device, cudagraph_mode, self.draft_tokens
)
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
self.model = load_eagle_model(target_model, self.vllm_config) self.model = load_eagle_model(target_model, self.vllm_config)
...@@ -171,7 +180,7 @@ class EagleSpeculator: ...@@ -171,7 +180,7 @@ class EagleSpeculator:
) )
if attn_metadata is not None: if attn_metadata is not None:
self.block_tables.compute_slot_mappings( self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos idx_mapping, query_start_loc, pos, num_tokens_padded
) )
def capture_model(self) -> None: def capture_model(self) -> None:
...@@ -185,6 +194,7 @@ class EagleSpeculator: ...@@ -185,6 +194,7 @@ class EagleSpeculator:
self.block_tables, self.block_tables,
self.attn_groups, self.attn_groups,
self.kv_cache_config, self.kv_cache_config,
progress_bar_desc="Capturing eagle CUDA graphs",
) )
@torch.inference_mode() @torch.inference_mode()
...@@ -251,6 +261,7 @@ class EagleSpeculator: ...@@ -251,6 +261,7 @@ class EagleSpeculator:
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs num_reqs = input_batch.num_reqs
num_reqs_padded = input_batch.num_reqs_after_padding
# NOTE(woosuk): For draft sampling, we only consider the temperature # NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p, # and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance. # for simplicity and performance.
...@@ -292,48 +303,52 @@ class EagleSpeculator: ...@@ -292,48 +303,52 @@ class EagleSpeculator:
self.max_num_reqs, self.max_num_reqs,
) )
if not (dummy_run and skip_attn_for_dummy_run): # Get batch descriptor and sync across DP ranks.
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] # Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
cudagraph_mode, cudagraph_size = ( batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1)
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs) num_tokens_across_dp = None
)
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = ( if self.dp_size > 1:
get_cudagraph_and_dp_padding( batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager,
batch_desc,
num_reqs, num_reqs,
cudagraph_size, num_reqs,
cudagraph_mode.value, 1, # uniform_token_count
self.dp_size, self.dp_size,
self.dp_rank, self.dp_rank,
) )
)
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode) if not (dummy_run and skip_attn_for_dummy_run):
if cudagraph_mode == CUDAGraphMode.FULL: query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
# Run full CUDA graph. slot_mappings = self.block_tables.compute_slot_mappings(
self.cudagraph_manager.run_fullgraph(num_tokens_padded) idx_mapping, query_start_loc, pos, batch_desc.num_tokens
return self.draft_tokens[:num_reqs] )
if batch_desc.cg_mode == CUDAGraphMode.FULL:
return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs]
# Run eager or piecewise CUDA graph. # Run eager or piecewise CUDA graph.
attn_metadata_updated = None attn_metadata_updated = None
slot_mappings_updated = None slot_mappings_updated = None
if not (dummy_run and skip_attn_for_dummy_run): if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc_cpu = torch.arange( query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu" num_reqs_padded + 1, dtype=torch.int32, device="cpu"
) )
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] block_tables = [
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
]
# FIXME(woosuk): This is UNSAFE!! # FIXME(woosuk): This is UNSAFE!!
attn_metadata_updated = build_attn_metadata( attn_metadata_updated = build_attn_metadata(
attn_groups=self.attn_groups, attn_groups=self.attn_groups,
num_reqs=num_reqs, num_reqs=num_reqs_padded,
num_tokens=num_reqs, num_tokens=num_reqs_padded,
query_start_loc_gpu=query_start_loc, query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1, max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs], seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
max_seq_len=self.max_model_len, max_seq_len=self.max_model_len,
block_tables=block_tables, block_tables=block_tables,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
...@@ -345,11 +360,11 @@ class EagleSpeculator: ...@@ -345,11 +360,11 @@ class EagleSpeculator:
self.generate_draft( self.generate_draft(
num_reqs, num_reqs,
num_tokens_padded, batch_desc.num_tokens,
attn_metadata_updated, attn_metadata_updated,
slot_mappings_updated, slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode, cudagraph_runtime_mode=batch_desc.cg_mode,
) )
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment