Commit 992b7f32 authored by ThomasNing's avatar ThomasNing
Browse files

revert back the pipeline code of fmha

parent f6ceef78
...@@ -75,11 +75,14 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -75,11 +75,14 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType, using BlockGemmProblem =
typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize, BlockGemmPipelineProblem<typename Problem::QDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0, typename Problem::KDataType,
Problem::BlockFmhaShape::BlockTile::kN0, typename Problem::SaccDataType,
Problem::BlockFmhaShape::BlockTile::kK0>>; Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -195,11 +198,14 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -195,11 +198,14 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType, using BlockGemmProblem =
typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize, BlockGemmPipelineProblem<typename Problem::QDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0, typename Problem::KDataType,
Problem::BlockFmhaShape::BlockTile::kN0, typename Problem::SaccDataType,
Problem::BlockFmhaShape::BlockTile::kK0>>; Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -486,7 +492,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -486,7 +492,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem, index_t IBuf = 0> template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{}) MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
{ {
// K is always k-major, we use async-copy to load into LDS // K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
...@@ -541,7 +547,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -541,7 +547,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM #if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template <typename Problem, index_t IBuf = 0> template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{}) MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{})
{ {
// K is always k-major, we use async-copy to load into LDS // K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
...@@ -936,11 +942,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -936,11 +942,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::PDataType, using BlockGemmProblem =
typename Problem::VDataType, typename Problem::OaccDataType, Problem::kBlockSize, BlockGemmPipelineProblem<typename Problem::PDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0, typename Problem::VDataType,
Problem::BlockFmhaShape::BlockTile::kN1, typename Problem::OaccDataType,
Problem::BlockFmhaShape::BlockTile::kK1>>; Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> && if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
...@@ -978,4 +987,4 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -978,4 +987,4 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
} }
}; };
} // namespace ck_tile } // namespace ck_tile
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment