Unverified Commit 74f441f4 authored by fhl2000's avatar fhl2000 Committed by GitHub
Browse files

[Core] Allow full cudagraph with separate attention routines and orthogonal to...


[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)
Signed-off-by: default avatarfhl <2410591650@qq.com>
Signed-off-by: default avatarfhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
parent a0632a3e
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CUDAGraphMode
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
...@@ -100,16 +101,17 @@ class XPUPlatform(Platform): ...@@ -100,16 +101,17 @@ class XPUPlatform(Platform):
# Instances created using VllmConfig() typically have model_config as # Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent # None by default. The modification involves adding a check to prevent
# potential null exceptions check and update model config. # potential null exceptions check and update model config.
if model_config is not None: if model_config is not None and model_config.dtype == torch.bfloat16 \
if model_config.dtype == torch.bfloat16: and not cls.device_support_bf16():
bf16_supported = cls.device_support_bf16() model_config.dtype = torch.float16
if not bf16_supported:
model_config.dtype = torch.float16 compilation_config = vllm_config.compilation_config
if not model_config.enforce_eager: if compilation_config.cudagraph_mode is None or \
logger.warning( compilation_config.cudagraph_mode.max_cudagraph_mode() \
"CUDA graph is not supported on XPU, fallback to the eager " != CUDAGraphMode.NONE:
"mode.") logger.info("[XPU] CUDA graph is not supported on XPU, "
model_config.enforce_eager = True "disabling cudagraphs.")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# check and update parallel config # check and update parallel config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import Optional
import numpy as np import numpy as np
import torch import torch
...@@ -154,9 +154,26 @@ def _get_sliding_window_configs( ...@@ -154,9 +154,26 @@ def _get_sliding_window_configs(
class FlashAttentionMetadataBuilder( class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]): AttentionMetadataBuilder[FlashAttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ # FA3:
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \ # Supports full cudagraphs for all cases.
else AttentionCGSupport.ALWAYS #
# FA2:
# For FA2, a graph is captured with max_query_len=1, (which is what we
# capture by default for num_tokens <= max_num_seqs when there is no
# spec-decode) then these graphs will not work for mixed prefill-decode
# (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
# in FA2.
# In summary if we are running with spec decodes the graphs would
# work for mixed prefill-decode and uniform-decode. But for non-spec decodes
# the graphs would not work for mixed prefill-decode; sorta the inverse
# of UNIFORM_SINGLE_TOKEN_DECODE.
# Theres probably a better way to describe this using `AttentionCGSupport`
# but for now just set it to `UNIFORM_BATCH` to get use to drop down
# to FULL_AND_PIECEWISE.
# TODO(luka, lucas): audit FA2 as part of:
# https://github.com/vllm-project/vllm/issues/22945
cudagraph_support = AttentionCGSupport.ALWAYS \
if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -177,17 +194,13 @@ class FlashAttentionMetadataBuilder( ...@@ -177,17 +194,13 @@ class FlashAttentionMetadataBuilder(
self.max_num_splits = 0 # No upper bound on the number of splits. self.max_num_splits = 0 # No upper bound on the number of splits.
self.aot_schedule = (get_flash_attn_version() == 3) self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = self.compilation_config.full_cuda_graph
if self.use_full_cuda_graph: self.use_full_cuda_graph = \
if not self.aot_schedule: self.compilation_config.cudagraph_mode.has_full_cudagraphs()
raise ValueError(
"AoT scheduling is required for full cuda graph.") if self.use_full_cuda_graph and self.aot_schedule:
capture_sizes = self.compilation_config.cudagraph_capture_sizes self.max_cudagraph_size = self.compilation_config.max_capture_size
if not capture_sizes:
raise ValueError(
"cudagraph_capture_sizes should not be None when "
"full_cuda_graph is True.")
self.max_cudagraph_size = max(capture_sizes)
if self.max_cudagraph_size > 992: if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic. # This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes. # TODO(woosuk): Support larger cudagraph sizes.
...@@ -310,9 +323,9 @@ class FlashAttentionMetadataBuilder( ...@@ -310,9 +323,9 @@ class FlashAttentionMetadataBuilder(
seqlens=seq_lens, seqlens=seq_lens,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
causal=causal) causal=causal)
# For FA3 + full cudagraph
if self.use_full_cuda_graph: max_num_splits = 0
assert scheduler_metadata is not None if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0] n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n] = scheduler_metadata self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler # NOTE(woosuk): We should zero out the rest of the scheduler
...@@ -322,14 +335,12 @@ class FlashAttentionMetadataBuilder( ...@@ -322,14 +335,12 @@ class FlashAttentionMetadataBuilder(
self.scheduler_metadata[n:] = 0 self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n] scheduler_metadata = self.scheduler_metadata[:n]
max_num_splits = 0 if num_actual_tokens <= self.max_cudagraph_size:
if (self.use_full_cuda_graph # NOTE(woosuk): Setting num_splits > 1 may increase the memory
and num_actual_tokens <= self.max_cudagraph_size): # usage, because the intermediate buffers of size [num_splits,
# NOTE(woosuk): Setting num_splits > 1 may increase the memory # num_heads, num_tokens, head_size] are allocated. Therefore,
# usage, because the intermediate buffers of size [num_splits, # we only set num_splits when using cuda graphs.
# num_heads, num_tokens, head_size] are allocated. Therefore, max_num_splits = self.max_num_splits
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
attn_metadata = FlashAttentionMetadata( attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
...@@ -350,11 +361,6 @@ class FlashAttentionMetadataBuilder( ...@@ -350,11 +361,6 @@ class FlashAttentionMetadataBuilder(
causal=causal) causal=causal)
return attn_metadata return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs) return use_cascade_attention(*args, **kwargs)
......
...@@ -17,7 +17,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache ...@@ -17,7 +17,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType) AttentionType)
from vllm.config import VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv, is_pin_memory_available from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import use_trtllm_attention from vllm.utils.flashinfer import use_trtllm_attention
...@@ -183,8 +183,8 @@ class FlashInferMetadata: ...@@ -183,8 +183,8 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: ClassVar[int] = 1
...@@ -203,7 +203,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -203,7 +203,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.kv_cache_spec.block_size) self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.full_cuda_graph self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL
if self.enable_cuda_graph: if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch # For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer. # size is needed for FlashInfer.
...@@ -586,10 +587,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -586,10 +587,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
return self.build(0, m) return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting # TODO: The cascade wrapper currently does not support setting
......
...@@ -89,8 +89,8 @@ class Mamba2AttentionMetadata: ...@@ -89,8 +89,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder( class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]): AttentionMetadataBuilder[Mamba2AttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: ClassVar[int] = 1
...@@ -203,7 +203,3 @@ class Mamba2AttentionMetadataBuilder( ...@@ -203,7 +203,3 @@ class Mamba2AttentionMetadataBuilder(
m.max_query_len = 1 # decode-only m.max_query_len = 1 # decode-only
return self.build(0, m) return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
...@@ -575,7 +575,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -575,7 +575,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"MLA only supports decode-only full CUDAGraph capture. " \ "MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq." "Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only assert m.max_query_len == 1 # decode-only
return self.build(0, m) return self.build(0, m)
...@@ -728,10 +728,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -728,10 +728,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return attn_metadata return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
""" """
......
...@@ -22,7 +22,7 @@ logger = init_logger(__name__) ...@@ -22,7 +22,7 @@ logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture # enable full CUDA Graph support for decode-only capture
attn_cudagraph_support: ClassVar[ attn_cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
class CutlassMLABackend(MLACommonBackend): class CutlassMLABackend(MLACommonBackend):
......
...@@ -55,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): ...@@ -55,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY AttentionCGSupport.UNIFORM_BATCH
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -73,7 +73,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -73,7 +73,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
device_properties = torch.cuda.get_device_properties(self.device) device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count num_sms = device_properties.multi_processor_count
if self.compilation_config.full_cuda_graph: if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.cg_buf_tile_scheduler_metadata = torch.zeros( self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize) # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8 # TileSchedulerMetaDataSize = 8
...@@ -95,7 +95,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -95,7 +95,10 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1, # MQA for the decode path 1, # MQA for the decode path
) )
if self.compilation_config.full_cuda_graph: # TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None assert self.cg_buf_num_splits is not None
......
...@@ -65,8 +65,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): ...@@ -65,8 +65,10 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ # TODO(luka, lucas): audit this as part of:
AttentionCGSupport.PURE_DECODE_ONLY # https://github.com/vllm-project/vllm/issues/22945
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -82,7 +84,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -82,7 +84,10 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_pages = max_num_reqs * max_num_pages_per_req max_num_pages = max_num_reqs * max_num_pages_per_req
# Preparing persistent buffers # Preparing persistent buffers
if vllm_config.compilation_config.full_cuda_graph: # TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
...@@ -120,7 +125,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): ...@@ -120,7 +125,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
block_table_bounds.cumsum(dim=0, dtype=torch.int32) block_table_bounds.cumsum(dim=0, dtype=torch.int32)
]) ])
if self.compilation_config.full_cuda_graph: if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
num_actual_pages = paged_kv_indices.size(0) num_actual_pages = paged_kv_indices.size(0)
......
...@@ -311,11 +311,6 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -311,11 +311,6 @@ class AiterFlashAttentionMetadataBuilder(
) )
return attn_metadata return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
return False return False
......
...@@ -58,8 +58,7 @@ class TritonAttentionMetadata: ...@@ -58,8 +58,7 @@ class TritonAttentionMetadata:
class TritonAttentionMetadataBuilder( class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]): AttentionMetadataBuilder[TritonAttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -132,11 +131,6 @@ class TritonAttentionMetadataBuilder( ...@@ -132,11 +131,6 @@ class TritonAttentionMetadataBuilder(
) )
return attn_metadata return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported
return True
class TritonAttentionBackend(AttentionBackend): class TritonAttentionBackend(AttentionBackend):
......
...@@ -158,18 +158,21 @@ class AttentionCGSupport(enum.Enum): ...@@ -158,18 +158,21 @@ class AttentionCGSupport(enum.Enum):
Here we do not consider the cascade attention, as currently Here we do not consider the cascade attention, as currently
it is never cudagraph supported.""" it is never cudagraph supported."""
ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0 NEVER = 0
"""NO cudagraph support""" """NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]): class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention. # Does this backend/builder support CUDA Graphs for attention (default: no).
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch? # Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query # If not, set this to None. Otherwise set it to the query
...@@ -199,13 +202,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -199,13 +202,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
""" """
raise NotImplementedError raise NotImplementedError
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
"""
Can this batch (with given metadata) use CUDA Graphs for attention.
"""
return False
def build_for_cudagraph_capture( def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M: self, common_attn_metadata: CommonAttentionMetadata) -> M:
""" """
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
logger = init_logger(__name__)
class CudagraphDispatcher:
"""
Runtime cudagraph dispatcher to dispach keys for multiple set of cudagraphs.
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
for FULL cudagraph runtime mode. The keys are initialized depending on
attention support and what cudagraph mode is set in CompilationConfig. The
keys stored in dispatcher are the only source of truth for valid
cudagraphs that can be dispatched at runtime.
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
based on the input key. After dispatching (commuicate via forward context),
the cudagraph wrappers will trust the dispatch key to do either capturing
or replaying (if mode matched), or pass through to the underlying runnable
without cudagraph (if mode no match or mode is NONE).
"""
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.cudagraph_mode = self.compilation_config.cudagraph_mode
# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
CUDAGraphMode.PIECEWISE: set(),
CUDAGraphMode.FULL: set(),
}
assert not self.cudagraph_mode.requires_piecewise_compilation() or \
(self.compilation_config.level == CompilationLevel.PIECEWISE and
self.compilation_config.splitting_ops_contain_attention()), \
"Compilation level should be CompilationLevel.PIECEWISE when "\
"cudagraph_mode piecewise cudagraphs is used, "\
f"cudagraph_mode={self.cudagraph_mode}, "\
f"compilation_level={self.compilation_config.level}, "\
f"splitting_ops={self.compilation_config.splitting_ops}"
self.keys_initialized = False
def add_cudagraph_key(self, runtime_mode: CUDAGraphMode,
batch_descriptor: BatchDescriptor):
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \
f"Invalid cudagraph runtime mode: {runtime_mode}"
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int):
# This should be called only after attention backend is initialized.
# Note: we create all valid keys possible for cudagraph but do not
# guarantee all keys would be used. For example, we create keys for
# piecewise cudagraphs when it is piecewise compilation, which is always
# valid, but for attention backend support unified routine, we may not
# trigger capturing/replaying the piecewise cudagraphs depending on
# CompilationConfig.cudagraph_mode. In addition, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs in self.compilation_config.cudagraph_capture_sizes:
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(num_tokens=bs, uniform_decode=False))
# if decode cudagraph mode is FULL, and we don't already have mixed
# mode full cudagraphs then add them here.
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \
and cudagraph_mode.separate_routine():
max_num_tokens = uniform_decode_query_len * \
self.vllm_config.scheduler_config.max_num_seqs
cudagraph_capture_sizes_for_decode = [
x for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs in cudagraph_capture_sizes_for_decode:
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(num_tokens=bs, uniform_decode=True))
self.keys_initialized = True
def dispatch(
self, batch_descriptor: BatchDescriptor
) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]:
"""
Given a batch descriptor, dispatch to a cudagraph mode.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
logger.warning_once("cudagraph dispatching keys are not "
"initialized. No cudagraph will be used.")
return CUDAGraphMode.NONE, None
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor
# otherwise, check if non-uniform key exists
non_uniform_key = batch_descriptor.non_uniform
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# also check if non-uniform key exists for more "general"
# piecewise cudagraph
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, non_uniform_key
# finally, just return no cudagraphs
return CUDAGraphMode.NONE, None
This diff is collapsed.
...@@ -322,16 +322,11 @@ class Worker(WorkerBase): ...@@ -322,16 +322,11 @@ class Worker(WorkerBase):
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs, max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph = self.compilation_config.full_cuda_graph and\
not self.model_config.enforce_eager
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \ hidden_states, last_hidden_states = \
self.model_runner._dummy_run( self.model_runner._dummy_run(
num_tokens=max_num_reqs, num_tokens=max_num_reqs,
capture_attn_cudagraph=attn_cudagraph,
skip_eplb=True, skip_eplb=True,
) )
if self.model_runner.is_pooling_model: if self.model_runner.is_pooling_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment