Unverified Commit 3bc18127 authored by Chaojun Zhang's avatar Chaojun Zhang Committed by GitHub
Browse files

[XPU] Whisper model support on XPU Platform (#25123)


Signed-off-by: default avatarchzhang <chaojun.zhang@intel.com>
parent bec060fd
...@@ -391,8 +391,8 @@ class MultiHeadAttention(nn.Module): ...@@ -391,8 +391,8 @@ class MultiHeadAttention(nn.Module):
backend = _Backend.FLASH_ATTN backend = _Backend.FLASH_ATTN
use_upstream_fa = True use_upstream_fa = True
if current_platform.is_rocm(): if current_platform.is_rocm() or current_platform.is_xpu():
# currently, only torch_sdpa is supported on rocm # currently, only torch_sdpa is supported on rocm/xpu
self.attn_backend = _Backend.TORCH_SDPA self.attn_backend = _Backend.TORCH_SDPA
else: else:
......
...@@ -282,7 +282,7 @@ def bind_kv_cache( ...@@ -282,7 +282,7 @@ def bind_kv_cache(
# TODO - analyze where runner_kv_caches is used and the right # TODO - analyze where runner_kv_caches is used and the right
# way to ensure it properly reflects multiple attention layers # way to ensure it properly reflects multiple attention layers
# in the same decoder block. # in the same decoder block.
if current_platform.is_cuda(): if current_platform.is_cuda() or current_platform.is_xpu():
# We know that the GPU runner is not impacted by this # We know that the GPU runner is not impacted by this
# case. Some test code depends on runner_kv_caches, but # case. Some test code depends on runner_kv_caches, but
# not in a way that's impacted by ignoring this. # not in a way that's impacted by ignoring this.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment