Commit 4776c8c0 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Use un-rolled gemm for Gemm-0

parent 00fe0752
...@@ -330,23 +330,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -330,23 +330,14 @@ struct BlockFmhaPipelineQRKSVSAsync
// ensure k is completely updated on LDS // ensure k is completely updated on LDS
block_sync_lds(); block_sync_lds();
// for kQKHeaddim == 96 (kSubQKHeaddim == 128), we need to use k0_loops static_for<0, k0_loops, 1>{}([&](auto i_k0) {
if constexpr(kQKHeaddim == kSubQKHeaddim) gemm_0(
{ s_acc,
gemm_0(s_acc, q, k_lds_window); get_slice_tile(q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
} get_slice_tile(k_lds_window,
else sequence<0, i_k0 * kK0>{},
{ sequence<kN0, (i_k0 + 1) * kK0>{}));
});
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
});
}
__builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads __builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads
......
...@@ -57,75 +57,78 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy ...@@ -57,75 +57,78 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
template <typename Problem> /*
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() template <typename Problem>
{ CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
constexpr index_t BlockGemmK = (KLoadOnce && Problem::BlockFmhaShape::kQKHeaddim == {
Problem::BlockFmhaShape::kSubQKHeaddim) constexpr index_t BlockGemmK = (KLoadOnce && Problem::BlockFmhaShape::kQKHeaddim ==
? Problem::BlockFmhaShape::kSubQKHeaddim Problem::BlockFmhaShape::kSubQKHeaddim)
: Problem::BlockFmhaShape::kK0; ? Problem::BlockFmhaShape::kSubQKHeaddim
: Problem::BlockFmhaShape::kK0;
using GemmProblem = BlockGemmProblem<
typename Problem::QDataType, using GemmProblem = BlockGemmProblem<
typename Problem::KDataType, typename Problem::QDataType,
typename Problem::SaccDataType, typename Problem::KDataType,
Problem::kNumGemm0Warps * get_warp_size(), typename Problem::SaccDataType,
TileGemmShape< Problem::kNumGemm0Warps * get_warp_size(),
sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0, BlockGemmK>, TileGemmShape<
typename Problem::BlockFmhaShape::Gemm0BlockWarps, sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; BlockGemmK>, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename
Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); constexpr auto warp_gemm = []() {
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); static_assert(WarpGemmM == 4 ||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && WarpGemmM == 16 || WarpGemmM == 32);
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
{ std::is_same_v<typename Problem::KDataType, half_t> &&
if constexpr(WarpGemmM == 32) std::is_same_v<typename Problem::SaccDataType, float>)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; {
else if constexpr(WarpGemmM == 16) if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 4 else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M4N64K16{}; return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
} else // WarpGemmM == 4
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> && return WarpGemmMfmaF16F16F32M4N64K16{};
std::is_same_v<typename Problem::KDataType, bf16_t> && }
std::is_same_v<typename Problem::SaccDataType, float>) else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
{ std::is_same_v<typename Problem::KDataType, bf16_t> &&
if constexpr(WarpGemmM == 32) std::is_same_v<typename Problem::SaccDataType, float>)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; {
else if constexpr(WarpGemmM == 16) if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 4 else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M4N64K16{}; return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
} else // WarpGemmM == 4
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> && return WarpGemmMfmaBf16Bf16F32M4N64K16{};
std::is_same_v<typename Problem::KDataType, fp8_t> && }
std::is_same_v<typename Problem::SaccDataType, float>) else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
{ std::is_same_v<typename Problem::KDataType, fp8_t> &&
static_assert(WarpGemmM == 32); std::is_same_v<typename Problem::SaccDataType, float>)
{
// TODO: hard coded here. Otherwise, it may incorrect result static_assert(WarpGemmM == 32);
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< // TODO: hard coded here. Otherwise, it may incorrect result
swizzle_factor>{}; constexpr index_t swizzle_factor = 4;
} // TODO - bf8_t return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
}(); swizzle_factor>{};
} // TODO - bf8_t
using BlockGemmPolicy = }();
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType, using BlockGemmPolicy =
typename Problem::SaccDataType, BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::KDataType,
decltype(warp_gemm)>; typename Problem::SaccDataType,
typename
if constexpr(1 < Problem::kNumGemm0Warps) Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>;
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
} else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
}
*/
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment