Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9e5e8bc9
Commit
9e5e8bc9
authored
Aug 21, 2023
by
Tri Dao
Browse files
Change causal mask to be aligned to bottom-right instead of top-left
parent
e07aa036
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
573 additions
and
204 deletions
+573
-204
README.md
README.md
+26
-0
benchmarks/benchmark_causal.py
benchmarks/benchmark_causal.py
+10
-14
csrc/flash_attn/src/block_info.h
csrc/flash_attn/src/block_info.h
+2
-2
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+19
-52
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+69
-22
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+6
-8
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+22
-22
flash_attn/__init__.py
flash_attn/__init__.py
+1
-1
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+48
-0
tests/test_flash_attn.py
tests/test_flash_attn.py
+369
-82
training/Dockerfile
training/Dockerfile
+1
-1
No files found.
README.md
View file @
9e5e8bc9
...
...
@@ -136,6 +136,32 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```
python
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
)
```
## Changes in v2.1 (compared to v2.0)
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
v2.0:
1 0 0 0 0
1 1 0 0 0
v2.1:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
v2.0:
1 0
1 1
1 1
1 1
1 1
v2.1:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
## Performance
...
...
benchmarks/benchmark_causal.py
View file @
9e5e8bc9
...
...
@@ -15,12 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
# from triton.ops.flash_attention import attention as attention_triton
try
:
from
fav2
import
flash_attn_qkvpacked_func
as
fav2_qkvpacked_func
from
fav2
import
flash_attn_kvpacked_func
as
fav2_kvpacked_func
except
ImportError
:
fav2_qkvpacked_func
=
None
fav2_kvpacked_func
=
None
from
flash_attn
import
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
try
:
from
flash_attn.fused_softmax
import
scaled_upper_triang_masked_softmax
...
...
@@ -80,8 +75,8 @@ def attention_megatron(qkv):
torch
.
manual_seed
(
0
)
repeats
=
30
batch_size
=
2
seqlen
=
8192
batch_size
=
8
seqlen
=
2048
nheads
=
12
headdim
=
128
# nheads = 24
...
...
@@ -90,8 +85,8 @@ headdim = 128
# seqlen = 512
# nheads = 8
# headdim = 128
dropout_p
=
0.
1
causal
=
Fals
e
dropout_p
=
0.
0
causal
=
Tru
e
dtype
=
torch
.
float16
device
=
'cuda'
...
...
@@ -100,20 +95,20 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
#
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
qkv_unpad
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
).
detach
().
requires_grad_
(
True
)
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
# if fav2_qkvpacked_func is not None:
# benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
benchmark_forward
(
flash_attn_qkvpacked_func
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
desc
=
'Fav2'
)
pytorch_profiler
(
flash_attn_qkvpacked_func
,
qkv
,
dropout_p
,
causal
=
causal
,
backward
=
False
)
# for dropout_p in [0.1, 0.0]:
# for causal in [False, True]:
# print(f"### {dropout_p = }, {causal = } ###")
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
# nheads_k = 2
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
...
...
@@ -151,6 +146,7 @@ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch
flops
=
4
*
batch_size
*
seqlen
**
2
*
nheads
*
headdim
ideal_a100_time
=
flops
/
312
/
1e9
print
(
f
"Ideal A100 fwd time:
{
ideal_a100_time
:.
3
f
}
ms, bwd time:
{
ideal_a100_time
*
2.5
:.
3
f
}
ms"
)
exit
(
0
)
def
time_fwd_bwd
(
func
,
*
args
,
**
kwargs
):
...
...
csrc/flash_attn/src/block_info.h
View file @
9e5e8bc9
...
...
@@ -32,8 +32,8 @@ struct BlockInfo {
const
int
sum_s_q
;
const
int
sum_s_k
;
const
u
int
32_t
actual_seqlen_q
;
const
u
int
32_t
actual_seqlen_k
;
const
int
actual_seqlen_q
;
const
int
actual_seqlen_k
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
9e5e8bc9
...
...
@@ -659,46 +659,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
tdQgdQaccum
.
data
()
=
tdQgdQaccum
.
data
()
+
kBlockM
*
params
.
d_rounded
;
int
m_block
=
m_block_max
-
1
;
int
m_block_min
=
!
Is_causal
?
0
:
(
n_block
*
kBlockN
-
int
(
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
))
/
kBlockM
;
m_block_min
=
m_block_min
<
0
?
0
:
m_block_min
;
// We might need to exit early and write 0 to dK and dV.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
// TODO: what if we're not parallelizing, do we need to compute dot_do_o?
if
(
Is_causal
&&
m_block
<
m_block_min
)
{
const
index_t
row_offset_dk
=
binfo
.
k_offset
(
params
.
dk_batch_stride
,
params
.
dk_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dk_row_stride
+
bidh
*
params
.
dk_head_stride
;
const
index_t
row_offset_dv
=
binfo
.
k_offset
(
params
.
dv_batch_stride
,
params
.
dv_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dv_row_stride
+
bidh
*
params
.
dv_head_stride
;
Tensor
gdK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dk_ptr
)
+
row_offset_dk
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dk_row_stride
,
_1
{}));
Tensor
gdV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dv_ptr
)
+
row_offset_dv
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dv_row_stride
,
_1
{}));
typename
Kernel_traits
::
GmemTiledCopydKV
gmem_tiled_copy_dKV
;
auto
gmem_thr_copy_dKV
=
gmem_tiled_copy_dKV
.
get_thread_slice
(
tidx
);
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdVgdV
=
gmem_thr_copy_dKV
.
partition_D
(
gdV
);
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
Tensor
tdVrdV
=
make_tensor
<
Element
>
(
shape
(
tdVgdV
));
clear
(
tdKrdK
);
clear
(
tdVrdV
);
Tensor
cdKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gdK
),
size
<
1
>
(
gdK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor
tdKVcdKV
=
gmem_thr_copy_dKV
.
partition_D
(
cdKV
);
Tensor
tdKVpdKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tdKgdK
)));
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
return
;
}
int
m_block_min
=
!
Is_causal
?
0
:
std
::
max
(
0
,
(
n_block
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
)
/
kBlockM
);
// We're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
if
(
Double_buffer
&&
m_block
%
2
==
1
)
{
// Double buffer for sQ
tQsQ
.
data
()
=
tQsQ
.
data
()
+
size
(
sQ
);
...
...
@@ -743,7 +711,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
lse
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
decltype
(
size
(
taccScS_row
))
::
value
>>
{});
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
}
...
...
@@ -824,11 +791,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements beyond actual_seqlen_k.
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
||
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
))
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_
q
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
))
,
binfo
.
actual_seqlen_
k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
))
,
binfo
.
actual_seqlen_q
,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS
*
16
);
}
...
...
@@ -837,11 +804,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Compute the exponential value.
flash
::
scale_apply_exp2
<
/*scale_max=*/
false
>
(
scores
,
lse
,
params
.
scale_softmax_log2
);
if
(
Is_dropout
)
{
u
int
32_t
warp_id
=
tidx
/
32
;
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
int
warp_id
=
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
Tensor
scores_dropped
=
make_tensor
(
scores
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
scores
.
layout
()));
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
scores_dropped
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
...
...
@@ -1341,7 +1308,6 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
lse
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
decltype
(
size
(
taccScS_row
))
::
value
>>
{});
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
}
...
...
@@ -1379,18 +1345,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
// the corresponding values of K would be 0, so the result would still be correct.
if
(
Is_causal
&&
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
);
}
// Compute the exponential value.
flash
::
scale_apply_exp2
<
/*scale_max=*/
false
>
(
scores
,
lse
,
params
.
scale_softmax_log2
);
if
(
Is_dropout
)
{
u
int
32_t
warp_id
=
tidx
/
32
;
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
int
warp_id
=
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
Tensor
scores_dropped
=
make_tensor
(
scores
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
scores
.
layout
()));
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
scores_dropped
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
9e5e8bc9
...
...
@@ -118,7 +118,7 @@ inline __device__ void write_softmax_to_gmem(
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_
M
N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
...
...
@@ -130,8 +130,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
// The global block index.
const
int
block_id
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
+
gridDim
.
x
*
gridDim
.
y
*
blockIdx
.
z
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
...
...
@@ -139,16 +137,60 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
MMA_M
=
kBlockM
/
decltype
(
size
<
0
>
(
typename
Kernel_traits
::
TiledMma
::
TiledShape_MNK
{}))
::
value
;
const
BlockInfo
<
/*Varlen=*/
!
Is_even_N
>
binfo
(
params
,
bidb
);
const
BlockInfo
<
/*Varlen=*/
!
Is_even_
M
N
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
(
Is_causal
)
{
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
(
(
m_block
+
1
)
*
kBlockM
+
int
(
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
)
,
kBlockN
));
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
(
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
,
kBlockN
));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
// We exit early and write 0 to gO and gLSE.
// Otherwise we might read OOB elements from gK and gV.
if
(
n_block_max
<=
0
)
{
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
params
.
rng_state
[
0
]
=
std
::
get
<
0
>
(
seeds
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
Tensor
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
Tensor
tOrO
=
make_tensor
<
Element
>
(
shape
(
tOgO
));
clear
(
tOrO
);
// Construct identity layout for sO
Tensor
cO
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gO
),
size
<
1
>
(
gO
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor
tOcO
=
gmem_thr_copy_O
.
partition_D
(
cO
);
Tensor
tOpO
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tOgO
)));
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tOgO
);
++
m
)
{
const
int
row
=
get
<
0
>
(
tOcO
(
0
,
m
,
0
));
if
(
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
&&
get
<
1
>
(
tOcO
(
0
,
m
,
0
))
==
0
)
{
gLSE
(
row
)
=
INFINITY
;
}
}
return
;
}
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
...
...
@@ -275,8 +317,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
if
(
Kernel_traits
::
Is_Q_in_regs
)
{
cute
::
cp_async_fence
();
}
// // Copy rmem to smem
...
...
@@ -298,8 +340,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_N
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
Is_even_
M
N
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
...
...
@@ -317,7 +359,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
unsigned
long
long
offset
=
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
;
// Save seed and offset for backward.
if
(
block_id
==
0
&&
tidx
==
0
)
{
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
params
.
rng_state
[
0
]
=
seed
;
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
...
...
@@ -330,7 +372,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
constexpr
int
n_masking_steps
=
Is_causal
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
1
;
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
!
Is_causal
?
1
:
(
Is_even_MN
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
...
...
@@ -344,7 +390,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_
M
N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
...
...
@@ -363,7 +409,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if
(
!
Is_causal
)
{
if
(
!
Is_even_N
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
if
(
!
Is_even_
M
N
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
...
...
@@ -376,9 +422,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_k
,
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
...
...
@@ -405,8 +452,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
...
...
@@ -468,8 +515,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
...
...
@@ -563,14 +610,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_
M
N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
...
...
@@ -586,7 +633,7 @@ inline __device__ void compute_attn(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_N
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_
M
N
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
9e5e8bc9
...
...
@@ -10,9 +10,9 @@
#include "flash.h"
#include "flash_fwd_kernel.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_
M
N
,
bool
Is_even_K
,
bool
Return_softmax
>
__global__
void
flash_fwd_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_N
,
Is_even_K
,
Return_softmax
>
(
params
);
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_
M
N
,
Is_even_K
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
...
...
@@ -26,17 +26,15 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
b
,
params
.
h
);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_q as well.
const
bool
is_even_N
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
;
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
return_softmax
=
params
.
p_ptr
!=
nullptr
;
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_
M
N
,
IsEven
M
NConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenNConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEven
M
NConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEven
M
NConst, true, ReturnSoftmaxConst && Is_dropout>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
...
...
csrc/flash_attn/src/softmax.h
View file @
9e5e8bc9
...
...
@@ -117,18 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
u
int
32_t
max_seqlen_k
,
const
u
int
32_t
col_idx_offset_
=
0
)
{
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
u
int
32_t
lane_id
=
threadIdx
.
x
%
32
;
const
u
int
32_t
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
u
int
32_t
col_idx_base
=
col_idx_offset
+
nj
*
8
;
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
u
int
32_t
col_idx
=
col_idx_base
+
j
;
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
max_seqlen_k
)
{
// Without the "make_coord" we get wrong results
#pragma unroll
...
...
@@ -141,28 +141,28 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
u
int
32_t
col_idx_offset_
,
const
u
int
32_t
max_seqlen_
q
,
const
u
int
32_t
max_seqlen_k
,
const
u
int
32_t
row_idx_offset_
,
const
u
int
32_t
warp_row_stride
)
{
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_
k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
u
int
32_t
lane_id
=
threadIdx
.
x
%
32
;
// const
u
int
32_t
row_idx_offset = row_idx_offset_ + lane_id / 4;
const
u
int
32_t
row_idx_offset
=
row_idx_offset_
;
const
u
int
32_t
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
u
int
32_t
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
u
int
32_t
row_idx
=
row_idx_base
+
i
*
8
;
const
u
int
32_t
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
);
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
);
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
u
int
32_t
col_idx_base
=
col_idx_offset
+
nj
*
8
;
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
u
int
32_t
col_idx
=
col_idx_base
+
j
;
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
col_idx_limit
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
...
...
@@ -180,7 +180,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
u
int
32_t
col_idx_offset_
,
const
u
int
32_t
max_seqlen_k
,
const
u
int
32_t
row_idx_offset_
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
...
...
@@ -189,7 +189,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
u
int
32_t
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset_
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset_
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
...
...
@@ -207,8 +207,8 @@ inline __device__ void apply_mask_causal_w_idx(
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
uint8_t
p_dropout_in_uint8_t
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
u
int
32_t
block_row_start
,
u
int
32_t
block_col_start
,
u
int
32_t
block_row_stride
)
{
int
block_row_start
,
int
block_col_start
,
int
block_row_stride
)
{
// tensor has shape (8, MMA_M, MMA_N / 2)
using
T
=
typename
Engine
::
value_type
;
auto
encode_dropout
=
[](
bool
keep
,
T
val
)
{
...
...
flash_attn/__init__.py
View file @
9e5e8bc9
__version__
=
"2.
0.9
"
__version__
=
"2.
1.0
"
from
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
...
...
flash_attn/flash_attn_interface.py
View file @
9e5e8bc9
...
...
@@ -528,6 +528,18 @@ def flash_attn_kvpacked_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
...
...
@@ -559,6 +571,18 @@ def flash_attn_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
...
...
@@ -645,6 +669,18 @@ def flash_attn_varlen_kvpacked_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
...
@@ -703,6 +739,18 @@ def flash_attn_varlen_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
...
tests/test_flash_attn.py
View file @
9e5e8bc9
...
...
@@ -29,9 +29,11 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
if
mode
==
"full"
:
lengths
=
torch
.
full
((
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
elif
mode
==
"random"
:
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
,
(
batch_size
,
1
),
device
=
device
)
lengths
=
torch
.
randint
(
max
(
1
,
max_seqlen
-
20
),
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
elif
mode
==
"third"
:
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
,
(
batch_size
,
1
),
device
=
device
)
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
(
repeat
(
torch
.
arange
(
max_seqlen
,
device
=
device
),
"s -> b s"
,
b
=
batch_size
)
<
lengths
)
...
...
@@ -146,6 +148,23 @@ def generate_qkv(
)
def
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
device
=
None
):
row_idx
=
rearrange
(
torch
.
arange
(
seqlen_q
,
device
=
device
,
dtype
=
torch
.
long
),
"s -> s 1"
)
col_idx
=
torch
.
arange
(
seqlen_k
,
device
=
device
,
dtype
=
torch
.
long
)
sk
=
(
seqlen_k
if
key_padding_mask
is
None
else
rearrange
(
key_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
sq
=
(
seqlen_q
if
query_padding_mask
is
None
else
rearrange
(
query_padding_mask
.
sum
(
-
1
),
"b -> b 1 1 1"
)
)
return
col_idx
>
row_idx
+
sk
-
sq
def
attention_ref
(
q
,
k
,
...
...
@@ -190,11 +209,16 @@ def attention_ref(
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
causal
:
causal_mask
=
torch
.
triu
(
torch
.
ones
(
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
bool
,
device
=
q
.
device
),
1
# causal_mask = torch.triu(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
# )
causal_mask
=
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
q
.
device
)
scores
.
masked_fill_
(
causal_mask
,
float
(
"-inf"
))
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
if
causal
:
# Some rows are completely masked out so we fill them with zero instead of NaN
attention
=
attention
.
masked_fill
(
torch
.
all
(
causal_mask
,
dim
=-
1
,
keepdim
=
True
),
0.0
)
dropout_scaling
=
1.0
/
(
1
-
dropout_p
)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
...
...
@@ -300,19 +324,19 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask
def
convert_flash_attn_S_to_softmax
(
S
,
query_padding_mask
,
key_padding_mask
,
head_dim
,
is_dropout
,
causal
=
False
S
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
head_dim
,
is_dropout
,
causal
=
False
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
query_padding_mask: (batch_size, seqlen_q
_rounded
)
key_padding_mask: (batch_size, seqlen_k
_rounded
)
"""
seqlen_q
,
seqlen_k
=
S
.
shape
[
-
2
:]
seqlen_q
_rounded
,
seqlen_k
_rounded
=
S
.
shape
[
-
2
:]
warps_n
=
4
blocksize_m
,
blocksize_n
=
_get_block_size
(
S
.
device
,
head_dim
,
is_dropout
,
causal
)
nblocks_n
=
(
seqlen_k
+
blocksize_n
-
1
)
//
blocksize_n
nblocks_m
=
(
seqlen_q
+
blocksize_m
-
1
)
//
blocksize_m
nblocks_n
=
(
seqlen_k
_rounded
+
blocksize_n
-
1
)
//
blocksize_n
nblocks_m
=
(
seqlen_q
_rounded
+
blocksize_m
-
1
)
//
blocksize_m
mmas_n
=
(
blocksize_n
+
16
-
1
)
//
16
S_flat
=
rearrange
(
S
,
...
...
@@ -331,37 +355,30 @@ def convert_flash_attn_S_to_softmax(
c2
=
2
,
four
=
4
,
)
if
causal
:
causal_mask
=
torch
.
triu
(
torch
.
ones
(
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
bool
,
device
=
S
.
device
),
1
# causal_mask = torch.triu(
# torch.ones(seqlen_q_rounded, seqlen_k_rounded, dtype=torch.bool, device=q.device), 1
# )
causal_mask
=
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
S
.
device
)
causal_mask
=
F
.
pad
(
causal_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k
,
0
,
seqlen_q_rounded
-
seqlen_q
),
value
=
True
)
S_converted
.
masked_fill_
(
causal_mask
,
0.0
)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og
=
query_padding_mask
.
shape
[
-
1
]
if
query_padding_mask
is
not
None
else
seqlen_q
seqlen_q_og
=
query_padding_mask
.
shape
[
-
1
]
if
query_padding_mask
is
not
None
else
seqlen_q
_rounded
if
query_padding_mask
is
not
None
:
if
seqlen_q_og
<
seqlen_q
:
query_padding_mask
=
F
.
pad
(
query_padding_mask
,
(
0
,
seqlen_q
-
seqlen_q_og
))
else
:
query_padding_mask
=
query_padding_mask
[:,
:
seqlen_q
]
query_padding_mask
=
F
.
pad
(
query_padding_mask
,
(
0
,
seqlen_q_rounded
-
seqlen_q_og
))
S_converted
=
S_converted
.
masked_fill
(
rearrange
(
~
query_padding_mask
,
"b s -> b 1 s 1"
),
0.0
)
seqlen_k_og
=
key_padding_mask
.
shape
[
-
1
]
if
key_padding_mask
is
not
None
else
seqlen_k
if
key_padding_mask
is
not
None
:
if
seqlen_k_og
<
seqlen_k
:
key_padding_mask
=
F
.
pad
(
key_padding_mask
,
(
0
,
seqlen_k
-
seqlen_k_og
))
else
:
key_padding_mask
=
key_padding_mask
[:,
:
seqlen_k
]
key_padding_mask
=
F
.
pad
(
key_padding_mask
,
(
0
,
seqlen_k_rounded
-
seqlen_k_og
))
S_converted
=
S_converted
.
masked_fill
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
0.0
)
if
seqlen_q_og
<
seqlen_q
:
S_converted
=
S_converted
[:,
:,
:
seqlen_q_og
,
:]
else
:
S_converted
=
F
.
pad
(
S_converted
,
(
0
,
0
,
0
,
seqlen_q_og
-
seqlen_q
))
if
seqlen_k_og
<
seqlen_k
:
S_converted
=
S_converted
[:,
:,
:,
:
seqlen_k_og
]
else
:
S_converted
=
F
.
pad
(
S_converted
,
(
0
,
seqlen_k_og
-
seqlen_k
))
return
S_converted
S_converted
=
F
.
pad
(
S_converted
,
(
0
,
0
,
0
,
seqlen_q_og
-
seqlen_q_rounded
))
S_converted
=
F
.
pad
(
S_converted
,
(
0
,
seqlen_k_og
-
seqlen_k_rounded
))
return
S_converted
[:,
:,
:
seqlen_q
,
:
seqlen_k
]
def
normalize_flash_attn_S
(
...
...
@@ -390,20 +407,26 @@ def normalize_flash_attn_S(
if
key_padding_mask
is
not
None
:
scores
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
float
(
"-inf"
))
if
causal
:
causal_mask
=
torch
.
triu
(
torch
.
ones
(
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
bool
,
device
=
q
.
device
),
1
# causal_mask = torch.triu(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1
# )
causal_mask
=
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
q
.
device
)
scores
.
masked_fill_
(
causal_mask
,
float
(
"-inf"
))
_
,
block_size_n
=
_get_block_size
(
scores
.
device
,
head_dim
,
is_dropout
,
causal
)
scores_block
=
scores
.
split
(
block_size_n
,
dim
=-
1
)
lse_block
=
torch
.
stack
([
torch
.
logsumexp
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
lse
=
torch
.
logsumexp
(
lse_block
,
dim
=-
1
)
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
lse
[
lse
==
float
(
"-inf"
)]
=
float
(
"inf"
)
scores_max_block
=
torch
.
stack
([
torch
.
amax
(
s
,
dim
=-
1
)
for
s
in
scores_block
],
dim
=-
1
)
cummax_block
=
torch
.
cummax
(
scores_max_block
.
flip
(
-
1
),
dim
=-
1
).
values
.
flip
(
-
1
).
unbind
(
dim
=-
1
)
attn_unnorm_block
=
attn_unnorm
.
split
(
block_size_n
,
dim
=-
1
)
attn_norm
=
torch
.
cat
(
[
a
/
rearrange
(
torch
.
exp
(
lse
-
m
),
"b h s -> b h s 1"
)
a
*
rearrange
(
torch
.
exp
(
m
-
lse
),
"b h s -> b h s 1"
)
for
a
,
m
in
zip
(
attn_unnorm_block
,
cummax_block
)
],
dim
=-
1
,
...
...
@@ -428,8 +451,11 @@ def get_dropout_fraction(
if
key_padding_mask
is
not
None
:
dropped
.
masked_fill_
(
rearrange
(
~
key_padding_mask
,
"b s -> b 1 1 s"
),
False
)
if
causal
:
causal_mask
=
torch
.
triu
(
torch
.
ones
(
seqlen_q
,
seqlen_k
,
dtype
=
torch
.
bool
,
device
=
dropout_mask
.
device
),
1
# causal_mask = torch.triu(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1
# )
causal_mask
=
construct_causal_mask
(
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
dropout_mask
.
device
)
dropped
.
masked_fill_
(
causal_mask
,
False
)
dropped_total
=
dropped
.
sum
()
...
...
@@ -447,9 +473,9 @@ def get_dropout_fraction(
numel_per_batch
=
query_lengths
*
key_lengths
else
:
numel_per_batch
=
torch
.
where
(
quer
y_lengths
<=
ke
y_lengths
,
quer
y_lengths
*
(
quer
y_lengths
+
1
)
/
2
,
query_lengths
*
key_lengths
-
(
ke
y_lengths
*
(
ke
y_lengths
-
1
)
/
2
),
ke
y_lengths
<=
quer
y_lengths
,
ke
y_lengths
*
(
ke
y_lengths
+
1
)
/
2
,
query_lengths
*
key_lengths
-
(
quer
y_lengths
*
(
quer
y_lengths
-
1
)
/
2
),
)
return
dropped_total
/
(
numel_per_batch
.
sum
()
*
nheads
)
...
...
@@ -483,8 +509,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
[:,
:,
:
seqlen
,
:
seqlen
]
S_dmask
,
seqlen
,
seqlen
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn
=
normalize_flash_attn_S
(
...
...
@@ -596,8 +622,8 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
key_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
[:,
:,
:
seqlen
,
:
seqlen
]
S_dmask
,
seqlen
,
seqlen
,
key_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
attn
=
normalize_flash_attn_S
(
...
...
@@ -665,19 +691,19 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
@
pytest
.
mark
.
parametrize
(
"kvpacked"
,
[
True
,
False
])
# @pytest.mark.parametrize(
'
kvpacked
'
, [False])
# @pytest.mark.parametrize(
"
kvpacked
"
, [False])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize(
'
dtype
'
, [torch.bfloat16])
# @pytest.mark.parametrize(
"
dtype
"
, [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize(
'
mha_type
'
, ["mha"])
# @pytest.mark.parametrize(
"
mha_type
"
, ["mha"])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize(
'
causal
'
, [
Fals
e])
# @pytest.mark.parametrize(
"
causal
"
, [
Tru
e])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize(
'd'
, [64])
# @pytest.mark.parametrize(
"d"
, [64])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
...
...
@@ -693,9 +719,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
(
2048
,
2048
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(
128
, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(
256
, 128)])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize(
'
dropout_p
'
, [0.
0
])
# @pytest.mark.parametrize(
"
dropout_p
"
, [0.
17
])
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
mha_type
,
dtype
,
kvpacked
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
...
...
@@ -732,8 +758,8 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
[:,
:,
:
seqlen_q
,
:
seqlen_k
]
S_dmask
,
seqlen_q
,
seqlen_k
,
None
,
None
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
if
kvpacked
:
...
...
@@ -969,8 +995,8 @@ def test_flash_attn_varlen_output(
out
=
output_pad_fn
(
out_unpad
)
if
dropout_p
>
0.0
:
S_dmask_converted
=
convert_flash_attn_S_to_softmax
(
S_dmask
,
query_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
[:,
:,
:
seqlen_q
,
:
seqlen_k
]
S_dmask
,
seqlen_q
,
seqlen_k
,
query_padding_mask
,
key_padding_mask
,
d
,
dropout_p
>
0.0
,
causal
=
causal
)
dropout_mask
=
S_dmask_converted
>=
0
attn_unnorm
=
S_dmask_converted
.
abs
()
if
kvpacked
:
...
...
@@ -1101,53 +1127,314 @@ def test_flash_attn_varlen_output(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64, 128])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
239
),
(
3
,
799
),
(
127
,
512
),
(
127
,
513
),
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
1023
,
1024
),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
causal
=
True
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
16
nheads
=
9
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
=
flash_attn_func
(
q
,
k
,
v
,
0.0
,
causal
=
causal
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
0.0
,
None
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
(
dq
,
dk
,
dv
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q
,
k
,
v
),
g
)
(
dq_ref
,
dk_ref
,
dv_ref
,
)
=
torch
.
autograd
.
grad
(
out_ref
,
(
q
,
k
,
v
),
g
)
(
dq_pt
,
dk_pt
,
dv_pt
,
)
=
torch
.
autograd
.
grad
(
out_pt
,
(
q
,
k
,
v
),
g
)
print
(
f
"dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
1e-5
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
1e-5
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
1e-5
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
is_sm75
else
[
torch
.
float16
,
torch
.
bfloat16
]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
239
),
(
3
,
799
),
(
127
,
512
),
(
127
,
513
),
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
1023
,
1024
),
],
)
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def
test_flash_attn_varlen_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
causal
=
True
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
16
nheads
=
9
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
query_padding_mask
=
generate_random_padding_mask
(
seqlen_q
,
batch_size
,
device
,
mode
=
"random"
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen_k
,
batch_size
,
device
,
mode
=
"random"
)
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
)
out_unpad
=
flash_attn_varlen_func
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
0.0
,
causal
=
causal
,
)
out
=
output_pad_fn
(
out_unpad
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
0.0
,
None
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
do_o
=
(
g
.
float
()
*
out
.
float
()).
sum
(
-
1
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
(
dq_unpad
,
dk_unpad
,
dv_unpad
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g
)
dq
=
dq_pad_fn
(
dq_unpad
)
dk
=
dk_pad_fn
(
dk_unpad
)
dv
=
dk_pad_fn
(
dv_unpad
)
(
dq_ref
,
dk_ref
,
dv_ref
,
)
=
torch
.
autograd
.
grad
(
out_ref
,
(
q
,
k
,
v
),
g
)
(
dq_pt
,
dk_pt
,
dv_pt
,
)
=
torch
.
autograd
.
grad
(
out_pt
,
(
q
,
k
,
v
),
g
)
print
(
f
"dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
1e-5
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
2
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
1e-5
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
2
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
1e-5
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
2
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
1e-5
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('causal', [True])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
40
,
59
,
64
,
80
,
96
,
111
,
128
,
160
,
192
,
224
,
256
])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
64
,
96
,
128
,
160
,
192
])
#
@pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [128])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
128
])
# @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@
pytest
.
mark
.
parametrize
(
"dropout_p"
,
[
0.0
])
def
test_flash_attn_race_condition
(
seqlen
,
d
,
dropout_p
,
causal
,
dtype
):
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
239
),
(
239
,
1
),
(
3
,
799
),
(
799
,
3
),
(
1024
,
128
),
(
97
,
97
),
(
128
,
128
),
(
200
,
200
),
(
256
,
256
),
(
257
,
257
),
(
384
,
384
),
(
512
,
512
),
(
768
,
768
),
(
1024
,
1024
),
],
)
@
pytest
.
mark
.
parametrize
(
'dropout_p'
,
[
0.0
,
0.17
])
# @pytest.mark.parametrize("dropout_p", [0.0])
def
test_flash_attn_race_condition
(
seqlen_q
,
seqlen_k
,
d
,
dropout_p
,
causal
,
dtype
):
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
60
# Sometimes we need large batch size for the race conditions to trigger
nheads
=
4
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out0
,
lse0
,
_
=
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
torch
.
random
.
manual_seed
(
42
)
out0
,
lse0
,
_
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
return_attn_probs
=
True
)
g
=
torch
.
randn_like
(
out0
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
(
dqkv0
,)
=
torch
.
autograd
.
grad
(
out0
,
qkv
,
g
)
(
dq0
,
dk0
,
dv0
,
)
=
torch
.
autograd
.
grad
(
out0
,
(
q
,
k
,
v
),
g
)
# Numerical error if we just do any arithmetic on dq
dq_atol
=
2
*
((
dq
kv0
[:,
:,
0
]
+
0.3
-
0.3
)
-
dq
kv0
[:,
:,
0
]
).
abs
().
max
().
item
()
dq_atol
=
2
*
((
dq
0
+
0.3
-
0.3
)
-
dq
0
).
abs
().
max
().
item
()
for
i
in
range
(
200
):
torch
.
random
.
manual_seed
(
0
)
out
,
lse
,
S_dmask
=
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
,
return_attn_probs
=
True
,
causal
=
causal
)
for
i
in
range
(
250
):
torch
.
random
.
manual_seed
(
42
)
out
,
lse
,
_
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
,
causal
=
causal
,
return_attn_probs
=
True
)
assert
torch
.
equal
(
out
,
out0
)
assert
torch
.
equal
(
lse
,
lse0
)
if
d
<=
MAX_HEADDIM_SM8x
or
(
is_sm80
or
is_sm90
):
(
dqkv
,)
=
torch
.
autograd
.
grad
(
out
,
qkv
,
g
)
dq_equal
=
torch
.
allclose
(
dqkv
[:,
:,
0
],
dqkv0
[:,
:,
0
],
atol
=
dq_atol
)
(
dq
,
dk
,
dv
,
)
=
torch
.
autograd
.
grad
(
out
,
(
q
,
k
,
v
),
g
)
dq_equal
=
torch
.
allclose
(
dq
,
dq0
,
atol
=
dq_atol
)
if
not
dq_equal
:
dq0
=
dqkv0
[:,
:,
0
]
dq
=
dqkv
[:,
:,
0
]
print
(
f
"Iter
{
i
}
,
{
dq_atol
=
}
, dQ max diff:
{
(
dqkv
[:,
:,
0
]
-
dqkv0
[:,
:,
0
]).
abs
().
max
().
item
()
}
"
)
print
(
f
"Iter
{
i
}
,
{
dq_atol
=
}
, dQ max diff:
{
(
dq
-
dq0
).
abs
().
max
().
item
()
}
"
)
assert
torch
.
equal
(
dv
,
dv0
)
assert
torch
.
equal
(
dk
,
dk0
)
assert
dq_equal
assert
torch
.
equal
(
dqkv
[:,
:,
1
],
dqkv0
[:,
:,
1
])
assert
torch
.
equal
(
dqkv
[:,
:,
2
],
dqkv0
[:,
:,
2
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
...
...
training/Dockerfile
View file @
9e5e8bc9
...
...
@@ -89,7 +89,7 @@ RUN pip install flash-attn==2.0.9
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN
git clone https://github.com/HazyResearch/flash-attention
\
&&
cd
flash-attention
&&
git checkout v2.
0.9
\
&&
cd
flash-attention
&&
git checkout v2.
1.0
\
&&
cd
csrc/fused_softmax
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/rotary
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/xentropy
&&
pip
install
.
&&
cd
../../
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment