Unverified Commit 23322431 authored by fhl2000's avatar fhl2000 Committed by GitHub
Browse files

[V1][CUDA] Full cudagraph support for FlashInfer (#21367)

parent 3654847d
...@@ -25,7 +25,8 @@ if is_flash_attn_varlen_func_available(): ...@@ -25,7 +25,8 @@ if is_flash_attn_varlen_func_available():
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
get_kv_cache_layout) get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
...@@ -153,7 +154,9 @@ def _get_sliding_window_configs( ...@@ -153,7 +154,9 @@ def _get_sliding_window_configs(
class FlashAttentionMetadataBuilder( class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]): AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
else AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
......
This diff is collapsed.
...@@ -18,6 +18,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, ...@@ -18,6 +18,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder) MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): ...@@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
......
...@@ -17,6 +17,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, ...@@ -17,6 +17,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder) MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
# yapf: enable # yapf: enable
...@@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): ...@@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
......
...@@ -18,7 +18,8 @@ from vllm.config import VllmConfig ...@@ -18,7 +18,8 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
...@@ -57,7 +58,8 @@ class TritonAttentionMetadata: ...@@ -57,7 +58,8 @@ class TritonAttentionMetadata:
class TritonAttentionMetadataBuilder( class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]): AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc import abc
import enum
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, make_dataclass from dataclasses import dataclass, make_dataclass
...@@ -65,9 +66,24 @@ class CommonAttentionMetadata: ...@@ -65,9 +66,24 @@ class CommonAttentionMetadata:
M = TypeVar("M") M = TypeVar("M")
class AttentionCGSupport(enum.Enum):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]): class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention. # Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
@abstractmethod @abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
......
...@@ -47,7 +47,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, ...@@ -47,7 +47,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
is_pin_memory_available, round_up, supports_dynamo) is_pin_memory_available, round_up, supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata, make_kv_sharing_fast_prefill_attention_metadata,
make_local_attention_virtual_batches) make_local_attention_virtual_batches)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
...@@ -2619,12 +2619,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2619,12 +2619,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.device, self.device,
) )
if (self.full_cuda_graph if self.full_cuda_graph:
and not attn_metadata_builder_i.full_cudagraph_supported): if attn_metadata_builder_i.attn_cudagraph_support == \
raise ValueError( AttentionCGSupport.NEVER:
f"Full CUDAGraph not supported for " raise ValueError(f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig." f"{attn_backend_i.__name__}. Turn off "
f"full_cuda_graph or use a different attention backend.") f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_backend_i, attn_metadata_builder_i return attn_backend_i, attn_metadata_builder_i
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
......
...@@ -321,11 +321,16 @@ class Worker(WorkerBase): ...@@ -321,11 +321,16 @@ class Worker(WorkerBase):
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs, max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph = self.compilation_config.full_cuda_graph and\
not self.model_config.enforce_eager
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \ hidden_states, last_hidden_states = \
self.model_runner._dummy_run( self.model_runner._dummy_run(
num_tokens=max_num_reqs, num_tokens=max_num_reqs,
capture_attn_cudagraph=attn_cudagraph,
skip_eplb=True, skip_eplb=True,
) )
if self.model_runner.is_pooling_model: if self.model_runner.is_pooling_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment