Unverified Commit 8b5014d3 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] FA4 integration (#32974)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 57a96e26
...@@ -9,6 +9,7 @@ steps: ...@@ -9,6 +9,7 @@ steps:
- tests/v1 - tests/v1
commands: commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# split the test to avoid interference # split the test to avoid interference
- pytest -v -s -m 'not cpu_test' v1/core - pytest -v -s -m 'not cpu_test' v1/core
- pytest -v -s v1/executor - pytest -v -s v1/executor
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# vllm-flash-attn built from source # vllm-flash-attn built from source
vllm/vllm_flash_attn/* vllm/vllm_flash_attn/*
!vllm/vllm_flash_attn/__init__.py
!vllm/vllm_flash_attn/flash_attn_interface.py
# OpenAI triton kernels copied from source # OpenAI triton kernels copied from source
vllm/third_party/triton_kernels/* vllm/third_party/triton_kernels/*
......
...@@ -17,7 +17,8 @@ endif() ...@@ -17,7 +17,8 @@ endif()
# They should be identical but if they aren't, this is a massive footgun. # They should be identical but if they aren't, this is a massive footgun.
# #
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. # The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). # To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3),
# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files).
# If no component is specified, vllm-flash-attn is still installed. # If no component is specified, vllm-flash-attn is still installed.
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. # If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
...@@ -38,7 +39,7 @@ else() ...@@ -38,7 +39,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6 GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
...@@ -46,38 +47,62 @@ else() ...@@ -46,38 +47,62 @@ else()
endif() endif()
# Ensure the vllm/vllm_flash_attn directory exists before installation # Install rules for FA components need the install prefix nested under vllm/
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS) # These run at install time, before the FA library's own install rules
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
# Make sure vllm-flash-attn install rules are nested under vllm/ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT ${_FA_COMPONENT})
# This is here to support installing all components under the same prefix with cmake --install. install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT})
# setup.py installs every component separately but uses the same prefix for all. install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT ${_FA_COMPONENT})
# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3, endforeach()
# and these statements don't hurt when installing neither component.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS)
# Fetch the vllm-flash-attn library # Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn) FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
# Restore the install prefix # Restore the install prefix after FA's install rules
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT})
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT ${_FA_COMPONENT})
endforeach()
# Install shared Python files for both FA2 and FA3 components
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")"
COMPONENT ${_FA_COMPONENT})
# Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py
# which are source-controlled in vllm)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT ${_FA_COMPONENT}
FILES_MATCHING PATTERN "*.py"
PATTERN "__init__.py" EXCLUDE
PATTERN "flash_attn_interface.py" EXCLUDE
)
endforeach()
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in #
# case only one is built, in the case both are built redundant work is done) # FA4 CuteDSL component
install( # This is a Python-only component that copies the flash_attn/cute directory
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ # and transforms imports to match our package structure.
DESTINATION vllm/vllm_flash_attn #
COMPONENT _vllm_fa2_C add_custom_target(_vllm_fa4_cutedsl_C)
FILES_MATCHING PATTERN "*.py"
)
install( # Copy flash_attn/cute directory (needed for FA4) and transform imports
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ # The cute directory uses flash_attn.cute imports internally, which we replace
DESTINATION vllm/vllm_flash_attn # with vllm.vllm_flash_attn.cute to match our package structure.
COMPONENT _vllm_fa3_C install(CODE "
FILES_MATCHING PATTERN "*.py" file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\")
) foreach(SRC_FILE \${CUTE_PY_FILES})
file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE})
set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\")
get_filename_component(DST_DIR \${DST_FILE} DIRECTORY)
file(MAKE_DIRECTORY \${DST_DIR})
file(READ \${SRC_FILE} FILE_CONTENTS)
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
endforeach()
" COMPONENT _vllm_fa4_cutedsl_C)
...@@ -168,6 +168,7 @@ Priority is **1 = highest** (tried first). ...@@ -168,6 +168,7 @@ Priority is **1 = highest** (tried first).
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
...@@ -178,7 +179,7 @@ Priority is **1 = highest** (tried first). ...@@ -178,7 +179,7 @@ Priority is **1 = highest** (tried first).
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
> >
> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, FA2 otherwise. > **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2`, `3`, or `4`. Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), FA2 otherwise.
## MLA (Multi-head Latent Attention) Backends ## MLA (Multi-head Latent Attention) Backends
......
...@@ -11,3 +11,7 @@ torchaudio==2.10.0 ...@@ -11,3 +11,7 @@ torchaudio==2.10.0
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile # FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.4 flashinfer-python==0.6.4
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
nvidia-cutlass-dsl>=4.4.0.dev1
quack-kernels>=0.2.7
...@@ -976,6 +976,11 @@ if _is_cuda(): ...@@ -976,6 +976,11 @@ if _is_cuda():
): ):
# FA3 requires CUDA 12.3 or later # FA3 requires CUDA 12.3 or later
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
# FA4 CuteDSL - Python-only component for FA4's cute DSL support
# Optional since this doesn't produce a .so file, just copies Python files
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa4_cutedsl_C", optional=True)
)
if envs.VLLM_USE_PRECOMPILED or ( if envs.VLLM_USE_PRECOMPILED or (
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9") CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9")
): ):
......
...@@ -563,14 +563,53 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None ...@@ -563,14 +563,53 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill) # Special backend variant parsers (FA2/FA3/FA4, FlashInfer TRTLLM, MLA prefill)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _parse_fa4_supported_caps() -> str | None:
"""Parse flash_attn_interface.py for FA4 supported compute capabilities.
Looks for `cc not in [9, 10, 11]` pattern in _is_fa4_supported().
"""
fa_interface_file = (
REPO_ROOT / "vllm" / "vllm_flash_attn" / "flash_attn_interface.py"
)
if not fa_interface_file.exists():
return None
try:
tree = ast.parse(fa_interface_file.read_text())
except Exception:
return None
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef) or node.name != "_is_fa4_supported":
continue
for n in ast.walk(node):
if not (
isinstance(n, ast.Compare)
and len(n.ops) == 1
and isinstance(n.ops[0], ast.NotIn)
and isinstance(n.comparators[0], ast.List)
):
continue
caps: list[int] = [
e.value
for e in n.comparators[0].elts
if isinstance(e, ast.Constant) and isinstance(e.value, int)
]
if caps:
caps.sort()
return f"{caps[0]}.x-{caps[-1]}.x"
return None
def parse_flash_attn_features() -> dict[str, dict[str, Any]]: def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences. """Parse fa_utils.py to detect FA2 vs FA3 vs FA4 feature differences.
Returns a dict with 'fa2' and 'fa3' keys containing their respective Returns a dict with 'fa2', 'fa3', and 'fa4' keys containing their respective
feature overrides for compute capability, KV cache dtypes, and sink support. feature overrides for compute capability, KV cache dtypes, and sink support.
""" """
if not FA_UTILS_FILE.exists(): if not FA_UTILS_FILE.exists():
...@@ -585,6 +624,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -585,6 +624,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_fp8 = False fa3_supports_fp8 = False
fa3_supports_sinks = False fa3_supports_sinks = False
fa3_compute_cap: str | None = None fa3_compute_cap: str | None = None
fa4_compute_cap: str | None = None
for node in ast.walk(tree): for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef): if not isinstance(node, ast.FunctionDef):
...@@ -614,14 +654,12 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -614,14 +654,12 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_sinks = True fa3_supports_sinks = True
break break
# Check get_flash_attn_version for FA3 compute capability # Check get_flash_attn_version for FA3/FA4 compute capability
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
if node.name == "get_flash_attn_version": if node.name == "get_flash_attn_version":
for n in ast.walk(node): for n in ast.walk(node):
# Look for IfExp (ternary) with `device_capability.major == 9` # Handle IfExp (ternary) with `device_capability.major == 9`
if isinstance(n, ast.IfExp): if isinstance(n, ast.IfExp):
test = n.test test = n.test
# Check if test is a BoolOp (and) containing the major check
if isinstance(test, ast.BoolOp): if isinstance(test, ast.BoolOp):
for val in test.values: for val in test.values:
if ( if (
...@@ -634,6 +672,38 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -634,6 +672,38 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_compute_cap = f"{val.comparators[0].value}.x" fa3_compute_cap = f"{val.comparators[0].value}.x"
break break
# Handle If statements for FA3/FA4 detection
# e.g. `if device_capability.major == 9` -> FA3
# `elif device_capability.major >= 10` -> FA4
if isinstance(n, ast.If):
test = n.test
comparisons = (
[v for v in test.values if isinstance(v, ast.Compare)]
if isinstance(test, ast.BoolOp)
else [test]
if isinstance(test, ast.Compare)
else []
)
for comp in comparisons:
if not (
isinstance(comp.left, ast.Attribute)
and comp.left.attr == "major"
and comp.comparators
and isinstance(comp.comparators[0], ast.Constant)
and isinstance(comp.comparators[0].value, int)
):
continue
op = comp.ops[0]
val = comp.comparators[0].value
if isinstance(op, ast.Eq) and fa3_compute_cap is None:
fa3_compute_cap = f"{val}.x"
elif isinstance(op, ast.GtE) and fa4_compute_cap is None:
fa4_compute_cap = f"≥{val}.0"
# Fallback: try to parse FA4 compute caps from flash_attn_interface.py
if fa4_compute_cap is None:
fa4_compute_cap = _parse_fa4_supported_caps()
return { return {
"fa2": { "fa2": {
"supports_fp8": False, "supports_fp8": False,
...@@ -644,6 +714,11 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -644,6 +714,11 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"supports_fp8": fa3_supports_fp8, "supports_fp8": fa3_supports_fp8,
"supports_sink": fa3_supports_sinks, "supports_sink": fa3_supports_sinks,
}, },
"fa4": {
"compute_capability": fa4_compute_cap,
"supports_fp8": False,
"supports_sink": False,
},
} }
...@@ -760,7 +835,7 @@ def parse_mla_prefill_backends() -> list[dict[str, Any]]: ...@@ -760,7 +835,7 @@ def parse_mla_prefill_backends() -> list[dict[str, Any]]:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM) # Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -768,7 +843,7 @@ def _expand_flash_attn_variants( ...@@ -768,7 +843,7 @@ def _expand_flash_attn_variants(
all_backends: list[dict[str, Any]], all_backends: list[dict[str, Any]],
fa_features: dict[str, dict[str, Any]], fa_features: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities.""" """Expand FLASH_ATTN into FA2, FA3, and FA4 variants."""
expanded = [] expanded = []
for backend in all_backends: for backend in all_backends:
if backend["name"] != "FLASH_ATTN": if backend["name"] != "FLASH_ATTN":
...@@ -801,6 +876,18 @@ def _expand_flash_attn_variants( ...@@ -801,6 +876,18 @@ def _expand_flash_attn_variants(
expanded.append(fa2) expanded.append(fa2)
expanded.append(fa3) expanded.append(fa3)
# Create FA4 entry if FA4 features are available
if "fa4" in fa_features:
fa4 = backend.copy()
fa4["version"] = "FA4*"
fa4["_sort_key"] = "FLASH_ATTN"
fa4["_sort_order"] = 2
if fa_features["fa4"].get("compute_capability"):
fa4["compute_capability"] = fa_features["fa4"]["compute_capability"]
fa4["supports_sink"] = fa_features["fa4"]["supports_sink"]
expanded.append(fa4)
return expanded return expanded
...@@ -1360,7 +1447,8 @@ def generate_docs() -> str: ...@@ -1360,7 +1447,8 @@ def generate_docs() -> str:
if fa_features: if fa_features:
footnotes.append( footnotes.append(
"> **\\*** Specify the FlashAttention version via " "> **\\*** Specify the FlashAttention version via "
"`--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, " "`--attention-config.flash_attn_version=2`, `3`, or `4`. "
"Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), "
"FA2 otherwise." "FA2 otherwise."
) )
if footnotes: if footnotes:
......
...@@ -16,8 +16,8 @@ class AttentionConfig: ...@@ -16,8 +16,8 @@ class AttentionConfig:
backend: AttentionBackendEnum | None = None backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically.""" """Attention backend to use. If None, will be selected automatically."""
flash_attn_version: Literal[2, 3] | None = None flash_attn_version: Literal[2, 3, 4] | None = None
"""Force vllm to use a specific flash-attention version (2 or 3). """Force vllm to use a specific flash-attention version (2, 3, or 4).
Only valid when using the flash-attention backend.""" Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False use_prefill_decode_attention: bool = False
......
...@@ -2014,7 +2014,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -2014,7 +2014,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# RoCM and the latter has an additional parameter to control # RoCM and the latter has an additional parameter to control
# FA2 vs FA3 # FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version(
head_size=self.qk_head_dim
)
if self.vllm_flash_attn_version is not None: if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = functools.partial( self.flash_attn_varlen_func = functools.partial(
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
......
...@@ -204,7 +204,9 @@ class MMEncoderAttention(CustomOp): ...@@ -204,7 +204,9 @@ class MMEncoderAttention(CustomOp):
} }
self._fa_version = ( self._fa_version = (
get_flash_attn_version() if self.is_flash_attn_backend else None get_flash_attn_version(head_size=head_size)
if self.is_flash_attn_backend
else None
) )
if self.attn_backend == AttentionBackendEnum.FLASHINFER: if self.attn_backend == AttentionBackendEnum.FLASHINFER:
......
...@@ -52,7 +52,9 @@ elif current_platform.is_rocm(): ...@@ -52,7 +52,9 @@ elif current_platform.is_rocm():
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None: def get_flash_attn_version(
requires_alibi: bool = False, head_size: int | None = None
) -> int | None:
# import here to avoid circular dependencies # import here to avoid circular dependencies
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -72,9 +74,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ...@@ -72,9 +74,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
assert device_capability is not None assert device_capability is not None
# 1. default version depending on platform # 1. default version depending on platform
fa_version = ( if device_capability.major == 9 and is_fa_version_supported(3):
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2 # Hopper (SM90): prefer FA3
) fa_version = 3
elif device_capability.major == 10 and is_fa_version_supported(4):
# Blackwell (SM100+, restrict to SM100 for now): prefer FA4
fa_version = 4
else:
# Fallback to FA2
fa_version = 2
# 2. override if passed by environment or config # 2. override if passed by environment or config
from vllm.config import get_current_vllm_config_or_none from vllm.config import get_current_vllm_config_or_none
...@@ -87,12 +95,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ...@@ -87,12 +95,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
fa_version = vllm_config.attention_config.flash_attn_version fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations # 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3: if device_capability.major >= 10 and fa_version == 3:
logger.warning_once( logger.warning_once(
"Cannot use FA version 3 on Blackwell platform, " "Cannot use FA version 3 on Blackwell platform, "
"defaulting to FA version 2." "defaulting to FA version 4 if supported, otherwise FA2."
) )
fa_version = 2 fa_version = 4 if is_fa_version_supported(4) else 2
if requires_alibi and fa_version == 3: if requires_alibi and fa_version == 3:
logger.warning_once( logger.warning_once(
...@@ -100,6 +108,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ...@@ -100,6 +108,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
) )
fa_version = 2 fa_version = 2
if requires_alibi and fa_version == 4:
logger.warning_once(
"Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# supported head dimensions.
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
if (
fa_version == 4
and device_capability.major >= 10
and head_size is not None
and head_size > 128
):
logger.warning_once(
"FA4 on Blackwell does not support head_size=%d due to TMEM "
"capacity limits, defaulting to FA version 2.",
head_size,
)
fa_version = 2
if not is_fa_version_supported(fa_version): if not is_fa_version_supported(fa_version):
logger.error( logger.error(
"Cannot use FA version %d is not supported due to %s", "Cannot use FA version %d is not supported due to %s",
...@@ -139,6 +169,10 @@ def flash_attn_supports_mla(): ...@@ -139,6 +169,10 @@ def flash_attn_supports_mla():
return is_fa_version_supported( return is_fa_version_supported(
3 3
) and current_platform.is_device_capability_family(90) ) and current_platform.is_device_capability_family(90)
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
except (ImportError, AssertionError): except (ImportError, AssertionError):
pass pass
return False return False
......
...@@ -580,7 +580,15 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -580,7 +580,15 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.attn_type = attn_type self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=alibi_slopes is not None,
head_size=head_size,
)
logger.info_once(
"Using FlashAttention version %s",
self.vllm_flash_attn_version,
scope="local",
)
# Cache the batch invariant result for use in forward passes # Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant() self.batch_invariant_enabled = vllm_is_batch_invariant()
......
...@@ -137,7 +137,7 @@ class CudagraphDispatcher: ...@@ -137,7 +137,7 @@ class CudagraphDispatcher:
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens] num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs)
assert num_tokens_padded % uniform_decode_query_len == 0 assert num_tokens_padded % uniform_decode_query_len == 0
else: else:
uniform_decode = False uniform_decode = False
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.vllm_flash_attn.flash_attn_interface import (
FA2_AVAILABLE,
FA3_AVAILABLE,
fa_version_unsupported_reason,
flash_attn_varlen_func,
get_scheduler_metadata,
is_fa_version_supported,
)
if not (FA2_AVAILABLE or FA3_AVAILABLE):
raise ImportError(
"vllm.vllm_flash_attn requires the CUDA flash attention extensions "
"(_vllm_fa2_C or _vllm_fa3_C). On ROCm, use upstream flash_attn."
)
__all__ = [
"fa_version_unsupported_reason",
"flash_attn_varlen_func",
"get_scheduler_metadata",
"is_fa_version_supported",
]
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment