Unverified Commit 74f441f4 authored by fhl2000's avatar fhl2000 Committed by GitHub
Browse files

[Core] Allow full cudagraph with separate attention routines and orthogonal to...


[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)
Signed-off-by: default avatarfhl <2410591650@qq.com>
Signed-off-by: default avatarfhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent a0632a3e
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CUDAGraphMode
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
...@@ -100,16 +101,17 @@ class XPUPlatform(Platform): ...@@ -100,16 +101,17 @@ class XPUPlatform(Platform):
# Instances created using VllmConfig() typically have model_config as # Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent # None by default. The modification involves adding a check to prevent
# potential null exceptions check and update model config. # potential null exceptions check and update model config.
if model_config is not None: if model_config is not None and model_config.dtype == torch.bfloat16 \
if model_config.dtype == torch.bfloat16: and not cls.device_support_bf16():
bf16_supported = cls.device_support_bf16() model_config.dtype = torch.float16
if not bf16_supported:
model_config.dtype = torch.float16 compilation_config = vllm_config.compilation_config
if not model_config.enforce_eager: if compilation_config.cudagraph_mode is None or \
logger.warning( compilation_config.cudagraph_mode.max_cudagraph_mode() \
"CUDA graph is not supported on XPU, fallback to the eager " != CUDAGraphMode.NONE:
"mode.") logger.info("[XPU] CUDA graph is not supported on XPU, "
model_config.enforce_eager = True "disabling cudagraphs.")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# check and update parallel config # check and update parallel config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import Optional
import numpy as np import numpy as np
import torch import torch
...@@ -154,9 +154,26 @@ def _get_sliding_window_configs( ...@@ -154,9 +154,26 @@ def _get_sliding_window_configs(
class FlashAttentionMetadataBuilder( class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]): AttentionMetadataBuilder[FlashAttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ # FA3:
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \ # Supports full cudagraphs for all cases.
else AttentionCGSupport.ALWAYS #
# FA2:
# For FA2, a graph is captured with max_query_len=1, (which is what we
# capture by default for num_tokens <= max_num_seqs when there is no
# spec-decode) then these graphs will not work for mixed prefill-decode
# (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
# in FA2.
# In summary if we are running with spec decodes the graphs would
# work for mixed prefill-decode and uniform-decode. But for non-spec decodes
# the graphs would not work for mixed prefill-decode; sorta the inverse
# of UNIFORM_SINGLE_TOKEN_DECODE.
# Theres probably a better way to describe this using `AttentionCGSupport`
# but for now just set it to `UNIFORM_BATCH` to get use to drop down
# to FULL_AND_PIECEWISE.
# TODO(luka, lucas): audit FA2 as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support = AttentionCGSupport.ALWAYS \
if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -177,17 +194,13 @@ class FlashAttentionMetadataBuilder( ...@@ -177,17 +194,13 @@ class FlashAttentionMetadataBuilder(
self.max_num_splits = 0 # No upper bound on the number of splits. self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3) self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = self.compilation_config.full_cuda_graph
if self.use_full_cuda_graph: self.use_full_cuda_graph = \
if not self.aot_schedule: self.compilation_config.cudagraph_mode.has_full_cudagraphs()
raise ValueError(
"AoT scheduling is required for full cuda graph.") if self.use_full_cuda_graph and self.aot_schedule:
capture_sizes = self.compilation_config.cudagraph_capture_sizes self.max_cudagraph_size = self.compilation_config.max_capture_size
if not capture_sizes:
raise ValueError(
"cudagraph_capture_sizes should not be None when "
"full_cuda_graph is True.")
self.max_cudagraph_size = max(capture_sizes)
if self.max_cudagraph_size > 992: if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic. # This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes. # TODO(woosuk): Support larger cudagraph sizes.
...@@ -310,9 +323,9 @@ class FlashAttentionMetadataBuilder( ...@@ -310,9 +323,9 @@ class FlashAttentionMetadataBuilder(
seqlens=seq_lens, seqlens=seq_lens,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
causal=causal) causal=causal)
# For FA3 + full cudagraph
if self.use_full_cuda_graph: max_num_splits = 0
assert scheduler_metadata is not None if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0] n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler # NOTE(woosuk): We should zero out the rest of the scheduler
...@@ -322,14 +335,12 @@ class FlashAttentionMetadataBuilder( ...@@ -322,14 +335,12 @@ class FlashAttentionMetadataBuilder(
self.scheduler_metadata[n:] = 0 self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n] scheduler_metadata = self.scheduler_metadata[:n]
max_num_splits = 0 if num_actual_tokens <= self.max_cudagraph_size:
if (self.use_full_cuda_graph # NOTE(woosuk): Setting num_splits > 1 may increase the memory
and num_actual_tokens <= self.max_cudagraph_size): # usage, because the intermediate buffers of size [num_splits,
# NOTE(woosuk): Setting num_splits > 1 may increase the memory # num_heads, num_tokens, head_size] are allocated. Therefore,
# usage, because the intermediate buffers of size [num_splits, # we only set num_splits when using cuda graphs.
# num_heads, num_tokens, head_size] are allocated. Therefore, max_num_splits = self.max_num_splits
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
attn_metadata = FlashAttentionMetadata( attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
...@@ -350,11 +361,6 @@ class FlashAttentionMetadataBuilder( ...@@ -350,11 +361,6 @@ class FlashAttentionMetadataBuilder(
causal=causal) causal=causal)
return attn_metadata return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs) return use_cascade_attention(*args, **kwargs)
......
...@@ -17,7 +17,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache ...@@ -17,7 +17,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType) AttentionType)
from vllm.config import VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv, is_pin_memory_available from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_attention from vllm.utils.flashinfer import use_trtllm_attention
...@@ -183,8 +183,8 @@ class FlashInferMetadata: ...@@ -183,8 +183,8 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: ClassVar[int] = 1
...@@ -203,7 +203,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -203,7 +203,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.kv_cache_spec.block_size) self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.full_cuda_graph self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL
if self.enable_cuda_graph: if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch # For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer. # size is needed for FlashInfer.
...@@ -586,10 +587,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -586,10 +587,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
return self.build(0, m) return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting # TODO: The cascade wrapper currently does not support setting
......
...@@ -89,8 +89,8 @@ class Mamba2AttentionMetadata: ...@@ -89,8 +89,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder( class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]): AttentionMetadataBuilder[Mamba2AttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: ClassVar[int] = 1
...@@ -203,7 +203,3 @@ class Mamba2AttentionMetadataBuilder( ...@@ -203,7 +203,3 @@ class Mamba2AttentionMetadataBuilder(
m.max_query_len = 1 # decode-only m.max_query_len = 1 # decode-only
return self.build(0, m) return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
...@@ -575,7 +575,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -575,7 +575,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"MLA only supports decode-only full CUDAGraph capture. " \ "MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq." "Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only assert m.max_query_len == 1 # decode-only
return self.build(0, m) return self.build(0, m)
...@@ -728,10 +728,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -728,10 +728,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return attn_metadata return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
""" """
......
...@@ -22,7 +22,7 @@ logger = init_logger(__name__) ...@@ -22,7 +22,7 @@ logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture # enable full CUDA Graph support for decode-only capture
attn_cudagraph_support: ClassVar[ attn_cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
class CutlassMLABackend(MLACommonBackend): class CutlassMLABackend(MLACommonBackend):
......
...@@ -55,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): ...@@ -55,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -73,7 +73,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -73,7 +73,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
device_properties = torch.cuda.get_device_properties(self.device) device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count num_sms = device_properties.multi_processor_count
if self.compilation_config.full_cuda_graph: if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros( self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize) # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8 # TileSchedulerMetaDataSize = 8
...@@ -95,7 +95,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -95,7 +95,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1, # MQA for the decode path 1, # MQA for the decode path
) )
if self.compilation_config.full_cuda_graph: # TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None assert self.cg_buf_num_splits is not None
......
...@@ -65,8 +65,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): ...@@ -65,8 +65,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ # TODO(luka, lucas): audit this as part of:
AttentionCGSupport.PURE_DECODE_ONLY # https://github.com/vllm-project/vllm/issues/22945
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -82,7 +84,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -82,7 +84,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_pages = max_num_reqs * max_num_pages_per_req max_num_pages = max_num_reqs * max_num_pages_per_req
# Preparing persistent buffers # Preparing persistent buffers
if vllm_config.compilation_config.full_cuda_graph: # TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
...@@ -120,7 +125,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -120,7 +125,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
block_table_bounds.cumsum(dim=0, dtype=torch.int32) block_table_bounds.cumsum(dim=0, dtype=torch.int32)
]) ])
if self.compilation_config.full_cuda_graph: if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0) num_actual_pages = paged_kv_indices.size(0)
......
...@@ -311,11 +311,6 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -311,11 +311,6 @@ class AiterFlashAttentionMetadataBuilder(
) )
return attn_metadata return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
return False return False
......
...@@ -58,8 +58,7 @@ class TritonAttentionMetadata: ...@@ -58,8 +58,7 @@ class TritonAttentionMetadata:
class TritonAttentionMetadataBuilder( class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]): AttentionMetadataBuilder[TritonAttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -132,11 +131,6 @@ class TritonAttentionMetadataBuilder( ...@@ -132,11 +131,6 @@ class TritonAttentionMetadataBuilder(
) )
return attn_metadata return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported
return True
class TritonAttentionBackend(AttentionBackend): class TritonAttentionBackend(AttentionBackend):
......
...@@ -158,18 +158,21 @@ class AttentionCGSupport(enum.Enum): ...@@ -158,18 +158,21 @@ class AttentionCGSupport(enum.Enum):
Here we do not consider the cascade attention, as currently Here we do not consider the cascade attention, as currently
it is never cudagraph supported.""" it is never cudagraph supported."""
ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0 NEVER = 0
"""NO cudagraph support""" """NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]): class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention. # Does this backend/builder support CUDA Graphs for attention (default: no).
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch? # Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query # If not, set this to None. Otherwise set it to the query
...@@ -199,13 +202,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -199,13 +202,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
""" """
raise NotImplementedError raise NotImplementedError
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
"""
Can this batch (with given metadata) use CUDA Graphs for attention.
"""
return False
def build_for_cudagraph_capture( def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M: self, common_attn_metadata: CommonAttentionMetadata) -> M:
""" """
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
logger = init_logger(__name__)
class CudagraphDispatcher:
"""
Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs.
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
for FULL cudagraph runtime mode. The keys are initialized depending on
attention support and what cudagraph mode is set in CompilationConfig. The
keys stored in dispatcher are the only source of truth for valid
cudagraphs that can be dispatched at runtime.
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (commuicate via forward context),
the cudagraph wrappers will trust the dispatch key to do either capturing
or replaying (if mode matched), or pass through to the underlying runnable
without cudagraph (if mode no match or mode is NONE).
"""
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.cudagraph_mode = self.compilation_config.cudagraph_mode
# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
CUDAGraphMode.PIECEWISE: set(),
CUDAGraphMode.FULL: set(),
}
assert not self.cudagraph_mode.requires_piecewise_compilation() or \
(self.compilation_config.level == CompilationLevel.PIECEWISE and
self.compilation_config.splitting_ops_contain_attention()), \
"Compilation level should be CompilationLevel.PIECEWISE when "\
"cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.cudagraph_mode}, "\
f"compilation_level={self.compilation_config.level}, "\
f"splitting_ops={self.compilation_config.splitting_ops}"
self.keys_initialized = False
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
batch_descriptor: BatchDescriptor):
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {runtime_mode}"
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int):
# This should be called only after attention backend is initialized.
# Note: we create all valid keys possible for cudagraph but do not
# guarantee all keys would be used. For example, we create keys for
# piecewise cudagraphs when it is piecewise compilation, which is always
# valid, but for attention backend support unified routine, we may not
# trigger capturing/replaying the piecewise cudagraphs depending on
# CompilationConfig.cudagraph_mode. In addition, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs in self.compilation_config.cudagraph_capture_sizes:
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(num_tokens=bs, uniform_decode=False))
# if decode cudagraph mode is FULL, and we don't already have mixed
# mode full cudagraphs then add them here.
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \
and cudagraph_mode.separate_routine():
max_num_tokens = uniform_decode_query_len * \
self.vllm_config.scheduler_config.max_num_seqs
cudagraph_capture_sizes_for_decode = [
x for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in cudagraph_capture_sizes_for_decode:
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(num_tokens=bs, uniform_decode=True))
self.keys_initialized = True
def dispatch(
self, batch_descriptor: BatchDescriptor
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
"""
Given a batch descriptor, dispatch to a cudagraph mode.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
logger.warning_once("cudagraph dispatching keys are not "
"initialized. No cudagraph will be used.")
return CUDAGraphMode.NONE, None
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor
# otherwise, check if non-uniform key exists
non_uniform_key = batch_descriptor.non_uniform
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# also check if non-uniform key exists for more "general"
# piecewise cudagraph
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, non_uniform_key
# finally, just return no cudagraphs
return CUDAGraphMode.NONE, None
...@@ -21,7 +21,9 @@ from vllm.attention import Attention, AttentionType ...@@ -21,7 +21,9 @@ from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig, from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, update_config) get_layers_from_vllm_config, update_config)
from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
...@@ -29,7 +31,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, ...@@ -29,7 +31,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, is_global_first_rank, get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model) prepare_communication_buffer_for_model)
from vllm.forward_context import DPMetadata, set_forward_context from vllm.forward_context import (BatchDescriptor, DPMetadata,
set_forward_context)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
...@@ -48,13 +51,15 @@ from vllm.sampling_params import SamplingType ...@@ -48,13 +51,15 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, GiB_bytes, LazyLoader, cdiv, check_use_alibi,
is_pin_memory_available, round_up, supports_dynamo) get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata, make_kv_sharing_fast_prefill_attention_metadata,
reorder_batch_to_split_decodes_and_prefills) reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (AttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig, FullAttentionSpec, KVCacheConfig,
...@@ -218,11 +223,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -218,11 +223,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_spec_decode=bool(self.vllm_config.speculative_config), is_spec_decode=bool(self.vllm_config.speculative_config),
) )
self.use_cuda_graph = (
self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and self.vllm_config.compilation_config.use_cudagraph
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different. # The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order. # self.cudagraph_batch_sizes sorts in ascending order.
...@@ -230,8 +230,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -230,8 +230,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes)) reversed(self.compilation_config.cudagraph_capture_sizes))
self.full_cuda_graph = self.compilation_config.full_cuda_graph
# Cache the device properties. # Cache the device properties.
self._init_device_properties() self._init_device_properties()
...@@ -326,6 +324,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -326,6 +324,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices = torch.zeros( self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=self.device) self.max_num_tokens, dtype=torch.int32, device=self.device)
self.uniform_decode_query_len = 1 if not self.speculative_config else \
1 + self.speculative_config.num_speculative_tokens
# Cudagraph dispatcher for runtime cudagraph dispatching.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.mm_budget = (MultiModalBudget( self.mm_budget = (MultiModalBudget(
self.model_config, self.model_config,
self.scheduler_config, self.scheduler_config,
...@@ -471,7 +475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -471,7 +475,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert (task := pooling_params.task) is not None, ( assert (task := pooling_params.task) is not None, (
"You did not set `task` in the API") "You did not set `task` in the API")
model = cast(VllmModelForPooling, self.model) model = cast(VllmModelForPooling, self.get_model())
to_update = model.pooler.get_pooling_updates(task) to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params) to_update.apply(pooling_params)
...@@ -679,13 +683,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -679,13 +683,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _prepare_inputs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], np.ndarray, Optional[CommonAttentionMetadata], int]:
np.ndarray, Optional[CommonAttentionMetadata]]:
""" """
:return: tuple[ :return: tuple[
attn_metadata: layer-to-attention_metadata mapping, attn_metadata: layer-to-attention_metadata mapping,
attention_cuda_graphs: whether attention can run in cudagraph
logits_indices, spec_decode_metadata logits_indices, spec_decode_metadata
] ]
""" """
...@@ -820,7 +822,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -820,7 +822,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# valid, we fill the padded indices with the last index. # valid, we fill the padded indices with the last index.
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item()) logits_indices[-1].item())
if (self.use_cuda_graph if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_logits <= self.cudagraph_batch_sizes[-1]): and num_logits <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs. # Use piecewise CUDA graphs.
# Add padding to the batch size. # Add padding to the batch size.
...@@ -925,17 +927,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -925,17 +927,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue continue
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
attention_cuda_graphs = all(
g.metadata_builder.can_run_in_cudagraph(common_attn_metadata)
for g in self._attn_group_iterator())
# Hot-Swap lora model # Hot-Swap lora model
if self.lora_config: if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens) self.set_active_loras(self.input_batch, num_scheduled_tokens)
return (attn_metadata, attention_cuda_graphs, logits_indices, return (attn_metadata, logits_indices, spec_decode_metadata,
spec_decode_metadata, num_scheduled_tokens, num_scheduled_tokens, spec_decode_common_attn_metadata,
spec_decode_common_attn_metadata) max_num_scheduled_tokens)
def _compute_cascade_attn_prefix_len( def _compute_cascade_attn_prefix_len(
self, self,
...@@ -1259,6 +1257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1259,6 +1257,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return mm_embeds return mm_embeds
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
# get raw model out of the cudagraph wrapper.
if isinstance(self.model, CUDAGraphWrapper):
return self.model.unwrap()
return self.model return self.model
def get_supported_generation_tasks(self) -> list[GenerationTask]: def get_supported_generation_tasks(self) -> list[GenerationTask]:
...@@ -1415,9 +1416,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1415,9 +1416,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return return
assert self.eplb_state is not None assert self.eplb_state is not None
assert is_mixture_of_experts(self.model) model = self.get_model()
assert is_mixture_of_experts(model)
self.eplb_state.step( self.eplb_state.step(
self.model, model,
is_dummy, is_dummy,
is_profile, is_profile,
log_stats=self.parallel_config.eplb_log_balancedness, log_stats=self.parallel_config.eplb_log_balancedness,
...@@ -1507,15 +1509,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1507,15 +1509,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config) self.vllm_config)
# Prepare the decoder inputs. # Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices, (attn_metadata, logits_indices, spec_decode_metadata,
spec_decode_metadata, num_scheduled_tokens_np, num_scheduled_tokens_np, spec_decode_common_attn_metadata,
spec_decode_common_attn_metadata) = ( max_query_len) = (self._prepare_inputs(scheduler_output))
self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs. # Use CUDA graphs.
# Add padding to the batch size. # Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph( num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens) num_scheduled_tokens)
...@@ -1581,10 +1582,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1581,10 +1582,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True) num_input_tokens, intermediate_tensors, True)
# Some attention backends only support CUDA Graphs in pure decode. uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
# If attention doesn't support CUDA Graphs for this batch, but we num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
# compiled with full CUDA graphs, we have to skip them entirely. batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(batch_descriptor)
# Run the model. # Run the model.
# Use persistent buffers for CUDA graphs. # Use persistent buffers for CUDA graphs.
...@@ -1593,10 +1596,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1593,10 +1596,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs, cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
), self.maybe_get_kv_connector_output( ), self.maybe_get_kv_connector_output(
scheduler_output) as kv_connector_output: scheduler_output) as kv_connector_output:
model_output = self.model( model_output = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
...@@ -2021,20 +2024,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2021,20 +2024,31 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model.compile( self.model.compile(
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend) backend=backend)
return
# for other compilation levels, cudagraph behavior is controlled by
# CudagraphWraper and CudagraphDispatcher of vllm.
# wrap the model with full cudagraph wrapper if needed.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.model = CUDAGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
def reload_weights(self) -> None: def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, \ assert getattr(self, "model", None) is not None, \
"Cannot reload weights before model is loaded." "Cannot reload weights before model is loaded."
model_loader = get_model_loader(self.load_config) model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...") logger.info("Reloading weights inplace...")
model_loader.load_weights(self.model, model_config=self.model_config) model = self.get_model()
model_loader.load_weights(model, model_config=self.model_config)
def save_tensorized_model( def save_tensorized_model(
self, self,
tensorizer_config: "TensorizerConfig", tensorizer_config: "TensorizerConfig",
) -> None: ) -> None:
model = self.get_model()
TensorizerLoader.save_model( TensorizerLoader.save_model(
self.model, model,
tensorizer_config=tensorizer_config, tensorizer_config=tensorizer_config,
model_config=self.model_config, model_config=self.model_config,
) )
...@@ -2210,31 +2224,82 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2210,31 +2224,82 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _dummy_run( def _dummy_run(
self, self,
num_tokens: int, num_tokens: int,
capture_attn_cudagraph: bool = False, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
force_attention: bool = False,
uniform_decode: bool = False,
skip_eplb: bool = False, skip_eplb: bool = False,
is_profile: bool = False, is_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
CUDA graph for the model.
Args:
num_tokens: Number of tokens to run the dummy forward pass.
cudagraph_runtime_mode: used to control the behavior.
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is
needed.
force_attention: If True, always create attention metadata. Used to
warm up attention backend when mode is NONE.
uniform_decode: If True, the batch is a uniform decode batch.
skip_eplb: If True, skip EPLB state update.
is_profile: If True, this is a profile run.
"""
assert cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
# Padding for DP # Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad num_tokens += num_pad
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.seperate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
# uniform decode batches. A uniform decode batch means that all
# requests have identical query length, except a potential virtual
# request (shorter) in the batch account for padding.
# Uniform decode batch could either be common pure decode, where
# max_query_len == 1, or speculative decode, where
# max_query_len == 1 + num_spec_decode_tokens.
# When setting max_query_len = 1, we switch to and capture the optimized
# routine of FA2 for pure decode, i.e., Flashdecode + an optimization
# for GQA/MQA.
max_query_len = self.uniform_decode_query_len if uniform_decode else \
num_tokens
# Set num_scheduled_tokens based on num_tokens and max_num_seqs # Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively # for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total. # has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs) if uniform_decode:
min_tokens_per_req = num_tokens // num_reqs num_reqs = cdiv(num_tokens, max_query_len)
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs assert num_reqs <= max_num_reqs, \
num_scheduled_tokens_list[-1] += num_tokens % num_reqs "Do not capture num_reqs > max_num_reqs for uniform batch"
num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
else:
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list, num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32) dtype=np.int32)
attn_metadata: Optional[dict[str, Any]] = None attn_metadata: Optional[dict[str, Any]] = None
if capture_attn_cudagraph:
# If force_attention is True, we always capture attention. Otherwise,
# it only happens for cudagraph_runtime_mode=FULL.
if force_attention or cudagraph_runtime_mode == \
CUDAGraphMode.FULL:
attn_metadata = {} attn_metadata = {}
# Make sure max_model_len is used at the graph capture time. # Make sure max_model_len is used at the graph capture time.
...@@ -2255,7 +2320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2255,7 +2320,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens_cpu_tensor[:num_reqs], num_computed_tokens_cpu_tensor[:num_reqs],
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
max_query_len=num_tokens, max_query_len=max_query_len,
block_table_tensor=self.input_batch.block_table[ block_table_tensor=self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()[:num_reqs], kv_cache_group_id].get_device_tensor()[:num_reqs],
slot_mapping=self.input_batch. slot_mapping=self.input_batch.
...@@ -2299,12 +2364,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2299,12 +2364,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens, None, False) num_tokens, None, False)
if cudagraph_runtime_mode == CUDAGraphMode.NONE:
batch_descriptor = None
else:
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode))
# sanity check
assert cudagraph_runtime_mode == _cg_mode, (
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
with self.maybe_randomize_inputs(input_ids), set_forward_context( with self.maybe_randomize_inputs(input_ids), set_forward_context(
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp): num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor):
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
...@@ -2436,7 +2515,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2436,7 +2515,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
model = cast(VllmModelForPooling, self.model) model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task) dummy_pooling_params = PoolingParams(task=task)
to_update = model.pooler.get_pooling_updates(task) to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params) to_update.apply(dummy_pooling_params)
...@@ -2546,12 +2625,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2546,12 +2625,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
gc.collect() gc.collect()
def capture_model(self) -> None: def capture_model(self) -> None:
if not self.use_cuda_graph: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
logger.warning( logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, " "Skipping CUDA graph capture. To turn on CUDA graph capture, "
"set -O %s and ensure `use_cudagraph` was not manually set to " "ensure `cudagraph_mode` was not manually set to `NONE`")
"False", CompilationLevel.PIECEWISE)
return return
else:
self.initialize_cudagraph_capture()
compilation_counter.num_gpu_runner_capture_triggers += 1 compilation_counter.num_gpu_runner_capture_triggers += 1
...@@ -2576,25 +2656,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2576,25 +2656,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Trigger CUDA graph capture for specific shapes. # Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes # Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes. # can reuse the memory pool allocated for the large shapes.
set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device): with freeze_gc(), graph_capture(device=self.device):
full_cg = self.full_cuda_graph cudagraph_mode = self.compilation_config.cudagraph_mode
# Only rank 0 should print progress bar during capture if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
compilation_cases = reversed(self.cudagraph_batch_sizes) cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
if is_global_first_rank():
compilation_cases = tqdm( compilation_cases = list(reversed(self.cudagraph_batch_sizes))
list(compilation_cases), self._capture_cudagraphs(
disable=not self.load_config.use_tqdm_on_load, compilation_cases,
desc="Capturing CUDA graph shapes") cudagraph_runtime_mode=cudagraph_runtime_mode,
for num_tokens in compilation_cases: uniform_decode=False)
# We skip EPLB here since we don't want to record dummy metrics
for _ in range( # Capture full cudagraph for uniform decode batches if we have
self.compilation_config.cudagraph_num_of_warmups): # dont already have full mixed prefill-decode cudagraphs
self._dummy_run(num_tokens, if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
capture_attn_cudagraph=full_cg, cudagraph_mode.separate_routine():
skip_eplb=True) max_num_tokens = self.scheduler_config.max_num_seqs * \
self._dummy_run(num_tokens, self.uniform_decode_query_len
capture_attn_cudagraph=full_cg, decode_cudagraph_batch_sizes = [
skip_eplb=True) x for x in self.cudagraph_batch_sizes if
x <= max_num_tokens and x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
reversed(decode_cudagraph_batch_sizes))
self._capture_cudagraphs(
compilation_cases=compilation_cases_decode,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True)
# Disable cudagraph capturing globally, so any unexpected cudagraph
# capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because
# we may doing lazy capturing in future that still allows capturing
# after here.
set_cudagraph_capturing_enabled(False)
end_time = time.perf_counter() end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]
...@@ -2604,6 +2700,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2604,6 +2700,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30)) elapsed_time, cuda_graph_size / (1 << 30))
def _capture_cudagraphs(self, compilation_cases: list[int],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool):
assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \
cudagraph_runtime_mode in [CUDAGraphMode.FULL,
CUDAGraphMode.PIECEWISE]
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
compilation_cases = tqdm(
compilation_cases,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
cudagraph_runtime_mode.name))
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = (
cudagraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
skip_eplb=True)
self._dummy_run(num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
skip_eplb=True)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
""" """
Initialize the attention backends and attention metadata builders. Initialize the attention backends and attention metadata builders.
...@@ -2648,25 +2779,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2648,25 +2779,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builder_i, attn_metadata_builder_i,
layer_names) layer_names)
attn_groups.append(attn_group) attn_groups.append(attn_group)
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_groups return attn_groups
for kv_cache_group_spec in kv_cache_config.kv_cache_groups: for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
...@@ -2734,6 +2846,75 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2734,6 +2846,75 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"All or none of the layers are expected to be encoder-only" "All or none of the layers are expected to be encoder-only"
self.is_encoder_only_model = True self.is_encoder_only_model = True
def initialize_cudagraph_capture(self) -> None:
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_builder_name = None
for attn_group in self._attn_group_iterator():
builder = attn_group.metadata_builder
if builder.cudagraph_support.value < min_cg_support.value:
min_cg_support = builder.cudagraph_support
min_cg_builder_name = builder.__class__.__name__
# Flexible resolve the cudagraph mode
cudagraph_mode = self.compilation_config.cudagraph_mode
# check cudagraph for mixed batch is supported
if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \
and min_cg_support != AttentionCGSupport.ALWAYS:
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
f"with {min_cg_builder_name} backend (support: "
f"{min_cg_support})")
if min_cg_support == AttentionCGSupport.NEVER:
# if not supported any full cudagraphs, just raise it.
msg += "; please try cudagraph_mode=PIECEWISE, and "\
"make sure compilation level is piecewise"
raise ValueError(msg)
# attempt to resolve the full cudagraph related mode
if self.compilation_config.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.FULL_AND_PIECEWISE
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.FULL_DECODE_ONLY
logger.warning(msg)
# check that if we are doing spec-decode + decode full-cudagraphs it is
# supported
if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.uniform_decode_query_len > 1 and min_cg_support.value
< AttentionCGSupport.UNIFORM_BATCH.value):
msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_cg_builder_name} (support: {min_cg_support})")
if self.compilation_config.splitting_ops_contain_attention():
msg += "; setting cudagraph_mode=PIECEWISE"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE"
cudagraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
# double check that we can support full cudagraph if they are requested
# even after automatic downgrades
if cudagraph_mode.has_full_cudagraphs() \
and min_cg_support == AttentionCGSupport.NEVER:
raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not "
f"supported with {min_cg_builder_name} backend ("
f"support:{min_cg_support}) "
"; please try cudagraph_mode=PIECEWISE, "
"and make sure compilation level is piecewise")
# Trigger cudagraph dispatching keys initialization here (after
# initializing attn backends).
self.cudagraph_dispatcher.initialize_cudagraph_keys(
self.compilation_config.cudagraph_mode,
self.uniform_decode_query_len)
def calculate_reorder_batch_threshold(self) -> None: def calculate_reorder_batch_threshold(self) -> None:
""" """
Check that if any backends reorder batches; that the reordering Check that if any backends reorder batches; that the reordering
......
...@@ -322,16 +322,11 @@ class Worker(WorkerBase): ...@@ -322,16 +322,11 @@ class Worker(WorkerBase):
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs, max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph = self.compilation_config.full_cuda_graph and\
not self.model_config.enforce_eager
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \ hidden_states, last_hidden_states = \
self.model_runner._dummy_run( self.model_runner._dummy_run(
num_tokens=max_num_reqs, num_tokens=max_num_reqs,
capture_attn_cudagraph=attn_cudagraph,
skip_eplb=True, skip_eplb=True,
) )
if self.model_runner.is_pooling_model: if self.model_runner.is_pooling_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment