Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1c41d2b0
"...serve/git@developer.sourcefind.cn:guobj/qwen_lmdeploy.git" did not exist on "edb7c6eccf41a378d2a05dc87f8c10924fd324a4"
Commit
1c41d2b0
authored
Aug 01, 2023
by
Tri Dao
Browse files
Fix race condition in bwd (overwriting sK)
parent
a4e5d1ed
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
18 deletions
+26
-18
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+5
-3
setup.py
setup.py
+1
-0
tests/test_flash_attn.py
tests/test_flash_attn.py
+20
-15
No files found.
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
1c41d2b0
...
@@ -1020,9 +1020,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -1020,9 +1020,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
taccdVrdV
=
smem_thr_copy_dKV
.
retile_S
(
rdV
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdVrdV
=
smem_thr_copy_dKV
.
retile_S
(
rdV
);
// ((Atom,AtomNum), MMA_N, MMA_N)
Tensor
taccdVsdV
=
smem_thr_copy_dKV
.
partition_D
(
sdV
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccdVsdV
=
smem_thr_copy_dKV
.
partition_D
(
sdV
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
// If we don't need syncthreads here since we're writing to the same location as sK and sV.
// We need syncthreads here since we're writing to the same location as sK and sV.
// Unless Is_V_in_regs. If Is_last, there's already a __syncthreads() at the end of the loop.
// Without syncthreads, some thread might modify the location of sK while another thread
if
(
Kernel_traits
::
Is_V_in_regs
&&
!
Is_last
)
{
__syncthreads
();
}
// is reading it for dQ gemm, leading to a race condition.
// If Is_last, there's already a __syncthreads() at the end of the loop.
if
(
!
Is_last
)
{
__syncthreads
();
}
copy
(
smem_thr_copy_dKV
,
taccdKrdK
,
taccdKsdK
);
copy
(
smem_thr_copy_dKV
,
taccdKrdK
,
taccdKsdK
);
copy
(
smem_thr_copy_dKV
,
taccdVrdV
,
taccdVsdV
);
copy
(
smem_thr_copy_dKV
,
taccdVrdV
,
taccdVsdV
);
...
...
setup.py
View file @
1c41d2b0
...
@@ -172,6 +172,7 @@ ext_modules.append(
...
@@ -172,6 +172,7 @@ ext_modules.append(
"--expt-extended-lambda"
,
"--expt-extended-lambda"
,
"--use_fast_math"
,
"--use_fast_math"
,
"--ptxas-options=-v"
,
"--ptxas-options=-v"
,
# "--ptxas-options=-O2",
"-lineinfo"
"-lineinfo"
]
]
+
generator_flag
+
generator_flag
...
...
tests/test_flash_attn.py
View file @
1c41d2b0
...
@@ -785,44 +785,49 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
...
@@ -785,44 +785,49 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
# @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
#
@pytest.mark.parametrize('causal', [False, True])
#
@pytest.mark.parametrize('causal', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
Fals
e
])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
128
])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
128
,
256
,
384
,
512
,
768
,
1024
,
2048
])
#
@pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
#
@pytest.mark.parametrize('seqlen', [1
93
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
1
28
])
# @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.0
])
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.0
])
def
test_flash_attn_race_condition
(
seqlen
,
d
,
dropout_p
,
causal
,
dtype
):
def
test_flash_attn_race_condition
(
seqlen
,
d
,
dropout_p
,
causal
,
dtype
):
if
seqlen
>=
2048
and
torch
.
cuda
.
get_device_properties
(
'cuda'
).
total_memory
<=
16
*
2
**
30
:
pytest
.
skip
()
# Reference implementation OOM
device
=
'cuda'
device
=
'cuda'
# set seed
# set seed
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
batch_size
=
32
batch_size
=
60
# Sometimes we need large batch size for the race conditions to trigger
nheads
=
4
nheads
=
4
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out0
,
lse0
,
_
=
flash_attn_qkvpacked_func
(
out0
,
lse0
,
_
=
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
qkv
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
)
)
g
=
torch
.
randn_like
(
out0
)
g
=
torch
.
randn_like
(
out0
)
dqkv0
,
=
torch
.
autograd
.
grad
(
out0
,
qkv
,
g
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
dqkv0
,
=
torch
.
autograd
.
grad
(
out0
,
qkv
,
g
)
# Numerical error if we just do any arithmetic on dq
dq_atol
=
2
*
((
dqkv0
[:,
:,
0
]
+
0.3
-
0.3
)
-
dqkv0
[:,
:,
0
]).
abs
().
max
().
item
()
for
_
in
range
(
200
):
for
i
in
range
(
200
):
torch
.
random
.
manual_seed
(
0
)
torch
.
random
.
manual_seed
(
0
)
out
,
lse
,
S_dmask
=
flash_attn_qkvpacked_func
(
out
,
lse
,
S_dmask
=
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
qkv
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
)
)
assert
torch
.
equal
(
out
,
out0
)
assert
torch
.
equal
(
out
,
out0
)
assert
torch
.
equal
(
lse
,
lse0
)
assert
torch
.
equal
(
lse
,
lse0
)
# sm_lse has some parts that are uninitialized from torch.empty
# assert torch.equal(sm_lse, sm_lse_0)
if
not
(
is_sm75
and
d
==
128
):
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
dqkv
,
=
torch
.
autograd
.
grad
(
out
,
qkv
,
g
)
dqkv
,
=
torch
.
autograd
.
grad
(
out
,
qkv
,
g
)
assert
torch
.
equal
(
dqkv
[:,
:,
0
],
dqkv0
[:,
:,
0
])
dq_equal
=
torch
.
allclose
(
dqkv
[:,
:,
0
],
dqkv0
[:,
:,
0
],
atol
=
dq_atol
)
if
not
dq_equal
:
dq0
=
dqkv0
[:,
:,
0
]
dq
=
dqkv
[:,
:,
0
]
print
(
f
'Iter
{
i
}
,
{
dq_atol
=
}
, dQ max diff:
{
(
dqkv
[:,
:,
0
]
-
dqkv0
[:,
:,
0
]).
abs
().
max
().
item
()
}
'
)
assert
dq_equal
assert
torch
.
equal
(
dqkv
[:,
:,
1
],
dqkv0
[:,
:,
1
])
assert
torch
.
equal
(
dqkv
[:,
:,
1
],
dqkv0
[:,
:,
1
])
assert
torch
.
equal
(
dqkv
[:,
:,
2
],
dqkv0
[:,
:,
2
])
assert
torch
.
equal
(
dqkv
[:,
:,
2
],
dqkv0
[:,
:,
2
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment