Unverified Commit 6bbc5323 authored by Markus Krimmel's avatar Markus Krimmel Committed by GitHub
Browse files

fix: cast the alibi slopes to torch.float32 (#846)

parent 4a73e903
...@@ -101,6 +101,8 @@ class FlashSelfAttention(nn.Module): ...@@ -101,6 +101,8 @@ class FlashSelfAttention(nn.Module):
assert qkv.is_cuda assert qkv.is_cuda
causal = self.causal if causal is None else causal causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None unpadded = cu_seqlens is not None
if self.alibi_slopes is not None:
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
if unpadded: if unpadded:
assert cu_seqlens.dtype == torch.int32 assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None assert max_seqlen is not None
...@@ -185,6 +187,8 @@ class FlashCrossAttention(nn.Module): ...@@ -185,6 +187,8 @@ class FlashCrossAttention(nn.Module):
assert q.is_cuda and kv.is_cuda assert q.is_cuda and kv.is_cuda
causal = self.causal if causal is None else causal causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None unpadded = cu_seqlens is not None
if self.alibi_slopes is not None:
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
if unpadded: if unpadded:
assert cu_seqlens.dtype == torch.int32 assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None assert max_seqlen is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment