Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b6aa059b
Commit
b6aa059b
authored
Mar 30, 2023
by
Kirthi Shankar Sivamani
Browse files
Add option for deterministic execution
parent
009a3e71
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
21 deletions
+36
-21
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+36
-21
No files found.
flash_attn/flash_attn_interface.py
View file @
b6aa059b
...
@@ -50,7 +50,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
...
@@ -50,7 +50,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
...
@@ -65,6 +66,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -65,6 +66,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx
.
max_seqlen
=
max_seqlen
ctx
.
max_seqlen
=
max_seqlen
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
...
@@ -77,18 +79,19 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -77,18 +79,19 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
_flash_attn_backward
(
_flash_attn_backward
(
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse
,
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
ctx
.
max_seqlen
,
ctx
.
max_seqlen
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
max_seqlen
,
ctx
.
max_seqlen
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
)
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
def
forward
(
ctx
,
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
...
@@ -103,6 +106,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -103,6 +106,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
...
@@ -116,18 +120,19 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -116,18 +120,19 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
_flash_attn_backward
(
_flash_attn_backward
(
dout
,
q
,
kv
[:,
0
],
kv
[:,
1
],
out
,
softmax_lse
,
dout
,
q
,
kv
[:,
0
],
kv
[:,
1
],
out
,
softmax_lse
,
dq
,
dkv
[:,
0
],
dkv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
dq
,
dkv
[:,
0
],
dkv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
)
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
def
forward
(
ctx
,
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
...
@@ -142,6 +147,7 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -142,6 +147,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
@
staticmethod
...
@@ -153,18 +159,19 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -153,18 +159,19 @@ class FlashAttnFunc(torch.autograd.Function):
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
_flash_attn_backward
(
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
)
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnQKVPackedSplitFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnQKVPackedSplitFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
# Save rng_state because the backward pass will regenerate the dropout mask
if
dropout_p
>
0
:
if
dropout_p
>
0
:
rng_state0
=
torch
.
cuda
.
get_rng_state
()
rng_state0
=
torch
.
cuda
.
get_rng_state
()
...
@@ -196,6 +203,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
...
@@ -196,6 +203,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
ctx
.
batch_size0
=
batch_size0
ctx
.
batch_size0
=
batch_size0
ctx
.
softmax_scale
=
softmax_scale
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
causal
=
causal
ctx
.
deterministic
=
deterministic
if
not
return_softmax
:
if
not
return_softmax
:
return
out
return
out
else
:
else
:
...
@@ -223,7 +231,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
...
@@ -223,7 +231,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse0
,
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse0
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
[:
batch_size0
+
1
],
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
[:
batch_size0
+
1
],
cu_seqlens
[:
batch_size0
+
1
],
ctx
.
max_seqlen0
,
ctx
.
max_seqlen0
,
ctx
.
dropout_p
,
cu_seqlens
[:
batch_size0
+
1
],
ctx
.
max_seqlen0
,
ctx
.
max_seqlen0
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
)
s
=
torch
.
cuda
.
Stream
()
s
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
s
):
with
torch
.
cuda
.
stream
(
s
):
...
@@ -231,16 +239,17 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
...
@@ -231,16 +239,17 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse1
,
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse1
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
[
batch_size0
:],
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
[
batch_size0
:],
cu_seqlens
[
batch_size0
:],
ctx
.
max_seqlen1
,
ctx
.
max_seqlen1
,
ctx
.
dropout_p
,
cu_seqlens
[
batch_size0
:],
ctx
.
max_seqlen1
,
ctx
.
max_seqlen1
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
generator
=
generator1
ctx
.
softmax_scale
,
ctx
.
causal
,
generator
=
generator1
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
)
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
if
rng_state0
is
not
None
:
if
rng_state0
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
=
None
,
def
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
causal
=
False
,
return_attn_probs
=
False
,
deterministic
=
False
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
Arguments:
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
...
@@ -254,6 +263,7 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
...
@@ -254,6 +263,7 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
Return:
out: (total, nheads, headdim).
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...
@@ -264,12 +274,12 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
...
@@ -264,12 +274,12 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
pattern (negative means that location was dropped, nonnegative means it was kept).
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
"""
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
causal
,
return_attn_probs
,
deterministic
)
def
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
def
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
return_attn_probs
=
False
,
deterministic
=
False
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
Arguments:
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
...
@@ -287,6 +297,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
...
@@ -287,6 +297,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
Return:
out: (total, nheads, headdim).
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...
@@ -298,11 +309,12 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
...
@@ -298,11 +309,12 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
"""
"""
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return_attn_probs
,
deterministic
)
def
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
def
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
deterministic
=
False
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
Arguments:
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
...
@@ -321,6 +333,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
...
@@ -321,6 +333,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
Return:
out: (total, nheads, headdim).
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...
@@ -331,12 +344,12 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
...
@@ -331,12 +344,12 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
pattern (negative means that location was dropped, nonnegative means it was kept).
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
"""
return
FlashAttnFunc
.
apply
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
return
FlashAttnFunc
.
apply
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
deterministic
)
def
flash_attn_unpadded_qkvpacked_split_func
(
def
flash_attn_unpadded_qkvpacked_split_func
(
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
softmax_scale
=
None
,
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
causal
=
False
,
return_attn_probs
=
False
,
deterministic
=
False
):
"""
"""
Split attention into 2 kernels running on 2 separate streams for performance reason:
Split attention into 2 kernels running on 2 separate streams for performance reason:
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
...
@@ -358,6 +371,7 @@ def flash_attn_unpadded_qkvpacked_split_func(
...
@@ -358,6 +371,7 @@ def flash_attn_unpadded_qkvpacked_split_func(
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
Return:
out: (total, nheads, headdim).
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...
@@ -368,7 +382,8 @@ def flash_attn_unpadded_qkvpacked_split_func(
...
@@ -368,7 +382,8 @@ def flash_attn_unpadded_qkvpacked_split_func(
pattern (negative means that location was dropped, nonnegative means it was kept).
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
"""
return
FlashAttnQKVPackedSplitFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
return
FlashAttnQKVPackedSplitFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
deterministic
)
def
flash_attn_func
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
def
flash_attn_func
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment