Unverified Commit ed6ea065 authored by Pavani Majety's avatar Pavani Majety Committed by GitHub
Browse files

[Hardware] Update the flash attn tag to support Blackwell (#14244)

parent 5ee10e99
...@@ -38,7 +38,7 @@ else() ...@@ -38,7 +38,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade GIT_TAG 9bfa9869829d8c593527eb34c5271d0090f7ccc9
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
...@@ -64,4 +64,4 @@ install( ...@@ -64,4 +64,4 @@ install(
DESTINATION vllm_flash_attn DESTINATION vllm_flash_attn
COMPONENT _vllm_fa3_C COMPONENT _vllm_fa3_C
FILES_MATCHING PATTERN "*.py" FILES_MATCHING PATTERN "*.py"
) )
\ No newline at end of file
...@@ -595,7 +595,7 @@ def get_flash_attn_version(): ...@@ -595,7 +595,7 @@ def get_flash_attn_version():
# if hopper default to FA3, otherwise stick to FA2 for now # if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to # TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both # use FA3 as default for both
if current_platform.get_device_capability()[0] >= 9: if current_platform.get_device_capability()[0] == 9:
fa_version = 3 if is_fa_version_supported(3) else 2 fa_version = 3 if is_fa_version_supported(3) else 2
else: else:
fa_version = 2 fa_version = 2
...@@ -603,6 +603,11 @@ def get_flash_attn_version(): ...@@ -603,6 +603,11 @@ def get_flash_attn_version():
if envs.VLLM_FLASH_ATTN_VERSION is not None: if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION fa_version = envs.VLLM_FLASH_ATTN_VERSION
if (current_platform.get_device_capability()[0] == 10
and envs.VLLM_FLASH_ATTN_VERSION == 3):
logger.warning("Cannot use FA version 3 on Blackwell platform",
"defaulting to FA version 2.")
fa_version = 2
if not is_fa_version_supported(fa_version): if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s", logger.error("Cannot use FA version %d is not supported due to %s",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment