Commit aa304c67 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Prefetch first v_tile at earlier time for both kPreloadNextWholeIterationK true/false paths

parent e472af36
...@@ -355,6 +355,10 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -355,6 +355,10 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
k_tiles[number<k0_loops - 1>{}]); k_tiles[number<k0_loops - 1>{}]);
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops, 1>{}([&](auto i_k0) {
...@@ -395,6 +399,10 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -395,6 +399,10 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
k_tiles[number<k0_loops - 1>{}]); k_tiles[number<k0_loops - 1>{}]);
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, gemm_0(s_acc,
q_tiles[number<k0_loops - 1>{}], q_tiles[number<k0_loops - 1>{}],
...@@ -409,7 +417,20 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -409,7 +417,20 @@ struct BlockFmhaPipelineQRKSVSAsync
{ {
move_tile_window(k_dram_window, {kN0, 0}); move_tile_window(k_dram_window, {kN0, 0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) { store_tile(k_lds_windows[I0], k_tiles[I0]);
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
clear_tile(s_acc);
block_sync_lds();
gemm_0(s_acc, q_tiles[I0], k_lds_windows[I0]);
static_for<1, k0_loops, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}], store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
k_tiles[number<i_k0>{}]); k_tiles[number<i_k0>{}]);
...@@ -417,9 +438,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -417,9 +438,6 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
if constexpr(i_k0 == 0)
clear_tile(s_acc);
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, gemm_0(s_acc,
q_tiles[number<i_k0>{}], q_tiles[number<i_k0>{}],
...@@ -430,6 +448,10 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -430,6 +448,10 @@ struct BlockFmhaPipelineQRKSVSAsync
} }
else // last iteration else // last iteration
{ {
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}], store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
k_tiles[number<i_k0>{}]); k_tiles[number<i_k0>{}]);
...@@ -447,7 +469,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -447,7 +469,7 @@ struct BlockFmhaPipelineQRKSVSAsync
} }
else // only preload one unroll of K for next iteration else // only preload one unroll of K for next iteration
{ {
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}], k_tiles[I0]); store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}], k_tiles[I0]);
if constexpr(i_k0 == 0) if constexpr(i_k0 == 0)
clear_tile(s_acc); clear_tile(s_acc);
...@@ -463,13 +485,24 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -463,13 +485,24 @@ struct BlockFmhaPipelineQRKSVSAsync
q_tiles[number<i_k0>{}], q_tiles[number<i_k0>{}],
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]); k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
}); });
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], k_tiles[I0]);
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
block_sync_lds();
gemm_0(s_acc,
q_tiles[number<k0_loops - 1>{}],
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
}; };
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
static_for<0, NumPrefetchV, 1>{}([&](auto i_buf) { static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
v_tiles[i_buf] = load_tile(v_dram_window); v_tiles[i_buf] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment