"...composable_kernel.git" did not exist on "de37550f728ea27c683be3f367547db80cba68a8"
Commit 475c0d2c authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Use array of tiles to represent Q in vgprs

parent 119dd2ac
...@@ -179,7 +179,14 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -179,7 +179,14 @@ struct BlockFmhaPipelineQRKSVSAsync
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(), q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQRegTileDistribution<Problem>()); Policy::template MakeQRegTileDistribution<Problem>());
auto q = load_tile(q_dram_window); using q_tile_type = decltype(load_tile(q_dram_window));
statically_indexed_array<q_tile_type, k0_loops> q_tiles;
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
q_tiles[number<i_k0>{}] = load_tile(q_dram_window);
move_tile_window(q_dram_window, {0, kK0});
});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -308,10 +315,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -308,10 +315,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds(); block_sync_lds();
// execute current unroll of gemm_0 // execute current unroll of gemm_0
gemm_0(s_acc, gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
k_lds_window_tmp = get_slice_tile( k_lds_window_tmp = get_slice_tile(
k_lds_window, k_lds_window,
...@@ -333,11 +337,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -333,11 +337,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds(); block_sync_lds();
// execute last unroll of gemm_0 // execute last unroll of gemm_0
gemm_0(s_acc, gemm_0(s_acc, q_tiles[number<k0_loops - 1>{}], k_lds_window_tmp);
get_slice_tile(q,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window_tmp);
} }
else // there is only single iteration else // there is only single iteration
{ {
...@@ -356,10 +356,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -356,10 +356,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds(); block_sync_lds();
// execute current unroll of gemm_0 // execute current unroll of gemm_0
gemm_0(s_acc, gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
{ {
...@@ -396,10 +393,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -396,10 +393,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds(); block_sync_lds();
// execute last unroll of gemm_0 // execute last unroll of gemm_0
gemm_0(s_acc, gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
}); });
move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
...@@ -418,10 +412,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -418,10 +412,7 @@ struct BlockFmhaPipelineQRKSVSAsync
block_sync_lds(); block_sync_lds();
// execute last unroll of gemm_0 // execute last unroll of gemm_0
gemm_0(s_acc, gemm_0(s_acc, q_tiles[number<i_k0>{}], k_lds_window_tmp);
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
}); });
}; };
}; };
......
...@@ -13,6 +13,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy ...@@ -13,6 +13,15 @@ struct BlockFmhaPipelineQRKSVSAsyncDefaultPolicy
/* AsyncCopy = */ true, /* AsyncCopy = */ true,
/* NumPrefetchV = */ 2> /* NumPrefetchV = */ 2>
{ {
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return BlockGemm::template MakeABlockTileDistribution<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK0>();
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment