Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
66a127ae
Commit
66a127ae
authored
Jan 20, 2024
by
Tri Dao
Browse files
Refactor masking in fwd pass into 1 object
parent
ed4959b2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
169 additions
and
147 deletions
+169
-147
csrc/flash_attn/src/dropout.h
csrc/flash_attn/src/dropout.h
+2
-3
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+6
-6
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+41
-120
csrc/flash_attn/src/mask.h
csrc/flash_attn/src/mask.h
+103
-0
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+17
-18
No files found.
csrc/flash_attn/src/dropout.h
View file @
66a127ae
...
...
@@ -25,9 +25,8 @@ struct Dropout {
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
__forceinline__
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor_
,
int
block_row_start
,
int
block_col_start
,
int
block_row_stride
)
{
// tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
tensor
=
make_tensor
(
tensor_
.
data
(),
flash
::
convert_layout_rowcol_dropout
(
tensor_
.
layout
()));
// tensor has shape (8, MMA_M, MMA_N / 2)
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor
tensor
=
make_tensor
(
tensor_
.
data
(),
flash
::
convert_layout_acc_dropout
(
tensor_
.
layout
()));
using
T
=
typename
Engine
::
value_type
;
auto
encode_dropout
=
[](
bool
keep
,
T
val
)
{
return
keep
?
val
:
(
encode_dropout_in_sign_bit
?
-
val
:
T
(
0
));
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
66a127ae
...
...
@@ -527,16 +527,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
static_assert
(
MMA_N_SdP
%
2
==
0
);
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
)
+
(
warp_id
/
AtomLayoutMS
)
*
(
MMA_N_SdP
/
2
);
dropout
.
template
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>(
score
s
,
block_row_idx
,
block_col_idx
,
AtomLayoutMS
acc_
s
,
block_row_idx
,
block_col_idx
,
AtomLayoutMS
);
}
// Convert scores from fp32 to fp16/bf16
Tensor
rP
=
!
Is_dropout
?
flash
::
convert_type
<
Element
>
(
score
s
)
:
flash
::
convert_type_relu
<
Element
>
(
score
s
);
// Reshape rP from (
nrow=(2, MMA_N), ncol=(2
, MMA_N)
)
to ((
2, 2
, 2), MMA_N, MMA_N / 2)
// if using m16n8k16 or (
(2, 2, 1)
, MMA_N, MMA_N) if using m16n8k8.
Tensor
tPrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_
rowcol
_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
rP
.
layout
()));
?
flash
::
convert_type
<
Element
>
(
acc_
s
)
:
flash
::
convert_type_relu
<
Element
>
(
acc_
s
);
// Reshape rP from (
MMA=4, MMA_M
, MMA_N) to ((
4
, 2), MMA_N, MMA_N / 2)
// if using m16n8k16 or (
4
, MMA_N, MMA_N) if using m16n8k8.
Tensor
tPrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_
acc
_Aregs
<
Kernel_traits
::
TiledMmaSdP
>
(
rP
.
layout
()));
Tensor
tPaP
=
smem_thr_copy_PdS
.
retile_S
(
tPrP
);
// ((Atom,AtomNum), MMA_N, MMA_N)
cute
::
copy
(
smem_tiled_copy_PdS
,
tPaP
,
tPsP
);
// if (cute::thread0()) { print(tPaP); }
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
66a127ae
...
...
@@ -265,8 +265,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
flash
::
Softmax
<
2
*
size
<
1
>
(
acc_o
)
>
softmax
;
const
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
flash
::
Alibi
<
Is_causal
>
alibi
(
alibi_slope
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
);
const
float
alibi_slope
=
!
Has_alibi
||
params
.
alibi_slopes_ptr
==
nullptr
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
flash
::
Mask
<
Is_causal
,
Is_local
,
Has_alibi
>
mask
(
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
,
params
.
window_size_left
,
params
.
window_size_right
,
alibi_slope
);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
...
...
@@ -304,43 +304,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
);
// if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
// if (cute::thread0()) { print_tensor(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if
(
Has_alibi
)
{
alibi
.
apply_alibi
(
scores
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
if
(
!
Is_causal
&&
!
Is_local
)
{
if
(
!
Is_even_MN
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
// static_assert(decltype(size<0>(taccScS))::value == 4);
// // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
// Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
// Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout()));
// flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM);
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
flash
::
apply_mask_local
<
/*HasWSLeft=*/
Is_local
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
params
.
window_size_left
,
params
.
window_size_right
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
// if (cute::thread0()) { print_tensor(scores); }
}
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
...
...
@@ -358,26 +324,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
?
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
||
Is_local
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
);
// Convert
score
s from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
score
s
);
// Convert
acc_
s from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
acc_
s
);
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
acc_s_f16_drop
=
make_tensor
(
acc_s_f16
.
data
(),
rP
.
layout
()
);
Tensor
rP_drop
=
make_fragment_like
(
rP
);
cute
::
copy
(
rP
,
rP_drop
);
dropout
.
template
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>(
acc_s_f16
_drop
,
block_row_idx
,
block_col_idx
,
kNWarps
rP
_drop
,
block_row_idx
,
block_col_idx
,
kNWarps
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
cute
::
copy
(
rP_drop
,
tSgS
);
tSgS
.
data
()
=
tSgS
.
data
()
+
(
-
kBlockN
);
}
if
(
Is_dropout
)
{
dropout
.
apply_dropout
(
rP
,
block_row_idx
,
block_col_idx
,
kNWarps
);
}
// Reshape rP from (
nrow=(2, MMA_M), ncol=(2
, MMA_N)
)
to ((
2, 2
, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (
(2, 2, 1)
, MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_
rowcol
_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
// Reshape rP from (
MMA=4, MMA_M
, MMA_N) to ((
4
, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (
4
, MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_
acc
_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
// if (cute::thread0()) { print(tOrP); }
flash
::
gemm_rs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
// if (cute::thread0()) { print(scores); }
...
...
@@ -416,44 +382,31 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
cute
::
cp_async_fence
();
}
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
alibi
.
apply_alibi
(
scores
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
if
(
Is_local
&&
n_block
*
kBlockN
<
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
)
{
flash
::
apply_mask_local
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
params
.
window_size_left
,
params
.
window_size_right
mask
.
template
apply_mask
<
/*Causal_mask=*/
false
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
score
s
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
acc_
s
);
int
block_row_idx
=
m_block
*
(
kBlockM
/
16
)
+
tidx
/
32
;
int
block_col_idx
=
n_block
*
(
kBlockN
/
32
);
if
(
Return_softmax
)
{
Tensor
acc_s_f16
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
acc_s_f16_drop
=
make_tensor
(
acc_s_f16
.
data
(),
rP
.
layout
()
);
Tensor
rP_drop
=
make_fragment_like
(
rP
);
cute
::
copy
(
rP
,
rP_drop
);
dropout
.
template
apply_dropout
<
/*encode_dropout_in_sign_bit=*/
true
>(
acc_s_f16
_drop
,
block_row_idx
,
block_col_idx
,
kNWarps
rP
_drop
,
block_row_idx
,
block_col_idx
,
kNWarps
);
cute
::
copy
(
acc_s_f16
,
tSgS
);
cute
::
copy
(
rP_drop
,
tSgS
);
tSgS
.
data
()
=
tSgS
.
data
()
+
(
-
kBlockN
);
}
if
(
Is_dropout
)
{
dropout
.
apply_dropout
(
rP
,
block_row_idx
,
block_col_idx
,
kNWarps
);
}
// Reshape rP from (
nrow=(2, MMA_M), ncol=(2
, MMA_N)
)
to ((
2, 2
, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (
(2, 2, 1)
, MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_
rowcol
_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
// Reshape rP from (
MMA=4, MMA_M
, MMA_N) to ((
4
, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (
4
, MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_
acc
_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
flash
::
gemm_rs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
}
...
...
@@ -845,7 +798,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
Softmax
<
2
*
size
<
1
>
(
acc_o
)
>
softmax
;
const
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
flash
::
Alibi
<
Is_causal
>
alibi
(
alibi_slope
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
);
flash
::
Mask
<
Is_causal
,
Is_local
,
Has_alibi
>
mask
(
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
,
params
.
window_size_left
,
params
.
window_size_right
,
alibi_slope
);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
...
...
@@ -883,27 +836,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
);
// if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
alibi
.
apply_alibi
(
scores
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
// if (cute::thread0()) { print(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if
(
!
Is_causal
&&
!
Is_local
)
{
if
(
!
Is_even_MN
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
flash
::
apply_mask_local
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
params
.
window_size_left
,
params
.
window_size_right
mask
.
template
apply_mask
<
Is_causal
,
Is_even_MN
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
...
...
@@ -925,14 +860,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
:
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
||
Is_local
||
!
Is_even_MN
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert
score
s from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
score
s
);
// Reshape rP from (
nrow=(2, MMA_M), ncol=(2
, MMA_N)
)
to ((
2, 2
, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (
(2, 2, 1)
, MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_
rowcol
_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
// Convert
acc_
s from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
acc_
s
);
// Reshape rP from (
MMA=4, MMA_M
, MMA_N) to ((
4
, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (
4
, MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_
acc
_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
flash
::
gemm_rs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
if
(
n_masking_steps
>
1
&&
n_block
<=
n_block_min
)
{
...
...
@@ -968,28 +902,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
cute
::
cp_async_fence
();
}
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
alibi
.
apply_alibi
(
scores
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
if
(
Is_local
&&
n_block
*
kBlockN
<
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
)
{
flash
::
apply_mask_local
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
params
.
window_size_left
,
params
.
window_size_right
mask
.
template
apply_mask
<
/*Causal_mask=*/
false
>(
acc_s
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_local
>(
acc_s
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
score
s
);
// Reshape rP from (
nrow=(2, MMA_M), ncol=(2
, MMA_N)
)
to ((
2, 2
, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (
(2, 2, 1)
, MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_
rowcol
_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
acc_
s
);
// Reshape rP from (
MMA=4, MMA_M
, MMA_N) to ((
4
, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (
4
, MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_
acc
_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
flash
::
gemm_rs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
}
...
...
csrc/flash_attn/src/mask.h
View file @
66a127ae
...
...
@@ -107,4 +107,107 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
}
}
template
<
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
>
struct
Mask
{
const
int
max_seqlen_k
,
max_seqlen_q
;
const
int
window_size_left
,
window_size_right
;
const
float
alibi_slope
;
__forceinline__
__device__
Mask
(
const
int
max_seqlen_k
,
const
int
max_seqlen_q
,
const
int
window_size_left
,
const
int
window_size_right
,
const
float
alibi_slope
=
0.
f
)
:
max_seqlen_k
(
max_seqlen_k
)
,
max_seqlen_q
(
max_seqlen_q
)
,
window_size_left
(
window_size_left
)
,
window_size_right
(
window_size_right
)
,
alibi_slope
(
!
Has_alibi
?
0.0
:
alibi_slope
)
{
};
// Causal_mask: whether this particular iteration needs causal masking
template
<
bool
Causal_mask
=
false
,
bool
Is_even_MN
=
true
,
typename
Engine
,
typename
Layout
>
__forceinline__
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor_
,
const
int
col_idx_offset_
,
const
int
row_idx_offset
,
const
int
warp_row_stride
)
{
static_assert
(
!
(
Causal_mask
&&
Is_local
),
"Cannot be both causal and local"
);
static_assert
(
Layout
::
rank
==
3
,
"Only support 3D Tensor"
);
static_assert
(
decltype
(
size
<
0
>
(
tensor_
))
::
value
==
4
,
"First dimension must be 4"
);
static
constexpr
bool
Need_masking
=
Has_alibi
||
Causal_mask
||
Is_local
||
!
Is_even_MN
;
// if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
if
constexpr
(
Need_masking
)
{
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
tensor
=
make_tensor
(
tensor_
.
data
(),
flash
::
convert_layout_acc_rowcol
(
tensor_
.
layout
()));
// Do we need both row and column indices, or just column incides?
static
constexpr
bool
Col_idx_only
=
!
(
Has_alibi
&&
!
Is_causal
)
&&
!
Is_local
&&
!
Causal_mask
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
if
constexpr
(
Col_idx_only
)
{
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
// No causal, no local
if
constexpr
(
Has_alibi
)
{
tensor
(
mi
,
make_coord
(
j
,
nj
))
+=
alibi_slope
*
col_idx
;
}
if
constexpr
(
!
Is_even_MN
)
{
if
(
col_idx
>=
max_seqlen_k
)
{
tensor
(
mi
,
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
}
}
}
}
else
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
col_idx_limit_left
=
std
::
max
(
0
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
window_size_left
);
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
+
window_size_right
);
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
if
constexpr
(
Has_alibi
)
{
if
constexpr
(
Is_causal
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
+=
alibi_slope
*
col_idx
;
}
else
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
-=
alibi_slope
*
abs
(
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
col_idx
);
}
}
if
constexpr
(
Causal_mask
)
{
if
(
col_idx
>=
col_idx_limit_right
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
if
constexpr
(
Is_local
)
{
if
(
col_idx
>=
col_idx_limit_right
||
col_idx
<
col_idx_limit_left
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
if
constexpr
(
!
Causal_mask
&&
!
Is_local
&&
!
Is_even_MN
)
{
// Causal and Local already handles MN masking
if
(
col_idx
>=
max_seqlen_k
)
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
}
}
}
}
}
}
};
};
}
// namespace flash
csrc/flash_attn/src/utils.h
View file @
66a127ae
...
...
@@ -193,34 +193,33 @@ __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert
rowcol
_layout from (
nrow=(2, MMA_M), ncol=(2
, MMA_N)
)
to ((
2, 2
, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to (
(2, 2, 1)
, MMA_M, MMA_N) if using m16n8k8.
// Convert
acc
_layout from (
MMA=4, MMA_M
, MMA_N) to ((
4
, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to (
4
, MMA_M, MMA_N) if using m16n8k8.
template
<
typename
MMA_traits
,
typename
Layout
>
__forceinline__
__device__
auto
convert_layout_
rowcol
_Aregs
(
Layout
rowcol
_layout
)
{
__forceinline__
__device__
auto
convert_layout_
acc
_Aregs
(
Layout
acc
_layout
)
{
using
X
=
Underscore
;
static_assert
(
decltype
(
size
<
0
,
0
>
(
rowcol
_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
1
,
0
>
(
rowcol
_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
0
>
(
acc
_layout
))
::
value
==
4
);
static_assert
(
decltype
(
rank
(
acc
_layout
))
::
value
==
3
);
constexpr
int
mma_shape_K
=
get
<
2
>
(
typename
MMA_traits
::
Shape_MNK
{});
static_assert
(
mma_shape_K
==
8
||
mma_shape_K
==
16
);
constexpr
int
MMA_N_divisor
=
mma_shape_K
==
8
?
1
:
2
;
auto
l
=
logical_divide
(
rowcol_layout
,
Shape
<
X
,
Shape
<
X
,
Int
<
MMA_N_divisor
>>>
{});
// ((2, MMA_M), (2, (2, MMA_N / 2)))
return
make_layout
(
make_layout
(
get
<
1
,
0
>
(
l
),
get
<
0
,
0
>
(
l
),
get
<
1
,
1
,
0
>
(
l
)),
get
<
0
,
1
>
(
l
),
get
<
1
,
1
,
1
>
(
l
));
if
constexpr
(
mma_shape_K
==
8
)
{
return
acc_layout
;
}
else
{
auto
l
=
logical_divide
(
acc_layout
,
Shape
<
X
,
X
,
_2
>
{});
// (4, MMA_M, (2, MMA_N / 2)))
return
make_layout
(
make_layout
(
get
<
0
>
(
l
),
get
<
2
,
0
>
(
l
)),
get
<
1
>
(
l
),
get
<
2
,
1
>
(
l
));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert
rowcol
_layout from (
nrow=(2, MMA_M), ncol=(2
, MMA_N)
)
to ((
2, 2
, 2), MMA_M, MMA_N / 2)
// Convert
acc
_layout from (
MMA=4, MMA_M
, MMA_N) to ((
4
, 2), MMA_M, MMA_N / 2)
template
<
typename
Layout
>
__forceinline__
__device__
auto
convert_layout_
rowcol
_dropout
(
Layout
rowcol
_layout
)
{
__forceinline__
__device__
auto
convert_layout_
acc
_dropout
(
Layout
acc
_layout
)
{
using
X
=
Underscore
;
static_assert
(
decltype
(
size
<
0
,
0
>
(
rowcol_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
1
,
0
>
(
rowcol_layout
))
::
value
==
2
);
auto
l
=
logical_divide
(
rowcol_layout
,
Shape
<
X
,
Shape
<
X
,
Int
<
2
>>>
{});
// ((2, MMA_M), (2, (2, MMA_N / 2)))
return
make_layout
(
make_layout
(
get
<
1
,
0
>
(
l
),
get
<
0
,
0
>
(
l
),
get
<
1
,
1
,
0
>
(
l
)),
get
<
0
,
1
>
(
l
),
get
<
1
,
1
,
1
>
(
l
));
static_assert
(
decltype
(
size
<
0
>
(
acc_layout
))
::
value
==
4
);
static_assert
(
decltype
(
rank
(
acc_layout
))
::
value
==
3
);
auto
l
=
logical_divide
(
acc_layout
,
Shape
<
X
,
X
,
_2
>
{});
// (4, MMA_M, (2, MMA_N / 2)))
return
make_layout
(
make_layout
(
get
<
0
>
(
l
),
get
<
2
,
0
>
(
l
)),
get
<
1
>
(
l
),
get
<
2
,
1
>
(
l
));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment