Commit a7cce320 authored by Jimmy Zhang's avatar Jimmy Zhang
Browse files

dropout variable

parent 254e8815
......@@ -384,10 +384,11 @@ class FlashSelfAttention(torch.nn.Module):
is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device)
self.dropout_p = 0
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
0.0,
self.dropout_p,
softmax_scale=self.softmax_scale, causal=is_causal
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment