Unverified Commit 59fe6f29 authored by sihao_li's avatar sihao_li Committed by GitHub
Browse files

[XPU]fallback to TRITON_ATTN on xpu when use float32 dtype (#31762)


Signed-off-by: default avatarsihao.li <sihao.li@intel.com>
parent e7596371
......@@ -52,11 +52,18 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
dtype = attn_selector_config.dtype
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.")
return AttentionBackendEnum.TRITON_ATTN.get_path()
elif dtype == torch.float32:
logger.warning_once(
"Flash Attention on XPU does not support float32 dtype. "
"Falling back to Triton Attention backend."
)
return AttentionBackendEnum.TRITON_ATTN.get_path()
elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
logger.info_once("Using Flash Attention backend.")
return AttentionBackendEnum.FLASH_ATTN.get_path()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment