"example/vscode:/vscode.git/clone" did not exist on "a4f24233e51854c4b5cb7d75637fa0f235f78f8e"
Commit 4776c8c0 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Use un-rolled gemm for Gemm-0

parent 00fe0752
...@@ -330,23 +330,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -330,23 +330,14 @@ struct BlockFmhaPipelineQRKSVSAsync
// ensure k is completely updated on LDS // ensure k is completely updated on LDS
block_sync_lds(); block_sync_lds();
// for kQKHeaddim == 96 (kSubQKHeaddim == 128), we need to use k0_loops
if constexpr(kQKHeaddim == kSubQKHeaddim)
{
gemm_0(s_acc, q, k_lds_window);
}
else
{
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops, 1>{}([&](auto i_k0) {
gemm_0(s_acc, gemm_0(
get_slice_tile( s_acc,
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}), get_slice_tile(q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window, get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{}, sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{})); sequence<kN0, (i_k0 + 1) * kK0>{}));
}); });
}
__builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads __builtin_amdgcn_sched_barrier(0); // prevent from messing up the order of global loads
......
...@@ -57,6 +57,7 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy ...@@ -57,6 +57,7 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
/*
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
...@@ -71,13 +72,14 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy ...@@ -71,13 +72,14 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
typename Problem::SaccDataType, typename Problem::SaccDataType,
Problem::kNumGemm0Warps * get_warp_size(), Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape< TileGemmShape<
sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0, BlockGemmK>, sequence<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, BlockGemmK>, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); constexpr index_t WarpGemmM =
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); static_assert(WarpGemmM == 4 ||
WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> && std::is_same_v<typename Problem::KDataType, half_t> &&
...@@ -118,14 +120,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy ...@@ -118,14 +120,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType, BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename
decltype(warp_gemm)>; Problem::BlockFmhaShape::Gemm0BlockWarps, decltype(warp_gemm)>;
if constexpr(1 < Problem::kNumGemm0Warps) if constexpr(1 < Problem::kNumGemm0Warps)
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
else else
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{}; return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
} }
*/
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment