Commit 4cc514f8 authored by danyao12's avatar danyao12
Browse files

fix unpadded lse issue in fwd splitkv

parent 15758862
......@@ -99,10 +99,9 @@ struct FmhaFwdSplitKVCombineKernel
struct CommonLSEKargs
{
void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0;
ck_tile::index_t batch_stride_lse_acc = 0;
ck_tile::index_t batch_stride_lse = 0;
void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0;
ck_tile::index_t batch_stride_lse = 0;
};
struct Fp8StaticQuantKargs
......@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc;
};
struct GroupModeKargs
......@@ -171,14 +171,14 @@ struct FmhaFwdSplitKVCombineKernel
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
batch_stride_o};
batch_stride_o,
batch_stride_lse_acc};
if constexpr(kStoreLSE)
{
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse_acc = batch_stride_lse_acc;
kargs.batch_stride_lse = batch_stride_lse;
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kDoFp8StaticQuant)
{
......@@ -282,12 +282,12 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start;
if constexpr(kStoreLSE)
{
batch_offset_lse_acc = query_start;
batch_offset_lse = query_start;
batch_offset_lse = query_start;
}
// get real # queries & # keys under group mode
......@@ -303,12 +303,11 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
}
......
......@@ -47,7 +47,6 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
......@@ -520,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
......@@ -538,10 +538,6 @@ struct FmhaFwdSplitKVKernel
{
batch_offset_randval = query_start * kargs.stride_randval;
}
if constexpr(kStoreLSE)
{
batch_offset_lse_acc = query_start;
}
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
......@@ -566,9 +562,10 @@ struct FmhaFwdSplitKVKernel
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
......@@ -578,11 +575,6 @@ struct FmhaFwdSplitKVKernel
batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
if constexpr(kStoreLSE)
{
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
}
}
// for simplicity, batch stride we just modify the pointer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment