Commit 88c4e5db authored by Tri Dao's avatar Tri Dao
Browse files

Fix the case when dout is not contiguous

parent a1a5d2ee
...@@ -38,6 +38,7 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens ...@@ -38,6 +38,7 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
as num_splits=3), so effectively the choices are 0, 1, and 2. as num_splits=3), so effectively the choices are 0, 1, and 2.
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine. This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
""" """
dout = dout.contiguous() # CUDA code assumes that dout is contiguous
_, _, _, softmax_d = flash_attn_cuda.bwd( _, _, _, softmax_d = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator) max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment