Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
386e3911
Unverified
Commit
386e3911
authored
Jan 03, 2024
by
jiaxingli
Committed by
GitHub
Jan 02, 2024
Browse files
Fix: implement deterministic backward in mha (#748)
* fix deterministic * fix deterministic
parent
1a2c3e8c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
2 deletions
+8
-2
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+8
-2
No files found.
flash_attn/modules/mha.py
View file @
386e3911
...
...
@@ -61,7 +61,7 @@ class FlashSelfAttention(nn.Module):
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
,
deterministic
=
False
):
super
().
__init__
()
assert
flash_attn_varlen_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_qkvpacked_func
is
not
None
,
"FlashAttention is not installed"
...
...
@@ -69,6 +69,7 @@ class FlashSelfAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
self
.
deterministic
=
deterministic
def
forward
(
self
,
qkv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
"""Implements the multihead softmax attention.
...
...
@@ -103,6 +104,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
deterministic
=
self
.
deterministic
,
)
else
:
return
flash_attn_qkvpacked_func
(
...
...
@@ -111,6 +113,7 @@ class FlashSelfAttention(nn.Module):
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
deterministic
=
self
.
deterministic
,
)
...
...
@@ -125,7 +128,7 @@ class FlashCrossAttention(nn.Module):
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
):
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
alibi_slopes
=
None
,
deterministic
=
False
):
super
().
__init__
()
assert
flash_attn_varlen_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
assert
flash_attn_kvpacked_func
is
not
None
,
"FlashAttention is not installed"
...
...
@@ -133,6 +136,7 @@ class FlashCrossAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
register_buffer
(
"alibi_slopes"
,
alibi_slopes
,
persistent
=
False
)
self
.
deterministic
=
deterministic
def
forward
(
self
,
...
...
@@ -180,6 +184,7 @@ class FlashCrossAttention(nn.Module):
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
deterministic
=
self
.
deterministic
,
)
else
:
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
...
...
@@ -192,6 +197,7 @@ class FlashCrossAttention(nn.Module):
causal
=
causal
,
softmax_scale
=
self
.
softmax_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
deterministic
=
self
.
deterministic
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment