Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ac3b684c
Commit
ac3b684c
authored
Apr 17, 2023
by
Tri Dao
Browse files
Have a separate nn.Dropout module in SelfAttention module
parent
df1344f8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
12 deletions
+12
-12
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+12
-12
No files found.
flash_attn/modules/mha.py
View file @
ac3b684c
...
...
@@ -55,7 +55,7 @@ class FlashSelfAttention(nn.Module):
assert
flash_attn_qkvpacked_func
is
not
None
,
'FlashAttention Triton is not installed'
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
out_p
=
attention_dropout
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
triton
=
triton
def
forward
(
self
,
qkv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
...
...
@@ -84,13 +84,13 @@ class FlashSelfAttention(nn.Module):
assert
max_seqlen
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
drop
out_
p
if
self
.
training
else
0.0
,
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
else
:
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
# Triton version doesn't support dropout
if
self
.
triton
and
(
self
.
drop
out_
p
==
0
or
not
self
.
training
):
if
self
.
triton
and
(
self
.
drop
.
p
==
0
or
not
self
.
training
):
output
=
flash_attn_qkvpacked_func
(
qkv
,
None
,
causal
,
self
.
softmax_scale
)
else
:
qkv
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
)
...
...
@@ -98,7 +98,7 @@ class FlashSelfAttention(nn.Module):
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
output
=
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
drop
out_
p
if
self
.
training
else
0.0
,
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
...
...
@@ -124,7 +124,7 @@ class FlashCrossAttention(nn.Module):
assert
flash_attn_kvpacked_func
is
not
None
,
'FlashAttention Triton is not installed'
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
out_p
=
attention_dropout
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
self
.
triton
=
triton
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
...
...
@@ -156,14 +156,14 @@ class FlashCrossAttention(nn.Module):
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens
,
cu_seqlens_k
,
max_seqlen
,
max_seqlen_k
,
self
.
drop
out_
p
if
self
.
training
else
0.0
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
else
:
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
3
]
==
q
.
shape
[
2
]
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
if
self
.
triton
and
(
self
.
drop
out_
p
==
0.0
or
not
self
.
training
):
# Triton version doesn't support dropout
if
self
.
triton
and
(
self
.
drop
.
p
==
0.0
or
not
self
.
training
):
# Triton version doesn't support dropout
output
=
flash_attn_kvpacked_func
(
q
,
kv
,
None
,
causal
,
self
.
softmax_scale
)
else
:
q
=
rearrange
(
q
,
'b s ... -> (b s) ...'
)
...
...
@@ -174,7 +174,7 @@ class FlashCrossAttention(nn.Module):
dtype
=
torch
.
int32
,
device
=
kv
.
device
)
output
=
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
self
.
drop
out_
p
if
self
.
training
else
0.0
,
self
.
drop
.
p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
...
...
@@ -195,7 +195,7 @@ class SelfAttention(nn.Module):
super
().
__init__
()
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
out_p
=
attention_dropout
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
def
forward
(
self
,
qkv
,
causal
=
None
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
...
...
@@ -224,7 +224,7 @@ class SelfAttention(nn.Module):
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention_drop
=
F
.
drop
out
(
attention
,
self
.
dropout_p
if
self
.
training
else
0.0
)
attention_drop
=
self
.
drop
(
attention
)
output
=
torch
.
einsum
(
'bhts,bshd->bthd'
,
attention_drop
,
v
)
return
output
...
...
@@ -243,7 +243,7 @@ class CrossAttention(nn.Module):
super
().
__init__
()
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
drop
out_p
=
attention_dropout
self
.
drop
=
nn
.
Dropout
(
attention_dropout
)
def
forward
(
self
,
q
,
kv
,
causal
=
None
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
...
...
@@ -276,7 +276,7 @@ class CrossAttention(nn.Module):
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
v
.
dtype
)
attention_drop
=
F
.
drop
out
(
attention
,
self
.
dropout_p
if
self
.
training
else
0.0
)
attention_drop
=
self
.
drop
(
attention
)
output
=
torch
.
einsum
(
'bhts,bshd->bthd'
,
attention_drop
,
v
)
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment