Commit ceea2722 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Move code interleaving

parent 76871a6f
...@@ -324,46 +324,33 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -324,46 +324,33 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
if(i_total_loops == 0) // executed by fist iteration if(i_total_loops == 0) // executed by fist iteration
{ {
auto k_lds_window_tmp = if(i_total_loops < num_total_loop)
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}); {
store_tile(k_lds_window_tmp, k_tiles[I0]); auto k_lds_window_tmp =
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
clear_tile(s_acc); // initialize C store_tile(k_lds_window_tmp, k_tiles[I0]);
__builtin_amdgcn_sched_barrier(0); clear_tile(s_acc); // initialize C
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
if constexpr(i_k0 < k0_loops - 1)
{
k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window); k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
};
__builtin_amdgcn_sched_barrier(0);
block_sync_lds(); block_sync_lds();
// execute current unroll of gemm_0 // execute current unroll of gemm_0
gemm_0(s_acc, gemm_0(s_acc,
get_slice_tile( get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}), q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp); k_lds_window_tmp);
if constexpr(i_k0 < k0_loops - 1)
{
k_lds_window_tmp = get_slice_tile( k_lds_window_tmp = get_slice_tile(
k_lds_window, k_lds_window,
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{}, sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{}); sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]); store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]);
}; });
});
move_tile_window(k_dram_window, {0, -k0_loops * kK0});
// executed if the first iteration is not the last iteration move_tile_window(k_dram_window, {kN0, -k0_loops * kK0});
if(i_total_loops < num_total_loop - 1)
{
move_tile_window(k_dram_window, {kN0, 0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops, 1>{}([&](auto i_k0) {
k_tiles[number<i_k0>{}] = load_tile(k_dram_window); k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
...@@ -373,11 +360,53 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -373,11 +360,53 @@ struct BlockFmhaPipelineQRKSVSAsync
}); });
move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
block_sync_lds();
// execute last unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(q,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window_tmp);
}
else
{
auto k_lds_window_tmp =
get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[I0]);
clear_tile(s_acc); // initialize C
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
if constexpr(i_k0 < k0_loops - 1)
{
k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
};
block_sync_lds();
// execute current unroll of gemm_0
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_tmp);
if constexpr(i_k0 < k0_loops - 1)
{
k_lds_window_tmp = get_slice_tile(
k_lds_window,
sequence<((i_k0 + 1) % NumKLdsBuffers) * kN0, 0>{},
sequence<(((i_k0 + 1) % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0 + 1>{}]);
};
});
// move_tile_window(k_dram_window, {0, -k0_loops * kK0});
} }
} }
else // executed by intermediate and last iteration else // executed by intermediate and last iteration
{ {
if(i_total_loops < num_total_loop - 1) if(i_total_loops < num_total_loop - 1) // intermediate iteration
{ {
move_tile_window(k_dram_window, {kN0, 0}); move_tile_window(k_dram_window, {kN0, 0});
...@@ -388,13 +417,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -388,13 +417,13 @@ struct BlockFmhaPipelineQRKSVSAsync
sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{}); sequence<((i_k0 % NumKLdsBuffers) + 1) * kN0, kK0>{});
store_tile(k_lds_window_tmp, k_tiles[number<i_k0>{}]); store_tile(k_lds_window_tmp, k_tiles[number<i_k0>{}]);
if constexpr(i_k0 == 0)
clear_tile(s_acc);
k_tiles[number<i_k0>{}] = load_tile(k_dram_window); k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
if constexpr(i_k0 == 0)
clear_tile(s_acc);
block_sync_lds(); block_sync_lds();
// execute last unroll of gemm_0 // execute last unroll of gemm_0
gemm_0(s_acc, gemm_0(s_acc,
...@@ -405,7 +434,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -405,7 +434,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
} }
else else // last iteration
{ {
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops, 1>{}([&](auto i_k0) {
auto k_lds_window_tmp = auto k_lds_window_tmp =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment