Commit ca1a816d authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Move splitkv partitioner logics into splitkv kernel

parent f31fad7d
...@@ -107,9 +107,7 @@ using fmha_epilogue = ...@@ -107,9 +107,7 @@ using fmha_epilogue =
false, false>>; false, false>>;
using fmha_kernel = using fmha_kernel =
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>, ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
......
...@@ -17,10 +17,9 @@ ...@@ -17,10 +17,9 @@
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_> template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdSplitKVKernel struct FmhaFwdSplitKVKernel
{ {
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>; using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
...@@ -60,7 +59,7 @@ struct FmhaFwdSplitKVKernel ...@@ -60,7 +59,7 @@ struct FmhaFwdSplitKVKernel
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; }; template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on // clang-format on
__host__ static std::string GetName() CK_TILE_HOST static std::string GetName()
{ {
// sync with generate.py // sync with generate.py
// clang-format off // clang-format off
...@@ -235,7 +234,7 @@ struct FmhaFwdSplitKVKernel ...@@ -235,7 +234,7 @@ struct FmhaFwdSplitKVKernel
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>; using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -359,7 +358,7 @@ struct FmhaFwdSplitKVKernel ...@@ -359,7 +358,7 @@ struct FmhaFwdSplitKVKernel
} }
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -473,16 +472,60 @@ struct FmhaFwdSplitKVKernel ...@@ -473,16 +472,60 @@ struct FmhaFwdSplitKVKernel
return kargs; return kargs;
} }
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_splits) ck_tile::index_t num_splits)
{ {
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits); // TODO: this may need tuning
if constexpr(kIsGroupMode)
{
return dim3(nhead,
batch_size,
ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits);
}
else
{
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
nhead,
batch_size);
}
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
if constexpr(kIsGroupMode)
{
const auto [mn, i_split] = f(blockIdx.z, kargs.num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
else
{
const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
...@@ -495,8 +538,7 @@ struct FmhaFwdSplitKVKernel ...@@ -495,8 +538,7 @@ struct FmhaFwdSplitKVKernel
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
// divide problem // divide problem
const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v, kargs.num_splits);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaFwdSplitKVTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1) * num_splits,
nhead,
batch_size);
}
CK_TILE_DEVICE auto
operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
{
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [mn, i_split] = f(blockIdx.x, num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment