Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
c65b5106
Commit
c65b5106
authored
Aug 16, 2023
by
Tri Dao
Browse files
Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal
parent
0f7853c6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
3 deletions
+43
-3
.github/workflows/publish.yml
.github/workflows/publish.yml
+4
-1
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+5
-1
flash_attn/__init__.py
flash_attn/__init__.py
+1
-1
tests/test_flash_attn.py
tests/test_flash_attn.py
+33
-0
No files found.
.github/workflows/publish.yml
View file @
c65b5106
...
@@ -43,7 +43,7 @@ jobs:
...
@@ -43,7 +43,7 @@ jobs:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os
:
[
ubuntu-20.04
]
os
:
[
ubuntu-20.04
]
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
]
python-version
:
[
'
3.7'
,
'
3.8'
,
'
3.9'
,
'
3.10'
,
'
3.11'
]
torch-version
:
[
'
1.12.1'
,
'
1.13.1'
,
'
2.0.1'
,
'
2.1.0.dev20230731'
]
torch-version
:
[
'
1.12.1'
,
'
1.13.1'
,
'
2.0.1'
,
'
2.1.0.dev20230731'
]
cuda-version
:
[
'
11.6.2'
,
'
11.7.1'
,
'
11.8.0'
,
'
12.1.0'
]
cuda-version
:
[
'
11.6.2'
,
'
11.7.1'
,
'
11.8.0'
,
'
12.1.0'
]
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
...
@@ -52,6 +52,9 @@ jobs:
...
@@ -52,6 +52,9 @@ jobs:
# when building without C++11 ABI and using it on nvcr images.
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi
:
[
'
FALSE'
,
'
TRUE'
]
cxx11_abi
:
[
'
FALSE'
,
'
TRUE'
]
exclude
:
exclude
:
# Pytorch <= 1.12 does not support Python 3.11
-
torch-version
:
'
1.12'
python-version
:
'
3.11'
# Pytorch >= 2.0 only supports Python >= 3.8
# Pytorch >= 2.0 only supports Python >= 3.8
-
torch-version
:
'
2.0.1'
-
torch-version
:
'
2.0.1'
python-version
:
'
3.7'
python-version
:
'
3.7'
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
c65b5106
...
@@ -820,7 +820,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -820,7 +820,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
}
}
}
else
{
}
else
{
// Putting this causal masking right after acc_s is *much* slower for some reason.
// Putting this causal masking right after acc_s is *much* slower for some reason.
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
)
{
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements not beyond actual_seqlen_k.
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
||
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
))
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
...
...
flash_attn/__init__.py
View file @
c65b5106
__version__
=
"2.0.
7
"
__version__
=
"2.0.
8
"
from
flash_attn.flash_attn_interface
import
flash_attn_func
from
flash_attn.flash_attn_interface
import
flash_attn_func
from
flash_attn.flash_attn_interface
import
flash_attn_kvpacked_func
from
flash_attn.flash_attn_interface
import
flash_attn_kvpacked_func
...
...
tests/test_flash_attn.py
View file @
c65b5106
...
@@ -924,3 +924,36 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
...
@@ -924,3 +924,36 @@ def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
assert
(
q
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
q_pt
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
assert
(
q
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
q_pt
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
assert
(
k
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
k_pt
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
assert
(
k
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
k_pt
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
assert
(
v
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
v_pt
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
assert
(
v
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
v_pt
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
16
,
32
,
64
])
# @pytest.mark.parametrize('d', [16])
def
test_flash_attn_bwd_varlen_overflow
(
d
,
causal
,
dtype
):
""" We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0 or varlen.
"""
device
=
'cuda'
# set seed
torch
.
random
.
manual_seed
(
0
)
nheads
=
5
q_cuseqlen
=
torch
.
tensor
([
0
,
76
,
110
,
256
],
device
=
device
,
dtype
=
torch
.
int32
)
k_cuseqlen
=
torch
.
tensor
([
0
,
1
,
2
,
3
],
device
=
device
,
dtype
=
torch
.
int32
)
Mq
=
256
Mk
=
3
q
=
torch
.
randn
([
Mq
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
*
3
k
,
v
=
[
torch
.
randn
([
Mk
,
nheads
,
d
],
dtype
=
dtype
,
device
=
device
)
*
3
for
_
in
range
(
2
)]
q
.
requires_grad_
(
True
)
k
.
requires_grad_
(
True
)
v
.
requires_grad_
(
True
)
out
=
flash_attn_varlen_func
(
q
,
k
,
v
,
q_cuseqlen
,
k_cuseqlen
,
Mq
,
Mk
,
causal
=
causal
)
g
=
torch
.
randn_like
(
out
)
out
.
backward
(
g
)
assert
not
q
.
grad
.
isnan
().
any
()
assert
not
k
.
grad
.
isnan
().
any
()
assert
not
v
.
grad
.
isnan
().
any
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment