"src/lib/vscode:/vscode.git/clone" did not exist on "09a81eb22576df8b0e72a880b6e2b15c0b39d9ae"
Commit 1b9facac authored by Tri Dao's avatar Tri Dao
Browse files

Fix QKV interface to allocate output in Python

parent 5badfb78
......@@ -47,8 +47,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal,
return_softmax=return_softmax
)
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment