Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b6aa059b
Commit
b6aa059b
authored
Mar 30, 2023
by
Kirthi Shankar Sivamani
Browse files
Add option for deterministic execution
parent
009a3e71
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
21 deletions
+36
-21
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+36
-21
No files found.
flash_attn/flash_attn_interface.py
View file @
b6aa059b
...
...
@@ -50,7 +50,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
...
...
@@ -65,6 +66,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx
.
max_seqlen
=
max_seqlen
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
...
...
@@ -77,18 +79,19 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
_flash_attn_backward
(
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
ctx
.
max_seqlen
,
ctx
.
max_seqlen
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
max_seqlen
,
ctx
.
max_seqlen
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
...
...
@@ -103,6 +106,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
...
...
@@ -116,18 +120,19 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
_flash_attn_backward
(
dout
,
q
,
kv
[:,
0
],
kv
[:,
1
],
out
,
softmax_lse
,
dq
,
dkv
[:,
0
],
dkv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
...
...
@@ -142,6 +147,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
deterministic
=
deterministic
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
S_dmask
)
@
staticmethod
...
...
@@ -153,18 +159,19 @@ class FlashAttnFunc(torch.autograd.Function):
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnQKVPackedSplitFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
softmax_scale
,
causal
,
return_softmax
,
deterministic
):
# Save rng_state because the backward pass will regenerate the dropout mask
if
dropout_p
>
0
:
rng_state0
=
torch
.
cuda
.
get_rng_state
()
...
...
@@ -196,6 +203,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
ctx
.
batch_size0
=
batch_size0
ctx
.
softmax_scale
=
softmax_scale
ctx
.
causal
=
causal
ctx
.
deterministic
=
deterministic
if
not
return_softmax
:
return
out
else
:
...
...
@@ -223,7 +231,7 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse0
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
[:
batch_size0
+
1
],
cu_seqlens
[:
batch_size0
+
1
],
ctx
.
max_seqlen0
,
ctx
.
max_seqlen0
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
softmax_scale
,
ctx
.
causal
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
s
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
s
):
...
...
@@ -231,16 +239,17 @@ class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
dout
,
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
out
,
softmax_lse1
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
[
batch_size0
:],
cu_seqlens
[
batch_size0
:],
ctx
.
max_seqlen1
,
ctx
.
max_seqlen1
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
generator
=
generator1
ctx
.
softmax_scale
,
ctx
.
causal
,
generator
=
generator1
,
num_splits
=
1
if
ctx
.
deterministic
else
0
,
)
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
if
rng_state0
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
causal
=
False
,
return_attn_probs
=
False
,
deterministic
=
False
):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
...
...
@@ -254,6 +263,7 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...
...
@@ -264,12 +274,12 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
causal
,
return_attn_probs
,
deterministic
)
def
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
return_attn_probs
=
False
,
deterministic
=
False
):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
...
...
@@ -287,6 +297,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...
...
@@ -298,11 +309,12 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
"""
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return_attn_probs
,
deterministic
)
def
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
deterministic
=
False
):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
...
...
@@ -321,6 +333,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...
...
@@ -331,12 +344,12 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnFunc
.
apply
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
deterministic
)
def
flash_attn_unpadded_qkvpacked_split_func
(
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
causal
=
False
,
return_attn_probs
=
False
,
deterministic
=
False
):
"""
Split attention into 2 kernels running on 2 separate streams for performance reason:
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
...
...
@@ -358,6 +371,7 @@ def flash_attn_unpadded_qkvpacked_split_func(
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
deterministic: bool. Whether or not to ensure deterministic execution.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
...
...
@@ -368,7 +382,8 @@ def flash_attn_unpadded_qkvpacked_split_func(
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return
FlashAttnQKVPackedSplitFunc
.
apply
(
qkv
,
cu_seqlens
,
max_seqlen0
,
max_seqlen1
,
batch_size0
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
deterministic
)
def
flash_attn_func
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment