Unverified Commit 1ae9b059 authored by Max Podkorytov's avatar Max Podkorytov Committed by GitHub
Browse files

Fix enable memory efficient attention on ROCm (#10564)

* fix enable memory efficient attention on ROCm

while calling CK implementation

* Update attention_processor.py

refactor of picking a set element
parent aad69ac2
......@@ -405,11 +405,12 @@ class Attention(nn.Module):
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
dtype = None
if attention_op is not None:
op_fw, op_bw = attention_op
dtype, *_ = op_fw.SUPPORTED_DTYPES
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
_ = xformers.ops.memory_efficient_attention(q, q, q)
except Exception as e:
raise e
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment