Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
db2f8069
Commit
db2f8069
authored
Nov 19, 2023
by
Tri Dao
Browse files
Write zero to out / grad if seqlen_q or seqlen_k is zero
parent
43bb6d8a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
48 deletions
+62
-48
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+16
-3
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+3
-2
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+43
-43
No files found.
csrc/flash_attn/flash_api.cpp
View file @
db2f8069
...
...
@@ -405,8 +405,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
run_mha_fwd
(
params
,
stream
);
if
(
seqlen_k
>
0
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
run_mha_fwd
(
params
,
stream
);
}
else
{
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out
.
zero_
();
softmax_lse
.
fill_
(
std
::
numeric_limits
<
float
>::
infinity
());
}
at
::
Tensor
out_padded
=
out
;
if
(
head_size_og
%
8
!=
0
)
{
...
...
@@ -794,7 +800,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
launch
(
params
,
stream
,
/*configure=*/
false
);
if
(
seqlen_q
>
0
)
{
launch
(
params
,
stream
,
/*configure=*/
false
);
}
else
{
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk
.
zero_
();
dv
.
zero_
();
softmax_d
.
zero_
();
}
// For MQA/GQA we need to sum dK and dV across the groups
if
(
num_heads_k
!=
num_heads
)
{
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
db2f8069
...
...
@@ -444,7 +444,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
constexpr
bool
Double_buffer
=
!
Kernel_traits
::
No_double_buffer
;
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
if
(
n_block
*
kBlockN
>=
binfo
.
actual_seqlen_k
||
binfo
.
actual_seqlen_q
==
0
)
return
;
if
(
n_block
*
kBlockN
>=
binfo
.
actual_seqlen_k
)
return
;
int
m_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_q
,
kBlockM
);
if
(
Is_local
)
{
...
...
@@ -672,7 +672,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// We might need to exit early and write 0 to dK and dV for those blocks.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
if
(
Is_local
&&
m_block
<
m_block_min
)
{
// This also covers the case where actual_seqlen_q == 0
if
((
Is_local
||
!
Is_even_MN
)
&&
m_block
<
m_block_min
)
{
const
index_t
row_offset_dk
=
binfo
.
k_offset
(
params
.
dk_batch_stride
,
params
.
dk_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dk_row_stride
+
bidh
*
params
.
dk_head_stride
;
const
index_t
row_offset_dv
=
binfo
.
k_offset
(
params
.
dv_batch_stride
,
params
.
dv_row_stride
,
bidb
)
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
db2f8069
...
...
@@ -91,7 +91,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
MMA_M
=
kBlockM
/
decltype
(
size
<
0
>
(
typename
Kernel_traits
::
TiledMma
::
TiledShape_MNK
{}))
::
value
;
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
...
...
@@ -101,50 +101,50 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
// We exit early and write 0 to gO and gLSE.
// Otherwise we might read OOB elements from gK and gV.
if
(
n_block_max
<=
n_block_min
)
{
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
params
.
rng_state
[
0
]
=
std
::
get
<
0
>
(
seeds
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
Tensor
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
Tensor
tOrO
=
make_tensor
<
Element
>
(
shape
(
tOgO
));
clear
(
tOrO
);
// Construct identity layout for sO
Tensor
cO
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gO
),
size
<
1
>
(
gO
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor
tOcO
=
gmem_thr_copy_O
.
partition_D
(
cO
);
Tensor
tOpO
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tOgO
)));
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV.
if
((
Is_causal
||
Is_local
||
!
Is_even_MN
)
&&
n_block_max
<=
n_block_min
)
{
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
params
.
rng_state
[
0
]
=
std
::
get
<
0
>
(
seeds
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
Tensor
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
Tensor
tOrO
=
make_tensor
<
Element
>
(
shape
(
tOgO
));
clear
(
tOrO
);
// Construct identity layout for sO
Tensor
cO
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gO
),
size
<
1
>
(
gO
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor
tOcO
=
gmem_thr_copy_O
.
partition_D
(
cO
);
Tensor
tOpO
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tOgO
)));
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tOgO
);
++
m
)
{
const
int
row
=
get
<
0
>
(
tOcO
(
0
,
m
,
0
));
if
(
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
&&
get
<
1
>
(
tOcO
(
0
,
m
,
0
))
==
0
)
{
gLSE
(
row
)
=
INFINITY
;
}
}
return
;
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tOgO
);
++
m
)
{
const
int
row
=
get
<
0
>
(
tOcO
(
0
,
m
,
0
));
if
(
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
&&
get
<
1
>
(
tOcO
(
0
,
m
,
0
))
==
0
)
{
gLSE
(
row
)
=
INFINITY
;
}
}
return
;
}
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment