Unverified Commit e24e0a43 authored by Lain's avatar Lain Committed by GitHub
Browse files

[Attention] relax the head dim 512 and paged kv for sm90+FA4 (#38835)


Signed-off-by: default avatarSiyuan Fu <siyuanf@nvidia.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent b55d830e
...@@ -39,7 +39,7 @@ else() ...@@ -39,7 +39,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG c0ec424fd8a546d0cbbf4bf050bbcfe837c55afb GIT_TAG f5bc33cfc02c744d24a2e9d50e6db656de40611c
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
......
...@@ -154,6 +154,17 @@ def get_flash_attn_version( ...@@ -154,6 +154,17 @@ def get_flash_attn_version(
return None return None
def is_fa_version_supported(fa_version: int) -> bool:
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
is_fa_version_supported as _is_fa_version_supported,
)
return _is_fa_version_supported(fa_version)
except ImportError:
return False
def flash_attn_supports_fp8() -> bool: def flash_attn_supports_fp8() -> bool:
return ( return (
get_flash_attn_version() == 3 get_flash_attn_version() == 3
......
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
import torch import torch
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
...@@ -20,6 +21,7 @@ from vllm.v1.attention.backend import ( ...@@ -20,6 +21,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_fp8, flash_attn_supports_fp8,
get_flash_attn_version, get_flash_attn_version,
is_fa_version_supported,
is_flash_attn_varlen_func_available, is_flash_attn_varlen_func_available,
) )
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
...@@ -45,7 +47,6 @@ from vllm.config import ( ...@@ -45,7 +47,6 @@ from vllm.config import (
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv, round_up from vllm.utils.math_utils import cdiv, round_up
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
...@@ -170,7 +171,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -170,7 +171,13 @@ class FlashAttentionBackend(AttentionBackend):
@classmethod @classmethod
def supports_head_size(cls, head_size: int) -> bool: def supports_head_size(cls, head_size: int) -> bool:
return head_size % 8 == 0 and head_size <= 256 if head_size % 8 != 0:
return False
if head_size <= 256:
return True
if is_fa_version_supported(4):
return head_size <= 512
return False
@classmethod @classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
...@@ -618,6 +625,14 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -618,6 +625,14 @@ class FlashAttentionImpl(AttentionImpl):
requires_alibi=alibi_slopes is not None, requires_alibi=alibi_slopes is not None,
head_size=head_size, head_size=head_size,
) )
# head_size > 256 requires FA4 on SM90+; force upgrade from FA3
if (
head_size > 256
and self.vllm_flash_attn_version == 3
and current_platform.is_cuda()
and current_platform.is_device_capability_family(90)
):
self.vllm_flash_attn_version = 4
logger.info_once( logger.info_once(
"Using FlashAttention version %s", "Using FlashAttention version %s",
self.vllm_flash_attn_version, self.vllm_flash_attn_version,
......
...@@ -366,14 +366,7 @@ def flash_attn_varlen_func( ...@@ -366,14 +366,7 @@ def flash_attn_varlen_func(
) )
elif fa_version == 4: elif fa_version == 4:
assert alibi_slopes is None, "Alibi is not supported in FA4" assert alibi_slopes is None, "Alibi is not supported in FA4"
# FA4 on SM90 doesn't support paged KV; SM100+ does
from vllm.platforms import current_platform
if block_table is not None and current_platform.is_device_capability_family(90):
raise NotImplementedError(
"FA4 with paged KV is not supported on SM90 (Hopper). "
"Use FA3 or upgrade to Blackwell (SM100+)."
)
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd
out, softmax_lse = _flash_attn_fwd( out, softmax_lse = _flash_attn_fwd(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment