Commit e148767a authored by rocking's avatar rocking
Browse files

Support unpad lse layout for splitkv

parent e6c489df
...@@ -283,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -283,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_lse_acc, args.nhead_stride_lse_acc,
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.batch_stride_lse_acc,
args.batch_stride_o_acc, args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc, args.split_stride_o_acc,
...@@ -375,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) ...@@ -375,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc, args.batch_stride_o_acc,
args.batch_stride_lse,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc); args.split_stride_o_acc);
} }
......
...@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) + _SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) + (pn.empty() ? "" : "_" + pn) +
(kStoreLSE ? "_lse" : "" ) + (kStoreLSE ? "_lse" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" ); (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
...@@ -100,9 +99,10 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -100,9 +99,10 @@ struct FmhaFwdSplitKVCombineKernel
struct CommonLSEKargs struct CommonLSEKargs
{ {
void* lse_ptr = nullptr; void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0; ck_tile::index_t nhead_stride_lse = 0;
ck_tile::index_t batch_stride_lse = 0; ck_tile::index_t batch_stride_lse_acc = 0;
ck_tile::index_t batch_stride_lse = 0;
}; };
struct Fp8StaticQuantKargs struct Fp8StaticQuantKargs
...@@ -166,7 +166,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -166,7 +166,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -176,9 +175,10 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -176,9 +175,10 @@ struct FmhaFwdSplitKVCombineKernel
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
kargs.lse_ptr = lse_ptr; kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse; kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse; kargs.batch_stride_lse_acc = batch_stride_lse_acc;
kargs.batch_stride_lse = batch_stride_lse;
} }
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc) ck_tile::index_t split_stride_o_acc)
{ {
...@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{ {
kargs.lse_ptr = lse_ptr; kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse; kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
} }
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -274,17 +270,12 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -274,17 +270,12 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc = const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc; static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kStoreLSE) long_index_t batch_offset_lse_acc = 0;
{ long_index_t batch_offset_lse = 0;
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; long_index_t batch_offset_o = 0;
}
if constexpr(kIsGroupMode) if constexpr(kIsGroupMode)
{ {
...@@ -293,6 +284,12 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -293,6 +284,12 @@ struct FmhaFwdSplitKVCombineKernel
batch_offset_o = query_start * kargs.row_stride_o; batch_offset_o = query_start * kargs.row_stride_o;
if constexpr(kStoreLSE)
{
batch_offset_lse_acc = query_start;
batch_offset_lse = query_start;
}
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
...@@ -307,6 +304,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -307,6 +304,13 @@ struct FmhaFwdSplitKVCombineKernel
else else
{ {
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o; batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
if constexpr(kStoreLSE)
{
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
......
...@@ -47,6 +47,7 @@ struct FmhaFwdSplitKVKernel ...@@ -47,6 +47,7 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasMask = FmhaMask::IsMasking;
...@@ -85,7 +86,7 @@ struct FmhaFwdSplitKVKernel ...@@ -85,7 +86,7 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -136,7 +137,6 @@ struct FmhaFwdSplitKVKernel ...@@ -136,7 +137,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
...@@ -216,6 +216,7 @@ struct FmhaFwdSplitKVKernel ...@@ -216,6 +216,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -313,7 +314,6 @@ struct FmhaFwdSplitKVKernel ...@@ -313,7 +314,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -323,7 +323,8 @@ struct FmhaFwdSplitKVKernel ...@@ -323,7 +323,8 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for dropout {}, // placeholder for dropout
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v}; batch_stride_v,
batch_stride_lse_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -394,7 +395,6 @@ struct FmhaFwdSplitKVKernel ...@@ -394,7 +395,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
...@@ -433,7 +433,6 @@ struct FmhaFwdSplitKVKernel ...@@ -433,7 +433,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -511,8 +510,7 @@ struct FmhaFwdSplitKVKernel ...@@ -511,8 +510,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0; long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0; long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0; long_index_t batch_offset_randval = 0;
const long_index_t batch_offset_lse_acc = long_index_t batch_offset_lse_acc = 0;
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc = const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc; static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
...@@ -540,6 +538,10 @@ struct FmhaFwdSplitKVKernel ...@@ -540,6 +538,10 @@ struct FmhaFwdSplitKVKernel
{ {
batch_offset_randval = query_start * kargs.stride_randval; batch_offset_randval = query_start * kargs.stride_randval;
} }
if constexpr(kStoreLSE)
{
batch_offset_lse_acc = query_start;
}
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
...@@ -576,6 +578,11 @@ struct FmhaFwdSplitKVKernel ...@@ -576,6 +578,11 @@ struct FmhaFwdSplitKVKernel
batch_offset_randval = batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval; static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
} }
if constexpr(kStoreLSE)
{
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
}
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment