Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b1fbbd83
Commit
b1fbbd83
authored
Aug 29, 2023
by
Tri Dao
Browse files
Implement splitKV attention
parent
7a983df7
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
825 additions
and
11 deletions
+825
-11
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+66
-1
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+5
-0
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+6
-6
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+569
-0
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+74
-4
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
+7
-0
No files found.
csrc/flash_attn/flash_api.cpp
View file @
b1fbbd83
...
@@ -178,11 +178,57 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
...
@@ -178,11 +178,57 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
void
run_mha_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
FWD_HEADDIM_SWITCH
(
params
.
d
,
[
&
]
{
FWD_HEADDIM_SWITCH
(
params
.
d
,
[
&
]
{
if
(
params
.
num_splits
<=
1
)
{
// If we don't set it num_splits == 0
run_mha_fwd_
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
run_mha_fwd_
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
}
else
{
run_mha_fwd_splitkv_dispatch
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
}
});
});
});
});
}
}
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 85%
// of the best efficiency.
inline
int
num_splits_heuristic
(
int
batch_nheads_mblocks
,
int
num_SMs
,
int
num_n_blocks
,
int
max_splits
)
{
// If we have enough to almost fill the SMs, then just use 1 split
if
(
batch_nheads_mblocks
>=
0.8
f
*
num_SMs
)
{
return
1
;
}
max_splits
=
std
::
min
({
max_splits
,
num_SMs
,
num_n_blocks
});
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
auto
ceildiv
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
};
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto
is_split_eligible
=
[
&
ceildiv
,
&
num_n_blocks
](
int
num_splits
)
{
return
num_splits
==
1
||
ceildiv
(
num_n_blocks
,
num_splits
)
!=
ceildiv
(
num_n_blocks
,
num_splits
-
1
);
};
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
efficiency
.
push_back
(
0.
f
);
}
else
{
float
n_waves
=
float
(
batch_nheads_mblocks
*
num_splits
)
/
num_SMs
;
float
eff
=
n_waves
/
ceil
(
n_waves
);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
}
}
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
continue
;
}
if
(
efficiency
[
num_splits
-
1
]
>=
0.85
*
max_efficiency
)
{
// printf("num_splits chosen = %d\n", num_splits);
return
num_splits
;
}
}
return
1
;
}
std
::
vector
<
at
::
Tensor
>
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
const
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
mha_fwd
(
const
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
...
@@ -294,6 +340,25 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
...
@@ -294,6 +340,25 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
softmax_scale
,
softmax_scale
,
is_causal
);
is_causal
);
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
is_sm90
||
is_sm8x
?
(
head_size
<=
64
?
256
:
(
head_size
<=
160
?
128
:
64
))
:
(
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
));
const
int
num_n_blocks
=
(
seqlen_k
+
block_n
-
1
)
/
block_n
;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const
int
num_m_blocks
=
(
seqlen_q
+
64
-
1
)
/
64
;
params
.
num_splits
=
1
;
if
(
p_dropout
==
0.0
f
)
{
// SplitKV is not implemented for dropout
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
64
);
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
,
head_size_rounded
},
opts
.
dtype
(
at
::
kFloat
));
params
.
softmax_lseaccum_ptr
=
softmax_lse_accum
.
data_ptr
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
();
}
}
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
...
...
csrc/flash_attn/src/flash.h
View file @
b1fbbd83
...
@@ -53,6 +53,7 @@ struct Flash_fwd_params : public Qkv_params {
...
@@ -53,6 +53,7 @@ struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
// The O matrix (output).
void
*
__restrict__
o_ptr
;
void
*
__restrict__
o_ptr
;
void
*
__restrict__
oaccum_ptr
;
// The stride between rows of O.
// The stride between rows of O.
index_t
o_batch_stride
;
index_t
o_batch_stride
;
...
@@ -64,6 +65,7 @@ struct Flash_fwd_params : public Qkv_params {
...
@@ -64,6 +65,7 @@ struct Flash_fwd_params : public Qkv_params {
// The pointer to the softmax sum.
// The pointer to the softmax sum.
void
*
__restrict__
softmax_lse_ptr
;
void
*
__restrict__
softmax_lse_ptr
;
void
*
__restrict__
softmax_lseaccum_ptr
;
// The dimensions.
// The dimensions.
int
b
,
seqlen_q
,
seqlen_k
,
d
,
seqlen_q_rounded
,
seqlen_k_rounded
,
d_rounded
;
int
b
,
seqlen_q
,
seqlen_k
,
d
,
seqlen_q_rounded
,
seqlen_k_rounded
,
d_rounded
;
...
@@ -96,6 +98,8 @@ struct Flash_fwd_params : public Qkv_params {
...
@@ -96,6 +98,8 @@ struct Flash_fwd_params : public Qkv_params {
bool
is_bf16
;
bool
is_bf16
;
bool
is_causal
;
bool
is_causal
;
int
num_splits
;
// For split-KV version
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -140,5 +144,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
...
@@ -140,5 +144,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
<
typename
T
,
int
Headdim
>
void
run_mha_bwd_
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
template
<
typename
T
,
int
Headdim
>
void
run_mha_bwd_
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
b1fbbd83
...
@@ -64,7 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -64,7 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenMNConst
,
IsEvenKConst
>
;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenMNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
constexpr
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
}
...
@@ -75,7 +75,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -75,7 +75,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
});
});
auto
kernel_dq
=
&
flash_bwd_convert_dq_kernel
<
Kernel_traits
>
;
auto
kernel_dq
=
&
flash_bwd_convert_dq_kernel
<
Kernel_traits
>
;
if
(
Kernel_traits
::
kSmemdQSize
>=
48
*
1024
)
{
if
constexpr
(
Kernel_traits
::
kSmemdQSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel_dq
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemdQSize
));
kernel_dq
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemdQSize
));
}
}
...
@@ -103,7 +103,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -103,7 +103,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenNConst
,
IsEvenKConst
>
;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
constexpr
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
}
...
@@ -114,7 +114,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -114,7 +114,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
});
});
auto
kernel_dkv
=
&
flash_bwd_convert_dkv_kernel
<
Kernel_traits
>
;
auto
kernel_dkv
=
&
flash_bwd_convert_dkv_kernel
<
Kernel_traits
>
;
if
(
Kernel_traits
::
kSmemKVSize
>=
48
*
1024
)
{
if
constexpr
(
Kernel_traits
::
kSmemKVSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel_dkv
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemKVSize
));
kernel_dkv
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemKVSize
));
}
}
...
@@ -147,7 +147,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
...
@@ -147,7 +147,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
// BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
// BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
// // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel<Kernel_traits, Is_dropout, IsCausalConst>;
// // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel<Kernel_traits, Is_dropout, IsCausalConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
// if (smem_size_dq_dk_dv >= 48 * 1024) {
// if
constexpr
(smem_size_dq_dk_dv >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
// }
// }
...
@@ -159,7 +159,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
...
@@ -159,7 +159,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
// });
// });
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
// if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
// if
constexpr
(Kernel_traits::kSmemdQSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
// kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
// }
// }
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
b1fbbd83
...
@@ -617,6 +617,407 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -617,6 +617,407 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
num_n_splits
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
const
int
n_block_min
=
n_split_idx
*
n_blocks_per_split
;
int
n_block_max
=
std
::
min
(
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
),
(
n_split_idx
+
1
)
*
n_blocks_per_split
);
if
(
Is_causal
)
{
n_block_max
=
std
::
min
(
n_block_max
,
cute
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
,
kBlockN
));
}
if
(
n_block_min
>=
n_block_max
)
{
// This also covers the case where n_block_max <= 0
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
// Otherwise we might read OOB elements from gK and gV,
// or get wrong results when we combine gOaccum from different blocks.
const
index_t
row_offset_oaccum
=
(((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_rounded
;
const
index_t
row_offset_lseaccum
=
((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
oaccum_ptr
)
+
row_offset_oaccum
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lseaccum
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
typename
Kernel_traits
::
GmemTiledCopyOaccum
gmem_tiled_copy_Oaccum
;
auto
gmem_thr_copy_Oaccum
=
gmem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
Tensor
tOgOaccum
=
gmem_thr_copy_Oaccum
.
partition_D
(
gOaccum
);
Tensor
tOrOaccum
=
make_tensor
<
ElementAccum
>
(
shape
(
tOgOaccum
));
clear
(
tOrOaccum
);
// Construct identity layout for sO
Tensor
cO
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gOaccum
),
size
<
1
>
(
gOaccum
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor
tOcO
=
gmem_thr_copy_Oaccum
.
partition_D
(
cO
);
Tensor
tOpO
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tOgOaccum
)));
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_Oaccum
,
tOrOaccum
,
tOgOaccum
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tOgOaccum
);
++
m
)
{
const
int
row
=
get
<
0
>
(
tOcO
(
0
,
m
,
0
));
if
(
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
&&
get
<
1
>
(
tOcO
(
0
,
m
,
0
))
==
0
)
{
gLSEaccum
(
row
)
=
-
INFINITY
;
}
}
return
;
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const
index_t
row_offset_q
=
binfo
.
q_offset
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
// We move K and V to the last block.
const
index_t
row_offset_k
=
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_v
=
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
q_row_stride
,
_1
{}));
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
v_row_stride
,
_1
{}));
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutQ
{});
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
(
Kernel_traits
::
Share_Q_K_smem
?
0
:
size
(
sQ
)),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
typename
Kernel_traits
::
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
// (MMA,MMA_M,MMA_K)
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
//
// Copy Atom retiling
//
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
// TODO: this might need to change if we change the mma instruction in SM70
Tensor
scores_max
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
2
*
size
<
1
>
(
acc_o
)
>>
{});
Tensor
scores_sum
=
make_fragment_like
(
scores_max
);
//
// PREDICATES
//
// // Allocate predicate tensors for m and n
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
// Construct identity layout for sQ and sK
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
// Repeat the partitioning with identity layouts
Tensor
tQcQ
=
gmem_thr_copy_QKV
.
partition_S
(
cQ
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tKVcKV
=
gmem_thr_copy_QKV
.
partition_S
(
cKV
);
// (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Allocate predicate tensors for k
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
// Set predicates for k bounds
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tQpQ
);
++
k
)
{
tQpQ
(
k
)
=
get
<
1
>
(
tQcQ
(
0
,
0
,
k
))
<
params
.
d
;
}
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tKVpKV
);
++
k
)
{
tKVpKV
(
k
)
=
get
<
1
>
(
tKVcKV
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Prologue
Tensor
tQrQ
=
make_fragment_like
(
tQgQ
);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
if
(
Kernel_traits
::
Is_Q_in_regs
)
{
cute
::
cp_async_fence
();
}
if
(
Kernel_traits
::
Share_Q_K_smem
)
{
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
__syncthreads
();
}
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
if
(
Kernel_traits
::
Is_Q_in_regs
&&
!
Kernel_traits
::
Share_Q_K_smem
)
{
flash
::
cp_async_wait
<
1
>
();
__syncthreads
();
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tSsQ
)
==
size
<
1
>
(
tSrQ_copy_view
));
// M
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
}
clear
(
acc_o
);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr
int
n_masking_steps
=
!
Is_causal
?
1
:
(
Is_even_MN
?
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
// Advance gV
if
(
masking_step
>
0
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
// if (cute::thread0()) { print(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if
(
!
Is_causal
)
{
if
(
!
Is_even_MN
)
{
flash
::
apply_mask
(
scores
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
}
}
else
{
flash
::
apply_mask_causal
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
);
}
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
n_block
>
n_block_min
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
}
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step
==
0
?
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
)
:
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// Convert scores from fp32 to fp16/bf16
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
if
(
n_masking_steps
>
1
&&
n_block
<=
n_block_min
)
{
--
n_block
;
break
;
}
}
// These are the iterations where we don't need masking on S
for
(;
n_block
>=
n_block_min
;
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_M, MMA_N)
clear
(
acc_s
);
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
// Advance gV
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
cute
::
cp_async_fence
();
flash
::
gemm
<
/*A_in_regs=*/
Kernel_traits
::
Is_Q_in_regs
>
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma
,
smem_tiled_copy_Q
,
smem_tiled_copy_K
,
smem_thr_copy_Q
,
smem_thr_copy_K
);
flash
::
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
n_block
>
n_block_min
)
{
// Advance gK
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_QKV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
}
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
softmax_rescale_o
<
/*Is_first=*/
false
>
(
scores
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
scores
);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor
tOrP
=
make_tensor
(
rP
.
data
(),
flash
::
convert_layout_rowcol_Aregs
<
Kernel_traits
::
TiledMma
>
(
rP
.
layout
()));
flash
::
gemm_A_in_regs
(
acc_o
,
tOrP
,
tOrVt
,
tOsVt
,
tiled_mma
,
smem_tiled_copy_V
,
smem_thr_copy_V
);
}
// Epilogue
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
Tensor
lse
=
make_fragment_like
(
scores_sum
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
scores_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
INFINITY
:
scores_max
(
mi
)
*
params
.
scale_softmax
+
__logf
(
sum
);
float
scale
=
inv_sum
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
}
// if (cute::thread0()) { print(acc_o_rowcol); }
Tensor
sOaccum
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutO
{});
// (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
auto
smem_tiled_copy_Oaccum
=
make_tiled_copy_C
(
typename
Kernel_traits
::
SmemCopyAtomOaccum
{},
tiled_mma
);
auto
smem_thr_copy_Oaccum
=
smem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
Tensor
taccOrOaccum
=
smem_thr_copy_Oaccum
.
retile_S
(
acc_o
);
// ((Atom,AtomNum), MMA_M, MMA_N)
Tensor
taccOsOaccum
=
smem_thr_copy_Oaccum
.
partition_D
(
sOaccum
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here.
if
(
Kernel_traits
::
Share_Q_K_smem
)
{
__syncthreads
();
}
cute
::
copy
(
smem_tiled_copy_Oaccum
,
taccOrOaccum
,
taccOsOaccum
);
const
index_t
row_offset_oaccum
=
(((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
)
*
params
.
d_rounded
;
const
index_t
row_offset_lseaccum
=
((
n_split_idx
*
params
.
b
+
bidb
)
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
oaccum_ptr
)
+
row_offset_oaccum
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lseaccum
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
typename
Kernel_traits
::
GmemTiledCopyOaccum
gmem_tiled_copy_Oaccum
;
auto
gmem_thr_copy_Oaccum
=
gmem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
Tensor
tOsOaccum
=
gmem_thr_copy_Oaccum
.
partition_S
(
sOaccum
);
// ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor
tOgOaccum
=
gmem_thr_copy_Oaccum
.
partition_D
(
gOaccum
);
__syncthreads
();
Tensor
tOrOaccum
=
make_tensor
<
ElementAccum
>
(
shape
(
tOgOaccum
));
cute
::
copy
(
gmem_tiled_copy_Oaccum
,
tOsOaccum
,
tOrOaccum
);
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma
.
partition_C
(
caccO
);
// (MMA,MMA_M,MMA_K)
static_assert
(
decltype
(
size
<
0
>
(
taccOcO
))
::
value
==
4
);
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
Tensor
taccOcO_row
=
logical_divide
(
taccOcO
,
Shape
<
_2
>
{})(
make_coord
(
0
,
_
),
_
,
0
);
CUTE_STATIC_ASSERT_V
(
size
(
lse
)
==
size
(
taccOcO_row
));
// MMA_M
if
(
get
<
1
>
(
taccOcO_row
(
0
))
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccOcO_row
(
mi
));
if
(
row
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
// Construct identity layout for sO
Tensor
cO
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sOaccum
),
size
<
1
>
(
sOaccum
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor
tOcO
=
gmem_thr_copy_Oaccum
.
partition_D
(
cO
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tOpO
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tOgOaccum
)));
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tOpO
);
++
k
)
{
tOpO
(
k
)
=
get
<
1
>
(
tOcO
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
false
,
/*Clear_OOB_K=*/
false
>
(
gmem_tiled_copy_Oaccum
,
tOrOaccum
,
tOgOaccum
,
tOcO
,
tOpO
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
typename
Params
>
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
inline
__device__
void
compute_attn
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
const
int
m_block
=
blockIdx
.
x
;
...
@@ -638,4 +1039,172 @@ inline __device__ void compute_attn(const Params ¶ms) {
...
@@ -638,4 +1039,172 @@ inline __device__ void compute_attn(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv
(
const
Params
&
params
)
{
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
z
/
params
.
h
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
z
-
bidb
*
params
.
h
;
const
int
n_split_idx
=
blockIdx
.
y
;
const
int
num_n_splits
=
gridDim
.
y
;
flash
::
compute_attn_1rowblock_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
>
(
params
,
bidb
,
bidh
,
m_block
,
n_split_idx
,
num_n_splits
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
int
Log_max_splits
,
bool
Is_even_K
,
typename
Params
>
inline
__device__
void
combine_attn_seqk_parallel
(
const
Params
&
params
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
constexpr
int
kMaxSplits
=
1
<<
Log_max_splits
;
constexpr
int
kBlockM
=
16
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
static_assert
(
kMaxSplits
<=
128
,
"kMaxSplits must be <= 128"
);
// static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer");
static_assert
(
kBlockM
==
16
||
kBlockM
==
32
,
"kBlockM must be 16 or 32"
);
static_assert
(
Kernel_traits
::
kNThreads
==
128
,
"We assume that each block has 128 threads"
);
// Shared memory.
// kBlockM + 1 instead of kBlockM to reduce bank conflicts.
__shared__
ElementAccum
sLSE
[
kMaxSplits
][
kBlockM
+
1
];
// The thread and block index.
const
int
tidx
=
threadIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
index_t
row_offset_lse
=
bidx
*
kBlockM
;
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kMaxSplits
>
,
Int
<
kBlockM
>>
{},
make_stride
(
params
.
b
*
params
.
h
*
params
.
seqlen_q
,
_1
{}));
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
constexpr
int
kNLsePerThread
=
(
kMaxSplits
*
kBlockM
+
Kernel_traits
::
kNThreads
-
1
)
/
Kernel_traits
::
kNThreads
;
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
constexpr
int
kRowsPerLoadLSE
=
Kernel_traits
::
kNThreads
/
kBlockM
;
#pragma unroll
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
const
int
row
=
l
*
kRowsPerLoadLSE
+
tidx
/
kBlockM
;
const
int
col
=
tidx
%
kBlockM
;
ElementAccum
lse
=
(
row
<
params
.
num_splits
&&
col
<
params
.
b
*
params
.
h
*
params
.
seqlen_q
-
bidx
*
kBlockM
)
?
gLSEaccum
(
row
,
col
)
:
-
INFINITY
;
if
(
row
<
kMaxSplits
)
{
sLSE
[
row
][
col
]
=
lse
;
}
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
}
// if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
__syncthreads
();
Tensor
lse_accum
=
make_tensor
<
ElementAccum
>
(
Shape
<
Int
<
kNLsePerThread
>>
{});
constexpr
int
kRowsPerLoadTranspose
=
std
::
min
(
kRowsPerLoadLSE
,
kMaxSplits
);
// To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
// each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
// 16 rows, so each time we load we can load 8 rows).
// constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
// static_assert(kThreadsPerSplit <= 32);
static_assert
(
kRowsPerLoadTranspose
<=
32
);
static_assert
(
kNLsePerThread
*
kRowsPerLoadTranspose
<=
kMaxSplits
);
#pragma unroll
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
const
int
row
=
l
*
kRowsPerLoadTranspose
+
tidx
%
kRowsPerLoadTranspose
;
const
int
col
=
tidx
/
kRowsPerLoadTranspose
;
lse_accum
(
l
)
=
(
row
<
kMaxSplits
&&
col
<
kBlockM
)
?
sLSE
[
row
][
col
]
:
-
INFINITY
;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
}
// Compute the logsumexp of the LSE along the split dimension.
ElementAccum
lse_max
=
lse_accum
(
0
);
#pragma unroll
for
(
int
l
=
1
;
l
<
kNLsePerThread
;
++
l
)
{
lse_max
=
max
(
lse_max
,
lse_accum
(
l
));
}
MaxOp
<
float
>
max_op
;
lse_max
=
Allreduce
<
kRowsPerLoadTranspose
>::
run
(
lse_max
,
max_op
);
lse_max
==
lse_max
==
-
INFINITY
?
0.0
f
:
lse_max
;
// In case all local LSEs are -inf
float
lse_sum
=
expf
(
lse_accum
(
0
)
-
lse_max
);
#pragma unroll
for
(
int
l
=
1
;
l
<
kNLsePerThread
;
++
l
)
{
lse_sum
+=
expf
(
lse_accum
(
l
)
-
lse_max
);
}
SumOp
<
float
>
sum_op
;
lse_sum
=
Allreduce
<
kRowsPerLoadTranspose
>::
run
(
lse_sum
,
sum_op
);
ElementAccum
lse_logsum
=
logf
(
lse_sum
)
+
lse_max
;
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
if
(
tidx
%
kRowsPerLoadTranspose
==
0
&&
tidx
/
kRowsPerLoadTranspose
<
kBlockM
)
{
gLSE
(
tidx
/
kRowsPerLoadTranspose
)
=
lse_logsum
;
}
// Store the scales exp(lse - lse_logsum) in shared memory.
#pragma unroll
for
(
int
l
=
0
;
l
<
kNLsePerThread
;
++
l
)
{
const
int
row
=
l
*
kRowsPerLoadTranspose
+
tidx
%
kRowsPerLoadTranspose
;
const
int
col
=
tidx
/
kRowsPerLoadTranspose
;
if
(
row
<
params
.
num_splits
&&
col
<
kBlockM
)
{
sLSE
[
row
][
col
]
=
expf
(
lse_accum
(
l
)
-
lse_logsum
);
}
}
__syncthreads
();
const
index_t
row_offset_oaccum
=
bidx
*
kBlockM
*
params
.
d_rounded
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
oaccum_ptr
)
+
row_offset_oaccum
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
typename
Kernel_traits
::
GmemTiledCopyOaccum
gmem_tiled_copy_Oaccum
;
auto
gmem_thr_copy_Oaccum
=
gmem_tiled_copy_Oaccum
.
get_thread_slice
(
tidx
);
Tensor
tOgOaccum
=
gmem_thr_copy_Oaccum
.
partition_S
(
gOaccum
);
Tensor
tOrO
=
make_tensor
<
ElementAccum
>
(
shape
(
tOgOaccum
));
Tensor
tOrOaccum
=
make_tensor
<
ElementAccum
>
(
shape
(
tOgOaccum
));
clear
(
tOrO
);
// Predicates
Tensor
cOaccum
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// Repeat the partitioning with identity layouts
Tensor
tOcOaccum
=
gmem_thr_copy_Oaccum
.
partition_S
(
cOaccum
);
Tensor
tOpOaccum
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tOgOaccum
)));
if
(
!
Is_even_K
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
(
tOpOaccum
);
++
k
)
{
tOpOaccum
(
k
)
=
get
<
1
>
(
tOcOaccum
(
0
,
0
,
k
))
<
params
.
d
;
}
}
// Load Oaccum in then scale and accumulate to O
#pragma unroll 2
for
(
int
split
=
0
;
split
<
params
.
num_splits
;
++
split
)
{
flash
::
copy
<
/*Is_even_MN=*/
false
,
Is_even_K
>
(
gmem_tiled_copy_Oaccum
,
tOgOaccum
,
tOrOaccum
,
tOcOaccum
,
tOpOaccum
,
params
.
b
*
params
.
h
*
params
.
seqlen_q
-
bidx
*
kBlockM
);
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tOrOaccum
);
++
m
)
{
int
row
=
get
<
0
>
(
tOcOaccum
(
0
,
m
,
0
));
ElementAccum
lse_scale
=
sLSE
[
split
][
row
];
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
tOrOaccum
);
++
k
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
>
(
tOrOaccum
);
++
i
)
{
tOrO
(
i
,
m
,
k
)
+=
lse_scale
*
tOrOaccum
(
i
,
m
,
k
);
}
}
// if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); }
}
tOgOaccum
.
data
()
=
tOgOaccum
.
data
()
+
params
.
b
*
params
.
h
*
params
.
seqlen_q
*
params
.
d_rounded
;
}
// if (cute::thread0()) { print(tOrO); }
Tensor
rO
=
flash
::
convert_type
<
Element
>
(
tOrO
);
// Write to gO
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
rO
);
++
m
)
{
const
int
idx
=
bidx
*
kBlockM
+
get
<
0
>
(
tOcOaccum
(
0
,
m
,
0
));
if
(
idx
<
params
.
b
*
params
.
h
*
params
.
seqlen_q
)
{
const
int
batch_idx
=
idx
/
(
params
.
h
*
params
.
seqlen_q
);
const
int
head_idx
=
(
idx
-
batch_idx
*
(
params
.
h
*
params
.
seqlen_q
))
/
params
.
seqlen_q
;
// The index to the rows of Q
const
int
row
=
idx
-
batch_idx
*
(
params
.
h
*
params
.
seqlen_q
)
-
head_idx
*
params
.
seqlen_q
;
auto
o_ptr
=
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
batch_idx
*
params
.
o_batch_stride
+
head_idx
*
params
.
o_head_stride
+
row
*
params
.
o_row_stride
;
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
rO
);
++
k
)
{
if
(
Is_even_K
||
tOpOaccum
(
k
))
{
const
int
col
=
get
<
1
>
(
tOcOaccum
(
0
,
m
,
k
));
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
o_ptr
+
col
),
Shape
<
Int
<
decltype
(
size
<
0
>
(
rO
))
::
value
>>
{},
Stride
<
_1
>
{});
// TODO: Should check if this is using vectorized store, but it seems pretty fast
copy
(
rO
(
_
,
m
,
k
),
gO
);
// if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
// reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
}
}
}
}
}
}
// namespace flash
}
// namespace flash
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
b1fbbd83
...
@@ -15,6 +15,17 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) {
...
@@ -15,6 +15,17 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
>
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
int
Log_max_splits
,
bool
Is_even_K
>
__global__
void
flash_fwd_splitkv_combine_kernel
(
Flash_fwd_params
params
)
{
static_assert
(
Log_max_splits
>=
1
);
flash
::
combine_attn_seqk_parallel
<
Kernel_traits
,
Log_max_splits
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
void
run_flash_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_flash_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
...
@@ -35,13 +46,13 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -35,13 +46,13 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
// Will only return softmax if dropout, to reduce compilation time.
// Will only return softmax if dropout, to reduce compilation time.
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
if
(
smem_size
>=
48
*
1024
)
{
if
constexpr
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
int
ctas_per_sm
;
//
int ctas_per_sm;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
//
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
kNThreads
,
smem_size
);
//
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
...
@@ -50,6 +61,65 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -50,6 +61,65 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
});
}
}
template
<
typename
Kernel_traits
>
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
num_splits
,
params
.
b
*
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
// TODO: do we want to guarantee that seqlen_q <= seqlen_k? That would simplify the kernel a bit.
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
constexpr
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
16
-
1
)
/
16
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
4
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
8
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
16
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
32
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
64
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
// } else if (params.num_splits <= 128) {
// flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
}
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
constexpr
int
kBlockM
=
64
;
// Fixed for all head dimensions
if
(
!
is_sm8x
)
{
// A100, H100
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
constexpr
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
160
?
128
:
64
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
}
else
{
// Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above
constexpr
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
32
;
constexpr
int
Headdim
=
32
;
...
...
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
128
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
128
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
160
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
160
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
192
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
192
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
224
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
224
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
256
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
256
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
32
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
32
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
64
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
64
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
96
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment