Unverified Commit 1ae9b059 authored by Max Podkorytov's avatar Max Podkorytov Committed by GitHub
Browse files

Fix enable memory efficient attention on ROCm (#10564)

* fix enable memory efficient attention on ROCm

while calling CK implementation

* Update attention_processor.py

refactor of picking a set element
parent aad69ac2
...@@ -405,11 +405,12 @@ class Attention(nn.Module): ...@@ -405,11 +405,12 @@ class Attention(nn.Module):
else: else:
try: try:
# Make sure we can run the memory efficient attention # Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention( dtype = None
torch.randn((1, 2, 40), device="cuda"), if attention_op is not None:
torch.randn((1, 2, 40), device="cuda"), op_fw, op_bw = attention_op
torch.randn((1, 2, 40), device="cuda"), dtype, *_ = op_fw.SUPPORTED_DTYPES
) q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
_ = xformers.ops.memory_efficient_attention(q, q, q)
except Exception as e: except Exception as e:
raise e raise e
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment