Unverified Commit 8b5014d3 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] FA4 integration (#32974)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 57a96e26
......@@ -9,6 +9,7 @@ steps:
- tests/v1
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# split the test to avoid interference
- pytest -v -s -m 'not cpu_test' v1/core
- pytest -v -s v1/executor
......
......@@ -3,6 +3,8 @@
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
!vllm/vllm_flash_attn/__init__.py
!vllm/vllm_flash_attn/flash_attn_interface.py
# OpenAI triton kernels copied from source
vllm/third_party/triton_kernels/*
......
......@@ -17,7 +17,8 @@ endif()
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3),
# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files).
# If no component is specified, vllm-flash-attn is still installed.
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
......@@ -38,7 +39,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6
GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
......@@ -46,38 +47,62 @@ else()
endif()
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS)
# Make sure vllm-flash-attn install rules are nested under vllm/
# This is here to support installing all components under the same prefix with cmake --install.
# setup.py installs every component separately but uses the same prefix for all.
# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3,
# and these statements don't hurt when installing neither component.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS)
# Install rules for FA components need the install prefix nested under vllm/
# These run at install time, before the FA library's own install rules
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT ${_FA_COMPONENT})
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT})
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT ${_FA_COMPONENT})
endforeach()
# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
# Restore the install prefix
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
# Restore the install prefix after FA's install rules
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT})
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT ${_FA_COMPONENT})
endforeach()
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
# case only one is built, in the case both are built redundant work is done)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py"
)
# Install shared Python files for both FA2 and FA3 components
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")"
COMPONENT ${_FA_COMPONENT})
install(
# Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py
# which are source-controlled in vllm)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT _vllm_fa3_C
COMPONENT ${_FA_COMPONENT}
FILES_MATCHING PATTERN "*.py"
)
PATTERN "__init__.py" EXCLUDE
PATTERN "flash_attn_interface.py" EXCLUDE
)
endforeach()
#
# FA4 CuteDSL component
# This is a Python-only component that copies the flash_attn/cute directory
# and transforms imports to match our package structure.
#
add_custom_target(_vllm_fa4_cutedsl_C)
# Copy flash_attn/cute directory (needed for FA4) and transform imports
# The cute directory uses flash_attn.cute imports internally, which we replace
# with vllm.vllm_flash_attn.cute to match our package structure.
install(CODE "
file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\")
foreach(SRC_FILE \${CUTE_PY_FILES})
file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE})
set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\")
get_filename_component(DST_DIR \${DST_FILE} DIRECTORY)
file(MAKE_DIRECTORY \${DST_DIR})
file(READ \${SRC_FILE} FILE_CONTENTS)
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
endforeach()
" COMPONENT _vllm_fa4_cutedsl_C)
......@@ -168,6 +168,7 @@ Priority is **1 = highest** (tried first).
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
......@@ -178,7 +179,7 @@ Priority is **1 = highest** (tried first).
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
>
> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, FA2 otherwise.
> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2`, `3`, or `4`. Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), FA2 otherwise.
## MLA (Multi-head Latent Attention) Backends
......
......@@ -11,3 +11,7 @@ torchaudio==2.10.0
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.4
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
nvidia-cutlass-dsl>=4.4.0.dev1
quack-kernels>=0.2.7
......@@ -976,6 +976,11 @@ if _is_cuda():
):
# FA3 requires CUDA 12.3 or later
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
# FA4 CuteDSL - Python-only component for FA4's cute DSL support
# Optional since this doesn't produce a .so file, just copies Python files
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa4_cutedsl_C", optional=True)
)
if envs.VLLM_USE_PRECOMPILED or (
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9")
):
......
......@@ -563,14 +563,53 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
# ---------------------------------------------------------------------------
# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill)
# Special backend variant parsers (FA2/FA3/FA4, FlashInfer TRTLLM, MLA prefill)
# ---------------------------------------------------------------------------
def _parse_fa4_supported_caps() -> str | None:
"""Parse flash_attn_interface.py for FA4 supported compute capabilities.
Looks for `cc not in [9, 10, 11]` pattern in _is_fa4_supported().
"""
fa_interface_file = (
REPO_ROOT / "vllm" / "vllm_flash_attn" / "flash_attn_interface.py"
)
if not fa_interface_file.exists():
return None
try:
tree = ast.parse(fa_interface_file.read_text())
except Exception:
return None
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef) or node.name != "_is_fa4_supported":
continue
for n in ast.walk(node):
if not (
isinstance(n, ast.Compare)
and len(n.ops) == 1
and isinstance(n.ops[0], ast.NotIn)
and isinstance(n.comparators[0], ast.List)
):
continue
caps: list[int] = [
e.value
for e in n.comparators[0].elts
if isinstance(e, ast.Constant) and isinstance(e.value, int)
]
if caps:
caps.sort()
return f"{caps[0]}.x-{caps[-1]}.x"
return None
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences.
"""Parse fa_utils.py to detect FA2 vs FA3 vs FA4 feature differences.
Returns a dict with 'fa2' and 'fa3' keys containing their respective
Returns a dict with 'fa2', 'fa3', and 'fa4' keys containing their respective
feature overrides for compute capability, KV cache dtypes, and sink support.
"""
if not FA_UTILS_FILE.exists():
......@@ -585,6 +624,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_fp8 = False
fa3_supports_sinks = False
fa3_compute_cap: str | None = None
fa4_compute_cap: str | None = None
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef):
......@@ -614,14 +654,12 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_sinks = True
break
# Check get_flash_attn_version for FA3 compute capability
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
# Check get_flash_attn_version for FA3/FA4 compute capability
if node.name == "get_flash_attn_version":
for n in ast.walk(node):
# Look for IfExp (ternary) with `device_capability.major == 9`
# Handle IfExp (ternary) with `device_capability.major == 9`
if isinstance(n, ast.IfExp):
test = n.test
# Check if test is a BoolOp (and) containing the major check
if isinstance(test, ast.BoolOp):
for val in test.values:
if (
......@@ -634,6 +672,38 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_compute_cap = f"{val.comparators[0].value}.x"
break
# Handle If statements for FA3/FA4 detection
# e.g. `if device_capability.major == 9` -> FA3
# `elif device_capability.major >= 10` -> FA4
if isinstance(n, ast.If):
test = n.test
comparisons = (
[v for v in test.values if isinstance(v, ast.Compare)]
if isinstance(test, ast.BoolOp)
else [test]
if isinstance(test, ast.Compare)
else []
)
for comp in comparisons:
if not (
isinstance(comp.left, ast.Attribute)
and comp.left.attr == "major"
and comp.comparators
and isinstance(comp.comparators[0], ast.Constant)
and isinstance(comp.comparators[0].value, int)
):
continue
op = comp.ops[0]
val = comp.comparators[0].value
if isinstance(op, ast.Eq) and fa3_compute_cap is None:
fa3_compute_cap = f"{val}.x"
elif isinstance(op, ast.GtE) and fa4_compute_cap is None:
fa4_compute_cap = f"≥{val}.0"
# Fallback: try to parse FA4 compute caps from flash_attn_interface.py
if fa4_compute_cap is None:
fa4_compute_cap = _parse_fa4_supported_caps()
return {
"fa2": {
"supports_fp8": False,
......@@ -644,6 +714,11 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"supports_fp8": fa3_supports_fp8,
"supports_sink": fa3_supports_sinks,
},
"fa4": {
"compute_capability": fa4_compute_cap,
"supports_fp8": False,
"supports_sink": False,
},
}
......@@ -760,7 +835,7 @@ def parse_mla_prefill_backends() -> list[dict[str, Any]]:
# ---------------------------------------------------------------------------
# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM)
# Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
# ---------------------------------------------------------------------------
......@@ -768,7 +843,7 @@ def _expand_flash_attn_variants(
all_backends: list[dict[str, Any]],
fa_features: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
"""Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities."""
"""Expand FLASH_ATTN into FA2, FA3, and FA4 variants."""
expanded = []
for backend in all_backends:
if backend["name"] != "FLASH_ATTN":
......@@ -801,6 +876,18 @@ def _expand_flash_attn_variants(
expanded.append(fa2)
expanded.append(fa3)
# Create FA4 entry if FA4 features are available
if "fa4" in fa_features:
fa4 = backend.copy()
fa4["version"] = "FA4*"
fa4["_sort_key"] = "FLASH_ATTN"
fa4["_sort_order"] = 2
if fa_features["fa4"].get("compute_capability"):
fa4["compute_capability"] = fa_features["fa4"]["compute_capability"]
fa4["supports_sink"] = fa_features["fa4"]["supports_sink"]
expanded.append(fa4)
return expanded
......@@ -1360,7 +1447,8 @@ def generate_docs() -> str:
if fa_features:
footnotes.append(
"> **\\*** Specify the FlashAttention version via "
"`--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, "
"`--attention-config.flash_attn_version=2`, `3`, or `4`. "
"Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), "
"FA2 otherwise."
)
if footnotes:
......
......@@ -16,8 +16,8 @@ class AttentionConfig:
backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically."""
flash_attn_version: Literal[2, 3] | None = None
"""Force vllm to use a specific flash-attention version (2 or 3).
flash_attn_version: Literal[2, 3, 4] | None = None
"""Force vllm to use a specific flash-attention version (2, 3, or 4).
Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False
......
......@@ -2014,7 +2014,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# RoCM and the latter has an additional parameter to control
# FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
self.vllm_flash_attn_version = get_flash_attn_version(
head_size=self.qk_head_dim
)
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = functools.partial(
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
......
......@@ -204,7 +204,9 @@ class MMEncoderAttention(CustomOp):
}
self._fa_version = (
get_flash_attn_version() if self.is_flash_attn_backend else None
get_flash_attn_version(head_size=head_size)
if self.is_flash_attn_backend
else None
)
if self.attn_backend == AttentionBackendEnum.FLASHINFER:
......
......@@ -52,7 +52,9 @@ elif current_platform.is_rocm():
reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
def get_flash_attn_version(
requires_alibi: bool = False, head_size: int | None = None
) -> int | None:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
......@@ -72,9 +74,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
assert device_capability is not None
# 1. default version depending on platform
fa_version = (
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)
if device_capability.major == 9 and is_fa_version_supported(3):
# Hopper (SM90): prefer FA3
fa_version = 3
elif device_capability.major == 10 and is_fa_version_supported(4):
# Blackwell (SM100+, restrict to SM100 for now): prefer FA4
fa_version = 4
else:
# Fallback to FA2
fa_version = 2
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config_or_none
......@@ -87,12 +95,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
if device_capability.major >= 10 and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform, "
"defaulting to FA version 2."
"defaulting to FA version 4 if supported, otherwise FA2."
)
fa_version = 2
fa_version = 4 if is_fa_version_supported(4) else 2
if requires_alibi and fa_version == 3:
logger.warning_once(
......@@ -100,6 +108,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
)
fa_version = 2
if requires_alibi and fa_version == 4:
logger.warning_once(
"Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# supported head dimensions.
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
if (
fa_version == 4
and device_capability.major >= 10
and head_size is not None
and head_size > 128
):
logger.warning_once(
"FA4 on Blackwell does not support head_size=%d due to TMEM "
"capacity limits, defaulting to FA version 2.",
head_size,
)
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error(
"Cannot use FA version %d is not supported due to %s",
......@@ -139,6 +169,10 @@ def flash_attn_supports_mla():
return is_fa_version_supported(
3
) and current_platform.is_device_capability_family(90)
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
except (ImportError, AssertionError):
pass
return False
......
......@@ -580,7 +580,15 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=alibi_slopes is not None,
head_size=head_size,
)
logger.info_once(
"Using FlashAttention version %s",
self.vllm_flash_attn_version,
scope="local",
)
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant()
......
......@@ -137,7 +137,7 @@ class CudagraphDispatcher:
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs)
assert num_tokens_padded % uniform_decode_query_len == 0
else:
uniform_decode = False
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.vllm_flash_attn.flash_attn_interface import (
FA2_AVAILABLE,
FA3_AVAILABLE,
fa_version_unsupported_reason,
flash_attn_varlen_func,
get_scheduler_metadata,
is_fa_version_supported,
)
if not (FA2_AVAILABLE or FA3_AVAILABLE):
raise ImportError(
"vllm.vllm_flash_attn requires the CUDA flash attention extensions "
"(_vllm_fa2_C or _vllm_fa3_C). On ROCm, use upstream flash_attn."
)
__all__ = [
"fa_version_unsupported_reason",
"flash_attn_varlen_func",
"get_scheduler_metadata",
"is_fa_version_supported",
]
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment