Commit 7f417161 authored by zhuwenwen's avatar zhuwenwen
Browse files

switching to the implementation of MHA in FA

parent f3731273
......@@ -416,11 +416,14 @@ class MultiHeadAttention(nn.Module):
backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if current_platform.is_rocm() or current_platform.is_xpu():
if current_platform.is_xpu():
# currently, only torch_sdpa is supported on rocm/xpu
self.attn_backend = _Backend.TORCH_SDPA
elif current_platform.is_rocm():
self.attn_backend = backend if backend in {
_Backend.FLASH_ATTN,
} else _Backend.TORCH_SDPA
else:
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
......@@ -438,8 +441,12 @@ class MultiHeadAttention(nn.Module):
from flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
if current_platform.is_rocm():
from flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment