Unverified Commit 51e5b3e3 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Bugfix] Fix ViT with FlashAttention on ROCm (#30703)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent ec154c36
...@@ -464,7 +464,10 @@ class MultiHeadAttention(nn.Module): ...@@ -464,7 +464,10 @@ class MultiHeadAttention(nn.Module):
} }
self.fa_version = None self.fa_version = None
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
and current_platform.is_cuda()
):
self.fa_version = get_flash_attn_version() self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial( self._flash_attn_varlen_func = functools.partial(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment