Commit 7cd4e574 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Add group mode block-mapping for fmha splitkv kernel

parent db952741
......@@ -60,7 +60,7 @@ struct FmhaFwdSplitKVKernel
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
......@@ -237,7 +237,7 @@ struct FmhaFwdSplitKVKernel
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
......@@ -361,7 +361,7 @@ struct FmhaFwdSplitKVKernel
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
......@@ -482,10 +482,20 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t num_splits)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
nhead,
batch_size);
if constexpr(kIsGroupMode)
{
return dim3(nhead,
batch_size,
ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits);
}
else
{
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
nhead,
batch_size);
}
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
......@@ -498,15 +508,27 @@ struct FmhaFwdSplitKVKernel
return ck_tile::make_tuple(quotient, modulus);
};
const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
if constexpr(kIsGroupMode)
{
const auto [mn, i_split] = f(blockIdx.z, kargs.num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
else
{
const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment