Unverified Commit a1c07e8d authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] Change output accum tensor layout of fmha fwd split-kv & combine kernels (#1527)

* Use same layout for o_acc and o tensor

* Use better param names in partitioner

* Remove redundant kargs 'max_seqlen_q'

* Use better param names in splitkv kernel

* Add comment for additional kernel arguments

* Sync empty loop early return logics between pipelines

* Pass more arguments to cmake in scripts

* Align backslashes

* Fix wrong o_acc tensor view strides

* Change o_acc layout if o_perm=0

* Handle whole row masked via attn_bias

* Use use vector width = 1 for o_acc

* Use more even split sizes
parent 4cd1dc7f
...@@ -552,16 +552,33 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -552,16 +552,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
#endif #endif
auto get_lengths = [&](bool permute, struct
ck_tile::index_t b /*batch*/, {
ck_tile::index_t h /*nhead*/, auto operator()(bool permute,
ck_tile::index_t s /*seqlen*/, ck_tile::index_t b /*batch*/,
ck_tile::index_t d /*hdim*/) { ck_tile::index_t h /*nhead*/,
if(permute) ck_tile::index_t s /*seqlen*/,
return std::array<ck_tile::index_t, 4>{b, h, s, d}; ck_tile::index_t d /*hdim*/)
else {
return std::array<ck_tile::index_t, 4>{b, s, h, d}; if(permute)
}; return std::array<ck_tile::index_t, 4>{b, h, s, d};
else
return std::array<ck_tile::index_t, 4>{b, s, h, d};
}
auto operator()(bool permute,
ck_tile::index_t ns /*num_splits*/,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/)
{
if(permute)
return std::array<ck_tile::index_t, 5>{ns, b, h, s, d};
else
return std::array<ck_tile::index_t, 5>{ns, b, s, h, d};
}
} get_lengths;
bool is_v_rowmajor = vlayout == std::string("r"); bool is_v_rowmajor = vlayout == std::string("r");
...@@ -617,7 +634,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -617,7 +634,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host( ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_kvcache 1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v} ? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v)
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1}); : std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// batch mode of lse data layout is [batch, nhead, seqlen_q] // batch mode of lse data layout is [batch, nhead, seqlen_q]
...@@ -854,7 +871,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -854,7 +871,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}(); }();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = hdim_v; const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments // setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
...@@ -881,7 +898,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -881,7 +898,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments // setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
...@@ -897,12 +914,12 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -897,12 +914,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
// setup split_stride_* arguments (only used in split-kv kernel) // setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v);
args.q_ptr = q_buf.GetDeviceBuffer(); args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer();
......
...@@ -398,10 +398,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -398,10 +398,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_lse_acc, args.nhead_stride_lse_acc,
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.batch_stride_k, args.batch_stride_k, // only used for paged-kvcache
args.batch_stride_v, args.batch_stride_v, // only used for paged-kvcache
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc, args.split_stride_o_acc,
args.window_size_left, args.window_size_left,
...@@ -475,7 +473,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -475,7 +473,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.batch, args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.hdim_v, args.hdim_v,
args.num_splits, args.num_splits,
...@@ -486,7 +483,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -486,7 +483,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc); args.split_stride_o_acc);
} }
...@@ -497,7 +493,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -497,7 +493,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.batch, args.batch,
args.max_seqlen_q,
args.seqlen_q, args.seqlen_q,
args.hdim_v, args.hdim_v,
args.num_splits, args.num_splits,
......
...@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask ...@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{ {
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, x_total / num_splits); const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split; const index_t split_start = x_per_split * i_split;
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split); const index_t split_end = split_start + x_per_split;
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end)); ck_tile::min(origin_end, split_end));
......
...@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void* o_ptr; void* o_ptr;
ck_tile::index_t batch; ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v; ck_tile::index_t hdim_v;
ck_tile::index_t num_splits; ck_tile::index_t num_splits;
...@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc; ck_tile::index_t split_stride_o_acc;
}; };
...@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>, std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>> std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{ {
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr, void* lse_ptr,
void* o_ptr, void* o_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
...@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr, o_acc_ptr,
o_ptr, o_ptr,
batch, batch,
max_seqlen_q,
seqlen_q, seqlen_q,
hdim_v, hdim_v,
num_splits, num_splits,
...@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
batch_stride_o, batch_stride_lse_acc,
batch_stride_lse_acc}; batch_stride_o_acc,
batch_stride_o};
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr, void* lse_ptr,
void* o_ptr, void* o_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
...@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc) ck_tile::index_t split_stride_o_acc)
{ {
...@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr, o_acc_ptr,
o_ptr, o_ptr,
batch, batch,
max_seqlen_q,
-1, // seqlen will be updated by another pointer -1, // seqlen will be updated by another pointer
hdim_v, hdim_v,
num_splits, num_splits,
...@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
...@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return kargs; return kargs;
} }
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, __host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead_, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q_, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
long_index_t batch_offset_lse = 0; long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0; long_index_t batch_offset_o = 0;
...@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch // get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start; batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = query_start; batch_offset_lse = query_start;
} }
batch_offset_o = query_start * kargs.row_stride_o;
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
...@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
} }
else else
{ {
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc; batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
} }
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
...@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto o_acc_dram = [&]() { auto o_acc_dram = [&]() {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr, o_acc_ptr,
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v), make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1), make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{}, number<FmhaPipeline::kAlignmentOacc>{},
number<1>{}); number<1>{});
...@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}), make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{}); sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_max_seqlen_q = const index_t padded_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v = const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view( return transform_tensor_view(
o_acc_dram_view, o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)), make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)),
make_pass_through_transform(padded_hdim_v)), make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits, kargs.num_splits,
kargs.max_seqlen_q, kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
else else
...@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window, o_acc_dram_window,
lse_dram_window, lse_dram_window,
kargs.num_splits, kargs.num_splits,
kargs.max_seqlen_q, kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
}(); }();
......
...@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner ...@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_; static constexpr ck_tile::index_t kN1 = kN1_;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead_, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q_, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) * return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1), ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead_, nhead,
batch_size_); batch_size);
} }
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{ {
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x; const index_t i_block = blockIdx.x;
......
...@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel ...@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc; ck_tile::index_t split_stride_o_acc;
}; };
...@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel ...@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel ...@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
const int32_t* seqstart_k_ptr; const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr; const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k; // only used for paged-kvcache
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v; // only used for paged-kvcache
}; };
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>; using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
...@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel ...@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
...@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel ...@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(seqlen_k_ptr), reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v}; batch_stride_v,
batch_stride_lse_acc,
batch_stride_o_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel ...@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_k, // only used for paged-kvcache
ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
...@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel ...@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
...@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel ...@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size, __host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits); return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel ...@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0; long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0; long_index_t batch_offset_bias = 0;
long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc = long_index_t batch_offset_o_acc = 0;
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(kIsGroupMode) if constexpr(kIsGroupMode)
{ {
...@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel ...@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k; batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
batch_offset_v = key_start * kargs.stride_v; batch_offset_v = key_start * kargs.stride_v;
...@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel ...@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
batch_offset_bias = query_start * kargs.stride_bias + key_start; batch_offset_bias = query_start * kargs.stride_bias + key_start;
} }
batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.stride_o_acc;
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
...@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel ...@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc; batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel ...@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr, o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v), make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.hdim_v, 1), make_tuple(kargs.stride_o_acc, 1),
number<FmhaPipeline::kAlignmentO>{}, number<1>{},
number<1>{}); number<1>{});
return pad_tensor_view( return pad_tensor_view(
......
...@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner ...@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size, __host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) * return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1), ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits, nhead * num_splits,
batch_size); batch_size);
......
...@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func, const OaccElementFunction& o_acc_element_func,
index_t num_splits, index_t num_splits,
index_t max_seqlen_q, index_t seqlen_q,
void* smem_ptr) const void* smem_ptr) const
{ {
// lse_acc tile in LDS // lse_acc tile in LDS
...@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist); auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc); clear_tile(o_acc);
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0; const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0;
for(index_t i_split = 0; i_split < num_splits; ++i_split) for(index_t i_split = 0; i_split < num_splits; ++i_split)
{ {
...@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
}); });
} }
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0}); move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0});
} }
o_acc = tile_elementwise_in(o_acc_element_func, o_acc); o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
...@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccDramBlockWindow& o_acc_dram_block_window, const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window, LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits, index_t num_splits,
index_t max_seqlen_q, index_t seqlen_q,
void* smem_ptr) const void* smem_ptr) const
{ {
return operator()(lse_acc_dram_block_window, return operator()(lse_acc_dram_block_window,
...@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{}, identity{},
identity{}, identity{},
num_splits, num_splits,
max_seqlen_q, seqlen_q,
smem_ptr); smem_ptr);
} }
}; };
......
...@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>(); return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}(); }();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias = static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>(); kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
...@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split); q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if masked and no work to do. // check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{ {
const index_t original_num_total_loop = const index_t original_num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
...@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() { const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{ {
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
} }
......
...@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1 ...@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
if [ $# -ge 2 ] ; then if [ $# -ge 2 ] ; then
GPU_TARGETS=$2 GPU_TARGETS=$2
REST_ARGS=${@:3}
else else
GPU_TARGETS="gfx908;gfx90a;gfx940" GPU_TARGETS="gfx908;gfx90a;gfx940"
REST_ARGS=
fi fi
cmake \ cmake \
...@@ -20,4 +22,5 @@ cmake ...@@ -20,4 +22,5 @@ cmake
-D GPU_TARGETS=$GPU_TARGETS \ -D GPU_TARGETS=$GPU_TARGETS \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
$REST_ARGS \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
...@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1 ...@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
if [ $# -ge 2 ] ; then if [ $# -ge 2 ] ; then
GPU_TARGETS=$2 GPU_TARGETS=$2
REST_ARGS=${@:3}
else else
GPU_TARGETS="gfx908;gfx90a;gfx940" GPU_TARGETS="gfx908;gfx90a;gfx940"
REST_ARGS=
fi fi
cmake \ cmake \
...@@ -20,5 +22,6 @@ cmake ...@@ -20,5 +22,6 @@ cmake
-D GPU_TARGETS=$GPU_TARGETS \ -D GPU_TARGETS=$GPU_TARGETS \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
$REST_ARGS \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment