Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a4f148b6
Commit
a4f148b6
authored
Jul 31, 2023
by
Tri Dao
Browse files
Fix masking of bwd when seqlen is not divisible by 128
parent
184b992d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
145 additions
and
37 deletions
+145
-37
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+37
-26
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+9
-9
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+5
-2
tests/test_flash_attn.py
tests/test_flash_attn.py
+94
-0
No files found.
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
a4f148b6
...
...
@@ -415,7 +415,7 @@ inline __device__ void convert_dKV(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_M
,
bool
Is_even_K
,
bool
Is_first
,
bool
Is_last
,
bool
Seq_parallel
=
false
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_M
N
,
bool
Is_even_K
,
bool
Is_first
,
bool
Is_last
,
bool
Seq_parallel
=
false
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv_1colblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
n_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
...
...
@@ -436,7 +436,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
constexpr
int
AtomLayoutMS
=
Kernel_traits
::
AtomLayoutMSdP
;
constexpr
bool
Double_buffer
=
!
Kernel_traits
::
No_double_buffer
;
const
BlockInfo
<
/*Varlen=*/
!
Is_even_M
>
binfo
(
params
,
bidb
);
const
BlockInfo
<
/*Varlen=*/
!
Is_even_M
N
>
binfo
(
params
,
bidb
);
if
(
n_block
*
kBlockN
>=
binfo
.
actual_seqlen_k
||
binfo
.
actual_seqlen_q
==
0
)
return
;
int
m_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_q
,
kBlockM
);
...
...
@@ -668,10 +668,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_thr_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_thr_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
return
;
...
...
@@ -687,7 +687,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
if
(
Kernel_traits
::
Is_V_in_regs
)
{
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_thr_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
cp_async_fence
();
...
...
@@ -697,18 +697,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
tdOrO
=
make_fragment_like
(
tdOgO
);
if
(
!
Is_first
)
{
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_M
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_M
N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_thr_copy_dO
,
tdOgdO
,
tdOsdO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
else
{
flash
::
copy
<
Is_even_M
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_M
N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_thr_copy_dO
,
tdOgdO
,
tdOrdO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
flash
::
copy
<
Is_even_M
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_M
N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_thr_copy_dO
,
tdOgO
,
tdOrO
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
flash
::
copy
<
Is_even_M
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_M
N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_thr_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
...
...
@@ -722,7 +722,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
Is_even_M
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
lse
(
mi
)
=
Is_even_M
N
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
}
// Tensor tKrK = make_fragment_like(tKsK);
...
...
@@ -730,11 +730,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// copy(gmem_thr_copy_QKV, tKgK, tKrK);
// // if (cute::thread(1, 0)) { print(tKrK); }
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_thr_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
if
(
!
Kernel_traits
::
Is_V_in_regs
)
{
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_thr_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
...
...
@@ -783,15 +783,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
// if (cute::thread(32, 0)) { print(scores); }
// We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would
// be some finite value for those indices. In the end when we multiply with K to get dQ,
// the corresponding values of K would be 0, so the result would still be correct.
// Putting this causal masking right after acc_s is *much* slower for some reason.
if
(
Is_causal
&&
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS
*
16
);
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
// so the result would still be correct.
// However, it's possible that the values in acc_s are so large that they overflow
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
// So we need to mask out the elements beyond actual_seqlen_k.
if
(
!
Is_causal
)
{
if
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
);
}
}
else
{
// Putting this causal masking right after acc_s is *much* slower for some reason.
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS
*
16
);
}
}
// if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value.
...
...
@@ -978,7 +989,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
tdQcdQ
=
gmem_thr_copy_dQ
.
partition_D
(
cdQ
);
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tdQgdQ
);
++
m
)
{
if
(
Is_even_M
||
get
<
0
>
(
tdQcdQ
(
0
,
m
,
0
))
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
)
{
if
(
Is_even_M
N
||
get
<
0
>
(
tdQcdQ
(
0
,
m
,
0
))
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
)
{
copy
(
gmem_thr_copy_dQ
,
tdQrdQ
(
_
,
m
,
_
),
tdQgdQ
(
_
,
m
,
_
));
}
}
...
...
@@ -1044,10 +1055,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_thr_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_thr_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
...
...
@@ -1487,7 +1498,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_M
,
bool
Is_even_K
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_M
N
,
bool
Is_even_K
,
typename
Params
>
inline
__device__
void
compute_dq_dk_dv_seqk_parallel
(
const
Params
&
params
)
{
const
int
n_block
=
blockIdx
.
x
;
...
...
@@ -1496,7 +1507,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
// The block index for the head.
const
int
bidh
=
blockIdx
.
z
;
compute_dq_dk_dv_1colblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_M
,
Is_even_K
,
false
,
false
,
/*Seq_parallel=*/
true
>
(
params
,
bidb
,
bidh
,
n_block
);
compute_dq_dk_dv_1colblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_M
N
,
Is_even_K
,
false
,
false
,
/*Seq_parallel=*/
true
>
(
params
,
bidb
,
bidh
,
n_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
a4f148b6
...
...
@@ -23,9 +23,9 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
flash
::
compute_dq_dk_dv
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_M
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_M
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_M
N
,
bool
Is_even_K
>
__global__
void
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
(
Flash_bwd_params
params
)
{
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_M
,
Is_even_K
>
(
params
);
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_M
N
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
>
...
...
@@ -53,17 +53,17 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
flash_bwd_dot_do_o_kernel
<
true
,
Kernel_traits
><<<
grid_m
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
// We
also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check
//
for cu_seqlens_q as well
.
const
bool
is_even_M
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
// We
want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
//
a multiple of kBlockN, we'll need to apply mask in the loop
.
const
bool
is_even_M
N
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
constexpr
int
smem_size_dq_dk_dv
=
Kernel_traits
::
kSmemSize1colblock
;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH
(
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_M
,
IsEvenMConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_M
N
,
IsEvenM
N
Const
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenMConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, true>;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenM
N
Const
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenM
N
Const, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
...
...
@@ -102,7 +102,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits,
Is_dropout, IsCausalConst, true, true
>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits,
false, false, IsEvenNConst, IsEvenKConst
>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
...
...
csrc/flash_attn/src/softmax.h
View file @
a4f148b6
...
...
@@ -117,15 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
uint32_t
max_seqlen_k
)
{
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
uint32_t
max_seqlen_k
,
const
uint32_t
col_idx_offset_
=
0
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
uint32_t
lane_id
=
threadIdx
.
x
%
32
;
const
uint32_t
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
uint32_t
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
uint32_t
col_idx
=
nj
*
8
+
j
+
(
lane_id
%
4
)
*
2
;
const
uint32_t
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
max_seqlen_k
)
{
// Without the "make_coord" we get wrong results
#pragma unroll
...
...
tests/test_flash_attn.py
View file @
a4f148b6
...
...
@@ -825,3 +825,97 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
assert
torch
.
equal
(
dqkv
[:,
:,
0
],
dqkv0
[:,
:,
0
])
assert
torch
.
equal
(
dqkv
[:,
:,
1
],
dqkv0
[:,
:,
1
])
assert
torch
.
equal
(
dqkv
[:,
:,
2
],
dqkv0
[:,
:,
2
])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
16
,
32
,
64
])
# @pytest.mark.parametrize('d', [16])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
1
,
2
,
5
,
17
,
128
])
# @pytest.mark.parametrize('seqlen', [2])
def
test_flash_attn_bwd_overflow
(
seqlen
,
d
,
causal
,
dtype
):
""" We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0.
"""
device
=
'cuda'
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
2
nheads
=
5
q
=
torch
.
randn
([
batch_size
,
seqlen
,
nheads
,
d
],
dtype
=
dtype
,
device
=
"cuda"
)
*
5
k
,
v
=
[
torch
.
randn
([
batch_size
,
seqlen
,
nheads
,
d
],
dtype
=
dtype
,
device
=
"cuda"
)
*
3
for
_
in
range
(
2
)]
q
.
requires_grad_
(
True
)
k
.
requires_grad_
(
True
)
v
.
requires_grad_
(
True
)
out
=
flash_attn_func
(
q
,
k
,
v
,
causal
=
causal
)
g
=
torch
.
randn_like
(
out
)
out
.
backward
(
g
)
q_pt
=
q
.
detach
().
clone
().
requires_grad_
(
True
)
k_pt
=
k
.
detach
().
clone
().
requires_grad_
(
True
)
v_pt
=
v
.
detach
().
clone
().
requires_grad_
(
True
)
out_pt
,
_
=
attention_ref
(
q_pt
,
k_pt
,
v_pt
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
out_pt
.
backward
(
g
)
q_ref
=
q
.
detach
().
clone
().
requires_grad_
(
True
)
k_ref
=
k
.
detach
().
clone
().
requires_grad_
(
True
)
v_ref
=
v
.
detach
().
clone
().
requires_grad_
(
True
)
out_ref
,
attn_ref
=
attention_ref
(
q_ref
,
k_ref
,
v_ref
,
causal
=
causal
)
out_ref
.
backward
(
g
)
print
(
f
'dQ max diff:
{
(
q
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
k
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
v
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
q_pt
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
k_pt
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV Pytorch max diff:
{
(
v_pt
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
assert
(
q
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
<=
5
*
(
q_pt
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
+
1e-3
assert
(
k
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
<=
5
*
(
k_pt
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
+
1e-3
assert
(
v
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
<=
5
*
(
v_pt
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
+
1e-3
@
pytest
.
mark
.
parametrize
(
'dtype'
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'causal'
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
@
pytest
.
mark
.
parametrize
(
'd'
,
[
64
,
128
])
# @pytest.mark.parametrize('d', [64])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
97
,
128
,
200
,
256
])
# @pytest.mark.parametrize('seqlen', [128])
def
test_flash_attn_bwd_transpose
(
seqlen
,
d
,
causal
,
dtype
):
""" We previously had a bug where we were using the wrong strides of dout, which shows up
when dout is not contiguous.
"""
device
=
'cuda'
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
5
nheads
=
2
q
,
k
,
v
=
[
torch
.
randn
([
batch_size
,
seqlen
,
nheads
,
d
],
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
out
=
rearrange
(
flash_attn_func
(
q
,
k
,
v
,
causal
=
causal
),
"b s ... -> s b ..."
)
# So g is not contiguous
g
=
torch
.
randn
(
seqlen
,
2
*
batch_size
,
nheads
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)[:,
::
2
]
out
.
backward
(
g
)
q_pt
=
q
.
detach
().
clone
().
requires_grad_
(
True
)
k_pt
=
k
.
detach
().
clone
().
requires_grad_
(
True
)
v_pt
=
v
.
detach
().
clone
().
requires_grad_
(
True
)
out_pt
,
attn_pt
=
attention_ref
(
q_pt
,
k_pt
,
v_pt
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
)
out_pt
=
rearrange
(
out_pt
,
"b s ... -> s b ..."
)
out_pt
.
backward
(
g
)
q_ref
=
q
.
detach
().
clone
().
requires_grad_
(
True
)
k_ref
=
k
.
detach
().
clone
().
requires_grad_
(
True
)
v_ref
=
v
.
detach
().
clone
().
requires_grad_
(
True
)
out_ref
,
attn_ref
=
attention_ref
(
q_ref
,
k_ref
,
v_ref
,
causal
=
causal
)
out_ref
=
rearrange
(
out_ref
,
"b s ... -> s b ..."
)
out_ref
.
backward
(
g
)
print
(
f
'dQ max diff:
{
(
q
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK max diff:
{
(
k
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV max diff:
{
(
v
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dQ Pytorch max diff:
{
(
q_pt
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dK Pytorch max diff:
{
(
k_pt
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
print
(
f
'dV Pytorch max diff:
{
(
v_pt
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
assert
(
q
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
q_pt
.
grad
-
q_ref
.
grad
).
abs
().
max
().
item
()
assert
(
k
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
k_pt
.
grad
-
k_ref
.
grad
).
abs
().
max
().
item
()
assert
(
v
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
<=
2
*
(
v_pt
.
grad
-
v_ref
.
grad
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment