"vscode:/vscode.git/clone" did not exist on "28e0750847ded93158a66efdcbc869d87463b38f"
Unverified Commit 8f121f78 authored by zhrrr's avatar zhrrr Committed by GitHub
Browse files

[Model Runner V2] support auto resolve cudagraph mode/sizes based on attn backend (#32936)


Signed-off-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
parent cb5f7501
...@@ -26,6 +26,8 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer ...@@ -26,6 +26,8 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionCGSupport
from vllm.v1.kv_cache_interface import KVCacheConfig
else: else:
VllmConfig = object VllmConfig = object
...@@ -1241,6 +1243,152 @@ class CompilationConfig: ...@@ -1241,6 +1243,152 @@ class CompilationConfig:
assert "none" in self.custom_ops assert "none" in self.custom_ops
return f"+{op}" in self.custom_ops return f"+{op}" in self.custom_ops
def resolve_cudagraph_mode_and_sizes(
self,
min_cg_support: "AttentionCGSupport",
min_cg_attn_backend: str | None,
uniform_decode_query_len: int = 1,
tensor_parallel_size: int = 1,
kv_cache_config: "KVCacheConfig | None" = None,
max_num_reqs: int | None = None,
is_profiling: bool = False,
) -> CUDAGraphMode:
from vllm.v1.attention.backend import AttentionCGSupport
cudagraph_mode = self.cudagraph_mode
if cudagraph_mode is None or cudagraph_mode == CUDAGraphMode.NONE:
self.cudagraph_mode = CUDAGraphMode.NONE
return CUDAGraphMode.NONE
# Check cudagraph for mixed batch is supported
if (
cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL
and min_cg_support != AttentionCGSupport.ALWAYS
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_attn_backend} backend (support: "
f"{min_cg_support})"
)
if min_cg_support == AttentionCGSupport.NEVER:
# if not supported any full cudagraphs, just raise it.
msg += (
"; please try cudagraph_mode=PIECEWISE, and "
"make sure compilation mode is VLLM_COMPILE"
)
raise ValueError(msg)
# attempt to resolve the full cudagraph related mode
if self.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
logger.warning(msg)
# check that if we are doing decode full-cudagraphs it is supported
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and min_cg_support == AttentionCGSupport.NEVER
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_attn_backend} backend (support: "
f"{min_cg_support})"
)
if self.mode == CompilationMode.VLLM_COMPILE and (
self.splitting_ops_contain_attention()
or self.use_inductor_graph_partition
):
msg += (
"; setting cudagraph_mode=PIECEWISE because "
"attention is compiled piecewise"
)
cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
msg += (
"; setting cudagraph_mode=NONE because "
"attention is not compiled piecewise"
)
cudagraph_mode = CUDAGraphMode.NONE
logger.warning(msg)
# check that if we are doing spec-decode + decode full-cudagraphs it is
# supported
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and uniform_decode_query_len > 1
and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_cg_attn_backend} (support: {min_cg_support})"
)
if self.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=PIECEWISE"
cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE"
cudagraph_mode = CUDAGraphMode.NONE
logger.warning(msg)
# double check that we can support full cudagraph if they are requested
# even after automatic downgrades
if (
cudagraph_mode.has_full_cudagraphs()
and min_cg_support == AttentionCGSupport.NEVER
):
raise ValueError(
f"CUDAGraphMode.{cudagraph_mode.name} is not "
f"supported with {min_cg_attn_backend} backend ("
f"support:{min_cg_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation mode is VLLM_COMPILE"
)
# Adjust cudagraph sizes to be a multiple of uniform_decode_query_len
# to avoid: https://github.com/vllm-project/vllm/issues/28207 and temp-fix:
# https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
# Will be removed in the near future when we have separate cudagraph capture
# sizes for decode and mixed prefill-decode.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and uniform_decode_query_len > 1
):
self.adjust_cudagraph_sizes_for_spec_decode(
uniform_decode_query_len,
tensor_parallel_size,
)
# For Mamba models with FULL decode cudagraphs, each decode
# sequence needs one Mamba cache block. The decode cudagraph
# dispatcher already caps batch sizes at max_num_seqs, so we just
# need to verify that enough blocks exist. Raising here instead
# of silently capping cudagraph_capture_sizes avoids unintended
# restrictions on PIECEWISE (prefill) cudagraphs.
# See: https://github.com/vllm-project/vllm/issues/34094
if (
kv_cache_config is not None
and max_num_reqs is not None
and cudagraph_mode.has_full_cudagraphs()
and not is_profiling
and kv_cache_config.has_mamba_layers
and max_num_reqs > kv_cache_config.num_blocks
):
raise ValueError(
f"max_num_seqs ({max_num_reqs}) exceeds available Mamba cache "
f"blocks ({kv_cache_config.num_blocks}). Each decode sequence "
"requires one Mamba cache block, so CUDA graph capture cannot "
"proceed. Please lower max_num_seqs to at most "
f"{kv_cache_config.num_blocks} or increase "
"gpu_memory_utilization."
)
self.cudagraph_mode = cudagraph_mode
return cudagraph_mode
def adjust_cudagraph_sizes_for_spec_decode( def adjust_cudagraph_sizes_for_spec_decode(
self, uniform_decode_query_len: int, tensor_parallel_size: int self, uniform_decode_query_len: int, tensor_parallel_size: int
): ):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, cast from typing import Any, cast
import numpy as np import numpy as np
...@@ -8,7 +9,11 @@ import torch ...@@ -8,7 +9,11 @@ import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
KVCacheConfig, KVCacheConfig,
...@@ -18,6 +23,12 @@ from vllm.v1.kv_cache_interface import ( ...@@ -18,6 +23,12 @@ from vllm.v1.kv_cache_interface import (
from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache
@dataclass(frozen=True)
class AttentionCGSupportInfo:
min_cg_support: AttentionCGSupport = AttentionCGSupport.ALWAYS
min_cg_attn_backend: str | None = None
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
layer_type = cast(type[Any], AttentionLayerBase) layer_type = cast(type[Any], AttentionLayerBase)
...@@ -34,10 +45,17 @@ def init_attn_backend( ...@@ -34,10 +45,17 @@ def init_attn_backend(
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
active_layer_names: set[str] | None = None, active_layer_names: set[str] | None = None,
): ) -> tuple[
dict[str, type[AttentionBackend]],
list[list[AttentionGroup]],
AttentionCGSupportInfo,
]:
attn_backends: dict[str, type[AttentionBackend]] = {} attn_backends: dict[str, type[AttentionBackend]] = {}
attn_groups: list[list[AttentionGroup]] = [] attn_groups: list[list[AttentionGroup]] = []
attn_backend_workspace: torch.Tensor | None = None attn_backend_workspace: torch.Tensor | None = None
# Find minimum cudagraph support across all attention backends
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_attn_backend = None
for kv_cache_group_id, kv_cache_group_spec in enumerate( for kv_cache_group_id, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups kv_cache_config.kv_cache_groups
): ):
...@@ -86,8 +104,24 @@ def init_attn_backend( ...@@ -86,8 +104,24 @@ def init_attn_backend(
else: else:
if hasattr(builder, "set_workspace_buffer"): if hasattr(builder, "set_workspace_buffer"):
builder.set_workspace_buffer(attn_backend_workspace) builder.set_workspace_buffer(attn_backend_workspace)
# Check cudagraph support for the attention backend
cg_support = builder.get_cudagraph_support(
vllm_config,
cast(AttentionSpec, kv_cache_group_spec.kv_cache_spec),
)
if cg_support.value < min_cg_support.value:
min_cg_support = cg_support
min_cg_attn_backend = attn_backend.__name__
attn_groups.append(groups) attn_groups.append(groups)
return attn_backends, attn_groups
return (
attn_backends,
attn_groups,
AttentionCGSupportInfo(
min_cg_support=min_cg_support,
min_cg_attn_backend=min_cg_attn_backend,
),
)
def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device): def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
...@@ -110,7 +144,7 @@ def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device): ...@@ -110,7 +144,7 @@ def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
def _reshape_kv_cache( def _reshape_kv_cache(
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor], kv_cache_raw_tensors: dict[str, torch.Tensor],
attn_backends: dict[str, AttentionBackend], attn_backends: dict[str, type[AttentionBackend]],
cache_dtype: str, cache_dtype: str,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
...@@ -158,7 +192,7 @@ def init_kv_cache( ...@@ -158,7 +192,7 @@ def init_kv_cache(
runner_kv_caches: list[torch.Tensor], runner_kv_caches: list[torch.Tensor],
forward_context: dict[str, Any], forward_context: dict[str, Any],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
attn_backends: dict[str, AttentionBackend], attn_backends: dict[str, type[AttentionBackend]],
device: torch.device, device: torch.device,
cache_dtype: str, cache_dtype: str,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
......
...@@ -20,7 +20,7 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N ...@@ -20,7 +20,7 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
def sync_cudagraph_and_dp_padding( def sync_cudagraph_and_dp_padding(
cudagraph_manager: CudaGraphManager, cudagraph_manager: CudaGraphManager | None,
desired_batch_desc: BatchExecutionDescriptor, desired_batch_desc: BatchExecutionDescriptor,
num_tokens: int, num_tokens: int,
num_reqs: int, num_reqs: int,
...@@ -61,6 +61,10 @@ def sync_cudagraph_and_dp_padding( ...@@ -61,6 +61,10 @@ def sync_cudagraph_and_dp_padding(
num_reqs=num_reqs, num_reqs=num_reqs,
), num_tokens_across_dp ), num_tokens_across_dp
assert cudagraph_manager is not None, (
"cudagraph_manager should only be None during profile run, "
"where synced_cg_mode must be NONE across all DP ranks"
)
synced_num_tokens = int(num_tokens_across_dp.max().item()) synced_num_tokens = int(num_tokens_across_dp.max().item())
synced_uniform_token_count = uniform_token_counts_across_dp[0] synced_uniform_token_count = uniform_token_counts_across_dp[0]
# If ranks disagree on the uniform token count, or its 0 (means None) set to None # If ranks disagree on the uniform token count, or its 0 (means None) set to None
...@@ -79,3 +83,41 @@ def sync_cudagraph_and_dp_padding( ...@@ -79,3 +83,41 @@ def sync_cudagraph_and_dp_padding(
num_tokens_across_dp[:] = synced_desc.num_tokens num_tokens_across_dp[:] = synced_desc.num_tokens
return synced_desc, num_tokens_across_dp return synced_desc, num_tokens_across_dp
def dispatch_cg_and_sync_dp(
cudagraph_manager: CudaGraphManager | None,
num_reqs: int,
num_tokens: int,
uniform_token_count: int | None,
dp_size: int,
dp_rank: int,
need_eager: bool = False,
) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
if need_eager:
batch_desc = BatchExecutionDescriptor(
cg_mode=CUDAGraphMode.NONE,
num_tokens=num_tokens,
num_reqs=num_reqs,
)
else:
assert cudagraph_manager is not None, (
"cudagraph_manager should only be None during profile run, "
"where need_eager must be True"
)
batch_desc = cudagraph_manager.dispatch(
num_reqs, num_tokens, uniform_token_count
)
if dp_size == 1:
return batch_desc, None
return sync_cudagraph_and_dp_padding(
cudagraph_manager,
batch_desc,
num_tokens,
num_reqs,
uniform_token_count,
dp_size,
dp_rank,
)
...@@ -61,7 +61,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import ( ...@@ -61,7 +61,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
ModelCudaGraphManager, ModelCudaGraphManager,
get_uniform_token_count, get_uniform_token_count,
) )
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding from vllm.v1.worker.gpu.dp_utils import dispatch_cg_and_sync_dp
from vllm.v1.worker.gpu.eplb_utils import EPLBController, step_eplb_after from vllm.v1.worker.gpu.eplb_utils import EPLBController, step_eplb_after
from vllm.v1.worker.gpu.input_batch import ( from vllm.v1.worker.gpu.input_batch import (
InputBatch, InputBatch,
...@@ -176,6 +176,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -176,6 +176,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Draft tokens propagation - for spec-dec + struct outputs. # Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device) self.draft_tokens_handler = DraftTokensHandler(self.device)
self.uniform_decode_query_len = 1 + self.num_speculative_steps
# Pooling models. # Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling" self.is_pooling_model = self.model_config.runner_type == "pooling"
...@@ -224,14 +225,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -224,14 +225,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device, device=self.device,
) )
# CUDA graphs. # For CUDA graphs, and will init cudagraph_manager after init_attn_backend.
self.decode_query_len = self.num_speculative_steps + 1 self.decode_query_len = self.num_speculative_steps + 1
self.cudagraph_manager = ModelCudaGraphManager( self.cudagraph_manager: ModelCudaGraphManager | None = None
self.vllm_config,
self.device,
self.compilation_config.cudagraph_mode,
decode_query_len=self.decode_query_len,
)
# LoRA-related workers. # LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs) self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# KV Connector if configured. # KV Connector if configured.
...@@ -361,9 +357,26 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -361,9 +357,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cp_interleave=self.cp_interleave, cp_interleave=self.cp_interleave,
) )
self.attn_backends, self.attn_groups = init_attn_backend( self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device self.kv_cache_config, self.vllm_config, self.device
) )
cudagraph_mode = self.compilation_config.resolve_cudagraph_mode_and_sizes(
attn_cg_support.min_cg_support,
attn_cg_support.min_cg_attn_backend,
self.uniform_decode_query_len,
self.parallel_config.tensor_parallel_size,
self.kv_cache_config,
self.max_num_reqs,
)
self.cudagraph_manager = ModelCudaGraphManager(
self.vllm_config,
self.device,
cudagraph_mode,
decode_query_len=self.decode_query_len,
)
if self.speculator is not None:
self.speculator.init_cudagraph_manager(cudagraph_mode)
check_attention_cp_compatibility(self.vllm_config) check_attention_cp_compatibility(self.vllm_config)
if self.speculator is not None: if self.speculator is not None:
# HACK(woosuk) # HACK(woosuk)
...@@ -437,6 +450,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -437,6 +450,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
dummy_run=True, dummy_run=True,
skip_attn_for_dummy_run=skip_attn, skip_attn_for_dummy_run=skip_attn,
is_profile=is_profile,
) )
self.kv_connector.set_disabled(False) self.kv_connector.set_disabled(False)
...@@ -486,6 +500,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -486,6 +500,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dummy_run=True, dummy_run=True,
skip_attn_for_dummy_run=skip_attn, skip_attn_for_dummy_run=skip_attn,
mm_inputs=mm_inputs, mm_inputs=mm_inputs,
is_profile=is_profile,
) )
assert hidden_states is not None # Last PP rank always has hidden_states assert hidden_states is not None # Last PP rank always has hidden_states
...@@ -547,6 +562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -547,6 +562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode() @torch.inference_mode()
def capture_model(self) -> int: def capture_model(self) -> int:
assert self.cudagraph_manager is not None
if not self.cudagraph_manager.needs_capture(): if not self.cudagraph_manager.needs_capture():
logger.warning( logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, " "Skipping CUDA graph capture. To turn on CUDA graph capture, "
...@@ -915,6 +931,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -915,6 +931,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
dummy_run: bool = False, dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False, skip_attn_for_dummy_run: bool = False,
is_profile: bool = False,
) -> ModelRunnerOutput | IntermediateTensors | None: ) -> ModelRunnerOutput | IntermediateTensors | None:
if not dummy_run: if not dummy_run:
# Update the request states. # Update the request states.
...@@ -934,33 +951,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -934,33 +951,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_query_len = max(scheduler_output.num_scheduled_tokens.values()) max_query_len = max(scheduler_output.num_scheduled_tokens.values())
uniform_tok_count = get_uniform_token_count(num_reqs, num_toks, max_query_len) uniform_tok_count = get_uniform_token_count(num_reqs, num_toks, max_query_len)
batch_desc = self.cudagraph_manager.dispatch(
num_reqs, num_toks, uniform_tok_count
)
num_tokens_across_dp = None
skip_compiled = False skip_compiled = False
if self.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs: if self.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs:
# Encoder-decoder models such as Whisper should run eager/non-compiled # Encoder-decoder models such as Whisper should run eager/non-compiled
# when encoder inputs are scheduled, because this step updates # when encoder inputs are scheduled, because this step updates
# cross-attention cache with dynamic encoder outputs. # cross-attention cache with dynamic encoder outputs.
# Override batch_desc to NONE.
skip_compiled = True skip_compiled = True
batch_desc = BatchExecutionDescriptor(
cg_mode=CUDAGraphMode.NONE,
num_tokens=num_toks,
num_reqs=num_reqs,
)
if self.dp_size > 1: batch_desc, num_tokens_across_dp = dispatch_cg_and_sync_dp(
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager, self.cudagraph_manager,
batch_desc,
num_toks,
num_reqs, num_reqs,
num_toks,
uniform_tok_count, uniform_tok_count,
self.dp_size, self.dp_size,
self.dp_rank, self.dp_rank,
need_eager=is_profile or skip_compiled,
) )
if batch_desc.num_tokens == 0: if batch_desc.num_tokens == 0:
...@@ -1059,6 +1064,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1059,6 +1064,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Use explicit cudagraph replay for FULL mode. # Use explicit cudagraph replay for FULL mode.
# NOTE(woosuk): Here, we don't need to pass the input tensors, # NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers. # because they are already copied to the CUDA graph input buffers.
assert self.cudagraph_manager is not None
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
model_output = self.cudagraph_manager.run_fullgraph(batch_desc) model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
else: else:
......
...@@ -19,10 +19,7 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -19,10 +19,7 @@ from vllm.v1.worker.gpu.attn_utils import (
init_attn_backend, init_attn_backend,
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import ( from vllm.v1.worker.gpu.dp_utils import dispatch_cg_and_sync_dp
BatchExecutionDescriptor,
)
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
...@@ -98,15 +95,19 @@ class EagleSpeculator: ...@@ -98,15 +95,19 @@ class EagleSpeculator:
device=device, device=device,
) )
# currently we don't support PIECEWISE for Eagle. self.cudagraph_manager: EagleCudaGraphManager | None = None
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
def init_cudagraph_manager(self, cudagraph_mode: CUDAGraphMode) -> None:
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL: if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
else: else:
cudagraph_mode = CUDAGraphMode.NONE cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_manager = EagleCudaGraphManager( self.cudagraph_manager = EagleCudaGraphManager(
vllm_config, device, cudagraph_mode, self.draft_tokens self.vllm_config,
self.device,
cudagraph_mode,
self.draft_tokens,
) )
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
...@@ -133,7 +134,7 @@ class EagleSpeculator: ...@@ -133,7 +134,7 @@ class EagleSpeculator:
) -> None: ) -> None:
self.model_state = model_state self.model_state = model_state
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
_, self.attn_groups = init_attn_backend( _, self.attn_groups, _ = init_attn_backend(
kv_cache_config, kv_cache_config,
self.vllm_config, self.vllm_config,
self.device, self.device,
...@@ -242,29 +243,6 @@ class EagleSpeculator: ...@@ -242,29 +243,6 @@ class EagleSpeculator:
idx_mapping, query_start_loc, pos, num_tokens_padded idx_mapping, query_start_loc, pos, num_tokens_padded
) )
def _dispatch_and_sync_dp(
self,
cudagraph_manager: EagleCudaGraphManager,
num_reqs: int,
num_tokens: int,
uniform_token_count: int | None,
) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
batch_desc = cudagraph_manager.dispatch(
num_reqs, num_tokens, uniform_token_count
)
num_tokens_across_dp = None
if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
cudagraph_manager,
batch_desc,
num_tokens,
num_reqs,
uniform_token_count,
self.dp_size,
self.dp_rank,
)
return batch_desc, num_tokens_across_dp
def _build_draft_attn_metadata( def _build_draft_attn_metadata(
self, self,
num_reqs: int, num_reqs: int,
...@@ -303,8 +281,10 @@ class EagleSpeculator: ...@@ -303,8 +281,10 @@ class EagleSpeculator:
return attn_metadata return attn_metadata
def capture_model(self) -> None: def capture_model(self) -> None:
assert self.cudagraph_manager is not None
if self.num_speculative_steps == 1: if self.num_speculative_steps == 1:
return return
logger.info("Capturing model for Eagle speculator...") logger.info("Capturing model for Eagle speculator...")
self.cudagraph_manager.capture( self.cudagraph_manager.capture(
self.generate_draft, self.generate_draft,
...@@ -342,6 +322,7 @@ class EagleSpeculator: ...@@ -342,6 +322,7 @@ class EagleSpeculator:
dummy_run: bool = False, dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False, skip_attn_for_dummy_run: bool = False,
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
is_profile: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and # number of rejected tokens, we maintain the size of eagle's input_ids and
...@@ -430,11 +411,14 @@ class EagleSpeculator: ...@@ -430,11 +411,14 @@ class EagleSpeculator:
# Each request produces exactly 1 token per draft decode step, # Each request produces exactly 1 token per draft decode step,
# enabling FULL cudagraph. # enabling FULL cudagraph.
decode_batch_desc, num_tokens_across_dp = self._dispatch_and_sync_dp( decode_batch_desc, num_tokens_across_dp = dispatch_cg_and_sync_dp(
self.cudagraph_manager, self.cudagraph_manager,
num_reqs, num_reqs,
num_reqs, num_reqs,
uniform_token_count=1, uniform_token_count=1,
dp_size=self.dp_size,
dp_rank=self.dp_rank,
need_eager=is_profile,
) )
attn_metadata_updated = None attn_metadata_updated = None
...@@ -461,6 +445,7 @@ class EagleSpeculator: ...@@ -461,6 +445,7 @@ class EagleSpeculator:
) )
if decode_batch_desc.cg_mode == CUDAGraphMode.FULL: if decode_batch_desc.cg_mode == CUDAGraphMode.FULL:
assert self.cudagraph_manager is not None
self.cudagraph_manager.run_fullgraph(decode_batch_desc) self.cudagraph_manager.run_fullgraph(decode_batch_desc)
else: else:
self.generate_draft( self.generate_draft(
......
...@@ -6307,7 +6307,7 @@ class GPUModelRunner( ...@@ -6307,7 +6307,7 @@ class GPUModelRunner(
cudagraph_mode. cudagraph_mode.
""" """
min_cg_support = AttentionCGSupport.ALWAYS min_cg_support = AttentionCGSupport.ALWAYS
min_cg_backend_name = None min_cg_attn_backend = None
for attn_backend_set, kv_cache_group in zip( for attn_backend_set, kv_cache_group in zip(
attention_backends, kv_cache_groups attention_backends, kv_cache_groups
...@@ -6320,152 +6320,18 @@ class GPUModelRunner( ...@@ -6320,152 +6320,18 @@ class GPUModelRunner(
) )
if cg_support.value < min_cg_support.value: if cg_support.value < min_cg_support.value:
min_cg_support = cg_support min_cg_support = cg_support
min_cg_backend_name = attn_backend.__name__ min_cg_attn_backend = attn_backend.__name__
# Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.resolve_cudagraph_mode_and_sizes(
cudagraph_mode = self.compilation_config.cudagraph_mode min_cg_support,
assert cudagraph_mode is not None min_cg_attn_backend,
# check cudagraph for mixed batch is supported self.uniform_decode_query_len,
if ( self.parallel_config.tensor_parallel_size,
cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL self.kv_cache_config,
and min_cg_support != AttentionCGSupport.ALWAYS self.max_num_reqs,
): is_profiling=is_profiling,
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_backend_name} backend (support: "
f"{min_cg_support})"
)
if min_cg_support == AttentionCGSupport.NEVER:
# if not supported any full cudagraphs, just raise it.
msg += (
"; please try cudagraph_mode=PIECEWISE, and "
"make sure compilation mode is VLLM_COMPILE"
)
raise ValueError(msg)
# attempt to resolve the full cudagraph related mode
if self.compilation_config.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
cudagraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_AND_PIECEWISE
)
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
cudagraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_DECODE_ONLY
)
logger.warning(msg)
# check that if we are doing decode full-cudagraphs it is supported
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and min_cg_support == AttentionCGSupport.NEVER
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_backend_name} backend (support: "
f"{min_cg_support})"
)
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and (
self.compilation_config.splitting_ops_contain_attention()
or self.compilation_config.use_inductor_graph_partition
):
msg += (
"; setting cudagraph_mode=PIECEWISE because "
"attention is compiled piecewise"
)
cudagraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.PIECEWISE
)
else:
msg += (
"; setting cudagraph_mode=NONE because "
"attention is not compiled piecewise"
)
cudagraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.NONE
)
logger.warning(msg)
# check that if we are doing spec-decode + decode full-cudagraphs it is
# supported
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.uniform_decode_query_len > 1
and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value
):
msg = (
f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_cg_backend_name} (support: {min_cg_support})"
)
if self.compilation_config.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=PIECEWISE"
cudagraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.PIECEWISE
)
else:
msg += "; setting cudagraph_mode=NONE"
cudagraph_mode = self.compilation_config.cudagraph_mode = (
CUDAGraphMode.NONE
)
logger.warning(msg)
# double check that we can support full cudagraph if they are requested
# even after automatic downgrades
if (
cudagraph_mode.has_full_cudagraphs()
and min_cg_support == AttentionCGSupport.NEVER
):
raise ValueError(
f"CUDAGraphMode.{cudagraph_mode.name} is not "
f"supported with {min_cg_backend_name} backend ("
f"support:{min_cg_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation mode is VLLM_COMPILE"
)
# if we have dedicated decode cudagraphs, and spec-decode is enabled,
# we need to adjust the cudagraph sizes to be a multiple of the uniform
# decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
# temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
# Will be removed in the near future when we have separate cudagraph capture
# sizes for decode and mixed prefill-decode.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and cudagraph_mode.separate_routine()
and self.uniform_decode_query_len > 1
):
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
)
# For Mamba models with FULL decode cudagraphs, each decode
# sequence needs one Mamba cache block. The decode cudagraph
# dispatcher already caps batch sizes at max_num_seqs, so we just
# need to verify that enough blocks exist. Raising here instead
# of silently capping cudagraph_capture_sizes avoids unintended
# restrictions on PIECEWISE (prefill) cudagraphs.
# See: https://github.com/vllm-project/vllm/issues/34094
if cudagraph_mode.has_full_cudagraphs() and not is_profiling:
has_mamba = any(
isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups
)
if has_mamba and self.kv_cache_config is not None:
num_blocks = self.kv_cache_config.num_blocks
if self.max_num_reqs > num_blocks:
raise ValueError(
f"max_num_seqs ({self.max_num_reqs}) exceeds "
f"available Mamba cache blocks ({num_blocks}). "
f"Each decode sequence requires one Mamba cache "
f"block, so CUDA graph capture cannot proceed. "
f"Please lower max_num_seqs to at most "
f"{num_blocks} or increase "
f"gpu_memory_utilization."
) )
# Trigger cudagraph dispatching keys initialization after # Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode. # resolved cudagraph mode.
self.compilation_config.cudagraph_mode = cudagraph_mode
self.cudagraph_dispatcher.initialize_cudagraph_keys( self.cudagraph_dispatcher.initialize_cudagraph_keys(
cudagraph_mode, self.uniform_decode_query_len cudagraph_mode, self.uniform_decode_query_len
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment