"benchmarks/vscode:/vscode.git/clone" did not exist on "55dcce91df150f576c28520d987eaf1498fcb0bd"
Unverified Commit 0032903a authored by Travis Johnson's avatar Travis Johnson Committed by GitHub
Browse files

[Bugfix] detect alibi and revert to FA2 (#15231)


Signed-off-by: default avatarTravis Johnson <tsjohnso@us.ibm.com>
parent 47195057
...@@ -630,7 +630,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -630,7 +630,8 @@ class FlashAttentionImpl(AttentionImpl):
self.sliding_window = ((sliding_window - 1, self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1)) 0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=self.alibi_slopes is not None)
if (is_quantized_kv_cache(self.kv_cache_dtype) if (is_quantized_kv_cache(self.kv_cache_dtype)
and self.vllm_flash_attn_version != 3): and self.vllm_flash_attn_version != 3):
raise NotImplementedError( raise NotImplementedError(
......
...@@ -7,7 +7,7 @@ from vllm.logger import init_logger ...@@ -7,7 +7,7 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
def get_flash_attn_version() -> Optional[int]: def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
# import here to avoid circular dependencies # import here to avoid circular dependencies
from vllm.platforms import current_platform from vllm.platforms import current_platform
try: try:
...@@ -28,7 +28,13 @@ def get_flash_attn_version() -> Optional[int]: ...@@ -28,7 +28,13 @@ def get_flash_attn_version() -> Optional[int]:
# 3. fallback for unsupported combinations # 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3: if device_capability.major == 10 and fa_version == 3:
logger.warning("Cannot use FA version 3 on Blackwell platform", logger.warning_once(
"Cannot use FA version 3 on Blackwell platform "
"defaulting to FA version 2.")
fa_version = 2
if requires_alibi and fa_version == 3:
logger.warning_once("Cannot use FA version 3 with ALiBi, "
"defaulting to FA version 2.") "defaulting to FA version 2.")
fa_version = 2 fa_version = 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment