Unverified Commit c0724fc9 authored by alexeykondrat's avatar alexeykondrat Committed by GitHub
Browse files

[ROCm][Hardware][AMD] Adding Navi21 to fallback to naive attention if Triton is not used (#4658)

parent 86b45ae0
...@@ -231,8 +231,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -231,8 +231,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.attn_func = triton_attention self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend") logger.debug("Using Triton FA in ROCmBackend")
else: else:
# if not using triton, navi3x not use flash-attn either # if not using triton, navi3x/navi21/navi10 do not use flash-attn
if torch.cuda.get_device_capability()[0] == 11: # either
if torch.cuda.get_device_capability()[0] != 9:
self.use_naive_attn = True self.use_naive_attn = True
else: else:
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment