Commit 904f15a4 authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

sorry, I forgot to upload code for issue #26

parent 049f559f
......@@ -173,8 +173,7 @@ class CrossAttention(nn.Module):
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
# with torch.autocast(enabled=False, device_type = 'cuda'):
with torch.autocast(enabled=False, device_type=x.device):
# with torch.autocast(enabled=False, device_type="cuda" if str(x.device).startswith("cuda") else "cpu"):
with torch.autocast(enabled=False, device_type="cuda" if str(x.device).startswith("cuda") else "cpu"):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment