Commit 475c0d2c authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Use array of tiles to represent Q in vgprs

parent 119dd2ac
......@@ -179,7 +179,14 @@ struct BlockFmhaPipelineQRKSVSAsync
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>());
auto q = load_tile(q_dram_window);
using q_tile_type = decltype(load_tile(q_dram_window));
statically_indexed_array<q_tile_type, k0_loops> q_tiles;
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
q_tiles[number<i_k0>{}] = load_tile(q_dram_window);
move_tile_window(q_dram_window, {0, kK0});
});
__builtin_amdgcn_sched_barrier(0);
......@@ -308,10 +315,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
k_lds_window_tmp = get_slice_tile(
k_lds_window,
......@@ -333,11 +337,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds();
// execute last unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(q,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window_tmp);
gemm_0(s_acc, q_tiles[number<k0_loops - 1>{}], k_lds_window_tmp);
}
else // there is only single iteration
{
......@@ -356,10 +356,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
if constexpr(i_k0 < k0_loops - 1)
{
......@@ -396,10 +393,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds();
// execute last unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
});
move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
......@@ -418,10 +412,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds();
// execute last unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
});
};
};
......
......@@ -13,6 +13,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
/* AsyncCopy = */ true,
/* NumPrefetchV = */ 2>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return BlockGemm::template MakeABlockTileDistribution<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK0>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment