"experiments/vscode:/vscode.git/clone" did not exist on "474d6aa90101a7ffcea814b5803df7ab87e064c8"
Commit ca1a816d authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Move splitkv partitioner logics into splitkv kernel

parent f31fad7d
......@@ -107,9 +107,7 @@ using fmha_epilogue =
false, false>>;
using fmha_kernel =
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>,
fmha_pipeline,
fmha_epilogue>;
ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
......
......@@ -16,7 +16,6 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
......
......@@ -17,10 +17,9 @@
namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdSplitKVKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
......@@ -60,7 +59,7 @@ struct FmhaFwdSplitKVKernel
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
......@@ -235,7 +234,7 @@ struct FmhaFwdSplitKVKernel
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
......@@ -359,7 +358,7 @@ struct FmhaFwdSplitKVKernel
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
......@@ -473,16 +472,60 @@ struct FmhaFwdSplitKVKernel
return kargs;
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits);
// TODO: this may need tuning
if constexpr(kIsGroupMode)
{
return dim3(nhead,
batch_size,
ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits);
}
else
{
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
nhead,
batch_size);
}
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
if constexpr(kIsGroupMode)
{
const auto [mn, i_split] = f(blockIdx.z, kargs.num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
else
{
const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
......@@ -495,8 +538,7 @@ struct FmhaFwdSplitKVKernel
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] =
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v, kargs.num_splits);
const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaFwdSplitKVTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1) * num_splits,
nhead,
batch_size);
}
CK_TILE_DEVICE auto
operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
{
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [mn, i_split] = f(blockIdx.x, num_splits);
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment