Commit 992b7f32 authored by ThomasNing's avatar ThomasNing
Browse files

revert back the pipeline code of fmha

parent f6ceef78
...@@ -75,11 +75,14 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -75,11 +75,14 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType, using BlockGemmProblem =
typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize, BlockGemmPipelineProblem<typename Problem::QDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0, typename Problem::KDataType,
Problem::BlockFmhaShape::BlockTile::kN0, typename Problem::SaccDataType,
Problem::BlockFmhaShape::BlockTile::kK0>>; Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -195,11 +198,14 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -195,11 +198,14 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::QDataType, using BlockGemmProblem =
typename Problem::KDataType, typename Problem::SaccDataType, Problem::kBlockSize, BlockGemmPipelineProblem<typename Problem::QDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0, typename Problem::KDataType,
Problem::BlockFmhaShape::BlockTile::kN0, typename Problem::SaccDataType,
Problem::BlockFmhaShape::BlockTile::kK0>>; Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -936,11 +942,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -936,11 +942,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using BlockGemmProblem = BlockGemmPipelineProblem < typename Problem::PDataType, using BlockGemmProblem =
typename Problem::VDataType, typename Problem::OaccDataType, Problem::kBlockSize, BlockGemmPipelineProblem<typename Problem::PDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0, typename Problem::VDataType,
Problem::BlockFmhaShape::BlockTile::kN1, typename Problem::OaccDataType,
Problem::BlockFmhaShape::BlockTile::kK1>>; Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> && if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment