"ppstructure/vscode:/vscode.git/clone" did not exist on "55d8af94f52dc8fd4c6356ef98c5a703e2870dbf"
Commit 898dd4bb authored by Tri Dao's avatar Tri Dao
Browse files

Pass seqused_k to _flash_attn_varlen_forward

parent 7ef24848
...@@ -77,12 +77,13 @@ def _flash_attn_varlen_forward( ...@@ -77,12 +77,13 @@ def _flash_attn_varlen_forward(
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size=(-1, -1),
softcap, softcap=0.0,
alibi_slopes, alibi_slopes=None,
return_softmax, return_softmax=False,
block_table=None, block_table=None,
leftpad_k=None, leftpad_k=None,
seqused_k=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
...@@ -93,7 +94,7 @@ def _flash_attn_varlen_forward( ...@@ -93,7 +94,7 @@ def _flash_attn_varlen_forward(
None, None,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
None, seqused_k,
leftpad_k, leftpad_k,
block_table, block_table,
alibi_slopes, alibi_slopes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment