Unverified Commit 51e5b3e3 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Bugfix] Fix ViT with FlashAttention on ROCm (#30703)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent ec154c36
......@@ -464,7 +464,10 @@ class MultiHeadAttention(nn.Module):
}
self.fa_version = None
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
and current_platform.is_cuda()
):
self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment