Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9e5e8bc9
Commit
9e5e8bc9
authored
Aug 21, 2023
by
Tri Dao
Browse files
Change causal mask to be aligned to bottom-right instead of top-left
parent
e07aa036
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
573 additions
and
204 deletions
+573
-204
README.md
README.md
+26
-0
benchmarks/benchmark_causal.py
benchmarks/benchmark_causal.py
+10
-14
csrc/flash_attn/src/block_info.h
csrc/flash_attn/src/block_info.h
+2
-2
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+19
-52
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+69
-22
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+6
-8
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+22
-22
flash_attn/__init__.py
flash_attn/__init__.py
+1
-1
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+48
-0
tests/test_flash_attn.py
tests/test_flash_attn.py
+369
-82
training/Dockerfile
training/Dockerfile
+1
-1
No files found.
README.md
View file @
9e5e8bc9
...
...
@@ -136,6 +136,32 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
```
python
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
)
```
## Changes in v2.1 (compared to v2.0)
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
bottom right corner of the attention matrix, instead of the top-left corner.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
v2.0:
1 0 0 0 0
1 1 0 0 0
v2.1:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
v2.0:
1 0
1 1
1 1
1 1
1 1
v2.1:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
## Performance
...
...
benchmarks/benchmark_causal.py
View file @
9e5e8bc9
...
...
@@ -15,12 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
# from triton.ops.flash_attention import attention as attention_triton
try
:
from
fav2
import
flash_attn_qkvpacked_func
as
fav2_qkvpacked_func
from
fav2
import
flash_attn_kvpacked_func
as
fav2_kvpacked_func
except
ImportError
:
fav2_qkvpacked_func
=
None
fav2_kvpacked_func
=
None
from
flash_attn
import
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
try
:
from
flash_attn.fused_softmax
import
scaled_upper_triang_masked_softmax
...
...
@@ -80,8 +75,8 @@ def attention_megatron(qkv):
torch
.
manual_seed
(
0
)
repeats
=
30
batch_size
=
2
seqlen
=
8192
batch_size
=
8
seqlen
=
2048
nheads
=
12
headdim
=
128
# nheads = 24
...
...
@@ -90,8 +85,8 @@ headdim = 128
# seqlen = 512
# nheads = 8
# headdim = 128
dropout_p
=
0.
1
causal
=
Fals
e
dropout_p
=
0.
0
causal
=
Tru
e
dtype
=
torch
.
float16
device
=
'cuda'
...
...
@@ -100,20 +95,20 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
#
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
qkv_unpad
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
).
detach
().
requires_grad_
(
True
)
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
# if fav2_qkvpacked_func is not None:
# benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
benchmark_forward
(
flash_attn_qkvpacked_func
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
desc
=
'Fav2'
)
pytorch_profiler
(
flash_attn_qkvpacked_func
,
qkv
,
dropout_p
,
causal
=
causal
,
backward
=
False
)
# for dropout_p in [0.1, 0.0]:
# for causal in [False, True]:
# print(f"### {dropout_p = }, {causal = } ###")
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
# nheads_k = 2
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
...
...
@@ -151,6 +146,7 @@ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch
flops
=
4
*
batch_size
*
seqlen
**
2
*
nheads
*
headdim
ideal_a100_time
=
flops
/
312
/
1e9
print
(
f
"Ideal A100 fwd time:
{
ideal_a100_time
:.
3
f
}
ms, bwd time:
{
ideal_a100_time
*
2.5
:.
3
f
}
ms"
)
exit
(
0
)
def
time_fwd_bwd
(
func
,
*
args
,
**
kwargs
):
...
...
csrc/flash_attn/src/block_info.h
View file @
9e5e8bc9
...
...
@@ -32,8 +32,8 @@ struct BlockInfo {
const
int
sum_s_q
;
const
int
sum_s_k
;
const
u
int
32_t
actual_seqlen_q
;
const
u
int
32_t
actual_seqlen_k
;
const
int
actual_seqlen_q
;
const
int
actual_seqlen_k
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
9e5e8bc9
...
...
@@ -659,46 +659,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
tdQgdQaccum
.
data
()
=
tdQgdQaccum
.
data
()
+
kBlockM
*
params
.
d_rounded
;
int
m_block
=
m_block_max
-
1
;
int
m_block_min
=
!
Is_causal
?
0
:
(
n_block
*
kBlockN
-
int
(
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
))
/
kBlockM
;
m_block_min
=
m_block_min
<
0
?
0
:
m_block_min
;
// We might need to exit early and write 0 to dK and dV.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
// TODO: what if we're not parallelizing, do we need to compute dot_do_o?
if
(
Is_causal
&&
m_block
<
m_block_min
)
{
const
index_t
row_offset_dk
=
binfo
.
k_offset
(
params
.
dk_batch_stride
,
params
.
dk_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dk_row_stride
+
bidh
*
params
.
dk_head_stride
;
const
index_t
row_offset_dv
=
binfo
.
k_offset
(
params
.
dv_batch_stride
,
params
.
dv_row_stride
,
bidb
)
+
n_block
*
kBlockN
*
params
.
dv_row_stride
+
bidh
*
params
.
dv_head_stride
;
Tensor
gdK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dk_ptr
)
+
row_offset_dk
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dk_row_stride
,
_1
{}));
Tensor
gdV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dv_ptr
)
+
row_offset_dv
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dv_row_stride
,
_1
{}));
typename
Kernel_traits
::
GmemTiledCopydKV
gmem_tiled_copy_dKV
;
auto
gmem_thr_copy_dKV
=
gmem_tiled_copy_dKV
.
get_thread_slice
(
tidx
);
Tensor
tdKgdK
=
gmem_thr_copy_dKV
.
partition_D
(
gdK
);
Tensor
tdVgdV
=
gmem_thr_copy_dKV
.
partition_D
(
gdV
);
Tensor
tdKrdK
=
make_tensor
<
Element
>
(
shape
(
tdKgdK
));
Tensor
tdVrdV
=
make_tensor
<
Element
>
(
shape
(
tdVgdV
));
clear
(
tdKrdK
);
clear
(
tdVrdV
);
Tensor
cdKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gdK
),
size
<
1
>
(
gdK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor
tdKVcdKV
=
gmem_thr_copy_dKV
.
partition_D
(
cdKV
);
Tensor
tdKVpdKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tdKgdK
)));
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tdKVpdKV
);
++
k
)
{
tdKVpdKV
(
k
)
=
get
<
1
>
(
tdKVcdKV
(
0
,
0
,
k
))
<
params
.
d
;
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dKV
,
tdKrdK
,
tdKgdK
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_dKV
,
tdVrdV
,
tdVgdV
,
tdKVcdKV
,
tdKVpdKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
return
;
}
int
m_block_min
=
!
Is_causal
?
0
:
std
::
max
(
0
,
(
n_block
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
)
/
kBlockM
);
// We're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
if
(
Double_buffer
&&
m_block
%
2
==
1
)
{
// Double buffer for sQ
tQsQ
.
data
()
=
tQsQ
.
data
()
+
size
(
sQ
);
...
...
@@ -743,7 +711,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
Tensor
lse
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
decltype
(
size
(
taccScS_row
))
::
value
>>
{});
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
Is_even_MN
||
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
}
...
...
@@ -824,11 +791,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements beyond actual_seqlen_k.
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
if
(
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
+
binfo
.
actual_seqlen_q
-
binfo
.
actual_seqlen_k
||
(
!
Is_even_MN
&&
(
n_block
+
1
)
*
kBlockN
>=
binfo
.
actual_seqlen_k
))
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_
q
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
))
,
binfo
.
actual_seqlen_
k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
))
,
binfo
.
actual_seqlen_q
,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS
*
16
);
}
...
...
@@ -837,11 +804,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// Compute the exponential value.
flash
::
scale_apply_exp2
<
/*scale_max=*/
false
>
(
scores
,
lse
,
params
.
scale_softmax_log2
);
if
(
Is_dropout
)
{
u
int
32_t
warp_id
=
tidx
/
32
;
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
int
warp_id
=
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
Tensor
scores_dropped
=
make_tensor
(
scores
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
scores
.
layout
()));
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
scores_dropped
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
...
...
@@ -1341,7 +1308,6 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
Tensor
lse
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
decltype
(
size
(
taccScS_row
))
::
value
>>
{});
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const
int
row
=
get
<
0
>
(
taccScS_row
(
mi
));
lse
(
mi
)
=
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
?
gLSE
(
row
)
:
0
;
}
...
...
@@ -1379,18 +1345,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
// the corresponding values of K would be 0, so the result would still be correct.
if
(
Is_causal
&&
m_block
*
kBlockM
<
(
n_block
+
1
)
*
kBlockN
)
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
);
}
// Compute the exponential value.
flash
::
scale_apply_exp2
<
/*scale_max=*/
false
>
(
scores
,
lse
,
params
.
scale_softmax_log2
);
if
(
Is_dropout
)
{
u
int
32_t
warp_id
=
tidx
/
32
;
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
int
warp_id
=
tidx
/
32
;
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
warp_id
%
AtomLayoutMS
;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert
(
MMA_N_SdP
%
2
==
0
);
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
Tensor
scores_dropped
=
make_tensor
(
scores
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
scores
.
layout
()));
flash
::
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>
(
scores_dropped
,
params
.
p_dropout_in_uint8_t
,
seed
,
offset
,
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
9e5e8bc9
...
...
@@ -118,7 +118,7 @@ inline __device__ void write_softmax_to_gmem(
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_
M
N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
...
...
@@ -130,8 +130,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
// The global block index.
const
int
block_id
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
+
gridDim
.
x
*
gridDim
.
y
*
blockIdx
.
z
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
...
...
@@ -139,16 +137,60 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
MMA_M
=
kBlockM
/
decltype
(
size
<
0
>
(
typename
Kernel_traits
::
TiledMma
::
TiledShape_MNK
{}))
::
value
;
const
BlockInfo
<
/*Varlen=*/
!
Is_even_N
>
binfo
(
params
,
bidb
);
const
BlockInfo
<
/*Varlen=*/
!
Is_even_
M
N
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
(
Is_causal
)
{
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
(
(
m_block
+
1
)
*
kBlockM
+
int
(
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
)
,
kBlockN
));
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
(
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
,
kBlockN
));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
// We exit early and write 0 to gO and gLSE.
// Otherwise we might read OOB elements from gK and gV.
if
(
n_block_max
<=
0
)
{
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
params
.
rng_state
[
0
]
=
std
::
get
<
0
>
(
seeds
);
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
typename
Kernel_traits
::
GmemTiledCopyO
gmem_tiled_copy_O
;
auto
gmem_thr_copy_O
=
gmem_tiled_copy_O
.
get_thread_slice
(
tidx
);
Tensor
tOgO
=
gmem_thr_copy_O
.
partition_D
(
gO
);
Tensor
tOrO
=
make_tensor
<
Element
>
(
shape
(
tOgO
));
clear
(
tOrO
);
// Construct identity layout for sO
Tensor
cO
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gO
),
size
<
1
>
(
gO
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor
tOcO
=
gmem_thr_copy_O
.
partition_D
(
cO
);
Tensor
tOpO
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tOgO
)));
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tOgO
);
++
m
)
{
const
int
row
=
get
<
0
>
(
tOcO
(
0
,
m
,
0
));
if
(
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
&&
get
<
1
>
(
tOcO
(
0
,
m
,
0
))
==
0
)
{
gLSE
(
row
)
=
INFINITY
;
}
}
return
;
}
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
...
...
@@ -275,8 +317,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
if
(
Kernel_traits
::
Is_Q_in_regs
)
{
cute
::
cp_async_fence
();
}
// // Copy rmem to smem
...
...
@@ -298,8 +340,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_N
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
flash
::
copy
<
Is_even_
M
N
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
...
...
@@ -317,7 +359,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
unsigned
long
long
offset
=
std
::
get
<
1
>
(
seeds
)
+
(
bidb
*
params
.
h
+
bidh
)
*
32
+
tidx
%
32
;
// Save seed and offset for backward.
if
(
block_id
==
0
&&
tidx
==
0
)
{
if
(
Is_dropout
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
&&
tidx
==
0
)
{
params
.
rng_state
[
0
]
=
seed
;
params
.
rng_state
[
1
]
=
std
::
get
<
1
>
(
seeds
);
}
...
...
@@ -330,7 +372,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
constexpr
int
n_masking_steps
=
Is_causal
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
1
;
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
!
Is_causal
?
1
:
(
Is_even_MN
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
...
...
@@ -344,7 +390,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_
M
N
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
...
...
@@ -363,7 +409,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if
(
!
Is_causal
)
{
if
(
!
Is_even_N
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
if
(
!
Is_even_
M
N
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
...
...
@@ -376,9 +422,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_k
,
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
...
...
@@ -405,8 +452,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
...
...
@@ -468,8 +515,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
u
int
32_t
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
u
int
32_t
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
tOrP_copy
=
make_fragment_like
(
tOrP
);
cute
::
copy
(
tOrP
,
tOrP_copy
);
...
...
@@ -563,14 +610,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
/*
Is_even_MN
=*/
false
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_O
,
tOrO
,
tOgO
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_
M
N
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
...
...
@@ -586,7 +633,7 @@ inline __device__ void compute_attn(const Params ¶ms) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_N
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
flash
::
compute_attn_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_
M
N
,
Is_even_K
,
Return_softmax
>
(
params
,
bidb
,
bidh
,
m_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
9e5e8bc9
...
...
@@ -10,9 +10,9 @@
#include "flash.h"
#include "flash_fwd_kernel.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_N
,
bool
Is_even_K
,
bool
Return_softmax
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_
M
N
,
bool
Is_even_K
,
bool
Return_softmax
>
__global__
void
flash_fwd_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_N
,
Is_even_K
,
Return_softmax
>
(
params
);
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_
M
N
,
Is_even_K
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
...
...
@@ -26,17 +26,15 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
b
,
params
.
h
);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_q as well.
const
bool
is_even_N
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
;
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
return_softmax
=
params
.
p_ptr
!=
nullptr
;
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_
M
N
,
IsEven
M
NConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenNConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEven
M
NConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEven
M
NConst, true, ReturnSoftmaxConst && Is_dropout>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
...
...
csrc/flash_attn/src/softmax.h
View file @
9e5e8bc9
...
...
@@ -117,18 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
u
int
32_t
max_seqlen_k
,
const
u
int
32_t
col_idx_offset_
=
0
)
{
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
u
int
32_t
lane_id
=
threadIdx
.
x
%
32
;
const
u
int
32_t
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
u
int
32_t
col_idx_base
=
col_idx_offset
+
nj
*
8
;
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
u
int
32_t
col_idx
=
col_idx_base
+
j
;
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
max_seqlen_k
)
{
// Without the "make_coord" we get wrong results
#pragma unroll
...
...
@@ -141,28 +141,28 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
u
int
32_t
col_idx_offset_
,
const
u
int
32_t
max_seqlen_
q
,
const
u
int
32_t
max_seqlen_k
,
const
u
int
32_t
row_idx_offset_
,
const
u
int
32_t
warp_row_stride
)
{
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_
k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
u
int
32_t
lane_id
=
threadIdx
.
x
%
32
;
// const
u
int
32_t
row_idx_offset = row_idx_offset_ + lane_id / 4;
const
u
int
32_t
row_idx_offset
=
row_idx_offset_
;
const
u
int
32_t
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
const
int
row_idx_offset
=
row_idx_offset_
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
u
int
32_t
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
u
int
32_t
row_idx
=
row_idx_base
+
i
*
8
;
const
u
int
32_t
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
);
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
);
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
u
int
32_t
col_idx_base
=
col_idx_offset
+
nj
*
8
;
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
u
int
32_t
col_idx
=
col_idx_base
+
j
;
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
col_idx_limit
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
...
...
@@ -180,7 +180,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
u
int
32_t
col_idx_offset_
,
const
u
int
32_t
max_seqlen_k
,
const
u
int
32_t
row_idx_offset_
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
...
...
@@ -189,7 +189,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
u
int
32_t
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset_
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset_
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
...
...
@@ -207,8 +207,8 @@ inline __device__ void apply_mask_causal_w_idx(
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
uint8_t
p_dropout_in_uint8_t
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
u
int
32_t
block_row_start
,
u
int
32_t
block_col_start
,
u
int
32_t
block_row_stride
)
{
int
block_row_start
,
int
block_col_start
,
int
block_row_stride
)
{
// tensor has shape (8, MMA_M, MMA_N / 2)
using
T
=
typename
Engine
::
value_type
;
auto
encode_dropout
=
[](
bool
keep
,
T
val
)
{
...
...
flash_attn/__init__.py
View file @
9e5e8bc9
__version__
=
"2.
0.9
"
__version__
=
"2.
1.0
"
from
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
...
...
flash_attn/flash_attn_interface.py
View file @
9e5e8bc9
...
...
@@ -528,6 +528,18 @@ def flash_attn_kvpacked_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim)
...
...
@@ -559,6 +571,18 @@ def flash_attn_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
...
...
@@ -645,6 +669,18 @@ def flash_attn_varlen_kvpacked_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
...
@@ -703,6 +739,18 @@ def flash_attn_varlen_func(
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...
...
tests/test_flash_attn.py
View file @
9e5e8bc9
This diff is collapsed.
Click to expand it.
training/Dockerfile
View file @
9e5e8bc9
...
...
@@ -89,7 +89,7 @@ RUN pip install flash-attn==2.0.9
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN
git clone https://github.com/HazyResearch/flash-attention
\
&&
cd
flash-attention
&&
git checkout v2.
0.9
\
&&
cd
flash-attention
&&
git checkout v2.
1.0
\
&&
cd
csrc/fused_softmax
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/rotary
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/xentropy
&&
pip
install
.
&&
cd
../../
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment