Commit aa304c67 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Prefetch first v_tile at earlier time for both kPreloadNextWholeIterationK true/false paths

parent e472af36
......@@ -355,6 +355,10 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
k_tiles[number<k0_loops - 1>{}]);
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
......@@ -395,6 +399,10 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
k_tiles[number<k0_loops - 1>{}]);
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
block_sync_lds();
gemm_0(s_acc,
q_tiles[number<k0_loops - 1>{}],
......@@ -409,7 +417,20 @@ struct BlockFmhaPipelineQRKSVSAsync
{
move_tile_window(k_dram_window, {kN0, 0});
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[I0], k_tiles[I0]);
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
k_tiles[I0] = load_tile(k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
clear_tile(s_acc);
block_sync_lds();
gemm_0(s_acc, q_tiles[I0], k_lds_windows[I0]);
static_for<1, k0_loops, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
k_tiles[number<i_k0>{}]);
......@@ -417,9 +438,6 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
if constexpr(i_k0 == 0)
clear_tile(s_acc);
block_sync_lds();
gemm_0(s_acc,
q_tiles[number<i_k0>{}],
......@@ -430,6 +448,10 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else // last iteration
{
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
k_tiles[number<i_k0>{}]);
......@@ -447,7 +469,7 @@ struct BlockFmhaPipelineQRKSVSAsync
}
else // only preload one unroll of K for next iteration
{
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}], k_tiles[I0]);
if constexpr(i_k0 == 0)
clear_tile(s_acc);
......@@ -463,13 +485,24 @@ struct BlockFmhaPipelineQRKSVSAsync
q_tiles[number<i_k0>{}],
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
});
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], k_tiles[I0]);
// prefetch first v_tile
v_tiles[I0] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
block_sync_lds();
gemm_0(s_acc,
q_tiles[number<k0_loops - 1>{}],
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
};
__builtin_amdgcn_sched_barrier(0);
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
static_for<0, NumPrefetchV, 1>{}([&](auto i_buf) {
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
v_tiles[i_buf] = load_tile(v_dram_window);
move_tile_window(v_dram_window, {0, kK1});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment