Unverified Commit 6bbc5323 authored by Markus Krimmel's avatar Markus Krimmel Committed by GitHub
Browse files

fix: cast the alibi slopes to torch.float32 (#846)

parent 4a73e903
......@@ -101,6 +101,8 @@ class FlashSelfAttention(nn.Module):
assert qkv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if self.alibi_slopes is not None:
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
......@@ -185,6 +187,8 @@ class FlashCrossAttention(nn.Module):
assert q.is_cuda and kv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if self.alibi_slopes is not None:
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment