Unverified Commit ab5bbf5a authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Bugfix][Kernel] Fix CUDA 11.8 being broken by FA3 build (#12375)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent 3bb8e2c9
...@@ -576,7 +576,7 @@ else() ...@@ -576,7 +576,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954 GIT_TAG 0aff05f577e8a10086066a00618609199b25231d
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
......
...@@ -598,7 +598,10 @@ if _is_hip(): ...@@ -598,7 +598,10 @@ if _is_hip():
if _is_cuda(): if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"):
# FA3 requires CUDA 12.0 or later
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
if _build_custom_ops(): if _build_custom_ops():
......
...@@ -6,7 +6,9 @@ import torch ...@@ -6,7 +6,9 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (cascade_attention, from vllm.v1.attention.backends.flash_attn import (cascade_attention,
merge_attn_states) merge_attn_states)
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)
NUM_HEADS = [(4, 4), (8, 2), (16, 2)] NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 192, 256] HEAD_SIZES = [128, 192, 256]
...@@ -91,10 +93,9 @@ def test_cascade( ...@@ -91,10 +93,9 @@ def test_cascade(
fa_version: int, fa_version: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) if not is_fa_version_supported(fa_version):
or torch.cuda.get_device_capability() == (8, 9)): pytest.skip(f"Flash attention version {fa_version} not supported due "
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
"insufficient shared memory for some shapes")
current_platform.seed_everything(0) current_platform.seed_everything(0)
......
...@@ -4,8 +4,10 @@ import pytest ...@@ -4,8 +4,10 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_with_kvcache) flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
NUM_HEADS = [(4, 4), (8, 2), (16, 2)] NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
...@@ -95,10 +97,9 @@ def test_flash_attn_with_paged_kv( ...@@ -95,10 +97,9 @@ def test_flash_attn_with_paged_kv(
fa_version: int, fa_version: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) if not is_fa_version_supported(fa_version):
or torch.cuda.get_device_capability() == (8, 9)): pytest.skip(f"Flash attention version {fa_version} not supported due "
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
"insufficient shared memory for some shapes")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
...@@ -182,11 +183,9 @@ def test_varlen_with_paged_kv( ...@@ -182,11 +183,9 @@ def test_varlen_with_paged_kv(
fa_version: int, fa_version: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) if not is_fa_version_supported(fa_version):
or torch.cuda.get_device_capability() == (8, 9)): pytest.skip(f"Flash attention version {fa_version} not supported due "
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
"insufficient shared memory for some shapes")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens] query_lens = [x[0] for x in seq_lens]
......
...@@ -18,17 +18,20 @@ from vllm.attention.backends.utils import ( ...@@ -18,17 +18,20 @@ from vllm.attention.backends.utils import (
get_seq_len_block_table_args, is_all_cross_attn_metadata_set, get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty) is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.envs import VLLM_FLASH_ATTN_VERSION from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
from vllm.vllm_flash_attn import (flash_attn_varlen_func, logger = init_logger(__name__)
flash_attn_with_kvcache,
is_fa_version_supported)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
...@@ -652,6 +655,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -652,6 +655,11 @@ class FlashAttentionImpl(AttentionImpl):
assert VLLM_FLASH_ATTN_VERSION in [2, 3] assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version) assert is_fa_version_supported(self.fa_version)
def forward( def forward(
......
...@@ -10,11 +10,15 @@ import triton.language as tl ...@@ -10,11 +10,15 @@ import triton.language as tl
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported) is_fa_version_supported)
logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
...@@ -143,6 +147,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -143,6 +147,11 @@ class FlashAttentionImpl(AttentionImpl):
assert VLLM_FLASH_ATTN_VERSION in [2, 3] assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION self.fa_version = VLLM_FLASH_ATTN_VERSION
if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))
assert is_fa_version_supported(self.fa_version) assert is_fa_version_supported(self.fa_version)
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment