Unverified Commit 8b5014d3 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] FA4 integration (#32974)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 57a96e26
...@@ -9,6 +9,7 @@ steps: ...@@ -9,6 +9,7 @@ steps:
- tests/v1 - tests/v1
commands: commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# split the test to avoid interference # split the test to avoid interference
- pytest -v -s -m 'not cpu_test' v1/core - pytest -v -s -m 'not cpu_test' v1/core
- pytest -v -s v1/executor - pytest -v -s v1/executor
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# vllm-flash-attn built from source # vllm-flash-attn built from source
vllm/vllm_flash_attn/* vllm/vllm_flash_attn/*
!vllm/vllm_flash_attn/__init__.py
!vllm/vllm_flash_attn/flash_attn_interface.py
# OpenAI triton kernels copied from source # OpenAI triton kernels copied from source
vllm/third_party/triton_kernels/* vllm/third_party/triton_kernels/*
......
...@@ -17,7 +17,8 @@ endif() ...@@ -17,7 +17,8 @@ endif()
# They should be identical but if they aren't, this is a massive footgun. # They should be identical but if they aren't, this is a massive footgun.
# #
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. # The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). # To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3),
# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files).
# If no component is specified, vllm-flash-attn is still installed. # If no component is specified, vllm-flash-attn is still installed.
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. # If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
...@@ -38,7 +39,7 @@ else() ...@@ -38,7 +39,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6 GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
...@@ -46,38 +47,62 @@ else() ...@@ -46,38 +47,62 @@ else()
endif() endif()
# Ensure the vllm/vllm_flash_attn directory exists before installation # Install rules for FA components need the install prefix nested under vllm/
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS) # These run at install time, before the FA library's own install rules
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
# Make sure vllm-flash-attn install rules are nested under vllm/ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT ${_FA_COMPONENT})
# This is here to support installing all components under the same prefix with cmake --install. install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT})
# setup.py installs every component separately but uses the same prefix for all. install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT ${_FA_COMPONENT})
# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3, endforeach()
# and these statements don't hurt when installing neither component.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS)
# Fetch the vllm-flash-attn library # Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn) FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
# Restore the install prefix # Restore the install prefix after FA's install rules
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT})
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT ${_FA_COMPONENT})
endforeach()
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in # Install shared Python files for both FA2 and FA3 components
# case only one is built, in the case both are built redundant work is done) foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
install( # Ensure the vllm/vllm_flash_attn directory exists before installation
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")"
DESTINATION vllm/vllm_flash_attn COMPONENT ${_FA_COMPONENT})
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py"
)
install( # Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py
# which are source-controlled in vllm)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn DESTINATION vllm/vllm_flash_attn
COMPONENT _vllm_fa3_C COMPONENT ${_FA_COMPONENT}
FILES_MATCHING PATTERN "*.py" FILES_MATCHING PATTERN "*.py"
) PATTERN "__init__.py" EXCLUDE
PATTERN "flash_attn_interface.py" EXCLUDE
)
endforeach()
#
# FA4 CuteDSL component
# This is a Python-only component that copies the flash_attn/cute directory
# and transforms imports to match our package structure.
#
add_custom_target(_vllm_fa4_cutedsl_C)
# Copy flash_attn/cute directory (needed for FA4) and transform imports
# The cute directory uses flash_attn.cute imports internally, which we replace
# with vllm.vllm_flash_attn.cute to match our package structure.
install(CODE "
file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\")
foreach(SRC_FILE \${CUTE_PY_FILES})
file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE})
set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\")
get_filename_component(DST_DIR \${DST_FILE} DIRECTORY)
file(MAKE_DIRECTORY \${DST_DIR})
file(READ \${SRC_FILE} FILE_CONTENTS)
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
endforeach()
" COMPONENT _vllm_fa4_cutedsl_C)
...@@ -168,6 +168,7 @@ Priority is **1 = highest** (tried first). ...@@ -168,6 +168,7 @@ Priority is **1 = highest** (tried first).
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
...@@ -178,7 +179,7 @@ Priority is **1 = highest** (tried first). ...@@ -178,7 +179,7 @@ Priority is **1 = highest** (tried first).
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
> >
> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, FA2 otherwise. > **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2`, `3`, or `4`. Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), FA2 otherwise.
## MLA (Multi-head Latent Attention) Backends ## MLA (Multi-head Latent Attention) Backends
......
...@@ -11,3 +11,7 @@ torchaudio==2.10.0 ...@@ -11,3 +11,7 @@ torchaudio==2.10.0
torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile # FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.4 flashinfer-python==0.6.4
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
nvidia-cutlass-dsl>=4.4.0.dev1
quack-kernels>=0.2.7
...@@ -976,6 +976,11 @@ if _is_cuda(): ...@@ -976,6 +976,11 @@ if _is_cuda():
): ):
# FA3 requires CUDA 12.3 or later # FA3 requires CUDA 12.3 or later
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
# FA4 CuteDSL - Python-only component for FA4's cute DSL support
# Optional since this doesn't produce a .so file, just copies Python files
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa4_cutedsl_C", optional=True)
)
if envs.VLLM_USE_PRECOMPILED or ( if envs.VLLM_USE_PRECOMPILED or (
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9") CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9")
): ):
......
...@@ -563,14 +563,53 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None ...@@ -563,14 +563,53 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill) # Special backend variant parsers (FA2/FA3/FA4, FlashInfer TRTLLM, MLA prefill)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _parse_fa4_supported_caps() -> str | None:
"""Parse flash_attn_interface.py for FA4 supported compute capabilities.
Looks for `cc not in [9, 10, 11]` pattern in _is_fa4_supported().
"""
fa_interface_file = (
REPO_ROOT / "vllm" / "vllm_flash_attn" / "flash_attn_interface.py"
)
if not fa_interface_file.exists():
return None
try:
tree = ast.parse(fa_interface_file.read_text())
except Exception:
return None
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef) or node.name != "_is_fa4_supported":
continue
for n in ast.walk(node):
if not (
isinstance(n, ast.Compare)
and len(n.ops) == 1
and isinstance(n.ops[0], ast.NotIn)
and isinstance(n.comparators[0], ast.List)
):
continue
caps: list[int] = [
e.value
for e in n.comparators[0].elts
if isinstance(e, ast.Constant) and isinstance(e.value, int)
]
if caps:
caps.sort()
return f"{caps[0]}.x-{caps[-1]}.x"
return None
def parse_flash_attn_features() -> dict[str, dict[str, Any]]: def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences. """Parse fa_utils.py to detect FA2 vs FA3 vs FA4 feature differences.
Returns a dict with 'fa2' and 'fa3' keys containing their respective Returns a dict with 'fa2', 'fa3', and 'fa4' keys containing their respective
feature overrides for compute capability, KV cache dtypes, and sink support. feature overrides for compute capability, KV cache dtypes, and sink support.
""" """
if not FA_UTILS_FILE.exists(): if not FA_UTILS_FILE.exists():
...@@ -585,6 +624,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -585,6 +624,7 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_fp8 = False fa3_supports_fp8 = False
fa3_supports_sinks = False fa3_supports_sinks = False
fa3_compute_cap: str | None = None fa3_compute_cap: str | None = None
fa4_compute_cap: str | None = None
for node in ast.walk(tree): for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef): if not isinstance(node, ast.FunctionDef):
...@@ -614,14 +654,12 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -614,14 +654,12 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_supports_sinks = True fa3_supports_sinks = True
break break
# Check get_flash_attn_version for FA3 compute capability # Check get_flash_attn_version for FA3/FA4 compute capability
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
if node.name == "get_flash_attn_version": if node.name == "get_flash_attn_version":
for n in ast.walk(node): for n in ast.walk(node):
# Look for IfExp (ternary) with `device_capability.major == 9` # Handle IfExp (ternary) with `device_capability.major == 9`
if isinstance(n, ast.IfExp): if isinstance(n, ast.IfExp):
test = n.test test = n.test
# Check if test is a BoolOp (and) containing the major check
if isinstance(test, ast.BoolOp): if isinstance(test, ast.BoolOp):
for val in test.values: for val in test.values:
if ( if (
...@@ -634,6 +672,38 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -634,6 +672,38 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
fa3_compute_cap = f"{val.comparators[0].value}.x" fa3_compute_cap = f"{val.comparators[0].value}.x"
break break
# Handle If statements for FA3/FA4 detection
# e.g. `if device_capability.major == 9` -> FA3
# `elif device_capability.major >= 10` -> FA4
if isinstance(n, ast.If):
test = n.test
comparisons = (
[v for v in test.values if isinstance(v, ast.Compare)]
if isinstance(test, ast.BoolOp)
else [test]
if isinstance(test, ast.Compare)
else []
)
for comp in comparisons:
if not (
isinstance(comp.left, ast.Attribute)
and comp.left.attr == "major"
and comp.comparators
and isinstance(comp.comparators[0], ast.Constant)
and isinstance(comp.comparators[0].value, int)
):
continue
op = comp.ops[0]
val = comp.comparators[0].value
if isinstance(op, ast.Eq) and fa3_compute_cap is None:
fa3_compute_cap = f"{val}.x"
elif isinstance(op, ast.GtE) and fa4_compute_cap is None:
fa4_compute_cap = f"≥{val}.0"
# Fallback: try to parse FA4 compute caps from flash_attn_interface.py
if fa4_compute_cap is None:
fa4_compute_cap = _parse_fa4_supported_caps()
return { return {
"fa2": { "fa2": {
"supports_fp8": False, "supports_fp8": False,
...@@ -644,6 +714,11 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]: ...@@ -644,6 +714,11 @@ def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"supports_fp8": fa3_supports_fp8, "supports_fp8": fa3_supports_fp8,
"supports_sink": fa3_supports_sinks, "supports_sink": fa3_supports_sinks,
}, },
"fa4": {
"compute_capability": fa4_compute_cap,
"supports_fp8": False,
"supports_sink": False,
},
} }
...@@ -760,7 +835,7 @@ def parse_mla_prefill_backends() -> list[dict[str, Any]]: ...@@ -760,7 +835,7 @@ def parse_mla_prefill_backends() -> list[dict[str, Any]]:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM) # Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -768,7 +843,7 @@ def _expand_flash_attn_variants( ...@@ -768,7 +843,7 @@ def _expand_flash_attn_variants(
all_backends: list[dict[str, Any]], all_backends: list[dict[str, Any]],
fa_features: dict[str, dict[str, Any]], fa_features: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities.""" """Expand FLASH_ATTN into FA2, FA3, and FA4 variants."""
expanded = [] expanded = []
for backend in all_backends: for backend in all_backends:
if backend["name"] != "FLASH_ATTN": if backend["name"] != "FLASH_ATTN":
...@@ -801,6 +876,18 @@ def _expand_flash_attn_variants( ...@@ -801,6 +876,18 @@ def _expand_flash_attn_variants(
expanded.append(fa2) expanded.append(fa2)
expanded.append(fa3) expanded.append(fa3)
# Create FA4 entry if FA4 features are available
if "fa4" in fa_features:
fa4 = backend.copy()
fa4["version"] = "FA4*"
fa4["_sort_key"] = "FLASH_ATTN"
fa4["_sort_order"] = 2
if fa_features["fa4"].get("compute_capability"):
fa4["compute_capability"] = fa_features["fa4"]["compute_capability"]
fa4["supports_sink"] = fa_features["fa4"]["supports_sink"]
expanded.append(fa4)
return expanded return expanded
...@@ -1360,7 +1447,8 @@ def generate_docs() -> str: ...@@ -1360,7 +1447,8 @@ def generate_docs() -> str:
if fa_features: if fa_features:
footnotes.append( footnotes.append(
"> **\\*** Specify the FlashAttention version via " "> **\\*** Specify the FlashAttention version via "
"`--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, " "`--attention-config.flash_attn_version=2`, `3`, or `4`. "
"Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), "
"FA2 otherwise." "FA2 otherwise."
) )
if footnotes: if footnotes:
......
...@@ -16,8 +16,8 @@ class AttentionConfig: ...@@ -16,8 +16,8 @@ class AttentionConfig:
backend: AttentionBackendEnum | None = None backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically.""" """Attention backend to use. If None, will be selected automatically."""
flash_attn_version: Literal[2, 3] | None = None flash_attn_version: Literal[2, 3, 4] | None = None
"""Force vllm to use a specific flash-attention version (2 or 3). """Force vllm to use a specific flash-attention version (2, 3, or 4).
Only valid when using the flash-attention backend.""" Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False use_prefill_decode_attention: bool = False
......
...@@ -2014,7 +2014,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -2014,7 +2014,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# RoCM and the latter has an additional parameter to control # RoCM and the latter has an additional parameter to control
# FA2 vs FA3 # FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version(
head_size=self.qk_head_dim
)
if self.vllm_flash_attn_version is not None: if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = functools.partial( self.flash_attn_varlen_func = functools.partial(
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
......
...@@ -204,7 +204,9 @@ class MMEncoderAttention(CustomOp): ...@@ -204,7 +204,9 @@ class MMEncoderAttention(CustomOp):
} }
self._fa_version = ( self._fa_version = (
get_flash_attn_version() if self.is_flash_attn_backend else None get_flash_attn_version(head_size=head_size)
if self.is_flash_attn_backend
else None
) )
if self.attn_backend == AttentionBackendEnum.FLASHINFER: if self.attn_backend == AttentionBackendEnum.FLASHINFER:
......
...@@ -52,7 +52,9 @@ elif current_platform.is_rocm(): ...@@ -52,7 +52,9 @@ elif current_platform.is_rocm():
reshape_and_cache_flash = ops.reshape_and_cache_flash reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None: def get_flash_attn_version(
requires_alibi: bool = False, head_size: int | None = None
) -> int | None:
# import here to avoid circular dependencies # import here to avoid circular dependencies
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -72,9 +74,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ...@@ -72,9 +74,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
assert device_capability is not None assert device_capability is not None
# 1. default version depending on platform # 1. default version depending on platform
fa_version = ( if device_capability.major == 9 and is_fa_version_supported(3):
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2 # Hopper (SM90): prefer FA3
) fa_version = 3
elif device_capability.major == 10 and is_fa_version_supported(4):
# Blackwell (SM100+, restrict to SM100 for now): prefer FA4
fa_version = 4
else:
# Fallback to FA2
fa_version = 2
# 2. override if passed by environment or config # 2. override if passed by environment or config
from vllm.config import get_current_vllm_config_or_none from vllm.config import get_current_vllm_config_or_none
...@@ -87,12 +95,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ...@@ -87,12 +95,12 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
fa_version = vllm_config.attention_config.flash_attn_version fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations # 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3: if device_capability.major >= 10 and fa_version == 3:
logger.warning_once( logger.warning_once(
"Cannot use FA version 3 on Blackwell platform, " "Cannot use FA version 3 on Blackwell platform, "
"defaulting to FA version 2." "defaulting to FA version 4 if supported, otherwise FA2."
) )
fa_version = 2 fa_version = 4 if is_fa_version_supported(4) else 2
if requires_alibi and fa_version == 3: if requires_alibi and fa_version == 3:
logger.warning_once( logger.warning_once(
...@@ -100,6 +108,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: ...@@ -100,6 +108,28 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
) )
fa_version = 2 fa_version = 2
if requires_alibi and fa_version == 4:
logger.warning_once(
"Cannot use FA version 4 with ALiBi, defaulting to FA version 2."
)
fa_version = 2
# FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict
# supported head dimensions.
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
if (
fa_version == 4
and device_capability.major >= 10
and head_size is not None
and head_size > 128
):
logger.warning_once(
"FA4 on Blackwell does not support head_size=%d due to TMEM "
"capacity limits, defaulting to FA version 2.",
head_size,
)
fa_version = 2
if not is_fa_version_supported(fa_version): if not is_fa_version_supported(fa_version):
logger.error( logger.error(
"Cannot use FA version %d is not supported due to %s", "Cannot use FA version %d is not supported due to %s",
...@@ -139,6 +169,10 @@ def flash_attn_supports_mla(): ...@@ -139,6 +169,10 @@ def flash_attn_supports_mla():
return is_fa_version_supported( return is_fa_version_supported(
3 3
) and current_platform.is_device_capability_family(90) ) and current_platform.is_device_capability_family(90)
# NOTE(Lucas): FA4 CuteDSL does NOT currently support MLA's non-standard
# head dimensions (576 for qk, 512 for v) due to TMEM capacity limits.
except (ImportError, AssertionError): except (ImportError, AssertionError):
pass pass
return False return False
......
...@@ -580,7 +580,15 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -580,7 +580,15 @@ class FlashAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.attn_type = attn_type self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=alibi_slopes is not None,
head_size=head_size,
)
logger.info_once(
"Using FlashAttention version %s",
self.vllm_flash_attn_version,
scope="local",
)
# Cache the batch invariant result for use in forward passes # Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant() self.batch_invariant_enabled = vllm_is_batch_invariant()
......
...@@ -137,7 +137,7 @@ class CudagraphDispatcher: ...@@ -137,7 +137,7 @@ class CudagraphDispatcher:
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens] num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs)
assert num_tokens_padded % uniform_decode_query_len == 0 assert num_tokens_padded % uniform_decode_query_len == 0
else: else:
uniform_decode = False uniform_decode = False
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.vllm_flash_attn.flash_attn_interface import (
FA2_AVAILABLE,
FA3_AVAILABLE,
fa_version_unsupported_reason,
flash_attn_varlen_func,
get_scheduler_metadata,
is_fa_version_supported,
)
if not (FA2_AVAILABLE or FA3_AVAILABLE):
raise ImportError(
"vllm.vllm_flash_attn requires the CUDA flash attention extensions "
"(_vllm_fa2_C or _vllm_fa3_C). On ROCm, use upstream flash_attn."
)
__all__ = [
"fa_version_unsupported_reason",
"flash_attn_varlen_func",
"get_scheduler_metadata",
"is_fa_version_supported",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2023, Tri Dao.
# ruff: noqa: E501
import torch
# isort: off
# We need to import the CUDA kernels after importing torch
# Use relative import to support build-from-source installation in vLLM
try:
from . import _vllm_fa2_C # type: ignore[attr-defined] # noqa: F401
FA2_UNAVAILABLE_REASON = None
FA2_AVAILABLE = True
except ImportError as e:
FA2_UNAVAILABLE_REASON = str(e)
FA2_AVAILABLE = False
try:
from . import _vllm_fa3_C # type: ignore[attr-defined] # noqa: F401
FA3_UNAVAILABLE_REASON = None
FA3_AVAILABLE = True
except ImportError as e:
FA3_UNAVAILABLE_REASON = str(e)
FA3_AVAILABLE = False
try:
import os
_cute_interface_path = os.path.join(
os.path.dirname(__file__), "cute", "interface.py"
)
if not os.path.exists(_cute_interface_path):
raise ImportError("vllm.vllm_flash_attn.cute.interface not found")
FA4_UNAVAILABLE_REASON = None
FA4_AVAILABLE = True
except (ImportError, ModuleNotFoundError) as e:
FA4_UNAVAILABLE_REASON = str(e)
FA4_AVAILABLE = False
# isort: on
DEFAULT_FA_VERSION = 2
def _is_fa2_supported() -> tuple[bool, str | None]:
if not FA2_AVAILABLE:
return False, f"FA2 is unavailable due to: {FA2_UNAVAILABLE_REASON}"
from vllm.platforms import current_platform
if not current_platform.has_device_capability(80):
return False, "FA2 is only supported on devices with compute capability >= 8"
return True, None
def _is_fa3_supported() -> tuple[bool, str | None]:
if not FA3_AVAILABLE:
return False, f"FA3 is unavailable due to: {FA3_UNAVAILABLE_REASON}"
from vllm.platforms import current_platform
if not current_platform.is_device_capability_family(90):
return False, "FA3 is only supported on devices with compute capability 9.x"
return True, None
def _is_fa4_supported() -> tuple[bool, str | None]:
if not FA4_AVAILABLE:
return False, f"FA4 is unavailable due to: {FA4_UNAVAILABLE_REASON}"
from vllm.platforms import current_platform
if not (
current_platform.is_device_capability_family(90)
or current_platform.is_device_capability_family(100)
or current_platform.is_device_capability_family(110)
):
return (
False,
"FA4 is only supported on devices with compute capability 9.x, 10.x, or 11.x",
)
return True, None
def is_fa_version_supported(fa_version: int) -> bool:
if fa_version == 2:
return _is_fa2_supported()[0]
elif fa_version == 3:
return _is_fa3_supported()[0]
elif fa_version == 4:
return _is_fa4_supported()[0]
else:
raise ValueError(f"Unsupported FA version: {fa_version}")
def fa_version_unsupported_reason(fa_version: int) -> str | None:
if fa_version == 2:
return _is_fa2_supported()[1]
elif fa_version == 3:
return _is_fa3_supported()[1]
elif fa_version == 4:
return _is_fa4_supported()[1]
else:
raise ValueError(f"Unsupported FA version: {fa_version}")
#
# For vLLM we only care about `flash_attn_varlen_func` and
# `flash_attn_with_kvcache` so we only maintain wrappers for these two.
#
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
# NOTE only used in FA3
def get_scheduler_metadata(
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads_q,
num_heads_kv,
headdim,
cache_seqlens: torch.Tensor,
qkv_dtype=torch.bfloat16,
headdim_v=None,
cu_seqlens_q: torch.Tensor | None = None,
cu_seqlens_k_new: torch.Tensor | None = None,
cache_leftpad: torch.Tensor | None = None,
page_size: int | None = None,
max_seqlen_k_new=0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
has_softcap=False,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
):
cache_seqlens = maybe_contiguous(cache_seqlens)
if headdim_v is None:
headdim_v = headdim
scheduler_metadata = torch.ops._vllm_fa3_C.get_scheduler_metadata(
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads_q,
num_heads_kv,
headdim,
headdim_v,
qkv_dtype,
cache_seqlens,
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_leftpad,
page_size,
max_seqlen_k_new,
causal,
window_size[0],
window_size[1],
has_softcap,
num_splits,
pack_gqa,
sm_margin,
)
return scheduler_metadata
def flash_attn_varlen_func(
q,
k,
v,
max_seqlen_q,
cu_seqlens_q,
max_seqlen_k,
cu_seqlens_k=None, # only used for non-paged prefill
seqused_k=None,
q_v=None,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size: list[int] | None = None,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
block_table=None,
return_softmax_lse=False,
out=None,
# FA3 Only
scheduler_metadata=None,
q_descale=None,
k_descale=None,
v_descale=None,
num_splits: int = 0,
# Version selector
fa_version: int = DEFAULT_FA_VERSION,
s_aux=None,
cp_world_size=1,
cp_rank=0,
cp_tot_seqused_k=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert cu_seqlens_k is not None or seqused_k is not None, (
"cu_seqlens_k or seqused_k must be provided"
)
assert cu_seqlens_k is None or seqused_k is None, (
"cu_seqlens_k and seqused_k cannot be provided at the same time"
)
assert block_table is None or seqused_k is not None, (
"seqused_k must be provided if block_table is provided"
)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
# custom op does not support non-tuple input
real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
if fa_version == 2:
if (
scheduler_metadata is not None
and q_descale is not None
and k_descale is not None
and v_descale is not None
):
raise NotImplementedError(
"FA2 does not support scheduler_metadata, q_descale, "
"k_descale, v_descale"
)
if s_aux is not None:
raise NotImplementedError("FA2 does not support s_aux")
if num_splits > 1:
raise NotImplementedError("FA2 does not support num_splits > 1")
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
q,
k,
v,
out,
cu_seqlens_q,
# cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
# still wants it so we pass all zeros
dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
seqused_k,
None,
block_table,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
real_window_size[0],
real_window_size[1],
softcap,
return_softmax_lse and dropout_p > 0,
num_splits,
None,
)
elif fa_version == 3:
assert alibi_slopes is None, "Alibi is not supported in FA3"
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
q,
k,
v,
None,
None, # k_new, v_new
q_v,
out,
cu_seqlens_q,
cu_seqlens_k, # cu_seqlens_k
None, # cu_seqlens_k_new
None,
seqused_k, # seqused_q, seqused_k
max_seqlen_q,
max_seqlen_k,
block_table,
None, # kv_batch_idx
None, # leftpad_k
None,
None,
None, # rotary_cos, rotary_sin, seqlens_rotary
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
real_window_size[0],
real_window_size[1],
softcap,
True, # rotary_interleaved
scheduler_metadata,
num_splits,
None, # pack_gqa
0, # sm_margin
s_aux, # s_aux
cp_world_size,
cp_rank,
cp_tot_seqused_k,
)
elif fa_version == 4:
assert alibi_slopes is None, "Alibi is not supported in FA4"
# FA4 on SM90 doesn't support paged KV; SM100+ does
from vllm.platforms import current_platform
if block_table is not None and current_platform.is_device_capability_family(90):
raise NotImplementedError(
"FA4 with paged KV is not supported on SM90 (Hopper). "
"Use FA3 or upgrade to Blackwell (SM100+)."
)
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd
out, softmax_lse = _flash_attn_fwd(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seqused_k=seqused_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
page_table=block_table,
softmax_scale=softmax_scale,
causal=causal,
softcap=softcap,
window_size_left=real_window_size[0] if real_window_size[0] >= 0 else None,
window_size_right=real_window_size[1] if real_window_size[1] >= 0 else None,
num_splits=num_splits,
return_lse=return_softmax_lse,
out=out,
)
else:
raise ValueError(f"Unsupported FA version: {fa_version}")
return (out, softmax_lse) if return_softmax_lse else out
def sparse_attn_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops._vllm_fa2_C.fwd_sparse(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out
def sparse_attn_varlen_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd_sparse(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
cu_seqlens_q,
cu_seqlens_k,
None,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment