Commit e148767a authored by rocking's avatar rocking
Browse files

Support unpad lse layout for splitkv

parent e6c489df
......@@ -283,7 +283,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_randval,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
......@@ -375,9 +374,7 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
......
......@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
......@@ -102,6 +101,7 @@ struct FmhaFwdSplitKVCombineKernel
{
void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0;
ck_tile::index_t batch_stride_lse_acc = 0;
ck_tile::index_t batch_stride_lse = 0;
};
......@@ -166,7 +166,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
......@@ -178,6 +177,7 @@ struct FmhaFwdSplitKVCombineKernel
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse_acc = batch_stride_lse_acc;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
......@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
......@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
......@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
......@@ -274,18 +270,13 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
......@@ -293,6 +284,12 @@ struct FmhaFwdSplitKVCombineKernel
batch_offset_o = query_start * kargs.row_stride_o;
if constexpr(kStoreLSE)
{
batch_offset_lse_acc = query_start;
batch_offset_lse = query_start;
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
......@@ -307,6 +304,13 @@ struct FmhaFwdSplitKVCombineKernel
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
if constexpr(kStoreLSE)
{
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
}
// for simplicity, batch stride we just modify the pointer
......
......@@ -47,6 +47,7 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
......@@ -136,7 +137,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
......@@ -216,6 +216,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
};
struct GroupModeKargs
......@@ -313,7 +314,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
......@@ -323,7 +323,8 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for dropout
batch_stride_q,
batch_stride_k,
batch_stride_v};
batch_stride_v,
batch_stride_lse_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
......@@ -394,7 +395,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc,
......@@ -433,7 +433,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
......@@ -511,8 +510,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
......@@ -540,6 +538,10 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_randval = query_start * kargs.stride_randval;
}
if constexpr(kStoreLSE)
{
batch_offset_lse_acc = query_start;
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
......@@ -576,6 +578,11 @@ struct FmhaFwdSplitKVKernel
batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
if constexpr(kStoreLSE)
{
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
}
}
// for simplicity, batch stride we just modify the pointer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment