Commit 7cd4e574 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Add group mode block-mapping for fmha splitkv kernel

parent db952741
...@@ -60,7 +60,7 @@ struct FmhaFwdSplitKVKernel ...@@ -60,7 +60,7 @@ struct FmhaFwdSplitKVKernel
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; }; template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on // clang-format on
__host__ static std::string GetName() CK_TILE_HOST static std::string GetName()
{ {
// sync with generate.py // sync with generate.py
// clang-format off // clang-format off
...@@ -237,7 +237,7 @@ struct FmhaFwdSplitKVKernel ...@@ -237,7 +237,7 @@ struct FmhaFwdSplitKVKernel
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>; using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -361,7 +361,7 @@ struct FmhaFwdSplitKVKernel ...@@ -361,7 +361,7 @@ struct FmhaFwdSplitKVKernel
} }
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -482,10 +482,20 @@ struct FmhaFwdSplitKVKernel ...@@ -482,10 +482,20 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
// TODO: this may need tuning // TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * if constexpr(kIsGroupMode)
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits, {
nhead, return dim3(nhead,
batch_size); batch_size,
ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits);
}
else
{
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
nhead,
batch_size);
}
} }
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
...@@ -498,15 +508,27 @@ struct FmhaFwdSplitKVKernel ...@@ -498,15 +508,27 @@ struct FmhaFwdSplitKVKernel
return ck_tile::make_tuple(quotient, modulus); return ck_tile::make_tuple(quotient, modulus);
}; };
const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits); if constexpr(kIsGroupMode)
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1); {
const index_t i_nhead = blockIdx.y; const auto [mn, i_split] = f(blockIdx.z, kargs.num_splits);
const index_t i_batch = blockIdx.z; const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
else
{
const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment