"docs/vscode:/vscode.git/clone" did not exist on "ba87c1607cae2ae00ab2547e911a101ed27ea18b"
Commit 4cc514f8 authored by danyao12's avatar danyao12
Browse files

fix unpadded lse issue in fwd splitkv

parent 15758862
...@@ -99,10 +99,9 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -99,10 +99,9 @@ struct FmhaFwdSplitKVCombineKernel
struct CommonLSEKargs struct CommonLSEKargs
{ {
void* lse_ptr = nullptr; void* lse_ptr = nullptr;
ck_tile::index_t nhead_stride_lse = 0; ck_tile::index_t nhead_stride_lse = 0;
ck_tile::index_t batch_stride_lse_acc = 0; ck_tile::index_t batch_stride_lse = 0;
ck_tile::index_t batch_stride_lse = 0;
}; };
struct Fp8StaticQuantKargs struct Fp8StaticQuantKargs
...@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>> std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{ {
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -171,14 +171,14 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -171,14 +171,14 @@ struct FmhaFwdSplitKVCombineKernel
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
batch_stride_o}; batch_stride_o,
batch_stride_lse_acc};
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
kargs.lse_ptr = lse_ptr; kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse; kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse_acc = batch_stride_lse_acc; kargs.batch_stride_lse = batch_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
} }
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -282,12 +282,12 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -282,12 +282,12 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch // get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o; batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start;
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse_acc = query_start; batch_offset_lse = query_start;
batch_offset_lse = query_start;
} }
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
...@@ -303,12 +303,11 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -303,12 +303,11 @@ struct FmhaFwdSplitKVCombineKernel
} }
else else
{ {
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o; batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
} }
} }
......
...@@ -47,7 +47,6 @@ struct FmhaFwdSplitKVKernel ...@@ -47,7 +47,6 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasMask = FmhaMask::IsMasking;
...@@ -520,8 +519,9 @@ struct FmhaFwdSplitKVKernel ...@@ -520,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k; batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
batch_offset_v = key_start * kargs.stride_v; batch_offset_v = key_start * kargs.stride_v;
...@@ -538,10 +538,6 @@ struct FmhaFwdSplitKVKernel ...@@ -538,10 +538,6 @@ struct FmhaFwdSplitKVKernel
{ {
batch_offset_randval = query_start * kargs.stride_randval; batch_offset_randval = query_start * kargs.stride_randval;
} }
if constexpr(kStoreLSE)
{
batch_offset_lse_acc = query_start;
}
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
...@@ -566,9 +562,10 @@ struct FmhaFwdSplitKVKernel ...@@ -566,9 +562,10 @@ struct FmhaFwdSplitKVKernel
} }
else else
{ {
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q; batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias; batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
...@@ -578,11 +575,6 @@ struct FmhaFwdSplitKVKernel ...@@ -578,11 +575,6 @@ struct FmhaFwdSplitKVKernel
batch_offset_randval = batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval; static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
} }
if constexpr(kStoreLSE)
{
batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
}
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment