Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6998e0ec
Commit
6998e0ec
authored
Nov 09, 2022
by
Tri Dao
Browse files
Fix out-of-bound memory read
parent
908a5b22
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
3 deletions
+28
-3
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+1
-1
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+18
-0
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+2
-0
tests/test_flash_attn.py
tests/test_flash_attn.py
+7
-2
No files found.
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
6998e0ec
...
...
@@ -7,7 +7,7 @@
#include "fmha_dgrad_kernel_1xN_loop.h"
// Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1).
// Parallelizing will have better occupancy, but has some overhead due to having to zero out
dq
// Parallelizing will have better occupancy, but has some overhead due to having to zero out
// dq_tmp and having to copy dq_tmp to dq.
int
num_splits_heuristic_bwd
(
int
batch_nheads
,
int
num_SMs
,
int
ctas_per_sm
,
int
seqlen
,
int
blocksize
,
bool
is_causal
)
{
...
...
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
6998e0ec
...
...
@@ -271,6 +271,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
static_assert
(
Cta_tile_p
::
N
%
Cta_tile_p
::
M
==
0
);
int
begin
=
Is_causal
?
loop_step_idx
*
Cta_tile_p
::
N
/
Cta_tile_p
::
M
:
0
;
// Otherwise we'd be reading out-of-bound memory before the loop
if
(
begin
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen_q
)
{
// Still need to zero out dk and dv before returning
static_assert
(
Smem_tile_dk
::
NUM_LDS
==
Smem_tile_dv
::
NUM_LDS
);
uint4
dkv_out
[
Smem_tile_dk
::
NUM_LDS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
Smem_tile_dk
::
NUM_LDS
;
++
i
)
{
dkv_out
[
i
]
=
make_uint4
(
0u
,
0u
,
0u
,
0u
);
}
Gmem_tile_dk
gmem_dk
(
params
.
dk_ptr
,
params
.
dk_row_stride_in_elts
,
params
.
dk_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
if
(
!
Is_first
)
{
gmem_dk
.
move
(
loop_step_idx
);
}
gmem_dk
.
store
(
dkv_out
);
Gmem_tile_dv
gmem_dv
(
params
.
dv_ptr
,
params
.
dv_row_stride_in_elts
,
params
.
dv_head_stride_in_elts
,
params
.
d
,
binfo
,
tidx
,
false
);
if
(
!
Is_first
)
{
gmem_dv
.
move
(
loop_step_idx
);
}
gmem_dv
.
store
(
dkv_out
);
return
;
}
const
int
steps
=
(
params
.
seqlen_q
+
Cta_tile_p
::
M
-
1
)
/
Cta_tile_p
::
M
-
begin
;
// Wind gmem tiles to the correct position.
gmem_q
.
move
(
begin
);
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
6998e0ec
...
...
@@ -280,6 +280,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// k * gridDim.z + 1 for integer k.
const
int
begin_mod_z
=
begin
%
gridDim
.
z
;
begin
=
begin_mod_z
<=
blockIdx
.
z
?
begin
-
begin_mod_z
:
begin
+
gridDim
.
z
-
begin_mod_z
;
// Otherwise we'd be reading out-of-bound memory before the loop
if
((
begin
+
blockIdx
.
z
)
*
Cta_tile_p
::
M
>=
binfo
.
actual_seqlen_q
)
return
;
const
int
steps_og
=
steps
;
steps
-=
begin
;
gmem_q
.
move
(
begin
+
blockIdx
.
z
);
...
...
tests/test_flash_attn.py
View file @
6998e0ec
...
...
@@ -12,6 +12,11 @@ from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_split_func
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
try
:
from
flash_attn.flash_attn_triton
import
flash_attn_func
except
(
ImportError
,
AttributeError
):
# Older version of Triton doesn't have tl.constexpr
flash_attn_func
=
None
is_sm75
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)
==
(
7
,
5
)
is_sm80
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)
==
(
8
,
0
)
...
...
@@ -857,9 +862,8 @@ def test_flash_attn_multigpu():
assert
(
dqkv
-
dqkv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dqkv_pt
-
dqkv_ref
).
abs
().
max
().
item
()
from
flash_attn.flash_attn_triton
import
flash_attn_func
@
pytest
.
mark
.
skipif
(
flash_attn_func
is
None
,
reason
=
'Triton is not installed or is too old'
)
@
pytest
.
mark
.
skipif
(
not
is_sm80
,
reason
=
'Triton version is only tested on A100'
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
...
...
@@ -930,6 +934,7 @@ def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_sha
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
skipif
(
flash_attn_func
is
None
,
reason
=
'Triton is not installed or is too old'
)
@
pytest
.
mark
.
skipif
(
not
is_sm80
,
reason
=
'Triton version is only tested on A100'
)
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment