Commit 992b7f32 authored by ThomasNing's avatar ThomasNing
Browse files

revert back the pipeline code of fmha

parent f6ceef78
......@@ -75,11 +75,14 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0>>;
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
......@@ -195,11 +198,14 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType,
typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0>>;
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
......@@ -486,7 +492,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
......@@ -541,7 +547,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{})
MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
......@@ -936,11 +942,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::PDataType,
typename Problem::VDataType, typename Problem::OaccDataType, Problem::kBlockSize,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN1,
Problem::BlockFmhaShape::BlockTile::kK1>>;
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>;
auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......@@ -978,4 +987,4 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
};
} // namespace ck_tile
} // namespace ck_tile
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment